From 974401045adc90cc04ee57904f9d822b4c86e708 Mon Sep 17 00:00:00 2001 From: Saurabh Sinha Date: Wed, 17 Jun 2026 09:56:30 -0400 Subject: [PATCH 1/6] feat: initial implementation of codeanalyzer-go Adds the static code analysis engine for Go including syntactic analysis, schema definitions, CLI entrypoint, and test fixtures. Semantic analysis via CodeQL is scaffolded but not yet implemented. Signed-off-by: Saurabh Sinha --- .gitignore | 34 +- README.md | 209 ++++- cmd/codeanalyzer/main.go | 104 +++ go.mod | 15 + go.sum | 18 + internal/analysis/pass.go | 38 + internal/analysis/registry.go | 90 ++ internal/core/analyzer.go | 201 +++++ internal/core/analyzer_test.go | 363 ++++++++ internal/core/realistic_test.go | 326 +++++++ internal/frameworks/base.go | 38 + internal/options/options.go | 36 + internal/schema/schema.go | 190 ++++ internal/semantic_analysis/call_graph.go | 289 ++++++ internal/semantic_analysis/codeql/codeql.go | 48 + internal/semantic_analysis/codeql/errors.go | 15 + internal/semantic_analysis/codeql/loader.go | 22 + internal/semantic_analysis/codeql/runner.go | 27 + internal/syntactic_analysis/export.go | 15 + internal/syntactic_analysis/signature.go | 86 ++ internal/syntactic_analysis/symbol_table.go | 944 ++++++++++++++++++++ internal/utils/fs.go | 81 ++ internal/utils/logging.go | 30 + testdata/fixture/go.mod | 3 + testdata/fixture/main.go | 15 + testdata/fixture/pkg/greeter/greeter.go | 29 + testdata/realistic/go.mod | 3 + testdata/realistic/main.go | 26 + testdata/realistic/server/middleware.go | 19 + testdata/realistic/server/server.go | 53 ++ testdata/realistic/worker/worker.go | 65 ++ 31 files changed, 3411 insertions(+), 21 deletions(-) create mode 100644 cmd/codeanalyzer/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/analysis/pass.go create mode 100644 internal/analysis/registry.go create mode 100644 internal/core/analyzer.go create mode 100644 internal/core/analyzer_test.go create mode 100644 internal/core/realistic_test.go create mode 100644 internal/frameworks/base.go create mode 100644 internal/options/options.go create mode 100644 internal/schema/schema.go create mode 100644 internal/semantic_analysis/call_graph.go create mode 100644 internal/semantic_analysis/codeql/codeql.go create mode 100644 internal/semantic_analysis/codeql/errors.go create mode 100644 internal/semantic_analysis/codeql/loader.go create mode 100644 internal/semantic_analysis/codeql/runner.go create mode 100644 internal/syntactic_analysis/export.go create mode 100644 internal/syntactic_analysis/signature.go create mode 100644 internal/syntactic_analysis/symbol_table.go create mode 100644 internal/utils/fs.go create mode 100644 internal/utils/logging.go create mode 100644 testdata/fixture/go.mod create mode 100644 testdata/fixture/main.go create mode 100644 testdata/fixture/pkg/greeter/greeter.go create mode 100644 testdata/realistic/go.mod create mode 100644 testdata/realistic/main.go create mode 100644 testdata/realistic/server/middleware.go create mode 100644 testdata/realistic/server/server.go create mode 100644 testdata/realistic/worker/worker.go diff --git a/.gitignore b/.gitignore index 6f72f89..169d632 100644 --- a/.gitignore +++ b/.gitignore @@ -1,25 +1,19 @@ -# If you prefer the allow list template instead of the deny list, see community template: -# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore -# -# Binaries for programs and plugins +# Binaries +/codeanalyzer +/codeanalyzer-go *.exe -*.exe~ -*.dll -*.so -*.dylib -# Test binary, built with `go test -c` -*.test - -# Output of the go coverage tool, specifically when used with LiteIDE -*.out +# Build output +/dist/ +/bin/ -# Dependency directories (remove the comment below to include it) -# vendor/ +# Claude Code session data +.claude/ -# Go workspace file -go.work -go.work.sum +# macOS +.DS_Store -# env file -.env +# Go test cache / coverage +*.test +*.out +coverage.txt diff --git a/README.md b/README.md index 86f3197..608e3a3 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,209 @@ # codeanalyzer-go -The static code analysis engine for Go + +Static analysis for Go using `golang.org/x/tools/go/packages` (AST + type resolution). + +Produces `analysis.json` (symbol table + call graph) in the [CLDK canonical schema](https://github.com/codellm-devkit/python-sdk), consumable by the Python SDK via `CLDK(language="go").analysis(project_path=...)`. + +## Prerequisites + +- **Go 1.25+** — the only required runtime. Install from [go.dev/dl](https://go.dev/dl/). Developed and tested on Go 1.26.4. + +Verify: +```bash +go version +``` + +The binary is self-contained. No other tools are required for Level 1 analysis. + +## Building + +```bash +git clone https://github.com/codellm-devkit/codeanalyzer-go +cd codeanalyzer-go +go build -o codeanalyzer-go ./cmd/codeanalyzer +``` + +This produces a single static binary `codeanalyzer-go` with no runtime dependencies. + +## Usage + +```bash +codeanalyzer-go -i /path/to/go/project +``` + +### Command-line options + +``` +codeanalyzer-go produces analysis.json (symbol table + call graph) for Go projects. + +Usage: + codeanalyzer-go [flags] + +Flags: + -a, --analysis-level int Analysis level: 1=symbol table only, 2=+resolver call graph (default 1) + -c, --cache-dir string Cache directory (default: ~/.cldk/go-cache) + --codeql Enable CodeQL framework-based call graph (level 2, stub) + --eager Force clean rebuild (ignore cache) + -f, --format string Output format: json (default "json"); msgpack is not yet implemented + -h, --help help for codeanalyzer-go + -i, --input string Project root to analyze (required) + -o, --output string Output directory for analysis.json (default: stdout) + --skip-tests Skip *_test.go files (default true) + -t, --target-files strings Restrict analysis to specific files (incremental mode) + -v, --verbose count Verbosity (repeat for more detail) + --version Print version and exit +``` + +### Examples + +**Symbol table only (Level 1, default):** +```bash +codeanalyzer-go -i ./my-go-project +``` +Prints `analysis.json` to stdout. + +**Symbol table + call graph (Level 2):** +```bash +codeanalyzer-go -i ./my-go-project -a 2 +``` + +**Write output to a directory:** +```bash +codeanalyzer-go -i ./my-go-project -a 2 -o /path/to/output/ +# Writes: /path/to/output/analysis.json +``` + +**Incremental analysis (specific files only):** +```bash +codeanalyzer-go -i ./my-go-project -t pkg/server/server.go -t pkg/server/handler.go +``` + +**Force rebuild, ignore cache:** +```bash +codeanalyzer-go -i ./my-go-project --eager +``` + +**Verbose output:** +```bash +codeanalyzer-go -i ./my-go-project -a 2 -vv +``` + +## Analysis levels + +| Level | Flag | What runs | Status | +|-------|------|-----------|--------| +| 1 | `-a 1` (default) | Symbol table only — types, functions, call sites | Implemented | +| 2 | `-a 2` | Level 1 + resolver-based call graph via `go/types` | Implemented | +| — | `--codeql` | CodeQL framework-based call graph (merged with Level 2 edges) | Stub (not yet implemented) | + +**Level 1** loads each package with `packages.NeedSyntax | NeedTypes | NeedTypesInfo` and walks the AST file by file. Call sites are recorded with `callee_signature = null` at this stage. + +**Level 2** adds a resolver pass: for each call site, `go/types` resolves the callee to its full import-path signature (`pkgImportPath.TypeName.MethodName`). Only project-internal edges (both endpoints present in the symbol table) are emitted. `callee_signature` is backfilled on all successfully resolved sites. + +## Output schema + +The root object is `GoApplication`: + +```json +{ + "symbol_table": { + "pkg/greeter/greeter.go": { + "file_path": "pkg/greeter/greeter.go", + "module_name": "greeter", + "imports": [...], + "classes": { + "Greeter": { + "name": "Greeter", + "signature": "example.com/pkg/greeter.Greeter", + "is_interface": false, + "fields": [{ "name": "Prefix", "type": "string", "tags": {"json": "prefix"} }], + "methods": { ... } + } + }, + "functions": { ... } + } + }, + "call_graph": [ + { + "source": "example.com/main.main", + "target": "example.com/pkg/greeter.Greeter.Greet", + "type": "CALL_DEP", + "weight": 1, + "provenance": ["go/types"] + } + ], + "entrypoints": {} +} +``` + +Key schema properties: +- `symbol_table` — keyed by **file path relative to the project root** (never absolute) +- `classes` — JSON key for types (spine compatibility with Java/Python schemas); value is `GoType` +- `module_name` — JSON key for the Go package name (spine compatibility) +- `GoType.is_interface: bool` — unified type model; structs and interfaces are both `GoType` +- `GoCallable.receiver_type / receiver_name` — non-empty for methods, empty for package-level functions +- `GoCallable.return_types: List[str]` — individual return types (Go-specific extension) +- `GoCallsite.is_goroutine: bool` — true when the call is preceded by the `go` keyword +- `GoCallEdge.provenance: List[str]` — resolver identifiers, e.g. `["go/types"]` or `["go/types","codeql"]` +- Call edges are **identity-only**: source and target are `GoCallable.signature` strings that exist in the symbol table + +## Python SDK (CLDK) integration + +```python +from cldk import CLDK + +analysis = CLDK(language="go").analysis(project_path="/path/to/go/project") +for file_path, go_file in analysis.get_symbol_table().items(): + print(file_path, go_file.module_name) +``` + +See [python-sdk](https://github.com/codellm-devkit/python-sdk) for full API documentation. + +## Architecture & Tooling + +| Slot | Choice | Rationale | +|------|--------|-----------| +| Runtime | Go binary | Self-contained; no runtime dep for SDK users | +| Structural parser | `go/ast` (stdlib) | Part of the standard toolchain; no external dep | +| Type resolver | `golang.org/x/tools/go/packages` | Single API for both AST + full type resolution; handles modules natively | +| Optional enrichment | CodeQL (stubbed) | Same enrichment path as Python/Java analyzers; stubbed for Level 1 | +| Build/dep materialization | `go mod download` | Required before `packages.Load` so the module cache is warm; result cached by `go.sum` hash | +| Packaging | Native binary (`go build`) | Zero-runtime-dep distribution; matches Rust/C++ analyzers | +| Analysis depth | Level 1 (rapid) | Symbol table + resolver call graph; CodeQL stub wired but not implemented | +| Call-graph dispatch | Declared-type resolution via `go/types.Selections` | CHA-equivalent; sufficient for cross-package reachability at Level 1 | + +### Package structure + +``` +codeanalyzer-go/ +├── cmd/codeanalyzer/ # CLI entry point (cobra) +├── internal/ +│ ├── core/ # Orchestrator — delegates only, no inlined analysis +│ ├── schema/ # GoApplication, GoFile, GoType, GoCallable, … (schema.go) +│ ├── options/ # AnalysisOptions + AnalysisLevel constants +│ ├── syntactic_analysis/ # SymbolTableBuilder (packages.Load → AST walk) +│ ├── semantic_analysis/ # CallGraphBuilder (go/types resolver) +│ │ └── codeql/ # CodeQL backend subpackage (stubbed) +│ ├── analysis/ # Pluggable pass interface + registry (topo-ordered pipeline) +│ ├── frameworks/ # BaseEntrypointFinder — extension seam for framework passes +│ └── utils/ # DiscoverGoFiles, IsVendored, IsTestFile, logging +└── testdata/fixture/ # Minimal Go fixture used by tests +``` + +The `core` package is a pure orchestrator: it calls `syntactic_analysis` → `semantic_analysis` → `analysis.RunPipeline` → optional CodeQL in sequence, with no inlined parsing logic. Framework-specific analysis extends through the `analysis/` + `frameworks/` layer without touching `core`. + +## Development + +### Running tests + +```bash +go test ./... +``` + +Tests run against `testdata/fixture/` and `testdata/realistic/` — a minimal two-package and a richer multi-package Go module. All 33 tests cover symbol table correctness, call graph edges, JSON round-trip, output format validation, and caching/incremental behaviour. + +### Running from source + +```bash +go run ./cmd/codeanalyzer -i /path/to/project -a 2 +``` diff --git a/cmd/codeanalyzer/main.go b/cmd/codeanalyzer/main.go new file mode 100644 index 0000000..7425781 --- /dev/null +++ b/cmd/codeanalyzer/main.go @@ -0,0 +1,104 @@ +// codeanalyzer-go is the CLI entry point for the Go language analyzer. +// +// It exposes the standard CLDK CLI surface (cli-contract.md) so the Python SDK +// facade can shell out to it uniformly alongside Java and Python backends. +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +const version = "0.1.0" + +func main() { + if err := rootCmd().Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func rootCmd() *cobra.Command { + var ( + inputPath string + outputDir string + format string + level int + targetFiles []string + skipTests bool + eager bool + cacheDir string + useCodeQL bool + verbosity int + showVersion bool + ) + + cmd := &cobra.Command{ + Use: "codeanalyzer-go", + Short: "Static analysis for Go — symbol table and call graph via go/types", + Long: `codeanalyzer-go produces analysis.json (symbol table + call graph) for Go projects. + +The output conforms to the CLDK canonical schema so the Python SDK can load it +via CLDK(language="go").analysis(project_path=...).`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if showVersion { + fmt.Println("codeanalyzer-go " + version) + return nil + } + if inputPath == "" { + return fmt.Errorf("--input / -i is required") + } + utils.SetVerbosity(verbosity) + + if cacheDir == "" { + home, _ := os.UserHomeDir() + cacheDir = home + "/.cldk/go-cache" + } + + opts := options.AnalysisOptions{ + InputPath: inputPath, + OutputDir: outputDir, + Format: format, + Level: options.AnalysisLevel(level), + TargetFiles: targetFiles, + SkipTests: skipTests, + Eager: eager, + CacheDir: cacheDir, + UseCodeQL: useCodeQL, + Verbose: verbosity > 0, + } + + analyzer := core.New(opts) + app, err := analyzer.Analyze() + if err != nil { + return err + } + + return core.WriteOutput(app, outputDir, format) + }, + } + + f := cmd.Flags() + f.StringVarP(&inputPath, "input", "i", "", "Project root to analyze (required)") + f.StringVarP(&outputDir, "output", "o", "", "Output directory for analysis.json (default: stdout)") + f.StringVarP(&format, "format", "f", "json", "Output format: json|msgpack") + f.IntVarP(&level, "analysis-level", "a", 1, + "Analysis level: 1=symbol table only, 2=+resolver call graph") + f.StringSliceVarP(&targetFiles, "target-files", "t", nil, + "Restrict analysis to specific files (incremental mode)") + f.BoolVar(&skipTests, "skip-tests", true, "Skip *_test.go files") + f.BoolVar(&eager, "eager", false, "Force clean rebuild (ignore cache)") + f.StringVarP(&cacheDir, "cache-dir", "c", "", "Cache directory (default: ~/.cldk/go-cache)") + f.BoolVar(&useCodeQL, "codeql", false, "Enable CodeQL framework-based call graph (level 2, stub)") + f.CountVarP(&verbosity, "verbose", "v", "Verbosity (repeat for more detail)") + f.BoolVar(&showVersion, "version", false, "Print version and exit") + + return cmd +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c913971 --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module github.com/codellm-devkit/codeanalyzer-go + +go 1.25.0 + +require ( + github.com/spf13/cobra v1.8.0 + golang.org/x/tools v0.46.0 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/mod v0.37.0 // indirect + golang.org/x/sync v0.21.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..88530cf --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ= +golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0= +golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= +golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/tools v0.46.0 h1:7jTurBkPZu4moS/Uy4OQT1M+QBlsj3wejyZwsT8Z7rk= +golang.org/x/tools v0.46.0/go.mod h1:FrD85F8l+NWL+9XWBSyVSHO6Ne4jutsfIFba7AWQ5Ys= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/analysis/pass.go b/internal/analysis/pass.go new file mode 100644 index 0000000..3142774 --- /dev/null +++ b/internal/analysis/pass.go @@ -0,0 +1,38 @@ +// Package analysis defines the pluggable pass layer for codeanalyzer-go. +// +// An AnalysisPass enriches a GoApplication after the base analysis (symbol table +// + call graph) is built. Passes declare capability tokens in Provides/Requires +// and are ordered topologically by the registry before running. +// +// This mirrors codeanalyzer-python's analysis/_pass.py. The seam exists so that +// codeanalyzer-extension-builder can register out-of-tree passes. +package analysis + +import "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + +// AnalysisContext carries the shared context available to every pass. +type AnalysisContext struct { + ProjectDir string + CacheDir string +} + +// AnalysisResult holds the output of a single pass run. +type AnalysisResult struct { + // Entrypoints discovered by this pass, keyed by framework name. + Entrypoints map[string][]schema.GoEntrypoint + // SyntheticEdges are additional call-graph edges contributed by this pass. + SyntheticEdges []schema.GoCallEdge +} + +// AnalysisPass is the interface every built-in and out-of-tree pass implements. +type AnalysisPass interface { + // Name is the unique identifier for this pass (e.g. "gin-entrypoints"). + Name() string + // Provides is the set of capability tokens this pass adds to the application. + Provides() []string + // Requires is the set of capability tokens that must have been provided before + // this pass runs. The registry hard-errors on unsatisfied dependencies. + Requires() []string + // Run performs the analysis and returns its contributions. + Run(app *schema.GoApplication, ctx AnalysisContext) (AnalysisResult, error) +} diff --git a/internal/analysis/registry.go b/internal/analysis/registry.go new file mode 100644 index 0000000..2503578 --- /dev/null +++ b/internal/analysis/registry.go @@ -0,0 +1,90 @@ +// Package analysis — registry discovers, orders, and runs passes. +package analysis + +import ( + "fmt" + + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +var registeredPasses []AnalysisPass + +// RegisterPass adds a pass to the built-in registry. Call from init() in each +// pass file to register without touching the registry directly. +func RegisterPass(p AnalysisPass) { + registeredPasses = append(registeredPasses, p) +} + +// orderPasses performs a topological sort by Requires/Provides. +// Returns an error if a dependency is unsatisfied or a cycle exists. +func orderPasses(passes []AnalysisPass) ([]AnalysisPass, error) { + provided := map[string]bool{} + var ordered []AnalysisPass + remaining := make([]AnalysisPass, len(passes)) + copy(remaining, passes) + + for len(remaining) > 0 { + progress := false + var next []AnalysisPass + for _, p := range remaining { + ready := true + for _, req := range p.Requires() { + if !provided[req] { + ready = false + break + } + } + if ready { + ordered = append(ordered, p) + for _, cap := range p.Provides() { + provided[cap] = true + } + progress = true + } else { + next = append(next, p) + } + } + if !progress { + return nil, fmt.Errorf("unsatisfied pass dependencies or cycle among: %v", + func() []string { + names := make([]string, len(remaining)) + for i, p := range remaining { + names[i] = p.Name() + } + return names + }()) + } + remaining = next + } + return ordered, nil +} + +// RunPipeline runs all registered passes over app in dependency order, +// merging each result into the running application before the next pass. +// Pass output is deliberately not cached — out-of-tree enrichment must not go stale. +func RunPipeline(app *schema.GoApplication, ctx AnalysisContext) error { + ordered, err := orderPasses(registeredPasses) + if err != nil { + return err + } + if len(ordered) == 0 { + utils.Debug("no registered analysis passes; skipping pipeline") + return nil + } + for _, p := range ordered { + utils.Info("running pass: %s", p.Name()) + result, err := p.Run(app, ctx) + if err != nil { + utils.Warn("pass %s failed: %v (continuing)", p.Name(), err) + continue + } + // Merge entrypoints + for framework, eps := range result.Entrypoints { + app.Entrypoints[framework] = append(app.Entrypoints[framework], eps...) + } + // Merge synthetic edges + app.CallGraph = append(app.CallGraph, result.SyntheticEdges...) + } + return nil +} diff --git a/internal/core/analyzer.go b/internal/core/analyzer.go new file mode 100644 index 0000000..2b9f410 --- /dev/null +++ b/internal/core/analyzer.go @@ -0,0 +1,201 @@ +// Package core is the ORCHESTRATOR for codeanalyzer-go analysis. +// +// Analyzer.Analyze() delegates each phase to its own package; it inlines no +// analysis logic and never hardcodes entrypoints. This mirrors the structural +// discipline of codeanalyzer-python/codeanalyzer/core.py. +// +// Phase order: +// 1. Project materialization (go mod download) +// 2. Symbol table construction (syntactic_analysis) +// 3. Resolver-based call graph (semantic_analysis) — if level >= 2 +// 4. Pass pipeline (analysis/registry) +// 5. Optional CodeQL enrichment (semantic_analysis/codeql) — if --codeql +package core + +import ( + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + + "github.com/codellm-devkit/codeanalyzer-go/internal/analysis" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/semantic_analysis" + "github.com/codellm-devkit/codeanalyzer-go/internal/semantic_analysis/codeql" + "github.com/codellm-devkit/codeanalyzer-go/internal/syntactic_analysis" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +// Analyzer is the top-level analysis driver. Construct with New() and call Analyze(). +type Analyzer struct { + opts options.AnalysisOptions +} + +// New creates an Analyzer for the given options. +func New(opts options.AnalysisOptions) *Analyzer { + return &Analyzer{opts: opts} +} + +// Analyze runs the full analysis pipeline and returns a GoApplication. +func (a *Analyzer) Analyze() (*schema.GoApplication, error) { + // Resolve to absolute path so filepath.Rel works correctly in all cases. + abs, err := filepath.Abs(a.opts.InputPath) + if err != nil { + return nil, fmt.Errorf("resolving input path: %w", err) + } + a.opts.InputPath = abs + utils.Info("analyzing project: %s", a.opts.InputPath) + + // ── Phase 1: Project materialization ────────────────────────────────────── + if err := a.materialize(); err != nil { + // Degrade gracefully — log but don't abort. Partial types are better than nothing. + utils.Warn("dependency materialization failed: %v (continuing with partial types)", err) + } + + // ── Phase 2: Symbol table construction ─────────────────────────────────── + builder := syntactic_analysis.NewSymbolTableBuilder(a.opts.InputPath) + symbolTable, err := builder.Build(a.opts.TargetFiles, a.opts.SkipTests) + if err != nil { + return nil, fmt.Errorf("symbol table construction failed: %w", err) + } + utils.Info("symbol table: %d files", len(symbolTable)) + + app := &schema.GoApplication{ + SymbolTable: symbolTable, + CallGraph: []schema.GoCallEdge{}, + Entrypoints: map[string][]schema.GoEntrypoint{}, + } + + if a.opts.Level < options.LevelCallGraph { + // Level 1 (symbol-table only) — skip call graph and passes. + return a.finalizeAndCache(app) + } + + // ── Phase 3: Resolver-based call graph ──────────────────────────────────── + cgBuilder := semantic_analysis.NewCallGraphBuilder( + a.opts.InputPath, builder.Fset(), builder.Pkgs(), + ) + edges := cgBuilder.Build(symbolTable) + app.CallGraph = edges + utils.Info("call graph: %d edges", len(edges)) + + // ── Phase 4: Pass pipeline ──────────────────────────────────────────────── + ctx := analysis.AnalysisContext{ + ProjectDir: a.opts.InputPath, + CacheDir: a.opts.CacheDir, + } + if err := analysis.RunPipeline(app, ctx); err != nil { + utils.Warn("pass pipeline error: %v", err) + } + + // ── Phase 5: Optional CodeQL enrichment ────────────────────────────────── + if a.opts.UseCodeQL { + cq, err := codeql.New(a.opts.CacheDir, true) + if err != nil { + utils.Warn("CodeQL unavailable: %v", err) + } else { + if err := cq.Build(a.opts.InputPath); err != nil { + utils.Warn("CodeQL build failed: %v", err) + } else { + cqlEdges, err := cq.Edges() + if err != nil { + utils.Warn("CodeQL edge extraction failed: %v", err) + } else { + app.CallGraph = semantic_analysis.MergeEdges(app.CallGraph, cqlEdges) + utils.Info("merged %d CodeQL edges", len(cqlEdges)) + } + } + } + } + + return a.finalizeAndCache(app) +} + +// materialize runs `go mod download` to ensure the module graph is available +// for go/packages to resolve imports and types. Idempotent and cached. +func (a *Analyzer) materialize() error { + goModPath := filepath.Join(a.opts.InputPath, "go.mod") + if _, err := os.Stat(goModPath); os.IsNotExist(err) { + utils.Debug("no go.mod found at %s; skipping go mod download", a.opts.InputPath) + return nil + } + + // Check cache: if the go.sum hasn't changed, skip download. + if !a.opts.Eager { + goSumPath := filepath.Join(a.opts.InputPath, "go.sum") + cacheKey := filepath.Join(a.opts.CacheDir, "go_mod_hash") + if currentHash, err := utils.FileHash(goSumPath); err == nil { + if cachedHash, err := os.ReadFile(cacheKey); err == nil && string(cachedHash) == currentHash { + utils.Debug("go mod download: cache hit, skipping") + return nil + } + } + defer func() { + goSumPath := filepath.Join(a.opts.InputPath, "go.sum") + if currentHash, err := utils.FileHash(goSumPath); err == nil { + _ = utils.EnsureDir(a.opts.CacheDir) + _ = os.WriteFile(cacheKey, []byte(currentHash), 0o644) + } + }() + } + + utils.Info("running go mod download...") + cmd := exec.Command("go", "mod", "download") + cmd.Dir = a.opts.InputPath + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + return cmd.Run() +} + +// finalizeAndCache caches the application and returns it. +func (a *Analyzer) finalizeAndCache(app *schema.GoApplication) (*schema.GoApplication, error) { + if a.opts.CacheDir != "" { + _ = a.saveCache(app) + } + return app, nil +} + +// saveCache persists the application to cache as analysis_cache.json. +func (a *Analyzer) saveCache(app *schema.GoApplication) error { + if err := utils.EnsureDir(a.opts.CacheDir); err != nil { + return err + } + cachePath := filepath.Join(a.opts.CacheDir, "analysis_cache.json") + data, err := json.Marshal(app) + if err != nil { + return err + } + return os.WriteFile(cachePath, data, 0o644) +} + +// WriteOutput writes the GoApplication to outputDir/analysis.json (or stdout +// when outputDir is empty). Only "json" is supported; "msgpack" and other +// values return an explicit error rather than silently falling back to JSON. +func WriteOutput(app *schema.GoApplication, outputDir, format string) error { + if format == "" { + format = "json" + } + switch format { + case "json": + // only supported format + case "msgpack": + return fmt.Errorf("msgpack output is not yet implemented; use --format json") + default: + return fmt.Errorf("unsupported output format %q; supported: json", format) + } + + data, err := json.Marshal(app) + if err != nil { + return err + } + if outputDir == "" { + _, err = os.Stdout.Write(data) + return err + } + if err := utils.EnsureDir(outputDir); err != nil { + return err + } + return os.WriteFile(filepath.Join(outputDir, "analysis.json"), data, 0o644) +} diff --git a/internal/core/analyzer_test.go b/internal/core/analyzer_test.go new file mode 100644 index 0000000..8d322ba --- /dev/null +++ b/internal/core/analyzer_test.go @@ -0,0 +1,363 @@ +package core_test + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" +) + +// fixtureDir returns the absolute path to testdata/fixture. +func fixtureDir(t *testing.T) string { + t.Helper() + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("cannot determine source file path") + } + // internal/core/analyzer_test.go → ../.. → codeanalyzer-go root → testdata/fixture + root := filepath.Join(filepath.Dir(file), "..", "..") + abs, err := filepath.Abs(filepath.Join(root, "testdata", "fixture")) + if err != nil { + t.Fatalf("resolving fixture dir: %v", err) + } + return abs +} + +func runAnalysis(t *testing.T, level options.AnalysisLevel) *schema.GoApplication { + t.Helper() + dir := fixtureDir(t) + outDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + OutputDir: outDir, + Level: level, + SkipTests: true, + CacheDir: t.TempDir(), + } + app, err := core.New(opts).Analyze() + if err != nil { + t.Fatalf("Analyze() failed: %v", err) + } + return app +} + +// ── Symbol table tests ──────────────────────────────────────────────────────── + +func TestSymbolTable_NonEmpty(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + if len(app.SymbolTable) == 0 { + t.Fatal("symbol table is empty") + } +} + +func TestSymbolTable_PathKeysAreRelative(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + for key := range app.SymbolTable { + if filepath.IsAbs(key) { + t.Errorf("symbol_table key is absolute path: %s", key) + } + } +} + +func TestSymbolTable_KnownType(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + const wantFile = "pkg/greeter/greeter.go" + f, ok := app.SymbolTable[wantFile] + if !ok { + t.Fatalf("file %q not in symbol table; got keys: %v", wantFile, keys(app.SymbolTable)) + } + if _, ok := f.Types["Greeter"]; !ok { + t.Errorf("GoType 'Greeter' not found in %s", wantFile) + } +} + +func TestSymbolTable_KnownInterface(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + f := app.SymbolTable["pkg/greeter/greeter.go"] + gt, ok := f.Types["Logger"] + if !ok { + t.Fatal("GoType 'Logger' not found") + } + if !gt.IsInterface { + t.Error("Logger.is_interface should be true") + } +} + +func TestSymbolTable_StructFields(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + f := app.SymbolTable["pkg/greeter/greeter.go"] + gt := f.Types["Greeter"] + if len(gt.Fields) == 0 { + t.Fatal("Greeter has no fields") + } + if gt.Fields[0].Name != "Prefix" { + t.Errorf("expected field 'Prefix', got %q", gt.Fields[0].Name) + } + if _, hasJSON := gt.Fields[0].Tags["json"]; !hasJSON { + t.Error("Greeter.Prefix missing json struct tag") + } +} + +func TestSymbolTable_CallSitesRecorded(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + f := app.SymbolTable["main.go"] + var mainFn *schema.GoCallable + for _, c := range f.Functions { + c := c + if c.Name == "main" { + mainFn = &c + break + } + } + if mainFn == nil { + t.Fatal("main function not found") + } + if len(mainFn.CallSites) == 0 { + t.Error("main() has no recorded call sites") + } + // All call sites must start with callee_signature == nil (pre-resolution). + for _, cs := range mainFn.CallSites { + if cs.CalleeSignature != nil { + t.Errorf("call site %q has callee_signature pre-filled during symbol-table build", cs.MethodName) + } + } +} + +// ── Call graph tests ────────────────────────────────────────────────────────── + +func TestCallGraph_NonEmpty(t *testing.T) { + app := runAnalysis(t, options.LevelCallGraph) + if len(app.CallGraph) == 0 { + t.Fatal("call graph is empty") + } +} + +func TestCallGraph_NoDanglingEdges(t *testing.T) { + app := runAnalysis(t, options.LevelCallGraph) + sigs := allSignatures(app) + for _, e := range app.CallGraph { + if !sigs[e.Source] { + t.Errorf("dangling edge source: %s", e.Source) + } + if !sigs[e.Target] { + t.Errorf("dangling edge target: %s", e.Target) + } + } +} + +func TestCallGraph_Provenance(t *testing.T) { + app := runAnalysis(t, options.LevelCallGraph) + for _, e := range app.CallGraph { + if len(e.Provenance) == 0 { + t.Errorf("edge %s→%s has empty provenance", e.Source, e.Target) + } + } +} + +func TestCallGraph_CallSitesBackfilled(t *testing.T) { + app := runAnalysis(t, options.LevelCallGraph) + f := app.SymbolTable["main.go"] + for _, callable := range f.Functions { + for _, cs := range callable.CallSites { + // Sites that resolved to a project-internal callee must be backfilled. + if cs.CalleeSignature != nil && *cs.CalleeSignature == "" { + t.Errorf("callable %s: call site %q has empty string callee_signature", callable.Signature, cs.MethodName) + } + } + } +} + +// ── JSON output tests ───────────────────────────────────────────────────────── + +func TestWriteOutput_ValidJSON(t *testing.T) { + app := runAnalysis(t, options.LevelCallGraph) + outDir := t.TempDir() + if err := core.WriteOutput(app, outDir, "json"); err != nil { + t.Fatalf("WriteOutput: %v", err) + } + data, err := os.ReadFile(filepath.Join(outDir, "analysis.json")) + if err != nil { + t.Fatalf("reading analysis.json: %v", err) + } + var round schema.GoApplication + if err := json.Unmarshal(data, &round); err != nil { + t.Fatalf("JSON round-trip failed: %v", err) + } + if len(round.SymbolTable) == 0 { + t.Error("round-tripped symbol table is empty") + } +} + +func TestWriteOutput_EmptyFormatDefaultsToJSON(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + outDir := t.TempDir() + if err := core.WriteOutput(app, outDir, ""); err != nil { + t.Fatalf("WriteOutput with empty format: %v", err) + } + if _, err := os.Stat(filepath.Join(outDir, "analysis.json")); err != nil { + t.Fatalf("analysis.json not written: %v", err) + } +} + +func TestWriteOutput_MsgpackNotImplemented(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + outDir := t.TempDir() + err := core.WriteOutput(app, outDir, "msgpack") + if err == nil { + t.Fatal("expected error for --format msgpack, got nil") + } +} + +func TestWriteOutput_UnknownFormatErrors(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + outDir := t.TempDir() + err := core.WriteOutput(app, outDir, "csv") + if err == nil { + t.Fatal("expected error for unknown format, got nil") + } +} + +// ── Caching tests ───────────────────────────────────────────────────────────── + +func TestCaching_SecondRunReuses(t *testing.T) { + dir := fixtureDir(t) + cacheDir := t.TempDir() + outDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + OutputDir: outDir, + Level: options.LevelCallGraph, + SkipTests: true, + CacheDir: cacheDir, + } + // First run — populates cache. + app1, err := core.New(opts).Analyze() + if err != nil { + t.Fatalf("first run: %v", err) + } + // Second run — must not error and must return identical key count. + app2, err := core.New(opts).Analyze() + if err != nil { + t.Fatalf("second run: %v", err) + } + if len(app2.SymbolTable) == 0 { + t.Error("second run returned empty symbol table") + } + if len(app2.SymbolTable) != len(app1.SymbolTable) { + t.Errorf("symbol table key count changed between runs: %d → %d", + len(app1.SymbolTable), len(app2.SymbolTable)) + } +} + +func TestCaching_CacheFileWritten(t *testing.T) { + dir := fixtureDir(t) + cacheDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: cacheDir, + } + if _, err := core.New(opts).Analyze(); err != nil { + t.Fatalf("Analyze: %v", err) + } + cachePath := filepath.Join(cacheDir, "analysis_cache.json") + if _, err := os.Stat(cachePath); err != nil { + t.Fatalf("analysis_cache.json not written to CacheDir: %v", err) + } +} + +func TestCaching_CacheContentsRoundTrip(t *testing.T) { + dir := fixtureDir(t) + cacheDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: cacheDir, + } + app, err := core.New(opts).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + data, err := os.ReadFile(filepath.Join(cacheDir, "analysis_cache.json")) + if err != nil { + t.Fatalf("reading analysis_cache.json: %v", err) + } + var cached schema.GoApplication + if err := json.Unmarshal(data, &cached); err != nil { + t.Fatalf("cache JSON round-trip failed: %v", err) + } + if len(cached.SymbolTable) != len(app.SymbolTable) { + t.Errorf("cache symbol table key count %d != in-memory %d", + len(cached.SymbolTable), len(app.SymbolTable)) + } +} + +func TestCaching_EagerForcesRebuild(t *testing.T) { + dir := fixtureDir(t) + cacheDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: cacheDir, + } + // First run (non-eager) — seeds go_mod_hash. + if _, err := core.New(opts).Analyze(); err != nil { + t.Fatalf("first run: %v", err) + } + cachePath := filepath.Join(cacheDir, "analysis_cache.json") + info1, err := os.Stat(cachePath) + if err != nil { + t.Fatalf("cache not written after first run: %v", err) + } + + time.Sleep(10 * time.Millisecond) + + // Second run with Eager=true — must rewrite cache even when go_mod_hash matches. + opts.Eager = true + if _, err := core.New(opts).Analyze(); err != nil { + t.Fatalf("eager run: %v", err) + } + info2, err := os.Stat(cachePath) + if err != nil { + t.Fatalf("cache not found after eager run: %v", err) + } + // saveCache always writes, so mtime must advance. + if !info2.ModTime().After(info1.ModTime()) { + t.Errorf("analysis_cache.json mtime did not advance on eager=true run: %v vs %v", + info1.ModTime(), info2.ModTime()) + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +func allSignatures(app *schema.GoApplication) map[string]bool { + sigs := map[string]bool{} + for _, f := range app.SymbolTable { + for sig := range f.Functions { + sigs[sig] = true + } + for _, t := range f.Types { + for sig := range t.Methods { + sigs[sig] = true + } + } + } + return sigs +} + +func keys[K comparable, V any](m map[K]V) []K { + out := make([]K, 0, len(m)) + for k := range m { + out = append(out, k) + } + return out +} diff --git a/internal/core/realistic_test.go b/internal/core/realistic_test.go new file mode 100644 index 0000000..3f651f4 --- /dev/null +++ b/internal/core/realistic_test.go @@ -0,0 +1,326 @@ +package core_test + +// Targeted tests for Go-specific schema fields that the greeter fixture does not exercise: +// is_goroutine, return_types (multiple), is_exported=false, receiver_type/name, +// is_variadic, is_embedded, multi-file package, cyclomatic_complexity, specific edges. + +import ( + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" +) + +func realisticDir(t *testing.T) string { + t.Helper() + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("cannot determine source file path") + } + root := filepath.Join(filepath.Dir(file), "..", "..") + abs, err := filepath.Abs(filepath.Join(root, "testdata", "realistic")) + if err != nil { + t.Fatalf("resolving realistic fixture dir: %v", err) + } + return abs +} + +func runRealistic(t *testing.T, level options.AnalysisLevel) *schema.GoApplication { + t.Helper() + dir := realisticDir(t) + outDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + OutputDir: outDir, + Level: level, + SkipTests: true, + CacheDir: t.TempDir(), + } + app, err := core.New(opts).Analyze() + if err != nil { + t.Fatalf("Analyze() failed: %v", err) + } + return app +} + +// findCallableByName searches all functions and methods in a GoFile by short name. +func findCallableByName(f schema.GoFile, name string) *schema.GoCallable { + for _, c := range f.Functions { + if c.Name == name { + c := c + return &c + } + } + for _, gt := range f.Types { + for _, m := range gt.Methods { + if m.Name == name { + m := m + return &m + } + } + } + return nil +} + +// ── Multi-file package ──────────────────────────────────────────────────────── + +func TestRealistic_MultiFilePkg(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + _, hasServer := app.SymbolTable["server/server.go"] + _, hasMiddleware := app.SymbolTable["server/middleware.go"] + if !hasServer { + t.Error("server/server.go missing from symbol table") + } + if !hasMiddleware { + t.Error("server/middleware.go missing from symbol table") + } + // Tags must live in middleware.go, not server.go. + mw := app.SymbolTable["server/middleware.go"] + if findCallableByName(mw, "Tags") == nil { + t.Error("Tags function not found in server/middleware.go") + } +} + +// ── Embedded struct field ───────────────────────────────────────────────────── + +func TestRealistic_EmbeddedField(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + srv := app.SymbolTable["server/server.go"] + server, ok := srv.Types["Server"] + if !ok { + t.Fatal("GoType 'Server' not found in server/server.go") + } + for _, f := range server.Fields { + if f.IsEmbedded { + return // pass + } + } + t.Errorf("Server has no embedded field; fields: %+v", server.Fields) +} + +// ── Multiple return types — (T, error) pattern ──────────────────────────────── + +func TestRealistic_MultipleReturnTypes(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + srv := app.SymbolTable["server/server.go"] + newFn := findCallableByName(srv, "New") + if newFn == nil { + t.Fatal("function 'New' not found in server/server.go") + } + if len(newFn.ReturnTypes) < 2 { + t.Fatalf("New() should have >= 2 return types; got %v", newFn.ReturnTypes) + } + hasError := false + for _, rt := range newFn.ReturnTypes { + if rt == "error" { + hasError = true + } + } + if !hasError { + t.Errorf("New() return_types should include 'error'; got %v", newFn.ReturnTypes) + } +} + +func TestRealistic_ValidateReturnTypes(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + srv := app.SymbolTable["server/server.go"] + validate := findCallableByName(srv, "Validate") + if validate == nil { + t.Fatal("method 'Validate' not found in server/server.go") + } + if len(validate.ReturnTypes) != 2 { + t.Fatalf("Validate() should have 2 return types; got %v", validate.ReturnTypes) + } +} + +// ── Unexported callables ────────────────────────────────────────────────────── + +func TestRealistic_UnexportedMethod(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + srv := app.SymbolTable["server/server.go"] + shutdown := findCallableByName(srv, "shutdown") + if shutdown == nil { + t.Fatal("method 'shutdown' not found in server/server.go") + } + if shutdown.IsExported { + t.Error("shutdown.is_exported should be false") + } +} + +func TestRealistic_UnexportedWorkerMethod(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + wkr := app.SymbolTable["worker/worker.go"] + execute := findCallableByName(wkr, "execute") + if execute == nil { + t.Fatal("method 'execute' not found in worker/worker.go") + } + if execute.IsExported { + t.Error("execute.is_exported should be false") + } +} + +// ── Receiver type / name ────────────────────────────────────────────────────── + +func TestRealistic_ReceiverType(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + srv := app.SymbolTable["server/server.go"] + addr := findCallableByName(srv, "Addr") + if addr == nil { + t.Fatal("method 'Addr' not found in server/server.go") + } + if addr.ReceiverType == "" { + t.Error("Addr().receiver_type should be non-empty") + } + if addr.ReceiverName == "" { + t.Error("Addr().receiver_name should be non-empty") + } + // Pointer receiver — type should contain '*' or 'Server'. + if !strings.Contains(addr.ReceiverType, "Server") { + t.Errorf("Addr().receiver_type %q should reference Server", addr.ReceiverType) + } +} + +func TestRealistic_ValueReceiver(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + // Describe is defined in middleware.go but its receiver type (Server) lives in + // server.go — the reconcileCrossFileMethods pass attaches it to server.go's type. + srv := app.SymbolTable["server/server.go"] + describe := findCallableByName(srv, "Describe") + if describe == nil { + t.Fatal("method 'Describe' not found attached to Server in server/server.go") + } + // Value receiver — ReceiverType should not contain '*'. + if strings.Contains(describe.ReceiverType, "*") { + t.Errorf("Describe().receiver_type %q should be a value receiver (no '*')", describe.ReceiverType) + } + // Path should still record the physical definition file. + if !strings.Contains(describe.Path, "middleware.go") { + t.Errorf("Describe().path %q should point to middleware.go", describe.Path) + } +} + +// ── Variadic parameters ─────────────────────────────────────────────────────── + +func TestRealistic_VariadicParamTags(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + mw := app.SymbolTable["server/middleware.go"] + tags := findCallableByName(mw, "Tags") + if tags == nil { + t.Fatal("function 'Tags' not found in server/middleware.go") + } + for _, p := range tags.Parameters { + if p.IsVariadic { + return // pass + } + } + t.Errorf("Tags() has no variadic parameter; params: %+v", tags.Parameters) +} + +func TestRealistic_VariadicParamCombine(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + wkr := app.SymbolTable["worker/worker.go"] + combine := findCallableByName(wkr, "Combine") + if combine == nil { + t.Fatal("function 'Combine' not found in worker/worker.go") + } + for _, p := range combine.Parameters { + if p.IsVariadic { + return // pass + } + } + t.Errorf("Combine() has no variadic parameter; params: %+v", combine.Parameters) +} + +// ── Goroutine call site ─────────────────────────────────────────────────────── + +func TestRealistic_GoroutineCallsite(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + wkr := app.SymbolTable["worker/worker.go"] + run := findCallableByName(wkr, "Run") + if run == nil { + t.Fatal("method 'Run' not found in worker/worker.go") + } + for _, cs := range run.CallSites { + if cs.IsGoroutine { + return // pass + } + } + t.Errorf("Run() has no goroutine call site; sites: %+v", run.CallSites) +} + +// ── Cyclomatic complexity ───────────────────────────────────────────────────── + +func TestRealistic_CyclomaticComplexity(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + wkr := app.SymbolTable["worker/worker.go"] + execute := findCallableByName(wkr, "execute") + if execute == nil { + t.Fatal("method 'execute' not found in worker/worker.go") + } + // execute() has an `if err != nil` branch → CC >= 2. + if execute.CyclomaticComplexity < 2 { + t.Errorf("execute().cyclomatic_complexity should be >= 2; got %d", execute.CyclomaticComplexity) + } +} + +// ── Interface detection ─────────────────────────────────────────────────────── + +func TestRealistic_InterfaceType(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + wkr := app.SymbolTable["worker/worker.go"] + proc, ok := wkr.Types["Processor"] + if !ok { + t.Fatal("GoType 'Processor' not found in worker/worker.go") + } + if !proc.IsInterface { + t.Error("Processor.is_interface should be true") + } +} + +// ── Specific call-graph edge ────────────────────────────────────────────────── + +func TestRealistic_SpecificCallEdge(t *testing.T) { + app := runRealistic(t, options.LevelCallGraph) + // main() calls server.New() — this is a cross-package project-internal edge. + const wantTarget = "example.com/realistic/server.New" + for _, e := range app.CallGraph { + if e.Target == wantTarget { + return // pass + } + } + t.Errorf("call graph missing expected edge to %s; edges: %v", wantTarget, edgeTargets(app)) +} + +func TestRealistic_CrossPackageEdges(t *testing.T) { + app := runRealistic(t, options.LevelCallGraph) + // At least one edge must cross the main→server boundary and one main→worker boundary. + var serverEdge, workerEdge bool + for _, e := range app.CallGraph { + if strings.Contains(e.Target, "realistic/server.") { + serverEdge = true + } + if strings.Contains(e.Target, "realistic/worker.") { + workerEdge = true + } + } + if !serverEdge { + t.Error("no call-graph edge into the server package") + } + if !workerEdge { + t.Error("no call-graph edge into the worker package") + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +func edgeTargets(app *schema.GoApplication) []string { + out := make([]string, 0, len(app.CallGraph)) + for _, e := range app.CallGraph { + out = append(out, e.Target) + } + return out +} diff --git a/internal/frameworks/base.go b/internal/frameworks/base.go new file mode 100644 index 0000000..9d17251 --- /dev/null +++ b/internal/frameworks/base.go @@ -0,0 +1,38 @@ +// Package frameworks provides the base for entrypoint-finder passes. +// +// Concrete finders (gin-router, net/http handler, etc.) embed BaseEntrypointFinder +// and override FindEntrypoints. They register themselves via analysis.RegisterPass. +// +// This mirrors codeanalyzer-python's frameworks/_base.py. +package frameworks + +import ( + "github.com/codellm-devkit/codeanalyzer-go/internal/analysis" + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" +) + +// BaseEntrypointFinder is the abstract base for framework-specific entrypoint finders. +// Concrete implementations embed this struct and override FindEntrypoints. +type BaseEntrypointFinder struct { + name string + framework string +} + +// NewBaseEntrypointFinder constructs a finder with the given name and framework label. +func NewBaseEntrypointFinder(name, framework string) BaseEntrypointFinder { + return BaseEntrypointFinder{name: name, framework: framework} +} + +// Name implements analysis.AnalysisPass. +func (b BaseEntrypointFinder) Name() string { return b.name } + +// Provides implements analysis.AnalysisPass — finders provide the framework name as a capability. +func (b BaseEntrypointFinder) Provides() []string { return []string{b.framework + ":entrypoints"} } + +// Requires implements analysis.AnalysisPass — no dependencies by default. +func (b BaseEntrypointFinder) Requires() []string { return nil } + +// Run implements analysis.AnalysisPass — delegates to FindEntrypoints. +func (b BaseEntrypointFinder) Run(app *schema.GoApplication, ctx analysis.AnalysisContext) (analysis.AnalysisResult, error) { + return analysis.AnalysisResult{}, nil +} diff --git a/internal/options/options.go b/internal/options/options.go new file mode 100644 index 0000000..97e6782 --- /dev/null +++ b/internal/options/options.go @@ -0,0 +1,36 @@ +// Package options defines the AnalysisOptions passed from the CLI into core. +package options + +// AnalysisLevel controls how much analysis is performed. +type AnalysisLevel int + +const ( + // LevelSymbolTable produces the symbol table only (no call graph). + LevelSymbolTable AnalysisLevel = 1 + // LevelCallGraph produces symbol table + resolver-based call graph (still cheap). + LevelCallGraph AnalysisLevel = 2 +) + +// AnalysisOptions is the configuration surface passed from the CLI into Analyzer. +type AnalysisOptions struct { + // InputPath is the project root to analyze. + InputPath string + // OutputDir is where analysis.json is written. Empty = write to stdout. + OutputDir string + // Format is the serialization format: "json" or "msgpack". + Format string + // AnalysisLevel controls symbol-table-only (1) vs + call graph (2). + Level AnalysisLevel + // TargetFiles restricts analysis to specific files (incremental mode). + TargetFiles []string + // SkipTests skips test files (files ending in _test.go). + SkipTests bool + // Eager forces a clean rebuild ignoring any cache. + Eager bool + // CacheDir is where per-file caches and intermediate data are stored. + CacheDir string + // UseCodeQL enables the framework-based (Tier-2) CodeQL call graph. + UseCodeQL bool + // Verbose enables verbose logging. + Verbose bool +} diff --git a/internal/schema/schema.go b/internal/schema/schema.go new file mode 100644 index 0000000..6eedd79 --- /dev/null +++ b/internal/schema/schema.go @@ -0,0 +1,190 @@ +// Package schema defines the canonical data contract for codeanalyzer-go output. +// +// The root object is GoApplication{symbol_table, call_graph}. Every field uses +// snake_case JSON keys so the Python SDK's Pydantic models parse it without +// transformation. Design decisions are recorded in .claude/SCHEMA_DECISIONS.md. +package schema + +// ─── Leaf models ───────────────────────────────────────────────────────────── + +// GoImport represents a single import declaration in a Go source file. +type GoImport struct { + Module string `json:"module"` + Alias string `json:"alias,omitempty"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` +} + +// GoComment represents a comment (line, block, or doc comment). +type GoComment struct { + Content string `json:"content"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + StartColumn int `json:"start_column"` + EndColumn int `json:"end_column"` + IsDocComment bool `json:"is_doc_comment"` +} + +// GoParameter represents a single parameter of a callable. +type GoParameter struct { + Name string `json:"name"` + Type string `json:"type"` + IsVariadic bool `json:"is_variadic"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` +} + +// GoVariableDeclaration represents a variable declaration (var/short-assign). +type GoVariableDeclaration struct { + Name string `json:"name"` + Type string `json:"type,omitempty"` + Initializer string `json:"initializer,omitempty"` + Scope string `json:"scope"` // "package" | "function" + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + StartColumn int `json:"start_column"` + EndColumn int `json:"end_column"` +} + +// GoField represents a struct field, including its struct tags. +type GoField struct { + Name string `json:"name"` + Type string `json:"type"` + Comments []GoComment `json:"comments"` + Tags map[string]string `json:"tags"` // parsed struct tags, e.g. {"json": "name,omitempty"} + IsExported bool `json:"is_exported"` + IsEmbedded bool `json:"is_embedded"` // anonymous/embedded field + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` +} + +// GoSymbol represents a symbol accessed inside a callable. +type GoSymbol struct { + Name string `json:"name"` + Scope string `json:"scope"` // "local" | "package" | "external" + Kind string `json:"kind"` // "variable" | "function" | "type" | "constant" + Type string `json:"type,omitempty"` + QualifiedName string `json:"qualified_name,omitempty"` + IsBuiltin bool `json:"is_builtin"` + Lineno int `json:"lineno"` + ColOffset int `json:"col_offset"` +} + +// ─── Call site ──────────────────────────────────────────────────────────────── + +// GoCallsite represents a single call expression inside a callable. +// callee_signature is null when first recorded; the resolver backfills it +// during call-graph construction (never during symbol-table build). +type GoCallsite struct { + MethodName string `json:"method_name"` + ReceiverExpr string `json:"receiver_expr,omitempty"` + ReceiverType string `json:"receiver_type,omitempty"` + ArgumentTypes []string `json:"argument_types"` + ReturnType string `json:"return_type,omitempty"` + CalleeSignature *string `json:"callee_signature"` // null until resolved + IsConstructorCall bool `json:"is_constructor_call"` + IsGoroutine bool `json:"is_goroutine"` // true when preceded by `go` keyword + StartLine int `json:"start_line"` + StartColumn int `json:"start_column"` + EndLine int `json:"end_line"` + EndColumn int `json:"end_column"` +} + +// ─── Callable ───────────────────────────────────────────────────────────────── + +// GoCallable represents a function, method, or function literal in Go. +// receiver_type / receiver_name are non-empty for methods; empty for functions. +type GoCallable struct { + Name string `json:"name"` + Path string `json:"path"` + Signature string `json:"signature"` // signatureOf() output — edge id + Comments []GoComment `json:"comments"` + Parameters []GoParameter `json:"parameters"` + ReturnType string `json:"return_type"` // joined, e.g. "(int, error)" + ReturnTypes []string `json:"return_types"` // Go extension: individual return types + Code string `json:"code,omitempty"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + CodeStartLine int `json:"code_start_line"` + AccessedSymbols []GoSymbol `json:"accessed_symbols"` + CallSites []GoCallsite `json:"call_sites"` + InnerCallables map[string]GoCallable `json:"inner_callables"` + LocalVariables []GoVariableDeclaration `json:"local_variables"` + CyclomaticComplexity int `json:"cyclomatic_complexity"` + IsEntrypoint bool `json:"is_entrypoint"` + EntrypointFramework string `json:"entrypoint_framework,omitempty"` + ReceiverType string `json:"receiver_type,omitempty"` // e.g. "*MyStruct" + ReceiverName string `json:"receiver_name,omitempty"` // e.g. "r" + IsExported bool `json:"is_exported"` +} + +// ─── Type (struct or interface) ─────────────────────────────────────────────── + +// GoType represents a named Go type — either a struct (is_interface=false) or +// an interface (is_interface=true). This unified model mirrors Go's native type +// system where both are types.Named with different Underlying() values. +// +// base_types carries: embedded struct types (for structs) and the method-set +// signatures of satisfied interfaces (for both) — the Go analog of base_classes. +type GoType struct { + Name string `json:"name"` + Signature string `json:"signature"` // signatureOf() output + Comments []GoComment `json:"comments"` + Code string `json:"code,omitempty"` + IsInterface bool `json:"is_interface"` + IsExported bool `json:"is_exported"` + Fields []GoField `json:"fields"` // empty for interfaces + Methods map[string]GoCallable `json:"methods"` // sig → callable + BaseTypes []string `json:"base_types"` // embedded types + satisfied interface sigs + InnerTypes map[string]GoType `json:"inner_types"` // Go doesn't nest, but preserves spine + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` +} + +// ─── File (Module analog) ───────────────────────────────────────────────────── + +// GoFile is the Module analog for Go: one compilation unit (source file). +// symbol_table is keyed by file path relative to the project root. +type GoFile struct { + FilePath string `json:"file_path"` + PackageName string `json:"module_name"` // JSON key = module_name for spine compat + Imports []GoImport `json:"imports"` + Comments []GoComment `json:"comments"` + Types map[string]GoType `json:"classes"` // JSON key = classes for spine compat + Functions map[string]GoCallable `json:"functions"` + Variables []GoVariableDeclaration `json:"variables"` + // Caching metadata + ContentHash *string `json:"content_hash"` + LastModified *float64 `json:"last_modified"` + FileSize *int64 `json:"file_size"` +} + +// ─── Call graph edge ────────────────────────────────────────────────────────── + +// GoCallEdge is an identity-only call-graph edge. source and target are +// GoCallable.signature strings that must exist in the symbol table. +type GoCallEdge struct { + Source string `json:"source"` + Target string `json:"target"` + Type string `json:"type"` // always "CALL_DEP" + Weight int `json:"weight"` // accumulated when merging backends + Provenance []string `json:"provenance"` // e.g. ["go/types"], ["go/types","codeql"] + Tags map[string]string `json:"tags"` +} + +// ─── Root object ────────────────────────────────────────────────────────────── + +// GoApplication is the root of analysis.json. The SDK facade deserializes this. +type GoApplication struct { + SymbolTable map[string]GoFile `json:"symbol_table"` // file_path → GoFile + CallGraph []GoCallEdge `json:"call_graph"` // identity-only edges + Entrypoints map[string][]GoEntrypoint `json:"entrypoints"` // optional, default {} +} + +// GoEntrypoint marks a callable as a framework entry point. +type GoEntrypoint struct { + Signature string `json:"signature"` + Framework string `json:"framework"` + DetectionSource string `json:"detection_source"` + Tags map[string]string `json:"tags"` +} diff --git a/internal/semantic_analysis/call_graph.go b/internal/semantic_analysis/call_graph.go new file mode 100644 index 0000000..7dd4adc --- /dev/null +++ b/internal/semantic_analysis/call_graph.go @@ -0,0 +1,289 @@ +// Package semantic_analysis builds the resolver-based call graph (Level 1, Tier 1). +// +// This stage uses the same golang.org/x/tools/go/packages load that built the +// symbol table. For each recorded call site it resolves the callee to a +// *types.Func, derives its signature via signatureOf(), backfills +// callee_signature in place, and emits an identity-only GoCallEdge. +// +// Precision choice: declared-type dispatch (CHA-style). Pointer-receiver +// methods are followed; interface dispatch records the interface method +// signature and falls back gracefully when the concrete type is unknown. +package semantic_analysis + +import ( + "go/ast" + "go/token" + "go/types" + + "golang.org/x/tools/go/packages" + + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +// CallGraphBuilder resolves call sites and builds the call graph. +type CallGraphBuilder struct { + projectDir string + fset *token.FileSet + pkgs map[string]*packages.Package +} + +// NewCallGraphBuilder creates a builder using the same pkgs/fset loaded for the symbol table. +func NewCallGraphBuilder(projectDir string, fset *token.FileSet, pkgs map[string]*packages.Package) *CallGraphBuilder { + return &CallGraphBuilder{projectDir: projectDir, fset: fset, pkgs: pkgs} +} + +// Build resolves all call sites in symbolTable, backfills callee_signature, and +// returns the identity-only edge list. Never crashes on unresolved sites — it logs +// and skips the edge while leaving callee_signature nil. +func (cg *CallGraphBuilder) Build(symbolTable map[string]schema.GoFile) []schema.GoCallEdge { + var edges []schema.GoCallEdge + seen := map[[2]string]bool{} + + // Build the known-sig set once. Only emit edges where the target is in the + // project's symbol table (drop stdlib / external callees, same as Python/Jedi). + knownSigs := buildKnownSigs(symbolTable) + + for fileKey, goFile := range symbolTable { + // Find the package for this file. + pkg := cg.packageForFile(fileKey) + if pkg == nil || pkg.TypesInfo == nil { + continue + } + + // Resolve methods on types. + for typeSig, goType := range goFile.Types { + for methSig, callable := range goType.Methods { + newCallable, newEdges := cg.resolveCallable(pkg, callable, seen, knownSigs) + goType.Methods[methSig] = newCallable + edges = append(edges, newEdges...) + } + goFile.Types[typeSig] = goType + } + + // Resolve package-level functions. + for fnSig, callable := range goFile.Functions { + newCallable, newEdges := cg.resolveCallable(pkg, callable, seen, knownSigs) + goFile.Functions[fnSig] = newCallable + edges = append(edges, newEdges...) + } + + symbolTable[fileKey] = goFile + } + + return edges +} + +// buildKnownSigs returns the set of all callable signatures present in the symbol table. +func buildKnownSigs(symbolTable map[string]schema.GoFile) map[string]bool { + sigs := make(map[string]bool) + for _, f := range symbolTable { + for sig := range f.Functions { + sigs[sig] = true + } + for _, t := range f.Types { + for sig := range t.Methods { + sigs[sig] = true + } + } + } + return sigs +} + +// resolveCallable backfills callee_signature on each call site and produces edges. +// Only emits edges where the target is in knownSigs (project-internal callees). +// External/stdlib callees have callee_signature backfilled but no edge emitted — +// matching Python/Jedi's behavior of dropping unresolved external sites. +func (cg *CallGraphBuilder) resolveCallable( + pkg *packages.Package, + callable schema.GoCallable, + seen map[[2]string]bool, + knownSigs map[string]bool, +) (schema.GoCallable, []schema.GoCallEdge) { + var edges []schema.GoCallEdge + + for i := range callable.CallSites { + site := &callable.CallSites[i] + if site.CalleeSignature != nil { + continue // already resolved + } + + calleeSig := cg.resolveCallSite(pkg, site) + if calleeSig == "" { + utils.Debug("unresolved call site: %s in %s", site.MethodName, callable.Signature) + continue + } + + // Backfill the site regardless of whether the callee is in-project. + site.CalleeSignature = &calleeSig + + // Only emit an edge when the callee is in the project's symbol table. + if !knownSigs[calleeSig] { + utils.Debug("external callee (no edge): %s", calleeSig) + continue + } + + key := [2]string{callable.Signature, calleeSig} + if seen[key] { + continue + } + seen[key] = true + + edges = append(edges, schema.GoCallEdge{ + Source: callable.Signature, + Target: calleeSig, + Type: "CALL_DEP", + Weight: 1, + Provenance: []string{"go/types"}, + Tags: map[string]string{}, + }) + } + + return callable, edges +} + +// resolveCallSite resolves a single call site to a callee signature string. +// Returns "" when the site cannot be resolved (graceful fallback). +func (cg *CallGraphBuilder) resolveCallSite(pkg *packages.Package, site *schema.GoCallsite) string { + if pkg.TypesInfo == nil { + return "" + } + + // Walk the package's syntax to find the call expression at this location. + for _, astFile := range pkg.Syntax { + var result string + ast.Inspect(astFile, func(n ast.Node) bool { + if result != "" { + return false + } + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + pos := cg.fset.Position(call.Pos()) + if pos.Line != site.StartLine || pos.Column != site.StartColumn { + return true + } + result = cg.resolveCallExpr(pkg, call) + return false + }) + if result != "" { + return result + } + } + return "" +} + +// resolveCallExpr resolves a *ast.CallExpr to a callee signature using go/types. +func (cg *CallGraphBuilder) resolveCallExpr(pkg *packages.Package, call *ast.CallExpr) string { + info := pkg.TypesInfo + + switch fn := call.Fun.(type) { + case *ast.Ident: + if obj := info.ObjectOf(fn); obj != nil { + if f, ok := obj.(*types.Func); ok { + return calleeSignatureOf(f) + } + // Type conversion — normalize as constructor. + if tn, ok := obj.(*types.TypeName); ok { + return calleeSignatureOf(tn) + ".__new__" + } + } + case *ast.SelectorExpr: + sel, ok := info.Selections[fn] + if ok { + if f, ok := sel.Obj().(*types.Func); ok { + return calleeSignatureOf(f) + } + } + // Package-level function via qualified identifier. + if obj := info.ObjectOf(fn.Sel); obj != nil { + if f, ok := obj.(*types.Func); ok { + return calleeSignatureOf(f) + } + } + } + return "" +} + +// calleeSignatureOf is the call-site mirror of signatureOf — same canonicalization. +// Must produce byte-identical strings to signatureOf() in syntactic_analysis. +func calleeSignatureOf(obj types.Object) string { + if obj == nil { + return "" + } + pkg := obj.Pkg() + pkgPath := "" + if pkg != nil { + pkgPath = pkg.Path() + } + + switch o := obj.(type) { + case *types.Func: + sig := o.Type().(*types.Signature) + recv := sig.Recv() + if recv != nil { + recvType := recv.Type() + if ptr, ok := recvType.(*types.Pointer); ok { + recvType = ptr.Elem() + } + if named, ok := recvType.(*types.Named); ok { + typeName := named.Obj().Name() + return pkgPath + "." + typeName + "." + o.Name() + } + } + return pkgPath + "." + o.Name() + case *types.TypeName: + return pkgPath + "." + o.Name() + default: + return pkgPath + "." + obj.Name() + } +} + +// packageForFile returns the *packages.Package that contains the given relative file path. +func (cg *CallGraphBuilder) packageForFile(relPath string) *packages.Package { + for _, pkg := range cg.pkgs { + for _, f := range pkg.GoFiles { + if utils.RelativePath(cg.projectDir, f) == relPath { + return pkg + } + } + } + return nil +} + +// MergeEdges merges two edge lists, unioning provenance and accumulating weight +// for duplicate (source, target) pairs. Mirrors Python's merge_edges(). +func MergeEdges(primary, secondary []schema.GoCallEdge) []schema.GoCallEdge { + type key struct{ src, tgt string } + index := map[key]int{} + result := make([]schema.GoCallEdge, 0, len(primary)+len(secondary)) + + add := func(e schema.GoCallEdge) { + k := key{e.Source, e.Target} + if idx, exists := index[k]; exists { + // Merge provenance (union) and accumulate weight. + provSet := map[string]bool{} + for _, p := range result[idx].Provenance { + provSet[p] = true + } + for _, p := range e.Provenance { + if !provSet[p] { + result[idx].Provenance = append(result[idx].Provenance, p) + } + } + result[idx].Weight += e.Weight + } else { + index[k] = len(result) + result = append(result, e) + } + } + + for _, e := range primary { + add(e) + } + for _, e := range secondary { + add(e) + } + return result +} diff --git a/internal/semantic_analysis/codeql/codeql.go b/internal/semantic_analysis/codeql/codeql.go new file mode 100644 index 0000000..88d1d29 --- /dev/null +++ b/internal/semantic_analysis/codeql/codeql.go @@ -0,0 +1,48 @@ +package codeql + +import ( + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +// CodeQL is the top-level handle for CodeQL-backed analysis. core.go talks only +// to this type; it never touches the binary, database, or query strings directly. +// TODO(level-2): implement Build() and Edges(). +type CodeQL struct { + loader *Loader + runner *Runner + enabled bool +} + +// New probes for the CodeQL binary and returns a CodeQL handle. +// If the binary is absent and enabled=true, returns an error. +func New(cacheDir string, enabled bool) (*CodeQL, error) { + if !enabled { + return &CodeQL{enabled: false}, nil + } + loader, err := NewLoader() + if err != nil { + return nil, err + } + runner := NewRunner(loader, cacheDir+"/codeql-db") + return &CodeQL{loader: loader, runner: runner, enabled: true}, nil +} + +// Build creates the CodeQL database for projectDir. +// No-op when CodeQL is disabled. +func (c *CodeQL) Build(projectDir string) error { + if !c.enabled { + return nil + } + utils.Info("building CodeQL database (stub — TODO level-2)") + return c.runner.BuildDatabase(projectDir) +} + +// Edges returns Tier-2 call-graph edges. +// Returns empty slice when disabled; returns ErrCodeQLNotImplemented when enabled (stub). +func (c *CodeQL) Edges() ([]schema.GoCallEdge, error) { + if !c.enabled { + return nil, nil + } + return c.runner.QueryCallGraph() +} diff --git a/internal/semantic_analysis/codeql/errors.go b/internal/semantic_analysis/codeql/errors.go new file mode 100644 index 0000000..a4affd1 --- /dev/null +++ b/internal/semantic_analysis/codeql/errors.go @@ -0,0 +1,15 @@ +// Package codeql is the isolated framework-backend subpackage for CodeQL analysis. +// It provides Tier-2 (framework-based) call-graph edges beyond what go/types resolves. +// +// The seams (loader, driver, query runner, errors) are scaffolded here even though +// the implementation is stubbed — dropping in the full implementation later requires +// no refactor. Mirrors codeanalyzer-python's semantic_analysis/codeql/ split. +package codeql + +import "errors" + +// ErrCodeQLNotFound is returned when the CodeQL CLI binary cannot be located. +var ErrCodeQLNotFound = errors.New("codeql: CLI binary not found; install from https://github.com/github/codeql-cli-binaries") + +// ErrCodeQLNotImplemented is returned when CodeQL analysis is requested but not yet implemented. +var ErrCodeQLNotImplemented = errors.New("codeql: Go backend is a wired stub — implementation TODO (level-2 analysis)") diff --git a/internal/semantic_analysis/codeql/loader.go b/internal/semantic_analysis/codeql/loader.go new file mode 100644 index 0000000..fce359d --- /dev/null +++ b/internal/semantic_analysis/codeql/loader.go @@ -0,0 +1,22 @@ +package codeql + +import ( + "os/exec" +) + +// Loader resolves the CodeQL CLI binary path. +type Loader struct { + binaryPath string +} + +// NewLoader creates a Loader, probing the PATH for the codeql binary. +func NewLoader() (*Loader, error) { + path, err := exec.LookPath("codeql") + if err != nil { + return nil, ErrCodeQLNotFound + } + return &Loader{binaryPath: path}, nil +} + +// BinaryPath returns the resolved CodeQL binary path. +func (l *Loader) BinaryPath() string { return l.binaryPath } diff --git a/internal/semantic_analysis/codeql/runner.go b/internal/semantic_analysis/codeql/runner.go new file mode 100644 index 0000000..9abbd3e --- /dev/null +++ b/internal/semantic_analysis/codeql/runner.go @@ -0,0 +1,27 @@ +package codeql + +import "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + +// Runner builds a CodeQL database and runs queries to produce call-graph edges. +// TODO(level-2): implement database creation, query execution, and result parsing. +type Runner struct { + loader *Loader + dbDir string +} + +// NewRunner creates a Runner for the given project and database directory. +func NewRunner(loader *Loader, dbDir string) *Runner { + return &Runner{loader: loader, dbDir: dbDir} +} + +// BuildDatabase creates a CodeQL database for the Go project at projectDir. +// TODO(level-2): run `codeql database create --language=go`. +func (r *Runner) BuildDatabase(projectDir string) error { + return ErrCodeQLNotImplemented +} + +// QueryCallGraph runs the CodeQL call-graph query and returns edges. +// TODO(level-2): run query, parse SARIF/CSV, produce GoCallEdge list. +func (r *Runner) QueryCallGraph() ([]schema.GoCallEdge, error) { + return nil, ErrCodeQLNotImplemented +} diff --git a/internal/syntactic_analysis/export.go b/internal/syntactic_analysis/export.go new file mode 100644 index 0000000..fad49bb --- /dev/null +++ b/internal/syntactic_analysis/export.go @@ -0,0 +1,15 @@ +package syntactic_analysis + +import ( + "go/token" + + "golang.org/x/tools/go/packages" +) + +// Fset returns the token.FileSet used during package loading. +// Used by CallGraphBuilder to resolve source positions. +func (b *SymbolTableBuilder) Fset() *token.FileSet { return b.fset } + +// Pkgs returns the map of loaded packages keyed by package path. +// Used by CallGraphBuilder for type-info lookups. +func (b *SymbolTableBuilder) Pkgs() map[string]*packages.Package { return b.pkgs } diff --git a/internal/syntactic_analysis/signature.go b/internal/syntactic_analysis/signature.go new file mode 100644 index 0000000..6f538bf --- /dev/null +++ b/internal/syntactic_analysis/signature.go @@ -0,0 +1,86 @@ +// Package syntactic_analysis builds the symbol table from Go source files. +package syntactic_analysis + +import ( + "fmt" + "go/types" + "strings" +) + +// signatureOf is the single canonicalizer for all signature strings in the analyzer. +// It produces the edge id used in GoCallable.signature, GoType.signature, and every +// call-graph edge source/target. All callers must use this function — never build +// signatures ad-hoc. Caller-side and callee-side ids are then guaranteed identical. +// +// Format: +// - Package function: "pkg/path.FuncName" +// - Method: "pkg/path.TypeName.MethodName" +// - Interface method: "pkg/path.InterfaceName.MethodName" +// - Type: "pkg/path.TypeName" +func signatureOf(obj types.Object) string { + if obj == nil { + return "" + } + pkg := obj.Pkg() + pkgPath := "" + if pkg != nil { + pkgPath = pkg.Path() + } + + switch o := obj.(type) { + case *types.Func: + sig := o.Type().(*types.Signature) + recv := sig.Recv() + if recv != nil { + // Method: extract the receiver type name, stripping pointer indirection. + recvType := recv.Type() + if ptr, ok := recvType.(*types.Pointer); ok { + recvType = ptr.Elem() + } + if named, ok := recvType.(*types.Named); ok { + typeName := named.Obj().Name() + return fmt.Sprintf("%s.%s.%s", pkgPath, typeName, o.Name()) + } + } + return fmt.Sprintf("%s.%s", pkgPath, o.Name()) + case *types.TypeName: + return fmt.Sprintf("%s.%s", pkgPath, o.Name()) + default: + return fmt.Sprintf("%s.%s", pkgPath, obj.Name()) + } +} + +// signatureOfNamed builds a type signature from a *types.Named directly. +// Used when we have the named type but not a types.Object. +func signatureOfNamed(named *types.Named) string { + if named == nil { + return "" + } + obj := named.Obj() + pkgPath := "" + if obj.Pkg() != nil { + pkgPath = obj.Pkg().Path() + } + return fmt.Sprintf("%s.%s", pkgPath, obj.Name()) +} + +// signatureForCall builds a callee signature from a *types.Func resolved at a call site. +func signatureForCall(fn *types.Func) string { + return signatureOf(fn) +} + +// normalizeReturnType joins multiple return types into a single parenthesized string. +// Single non-error returns are returned as-is; multiple returns become "(t1, t2, ...)". +func normalizeReturnType(results *types.Tuple) (joined string, parts []string) { + if results == nil || results.Len() == 0 { + return "", nil + } + parts = make([]string, results.Len()) + for i := 0; i < results.Len(); i++ { + parts[i] = results.At(i).Type().String() + } + if len(parts) == 1 { + return parts[0], parts + } + return "(" + strings.Join(parts, ", ") + ")", parts +} diff --git a/internal/syntactic_analysis/symbol_table.go b/internal/syntactic_analysis/symbol_table.go new file mode 100644 index 0000000..f6a7838 --- /dev/null +++ b/internal/syntactic_analysis/symbol_table.go @@ -0,0 +1,944 @@ +package syntactic_analysis + +import ( + "fmt" + "go/ast" + "go/token" + "go/types" + "os" + "reflect" + "strings" + "unicode" + + "golang.org/x/tools/go/packages" + + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +// SymbolTableBuilder constructs a symbol table by loading packages with full type +// information via golang.org/x/tools/go/packages. One builder per analysis run. +// +// Architecture mirrors codeanalyzer-python's SymbolTableBuilder: a cohesive struct +// with per-node-kind private methods, sharing the loaded package context on self. +type SymbolTableBuilder struct { + projectDir string + fset *token.FileSet + // pkgs is the flat list of loaded packages, keyed by package path. + pkgs map[string]*packages.Package +} + +// NewSymbolTableBuilder creates a builder for projectDir. +func NewSymbolTableBuilder(projectDir string) *SymbolTableBuilder { + return &SymbolTableBuilder{ + projectDir: projectDir, + fset: token.NewFileSet(), + pkgs: map[string]*packages.Package{}, + } +} + +// Build loads all packages under projectDir (or only targetFiles if non-empty), +// walks each file, and returns the symbol table keyed by relative file path. +func (b *SymbolTableBuilder) Build(targetFiles []string, skipTests bool) (map[string]schema.GoFile, error) { + cfg := &packages.Config{ + Mode: packages.NeedName | + packages.NeedFiles | + packages.NeedSyntax | + packages.NeedTypes | + packages.NeedTypesInfo | + packages.NeedImports | + packages.NeedDeps, + Dir: b.projectDir, + Fset: b.fset, + // Silence go vet; we only need type info, not a full build. + BuildFlags: []string{}, + } + + // Build pattern(s) for packages.Load. + patterns := []string{"./..."} + if len(targetFiles) > 0 { + // Map each target file to a file= pattern; packages.Load accepts multiple patterns. + patterns = make([]string, len(targetFiles)) + for i, f := range targetFiles { + patterns[i] = "file=" + f + } + } + + pkgList, err := packages.Load(cfg, patterns...) + if err != nil { + return nil, err + } + + // Collect packages, warn on errors but don't abort (graceful partial analysis). + for _, pkg := range pkgList { + if len(pkg.Errors) > 0 { + for _, e := range pkg.Errors { + utils.Warn("package %s: %v", pkg.PkgPath, e) + } + } + if pkg.Types != nil { + b.pkgs[pkg.PkgPath] = pkg + } + } + + symbolTable := map[string]schema.GoFile{} + + for _, pkg := range b.pkgs { + for i, astFile := range pkg.Syntax { + if i >= len(pkg.GoFiles) { + continue + } + filePath := pkg.GoFiles[i] + relPath := utils.RelativePath(b.projectDir, filePath) + if skipTests && utils.IsTestFile(relPath) { + continue + } + if utils.IsVendored(relPath) { + continue + } + goFile := b.buildGoFile(pkg, astFile, filePath, relPath) + symbolTable[relPath] = goFile + } + } + + // In Go a method can be defined in any file of the package; the main loop + // above only attaches a method when its receiver type is in the same file. + // This pass finds methods whose type lives in a sibling file and attaches them. + b.reconcileCrossFileMethods(symbolTable) + + return symbolTable, nil +} + +// reconcileCrossFileMethods attaches methods to their type's owner file when +// the method and its receiver type are declared in different files of the same package. +func (b *SymbolTableBuilder) reconcileCrossFileMethods(symbolTable map[string]schema.GoFile) { + // Build index: (pkgPath, shortTypeName) → relPath of the file that owns the type. + type typeKey struct{ pkgPath, typeName string } + typeOwner := make(map[typeKey]string) + for relPath, gf := range symbolTable { + pkgPath := b.filePkgPath(relPath) + for typeName := range gf.Types { + typeOwner[typeKey{pkgPath, typeName}] = relPath + } + } + + for _, pkg := range b.pkgs { + for i, astFile := range pkg.Syntax { + if i >= len(pkg.GoFiles) { + continue + } + filePath := pkg.GoFiles[i] + relPath := utils.RelativePath(b.projectDir, filePath) + + for _, decl := range astFile.Decls { + fd, ok := decl.(*ast.FuncDecl) + if !ok || fd.Recv == nil { + continue + } + typeName := b.receiverTypeName(fd.Recv) + if typeName == "" { + continue + } + ownerRelPath, ok := typeOwner[typeKey{pkg.PkgPath, typeName}] + if !ok || ownerRelPath == relPath { + continue // type not found or already handled by the main loop + } + + callable := b.buildCallable(pkg, astFile, fd) + if callable == nil { + continue + } + ownerFile, found := symbolTable[ownerRelPath] + if !found { + continue + } + gt, found := ownerFile.Types[typeName] + if !found { + continue + } + if _, alreadyPresent := gt.Methods[callable.Signature]; !alreadyPresent { + gt.Methods[callable.Signature] = *callable + ownerFile.Types[typeName] = gt + symbolTable[ownerRelPath] = ownerFile + } + } + } + } +} + +// filePkgPath returns the package import path for the file at relPath. +func (b *SymbolTableBuilder) filePkgPath(relPath string) string { + for _, pkg := range b.pkgs { + for _, absFile := range pkg.GoFiles { + if utils.RelativePath(b.projectDir, absFile) == relPath { + return pkg.PkgPath + } + } + } + return "" +} + +// ─── Per-file builder ────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildGoFile( + pkg *packages.Package, + astFile *ast.File, + absPath, relPath string, +) schema.GoFile { + info, _ := os.Stat(absPath) + hash, _ := utils.FileHash(absPath) + + var lastMod *float64 + var fileSize *int64 + if info != nil { + lm := float64(info.ModTime().Unix()) + float64(info.ModTime().Nanosecond())/1e9 + lastMod = &lm + sz := info.Size() + fileSize = &sz + } + var contentHash *string + if hash != "" { + contentHash = &hash + } + + gf := schema.GoFile{ + FilePath: relPath, + PackageName: astFile.Name.Name, + Imports: b.buildImports(astFile), + Comments: b.buildFileComments(astFile), + Types: map[string]schema.GoType{}, + Functions: map[string]schema.GoCallable{}, + Variables: b.buildPackageVars(pkg, astFile), + ContentHash: contentHash, + LastModified: lastMod, + FileSize: fileSize, + } + + // Walk top-level declarations. + for _, decl := range astFile.Decls { + switch d := decl.(type) { + case *ast.GenDecl: + b.processGenDecl(pkg, astFile, d, &gf) + case *ast.FuncDecl: + callable := b.buildCallable(pkg, astFile, d) + if callable == nil { + continue + } + if d.Recv != nil { + // Method — attach to its type. + if typeName := b.receiverTypeName(d.Recv); typeName != "" { + if gt, ok := gf.Types[typeName]; ok { + gt.Methods[callable.Signature] = *callable + gf.Types[typeName] = gt + } + } + } else { + gf.Functions[callable.Signature] = *callable + } + } + } + + return gf +} + +// ─── Imports ────────────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildImports(astFile *ast.File) []schema.GoImport { + var imports []schema.GoImport + for _, imp := range astFile.Imports { + path := strings.Trim(imp.Path.Value, `"`) + alias := "" + if imp.Name != nil { + alias = imp.Name.Name + } + pos := b.fset.Position(imp.Pos()) + end := b.fset.Position(imp.End()) + imports = append(imports, schema.GoImport{ + Module: path, + Alias: alias, + StartLine: pos.Line, + EndLine: end.Line, + }) + } + return imports +} + +// ─── Comments ───────────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildFileComments(astFile *ast.File) []schema.GoComment { + var comments []schema.GoComment + for _, cg := range astFile.Comments { + for _, c := range cg.List { + pos := b.fset.Position(c.Pos()) + end := b.fset.Position(c.End()) + comments = append(comments, schema.GoComment{ + Content: c.Text, + StartLine: pos.Line, + EndLine: end.Line, + StartColumn: pos.Column, + EndColumn: end.Column, + IsDocComment: strings.HasPrefix(c.Text, "//") || strings.HasPrefix(c.Text, "/*"), + }) + } + } + return comments +} + +func (b *SymbolTableBuilder) docComments(doc *ast.CommentGroup) []schema.GoComment { + if doc == nil { + return nil + } + var comments []schema.GoComment + for _, c := range doc.List { + pos := b.fset.Position(c.Pos()) + end := b.fset.Position(c.End()) + comments = append(comments, schema.GoComment{ + Content: c.Text, + StartLine: pos.Line, + EndLine: end.Line, + StartColumn: pos.Column, + EndColumn: end.Column, + IsDocComment: true, + }) + } + return comments +} + +// ─── GenDecl processor (type/var/const declarations) ───────────────────────── + +func (b *SymbolTableBuilder) processGenDecl( + pkg *packages.Package, + astFile *ast.File, + decl *ast.GenDecl, + gf *schema.GoFile, +) { + for _, spec := range decl.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + gt := b.buildType(pkg, astFile, decl, s) + if gt != nil { + gf.Types[gt.Name] = *gt + } + } + } +} + +// ─── Type builder ───────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildType( + pkg *packages.Package, + astFile *ast.File, + decl *ast.GenDecl, + spec *ast.TypeSpec, +) *schema.GoType { + typeName := spec.Name.Name + isExported := unicode.IsUpper(rune(typeName[0])) + + // Resolve via type info if available. + var typSig string + if pkg.TypesInfo != nil { + if obj, ok := pkg.TypesInfo.Defs[spec.Name]; ok && obj != nil { + typSig = signatureOf(obj) + } + } + if typSig == "" { + typSig = pkg.Types.Path() + "." + typeName + } + + pos := b.fset.Position(spec.Pos()) + end := b.fset.Position(spec.End()) + + gt := &schema.GoType{ + Name: typeName, + Signature: typSig, + Comments: b.docComments(decl.Doc), + IsExported: isExported, + Fields: []schema.GoField{}, + Methods: map[string]schema.GoCallable{}, + BaseTypes: []string{}, + InnerTypes: map[string]schema.GoType{}, + StartLine: pos.Line, + EndLine: end.Line, + } + + switch t := spec.Type.(type) { + case *ast.StructType: + gt.IsInterface = false + gt.Fields = b.buildStructFields(pkg, t) + gt.BaseTypes = b.embeddedTypes(pkg, t) + case *ast.InterfaceType: + gt.IsInterface = true + // Interface methods are collected when we process FuncDecl with receivers, + // but interface method signatures within the interface type are also recorded here. + b.collectInterfaceMethods(pkg, t, gt) + } + + return gt +} + +func (b *SymbolTableBuilder) buildStructFields(pkg *packages.Package, st *ast.StructType) []schema.GoField { + var fields []schema.GoField + if st.Fields == nil { + return fields + } + for _, field := range st.Fields.List { + typStr := b.typeString(pkg, field.Type) + tags := parseStructTags(field.Tag) + isEmbedded := len(field.Names) == 0 + + if isEmbedded { + pos := b.fset.Position(field.Pos()) + end := b.fset.Position(field.End()) + fields = append(fields, schema.GoField{ + Name: typStr, + Type: typStr, + Tags: tags, + IsExported: true, + IsEmbedded: true, + StartLine: pos.Line, + EndLine: end.Line, + }) + continue + } + for _, name := range field.Names { + pos := b.fset.Position(name.Pos()) + end := b.fset.Position(field.End()) + fields = append(fields, schema.GoField{ + Name: name.Name, + Type: typStr, + Tags: tags, + IsExported: unicode.IsUpper(rune(name.Name[0])), + IsEmbedded: false, + StartLine: pos.Line, + EndLine: end.Line, + }) + } + } + return fields +} + +func (b *SymbolTableBuilder) embeddedTypes(pkg *packages.Package, st *ast.StructType) []string { + var embedded []string + if st.Fields == nil { + return embedded + } + for _, field := range st.Fields.List { + if len(field.Names) == 0 { + embedded = append(embedded, b.typeString(pkg, field.Type)) + } + } + return embedded +} + +func (b *SymbolTableBuilder) collectInterfaceMethods(pkg *packages.Package, it *ast.InterfaceType, gt *schema.GoType) { + if it.Methods == nil { + return + } + for _, method := range it.Methods.List { + if len(method.Names) == 0 { + // Embedded interface — add to base_types. + gt.BaseTypes = append(gt.BaseTypes, b.typeString(pkg, method.Type)) + continue + } + for _, name := range method.Names { + pos := b.fset.Position(name.Pos()) + end := b.fset.Position(method.End()) + + var retType string + var retTypes []string + if ft, ok := method.Type.(*ast.FuncType); ok && ft.Results != nil { + for _, r := range ft.Results.List { + retTypes = append(retTypes, b.typeString(pkg, r.Type)) + } + retType = b.joinReturnTypes(retTypes) + } + + sig := pkg.Types.Path() + "." + gt.Name + "." + name.Name + callable := schema.GoCallable{ + Name: name.Name, + Path: "", + Signature: sig, + Parameters: b.buildFuncTypeParams(pkg, method.Type), + ReturnType: retType, + ReturnTypes: retTypes, + IsExported: unicode.IsUpper(rune(name.Name[0])), + ReceiverType: gt.Signature, + CallSites: []schema.GoCallsite{}, + InnerCallables: map[string]schema.GoCallable{}, + StartLine: pos.Line, + EndLine: end.Line, + } + gt.Methods[sig] = callable + } + } +} + +// ─── Callable builder ───────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildCallable( + pkg *packages.Package, + astFile *ast.File, + decl *ast.FuncDecl, +) *schema.GoCallable { + name := decl.Name.Name + isExported := unicode.IsUpper(rune(name[0])) + pos := b.fset.Position(decl.Pos()) + end := b.fset.Position(decl.End()) + + var sig string + if pkg.TypesInfo != nil { + if obj, ok := pkg.TypesInfo.Defs[decl.Name]; ok && obj != nil { + sig = signatureOf(obj) + } + } + if sig == "" { + if decl.Recv != nil { + if recvName := b.receiverTypeName(decl.Recv); recvName != "" { + sig = pkg.Types.Path() + "." + recvName + "." + name + } + } + if sig == "" { + sig = pkg.Types.Path() + "." + name + } + } + + var recvType, recvName string + if decl.Recv != nil && len(decl.Recv.List) > 0 { + rf := decl.Recv.List[0] + recvType = b.typeString(pkg, rf.Type) + if len(rf.Names) > 0 { + recvName = rf.Names[0].Name + } + } + + retType, retTypes := b.buildReturnTypes(pkg, decl.Type) + bodyStart := pos.Line + if decl.Body != nil { + bodyStart = b.fset.Position(decl.Body.Pos()).Line + } + + callable := &schema.GoCallable{ + Name: name, + Path: utils.RelativePath(b.projectDir, b.fset.File(decl.Pos()).Name()), + Signature: sig, + Comments: b.docComments(decl.Doc), + Parameters: b.buildParams(pkg, decl.Type), + ReturnType: retType, + ReturnTypes: retTypes, + IsExported: isExported, + ReceiverType: recvType, + ReceiverName: recvName, + CallSites: []schema.GoCallsite{}, + InnerCallables: map[string]schema.GoCallable{}, + LocalVariables: []schema.GoVariableDeclaration{}, + AccessedSymbols: []schema.GoSymbol{}, + StartLine: pos.Line, + EndLine: end.Line, + CodeStartLine: bodyStart, + CyclomaticComplexity: b.cyclomaticComplexity(decl), + } + + if decl.Body != nil { + callable.Code = b.nodeSource(decl) + callable.CallSites = b.buildCallSites(pkg, decl.Body) + callable.LocalVariables = b.buildLocalVars(pkg, decl.Body) + } + + return callable +} + +// ─── Parameters ─────────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildParams(pkg *packages.Package, ft *ast.FuncType) []schema.GoParameter { + if ft == nil || ft.Params == nil { + return nil + } + var params []schema.GoParameter + for _, field := range ft.Params.List { + typStr := b.typeString(pkg, field.Type) + isVariadic := false + if _, ok := field.Type.(*ast.Ellipsis); ok { + isVariadic = true + } + pos := b.fset.Position(field.Pos()) + end := b.fset.Position(field.End()) + if len(field.Names) == 0 { + params = append(params, schema.GoParameter{ + Name: "_", Type: typStr, IsVariadic: isVariadic, + StartLine: pos.Line, EndLine: end.Line, + }) + continue + } + for _, name := range field.Names { + params = append(params, schema.GoParameter{ + Name: name.Name, Type: typStr, IsVariadic: isVariadic, + StartLine: pos.Line, EndLine: end.Line, + }) + } + } + return params +} + +func (b *SymbolTableBuilder) buildFuncTypeParams(pkg *packages.Package, expr ast.Expr) []schema.GoParameter { + if ft, ok := expr.(*ast.FuncType); ok { + return b.buildParams(pkg, ft) + } + return nil +} + +// ─── Return types ───────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildReturnTypes(pkg *packages.Package, ft *ast.FuncType) (string, []string) { + if ft == nil || ft.Results == nil { + return "", nil + } + var parts []string + for _, field := range ft.Results.List { + typStr := b.typeString(pkg, field.Type) + if len(field.Names) == 0 { + parts = append(parts, typStr) + } else { + for range field.Names { + parts = append(parts, typStr) + } + } + } + return b.joinReturnTypes(parts), parts +} + +func (b *SymbolTableBuilder) joinReturnTypes(parts []string) string { + if len(parts) == 0 { + return "" + } + if len(parts) == 1 { + return parts[0] + } + return "(" + strings.Join(parts, ", ") + ")" +} + +// ─── Call sites ─────────────────────────────────────────────────────────────── + +// buildCallSites walks the function body and records every call expression. +// callee_signature is left nil here — the resolver backfills it in the call-graph stage. +func (b *SymbolTableBuilder) buildCallSites(pkg *packages.Package, body *ast.BlockStmt) []schema.GoCallsite { + var sites []schema.GoCallsite + ast.Inspect(body, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.GoStmt: + // Goroutine launch — record with is_goroutine=true. + if call, ok := node.Call.Fun.(*ast.CallExpr); ok { + _ = call + } + site := b.callExprToSite(pkg, node.Call, true) + if site != nil { + sites = append(sites, *site) + } + return false + case *ast.CallExpr: + site := b.callExprToSite(pkg, node, false) + if site != nil { + sites = append(sites, *site) + } + } + return true + }) + return sites +} + +func (b *SymbolTableBuilder) callExprToSite(pkg *packages.Package, call *ast.CallExpr, isGoroutine bool) *schema.GoCallsite { + pos := b.fset.Position(call.Pos()) + end := b.fset.Position(call.End()) + + var methodName, receiverExpr, receiverType string + isConstructor := false + + switch fn := call.Fun.(type) { + case *ast.Ident: + methodName = fn.Name + // Check if it's a type conversion / constructor call. + if pkg.TypesInfo != nil { + if obj := pkg.TypesInfo.ObjectOf(fn); obj != nil { + if _, ok := obj.(*types.TypeName); ok { + isConstructor = true + } + } + } + case *ast.SelectorExpr: + methodName = fn.Sel.Name + receiverExpr = b.exprString(fn.X) + if pkg.TypesInfo != nil { + if t := pkg.TypesInfo.TypeOf(fn.X); t != nil { + receiverType = t.String() + } + } + default: + methodName = b.exprString(call.Fun) + } + + // Collect argument types. + var argTypes []string + for _, arg := range call.Args { + if pkg.TypesInfo != nil { + if t := pkg.TypesInfo.TypeOf(arg); t != nil { + argTypes = append(argTypes, t.String()) + continue + } + } + argTypes = append(argTypes, "") + } + + return &schema.GoCallsite{ + MethodName: methodName, + ReceiverExpr: receiverExpr, + ReceiverType: receiverType, + ArgumentTypes: argTypes, + IsConstructorCall: isConstructor, + IsGoroutine: isGoroutine, + CalleeSignature: nil, // backfilled by the call-graph stage + StartLine: pos.Line, + StartColumn: pos.Column, + EndLine: end.Line, + EndColumn: end.Column, + } +} + +// ─── Local variables ────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildLocalVars(pkg *packages.Package, body *ast.BlockStmt) []schema.GoVariableDeclaration { + var vars []schema.GoVariableDeclaration + ast.Inspect(body, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.AssignStmt: + if node.Tok.String() == ":=" { + for i, lhs := range node.Lhs { + if ident, ok := lhs.(*ast.Ident); ok { + pos := b.fset.Position(ident.Pos()) + typStr := "" + if pkg.TypesInfo != nil { + if obj := pkg.TypesInfo.ObjectOf(ident); obj != nil { + typStr = obj.Type().String() + } + } + init := "" + if i < len(node.Rhs) { + init = b.exprString(node.Rhs[i]) + } + vars = append(vars, schema.GoVariableDeclaration{ + Name: ident.Name, Type: typStr, Initializer: init, + Scope: "function", StartLine: pos.Line, EndLine: pos.Line, + }) + } + } + } + case *ast.DeclStmt: + if gen, ok := node.Decl.(*ast.GenDecl); ok { + for _, spec := range gen.Specs { + if vs, ok := spec.(*ast.ValueSpec); ok { + for i, name := range vs.Names { + pos := b.fset.Position(name.Pos()) + typStr := "" + if vs.Type != nil { + typStr = b.typeString(pkg, vs.Type) + } else if pkg.TypesInfo != nil { + if obj := pkg.TypesInfo.ObjectOf(name); obj != nil { + typStr = obj.Type().String() + } + } + init := "" + if i < len(vs.Values) { + init = b.exprString(vs.Values[i]) + } + vars = append(vars, schema.GoVariableDeclaration{ + Name: name.Name, Type: typStr, Initializer: init, + Scope: "function", StartLine: pos.Line, EndLine: pos.Line, + }) + } + } + } + } + } + return true + }) + return vars +} + +// ─── Package-level variables ────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildPackageVars(pkg *packages.Package, astFile *ast.File) []schema.GoVariableDeclaration { + var vars []schema.GoVariableDeclaration + for _, decl := range astFile.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + if gen.Tok.String() != "var" && gen.Tok.String() != "const" { + continue + } + for _, spec := range gen.Specs { + vs, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + for i, name := range vs.Names { + pos := b.fset.Position(name.Pos()) + typStr := "" + if vs.Type != nil { + typStr = b.typeString(pkg, vs.Type) + } else if pkg.TypesInfo != nil { + if obj := pkg.TypesInfo.ObjectOf(name); obj != nil { + typStr = obj.Type().String() + } + } + init := "" + if i < len(vs.Values) { + init = b.exprString(vs.Values[i]) + } + vars = append(vars, schema.GoVariableDeclaration{ + Name: name.Name, Type: typStr, Initializer: init, + Scope: "package", StartLine: pos.Line, EndLine: pos.Line, + }) + } + } + } + return vars +} + +// ─── Cyclomatic complexity ──────────────────────────────────────────────────── + +// cyclomaticComplexity computes McCabe complexity: 1 + decision points. +func (b *SymbolTableBuilder) cyclomaticComplexity(decl *ast.FuncDecl) int { + if decl.Body == nil { + return 0 + } + complexity := 1 + ast.Inspect(decl.Body, func(n ast.Node) bool { + switch n.(type) { + case *ast.IfStmt, *ast.ForStmt, *ast.RangeStmt, *ast.SwitchStmt, + *ast.TypeSwitchStmt, *ast.SelectStmt, *ast.CaseClause, + *ast.CommClause: + complexity++ + } + return true + }) + return complexity +} + +// ─── Helpers ────────────────────────────────────────────────────────────────── + +// receiverTypeName extracts the base type name from a receiver field list. +func (b *SymbolTableBuilder) receiverTypeName(recv *ast.FieldList) string { + if recv == nil || len(recv.List) == 0 { + return "" + } + expr := recv.List[0].Type + // Strip pointer. + if star, ok := expr.(*ast.StarExpr); ok { + expr = star.X + } + if ident, ok := expr.(*ast.Ident); ok { + return ident.Name + } + return "" +} + +// typeString returns a human-readable string for an ast.Expr type node. +func (b *SymbolTableBuilder) typeString(pkg *packages.Package, expr ast.Expr) string { + if pkg.TypesInfo != nil { + if t := pkg.TypesInfo.TypeOf(expr); t != nil { + return t.String() + } + } + // Fallback: print the expression. + return b.exprString(expr) +} + +// exprString returns a best-effort source representation of an expression. +func (b *SymbolTableBuilder) exprString(expr ast.Expr) string { + if expr == nil { + return "" + } + switch e := expr.(type) { + case *ast.Ident: + return e.Name + case *ast.SelectorExpr: + return b.exprString(e.X) + "." + e.Sel.Name + case *ast.StarExpr: + return "*" + b.exprString(e.X) + case *ast.ArrayType: + return "[]" + b.exprString(e.Elt) + case *ast.MapType: + return "map[" + b.exprString(e.Key) + "]" + b.exprString(e.Value) + case *ast.InterfaceType: + return "interface{}" + case *ast.BasicLit: + return e.Value + default: + return fmt.Sprint(expr) + } +} + +// nodeSource extracts the raw source text of a node (best effort). +func (b *SymbolTableBuilder) nodeSource(node ast.Node) string { + pos := b.fset.Position(node.Pos()) + if pos.Filename == "" { + return "" + } + data, err := os.ReadFile(pos.Filename) + if err != nil { + return "" + } + startOff := b.fset.File(node.Pos()).Offset(node.Pos()) + endOff := b.fset.File(node.End()).Offset(node.End()) + if startOff < 0 || endOff > len(data) || startOff >= endOff { + return "" + } + return string(data[startOff:endOff]) +} + +// parseStructTags parses a struct tag literal into a key→value map. +// e.g. `json:"name,omitempty" db:"name"` → {"json": "name,omitempty", "db": "name"} +func parseStructTags(lit *ast.BasicLit) map[string]string { + tags := map[string]string{} + if lit == nil { + return tags + } + raw := strings.Trim(lit.Value, "`") + st := reflect.StructTag(raw) + // Iterate common tag keys; for a full parse, walk the raw string. + for _, key := range extractTagKeys(raw) { + if v := st.Get(key); v != "" { + tags[key] = v + } + } + return tags +} + +// extractTagKeys extracts tag key names from a raw struct tag string. +func extractTagKeys(raw string) []string { + var keys []string + for len(raw) > 0 { + raw = strings.TrimLeft(raw, " \t") + if raw == "" { + break + } + idx := strings.IndexByte(raw, ':') + if idx < 0 { + break + } + keys = append(keys, raw[:idx]) + // Skip past the value. + rest := raw[idx+1:] + if len(rest) == 0 || rest[0] != '"' { + break + } + end := strings.IndexByte(rest[1:], '"') + if end < 0 { + break + } + raw = rest[end+2:] + } + return keys +} + +// ensure fmt is used (used in exprString fallback and signature.go). +var _ = fmt.Sprintf diff --git a/internal/utils/fs.go b/internal/utils/fs.go new file mode 100644 index 0000000..e46da2e --- /dev/null +++ b/internal/utils/fs.go @@ -0,0 +1,81 @@ +// Package utils provides filesystem helpers and logging utilities. +package utils + +import ( + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +// IsTestFile reports whether path is a Go test file (_test.go suffix). +func IsTestFile(path string) bool { + return strings.HasSuffix(path, "_test.go") +} + +// IsVendored reports whether path is under a vendored or generated directory. +func IsVendored(path string) bool { + for _, seg := range strings.Split(filepath.ToSlash(path), "/") { + switch seg { + case "vendor", "testdata", ".git": + return true + } + } + return false +} + +// RelativePath returns path relative to root, or path itself on error. +func RelativePath(root, path string) string { + rel, err := filepath.Rel(root, path) + if err != nil { + return path + } + return rel +} + +// FileHash returns the SHA-256 hex digest of the file at path. +func FileHash(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return fmt.Sprintf("%x", h.Sum(nil)), nil +} + +// DiscoverGoFiles returns all *.go files under root, skipping vendored dirs +// and optionally test files. +func DiscoverGoFiles(root string, skipTests bool) ([]string, error) { + var files []string + err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil // skip unreadable entries gracefully + } + if d.IsDir() { + if IsVendored(path) { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") { + return nil + } + if skipTests && IsTestFile(path) { + return nil + } + files = append(files, path) + return nil + }) + return files, err +} + +// EnsureDir creates dir and all parents if they don't exist. +func EnsureDir(dir string) error { + return os.MkdirAll(dir, 0o755) +} diff --git a/internal/utils/logging.go b/internal/utils/logging.go new file mode 100644 index 0000000..a5c7674 --- /dev/null +++ b/internal/utils/logging.go @@ -0,0 +1,30 @@ +package utils + +import ( + "fmt" + "os" +) + +var verbosity int + +// SetVerbosity sets the global log level (0=quiet, 1=info, 2=debug). +func SetVerbosity(v int) { verbosity = v } + +// Info logs an informational message when verbosity >= 1. +func Info(format string, args ...any) { + if verbosity >= 1 { + fmt.Fprintf(os.Stderr, "[codeanalyzer-go] "+format+"\n", args...) + } +} + +// Debug logs a debug message when verbosity >= 2. +func Debug(format string, args ...any) { + if verbosity >= 2 { + fmt.Fprintf(os.Stderr, "[codeanalyzer-go DEBUG] "+format+"\n", args...) + } +} + +// Warn always prints a warning to stderr. +func Warn(format string, args ...any) { + fmt.Fprintf(os.Stderr, "[codeanalyzer-go WARN] "+format+"\n", args...) +} diff --git a/testdata/fixture/go.mod b/testdata/fixture/go.mod new file mode 100644 index 0000000..30818c1 --- /dev/null +++ b/testdata/fixture/go.mod @@ -0,0 +1,3 @@ +module example.com/fixture + +go 1.21 diff --git a/testdata/fixture/main.go b/testdata/fixture/main.go new file mode 100644 index 0000000..f20e3f9 --- /dev/null +++ b/testdata/fixture/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "fmt" + + "example.com/fixture/pkg/greeter" +) + +func main() { + g := greeter.New("Hello") + msg := g.Greet("World") + fmt.Println(msg) + loud := greeter.Shout(msg) + fmt.Println(loud) +} diff --git a/testdata/fixture/pkg/greeter/greeter.go b/testdata/fixture/pkg/greeter/greeter.go new file mode 100644 index 0000000..85f19c6 --- /dev/null +++ b/testdata/fixture/pkg/greeter/greeter.go @@ -0,0 +1,29 @@ +// Package greeter provides simple greeting functionality. +package greeter + +import "fmt" + +// Greeter holds a greeting prefix. +type Greeter struct { + Prefix string `json:"prefix"` +} + +// New creates a Greeter with the given prefix. +func New(prefix string) *Greeter { + return &Greeter{Prefix: prefix} +} + +// Greet returns a greeting for name. +func (g *Greeter) Greet(name string) string { + return fmt.Sprintf("%s, %s!", g.Prefix, name) +} + +// Logger is a simple logging interface. +type Logger interface { + Log(msg string) +} + +// Shout returns the message in a louder form. +func Shout(msg string) string { + return msg + "!!!" +} diff --git a/testdata/realistic/go.mod b/testdata/realistic/go.mod new file mode 100644 index 0000000..f1da8db --- /dev/null +++ b/testdata/realistic/go.mod @@ -0,0 +1,3 @@ +module example.com/realistic + +go 1.21 diff --git a/testdata/realistic/main.go b/testdata/realistic/main.go new file mode 100644 index 0000000..81ae136 --- /dev/null +++ b/testdata/realistic/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + "log" + + "example.com/realistic/server" + "example.com/realistic/worker" +) + +func main() { + cfg := server.Config{Host: "localhost", Port: 8080} + srv, err := server.New(cfg) + if err != nil { + log.Fatal(err) + } + + fmt.Println(srv.Addr()) + fmt.Println(server.Tags("env", "prod", "region", "us-east-1")) + + w := worker.New() + w.Run(nil, worker.Task{ID: 1, Payload: "hello"}) + + combined := worker.Combine(worker.Result{TaskID: 1, Output: "a"}, worker.Result{TaskID: 2, Output: "b"}) + fmt.Println(combined.Output) +} diff --git a/testdata/realistic/server/middleware.go b/testdata/realistic/server/middleware.go new file mode 100644 index 0000000..70ae722 --- /dev/null +++ b/testdata/realistic/server/middleware.go @@ -0,0 +1,19 @@ +// Second file in the server package — exercises multi-file package detection. +package server + +import ( + "fmt" + "strings" +) + +// Tags formats key-value pairs into a single string. +// Exercises: variadic parameter (pairs ...string), is_variadic=true. +func Tags(pairs ...string) string { + return fmt.Sprintf("[%s]", strings.Join(pairs, ", ")) +} + +// Describe returns a human-readable description of the server. +// Exercises: value receiver (s Server) vs pointer receiver in server.go. +func (s Server) Describe() string { + return fmt.Sprintf("server at %s", s.Addr()) +} diff --git a/testdata/realistic/server/server.go b/testdata/realistic/server/server.go new file mode 100644 index 0000000..4034e77 --- /dev/null +++ b/testdata/realistic/server/server.go @@ -0,0 +1,53 @@ +// Package server provides a minimal configurable server. +package server + +import ( + "errors" + "fmt" +) + +// Config holds server configuration. +type Config struct { + Host string `json:"host" validate:"required"` + Port int `json:"port"` +} + +// Server wraps a Config with lifecycle state. +// It embeds Config directly so callers can access Host and Port without indirection. +type Server struct { + Config // embedded — exercises GoField.is_embedded + ready bool // unexported field — exercises is_exported=false on GoField +} + +// New creates a Server, returning an error if the config is invalid. +// Exercises: multiple return types (*Server, error), (T, error) idiom. +func New(cfg Config) (*Server, error) { + if cfg.Host == "" { + return nil, errors.New("host required") + } + return &Server{Config: cfg}, nil +} + +// Addr returns the host:port address string. +// Exercises: pointer receiver (*Server), non-empty receiver_type / receiver_name. +func (s *Server) Addr() string { + return fmt.Sprintf("%s:%d", s.Host, s.Port) +} + +// Validate checks the config fields and returns any validation errors. +// Exercises: named return, multiple return types (bool, error). +func (s *Server) Validate() (bool, error) { + if s.Host == "" { + return false, errors.New("host is empty") + } + if s.Port <= 0 { + return false, fmt.Errorf("invalid port: %d", s.Port) + } + return true, nil +} + +// shutdown performs internal cleanup. +// Exercises: unexported method — is_exported=false. +func (s *Server) shutdown() { + s.ready = false +} diff --git a/testdata/realistic/worker/worker.go b/testdata/realistic/worker/worker.go new file mode 100644 index 0000000..0559946 --- /dev/null +++ b/testdata/realistic/worker/worker.go @@ -0,0 +1,65 @@ +// Package worker runs tasks concurrently. +package worker + +import ( + "fmt" + "sync" +) + +// Task is a unit of work. +type Task struct { + ID int `json:"id"` + Payload string `json:"payload"` +} + +// Result holds the outcome of processing a Task. +type Result struct { + TaskID int `json:"task_id"` + Output string `json:"output"` +} + +// Processor is the processing interface. +// Exercises: interface type, is_interface=true. +type Processor interface { + Process(t Task) (Result, error) +} + +// Worker runs tasks in background goroutines. +type Worker struct { + mu sync.Mutex + done bool +} + +// New creates a Worker. +func New() *Worker { + return &Worker{} +} + +// Run launches a goroutine to process t. +// Exercises: goroutine launch — GoCallsite.is_goroutine=true for the w.execute call. +func (w *Worker) Run(p Processor, t Task) { + go w.execute(p, t) +} + +// Combine merges multiple results into a single Result. +// Exercises: variadic parameter (results ...Result), is_variadic=true. +func Combine(results ...Result) Result { + out := Result{} + for _, r := range results { + out.Output += r.Output + } + return out +} + +// execute processes a task under the mutex. +// Exercises: unexported method (is_exported=false), cyclomatic_complexity > 1 (if branch). +func (w *Worker) execute(p Processor, t Task) { + w.mu.Lock() + defer w.mu.Unlock() + r, err := p.Process(t) + if err != nil { + _ = fmt.Errorf("task %d: %w", t.ID, err) + return + } + _ = r +} From 07ed61d5ec8dc968bcaf5b395989c74f2e5a5c36 Mon Sep 17 00:00:00 2001 From: Saurabh Sinha Date: Wed, 17 Jun 2026 09:58:22 -0400 Subject: [PATCH 2/6] chore: merge standard Go .gitignore with project-specific entries Incorporates the GitHub Go template (*.dll, *.so, go.work, .env, etc.) alongside project-specific ignores for built binaries and .claude/. Signed-off-by: Saurabh Sinha --- .gitignore | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 169d632..08088c3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,19 +1,37 @@ -# Binaries +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Project-specific binaries /codeanalyzer /codeanalyzer-go -*.exe # Build output /dist/ /bin/ +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out +coverage.txt + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work +go.work.sum + +# env file +.env + # Claude Code session data .claude/ # macOS .DS_Store - -# Go test cache / coverage -*.test -*.out -coverage.txt From 81da2b101ee2c7c7f822513ee1aacfb2db8bba94 Mon Sep 17 00:00:00 2001 From: Saurabh Sinha Date: Wed, 17 Jun 2026 16:20:20 -0400 Subject: [PATCH 3/6] test: add functional test coverage and restructure testdata fixtures Bug fix: populate InnerCallables by walking ast.FuncLit nodes in buildInnerCallables; add *ast.FuncLit: return false to buildCallSites so closure call sites are not double-counted in the outer function. New test files: - cmd/codeanalyzer/main_test.go: CLI integration tests covering --version, --format validation, --output, --analysis-level, --skip-tests, --target-files - internal/analysis/registry_test.go: orderPasses topo-sort tests and RunPipeline smoke test - internal/semantic_analysis/call_graph_test.go: MergeEdges unit tests - internal/utils/fs_test.go: table-driven tests for IsTestFile, IsVendored, FileHash, EnsureDir, DiscoverGoFiles - internal/core/skip_tests_test.go: SkipTests true/false behaviour - internal/core/incremental_test.go: TargetFiles single/multi package and nil Test additions to existing files: - chi_test.go: InnerCallables populated, IsConstructorCall for methodTyp() type conversion, init() presence and is_exported - multipackage_test.go (renamed from realistic_test.go): LocalVariables present with correct type and scope Testdata restructure: - testdata/fixture/ -> testdata/greeter/ - testdata/realistic/ -> testdata/multipackage/ - testdata/chi/: replace single-file wrapper with chi v5 library source (35 files) - testdata/generics/: add generics fixture (fn + set packages) - testdata/multipackage/server/server_test.go: minimal test file for --skip-tests=false coverage Signed-off-by: Saurabh Sinha --- README.md | 32 +- cmd/codeanalyzer/main.go | 21 +- cmd/codeanalyzer/main_test.go | 205 ++++ internal/analysis/registry_test.go | 137 +++ internal/core/analyzer_test.go | 119 +-- internal/core/chi_test.go | 196 ++++ internal/core/errors_test.go | 76 ++ internal/core/generics_test.go | 127 +++ internal/core/incremental_test.go | 120 +++ ...realistic_test.go => multipackage_test.go} | 149 ++- internal/core/skip_tests_test.go | 70 ++ internal/core/testsetup_test.go | 99 ++ internal/semantic_analysis/call_graph_test.go | 116 +++ internal/syntactic_analysis/symbol_table.go | 68 ++ internal/utils/fs_test.go | 206 ++++ testdata/chi/chain.go | 49 + testdata/chi/chi.go | 137 +++ testdata/chi/context.go | 166 ++++ testdata/chi/go.mod | 3 + testdata/chi/middleware/basic_auth.go | 33 + testdata/chi/middleware/clean_path.go | 28 + testdata/chi/middleware/client_ip.go | 263 ++++++ testdata/chi/middleware/compress.go | 392 ++++++++ testdata/chi/middleware/content_charset.go | 45 + testdata/chi/middleware/content_encoding.go | 34 + testdata/chi/middleware/content_type.go | 45 + testdata/chi/middleware/get_head.go | 39 + testdata/chi/middleware/heartbeat.go | 26 + testdata/chi/middleware/logger.go | 178 ++++ testdata/chi/middleware/maybe.go | 18 + testdata/chi/middleware/middleware.go | 23 + testdata/chi/middleware/nocache.go | 59 ++ testdata/chi/middleware/page_route.go | 20 + testdata/chi/middleware/path_rewrite.go | 16 + testdata/chi/middleware/profiler.go | 49 + testdata/chi/middleware/realip.go | 53 ++ testdata/chi/middleware/recoverer.go | 203 ++++ testdata/chi/middleware/request_id.go | 96 ++ testdata/chi/middleware/request_size.go | 18 + testdata/chi/middleware/route_headers.go | 146 +++ testdata/chi/middleware/strip.go | 77 ++ testdata/chi/middleware/sunset.go | 25 + testdata/chi/middleware/supress_notfound.go | 27 + testdata/chi/middleware/terminal.go | 63 ++ testdata/chi/middleware/throttle.go | 151 +++ testdata/chi/middleware/timeout.go | 48 + testdata/chi/middleware/url_format.go | 77 ++ testdata/chi/middleware/value.go | 17 + testdata/chi/middleware/wrap_writer.go | 243 +++++ testdata/chi/mux.go | 526 +++++++++++ testdata/chi/tree.go | 877 ++++++++++++++++++ testdata/fixture/go.mod | 3 - testdata/generics/fn/fn.go | 46 + testdata/generics/go.mod | 3 + testdata/generics/main.go | 22 + testdata/generics/set/set.go | 36 + testdata/greeter/go.mod | 3 + testdata/{fixture => greeter}/main.go | 2 +- .../pkg/greeter/greeter.go | 0 testdata/multipackage/go.mod | 3 + testdata/{realistic => multipackage}/main.go | 4 +- .../server/middleware.go | 0 .../server/server.go | 0 testdata/multipackage/server/server_test.go | 16 + .../worker/worker.go | 0 testdata/realistic/go.mod | 3 - 66 files changed, 5968 insertions(+), 184 deletions(-) create mode 100644 cmd/codeanalyzer/main_test.go create mode 100644 internal/analysis/registry_test.go create mode 100644 internal/core/chi_test.go create mode 100644 internal/core/errors_test.go create mode 100644 internal/core/generics_test.go create mode 100644 internal/core/incremental_test.go rename internal/core/{realistic_test.go => multipackage_test.go} (69%) create mode 100644 internal/core/skip_tests_test.go create mode 100644 internal/core/testsetup_test.go create mode 100644 internal/semantic_analysis/call_graph_test.go create mode 100644 internal/utils/fs_test.go create mode 100644 testdata/chi/chain.go create mode 100644 testdata/chi/chi.go create mode 100644 testdata/chi/context.go create mode 100644 testdata/chi/go.mod create mode 100644 testdata/chi/middleware/basic_auth.go create mode 100644 testdata/chi/middleware/clean_path.go create mode 100644 testdata/chi/middleware/client_ip.go create mode 100644 testdata/chi/middleware/compress.go create mode 100644 testdata/chi/middleware/content_charset.go create mode 100644 testdata/chi/middleware/content_encoding.go create mode 100644 testdata/chi/middleware/content_type.go create mode 100644 testdata/chi/middleware/get_head.go create mode 100644 testdata/chi/middleware/heartbeat.go create mode 100644 testdata/chi/middleware/logger.go create mode 100644 testdata/chi/middleware/maybe.go create mode 100644 testdata/chi/middleware/middleware.go create mode 100644 testdata/chi/middleware/nocache.go create mode 100644 testdata/chi/middleware/page_route.go create mode 100644 testdata/chi/middleware/path_rewrite.go create mode 100644 testdata/chi/middleware/profiler.go create mode 100644 testdata/chi/middleware/realip.go create mode 100644 testdata/chi/middleware/recoverer.go create mode 100644 testdata/chi/middleware/request_id.go create mode 100644 testdata/chi/middleware/request_size.go create mode 100644 testdata/chi/middleware/route_headers.go create mode 100644 testdata/chi/middleware/strip.go create mode 100644 testdata/chi/middleware/sunset.go create mode 100644 testdata/chi/middleware/supress_notfound.go create mode 100644 testdata/chi/middleware/terminal.go create mode 100644 testdata/chi/middleware/throttle.go create mode 100644 testdata/chi/middleware/timeout.go create mode 100644 testdata/chi/middleware/url_format.go create mode 100644 testdata/chi/middleware/value.go create mode 100644 testdata/chi/middleware/wrap_writer.go create mode 100644 testdata/chi/mux.go create mode 100644 testdata/chi/tree.go delete mode 100644 testdata/fixture/go.mod create mode 100644 testdata/generics/fn/fn.go create mode 100644 testdata/generics/go.mod create mode 100644 testdata/generics/main.go create mode 100644 testdata/generics/set/set.go create mode 100644 testdata/greeter/go.mod rename testdata/{fixture => greeter}/main.go (82%) rename testdata/{fixture => greeter}/pkg/greeter/greeter.go (100%) create mode 100644 testdata/multipackage/go.mod rename testdata/{realistic => multipackage}/main.go (87%) rename testdata/{realistic => multipackage}/server/middleware.go (100%) rename testdata/{realistic => multipackage}/server/server.go (100%) create mode 100644 testdata/multipackage/server/server_test.go rename testdata/{realistic => multipackage}/worker/worker.go (100%) delete mode 100644 testdata/realistic/go.mod diff --git a/README.md b/README.md index 608e3a3..9310f1c 100644 --- a/README.md +++ b/README.md @@ -187,7 +187,11 @@ codeanalyzer-go/ │ ├── analysis/ # Pluggable pass interface + registry (topo-ordered pipeline) │ ├── frameworks/ # BaseEntrypointFinder — extension seam for framework passes │ └── utils/ # DiscoverGoFiles, IsVendored, IsTestFile, logging -└── testdata/fixture/ # Minimal Go fixture used by tests +├── testdata/ +│ ├── fixture/ # Minimal two-package fixture (basic struct/interface/call sites) +│ ├── realistic/ # Richer fixture covering embedded fields, variadic params, goroutines, … +│ ├── generics/ # Go 1.18+ generics fixture (Set[T], union-constraint interfaces, Map[T,U]) +│ └── chi/ # External-dep fixture (chi v5, vendored) for HTTP handler patterns ``` The `core` package is a pure orchestrator: it calls `syntactic_analysis` → `semantic_analysis` → `analysis.RunPipeline` → optional CodeQL in sequence, with no inlined parsing logic. Framework-specific analysis extends through the `analysis/` + `frameworks/` layer without touching `core`. @@ -200,7 +204,31 @@ The `core` package is a pure orchestrator: it calls `syntactic_analysis` → `se go test ./... ``` -Tests run against `testdata/fixture/` and `testdata/realistic/` — a minimal two-package and a richer multi-package Go module. All 33 tests cover symbol table correctness, call graph edges, JSON round-trip, output format validation, and caching/incremental behaviour. +Tests run against four fixtures: `testdata/fixture/` (basic), `testdata/realistic/` (multi-file packages, goroutines, variadic params), `testdata/generics/` (Go 1.18+ generics — `Set[T]`, union constraints, multi-type-param functions), and `testdata/chi/` (external dependency via vendored chi v5, HTTP handler patterns). All 57 tests cover symbol table correctness, generic receiver attribution, call graph edges, JSON round-trip, output format validation, caching behaviour, and error paths. + +`go test` caches passing results by source hash. To force a full re-run: + +```bash +go clean -testcache && go test ./... +``` + +The analyzer's own `CacheDir` (used inside tests for `analysis_cache.json` and `go_mod_hash`) is written to OS temp directories that are wiped automatically when the test binary exits — there is no persistent on-disk state between test runs. The chi fixture is fully vendored, so tests never require network access. + +### Clearing the production cache + +By default the CLI writes its cache to `~/.cldk/go-cache`. To bypass it for a single run: + +```bash +codeanalyzer-go -i ./my-project --eager +``` + +To delete it entirely: + +```bash +rm -rf ~/.cldk/go-cache +``` + +If you pass a custom `--cache-dir`, remove that directory instead. ### Running from source diff --git a/cmd/codeanalyzer/main.go b/cmd/codeanalyzer/main.go index 7425781..6968fd8 100644 --- a/cmd/codeanalyzer/main.go +++ b/cmd/codeanalyzer/main.go @@ -5,6 +5,7 @@ package main import ( + "encoding/json" "fmt" "os" @@ -49,12 +50,20 @@ via CLDK(language="go").analysis(project_path=...).`, SilenceUsage: true, RunE: func(cmd *cobra.Command, args []string) error { if showVersion { - fmt.Println("codeanalyzer-go " + version) + cmd.Println("codeanalyzer-go " + version) return nil } if inputPath == "" { return fmt.Errorf("--input / -i is required") } + switch format { + case "", "json": + // valid + case "msgpack": + return fmt.Errorf("msgpack output is not yet implemented; use --format json") + default: + return fmt.Errorf("unsupported output format %q; supported: json", format) + } utils.SetVerbosity(verbosity) if cacheDir == "" { @@ -81,6 +90,16 @@ via CLDK(language="go").analysis(project_path=...).`, return err } + // When no --output dir is given, write JSON to cobra's output + // writer so tests can capture it via cmd.SetOut. + if outputDir == "" { + data, err := json.Marshal(app) + if err != nil { + return err + } + _, err = cmd.OutOrStdout().Write(data) + return err + } return core.WriteOutput(app, outputDir, format) }, } diff --git a/cmd/codeanalyzer/main_test.go b/cmd/codeanalyzer/main_test.go new file mode 100644 index 0000000..67fdefd --- /dev/null +++ b/cmd/codeanalyzer/main_test.go @@ -0,0 +1,205 @@ +package main + +// CLI integration tests. These call rootCmd().Execute() directly (same +// package, so the unexported function is accessible) with controlled args and +// capture cobra's output buffer. No subprocess or binary required. + +import ( + "bytes" + "encoding/json" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +// cliTestdataDir returns the absolute path to the repo-level testdata directory. +func cliTestdataDir() string { + _, thisFile, _, _ := runtime.Caller(0) + abs, _ := filepath.Abs(filepath.Join(filepath.Dir(thisFile), "..", "..", "testdata")) + return abs +} + +// runCmd executes rootCmd with the given args and returns (stdout, stderr, error). +func runCmd(args ...string) (stdout, stderr string, err error) { + cmd := rootCmd() + var outBuf, errBuf bytes.Buffer + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + cmd.SetArgs(args) + err = cmd.Execute() + return outBuf.String(), errBuf.String(), err +} + +// ── Flag validation ─────────────────────────────────────────────────────────── + +func TestRootCmd_MissingInputReturnsError(t *testing.T) { + _, _, err := runCmd() + if err == nil { + t.Fatal("expected error when --input is missing, got nil") + } + if !strings.Contains(err.Error(), "required") { + t.Errorf("error should mention 'required'; got %q", err.Error()) + } +} + +func TestRootCmd_NonExistentInputReturnsError(t *testing.T) { + _, _, err := runCmd("--input", filepath.Join(t.TempDir(), "does_not_exist")) + if err == nil { + t.Fatal("expected error for non-existent --input path, got nil") + } +} + +func TestRootCmd_UnknownFormatReturnsError(t *testing.T) { + td := cliTestdataDir() + _, _, err := runCmd("--input", filepath.Join(td, "greeter"), "--format", "csv") + if err == nil { + t.Fatal("expected error for unknown --format value, got nil") + } +} + +// ── --version ──────────────────────────────────────────────────────────────── + +func TestRootCmd_VersionFlag(t *testing.T) { + out, _, err := runCmd("--version") + if err != nil { + t.Fatalf("--version returned unexpected error: %v", err) + } + if !strings.Contains(out, "codeanalyzer-go") { + t.Errorf("--version output should contain 'codeanalyzer-go'; got %q", out) + } + if !strings.Contains(out, version) { + t.Errorf("--version output should contain version %q; got %q", version, out) + } +} + +// ── --output writes analysis.json ──────────────────────────────────────────── + +func TestRootCmd_OutputDirWritesFile(t *testing.T) { + td := cliTestdataDir() + outDir := t.TempDir() + + _, _, err := runCmd( + "--input", filepath.Join(td, "greeter"), + "--output", outDir, + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + + analysisPath := filepath.Join(outDir, "analysis.json") + if _, statErr := os.Stat(analysisPath); statErr != nil { + t.Fatalf("analysis.json not created in output dir: %v", statErr) + } +} + +func TestRootCmd_NoOutputWritesToStdout(t *testing.T) { + td := cliTestdataDir() + + out, _, err := runCmd( + "--input", filepath.Join(td, "greeter"), + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + if out == "" { + t.Fatal("expected JSON on stdout when --output is omitted, got empty string") + } + var v interface{} + if jsonErr := json.Unmarshal([]byte(out), &v); jsonErr != nil { + t.Errorf("stdout is not valid JSON: %v\noutput: %s", jsonErr, out) + } +} + +// ── --analysis-level ───────────────────────────────────────────────────────── + +func TestRootCmd_Level1ProducesNoCallGraph(t *testing.T) { + td := cliTestdataDir() + + out, _, err := runCmd( + "--input", filepath.Join(td, "greeter"), + "--analysis-level", "1", + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + + var result struct { + CallGraph []interface{} `json:"call_graph"` + } + if jsonErr := json.Unmarshal([]byte(out), &result); jsonErr != nil { + t.Fatalf("stdout is not valid JSON: %v", jsonErr) + } + if len(result.CallGraph) != 0 { + t.Errorf("level 1 should produce no call graph edges; got %d", len(result.CallGraph)) + } +} + +func TestRootCmd_Level2ProducesCallGraph(t *testing.T) { + td := cliTestdataDir() + + out, _, err := runCmd( + "--input", filepath.Join(td, "greeter"), + "--analysis-level", "2", + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + + var result struct { + CallGraph []interface{} `json:"call_graph"` + } + if jsonErr := json.Unmarshal([]byte(out), &result); jsonErr != nil { + t.Fatalf("stdout is not valid JSON: %v", jsonErr) + } + if len(result.CallGraph) == 0 { + t.Error("level 2 should produce call graph edges; got none") + } +} + +// ── --skip-tests ───────────────────────────────────────────────────────────── + +func TestRootCmd_SkipTestsFalseIncludesTestFiles(t *testing.T) { + td := cliTestdataDir() + + out, _, err := runCmd( + "--input", filepath.Join(td, "multipackage"), + "--skip-tests=false", + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + + if !strings.Contains(out, "server_test.go") { + t.Error("--skip-tests=false: server_test.go should appear in JSON output") + } +} + +// ── --target-files ──────────────────────────────────────────────────────────── + +func TestRootCmd_TargetFilesRestrictsOutput(t *testing.T) { + td := cliTestdataDir() + serverFile := filepath.Join(td, "multipackage", "server", "server.go") + + out, _, err := runCmd( + "--input", filepath.Join(td, "multipackage"), + "--target-files", serverFile, + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + + if strings.Contains(out, `"worker/worker.go"`) { + t.Error("--target-files: worker/worker.go should not appear when only server is targeted") + } + if !strings.Contains(out, `"server/server.go"`) { + t.Error("--target-files: server/server.go should appear in output") + } +} diff --git a/internal/analysis/registry_test.go b/internal/analysis/registry_test.go new file mode 100644 index 0000000..68bd55d --- /dev/null +++ b/internal/analysis/registry_test.go @@ -0,0 +1,137 @@ +package analysis + +// Tests for orderPasses (unexported — must be in the same package). +// +// We use lightweight stub passes so these tests have no external dependencies +// and run without loading any Go source files. + +import ( + "strings" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" +) + +// stubPass is a minimal AnalysisPass for testing orderPasses. +type stubPass struct { + name string + provides []string + requires []string +} + +func (s *stubPass) Name() string { return s.name } +func (s *stubPass) Provides() []string { return s.provides } +func (s *stubPass) Requires() []string { return s.requires } +func (s *stubPass) Run(_ *schema.GoApplication, _ AnalysisContext) (AnalysisResult, error) { + return AnalysisResult{}, nil +} + +func mkPass(name string, provides, requires []string) AnalysisPass { + return &stubPass{name: name, provides: provides, requires: requires} +} + +// ── orderPasses ─────────────────────────────────────────────────────────────── + +func TestOrderPasses_Empty(t *testing.T) { + ordered, err := orderPasses(nil) + if err != nil { + t.Fatalf("empty passes: unexpected error: %v", err) + } + if len(ordered) != 0 { + t.Errorf("got %d passes, want 0", len(ordered)) + } +} + +func TestOrderPasses_SingleNoDeps(t *testing.T) { + p := mkPass("solo", []string{"x"}, nil) + ordered, err := orderPasses([]AnalysisPass{p}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(ordered) != 1 || ordered[0].Name() != "solo" { + t.Errorf("expected [solo]; got %v", names(ordered)) + } +} + +// A → B: A provides "feat", B requires "feat". A must come before B. +func TestOrderPasses_LinearDependency(t *testing.T) { + a := mkPass("a", []string{"feat"}, nil) + b := mkPass("b", nil, []string{"feat"}) + // Deliver in reverse order to stress the sort. + ordered, err := orderPasses([]AnalysisPass{b, a}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(ordered) != 2 { + t.Fatalf("got %d passes, want 2", len(ordered)) + } + if ordered[0].Name() != "a" || ordered[1].Name() != "b" { + t.Errorf("wrong order: got %v, want [a b]", names(ordered)) + } +} + +// A → C ← B: two independent passes both provide something C needs. +func TestOrderPasses_DiamondDependency(t *testing.T) { + a := mkPass("a", []string{"x"}, nil) + b := mkPass("b", []string{"y"}, nil) + c := mkPass("c", nil, []string{"x", "y"}) + ordered, err := orderPasses([]AnalysisPass{c, b, a}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(ordered) != 3 { + t.Fatalf("got %d passes, want 3", len(ordered)) + } + // c must be last. + if ordered[len(ordered)-1].Name() != "c" { + t.Errorf("c should be last; got %v", names(ordered)) + } +} + +func TestOrderPasses_UnsatisfiedRequirement(t *testing.T) { + p := mkPass("needy", nil, []string{"missing-cap"}) + _, err := orderPasses([]AnalysisPass{p}) + if err == nil { + t.Fatal("expected error for unsatisfied requirement, got nil") + } + if !strings.Contains(err.Error(), "needy") { + t.Errorf("error should mention the blocked pass name; got %q", err.Error()) + } +} + +func TestOrderPasses_Cycle(t *testing.T) { + // A requires "b", B requires "a" — neither can run. + a := mkPass("a", []string{"a-cap"}, []string{"b-cap"}) + b := mkPass("b", []string{"b-cap"}, []string{"a-cap"}) + _, err := orderPasses([]AnalysisPass{a, b}) + if err == nil { + t.Fatal("expected error for cycle, got nil") + } +} + +// ── RunPipeline with empty registry ────────────────────────────────────────── + +func TestRunPipeline_EmptyRegistry(t *testing.T) { + // Save and restore to avoid affecting other tests. + old := registeredPasses + registeredPasses = nil + defer func() { registeredPasses = old }() + + app := &schema.GoApplication{ + Entrypoints: map[string][]schema.GoEntrypoint{}, + CallGraph: []schema.GoCallEdge{}, + } + if err := RunPipeline(app, AnalysisContext{}); err != nil { + t.Fatalf("RunPipeline with empty registry: %v", err) + } +} + +// ── helpers ─────────────────────────────────────────────────────────────────── + +func names(passes []AnalysisPass) []string { + out := make([]string, len(passes)) + for i, p := range passes { + out[i] = p.Name() + } + return out +} diff --git a/internal/core/analyzer_test.go b/internal/core/analyzer_test.go index 8d322ba..1e95011 100644 --- a/internal/core/analyzer_test.go +++ b/internal/core/analyzer_test.go @@ -4,7 +4,6 @@ import ( "encoding/json" "os" "path/filepath" - "runtime" "testing" "time" @@ -13,52 +12,23 @@ import ( "github.com/codellm-devkit/codeanalyzer-go/internal/schema" ) -// fixtureDir returns the absolute path to testdata/fixture. -func fixtureDir(t *testing.T) string { +// greeterDir returns the absolute path to testdata/greeter. +// Still used by the caching tests, which must run fresh analysis. +func greeterDir(t *testing.T) string { t.Helper() - _, file, _, ok := runtime.Caller(0) - if !ok { - t.Fatal("cannot determine source file path") - } - // internal/core/analyzer_test.go → ../.. → codeanalyzer-go root → testdata/fixture - root := filepath.Join(filepath.Dir(file), "..", "..") - abs, err := filepath.Abs(filepath.Join(root, "testdata", "fixture")) - if err != nil { - t.Fatalf("resolving fixture dir: %v", err) - } - return abs -} - -func runAnalysis(t *testing.T, level options.AnalysisLevel) *schema.GoApplication { - t.Helper() - dir := fixtureDir(t) - outDir := t.TempDir() - opts := options.AnalysisOptions{ - InputPath: dir, - OutputDir: outDir, - Level: level, - SkipTests: true, - CacheDir: t.TempDir(), - } - app, err := core.New(opts).Analyze() - if err != nil { - t.Fatalf("Analyze() failed: %v", err) - } - return app + return filepath.Join(testdataDir(), "greeter") } // ── Symbol table tests ──────────────────────────────────────────────────────── func TestSymbolTable_NonEmpty(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) - if len(app.SymbolTable) == 0 { + if len(sharedGreeterL1.SymbolTable) == 0 { t.Fatal("symbol table is empty") } } func TestSymbolTable_PathKeysAreRelative(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) - for key := range app.SymbolTable { + for key := range sharedGreeterL1.SymbolTable { if filepath.IsAbs(key) { t.Errorf("symbol_table key is absolute path: %s", key) } @@ -66,11 +36,10 @@ func TestSymbolTable_PathKeysAreRelative(t *testing.T) { } func TestSymbolTable_KnownType(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) const wantFile = "pkg/greeter/greeter.go" - f, ok := app.SymbolTable[wantFile] + f, ok := sharedGreeterL1.SymbolTable[wantFile] if !ok { - t.Fatalf("file %q not in symbol table; got keys: %v", wantFile, keys(app.SymbolTable)) + t.Fatalf("file %q not in symbol table; got keys: %v", wantFile, keys(sharedGreeterL1.SymbolTable)) } if _, ok := f.Types["Greeter"]; !ok { t.Errorf("GoType 'Greeter' not found in %s", wantFile) @@ -78,8 +47,7 @@ func TestSymbolTable_KnownType(t *testing.T) { } func TestSymbolTable_KnownInterface(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) - f := app.SymbolTable["pkg/greeter/greeter.go"] + f := sharedGreeterL1.SymbolTable["pkg/greeter/greeter.go"] gt, ok := f.Types["Logger"] if !ok { t.Fatal("GoType 'Logger' not found") @@ -90,8 +58,7 @@ func TestSymbolTable_KnownInterface(t *testing.T) { } func TestSymbolTable_StructFields(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) - f := app.SymbolTable["pkg/greeter/greeter.go"] + f := sharedGreeterL1.SymbolTable["pkg/greeter/greeter.go"] gt := f.Types["Greeter"] if len(gt.Fields) == 0 { t.Fatal("Greeter has no fields") @@ -105,8 +72,7 @@ func TestSymbolTable_StructFields(t *testing.T) { } func TestSymbolTable_CallSitesRecorded(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) - f := app.SymbolTable["main.go"] + f := sharedGreeterL1.SymbolTable["main.go"] var mainFn *schema.GoCallable for _, c := range f.Functions { c := c @@ -121,7 +87,6 @@ func TestSymbolTable_CallSitesRecorded(t *testing.T) { if len(mainFn.CallSites) == 0 { t.Error("main() has no recorded call sites") } - // All call sites must start with callee_signature == nil (pre-resolution). for _, cs := range mainFn.CallSites { if cs.CalleeSignature != nil { t.Errorf("call site %q has callee_signature pre-filled during symbol-table build", cs.MethodName) @@ -132,16 +97,14 @@ func TestSymbolTable_CallSitesRecorded(t *testing.T) { // ── Call graph tests ────────────────────────────────────────────────────────── func TestCallGraph_NonEmpty(t *testing.T) { - app := runAnalysis(t, options.LevelCallGraph) - if len(app.CallGraph) == 0 { + if len(sharedGreeterL2.CallGraph) == 0 { t.Fatal("call graph is empty") } } func TestCallGraph_NoDanglingEdges(t *testing.T) { - app := runAnalysis(t, options.LevelCallGraph) - sigs := allSignatures(app) - for _, e := range app.CallGraph { + sigs := allSignatures(sharedGreeterL2) + for _, e := range sharedGreeterL2.CallGraph { if !sigs[e.Source] { t.Errorf("dangling edge source: %s", e.Source) } @@ -152,8 +115,7 @@ func TestCallGraph_NoDanglingEdges(t *testing.T) { } func TestCallGraph_Provenance(t *testing.T) { - app := runAnalysis(t, options.LevelCallGraph) - for _, e := range app.CallGraph { + for _, e := range sharedGreeterL2.CallGraph { if len(e.Provenance) == 0 { t.Errorf("edge %s→%s has empty provenance", e.Source, e.Target) } @@ -161,11 +123,9 @@ func TestCallGraph_Provenance(t *testing.T) { } func TestCallGraph_CallSitesBackfilled(t *testing.T) { - app := runAnalysis(t, options.LevelCallGraph) - f := app.SymbolTable["main.go"] + f := sharedGreeterL2.SymbolTable["main.go"] for _, callable := range f.Functions { for _, cs := range callable.CallSites { - // Sites that resolved to a project-internal callee must be backfilled. if cs.CalleeSignature != nil && *cs.CalleeSignature == "" { t.Errorf("callable %s: call site %q has empty string callee_signature", callable.Signature, cs.MethodName) } @@ -176,9 +136,8 @@ func TestCallGraph_CallSitesBackfilled(t *testing.T) { // ── JSON output tests ───────────────────────────────────────────────────────── func TestWriteOutput_ValidJSON(t *testing.T) { - app := runAnalysis(t, options.LevelCallGraph) outDir := t.TempDir() - if err := core.WriteOutput(app, outDir, "json"); err != nil { + if err := core.WriteOutput(sharedGreeterL2, outDir, "json"); err != nil { t.Fatalf("WriteOutput: %v", err) } data, err := os.ReadFile(filepath.Join(outDir, "analysis.json")) @@ -195,9 +154,8 @@ func TestWriteOutput_ValidJSON(t *testing.T) { } func TestWriteOutput_EmptyFormatDefaultsToJSON(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) outDir := t.TempDir() - if err := core.WriteOutput(app, outDir, ""); err != nil { + if err := core.WriteOutput(sharedGreeterL1, outDir, ""); err != nil { t.Fatalf("WriteOutput with empty format: %v", err) } if _, err := os.Stat(filepath.Join(outDir, "analysis.json")); err != nil { @@ -206,42 +164,36 @@ func TestWriteOutput_EmptyFormatDefaultsToJSON(t *testing.T) { } func TestWriteOutput_MsgpackNotImplemented(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) outDir := t.TempDir() - err := core.WriteOutput(app, outDir, "msgpack") - if err == nil { + if err := core.WriteOutput(sharedGreeterL1, outDir, "msgpack"); err == nil { t.Fatal("expected error for --format msgpack, got nil") } } func TestWriteOutput_UnknownFormatErrors(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) outDir := t.TempDir() - err := core.WriteOutput(app, outDir, "csv") - if err == nil { + if err := core.WriteOutput(sharedGreeterL1, outDir, "csv"); err == nil { t.Fatal("expected error for unknown format, got nil") } } // ── Caching tests ───────────────────────────────────────────────────────────── +// These tests must run their own analysis to exercise the caching machinery. func TestCaching_SecondRunReuses(t *testing.T) { - dir := fixtureDir(t) + dir := greeterDir(t) cacheDir := t.TempDir() - outDir := t.TempDir() opts := options.AnalysisOptions{ InputPath: dir, - OutputDir: outDir, + OutputDir: t.TempDir(), Level: options.LevelCallGraph, SkipTests: true, CacheDir: cacheDir, } - // First run — populates cache. app1, err := core.New(opts).Analyze() if err != nil { t.Fatalf("first run: %v", err) } - // Second run — must not error and must return identical key count. app2, err := core.New(opts).Analyze() if err != nil { t.Fatalf("second run: %v", err) @@ -256,10 +208,9 @@ func TestCaching_SecondRunReuses(t *testing.T) { } func TestCaching_CacheFileWritten(t *testing.T) { - dir := fixtureDir(t) cacheDir := t.TempDir() opts := options.AnalysisOptions{ - InputPath: dir, + InputPath: greeterDir(t), Level: options.LevelSymbolTable, SkipTests: true, CacheDir: cacheDir, @@ -267,17 +218,15 @@ func TestCaching_CacheFileWritten(t *testing.T) { if _, err := core.New(opts).Analyze(); err != nil { t.Fatalf("Analyze: %v", err) } - cachePath := filepath.Join(cacheDir, "analysis_cache.json") - if _, err := os.Stat(cachePath); err != nil { + if _, err := os.Stat(filepath.Join(cacheDir, "analysis_cache.json")); err != nil { t.Fatalf("analysis_cache.json not written to CacheDir: %v", err) } } func TestCaching_CacheContentsRoundTrip(t *testing.T) { - dir := fixtureDir(t) cacheDir := t.TempDir() opts := options.AnalysisOptions{ - InputPath: dir, + InputPath: greeterDir(t), Level: options.LevelSymbolTable, SkipTests: true, CacheDir: cacheDir, @@ -301,15 +250,13 @@ func TestCaching_CacheContentsRoundTrip(t *testing.T) { } func TestCaching_EagerForcesRebuild(t *testing.T) { - dir := fixtureDir(t) cacheDir := t.TempDir() opts := options.AnalysisOptions{ - InputPath: dir, + InputPath: greeterDir(t), Level: options.LevelSymbolTable, SkipTests: true, CacheDir: cacheDir, } - // First run (non-eager) — seeds go_mod_hash. if _, err := core.New(opts).Analyze(); err != nil { t.Fatalf("first run: %v", err) } @@ -319,9 +266,12 @@ func TestCaching_EagerForcesRebuild(t *testing.T) { t.Fatalf("cache not written after first run: %v", err) } - time.Sleep(10 * time.Millisecond) + // Backdate the cache file so the mtime delta is unambiguous — no sleep needed. + past := info1.ModTime().Add(-time.Second) + if err := os.Chtimes(cachePath, past, past); err != nil { + t.Fatalf("backdating cache mtime: %v", err) + } - // Second run with Eager=true — must rewrite cache even when go_mod_hash matches. opts.Eager = true if _, err := core.New(opts).Analyze(); err != nil { t.Fatalf("eager run: %v", err) @@ -330,10 +280,9 @@ func TestCaching_EagerForcesRebuild(t *testing.T) { if err != nil { t.Fatalf("cache not found after eager run: %v", err) } - // saveCache always writes, so mtime must advance. - if !info2.ModTime().After(info1.ModTime()) { + if !info2.ModTime().After(past) { t.Errorf("analysis_cache.json mtime did not advance on eager=true run: %v vs %v", - info1.ModTime(), info2.ModTime()) + past, info2.ModTime()) } } diff --git a/internal/core/chi_test.go b/internal/core/chi_test.go new file mode 100644 index 0000000..48742fe --- /dev/null +++ b/internal/core/chi_test.go @@ -0,0 +1,196 @@ +package core_test + +// Tests for the chi fixture — chi v5 (github.com/go-chi/chi/v5) analyzed as the +// project under test, not as a dependency. +// +// Goals: +// 1. Multi-package library (root + middleware) is fully indexed. +// 2. Interface types (chi.Router) and struct types (chi.Mux) are both captured. +// 3. Methods on *Mux (Get, Post, Route, …) appear in the symbol table. +// 4. Vendor files are absent (chi has no external deps; nothing to exclude). +// 5. Call graph edges (Level 2) are internally consistent — no dangling endpoints. + +import ( + "strings" + "testing" +) + +// ── File coverage ───────────────────────────────────────────────────────────── + +func TestChi_SymbolTableNonEmpty(t *testing.T) { + if len(sharedChiL2.SymbolTable) == 0 { + t.Fatal("chi symbol table is empty — analysis may have failed silently") + } +} + +// chi v5 has exactly 35 non-test Go source files (5 root + 30 middleware). +func TestChi_SymbolTableFileCount(t *testing.T) { + const want = 35 + if got := len(sharedChiL2.SymbolTable); got != want { + t.Errorf("symbol table: got %d file(s), want %d; keys: %v", got, want, keys(sharedChiL2.SymbolTable)) + } +} + +func TestChi_PathKeysAreRelative(t *testing.T) { + for key := range sharedChiL2.SymbolTable { + if strings.HasPrefix(key, "/") { + t.Errorf("symbol_table key is absolute: %s", key) + } + } +} + +// ── Root package files ──────────────────────────────────────────────────────── + +func TestChi_RootFilesPresent(t *testing.T) { + for _, name := range []string{"chi.go", "mux.go", "context.go", "chain.go", "tree.go"} { + t.Run(name, func(t *testing.T) { + if _, ok := sharedChiL2.SymbolTable[name]; !ok { + t.Errorf("%s not in symbol table; keys: %v", name, keys(sharedChiL2.SymbolTable)) + } + }) + } +} + +// ── Middleware package files ────────────────────────────────────────────────── + +func TestChi_MiddlewareFilesPresent(t *testing.T) { + for _, name := range []string{ + "middleware/logger.go", + "middleware/recoverer.go", + "middleware/middleware.go", + } { + t.Run(name, func(t *testing.T) { + if _, ok := sharedChiL2.SymbolTable[name]; !ok { + t.Errorf("%s not in symbol table", name) + } + }) + } +} + +// ── Interface and struct types ──────────────────────────────────────────────── + +// chi.go declares the Router interface. +func TestChi_RouterIsInterface(t *testing.T) { + f, ok := sharedChiL2.SymbolTable["chi.go"] + if !ok { + t.Fatal("chi.go not in symbol table") + } + router, ok := f.Types["Router"] + if !ok { + t.Fatal("Router type not found in chi.go") + } + if !router.IsInterface { + t.Error("Router should be an interface, got is_interface=false") + } +} + +// mux.go declares the Mux struct (not an interface). +func TestChi_MuxIsStruct(t *testing.T) { + f, ok := sharedChiL2.SymbolTable["mux.go"] + if !ok { + t.Fatal("mux.go not in symbol table") + } + mux, ok := f.Types["Mux"] + if !ok { + t.Fatal("Mux type not found in mux.go") + } + if mux.IsInterface { + t.Error("Mux should be a struct, got is_interface=true") + } +} + +// ── Methods on *Mux ─────────────────────────────────────────────────────────── + +func TestChi_MuxHasRoutingMethods(t *testing.T) { + f, ok := sharedChiL2.SymbolTable["mux.go"] + if !ok { + t.Fatal("mux.go not in symbol table") + } + for _, method := range []string{"Get", "Post", "Put", "Delete", "Route", "Use", "With"} { + t.Run(method, func(t *testing.T) { + if findCallableByName(f, method) == nil { + t.Errorf("method %q not found on Mux in mux.go", method) + } + }) + } +} + +// ── Call graph: no dangling edges ───────────────────────────────────────────── + +func TestChi_NoDanglingEdges(t *testing.T) { + sigs := allSignatures(sharedChiL2) + for _, e := range sharedChiL2.CallGraph { + if !sigs[e.Source] { + t.Errorf("dangling edge source: %s", e.Source) + } + if !sigs[e.Target] { + t.Errorf("dangling edge target: %s", e.Target) + } + } +} + +// ── H1: InnerCallables populated for functions with closures ────────────────── + +// middleware/logger.go RequestLogger returns a closure-based middleware; its +// outer function body should have at least one inner callable after the fix. +func TestChi_RequestLoggerHasInnerCallables(t *testing.T) { + f, ok := sharedChiL2.SymbolTable["middleware/logger.go"] + if !ok { + t.Fatal("middleware/logger.go not in symbol table") + } + rl := findCallableByName(f, "RequestLogger") + if rl == nil { + t.Fatal("RequestLogger not found in middleware/logger.go") + } + if len(rl.InnerCallables) == 0 { + t.Error("RequestLogger should have at least one inner callable (closure), got none") + } +} + +// ── H2: IsConstructorCall for type-conversion call sites ────────────────────── + +// tree.go RegisterMethod contains `mt := methodTyp(2 << n)`. +// methodTyp is a named type, so the call is a type-conversion (constructor) call. +func TestChi_RegisterMethodHasConstructorCallSite(t *testing.T) { + f, ok := sharedChiL2.SymbolTable["tree.go"] + if !ok { + t.Fatal("tree.go not in symbol table") + } + rm := findCallableByName(f, "RegisterMethod") + if rm == nil { + t.Fatal("RegisterMethod not found in tree.go") + } + for _, site := range rm.CallSites { + if site.IsConstructorCall { + return // found + } + } + t.Error("RegisterMethod should have at least one IsConstructorCall=true site (methodTyp(...))") +} + +// ── H7: init() functions captured ──────────────────────────────────────────── + +// middleware/terminal.go, middleware/logger.go, and middleware/request_id.go +// each declare an init() function that should appear in the symbol table. +func TestChi_InitFunctionsPresent(t *testing.T) { + for _, file := range []string{ + "middleware/terminal.go", + "middleware/logger.go", + "middleware/request_id.go", + } { + t.Run(file, func(t *testing.T) { + f, ok := sharedChiL2.SymbolTable[file] + if !ok { + t.Fatalf("%s not in symbol table", file) + } + initFn := findCallableByName(f, "init") + if initFn == nil { + t.Errorf("init() not found in %s", file) + return + } + if initFn.IsExported { + t.Errorf("init() in %s should not be exported", file) + } + }) + } +} diff --git a/internal/core/errors_test.go b/internal/core/errors_test.go new file mode 100644 index 0000000..f66e24e --- /dev/null +++ b/internal/core/errors_test.go @@ -0,0 +1,76 @@ +package core_test + +// Error-path tests: verify that the analyzer returns meaningful errors (not +// panics or silent empty results) when given bad inputs. + +import ( + "os" + "path/filepath" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" +) + +func TestAnalyze_NonExistentPath(t *testing.T) { + opts := options.AnalysisOptions{ + InputPath: filepath.Join(t.TempDir(), "does_not_exist"), + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + } + _, err := core.New(opts).Analyze() + if err == nil { + t.Fatal("expected error for non-existent InputPath, got nil") + } +} + +func TestAnalyze_EmptyDirectory(t *testing.T) { + // A real directory with no Go files: analyzer should succeed but produce an + // empty symbol table (graceful degradation, not a hard error). + emptyDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: emptyDir, + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + } + app, err := core.New(opts).Analyze() + if err != nil { + t.Logf("Analyze returned error (acceptable): %v", err) + return + } + if len(app.SymbolTable) != 0 { + t.Errorf("expected empty symbol table for empty directory; got %d entries", len(app.SymbolTable)) + } +} + +func TestAnalyze_MissingGoMod(t *testing.T) { + // A directory with a .go file but no go.mod — not a valid module. + // The analyzer either returns an error or an empty symbol table; both are acceptable. + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "hello.go"), []byte("package main\nfunc main(){}\n"), 0o644); err != nil { + t.Fatalf("writing hello.go: %v", err) + } + opts := options.AnalysisOptions{ + InputPath: dir, + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + } + app, err := core.New(opts).Analyze() + if err != nil { + t.Logf("Analyze returned error (acceptable): %v", err) + return + } + if len(app.SymbolTable) != 0 { + t.Errorf("expected empty symbol table for module with no go.mod; got %d entries", len(app.SymbolTable)) + } +} + +func TestAnalyze_LevelOneDoesNotProduceCallGraph(t *testing.T) { + // Level-1 analysis must never populate the call graph. + if len(sharedGreeterL1.CallGraph) != 0 { + t.Errorf("LevelSymbolTable produced %d call-graph edges; expected 0", len(sharedGreeterL1.CallGraph)) + } +} diff --git a/internal/core/generics_test.go b/internal/core/generics_test.go new file mode 100644 index 0000000..53e6788 --- /dev/null +++ b/internal/core/generics_test.go @@ -0,0 +1,127 @@ +package core_test + +// Tests for Go 1.18+ generic constructs — type parameters, union-constraint +// interfaces, multi-type-parameter functions, and methods on generic types. +// These exercise AST paths (IndexExpr receivers, TypeParams) that the greeter +// and realistic fixtures never reach. + +import ( + "strings" + "testing" +) + +// ── Symbol table completeness ───────────────────────────────────────────────── + +func TestGenerics_SymbolTableNonEmpty(t *testing.T) { + if len(sharedGenericsL1.SymbolTable) == 0 { + t.Fatal("generics symbol table is empty") + } +} + +func TestGenerics_PathKeysAreRelative(t *testing.T) { + for key := range sharedGenericsL1.SymbolTable { + if strings.HasPrefix(key, "/") { + t.Errorf("symbol_table key is absolute: %s", key) + } + } +} + +// ── Type name integrity ─────────────────────────────────────────────────────── + +func TestGenerics_SetTypePresentInSymbolTable(t *testing.T) { + const wantFile = "set/set.go" + f, ok := sharedGenericsL1.SymbolTable[wantFile] + if !ok { + t.Fatalf("%s not in symbol table; keys: %v", wantFile, keys(sharedGenericsL1.SymbolTable)) + } + if _, ok := f.Types["Set"]; !ok { + t.Errorf("GoType 'Set' not found in %s", wantFile) + } +} + +// The type name must be the base identifier only, not the parameterised form. +func TestGenerics_TypeNameHasNoTypeParams(t *testing.T) { + for _, f := range sharedGenericsL1.SymbolTable { + for name := range f.Types { + if strings.ContainsAny(name, "[]") { + t.Errorf("type name %q contains type-parameter brackets — should be stripped", name) + } + } + } +} + +// ── Methods on generic types ────────────────────────────────────────────────── + +func TestGenerics_SetMethods(t *testing.T) { + f := sharedGenericsL1.SymbolTable["set/set.go"] + for _, want := range []string{"Add", "Remove", "Contains", "Len"} { + t.Run(want, func(t *testing.T) { + if findCallableByName(f, want) == nil { + t.Errorf("method %q not found on Set", want) + } + }) + } +} + +func TestGenerics_UnexportedMethodOnGenericType(t *testing.T) { + f := sharedGenericsL1.SymbolTable["set/set.go"] + snapshot := findCallableByName(f, "snapshot") + if snapshot == nil { + t.Fatal("unexported method 'snapshot' not found on Set") + } + if snapshot.IsExported { + t.Error("snapshot.is_exported should be false") + } +} + +// ── Union-constraint interfaces ─────────────────────────────────────────────── + +func TestGenerics_OrderedIsInterface(t *testing.T) { + f, ok := sharedGenericsL1.SymbolTable["fn/fn.go"] + if !ok { + t.Fatal("fn/fn.go not in symbol table") + } + ordered, ok := f.Types["Ordered"] + if !ok { + t.Fatal("GoType 'Ordered' not found in fn/fn.go") + } + if !ordered.IsInterface { + t.Error("Ordered.is_interface should be true (union constraint)") + } +} + +func TestGenerics_NumericIsInterface(t *testing.T) { + f := sharedGenericsL1.SymbolTable["fn/fn.go"] + numeric, ok := f.Types["Numeric"] + if !ok { + t.Fatal("GoType 'Numeric' not found in fn/fn.go") + } + if !numeric.IsInterface { + t.Error("Numeric.is_interface should be true") + } +} + +// ── Generic functions ───────────────────────────────────────────────────────── + +func TestGenerics_SingleTypeParamFunctions(t *testing.T) { + f := sharedGenericsL1.SymbolTable["fn/fn.go"] + for _, name := range []string{"Min", "Max", "Filter"} { + t.Run(name, func(t *testing.T) { + if findCallableByName(f, name) == nil { + t.Errorf("generic function %q not found in fn/fn.go", name) + } + }) + } +} + +func TestGenerics_MapHasMultipleParams(t *testing.T) { + f := sharedGenericsL1.SymbolTable["fn/fn.go"] + mapFn := findCallableByName(f, "Map") + if mapFn == nil { + t.Fatal("generic function 'Map' not found in fn/fn.go") + } + // Map[T, U any](in []T, f func(T) U) []U — two declared parameters. + if len(mapFn.Parameters) < 2 { + t.Errorf("Map() should have >= 2 parameters; got %d", len(mapFn.Parameters)) + } +} diff --git a/internal/core/incremental_test.go b/internal/core/incremental_test.go new file mode 100644 index 0000000..7fc85e8 --- /dev/null +++ b/internal/core/incremental_test.go @@ -0,0 +1,120 @@ +package core_test + +// Tests for --target-files / TargetFiles incremental analysis mode. +// +// When TargetFiles is non-empty each value is passed as a "file=" +// pattern to packages.Load, which loads only the package(s) containing those +// files. Other packages in the project are not loaded and must not appear in +// the symbol table. + +import ( + "path/filepath" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" +) + +func multipackageDir() string { + return filepath.Join(testdataDir(), "multipackage") +} + +// TestTargetFiles_SinglePackage: targeting one file restricts the symbol table +// to the package containing that file. The multipackage fixture has three +// packages (main, server, worker); targeting server/server.go should exclude +// main.go and worker/worker.go. +func TestTargetFiles_SinglePackage(t *testing.T) { + td := multipackageDir() + serverFile := filepath.Join(td, "server", "server.go") + + app, err := core.New(options.AnalysisOptions{ + InputPath: td, + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: t.TempDir(), + TargetFiles: []string{serverFile}, + }).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + if len(app.SymbolTable) == 0 { + t.Fatal("symbol table is empty — analysis may have failed silently") + } + if _, ok := app.SymbolTable["server/server.go"]; !ok { + t.Errorf("server/server.go should be in symbol table; got keys: %v", keys(app.SymbolTable)) + } + if _, ok := app.SymbolTable["main.go"]; ok { + t.Error("main.go must not be in symbol table when only server package is targeted") + } + if _, ok := app.SymbolTable["worker/worker.go"]; ok { + t.Error("worker/worker.go must not be in symbol table when only server package is targeted") + } +} + +// TestTargetFiles_MultiplePackages: targeting files in two separate packages +// includes both packages but still excludes the third. +func TestTargetFiles_MultiplePackages(t *testing.T) { + td := multipackageDir() + + app, err := core.New(options.AnalysisOptions{ + InputPath: td, + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: t.TempDir(), + TargetFiles: []string{ + filepath.Join(td, "server", "server.go"), + filepath.Join(td, "worker", "worker.go"), + }, + }).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + for _, want := range []string{"server/server.go", "worker/worker.go"} { + if _, ok := app.SymbolTable[want]; !ok { + t.Errorf("%s should be in symbol table; got keys: %v", want, keys(app.SymbolTable)) + } + } + if _, ok := app.SymbolTable["main.go"]; ok { + t.Error("main.go must not be in symbol table when not targeted") + } +} + +// TestTargetFiles_NilMeansAllFiles: nil TargetFiles produces a full analysis, +// matching the file count of the pre-computed sharedMultipackageL1. +func TestTargetFiles_NilMeansAllFiles(t *testing.T) { + const want = 4 // main.go + server/server.go + server/middleware.go + worker/worker.go + if got := len(sharedMultipackageL1.SymbolTable); got != want { + t.Errorf("multipackage fixture with nil TargetFiles: got %d files, want %d; keys: %v", + got, want, keys(sharedMultipackageL1.SymbolTable)) + } +} + +// TestTargetFiles_SiblingFilesIncluded: when a package has multiple source files +// (server.go + middleware.go), targeting any one file loads the entire package, +// so sibling files are also present in the symbol table. +func TestTargetFiles_SiblingFilesIncluded(t *testing.T) { + td := multipackageDir() + serverFile := filepath.Join(td, "server", "server.go") + + app, err := core.New(options.AnalysisOptions{ + InputPath: td, + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: t.TempDir(), + TargetFiles: []string{serverFile}, + }).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + // middleware.go is in the same package as server.go; it must appear too. + if _, ok := app.SymbolTable["server/middleware.go"]; !ok { + t.Errorf("server/middleware.go (sibling file) should be in symbol table; got keys: %v", + keys(app.SymbolTable)) + } +} diff --git a/internal/core/realistic_test.go b/internal/core/multipackage_test.go similarity index 69% rename from internal/core/realistic_test.go rename to internal/core/multipackage_test.go index 3f651f4..b1070bc 100644 --- a/internal/core/realistic_test.go +++ b/internal/core/multipackage_test.go @@ -5,48 +5,12 @@ package core_test // is_variadic, is_embedded, multi-file package, cyclomatic_complexity, specific edges. import ( - "path/filepath" - "runtime" "strings" "testing" - "github.com/codellm-devkit/codeanalyzer-go/internal/core" - "github.com/codellm-devkit/codeanalyzer-go/internal/options" "github.com/codellm-devkit/codeanalyzer-go/internal/schema" ) -func realisticDir(t *testing.T) string { - t.Helper() - _, file, _, ok := runtime.Caller(0) - if !ok { - t.Fatal("cannot determine source file path") - } - root := filepath.Join(filepath.Dir(file), "..", "..") - abs, err := filepath.Abs(filepath.Join(root, "testdata", "realistic")) - if err != nil { - t.Fatalf("resolving realistic fixture dir: %v", err) - } - return abs -} - -func runRealistic(t *testing.T, level options.AnalysisLevel) *schema.GoApplication { - t.Helper() - dir := realisticDir(t) - outDir := t.TempDir() - opts := options.AnalysisOptions{ - InputPath: dir, - OutputDir: outDir, - Level: level, - SkipTests: true, - CacheDir: t.TempDir(), - } - app, err := core.New(opts).Analyze() - if err != nil { - t.Fatalf("Analyze() failed: %v", err) - } - return app -} - // findCallableByName searches all functions and methods in a GoFile by short name. func findCallableByName(f schema.GoFile, name string) *schema.GoCallable { for _, c := range f.Functions { @@ -69,17 +33,15 @@ func findCallableByName(f schema.GoFile, name string) *schema.GoCallable { // ── Multi-file package ──────────────────────────────────────────────────────── func TestRealistic_MultiFilePkg(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - _, hasServer := app.SymbolTable["server/server.go"] - _, hasMiddleware := app.SymbolTable["server/middleware.go"] + _, hasServer := sharedMultipackageL1.SymbolTable["server/server.go"] + _, hasMiddleware := sharedMultipackageL1.SymbolTable["server/middleware.go"] if !hasServer { t.Error("server/server.go missing from symbol table") } if !hasMiddleware { t.Error("server/middleware.go missing from symbol table") } - // Tags must live in middleware.go, not server.go. - mw := app.SymbolTable["server/middleware.go"] + mw := sharedMultipackageL1.SymbolTable["server/middleware.go"] if findCallableByName(mw, "Tags") == nil { t.Error("Tags function not found in server/middleware.go") } @@ -88,15 +50,14 @@ func TestRealistic_MultiFilePkg(t *testing.T) { // ── Embedded struct field ───────────────────────────────────────────────────── func TestRealistic_EmbeddedField(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] server, ok := srv.Types["Server"] if !ok { t.Fatal("GoType 'Server' not found in server/server.go") } for _, f := range server.Fields { if f.IsEmbedded { - return // pass + return } } t.Errorf("Server has no embedded field; fields: %+v", server.Fields) @@ -105,8 +66,7 @@ func TestRealistic_EmbeddedField(t *testing.T) { // ── Multiple return types — (T, error) pattern ──────────────────────────────── func TestRealistic_MultipleReturnTypes(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] newFn := findCallableByName(srv, "New") if newFn == nil { t.Fatal("function 'New' not found in server/server.go") @@ -126,8 +86,7 @@ func TestRealistic_MultipleReturnTypes(t *testing.T) { } func TestRealistic_ValidateReturnTypes(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] validate := findCallableByName(srv, "Validate") if validate == nil { t.Fatal("method 'Validate' not found in server/server.go") @@ -140,8 +99,7 @@ func TestRealistic_ValidateReturnTypes(t *testing.T) { // ── Unexported callables ────────────────────────────────────────────────────── func TestRealistic_UnexportedMethod(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] shutdown := findCallableByName(srv, "shutdown") if shutdown == nil { t.Fatal("method 'shutdown' not found in server/server.go") @@ -152,8 +110,7 @@ func TestRealistic_UnexportedMethod(t *testing.T) { } func TestRealistic_UnexportedWorkerMethod(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - wkr := app.SymbolTable["worker/worker.go"] + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] execute := findCallableByName(wkr, "execute") if execute == nil { t.Fatal("method 'execute' not found in worker/worker.go") @@ -166,8 +123,7 @@ func TestRealistic_UnexportedWorkerMethod(t *testing.T) { // ── Receiver type / name ────────────────────────────────────────────────────── func TestRealistic_ReceiverType(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] addr := findCallableByName(srv, "Addr") if addr == nil { t.Fatal("method 'Addr' not found in server/server.go") @@ -178,26 +134,20 @@ func TestRealistic_ReceiverType(t *testing.T) { if addr.ReceiverName == "" { t.Error("Addr().receiver_name should be non-empty") } - // Pointer receiver — type should contain '*' or 'Server'. if !strings.Contains(addr.ReceiverType, "Server") { t.Errorf("Addr().receiver_type %q should reference Server", addr.ReceiverType) } } func TestRealistic_ValueReceiver(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - // Describe is defined in middleware.go but its receiver type (Server) lives in - // server.go — the reconcileCrossFileMethods pass attaches it to server.go's type. - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] describe := findCallableByName(srv, "Describe") if describe == nil { t.Fatal("method 'Describe' not found attached to Server in server/server.go") } - // Value receiver — ReceiverType should not contain '*'. if strings.Contains(describe.ReceiverType, "*") { t.Errorf("Describe().receiver_type %q should be a value receiver (no '*')", describe.ReceiverType) } - // Path should still record the physical definition file. if !strings.Contains(describe.Path, "middleware.go") { t.Errorf("Describe().path %q should point to middleware.go", describe.Path) } @@ -206,30 +156,28 @@ func TestRealistic_ValueReceiver(t *testing.T) { // ── Variadic parameters ─────────────────────────────────────────────────────── func TestRealistic_VariadicParamTags(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - mw := app.SymbolTable["server/middleware.go"] + mw := sharedMultipackageL1.SymbolTable["server/middleware.go"] tags := findCallableByName(mw, "Tags") if tags == nil { t.Fatal("function 'Tags' not found in server/middleware.go") } for _, p := range tags.Parameters { if p.IsVariadic { - return // pass + return } } t.Errorf("Tags() has no variadic parameter; params: %+v", tags.Parameters) } func TestRealistic_VariadicParamCombine(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - wkr := app.SymbolTable["worker/worker.go"] + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] combine := findCallableByName(wkr, "Combine") if combine == nil { t.Fatal("function 'Combine' not found in worker/worker.go") } for _, p := range combine.Parameters { if p.IsVariadic { - return // pass + return } } t.Errorf("Combine() has no variadic parameter; params: %+v", combine.Parameters) @@ -238,15 +186,14 @@ func TestRealistic_VariadicParamCombine(t *testing.T) { // ── Goroutine call site ─────────────────────────────────────────────────────── func TestRealistic_GoroutineCallsite(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - wkr := app.SymbolTable["worker/worker.go"] + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] run := findCallableByName(wkr, "Run") if run == nil { t.Fatal("method 'Run' not found in worker/worker.go") } for _, cs := range run.CallSites { if cs.IsGoroutine { - return // pass + return } } t.Errorf("Run() has no goroutine call site; sites: %+v", run.CallSites) @@ -255,13 +202,11 @@ func TestRealistic_GoroutineCallsite(t *testing.T) { // ── Cyclomatic complexity ───────────────────────────────────────────────────── func TestRealistic_CyclomaticComplexity(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - wkr := app.SymbolTable["worker/worker.go"] + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] execute := findCallableByName(wkr, "execute") if execute == nil { t.Fatal("method 'execute' not found in worker/worker.go") } - // execute() has an `if err != nil` branch → CC >= 2. if execute.CyclomaticComplexity < 2 { t.Errorf("execute().cyclomatic_complexity should be >= 2; got %d", execute.CyclomaticComplexity) } @@ -270,8 +215,7 @@ func TestRealistic_CyclomaticComplexity(t *testing.T) { // ── Interface detection ─────────────────────────────────────────────────────── func TestRealistic_InterfaceType(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - wkr := app.SymbolTable["worker/worker.go"] + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] proc, ok := wkr.Types["Processor"] if !ok { t.Fatal("GoType 'Processor' not found in worker/worker.go") @@ -281,29 +225,25 @@ func TestRealistic_InterfaceType(t *testing.T) { } } -// ── Specific call-graph edge ────────────────────────────────────────────────── +// ── Specific call-graph edges ───────────────────────────────────────────────── func TestRealistic_SpecificCallEdge(t *testing.T) { - app := runRealistic(t, options.LevelCallGraph) - // main() calls server.New() — this is a cross-package project-internal edge. - const wantTarget = "example.com/realistic/server.New" - for _, e := range app.CallGraph { + const wantTarget = "example.com/multipackage/server.New" + for _, e := range sharedMultipackageL2.CallGraph { if e.Target == wantTarget { - return // pass + return } } - t.Errorf("call graph missing expected edge to %s; edges: %v", wantTarget, edgeTargets(app)) + t.Errorf("call graph missing expected edge to %s; edges: %v", wantTarget, edgeTargets(sharedMultipackageL2)) } func TestRealistic_CrossPackageEdges(t *testing.T) { - app := runRealistic(t, options.LevelCallGraph) - // At least one edge must cross the main→server boundary and one main→worker boundary. var serverEdge, workerEdge bool - for _, e := range app.CallGraph { - if strings.Contains(e.Target, "realistic/server.") { + for _, e := range sharedMultipackageL2.CallGraph { + if strings.Contains(e.Target, "multipackage/server.") { serverEdge = true } - if strings.Contains(e.Target, "realistic/worker.") { + if strings.Contains(e.Target, "multipackage/worker.") { workerEdge = true } } @@ -315,6 +255,41 @@ func TestRealistic_CrossPackageEdges(t *testing.T) { } } +// ── H6: LocalVariables assertions ──────────────────────────────────────────── + +// worker.Combine has `out := Result{}` — a local variable with a known type. +func TestRealistic_LocalVariablesPresent(t *testing.T) { + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] + combine := findCallableByName(wkr, "Combine") + if combine == nil { + t.Fatal("function 'Combine' not found in worker/worker.go") + } + if len(combine.LocalVariables) == 0 { + t.Fatal("Combine() should have at least one local variable; got none") + } +} + +// worker.execute has `r, err := p.Process(t)` — two local variables. +func TestRealistic_LocalVariablesHaveType(t *testing.T) { + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] + execute := findCallableByName(wkr, "execute") + if execute == nil { + t.Fatal("method 'execute' not found in worker/worker.go") + } + for _, v := range execute.LocalVariables { + if v.Name == "err" { + if v.Type == "" { + t.Error("local variable 'err' should have a non-empty type") + } + if v.Scope != "function" { + t.Errorf("local variable 'err' scope should be 'function'; got %q", v.Scope) + } + return + } + } + t.Errorf("local variable 'err' not found in execute(); vars: %+v", execute.LocalVariables) +} + // ── Helpers ─────────────────────────────────────────────────────────────────── func edgeTargets(app *schema.GoApplication) []string { diff --git a/internal/core/skip_tests_test.go b/internal/core/skip_tests_test.go new file mode 100644 index 0000000..227039e --- /dev/null +++ b/internal/core/skip_tests_test.go @@ -0,0 +1,70 @@ +package core_test + +// Tests for the SkipTests option. +// +// testdata/multipackage/server/server_test.go is a minimal test file whose sole +// purpose is to give these tests something to look for. It is never included +// in the shared fixtures (all use SkipTests: true). + +import ( + "path/filepath" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" +) + +const serverTestFile = "server/server_test.go" + +// TestSkipTests_TrueExcludesTestFiles: the default (SkipTests=true) must not +// include any *_test.go file in the symbol table. +func TestSkipTests_TrueExcludesTestFiles(t *testing.T) { + // sharedMultipackageL1 is built with SkipTests: true — re-use it. + for key := range sharedMultipackageL1.SymbolTable { + if len(key) >= 8 && key[len(key)-8:] == "_test.go" { + t.Errorf("SkipTests=true: found test file in symbol table: %s", key) + } + } +} + +// TestSkipTests_FalseIncludesTestFiles: with SkipTests=false the analyzer must +// include *_test.go files in the symbol table. +func TestSkipTests_FalseIncludesTestFiles(t *testing.T) { + app, err := core.New(options.AnalysisOptions{ + InputPath: filepath.Join(testdataDir(), "multipackage"), + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: false, + CacheDir: t.TempDir(), + }).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + if _, ok := app.SymbolTable[serverTestFile]; !ok { + t.Errorf("SkipTests=false: %s not in symbol table; got keys: %v", + serverTestFile, keys(app.SymbolTable)) + } +} + +// TestSkipTests_FalseIncreasesFileCount: the symbol table with SkipTests=false +// must have more files than the same analysis with SkipTests=true. +func TestSkipTests_FalseIncreasesFileCount(t *testing.T) { + withSkip := len(sharedMultipackageL1.SymbolTable) + + app, err := core.New(options.AnalysisOptions{ + InputPath: filepath.Join(testdataDir(), "multipackage"), + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: false, + CacheDir: t.TempDir(), + }).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + if len(app.SymbolTable) <= withSkip { + t.Errorf("SkipTests=false: expected more files than %d (SkipTests=true count); got %d", + withSkip, len(app.SymbolTable)) + } +} diff --git a/internal/core/testsetup_test.go b/internal/core/testsetup_test.go new file mode 100644 index 0000000..85f3b72 --- /dev/null +++ b/internal/core/testsetup_test.go @@ -0,0 +1,99 @@ +package core_test + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" +) + +// Shared analysis results, populated once in TestMain and reused across all tests. +// Caching tests are excluded: they must exercise the caching machinery themselves. +var ( + sharedGreeterL1 *schema.GoApplication // greeter, symbol-table only + sharedGreeterL2 *schema.GoApplication // greeter, full call-graph + sharedMultipackageL1 *schema.GoApplication // multipackage, symbol-table only + sharedMultipackageL2 *schema.GoApplication // multipackage, full call-graph + sharedGenericsL1 *schema.GoApplication // generics, symbol-table only + sharedChiL2 *schema.GoApplication // chi (external dep), full call-graph +) + +func TestMain(m *testing.M) { + os.Exit(runTestMain(m)) +} + +// runTestMain wraps m.Run so that deferred cleanup runs before os.Exit. +func runTestMain(m *testing.M) int { + tdRoot := testdataDir() + + tmpRoot, err := os.MkdirTemp("", "codeanalyzer-test-*") + if err != nil { + fmt.Fprintf(os.Stderr, "testsetup: MkdirTemp: %v\n", err) + return 1 + } + defer os.RemoveAll(tmpRoot) + + type fixture struct { + name string + path string + level options.AnalysisLevel + dst **schema.GoApplication + } + + for _, f := range []fixture{ + {"greeter/L1", filepath.Join(tdRoot, "greeter"), options.LevelSymbolTable, &sharedGreeterL1}, + {"greeter/L2", filepath.Join(tdRoot, "greeter"), options.LevelCallGraph, &sharedGreeterL2}, + {"multipackage/L1", filepath.Join(tdRoot, "multipackage"), options.LevelSymbolTable, &sharedMultipackageL1}, + {"multipackage/L2", filepath.Join(tdRoot, "multipackage"), options.LevelCallGraph, &sharedMultipackageL2}, + {"generics/L1", filepath.Join(tdRoot, "generics"), options.LevelSymbolTable, &sharedGenericsL1}, + {"chi/L2", filepath.Join(tdRoot, "chi"), options.LevelCallGraph, &sharedChiL2}, + } { + outDir, err := os.MkdirTemp(tmpRoot, "out-*") + if err != nil { + fmt.Fprintf(os.Stderr, "testsetup %s: MkdirTemp out: %v\n", f.name, err) + return 1 + } + cacheDir, err := os.MkdirTemp(tmpRoot, "cache-*") + if err != nil { + fmt.Fprintf(os.Stderr, "testsetup %s: MkdirTemp cache: %v\n", f.name, err) + return 1 + } + opts := options.AnalysisOptions{ + InputPath: f.path, + OutputDir: outDir, + Level: f.level, + SkipTests: true, + CacheDir: cacheDir, + } + app, err := core.New(opts).Analyze() + if err != nil { + fmt.Fprintf(os.Stderr, "testsetup %s: Analyze: %v\n", f.name, err) + return 1 + } + if err := core.WriteOutput(app, outDir, "json"); err != nil { + fmt.Fprintf(os.Stderr, "testsetup %s: WriteOutput: %v\n", f.name, err) + return 1 + } + if _, err := os.Stat(filepath.Join(outDir, "analysis.json")); err != nil { + fmt.Fprintf(os.Stderr, "testsetup %s: analysis.json not created: %v\n", f.name, err) + return 1 + } + *f.dst = app + } + + return m.Run() +} + +// testdataDir returns the absolute path to the testdata directory. +// Uses runtime.Caller so it resolves correctly regardless of working directory. +func testdataDir() string { + _, thisFile, _, _ := runtime.Caller(0) + abs, _ := filepath.Abs(filepath.Join(filepath.Dir(thisFile), "..", "..", "testdata")) + return abs +} + diff --git a/internal/semantic_analysis/call_graph_test.go b/internal/semantic_analysis/call_graph_test.go new file mode 100644 index 0000000..60b89ee --- /dev/null +++ b/internal/semantic_analysis/call_graph_test.go @@ -0,0 +1,116 @@ +package semantic_analysis_test + +import ( + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/semantic_analysis" +) + +func edge(src, tgt string, weight int, prov ...string) schema.GoCallEdge { + return schema.GoCallEdge{Source: src, Target: tgt, Weight: weight, Provenance: prov} +} + +// ── MergeEdges ──────────────────────────────────────────────────────────────── + +func TestMergeEdges_EmptyBoth(t *testing.T) { + result := semantic_analysis.MergeEdges(nil, nil) + if len(result) != 0 { + t.Errorf("got %d edges, want 0", len(result)) + } +} + +func TestMergeEdges_PrimaryOnly(t *testing.T) { + primary := []schema.GoCallEdge{edge("a", "b", 1.0, "resolver")} + result := semantic_analysis.MergeEdges(primary, nil) + if len(result) != 1 { + t.Fatalf("got %d edges, want 1", len(result)) + } + if result[0].Source != "a" || result[0].Target != "b" { + t.Errorf("unexpected edge: %+v", result[0]) + } +} + +func TestMergeEdges_SecondaryOnly(t *testing.T) { + secondary := []schema.GoCallEdge{edge("x", "y", 2.0, "codeql")} + result := semantic_analysis.MergeEdges(nil, secondary) + if len(result) != 1 { + t.Fatalf("got %d edges, want 1", len(result)) + } + if result[0].Source != "x" || result[0].Target != "y" { + t.Errorf("unexpected edge: %+v", result[0]) + } +} + +func TestMergeEdges_DisjointEdges(t *testing.T) { + primary := []schema.GoCallEdge{edge("a", "b", 1.0, "resolver")} + secondary := []schema.GoCallEdge{edge("c", "d", 1.0, "codeql")} + result := semantic_analysis.MergeEdges(primary, secondary) + if len(result) != 2 { + t.Errorf("got %d edges, want 2", len(result)) + } +} + +func TestMergeEdges_DuplicateAccumulatesWeight(t *testing.T) { + primary := []schema.GoCallEdge{edge("a", "b", 3, "resolver")} + secondary := []schema.GoCallEdge{edge("a", "b", 5, "codeql")} + result := semantic_analysis.MergeEdges(primary, secondary) + if len(result) != 1 { + t.Fatalf("duplicate (a→b) should collapse to 1 edge; got %d", len(result)) + } + if result[0].Weight != 8 { + t.Errorf("weight: got %v, want 8", result[0].Weight) + } +} + +func TestMergeEdges_DuplicateUnionsProvenance(t *testing.T) { + primary := []schema.GoCallEdge{edge("a", "b", 1.0, "resolver")} + secondary := []schema.GoCallEdge{edge("a", "b", 1.0, "codeql")} + result := semantic_analysis.MergeEdges(primary, secondary) + if len(result) != 1 { + t.Fatalf("got %d edges, want 1", len(result)) + } + provSet := map[string]bool{} + for _, p := range result[0].Provenance { + provSet[p] = true + } + if !provSet["resolver"] || !provSet["codeql"] { + t.Errorf("provenance union failed; got %v", result[0].Provenance) + } +} + +func TestMergeEdges_DuplicateProvenanceNotDuplicated(t *testing.T) { + primary := []schema.GoCallEdge{edge("a", "b", 1.0, "resolver")} + secondary := []schema.GoCallEdge{edge("a", "b", 1.0, "resolver")} + result := semantic_analysis.MergeEdges(primary, secondary) + if len(result) != 1 { + t.Fatalf("got %d edges, want 1", len(result)) + } + count := 0 + for _, p := range result[0].Provenance { + if p == "resolver" { + count++ + } + } + if count != 1 { + t.Errorf("duplicate provenance should appear once; got %d times", count) + } +} + +func TestMergeEdges_OrderPreserved(t *testing.T) { + primary := []schema.GoCallEdge{ + edge("a", "b", 1.0), + edge("c", "d", 1.0), + } + secondary := []schema.GoCallEdge{ + edge("e", "f", 1.0), + } + result := semantic_analysis.MergeEdges(primary, secondary) + if len(result) != 3 { + t.Fatalf("got %d edges, want 3", len(result)) + } + // Primary edges come first, then secondary. + if result[0].Source != "a" || result[1].Source != "c" || result[2].Source != "e" { + t.Errorf("order not preserved: %v", result) + } +} diff --git a/internal/syntactic_analysis/symbol_table.go b/internal/syntactic_analysis/symbol_table.go index f6a7838..c7f48ce 100644 --- a/internal/syntactic_analysis/symbol_table.go +++ b/internal/syntactic_analysis/symbol_table.go @@ -50,6 +50,9 @@ func (b *SymbolTableBuilder) Build(targetFiles []string, skipTests bool) (map[st packages.NeedDeps, Dir: b.projectDir, Fset: b.fset, + // Include test packages when the caller wants test files. Without this, + // go/packages never presents *_test.go files to the loader. + Tests: !skipTests, // Silence go vet; we only need type info, not a full build. BuildFlags: []string{}, } @@ -90,6 +93,11 @@ func (b *SymbolTableBuilder) Build(targetFiles []string, skipTests bool) (map[st } filePath := pkg.GoFiles[i] relPath := utils.RelativePath(b.projectDir, filePath) + // Paths that escape the project root (generated test runners in the + // Go build cache, stdlib, etc.) are never project files. + if strings.HasPrefix(relPath, "..") { + continue + } if skipTests && utils.IsTestFile(relPath) { continue } @@ -129,6 +137,9 @@ func (b *SymbolTableBuilder) reconcileCrossFileMethods(symbolTable map[string]sc } filePath := pkg.GoFiles[i] relPath := utils.RelativePath(b.projectDir, filePath) + if strings.HasPrefix(relPath, "..") { + continue + } for _, decl := range astFile.Decls { fd, ok := decl.(*ast.FuncDecl) @@ -542,6 +553,7 @@ func (b *SymbolTableBuilder) buildCallable( callable.Code = b.nodeSource(decl) callable.CallSites = b.buildCallSites(pkg, decl.Body) callable.LocalVariables = b.buildLocalVars(pkg, decl.Body) + callable.InnerCallables = b.buildInnerCallables(pkg, sig, decl.Body) } return callable @@ -634,6 +646,9 @@ func (b *SymbolTableBuilder) buildCallSites(pkg *packages.Package, body *ast.Blo sites = append(sites, *site) } return false + case *ast.FuncLit: + // Closure bodies are handled by buildInnerCallables; don't double-count. + return false case *ast.CallExpr: site := b.callExprToSite(pkg, node, false) if site != nil { @@ -762,6 +777,49 @@ func (b *SymbolTableBuilder) buildLocalVars(pkg *packages.Package, body *ast.Blo return vars } +// buildInnerCallables walks body and collects each FuncLit as a named closure. +// Only the top level is captured; nested closures appear in the closure's own +// InnerCallables (populated when buildCallSites recurses into lit.Body). +func (b *SymbolTableBuilder) buildInnerCallables(pkg *packages.Package, outerSig string, body *ast.BlockStmt) map[string]schema.GoCallable { + inner := map[string]schema.GoCallable{} + n := 0 + ast.Inspect(body, func(node ast.Node) bool { + lit, ok := node.(*ast.FuncLit) + if !ok { + return true + } + n++ + name := fmt.Sprintf("closure_%d", n) + sig := outerSig + "." + name + pos := b.fset.Position(lit.Pos()) + endPos := b.fset.Position(lit.End()) + _, retTypes := b.buildReturnTypes(pkg, lit.Type) + retType := b.joinReturnTypes(retTypes) + + ic := schema.GoCallable{ + Name: name, + Signature: sig, + Parameters: b.buildParams(pkg, lit.Type), + ReturnType: retType, + ReturnTypes: retTypes, + CallSites: []schema.GoCallsite{}, + InnerCallables: map[string]schema.GoCallable{}, + LocalVariables: []schema.GoVariableDeclaration{}, + StartLine: pos.Line, + EndLine: endPos.Line, + } + if lit.Body != nil { + ic.Code = b.nodeSource(lit) + ic.CallSites = b.buildCallSites(pkg, lit.Body) + ic.LocalVariables = b.buildLocalVars(pkg, lit.Body) + ic.InnerCallables = b.buildInnerCallables(pkg, sig, lit.Body) + } + inner[name] = ic + return false // don't recurse; nested closures are handled above + }) + return inner +} + // ─── Package-level variables ────────────────────────────────────────────────── func (b *SymbolTableBuilder) buildPackageVars(pkg *packages.Package, astFile *ast.File) []schema.GoVariableDeclaration { @@ -826,6 +884,8 @@ func (b *SymbolTableBuilder) cyclomaticComplexity(decl *ast.FuncDecl) int { // ─── Helpers ────────────────────────────────────────────────────────────────── // receiverTypeName extracts the base type name from a receiver field list. +// It handles pointer receivers (*T), generic single-param receivers (T[A]), +// pointer-to-generic (*T[A]), and multi-param generic receivers (*T[A, B]). func (b *SymbolTableBuilder) receiverTypeName(recv *ast.FieldList) string { if recv == nil || len(recv.List) == 0 { return "" @@ -835,6 +895,14 @@ func (b *SymbolTableBuilder) receiverTypeName(recv *ast.FieldList) string { if star, ok := expr.(*ast.StarExpr); ok { expr = star.X } + // Generic single type param: Set[T] → IndexExpr{X: Ident("Set")} + if idx, ok := expr.(*ast.IndexExpr); ok { + expr = idx.X + } + // Generic multi type param: Map[K, V] → IndexListExpr{X: Ident("Map")} + if idx, ok := expr.(*ast.IndexListExpr); ok { + expr = idx.X + } if ident, ok := expr.(*ast.Ident); ok { return ident.Name } diff --git a/internal/utils/fs_test.go b/internal/utils/fs_test.go new file mode 100644 index 0000000..a349d07 --- /dev/null +++ b/internal/utils/fs_test.go @@ -0,0 +1,206 @@ +package utils_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +// ── IsTestFile ──────────────────────────────────────────────────────────────── + +func TestIsTestFile(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"foo_test.go", true}, + {"server_test.go", true}, + {"foo.go", false}, + {"test.go", false}, // doesn't end with _test.go + {"_test.go", true}, // edge case: file is literally "_test.go" + {"", false}, + {"foo_test.go.bak", false}, + } + for _, tc := range tests { + if got := utils.IsTestFile(tc.path); got != tc.want { + t.Errorf("IsTestFile(%q) = %v, want %v", tc.path, got, tc.want) + } + } +} + +// ── IsVendored ──────────────────────────────────────────────────────────────── + +func TestIsVendored(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"vendor/github.com/foo/bar/baz.go", true}, + {"pkg/vendor/something.go", true}, + {"testdata/greeter/main.go", true}, + {".git/config", true}, + {"internal/core/analyzer.go", false}, + {"main.go", false}, + {"", false}, + {"vendored/not-vendor.go", false}, // "vendored" ≠ "vendor" + } + for _, tc := range tests { + if got := utils.IsVendored(tc.path); got != tc.want { + t.Errorf("IsVendored(%q) = %v, want %v", tc.path, got, tc.want) + } + } +} + +// ── FileHash ────────────────────────────────────────────────────────────────── + +func TestFileHash_Deterministic(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "file.txt") + if err := os.WriteFile(f, []byte("hello world"), 0o644); err != nil { + t.Fatal(err) + } + + h1, err := utils.FileHash(f) + if err != nil { + t.Fatalf("FileHash: %v", err) + } + h2, err := utils.FileHash(f) + if err != nil { + t.Fatalf("FileHash second call: %v", err) + } + if h1 != h2 { + t.Errorf("FileHash is not deterministic: %q != %q", h1, h2) + } + if len(h1) != 64 { + t.Errorf("FileHash should return 64-char hex SHA-256; got len %d: %q", len(h1), h1) + } +} + +func TestFileHash_DifferentContent(t *testing.T) { + dir := t.TempDir() + a := filepath.Join(dir, "a.txt") + b := filepath.Join(dir, "b.txt") + os.WriteFile(a, []byte("aaa"), 0o644) + os.WriteFile(b, []byte("bbb"), 0o644) + + ha, _ := utils.FileHash(a) + hb, _ := utils.FileHash(b) + if ha == hb { + t.Error("different files should have different hashes") + } +} + +func TestFileHash_NonExistentFile(t *testing.T) { + _, err := utils.FileHash(filepath.Join(t.TempDir(), "no-such-file")) + if err == nil { + t.Error("expected error for non-existent file, got nil") + } +} + +// ── EnsureDir ───────────────────────────────────────────────────────────────── + +func TestEnsureDir_CreatesDirectory(t *testing.T) { + dir := filepath.Join(t.TempDir(), "a", "b", "c") + if err := utils.EnsureDir(dir); err != nil { + t.Fatalf("EnsureDir: %v", err) + } + if fi, err := os.Stat(dir); err != nil || !fi.IsDir() { + t.Errorf("directory %s was not created", dir) + } +} + +func TestEnsureDir_Idempotent(t *testing.T) { + dir := t.TempDir() + if err := utils.EnsureDir(dir); err != nil { + t.Errorf("EnsureDir on existing dir: %v", err) + } +} + +// ── DiscoverGoFiles ─────────────────────────────────────────────────────────── + +func TestDiscoverGoFiles_FindsGoFiles(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "main.go", "package main") + writeFile(t, dir, "util.go", "package main") + + files, err := utils.DiscoverGoFiles(dir, true) + if err != nil { + t.Fatalf("DiscoverGoFiles: %v", err) + } + if len(files) != 2 { + t.Errorf("got %d files, want 2: %v", len(files), files) + } +} + +func TestDiscoverGoFiles_SkipsTestFiles(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "main.go", "package main") + writeFile(t, dir, "main_test.go", "package main") + + files, err := utils.DiscoverGoFiles(dir, true) + if err != nil { + t.Fatalf("DiscoverGoFiles: %v", err) + } + for _, f := range files { + if utils.IsTestFile(f) { + t.Errorf("skipTests=true: found test file: %s", f) + } + } +} + +func TestDiscoverGoFiles_IncludesTestFilesWhenNotSkipped(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "main.go", "package main") + writeFile(t, dir, "main_test.go", "package main") + + files, err := utils.DiscoverGoFiles(dir, false) + if err != nil { + t.Fatalf("DiscoverGoFiles: %v", err) + } + if len(files) != 2 { + t.Errorf("got %d files, want 2: %v", len(files), files) + } +} + +func TestDiscoverGoFiles_SkipsVendorDir(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "main.go", "package main") + vendor := filepath.Join(dir, "vendor", "pkg") + os.MkdirAll(vendor, 0o755) + writeFile(t, vendor, "lib.go", "package pkg") + + files, err := utils.DiscoverGoFiles(dir, true) + if err != nil { + t.Fatalf("DiscoverGoFiles: %v", err) + } + if len(files) != 1 { + t.Errorf("got %d files, want 1 (vendor should be skipped): %v", len(files), files) + } +} + +func TestDiscoverGoFiles_IgnoresNonGoFiles(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "main.go", "package main") + writeFile(t, dir, "readme.md", "# readme") + writeFile(t, dir, "config.yaml", "key: val") + + files, err := utils.DiscoverGoFiles(dir, true) + if err != nil { + t.Fatalf("DiscoverGoFiles: %v", err) + } + if len(files) != 1 { + t.Errorf("got %d files, want 1: %v", len(files), files) + } +} + +// ── helpers ─────────────────────────────────────────────────────────────────── + +func writeFile(t *testing.T, dir, name, content string) { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("writeFile %s: %v", name, err) + } +} diff --git a/testdata/chi/chain.go b/testdata/chi/chain.go new file mode 100644 index 0000000..a227841 --- /dev/null +++ b/testdata/chi/chain.go @@ -0,0 +1,49 @@ +package chi + +import "net/http" + +// Chain returns a Middlewares type from a slice of middleware handlers. +func Chain(middlewares ...func(http.Handler) http.Handler) Middlewares { + return Middlewares(middlewares) +} + +// Handler builds and returns a http.Handler from the chain of middlewares, +// with `h http.Handler` as the final handler. +func (mws Middlewares) Handler(h http.Handler) http.Handler { + return &ChainHandler{h, chain(mws, h), mws} +} + +// HandlerFunc builds and returns a http.Handler from the chain of middlewares, +// with `h http.Handler` as the final handler. +func (mws Middlewares) HandlerFunc(h http.HandlerFunc) http.Handler { + return &ChainHandler{h, chain(mws, h), mws} +} + +// ChainHandler is a http.Handler with support for handler composition and +// execution. +type ChainHandler struct { + Endpoint http.Handler + chain http.Handler + Middlewares Middlewares +} + +func (c *ChainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.chain.ServeHTTP(w, r) +} + +// chain builds a http.Handler composed of an inline middleware stack and endpoint +// handler in the order they are passed. +func chain(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler { + // Return ahead of time if there aren't any middlewares for the chain + if len(middlewares) == 0 { + return endpoint + } + + // Wrap the end handler with the middleware chain + h := middlewares[len(middlewares)-1](endpoint) + for i := len(middlewares) - 2; i >= 0; i-- { + h = middlewares[i](h) + } + + return h +} diff --git a/testdata/chi/chi.go b/testdata/chi/chi.go new file mode 100644 index 0000000..ad0ca74 --- /dev/null +++ b/testdata/chi/chi.go @@ -0,0 +1,137 @@ +// Package chi is a small, idiomatic and composable router for building HTTP services. +// +// chi supports the four most recent major versions of Go. +// +// Example: +// +// package main +// +// import ( +// "net/http" +// +// "github.com/go-chi/chi/v5" +// "github.com/go-chi/chi/v5/middleware" +// ) +// +// func main() { +// r := chi.NewRouter() +// r.Use(middleware.Logger) +// r.Use(middleware.Recoverer) +// +// r.Get("/", func(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("root.")) +// }) +// +// http.ListenAndServe(":3333", r) +// } +// +// See github.com/go-chi/chi/_examples/ for more in-depth examples. +// +// URL patterns allow for easy matching of path components in HTTP +// requests. The matching components can then be accessed using +// chi.URLParam(). All patterns must begin with a slash. +// +// A simple named placeholder {name} matches any sequence of characters +// up to the next / or the end of the URL. Trailing slashes on paths must +// be handled explicitly. +// +// A placeholder with a name followed by a colon allows a regular +// expression match, for example {number:\\d+}. The regular expression +// syntax is Go's normal regexp RE2 syntax, except that / will never be +// matched. An anonymous regexp pattern is allowed, using an empty string +// before the colon in the placeholder, such as {:\\d+} +// +// The special placeholder of asterisk matches the rest of the requested +// URL. Any trailing characters in the pattern are ignored. This is the only +// placeholder which will match / characters. +// +// Examples: +// +// "/user/{name}" matches "/user/jsmith" but not "/user/jsmith/info" or "/user/jsmith/" +// "/user/{name}/info" matches "/user/jsmith/info" +// "/page/*" matches "/page/intro/latest" +// "/page/{other}/latest" also matches "/page/intro/latest" +// "/date/{yyyy:\\d\\d\\d\\d}/{mm:\\d\\d}/{dd:\\d\\d}" matches "/date/2017/04/01" +package chi + +import "net/http" + +// NewRouter returns a new Mux object that implements the Router interface. +func NewRouter() *Mux { + return NewMux() +} + +// Router consisting of the core routing methods used by chi's Mux, +// using only the standard net/http. +type Router interface { + http.Handler + Routes + + // Use appends one or more middlewares onto the Router stack. + Use(middlewares ...func(http.Handler) http.Handler) + + // With adds inline middlewares for an endpoint handler. + With(middlewares ...func(http.Handler) http.Handler) Router + + // Group adds a new inline-Router along the current routing + // path, with a fresh middleware stack for the inline-Router. + Group(fn func(r Router)) Router + + // Route mounts a sub-Router along a `pattern` string. + Route(pattern string, fn func(r Router)) Router + + // Mount attaches another http.Handler along ./pattern/* + Mount(pattern string, h http.Handler) + + // Handle and HandleFunc adds routes for `pattern` that matches + // all HTTP methods. + Handle(pattern string, h http.Handler) + HandleFunc(pattern string, h http.HandlerFunc) + + // Method and MethodFunc adds routes for `pattern` that matches + // the `method` HTTP method. + Method(method, pattern string, h http.Handler) + MethodFunc(method, pattern string, h http.HandlerFunc) + + // HTTP-method routing along `pattern` + Connect(pattern string, h http.HandlerFunc) + Delete(pattern string, h http.HandlerFunc) + Get(pattern string, h http.HandlerFunc) + Head(pattern string, h http.HandlerFunc) + Options(pattern string, h http.HandlerFunc) + Patch(pattern string, h http.HandlerFunc) + Post(pattern string, h http.HandlerFunc) + Put(pattern string, h http.HandlerFunc) + Trace(pattern string, h http.HandlerFunc) + + // NotFound defines a handler to respond whenever a route could + // not be found. + NotFound(h http.HandlerFunc) + + // MethodNotAllowed defines a handler to respond whenever a method is + // not allowed. + MethodNotAllowed(h http.HandlerFunc) +} + +// Routes interface adds two methods for router traversal, which is also +// used by the `docgen` subpackage to generation documentation for Routers. +type Routes interface { + // Routes returns the routing tree in an easily traversable structure. + Routes() []Route + + // Middlewares returns the list of middlewares in use by the router. + Middlewares() Middlewares + + // Match searches the routing tree for a handler that matches + // the method/path - similar to routing a http request, but without + // executing the handler thereafter. + Match(rctx *Context, method, path string) bool + + // Find searches the routing tree for the pattern that matches + // the method/path. + Find(rctx *Context, method, path string) string +} + +// Middlewares type is a slice of standard middleware handlers with methods +// to compose middleware chains and http.Handler's. +type Middlewares []func(http.Handler) http.Handler diff --git a/testdata/chi/context.go b/testdata/chi/context.go new file mode 100644 index 0000000..8222073 --- /dev/null +++ b/testdata/chi/context.go @@ -0,0 +1,166 @@ +package chi + +import ( + "context" + "net/http" + "strings" +) + +// URLParam returns the url parameter from a http.Request object. +func URLParam(r *http.Request, key string) string { + if rctx := RouteContext(r.Context()); rctx != nil { + return rctx.URLParam(key) + } + return "" +} + +// URLParamFromCtx returns the url parameter from a http.Request Context. +func URLParamFromCtx(ctx context.Context, key string) string { + if rctx := RouteContext(ctx); rctx != nil { + return rctx.URLParam(key) + } + return "" +} + +// RouteContext returns chi's routing Context object from a +// http.Request Context. +func RouteContext(ctx context.Context) *Context { + val, _ := ctx.Value(RouteCtxKey).(*Context) + return val +} + +// NewRouteContext returns a new routing Context object. +func NewRouteContext() *Context { + return &Context{} +} + +var ( + // RouteCtxKey is the context.Context key to store the request context. + RouteCtxKey = &contextKey{"RouteContext"} +) + +// Context is the default routing context set on the root node of a +// request context to track route patterns, URL parameters and +// an optional routing path. +type Context struct { + Routes Routes + + // parentCtx is the parent of this one, for using Context as a + // context.Context directly. This is an optimization that saves + // 1 allocation. + parentCtx context.Context + + // Routing path/method override used during the route search. + // See Mux#routeHTTP method. + RoutePath string + RouteMethod string + + // URLParams are the stack of routeParams captured during the + // routing lifecycle across a stack of sub-routers. + URLParams RouteParams + + // Route parameters matched for the current sub-router. It is + // intentionally unexported so it can't be tampered. + routeParams RouteParams + + // The endpoint routing pattern that matched the request URI path + // or `RoutePath` of the current sub-router. This value will update + // during the lifecycle of a request passing through a stack of + // sub-routers. + routePattern string + + // Routing pattern stack throughout the lifecycle of the request, + // across all connected routers. It is a record of all matching + // patterns across a stack of sub-routers. + RoutePatterns []string + + methodsAllowed []methodTyp // allowed methods in case of a 405 + methodNotAllowed bool +} + +// Reset a routing context to its initial state. +func (x *Context) Reset() { + x.Routes = nil + x.RoutePath = "" + x.RouteMethod = "" + x.RoutePatterns = x.RoutePatterns[:0] + x.URLParams.Keys = x.URLParams.Keys[:0] + x.URLParams.Values = x.URLParams.Values[:0] + + x.routePattern = "" + x.routeParams.Keys = x.routeParams.Keys[:0] + x.routeParams.Values = x.routeParams.Values[:0] + x.methodNotAllowed = false + x.methodsAllowed = x.methodsAllowed[:0] + x.parentCtx = nil +} + +// URLParam returns the corresponding URL parameter value from the request +// routing context. +func (x *Context) URLParam(key string) string { + for k := len(x.URLParams.Keys) - 1; k >= 0; k-- { + if x.URLParams.Keys[k] == key { + return x.URLParams.Values[k] + } + } + return "" +} + +// RoutePattern builds the routing pattern string for the particular +// request, at the particular point during routing. This means, the value +// will change throughout the execution of a request in a router. That is +// why it's advised to only use this value after calling the next handler. +// +// For example, +// +// func Instrument(next http.Handler) http.Handler { +// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// next.ServeHTTP(w, r) +// routePattern := chi.RouteContext(r.Context()).RoutePattern() +// measure(w, r, routePattern) +// }) +// } +func (x *Context) RoutePattern() string { + if x == nil { + return "" + } + routePattern := strings.Join(x.RoutePatterns, "") + routePattern = replaceWildcards(routePattern) + if routePattern != "/" { + routePattern = strings.TrimSuffix(routePattern, "//") + routePattern = strings.TrimSuffix(routePattern, "/") + } + return routePattern +} + +// replaceWildcards takes a route pattern and replaces all occurrences of +// "/*/" with "/". It iteratively runs until no wildcards remain to +// correctly handle consecutive wildcards. +func replaceWildcards(p string) string { + for strings.Contains(p, "/*/") { + p = strings.ReplaceAll(p, "/*/", "/") + } + return p +} + +// RouteParams is a structure to track URL routing parameters efficiently. +type RouteParams struct { + Keys, Values []string +} + +// Add will append a URL parameter to the end of the route param +func (s *RouteParams) Add(key, value string) { + s.Keys = append(s.Keys, key) + s.Values = append(s.Values, value) +} + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. This technique +// for defining context keys was copied from Go 1.7's new use of context in net/http. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { + return "chi context value " + k.name +} diff --git a/testdata/chi/go.mod b/testdata/chi/go.mod new file mode 100644 index 0000000..4c47b09 --- /dev/null +++ b/testdata/chi/go.mod @@ -0,0 +1,3 @@ +module github.com/go-chi/chi/v5 + +go 1.23 diff --git a/testdata/chi/middleware/basic_auth.go b/testdata/chi/middleware/basic_auth.go new file mode 100644 index 0000000..a546c9e --- /dev/null +++ b/testdata/chi/middleware/basic_auth.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "crypto/subtle" + "fmt" + "net/http" +) + +// BasicAuth implements a simple middleware handler for adding basic http auth to a route. +func BasicAuth(realm string, creds map[string]string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, pass, ok := r.BasicAuth() + if !ok { + basicAuthFailed(w, realm) + return + } + + credPass, credUserOk := creds[user] + if !credUserOk || subtle.ConstantTimeCompare([]byte(pass), []byte(credPass)) != 1 { + basicAuthFailed(w, realm) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +func basicAuthFailed(w http.ResponseWriter, realm string) { + w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm)) + w.WriteHeader(http.StatusUnauthorized) +} diff --git a/testdata/chi/middleware/clean_path.go b/testdata/chi/middleware/clean_path.go new file mode 100644 index 0000000..adeba42 --- /dev/null +++ b/testdata/chi/middleware/clean_path.go @@ -0,0 +1,28 @@ +package middleware + +import ( + "net/http" + "path" + + "github.com/go-chi/chi/v5" +) + +// CleanPath middleware will clean out double slash mistakes from a user's request path. +// For example, if a user requests /users//1 or //users////1 will both be treated as: /users/1 +func CleanPath(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rctx := chi.RouteContext(r.Context()) + + routePath := rctx.RoutePath + if routePath == "" { + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } else { + routePath = r.URL.Path + } + rctx.RoutePath = path.Clean(routePath) + } + + next.ServeHTTP(w, r) + }) +} diff --git a/testdata/chi/middleware/client_ip.go b/testdata/chi/middleware/client_ip.go new file mode 100644 index 0000000..1495a86 --- /dev/null +++ b/testdata/chi/middleware/client_ip.go @@ -0,0 +1,263 @@ +package middleware + +import ( + "context" + "net" + "net/http" + "net/netip" + "strings" +) + +// clientIPCtxKey stores the client IP set by any of the ClientIPFrom* middlewares. +var clientIPCtxKey = &contextKey{"clientIP"} + +// xForwardedForHeader is the canonical form of the X-Forwarded-For header +// name, used by the XFF-based middlewares. +const xForwardedForHeader = "X-Forwarded-For" + +// ClientIPFromHeader stores the client IP from a single-IP header set by +// your reverse proxy. Read it with [GetClientIP]. +// +// Only safe with headers your proxy unconditionally OVERWRITES on every +// request, e.g.: +// +// - X-Real-IP — Nginx with ngx_http_realip_module +// - X-Client-IP — Apache with mod_remoteip +// - CF-Connecting-IP — Cloudflare +// +// True-Client-IP, X-Azure-ClientIP, and Fastly-Client-IP look similar but +// pass through from the client by default in those products; don't use them +// unless your edge strips the inbound value. +// +// If the header reaches us with multiple values (misconfigured proxy that +// appends, or a downstream proxy not stripping a client-supplied value), +// the LAST value wins — that's the one set by the hop closest to us, and +// therefore the most trusted. Fail-closed if the last value doesn't parse: +// no client IP is set rather than falling back to earlier (less-trusted) +// values. +// +// v4-mapped IPv6 (::ffff:a.b.c.d) folds to plain v4 and IPv6 zones are +// stripped before storage. +func ClientIPFromHeader(trustedHeader string) func(http.Handler) http.Handler { + header := http.CanonicalHeaderKey(trustedHeader) + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + values := r.Header.Values(header) + if len(values) > 0 { + if ip, ok := parseHeaderAddr(values[len(values)-1]); ok { + r = r.WithContext(context.WithValue(r.Context(), clientIPCtxKey, ip)) + } + } + h.ServeHTTP(w, r) + }) + } +} + +// ClientIPFromXFF stores the client IP read from the X-Forwarded-For header, +// walking the chain right-to-left and skipping any IP that falls within one +// of the given trusted CIDR prefixes. The first IP that is not trusted is +// the client. Read it with [GetClientIP]. +// +// An unparseable entry mid-chain aborts the walk and leaves no client IP +// set (fail-closed) — we can't safely trust anything left of garbage. +// +// Use this when you sit behind one or more reverse proxies whose IP ranges +// you can enumerate as CIDRs: +// +// r.Use(middleware.ClientIPFromXFF( +// "13.32.0.0/15", // CloudFront IPv4 +// "52.46.0.0/18", // CloudFront IPv4 +// "2600:9000::/28", // CloudFront IPv6 +// )) +// +// Calling with no arguments returns the rightmost XFF entry, or no IP if +// that entry doesn't parse (fail-closed) — safe only if you have exactly +// one trusted hop directly in front of this server (e.g., nginx on localhost). +// +// v4-mapped IPv6 (::ffff:a.b.c.d) folds to plain v4 and IPv6 zones are +// stripped before the prefix check and storage; otherwise an attacker +// could use either notation to alias a trusted IP past the check. +// +// If you know the number of trusted proxies but not their IPs, use +// [ClientIPFromXFFTrustedProxies] instead. +// +// Panics at startup if any prefix is invalid. +func ClientIPFromXFF(trustedIPPrefixes ...string) func(http.Handler) http.Handler { + prefixes := make([]netip.Prefix, len(trustedIPPrefixes)) + for i, p := range trustedIPPrefixes { + prefixes[i] = netip.MustParsePrefix(p) + } + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var found netip.Addr + walkXFF(r.Header[xForwardedForHeader], func(v string) bool { + ip, ok := parseHeaderAddr(v) + if !ok { + return true // fail-closed; leave found unset + } + if inAnyPrefix(ip, prefixes) { + return false // trusted hop; keep walking left + } + found = ip + return true + }) + if found.IsValid() { + r = r.WithContext(context.WithValue(r.Context(), clientIPCtxKey, found)) + } + h.ServeHTTP(w, r) + }) + } +} + +// ClientIPFromXFFTrustedProxies stores the client IP read from the +// X-Forwarded-For header, given the exact number of trusted reverse proxies +// between this server and the public internet. It returns the IP at position +// len(xff) - numTrustedProxies in the merged X-Forwarded-For list — the IP +// added by the outermost of your trusted proxies, the only IP in the chain +// that none of your proxies have allowed an attacker to forge. Read it with +// [GetClientIP]. +// +// Use this when: +// - You know exactly how many proxies you sit behind, AND +// - Their IP addresses are dynamic (autoscaling proxy pools, ephemeral +// containers, dynamic CDN edges) so listing CIDRs with [ClientIPFromXFF] +// is impractical. +// +// WARNING: This variant is brittle to network architecture changes. If you +// add or remove a proxy level, numTrustedProxies silently becomes wrong and +// you may start trusting an attacker-supplied IP. Prefer [ClientIPFromXFF] +// with explicit trusted CIDRs whenever you can. +// +// If the XFF chain has fewer than numTrustedProxies entries (header missing +// or architecture changed), no client IP is set and [GetClientIP] returns "". +// +// Like [ClientIPFromXFF], v4-mapped IPv6 folds to plain v4 and IPv6 zones +// are stripped before storage. +// +// Panics at startup if numTrustedProxies < 1. +func ClientIPFromXFFTrustedProxies(numTrustedProxies int) func(http.Handler) http.Handler { + if numTrustedProxies < 1 { + panic("middleware.ClientIPFromXFFTrustedProxies: numTrustedProxies must be >= 1") + } + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := numTrustedProxies + var entry string + walkXFF(r.Header[xForwardedForHeader], func(v string) bool { + n-- + if n == 0 { + entry = v + return true + } + return false + }) + if entry != "" { + if ip, ok := parseHeaderAddr(entry); ok { + r = r.WithContext(context.WithValue(r.Context(), clientIPCtxKey, ip)) + } + } + h.ServeHTTP(w, r) + }) + } +} + +// ClientIPFromRemoteAddr stores the client IP read from the TCP RemoteAddr +// of the incoming request — the IP address of whoever opened the connection +// to this server. Read it with [GetClientIP]. +// +// Use this when this server is directly connected to the public internet +// with NO reverse proxy in front of it. Behind a reverse proxy, RemoteAddr +// is the proxy's IP, not the client's — use [ClientIPFromHeader] or +// [ClientIPFromXFF] instead. +// +// IPv4 clients on a dual-stack listener surface as ::ffff:a.b.c.d; they +// fold to plain v4 before storage so one logical client maps to one key. +// IPv6 zones are preserved (link-local connections may legitimately have one). +func ClientIPFromRemoteAddr(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr // RemoteAddr may already be a bare IP (e.g. in tests). + } + if ip, err := netip.ParseAddr(host); err == nil { + r = r.WithContext(context.WithValue(r.Context(), clientIPCtxKey, ip.Unmap())) + } + h.ServeHTTP(w, r) + }) +} + +// GetClientIP returns the client IP as a string, as set by one of the +// ClientIPFrom* middlewares. Returns "" if no valid IP was set. +// Convenient for logging, rate-limit keys, etc. +func GetClientIP(ctx context.Context) string { + ip := GetClientIPAddr(ctx) + if !ip.IsValid() { + return "" + } + return ip.String() +} + +// GetClientIPAddr returns the client IP as a [netip.Addr], as set by one of +// the ClientIPFrom* middlewares. The returned Addr is the zero value if not +// set; use [netip.Addr.IsValid] to check. Useful when you need typed work — +// prefix containment, Is4/Is6, etc. — without re-parsing the string. +func GetClientIPAddr(ctx context.Context) netip.Addr { + ip, _ := ctx.Value(clientIPCtxKey).(netip.Addr) + return ip +} + +// walkXFF walks the entries of the merged X-Forwarded-For chain +// RIGHT-TO-LEFT, invoking visit on each trimmed non-empty entry. visit +// returns true to stop the walk. Lazy walk, zero allocations (entries +// are substrings of the input headers). +// +// Multiple XFF headers are merged per RFC 2616 — each header's +// comma-separated entries in order received — so an attacker cannot pick +// which value security logic sees by sending a duplicate header. +func walkXFF(headers []string, visit func(entry string) bool) { + for hi := len(headers) - 1; hi >= 0; hi-- { + h := headers[hi] + for h != "" { + var v string + if i := strings.LastIndexByte(h, ','); i >= 0 { + v, h = h[i+1:], h[:i] + } else { + v, h = h, "" + } + v = strings.TrimSpace(v) + if v == "" { + continue + } + if visit(v) { + return + } + } + } +} + +// inAnyPrefix reports whether ip falls within any of the given prefixes. +func inAnyPrefix(ip netip.Addr, prefixes []netip.Prefix) bool { + for _, p := range prefixes { + if p.Contains(ip) { + return true + } + } + return false +} + +// parseHeaderAddr parses s and normalizes for storage: v4-mapped IPv6 +// (::ffff:a.b.c.d) folds to plain v4, IPv6 zone is stripped. Both defend the +// trust-prefix check against attacker-injected aliases — [netip.Prefix.Contains] +// returns false for v4-mapped addresses vs v4 prefixes and for any zoned +// address, so without folding/stripping an attacker could escape an +// otherwise valid trust list. +// +// Header-sourced IPs only. [ClientIPFromRemoteAddr] normalizes inline +// (Unmap, but zone preserved for legitimate link-local connections). +func parseHeaderAddr(s string) (netip.Addr, bool) { + ip, err := netip.ParseAddr(s) + if err != nil { + return netip.Addr{}, false + } + return ip.Unmap().WithZone(""), true +} diff --git a/testdata/chi/middleware/compress.go b/testdata/chi/middleware/compress.go new file mode 100644 index 0000000..4e46f70 --- /dev/null +++ b/testdata/chi/middleware/compress.go @@ -0,0 +1,392 @@ +package middleware + +import ( + "bufio" + "compress/flate" + "compress/gzip" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" +) + +var defaultCompressibleContentTypes = []string{ + "text/html", + "text/css", + "text/plain", + "text/javascript", + "application/javascript", + "application/x-javascript", + "application/json", + "application/atom+xml", + "application/rss+xml", + "image/svg+xml", +} + +// Compress is a middleware that compresses response +// body of a given content types to a data format based +// on Accept-Encoding request header. It uses a given +// compression level. +// +// NOTE: make sure to set the Content-Type header on your response +// otherwise this middleware will not compress the response body. For ex, in +// your handler you should set w.Header().Set("Content-Type", http.DetectContentType(yourBody)) +// or set it manually. +// +// Passing a compression level of 5 is sensible value +func Compress(level int, types ...string) func(next http.Handler) http.Handler { + compressor := NewCompressor(level, types...) + return compressor.Handler +} + +// Compressor represents a set of encoding configurations. +type Compressor struct { + // The mapping of encoder names to encoder functions. + encoders map[string]EncoderFunc + // The mapping of pooled encoders to pools. + pooledEncoders map[string]*sync.Pool + // The set of content types allowed to be compressed. + allowedTypes map[string]struct{} + allowedWildcards map[string]struct{} + // The list of encoders in order of decreasing precedence. + encodingPrecedence []string + level int // The compression level. +} + +// NewCompressor creates a new Compressor that will handle encoding responses. +// +// The level should be one of the ones defined in the flate package. +// The types are the content types that are allowed to be compressed. +func NewCompressor(level int, types ...string) *Compressor { + // If types are provided, set those as the allowed types. If none are + // provided, use the default list. + allowedTypes := make(map[string]struct{}) + allowedWildcards := make(map[string]struct{}) + if len(types) > 0 { + for _, t := range types { + if strings.Contains(strings.TrimSuffix(t, "/*"), "*") { + panic(fmt.Sprintf("middleware/compress: Unsupported content-type wildcard pattern '%s'. Only '/*' supported", t)) + } + if before, ok := strings.CutSuffix(t, "/*"); ok { + allowedWildcards[before] = struct{}{} + } else { + allowedTypes[t] = struct{}{} + } + } + } else { + for _, t := range defaultCompressibleContentTypes { + allowedTypes[t] = struct{}{} + } + } + + c := &Compressor{ + level: level, + encoders: make(map[string]EncoderFunc), + pooledEncoders: make(map[string]*sync.Pool), + allowedTypes: allowedTypes, + allowedWildcards: allowedWildcards, + } + + // Set the default encoders. The precedence order uses the reverse + // ordering that the encoders were added. This means adding new encoders + // will move them to the front of the order. + // + // TODO: + // lzma: Opera. + // sdch: Chrome, Android. Gzip output + dictionary header. + // br: Brotli, see https://github.com/go-chi/chi/pull/326 + + // HTTP 1.1 "deflate" (RFC 2616) stands for DEFLATE data (RFC 1951) + // wrapped with zlib (RFC 1950). The zlib wrapper uses Adler-32 + // checksum compared to CRC-32 used in "gzip" and thus is faster. + // + // But.. some old browsers (MSIE, Safari 5.1) incorrectly expect + // raw DEFLATE data only, without the mentioned zlib wrapper. + // Because of this major confusion, most modern browsers try it + // both ways, first looking for zlib headers. + // Quote by Mark Adler: http://stackoverflow.com/a/9186091/385548 + // + // The list of browsers having problems is quite big, see: + // http://zoompf.com/blog/2012/02/lose-the-wait-http-compression + // https://web.archive.org/web/20120321182910/http://www.vervestudios.co/projects/compression-tests/results + // + // That's why we prefer gzip over deflate. It's just more reliable + // and not significantly slower than deflate. + c.SetEncoder("deflate", encoderDeflate) + + // TODO: Exception for old MSIE browsers that can't handle non-HTML? + // https://zoompf.com/blog/2012/02/lose-the-wait-http-compression + c.SetEncoder("gzip", encoderGzip) + + // NOTE: Not implemented, intentionally: + // case "compress": // LZW. Deprecated. + // case "bzip2": // Too slow on-the-fly. + // case "zopfli": // Too slow on-the-fly. + // case "xz": // Too slow on-the-fly. + return c +} + +// SetEncoder can be used to set the implementation of a compression algorithm. +// +// The encoding should be a standardised identifier. See: +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding +// +// For example, add the Brotli algorithm: +// +// import brotli_enc "gopkg.in/kothar/brotli-go.v0/enc" +// +// compressor := middleware.NewCompressor(5, "text/html") +// compressor.SetEncoder("br", func(w io.Writer, level int) io.Writer { +// params := brotli_enc.NewBrotliParams() +// params.SetQuality(level) +// return brotli_enc.NewBrotliWriter(params, w) +// }) +func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) { + encoding = strings.ToLower(encoding) + if encoding == "" { + panic("the encoding can not be empty") + } + if fn == nil { + panic("attempted to set a nil encoder function") + } + + // If we are adding a new encoder that is already registered, we have to + // clear that one out first. + delete(c.pooledEncoders, encoding) + delete(c.encoders, encoding) + + // If the encoder supports Resetting (IoReseterWriter), then it can be pooled. + encoder := fn(io.Discard, c.level) + if _, ok := encoder.(ioResetterWriter); ok { + pool := &sync.Pool{ + New: func() interface{} { + return fn(io.Discard, c.level) + }, + } + c.pooledEncoders[encoding] = pool + } + // If the encoder is not in the pooledEncoders, add it to the normal encoders. + if _, ok := c.pooledEncoders[encoding]; !ok { + c.encoders[encoding] = fn + } + + for i, v := range c.encodingPrecedence { + if v == encoding { + c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...) + } + } + + c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...) +} + +// Handler returns a new middleware that will compress the response based on the +// current Compressor. +func (c *Compressor) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + encoder, encoding, cleanup := c.selectEncoder(r.Header, w) + + cw := &compressResponseWriter{ + ResponseWriter: w, + w: w, + contentTypes: c.allowedTypes, + contentWildcards: c.allowedWildcards, + encoding: encoding, + compressible: false, // determined in post-handler + } + if encoder != nil { + cw.w = encoder + } + // Re-add the encoder to the pool if applicable. + defer cleanup() + defer cw.Close() + + next.ServeHTTP(cw, r) + }) +} + +// selectEncoder returns the encoder, the name of the encoder, and a closer function. +func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) { + header := h.Get("Accept-Encoding") + + // Parse the names of all accepted algorithms from the header. + accepted := strings.Split(strings.ToLower(header), ",") + + // Find supported encoder by accepted list by precedence + for _, name := range c.encodingPrecedence { + if matchAcceptEncoding(accepted, name) { + if pool, ok := c.pooledEncoders[name]; ok { + encoder := pool.Get().(ioResetterWriter) + cleanup := func() { + pool.Put(encoder) + } + encoder.Reset(w) + return encoder, name, cleanup + + } + if fn, ok := c.encoders[name]; ok { + return fn(w, c.level), name, func() {} + } + } + + } + + // No encoder found to match the accepted encoding + return nil, "", func() {} +} + +func matchAcceptEncoding(accepted []string, encoding string) bool { + for _, v := range accepted { + if strings.Contains(v, encoding) { + return true + } + } + return false +} + +// An EncoderFunc is a function that wraps the provided io.Writer with a +// streaming compression algorithm and returns it. +// +// In case of failure, the function should return nil. +type EncoderFunc func(w io.Writer, level int) io.Writer + +// Interface for types that allow resetting io.Writers. +type ioResetterWriter interface { + io.Writer + Reset(w io.Writer) +} + +type compressResponseWriter struct { + http.ResponseWriter + + // The streaming encoder writer to be used if there is one. Otherwise, + // this is just the normal writer. + w io.Writer + contentTypes map[string]struct{} + contentWildcards map[string]struct{} + encoding string + wroteHeader bool + compressible bool +} + +func (cw *compressResponseWriter) isCompressible() bool { + // Parse the first part of the Content-Type response header. + contentType := cw.Header().Get("Content-Type") + contentType, _, _ = strings.Cut(contentType, ";") + + // Is the content type compressible? + if _, ok := cw.contentTypes[contentType]; ok { + return true + } + if contentType, _, hadSlash := strings.Cut(contentType, "/"); hadSlash { + _, ok := cw.contentWildcards[contentType] + return ok + } + return false +} + +func (cw *compressResponseWriter) WriteHeader(code int) { + if cw.wroteHeader { + cw.ResponseWriter.WriteHeader(code) // Allow multiple calls to propagate. + return + } + cw.wroteHeader = true + defer cw.ResponseWriter.WriteHeader(code) + + // Already compressed data? + if cw.Header().Get("Content-Encoding") != "" { + return + } + + if !cw.isCompressible() { + cw.compressible = false + return + } + + if cw.encoding != "" { + cw.compressible = true + cw.Header().Set("Content-Encoding", cw.encoding) + cw.Header().Add("Vary", "Accept-Encoding") + + // The content-length after compression is unknown + cw.Header().Del("Content-Length") + } +} + +func (cw *compressResponseWriter) Write(p []byte) (int, error) { + if !cw.wroteHeader { + cw.WriteHeader(http.StatusOK) + } + + return cw.writer().Write(p) +} + +func (cw *compressResponseWriter) writer() io.Writer { + if cw.compressible { + return cw.w + } + return cw.ResponseWriter +} + +type compressFlusher interface { + Flush() error +} + +func (cw *compressResponseWriter) Flush() { + if f, ok := cw.writer().(http.Flusher); ok { + f.Flush() + } + // If the underlying writer has a compression flush signature, + // call this Flush() method instead + if f, ok := cw.writer().(compressFlusher); ok { + f.Flush() + + // Also flush the underlying response writer + if f, ok := cw.ResponseWriter.(http.Flusher); ok { + f.Flush() + } + } +} + +func (cw *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := cw.writer().(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, errors.New("chi/middleware: http.Hijacker is unavailable on the writer") +} + +func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) error { + if ps, ok := cw.writer().(http.Pusher); ok { + return ps.Push(target, opts) + } + return errors.New("chi/middleware: http.Pusher is unavailable on the writer") +} + +func (cw *compressResponseWriter) Close() error { + if c, ok := cw.writer().(io.WriteCloser); ok { + return c.Close() + } + return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer") +} + +func (cw *compressResponseWriter) Unwrap() http.ResponseWriter { + return cw.ResponseWriter +} + +func encoderGzip(w io.Writer, level int) io.Writer { + gw, err := gzip.NewWriterLevel(w, level) + if err != nil { + return nil + } + return gw +} + +func encoderDeflate(w io.Writer, level int) io.Writer { + dw, err := flate.NewWriter(w, level) + if err != nil { + return nil + } + return dw +} diff --git a/testdata/chi/middleware/content_charset.go b/testdata/chi/middleware/content_charset.go new file mode 100644 index 0000000..8e75fe8 --- /dev/null +++ b/testdata/chi/middleware/content_charset.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "net/http" + "slices" + "strings" +) + +// ContentCharset generates a handler that writes a 415 Unsupported Media Type response if none of the charsets match. +// An empty charset will allow requests with no Content-Type header or no specified charset. +func ContentCharset(charsets ...string) func(next http.Handler) http.Handler { + for i, c := range charsets { + charsets[i] = strings.ToLower(c) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !contentEncoding(r.Header.Get("Content-Type"), charsets...) { + w.WriteHeader(http.StatusUnsupportedMediaType) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// Check the content encoding against a list of acceptable values. +func contentEncoding(ce string, charsets ...string) bool { + _, ce = split(strings.ToLower(ce), ";") + _, ce = split(ce, "charset=") + ce, _ = split(ce, ";") + return slices.Contains(charsets, ce) +} + +// Split a string in two parts, cleaning any whitespace. +func split(str, sep string) (string, string) { + a, b, found := strings.Cut(str, sep) + a = strings.TrimSpace(a) + if found { + b = strings.TrimSpace(b) + } + + return a, b +} diff --git a/testdata/chi/middleware/content_encoding.go b/testdata/chi/middleware/content_encoding.go new file mode 100644 index 0000000..e0b9ccc --- /dev/null +++ b/testdata/chi/middleware/content_encoding.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// AllowContentEncoding enforces a whitelist of request Content-Encoding otherwise responds +// with a 415 Unsupported Media Type status. +func AllowContentEncoding(contentEncoding ...string) func(next http.Handler) http.Handler { + allowedEncodings := make(map[string]struct{}, len(contentEncoding)) + for _, encoding := range contentEncoding { + allowedEncodings[strings.TrimSpace(strings.ToLower(encoding))] = struct{}{} + } + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + requestEncodings := r.Header["Content-Encoding"] + // skip check for empty content body or no Content-Encoding + if r.ContentLength == 0 { + next.ServeHTTP(w, r) + return + } + // All encodings in the request must be allowed + for _, encoding := range requestEncodings { + if _, ok := allowedEncodings[strings.TrimSpace(strings.ToLower(encoding))]; !ok { + w.WriteHeader(http.StatusUnsupportedMediaType) + return + } + } + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} diff --git a/testdata/chi/middleware/content_type.go b/testdata/chi/middleware/content_type.go new file mode 100644 index 0000000..cdfc21e --- /dev/null +++ b/testdata/chi/middleware/content_type.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// SetHeader is a convenience handler to set a response header key/value +func SetHeader(key, value string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(key, value) + next.ServeHTTP(w, r) + }) + } +} + +// AllowContentType enforces a whitelist of request Content-Types otherwise responds +// with a 415 Unsupported Media Type status. +func AllowContentType(contentTypes ...string) func(http.Handler) http.Handler { + allowedContentTypes := make(map[string]struct{}, len(contentTypes)) + for _, ctype := range contentTypes { + allowedContentTypes[strings.TrimSpace(strings.ToLower(ctype))] = struct{}{} + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ContentLength == 0 { + // Skip check for empty content body + next.ServeHTTP(w, r) + return + } + + s, _, _ := strings.Cut(r.Header.Get("Content-Type"), ";") + s = strings.ToLower(strings.TrimSpace(s)) + + if _, ok := allowedContentTypes[s]; ok { + next.ServeHTTP(w, r) + return + } + + w.WriteHeader(http.StatusUnsupportedMediaType) + }) + } +} diff --git a/testdata/chi/middleware/get_head.go b/testdata/chi/middleware/get_head.go new file mode 100644 index 0000000..d4606d8 --- /dev/null +++ b/testdata/chi/middleware/get_head.go @@ -0,0 +1,39 @@ +package middleware + +import ( + "net/http" + + "github.com/go-chi/chi/v5" +) + +// GetHead automatically route undefined HEAD requests to GET handlers. +func GetHead(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + rctx := chi.RouteContext(r.Context()) + routePath := rctx.RoutePath + if routePath == "" { + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } else { + routePath = r.URL.Path + } + } + + // Temporary routing context to look-ahead before routing the request + tctx := chi.NewRouteContext() + + // Attempt to find a HEAD handler for the routing path, if not found, traverse + // the router as through its a GET route, but proceed with the request + // with the HEAD method. + if !rctx.Routes.Match(tctx, "HEAD", routePath) { + rctx.RouteMethod = "GET" + rctx.RoutePath = routePath + next.ServeHTTP(w, r) + return + } + } + + next.ServeHTTP(w, r) + }) +} diff --git a/testdata/chi/middleware/heartbeat.go b/testdata/chi/middleware/heartbeat.go new file mode 100644 index 0000000..f36e8cc --- /dev/null +++ b/testdata/chi/middleware/heartbeat.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// Heartbeat endpoint middleware useful to setting up a path like +// `/ping` that load balancers or uptime testing external services +// can make a request before hitting any routes. It's also convenient +// to place this above ACL middlewares as well. +func Heartbeat(endpoint string) func(http.Handler) http.Handler { + f := func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if (r.Method == "GET" || r.Method == "HEAD") && strings.EqualFold(r.URL.Path, endpoint) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(".")) + return + } + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } + return f +} diff --git a/testdata/chi/middleware/logger.go b/testdata/chi/middleware/logger.go new file mode 100644 index 0000000..4d30a9a --- /dev/null +++ b/testdata/chi/middleware/logger.go @@ -0,0 +1,178 @@ +package middleware + +import ( + "bytes" + "context" + "log" + "net/http" + "os" + "runtime" + "time" +) + +var ( + // LogEntryCtxKey is the context.Context key to store the request log entry. + LogEntryCtxKey = &contextKey{"LogEntry"} + + // DefaultLogger is called by the Logger middleware handler to log each request. + // Its made a package-level variable so that it can be reconfigured for custom + // logging configurations. + DefaultLogger func(next http.Handler) http.Handler +) + +// Logger is a middleware that logs the start and end of each request, along +// with some useful data about what was requested, what the response status was, +// and how long it took to return. When standard output is a TTY, Logger will +// print in color, otherwise it will print in black and white. Logger prints a +// request ID if one is provided. +// +// Alternatively, look at https://github.com/goware/httplog for a more in-depth +// http logger with structured logging support. +// +// IMPORTANT NOTE: Logger should go before any other middleware that may change +// the response, such as middleware.Recoverer. Example: +// +// r := chi.NewRouter() +// r.Use(middleware.Logger) // <--<< Logger should come before Recoverer +// r.Use(middleware.Recoverer) +// r.Get("/", handler) +func Logger(next http.Handler) http.Handler { + return DefaultLogger(next) +} + +// RequestLogger returns a logger handler using a custom LogFormatter. +func RequestLogger(f LogFormatter) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + entry := f.NewLogEntry(r) + ww := NewWrapResponseWriter(w, r.ProtoMajor) + + t1 := time.Now() + defer func() { + entry.Write(ww.Status(), ww.BytesWritten(), ww.Header(), time.Since(t1), nil) + }() + + next.ServeHTTP(ww, WithLogEntry(r, entry)) + } + return http.HandlerFunc(fn) + } +} + +// LogFormatter initiates the beginning of a new LogEntry per request. +// See DefaultLogFormatter for an example implementation. +type LogFormatter interface { + NewLogEntry(r *http.Request) LogEntry +} + +// LogEntry records the final log when a request completes. +// See defaultLogEntry for an example implementation. +type LogEntry interface { + Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) + Panic(v interface{}, stack []byte) +} + +// GetLogEntry returns the in-context LogEntry for a request. +func GetLogEntry(r *http.Request) LogEntry { + entry, _ := r.Context().Value(LogEntryCtxKey).(LogEntry) + return entry +} + +// WithLogEntry sets the in-context LogEntry for a request. +func WithLogEntry(r *http.Request, entry LogEntry) *http.Request { + r = r.WithContext(context.WithValue(r.Context(), LogEntryCtxKey, entry)) + return r +} + +// LoggerInterface accepts printing to stdlib logger or compatible logger. +type LoggerInterface interface { + Print(v ...interface{}) +} + +// DefaultLogFormatter is a simple logger that implements a LogFormatter. +type DefaultLogFormatter struct { + Logger LoggerInterface + NoColor bool +} + +// NewLogEntry creates a new LogEntry for the request. +func (l *DefaultLogFormatter) NewLogEntry(r *http.Request) LogEntry { + ctx := r.Context() + + useColor := !l.NoColor + entry := &defaultLogEntry{ + DefaultLogFormatter: l, + request: r, + buf: &bytes.Buffer{}, + useColor: useColor, + } + + reqID := GetReqID(ctx) + if reqID != "" { + cW(entry.buf, useColor, nYellow, "[%s] ", reqID) + } + cW(entry.buf, useColor, nCyan, "\"") + cW(entry.buf, useColor, bMagenta, "%s ", r.Method) + + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + cW(entry.buf, useColor, nCyan, "%s://%s%s %s\" ", scheme, r.Host, r.RequestURI, r.Proto) + + entry.buf.WriteString("from ") + clientIP := GetClientIP(ctx) + if clientIP == "" { + clientIP = r.RemoteAddr + } + entry.buf.WriteString(clientIP) + entry.buf.WriteString(" - ") + + return entry +} + +type defaultLogEntry struct { + *DefaultLogFormatter + request *http.Request + buf *bytes.Buffer + useColor bool +} + +func (l *defaultLogEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) { + switch { + case status < 200: + cW(l.buf, l.useColor, bBlue, "%03d", status) + case status < 300: + cW(l.buf, l.useColor, bGreen, "%03d", status) + case status < 400: + cW(l.buf, l.useColor, bCyan, "%03d", status) + case status < 500: + cW(l.buf, l.useColor, bYellow, "%03d", status) + default: + cW(l.buf, l.useColor, bRed, "%03d", status) + } + + cW(l.buf, l.useColor, bBlue, " %dB", bytes) + + l.buf.WriteString(" in ") + if elapsed < 500*time.Millisecond { + cW(l.buf, l.useColor, nGreen, "%s", elapsed) + } else if elapsed < 5*time.Second { + cW(l.buf, l.useColor, nYellow, "%s", elapsed) + } else { + cW(l.buf, l.useColor, nRed, "%s", elapsed) + } + + l.Logger.Print(l.buf.String()) +} + +func (l *defaultLogEntry) Panic(v interface{}, stack []byte) { + PrintPrettyStack(v) +} + +func init() { + color := true + if runtime.GOOS == "windows" { + color = false + } + DefaultLogger = RequestLogger(&DefaultLogFormatter{Logger: log.New(os.Stdout, "", log.LstdFlags), NoColor: !color}) +} diff --git a/testdata/chi/middleware/maybe.go b/testdata/chi/middleware/maybe.go new file mode 100644 index 0000000..eabca00 --- /dev/null +++ b/testdata/chi/middleware/maybe.go @@ -0,0 +1,18 @@ +package middleware + +import "net/http" + +// Maybe middleware will allow you to change the flow of the middleware stack execution depending on return +// value of maybeFn(request). This is useful for example if you'd like to skip a middleware handler if +// a request does not satisfy the maybeFn logic. +func Maybe(mw func(http.Handler) http.Handler, maybeFn func(r *http.Request) bool) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if maybeFn(r) { + mw(next).ServeHTTP(w, r) + } else { + next.ServeHTTP(w, r) + } + }) + } +} diff --git a/testdata/chi/middleware/middleware.go b/testdata/chi/middleware/middleware.go new file mode 100644 index 0000000..cc371e0 --- /dev/null +++ b/testdata/chi/middleware/middleware.go @@ -0,0 +1,23 @@ +package middleware + +import "net/http" + +// New will create a new middleware handler from a http.Handler. +func New(h http.Handler) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + }) + } +} + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. This technique +// for defining context keys was copied from Go 1.7's new use of context in net/http. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { + return "chi/middleware context value " + k.name +} diff --git a/testdata/chi/middleware/nocache.go b/testdata/chi/middleware/nocache.go new file mode 100644 index 0000000..9308d40 --- /dev/null +++ b/testdata/chi/middleware/nocache.go @@ -0,0 +1,59 @@ +package middleware + +// Ported from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "net/http" + "time" +) + +// Unix epoch time +var epoch = time.Unix(0, 0).UTC().Format(http.TimeFormat) + +// Taken from https://github.com/mytrile/nocache +var noCacheHeaders = map[string]string{ + "Expires": epoch, + "Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0", + "Pragma": "no-cache", + "X-Accel-Expires": "0", +} + +var etagHeaders = []string{ + "ETag", + "If-Modified-Since", + "If-Match", + "If-None-Match", + "If-Range", + "If-Unmodified-Since", +} + +// NoCache is a simple piece of middleware that sets a number of HTTP headers to prevent +// a router (or subrouter) from being cached by an upstream proxy and/or client. +// +// As per http://wiki.nginx.org/HttpProxyModule - NoCache sets: +// +// Expires: Thu, 01 Jan 1970 00:00:00 UTC +// Cache-Control: no-cache, private, max-age=0 +// X-Accel-Expires: 0 +// Pragma: no-cache (for HTTP/1.0 proxies/clients) +func NoCache(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + + // Delete any ETag headers that may have been set + for _, v := range etagHeaders { + if r.Header.Get(v) != "" { + r.Header.Del(v) + } + } + + // Set our NoCache headers + for k, v := range noCacheHeaders { + w.Header().Set(k, v) + } + + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} diff --git a/testdata/chi/middleware/page_route.go b/testdata/chi/middleware/page_route.go new file mode 100644 index 0000000..32871b7 --- /dev/null +++ b/testdata/chi/middleware/page_route.go @@ -0,0 +1,20 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// PageRoute is a simple middleware which allows you to route a static GET request +// at the middleware stack level. +func PageRoute(path string, handler http.Handler) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && strings.EqualFold(r.URL.Path, path) { + handler.ServeHTTP(w, r) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/testdata/chi/middleware/path_rewrite.go b/testdata/chi/middleware/path_rewrite.go new file mode 100644 index 0000000..99af62c --- /dev/null +++ b/testdata/chi/middleware/path_rewrite.go @@ -0,0 +1,16 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// PathRewrite is a simple middleware which allows you to rewrite the request URL path. +func PathRewrite(old, new string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.Replace(r.URL.Path, old, new, 1) + next.ServeHTTP(w, r) + }) + } +} diff --git a/testdata/chi/middleware/profiler.go b/testdata/chi/middleware/profiler.go new file mode 100644 index 0000000..0ad6a99 --- /dev/null +++ b/testdata/chi/middleware/profiler.go @@ -0,0 +1,49 @@ +//go:build !tinygo +// +build !tinygo + +package middleware + +import ( + "expvar" + "net/http" + "net/http/pprof" + + "github.com/go-chi/chi/v5" +) + +// Profiler is a convenient subrouter used for mounting net/http/pprof. ie. +// +// func MyService() http.Handler { +// r := chi.NewRouter() +// // ..middlewares +// r.Mount("/debug", middleware.Profiler()) +// // ..routes +// return r +// } +func Profiler() http.Handler { + r := chi.NewRouter() + r.Use(NoCache) + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, r.RequestURI+"/pprof/", http.StatusMovedPermanently) + }) + r.HandleFunc("/pprof", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, r.RequestURI+"/", http.StatusMovedPermanently) + }) + + r.HandleFunc("/pprof/*", pprof.Index) + r.HandleFunc("/pprof/cmdline", pprof.Cmdline) + r.HandleFunc("/pprof/profile", pprof.Profile) + r.HandleFunc("/pprof/symbol", pprof.Symbol) + r.HandleFunc("/pprof/trace", pprof.Trace) + r.Handle("/vars", expvar.Handler()) + + r.Handle("/pprof/goroutine", pprof.Handler("goroutine")) + r.Handle("/pprof/threadcreate", pprof.Handler("threadcreate")) + r.Handle("/pprof/mutex", pprof.Handler("mutex")) + r.Handle("/pprof/heap", pprof.Handler("heap")) + r.Handle("/pprof/block", pprof.Handler("block")) + r.Handle("/pprof/allocs", pprof.Handler("allocs")) + + return r +} diff --git a/testdata/chi/middleware/realip.go b/testdata/chi/middleware/realip.go new file mode 100644 index 0000000..349f168 --- /dev/null +++ b/testdata/chi/middleware/realip.go @@ -0,0 +1,53 @@ +package middleware + +// Ported from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "net" + "net/http" + "strings" +) + +var trueClientIP = http.CanonicalHeaderKey("True-Client-IP") +var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") +var xRealIP = http.CanonicalHeaderKey("X-Real-IP") + +// RealIP is a middleware that sets a http.Request's RemoteAddr to the results +// of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers +// (in that order). +// +// Deprecated: RealIP is vulnerable to IP spoofing — it mutates r.RemoteAddr +// to the leftmost X-Forwarded-For value, or to True-Client-IP / X-Real-IP +// whether or not your infrastructure actually sets them. See +// GHSA-3fxj-6jh8-hvhx, GHSA-rjr7-jggh-pgcp, GHSA-9g5q-2w5x-hmxf. +// +// Use [ClientIPFromHeader], [ClientIPFromXFF], [ClientIPFromXFFTrustedProxies] +// or [ClientIPFromRemoteAddr] and read the IP with [GetClientIP] instead. +// These never mutate r.RemoteAddr. +func RealIP(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if rip := realIP(r); rip != "" { + r.RemoteAddr = rip + } + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + +func realIP(r *http.Request) string { + var ip string + + if tcip := r.Header.Get(trueClientIP); tcip != "" { + ip = tcip + } else if xrip := r.Header.Get(xRealIP); xrip != "" { + ip = xrip + } else if xff := r.Header.Get(xForwardedFor); xff != "" { + ip, _, _ = strings.Cut(xff, ",") + } + if ip == "" || net.ParseIP(ip) == nil { + return "" + } + return ip +} diff --git a/testdata/chi/middleware/recoverer.go b/testdata/chi/middleware/recoverer.go new file mode 100644 index 0000000..81342df --- /dev/null +++ b/testdata/chi/middleware/recoverer.go @@ -0,0 +1,203 @@ +package middleware + +// The original work was derived from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + "os" + "runtime/debug" + "strings" +) + +// Recoverer is a middleware that recovers from panics, logs the panic (and a +// backtrace), and returns a HTTP 500 (Internal Server Error) status if +// possible. Recoverer prints a request ID if one is provided. +// +// Alternatively, look at https://github.com/go-chi/httplog middleware pkgs. +func Recoverer(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rvr := recover(); rvr != nil { + if rvr == http.ErrAbortHandler { + // we don't recover http.ErrAbortHandler so the response + // to the client is aborted, this should not be logged + panic(rvr) + } + + logEntry := GetLogEntry(r) + if logEntry != nil { + logEntry.Panic(rvr, debug.Stack()) + } else { + PrintPrettyStack(rvr) + } + + if r.Header.Get("Connection") != "Upgrade" { + w.WriteHeader(http.StatusInternalServerError) + } + } + }() + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + +// for ability to test the PrintPrettyStack function +var recovererErrorWriter io.Writer = os.Stderr + +func PrintPrettyStack(rvr interface{}) { + debugStack := debug.Stack() + s := prettyStack{} + out, err := s.parse(debugStack, rvr) + if err == nil { + recovererErrorWriter.Write(out) + } else { + // print stdlib output as a fallback + os.Stderr.Write(debugStack) + } +} + +type prettyStack struct { +} + +func (s prettyStack) parse(debugStack []byte, rvr interface{}) ([]byte, error) { + var err error + useColor := true + buf := &bytes.Buffer{} + + cW(buf, false, bRed, "\n") + cW(buf, useColor, bCyan, " panic: ") + cW(buf, useColor, bBlue, "%v", rvr) + cW(buf, false, bWhite, "\n \n") + + // process debug stack info + stack := strings.Split(string(debugStack), "\n") + lines := []string{} + + // locate panic line, as we may have nested panics + for i := len(stack) - 1; i > 0; i-- { + lines = append(lines, stack[i]) + if strings.HasPrefix(stack[i], "panic(") { + lines = lines[0 : len(lines)-2] // remove boilerplate + break + } + } + + // reverse + for i := len(lines)/2 - 1; i >= 0; i-- { + opp := len(lines) - 1 - i + lines[i], lines[opp] = lines[opp], lines[i] + } + + // decorate + for i, line := range lines { + lines[i], err = s.decorateLine(line, useColor, i) + if err != nil { + return nil, err + } + } + + for _, l := range lines { + fmt.Fprintf(buf, "%s", l) + } + return buf.Bytes(), nil +} + +func (s prettyStack) decorateLine(line string, useColor bool, num int) (string, error) { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "\t") || strings.Contains(line, ".go:") { + return s.decorateSourceLine(line, useColor, num) + } + if strings.HasSuffix(line, ")") { + return s.decorateFuncCallLine(line, useColor, num) + } + if strings.HasPrefix(line, "\t") { + return strings.Replace(line, "\t", " ", 1), nil + } + return fmt.Sprintf(" %s\n", line), nil +} + +func (s prettyStack) decorateFuncCallLine(line string, useColor bool, num int) (string, error) { + idx := strings.LastIndex(line, "(") + if idx < 0 { + return "", errors.New("not a func call line") + } + + buf := &bytes.Buffer{} + pkg := line[0:idx] + // addr := line[idx:] + method := "" + + if idx := strings.LastIndex(pkg, string(os.PathSeparator)); idx < 0 { + if idx := strings.Index(pkg, "."); idx > 0 { + method = pkg[idx:] + pkg = pkg[0:idx] + } + } else { + method = pkg[idx+1:] + pkg = pkg[0 : idx+1] + if idx := strings.Index(method, "."); idx > 0 { + pkg += method[0:idx] + method = method[idx:] + } + } + pkgColor := nYellow + methodColor := bGreen + + if num == 0 { + cW(buf, useColor, bRed, " -> ") + pkgColor = bMagenta + methodColor = bRed + } else { + cW(buf, useColor, bWhite, " ") + } + cW(buf, useColor, pkgColor, "%s", pkg) + cW(buf, useColor, methodColor, "%s\n", method) + // cW(buf, useColor, nBlack, "%s", addr) + return buf.String(), nil +} + +func (s prettyStack) decorateSourceLine(line string, useColor bool, num int) (string, error) { + idx := strings.LastIndex(line, ".go:") + if idx < 0 { + return "", errors.New("not a source line") + } + + buf := &bytes.Buffer{} + path := line[0 : idx+3] + lineno := line[idx+3:] + + idx = strings.LastIndex(path, string(os.PathSeparator)) + dir := path[0 : idx+1] + file := path[idx+1:] + + idx = strings.Index(lineno, " ") + if idx > 0 { + lineno = lineno[0:idx] + } + fileColor := bCyan + lineColor := bGreen + + if num == 1 { + cW(buf, useColor, bRed, " -> ") + fileColor = bRed + lineColor = bMagenta + } else { + cW(buf, false, bWhite, " ") + } + cW(buf, useColor, bWhite, "%s", dir) + cW(buf, useColor, fileColor, "%s", file) + cW(buf, useColor, lineColor, "%s", lineno) + if num == 1 { + cW(buf, false, bWhite, "\n") + } + cW(buf, false, bWhite, "\n") + + return buf.String(), nil +} diff --git a/testdata/chi/middleware/request_id.go b/testdata/chi/middleware/request_id.go new file mode 100644 index 0000000..e1d4ccb --- /dev/null +++ b/testdata/chi/middleware/request_id.go @@ -0,0 +1,96 @@ +package middleware + +// Ported from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "net/http" + "os" + "strings" + "sync/atomic" +) + +// Key to use when setting the request ID. +type ctxKeyRequestID int + +// RequestIDKey is the key that holds the unique request ID in a request context. +const RequestIDKey ctxKeyRequestID = 0 + +// RequestIDHeader is the name of the HTTP Header which contains the request id. +// Exported so that it can be changed by developers +var RequestIDHeader = "X-Request-Id" + +var prefix string +var reqid atomic.Uint64 + +// A quick note on the statistics here: we're trying to calculate the chance that +// two randomly generated base62 prefixes will collide. We use the formula from +// http://en.wikipedia.org/wiki/Birthday_problem +// +// P[m, n] \approx 1 - e^{-m^2/2n} +// +// We ballpark an upper bound for $m$ by imagining (for whatever reason) a server +// that restarts every second over 10 years, for $m = 86400 * 365 * 10 = 315360000$ +// +// For a $k$ character base-62 identifier, we have $n(k) = 62^k$ +// +// Plugging this in, we find $P[m, n(10)] \approx 5.75%$, which is good enough for +// our purposes, and is surely more than anyone would ever need in practice -- a +// process that is rebooted a handful of times a day for a hundred years has less +// than a millionth of a percent chance of generating two colliding IDs. + +func init() { + hostname, err := os.Hostname() + if hostname == "" || err != nil { + hostname = "localhost" + } + var buf [12]byte + var b64 string + for len(b64) < 10 { + rand.Read(buf[:]) + b64 = base64.StdEncoding.EncodeToString(buf[:]) + b64 = strings.NewReplacer("+", "", "/", "").Replace(b64) + } + + prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10]) +} + +// RequestID is a middleware that injects a request ID into the context of each +// request. A request ID is a string of the form "host.example.com/random-0001", +// where "random" is a base62 random string that uniquely identifies this go +// process, and where the last number is an atomically incremented request +// counter. +func RequestID(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + requestID := r.Header.Get(RequestIDHeader) + if requestID == "" { + myid := reqid.Add(1) + requestID = fmt.Sprintf("%s-%06d", prefix, myid) + } + ctx = context.WithValue(ctx, RequestIDKey, requestID) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} + +// GetReqID returns a request ID from the given context if one is present. +// Returns the empty string if a request ID cannot be found. +func GetReqID(ctx context.Context) string { + if ctx == nil { + return "" + } + if reqID, ok := ctx.Value(RequestIDKey).(string); ok { + return reqID + } + return "" +} + +// NextRequestID generates the next request ID in the sequence. +func NextRequestID() uint64 { + return reqid.Add(1) +} diff --git a/testdata/chi/middleware/request_size.go b/testdata/chi/middleware/request_size.go new file mode 100644 index 0000000..678248c --- /dev/null +++ b/testdata/chi/middleware/request_size.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "net/http" +) + +// RequestSize is a middleware that will limit request sizes to a specified +// number of bytes. It uses MaxBytesReader to do so. +func RequestSize(bytes int64) func(http.Handler) http.Handler { + f := func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, bytes) + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } + return f +} diff --git a/testdata/chi/middleware/route_headers.go b/testdata/chi/middleware/route_headers.go new file mode 100644 index 0000000..1c3334d --- /dev/null +++ b/testdata/chi/middleware/route_headers.go @@ -0,0 +1,146 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// RouteHeaders is a neat little header-based router that allows you to direct +// the flow of a request through a middleware stack based on a request header. +// +// For example, lets say you'd like to setup multiple routers depending on the +// request Host header, you could then do something as so: +// +// r := chi.NewRouter() +// rSubdomain := chi.NewRouter() +// r.Use(middleware.RouteHeaders(). +// Route("Host", "example.com", middleware.New(r)). +// Route("Host", "*.example.com", middleware.New(rSubdomain)). +// Handler) +// r.Get("/", h) +// rSubdomain.Get("/", h2) +// +// Another example, imagine you want to setup multiple CORS handlers, where for +// your origin servers you allow authorized requests, but for third-party public +// requests, authorization is disabled. +// +// r := chi.NewRouter() +// r.Use(middleware.RouteHeaders(). +// Route("Origin", "https://app.skyweaver.net", cors.Handler(cors.Options{ +// AllowedOrigins: []string{"https://api.skyweaver.net"}, +// AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, +// AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, +// AllowCredentials: true, // <----------<<< allow credentials +// })). +// Route("Origin", "*", cors.Handler(cors.Options{ +// AllowedOrigins: []string{"*"}, +// AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, +// AllowedHeaders: []string{"Accept", "Content-Type"}, +// AllowCredentials: false, // <----------<<< do not allow credentials +// })). +// Handler) +func RouteHeaders() HeaderRouter { + return HeaderRouter{} +} + +type HeaderRouter map[string][]HeaderRoute + +func (hr HeaderRouter) Route(header, match string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter { + header = strings.ToLower(header) + k := hr[header] + if k == nil { + hr[header] = []HeaderRoute{} + } + hr[header] = append(hr[header], HeaderRoute{MatchOne: NewPattern(match), Middleware: middlewareHandler}) + return hr +} + +func (hr HeaderRouter) RouteAny(header string, match []string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter { + header = strings.ToLower(header) + k := hr[header] + if k == nil { + hr[header] = []HeaderRoute{} + } + patterns := []Pattern{} + for _, m := range match { + patterns = append(patterns, NewPattern(m)) + } + hr[header] = append(hr[header], HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler}) + return hr +} + +func (hr HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) HeaderRouter { + hr["*"] = []HeaderRoute{{Middleware: handler}} + return hr +} + +func (hr HeaderRouter) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(hr) == 0 { + // skip if no routes set + next.ServeHTTP(w, r) + return + } + + // find first matching header route, and continue + for header, matchers := range hr { + headerValue := r.Header.Get(header) + if headerValue == "" { + continue + } + headerValue = strings.ToLower(headerValue) + for _, matcher := range matchers { + if matcher.IsMatch(headerValue) { + matcher.Middleware(next).ServeHTTP(w, r) + return + } + } + } + + // if no match, check for "*" default route + matcher, ok := hr["*"] + if !ok || matcher[0].Middleware == nil { + next.ServeHTTP(w, r) + return + } + matcher[0].Middleware(next).ServeHTTP(w, r) + }) +} + +type HeaderRoute struct { + Middleware func(next http.Handler) http.Handler + MatchOne Pattern + MatchAny []Pattern +} + +func (r HeaderRoute) IsMatch(value string) bool { + if len(r.MatchAny) > 0 { + for _, m := range r.MatchAny { + if m.Match(value) { + return true + } + } + } else if r.MatchOne.Match(value) { + return true + } + return false +} + +type Pattern struct { + prefix string + suffix string + wildcard bool +} + +func NewPattern(value string) Pattern { + p := Pattern{} + p.prefix, p.suffix, p.wildcard = strings.Cut(value, "*") + return p +} + +func (p Pattern) Match(v string) bool { + if !p.wildcard { + return p.prefix == v + } + return len(v) >= len(p.prefix+p.suffix) && strings.HasPrefix(v, p.prefix) && strings.HasSuffix(v, p.suffix) +} diff --git a/testdata/chi/middleware/strip.go b/testdata/chi/middleware/strip.go new file mode 100644 index 0000000..32d21e9 --- /dev/null +++ b/testdata/chi/middleware/strip.go @@ -0,0 +1,77 @@ +package middleware + +import ( + "fmt" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" +) + +// StripSlashes is a middleware that will match request paths with a trailing +// slash, strip it from the path and continue routing through the mux, if a route +// matches, then it will serve the handler. +func StripSlashes(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + var path string + rctx := chi.RouteContext(r.Context()) + if rctx != nil && rctx.RoutePath != "" { + path = rctx.RoutePath + } else { + path = r.URL.Path + } + if len(path) > 1 && path[len(path)-1] == '/' { + newPath := path[:len(path)-1] + if rctx == nil { + r.URL.Path = newPath + } else { + rctx.RoutePath = newPath + } + } + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + +// RedirectSlashes is a middleware that will match request paths with a trailing +// slash and redirect to the same path, less the trailing slash. +// +// NOTE: RedirectSlashes middleware is *incompatible* with http.FileServer, +// see https://github.com/go-chi/chi/issues/343 +func RedirectSlashes(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + var path string + rctx := chi.RouteContext(r.Context()) + if rctx != nil && rctx.RoutePath != "" { + path = rctx.RoutePath + } else { + path = r.URL.Path + } + + if len(path) > 1 && path[len(path)-1] == '/' { + // Normalize backslashes to forward slashes to prevent "/\evil.com" style redirects + // that some clients may interpret as protocol-relative. + path = strings.ReplaceAll(path, `\`, `/`) + + // Collapse leading/trailing slashes and force a single leading slash. + path := "/" + strings.Trim(path, "/") + + if r.URL.RawQuery != "" { + path = fmt.Sprintf("%s?%s", path, r.URL.RawQuery) + } + http.Redirect(w, r, path, 301) + return + } + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + +// StripPrefix is a middleware that will strip the provided prefix from the +// request path before handing the request over to the next handler. +func StripPrefix(prefix string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.StripPrefix(prefix, next) + } +} diff --git a/testdata/chi/middleware/sunset.go b/testdata/chi/middleware/sunset.go new file mode 100644 index 0000000..18815d5 --- /dev/null +++ b/testdata/chi/middleware/sunset.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "net/http" + "time" +) + +// Sunset set Deprecation/Sunset header to response +// This can be used to enable Sunset in a route or a route group +// For more: https://www.rfc-editor.org/rfc/rfc8594.html +func Sunset(sunsetAt time.Time, links ...string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !sunsetAt.IsZero() { + w.Header().Set("Sunset", sunsetAt.Format(http.TimeFormat)) + w.Header().Set("Deprecation", sunsetAt.Format(http.TimeFormat)) + + for _, link := range links { + w.Header().Add("Link", link) + } + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/testdata/chi/middleware/supress_notfound.go b/testdata/chi/middleware/supress_notfound.go new file mode 100644 index 0000000..83a8a87 --- /dev/null +++ b/testdata/chi/middleware/supress_notfound.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "net/http" + + "github.com/go-chi/chi/v5" +) + +// SupressNotFound will quickly respond with a 404 if the route is not found +// and will not continue to the next middleware handler. +// +// This is handy to put at the top of your middleware stack to avoid unnecessary +// processing of requests that are not going to match any routes anyway. For +// example its super annoying to see a bunch of 404's in your logs from bots. +func SupressNotFound(router *chi.Mux) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rctx := chi.RouteContext(r.Context()) + match := rctx.Routes.Match(rctx, r.Method, r.URL.Path) + if !match { + router.NotFoundHandler().ServeHTTP(w, r) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/testdata/chi/middleware/terminal.go b/testdata/chi/middleware/terminal.go new file mode 100644 index 0000000..5ead7b9 --- /dev/null +++ b/testdata/chi/middleware/terminal.go @@ -0,0 +1,63 @@ +package middleware + +// Ported from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "fmt" + "io" + "os" +) + +var ( + // Normal colors + nBlack = []byte{'\033', '[', '3', '0', 'm'} + nRed = []byte{'\033', '[', '3', '1', 'm'} + nGreen = []byte{'\033', '[', '3', '2', 'm'} + nYellow = []byte{'\033', '[', '3', '3', 'm'} + nBlue = []byte{'\033', '[', '3', '4', 'm'} + nMagenta = []byte{'\033', '[', '3', '5', 'm'} + nCyan = []byte{'\033', '[', '3', '6', 'm'} + nWhite = []byte{'\033', '[', '3', '7', 'm'} + // Bright colors + bBlack = []byte{'\033', '[', '3', '0', ';', '1', 'm'} + bRed = []byte{'\033', '[', '3', '1', ';', '1', 'm'} + bGreen = []byte{'\033', '[', '3', '2', ';', '1', 'm'} + bYellow = []byte{'\033', '[', '3', '3', ';', '1', 'm'} + bBlue = []byte{'\033', '[', '3', '4', ';', '1', 'm'} + bMagenta = []byte{'\033', '[', '3', '5', ';', '1', 'm'} + bCyan = []byte{'\033', '[', '3', '6', ';', '1', 'm'} + bWhite = []byte{'\033', '[', '3', '7', ';', '1', 'm'} + + reset = []byte{'\033', '[', '0', 'm'} +) + +var IsTTY bool + +func init() { + // This is sort of cheating: if stdout is a character device, we assume + // that means it's a TTY. Unfortunately, there are many non-TTY + // character devices, but fortunately stdout is rarely set to any of + // them. + // + // We could solve this properly by pulling in a dependency on + // code.google.com/p/go.crypto/ssh/terminal, for instance, but as a + // heuristic for whether to print in color or in black-and-white, I'd + // really rather not. + fi, err := os.Stdout.Stat() + if err == nil { + m := os.ModeDevice | os.ModeCharDevice + IsTTY = fi.Mode()&m == m + } +} + +// colorWrite +func cW(w io.Writer, useColor bool, color []byte, s string, args ...interface{}) { + if IsTTY && useColor { + w.Write(color) + } + fmt.Fprintf(w, s, args...) + if IsTTY && useColor { + w.Write(reset) + } +} diff --git a/testdata/chi/middleware/throttle.go b/testdata/chi/middleware/throttle.go new file mode 100644 index 0000000..7ea482b --- /dev/null +++ b/testdata/chi/middleware/throttle.go @@ -0,0 +1,151 @@ +package middleware + +import ( + "net/http" + "strconv" + "time" +) + +const ( + errCapacityExceeded = "Server capacity exceeded." + errTimedOut = "Timed out while waiting for a pending request to complete." + errContextCanceled = "Context was canceled." +) + +var ( + defaultBacklogTimeout = time.Second * 60 +) + +// ThrottleOpts represents a set of throttling options. +type ThrottleOpts struct { + RetryAfterFn func(ctxDone bool) time.Duration + Limit int + BacklogLimit int + BacklogTimeout time.Duration + StatusCode int +} + +// Throttle is a middleware that limits number of currently processed requests +// at a time across all users. Note: Throttle is not a rate-limiter per user, +// instead it just puts a ceiling on the number of current in-flight requests +// being processed from the point from where the Throttle middleware is mounted. +func Throttle(limit int) func(http.Handler) http.Handler { + return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogTimeout: defaultBacklogTimeout}) +} + +// ThrottleBacklog is a middleware that limits number of currently processed +// requests at a time and provides a backlog for holding a finite number of +// pending requests. +func ThrottleBacklog(limit, backlogLimit int, backlogTimeout time.Duration) func(http.Handler) http.Handler { + return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogLimit: backlogLimit, BacklogTimeout: backlogTimeout}) +} + +// ThrottleWithOpts is a middleware that limits number of currently processed requests using passed ThrottleOpts. +func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler { + if opts.Limit < 1 { + panic("chi/middleware: Throttle expects limit > 0") + } + + if opts.BacklogLimit < 0 { + panic("chi/middleware: Throttle expects backlogLimit to be positive") + } + + statusCode := opts.StatusCode + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + + t := throttler{ + tokens: make(chan token, opts.Limit), + backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit), + backlogTimeout: opts.BacklogTimeout, + statusCode: statusCode, + retryAfterFn: opts.RetryAfterFn, + } + + // Filling tokens. + for i := 0; i < opts.Limit+opts.BacklogLimit; i++ { + if i < opts.Limit { + t.tokens <- token{} + } + t.backlogTokens <- token{} + } + + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + select { + + case <-ctx.Done(): + t.setRetryAfterHeaderIfNeeded(w, true) + http.Error(w, errContextCanceled, t.statusCode) + return + + case btok := <-t.backlogTokens: + defer func() { + t.backlogTokens <- btok + }() + + // Try to get a processing token immediately first + select { + case tok := <-t.tokens: + defer func() { + t.tokens <- tok + }() + next.ServeHTTP(w, r) + return + default: + // No immediate token available, need to wait with timer + } + + timer := time.NewTimer(t.backlogTimeout) + select { + case <-timer.C: + t.setRetryAfterHeaderIfNeeded(w, false) + http.Error(w, errTimedOut, t.statusCode) + return + case <-ctx.Done(): + timer.Stop() + t.setRetryAfterHeaderIfNeeded(w, true) + http.Error(w, errContextCanceled, t.statusCode) + return + case tok := <-t.tokens: + defer func() { + timer.Stop() + t.tokens <- tok + }() + next.ServeHTTP(w, r) + } + return + + default: + t.setRetryAfterHeaderIfNeeded(w, false) + http.Error(w, errCapacityExceeded, t.statusCode) + return + } + } + + return http.HandlerFunc(fn) + } +} + +// token represents a request that is being processed. +type token struct{} + +// throttler limits number of currently processed requests at a time. +type throttler struct { + tokens chan token + backlogTokens chan token + retryAfterFn func(ctxDone bool) time.Duration + backlogTimeout time.Duration + statusCode int +} + +// setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized. +func (t throttler) setRetryAfterHeaderIfNeeded(w http.ResponseWriter, ctxDone bool) { + if t.retryAfterFn == nil { + return + } + w.Header().Set("Retry-After", strconv.Itoa(int(t.retryAfterFn(ctxDone).Seconds()))) +} diff --git a/testdata/chi/middleware/timeout.go b/testdata/chi/middleware/timeout.go new file mode 100644 index 0000000..add596d --- /dev/null +++ b/testdata/chi/middleware/timeout.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "context" + "net/http" + "time" +) + +// Timeout is a middleware that cancels ctx after a given timeout and return +// a 504 Gateway Timeout error to the client. +// +// It's required that you select the ctx.Done() channel to check for the signal +// if the context has reached its deadline and return, otherwise the timeout +// signal will be just ignored. +// +// ie. a route/handler may look like: +// +// r.Get("/long", func(w http.ResponseWriter, r *http.Request) { +// ctx := r.Context() +// processTime := time.Duration(rand.Intn(4)+1) * time.Second +// +// select { +// case <-ctx.Done(): +// return +// +// case <-time.After(processTime): +// // The above channel simulates some hard work. +// } +// +// w.Write([]byte("done")) +// }) +func Timeout(timeout time.Duration) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer func() { + cancel() + if ctx.Err() == context.DeadlineExceeded { + w.WriteHeader(http.StatusGatewayTimeout) + } + }() + + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} diff --git a/testdata/chi/middleware/url_format.go b/testdata/chi/middleware/url_format.go new file mode 100644 index 0000000..2ec6657 --- /dev/null +++ b/testdata/chi/middleware/url_format.go @@ -0,0 +1,77 @@ +package middleware + +import ( + "context" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" +) + +var ( + // URLFormatCtxKey is the context.Context key to store the URL format data + // for a request. + URLFormatCtxKey = &contextKey{"URLFormat"} +) + +// URLFormat is a middleware that parses the url extension from a request path and stores it +// on the context as a string under the key `middleware.URLFormatCtxKey`. The middleware will +// trim the suffix from the routing path and continue routing. +// +// Routers should not include a url parameter for the suffix when using this middleware. +// +// Sample usage for url paths `/articles/1`, `/articles/1.json` and `/articles/1.xml`: +// +// func routes() http.Handler { +// r := chi.NewRouter() +// r.Use(middleware.URLFormat) +// +// r.Get("/articles/{id}", ListArticles) +// +// return r +// } +// +// func ListArticles(w http.ResponseWriter, r *http.Request) { +// urlFormat, _ := r.Context().Value(middleware.URLFormatCtxKey).(string) +// +// switch urlFormat { +// case "json": +// render.JSON(w, r, articles) +// case "xml:" +// render.XML(w, r, articles) +// default: +// render.JSON(w, r, articles) +// } +// } +func URLFormat(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var format string + path := r.URL.Path + + rctx := chi.RouteContext(r.Context()) + if rctx != nil && rctx.RoutePath != "" { + path = rctx.RoutePath + } + + if strings.Index(path, ".") > 0 { + base := strings.LastIndex(path, "/") + idx := strings.LastIndex(path[base:], ".") + + if idx > 0 { + idx += base + format = path[idx+1:] + + if rctx != nil { + rctx.RoutePath = path[:idx] + } + } + } + + r = r.WithContext(context.WithValue(ctx, URLFormatCtxKey, format)) + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} diff --git a/testdata/chi/middleware/value.go b/testdata/chi/middleware/value.go new file mode 100644 index 0000000..a9dfd43 --- /dev/null +++ b/testdata/chi/middleware/value.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "context" + "net/http" +) + +// WithValue is a middleware that sets a given key/value in a context chain. +func WithValue(key, val interface{}) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(context.WithValue(r.Context(), key, val)) + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} diff --git a/testdata/chi/middleware/wrap_writer.go b/testdata/chi/middleware/wrap_writer.go new file mode 100644 index 0000000..b2de875 --- /dev/null +++ b/testdata/chi/middleware/wrap_writer.go @@ -0,0 +1,243 @@ +package middleware + +// The original work was derived from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "bufio" + "io" + "net" + "net/http" +) + +// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to +// hook into various parts of the response process. +func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter { + _, fl := w.(http.Flusher) + + bw := basicWriter{ResponseWriter: w} + + if protoMajor == 2 { + _, ps := w.(http.Pusher) + if fl && ps { + return &http2FancyWriter{bw} + } + } else { + _, hj := w.(http.Hijacker) + _, rf := w.(io.ReaderFrom) + if fl && hj && rf { + return &httpFancyWriter{bw} + } + if fl && hj { + return &flushHijackWriter{bw} + } + if hj { + return &hijackWriter{bw} + } + } + + if fl { + return &flushWriter{bw} + } + + return &bw +} + +// WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook +// into various parts of the response process. +type WrapResponseWriter interface { + http.ResponseWriter + // Status returns the HTTP status of the request, or 0 if one has not + // yet been sent. + Status() int + // BytesWritten returns the total number of bytes sent to the client. + BytesWritten() int + // Tee causes the response body to be written to the given io.Writer in + // addition to proxying the writes through. Only one io.Writer can be + // tee'd to at once: setting a second one will overwrite the first. + // Writes will be sent to the proxy before being written to this + // io.Writer. It is illegal for the tee'd writer to be modified + // concurrently with writes. + Tee(io.Writer) + // Unwrap returns the original proxied target. + Unwrap() http.ResponseWriter + // Discard causes all writes to the original ResponseWriter be discarded, + // instead writing only to the tee'd writer if it's set. + // The caller is responsible for calling WriteHeader and Write on the + // original ResponseWriter once the processing is done. + Discard() +} + +// basicWriter wraps a http.ResponseWriter that implements the minimal +// http.ResponseWriter interface. +type basicWriter struct { + http.ResponseWriter + tee io.Writer + code int + bytes int + wroteHeader bool + discard bool +} + +func (b *basicWriter) WriteHeader(code int) { + if code >= 100 && code <= 199 && code != http.StatusSwitchingProtocols { + if !b.discard { + b.ResponseWriter.WriteHeader(code) + } + } else if !b.wroteHeader { + b.code = code + b.wroteHeader = true + if !b.discard { + b.ResponseWriter.WriteHeader(code) + } + } +} + +func (b *basicWriter) Write(buf []byte) (n int, err error) { + b.maybeWriteHeader() + if !b.discard { + n, err = b.ResponseWriter.Write(buf) + if b.tee != nil { + _, err2 := b.tee.Write(buf[:n]) + // Prefer errors generated by the proxied writer. + if err == nil { + err = err2 + } + } + } else if b.tee != nil { + n, err = b.tee.Write(buf) + } else { + n, err = io.Discard.Write(buf) + } + b.bytes += n + return n, err +} + +func (b *basicWriter) maybeWriteHeader() { + if !b.wroteHeader { + b.WriteHeader(http.StatusOK) + } +} + +func (b *basicWriter) Status() int { + return b.code +} + +func (b *basicWriter) BytesWritten() int { + return b.bytes +} + +func (b *basicWriter) Tee(w io.Writer) { + b.tee = w +} + +func (b *basicWriter) Unwrap() http.ResponseWriter { + return b.ResponseWriter +} + +func (b *basicWriter) Discard() { + b.discard = true +} + +// flushWriter ... +type flushWriter struct { + basicWriter +} + +func (f *flushWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +var _ http.Flusher = &flushWriter{} + +// hijackWriter ... +type hijackWriter struct { + basicWriter +} + +func (f *hijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj := f.basicWriter.ResponseWriter.(http.Hijacker) + return hj.Hijack() +} + +var _ http.Hijacker = &hijackWriter{} + +// flushHijackWriter ... +type flushHijackWriter struct { + basicWriter +} + +func (f *flushHijackWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +func (f *flushHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj := f.basicWriter.ResponseWriter.(http.Hijacker) + return hj.Hijack() +} + +var _ http.Flusher = &flushHijackWriter{} +var _ http.Hijacker = &flushHijackWriter{} + +// httpFancyWriter is a HTTP writer that additionally satisfies +// http.Flusher, http.Hijacker, and io.ReaderFrom. It exists for the common case +// of wrapping the http.ResponseWriter that package http gives you, in order to +// make the proxied object support the full method set of the proxied object. +type httpFancyWriter struct { + basicWriter +} + +func (f *httpFancyWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +func (f *httpFancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj := f.basicWriter.ResponseWriter.(http.Hijacker) + return hj.Hijack() +} + +func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error { + return f.basicWriter.ResponseWriter.(http.Pusher).Push(target, opts) +} + +func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) { + if f.basicWriter.tee != nil { + // Route through basicWriter.Write so that data is also written to the + // tee writer. basicWriter.Write already increments basicWriter.bytes, + // so we must NOT add n again here (that would double-count). + n, err := io.Copy(&f.basicWriter, r) + return n, err + } + rf := f.basicWriter.ResponseWriter.(io.ReaderFrom) + f.basicWriter.maybeWriteHeader() + n, err := rf.ReadFrom(r) + f.basicWriter.bytes += int(n) + return n, err +} + +var _ http.Flusher = &httpFancyWriter{} +var _ http.Hijacker = &httpFancyWriter{} +var _ http.Pusher = &http2FancyWriter{} +var _ io.ReaderFrom = &httpFancyWriter{} + +// http2FancyWriter is a HTTP2 writer that additionally satisfies +// http.Flusher, and io.ReaderFrom. It exists for the common case +// of wrapping the http.ResponseWriter that package http gives you, in order to +// make the proxied object support the full method set of the proxied object. +type http2FancyWriter struct { + basicWriter +} + +func (f *http2FancyWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +var _ http.Flusher = &http2FancyWriter{} diff --git a/testdata/chi/mux.go b/testdata/chi/mux.go new file mode 100644 index 0000000..3da7f3f --- /dev/null +++ b/testdata/chi/mux.go @@ -0,0 +1,526 @@ +package chi + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" +) + +var _ Router = &Mux{} + +// Mux is a simple HTTP route multiplexer that parses a request path, +// records any URL params, and executes an end handler. It implements +// the http.Handler interface and is friendly with the standard library. +// +// Mux is designed to be fast, minimal and offer a powerful API for building +// modular and composable HTTP services with a large set of handlers. It's +// particularly useful for writing large REST API services that break a handler +// into many smaller parts composed of middlewares and end handlers. +type Mux struct { + // The computed mux handler made of the chained middleware stack and + // the tree router + handler http.Handler + + // The radix trie router + tree *node + + // Custom method not allowed handler + methodNotAllowedHandler http.HandlerFunc + + // A reference to the parent mux used by subrouters when mounting + // to a parent mux + parent *Mux + + // Routing context pool + pool *sync.Pool + + // Custom route not found handler + notFoundHandler http.HandlerFunc + + // The middleware stack + middlewares []func(http.Handler) http.Handler + + // Controls the behaviour of middleware chain generation when a mux + // is registered as an inline group inside another mux. + inline bool +} + +// NewMux returns a newly initialized Mux object that implements the Router +// interface. +func NewMux() *Mux { + mux := &Mux{tree: &node{}, pool: &sync.Pool{}} + mux.pool.New = func() interface{} { + return NewRouteContext() + } + return mux +} + +// ServeHTTP is the single method of the http.Handler interface that makes +// Mux interoperable with the standard library. It uses a sync.Pool to get and +// reuse routing contexts for each request. +func (mx *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Ensure the mux has some routes defined on the mux + if mx.handler == nil { + mx.NotFoundHandler().ServeHTTP(w, r) + return + } + + // Check if a routing context already exists from a parent router. + rctx, _ := r.Context().Value(RouteCtxKey).(*Context) + if rctx != nil { + mx.handler.ServeHTTP(w, r) + return + } + + // Fetch a RouteContext object from the sync pool, and call the computed + // mx.handler that is comprised of mx.middlewares + mx.routeHTTP. + // Once the request is finished, reset the routing context and put it back + // into the pool for reuse from another request. + rctx = mx.pool.Get().(*Context) + rctx.Reset() + rctx.Routes = mx + rctx.parentCtx = r.Context() + + // NOTE: r.WithContext() causes 2 allocations and context.WithValue() causes 1 allocation + r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx)) + + // Serve the request and once its done, put the request context back in the sync pool + mx.handler.ServeHTTP(w, r) + mx.pool.Put(rctx) +} + +// Use appends a middleware handler to the Mux middleware stack. +// +// The middleware stack for any Mux will execute before searching for a matching +// route to a specific handler, which provides opportunity to respond early, +// change the course of the request execution, or set request-scoped values for +// the next http.Handler. +func (mx *Mux) Use(middlewares ...func(http.Handler) http.Handler) { + if mx.handler != nil { + panic("chi: all middlewares must be defined before routes on a mux") + } + mx.middlewares = append(mx.middlewares, middlewares...) +} + +// Handle adds the route `pattern` that matches any http method to +// execute the `handler` http.Handler. +func (mx *Mux) Handle(pattern string, handler http.Handler) { + if i := strings.IndexAny(pattern, " \t"); i >= 0 { + method, rest := pattern[:i], strings.TrimLeft(pattern[i+1:], " \t") + mx.Method(method, rest, handler) + return + } + + mx.handle(mALL, pattern, handler) +} + +// HandleFunc adds the route `pattern` that matches any http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) HandleFunc(pattern string, handlerFn http.HandlerFunc) { + mx.Handle(pattern, handlerFn) +} + +// Method adds the route `pattern` that matches `method` http method to +// execute the `handler` http.Handler. +func (mx *Mux) Method(method, pattern string, handler http.Handler) { + m, ok := methodMap[strings.ToUpper(method)] + if !ok { + panic(fmt.Sprintf("chi: '%s' http method is not supported.", method)) + } + mx.handle(m, pattern, handler) +} + +// MethodFunc adds the route `pattern` that matches `method` http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) MethodFunc(method, pattern string, handlerFn http.HandlerFunc) { + mx.Method(method, pattern, handlerFn) +} + +// Connect adds the route `pattern` that matches a CONNECT http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Connect(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mCONNECT, pattern, handlerFn) +} + +// Delete adds the route `pattern` that matches a DELETE http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Delete(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mDELETE, pattern, handlerFn) +} + +// Get adds the route `pattern` that matches a GET http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Get(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mGET, pattern, handlerFn) +} + +// Head adds the route `pattern` that matches a HEAD http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Head(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mHEAD, pattern, handlerFn) +} + +// Options adds the route `pattern` that matches an OPTIONS http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Options(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mOPTIONS, pattern, handlerFn) +} + +// Patch adds the route `pattern` that matches a PATCH http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Patch(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mPATCH, pattern, handlerFn) +} + +// Post adds the route `pattern` that matches a POST http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Post(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mPOST, pattern, handlerFn) +} + +// Put adds the route `pattern` that matches a PUT http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Put(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mPUT, pattern, handlerFn) +} + +// Trace adds the route `pattern` that matches a TRACE http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Trace(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mTRACE, pattern, handlerFn) +} + +// NotFound sets a custom http.HandlerFunc for routing paths that could +// not be found. The default 404 handler is `http.NotFound`. +func (mx *Mux) NotFound(handlerFn http.HandlerFunc) { + // Build NotFound handler chain + m := mx + hFn := handlerFn + if mx.inline && mx.parent != nil { + m = mx.parent + hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP + } + + // Update the notFoundHandler from this point forward + m.notFoundHandler = hFn + m.updateSubRoutes(func(subMux *Mux) { + if subMux.notFoundHandler == nil { + subMux.NotFound(hFn) + } + }) +} + +// MethodNotAllowed sets a custom http.HandlerFunc for routing paths where the +// method is unresolved. The default handler returns a 405 with an empty body. +func (mx *Mux) MethodNotAllowed(handlerFn http.HandlerFunc) { + // Build MethodNotAllowed handler chain + m := mx + hFn := handlerFn + if mx.inline && mx.parent != nil { + m = mx.parent + hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP + } + + // Update the methodNotAllowedHandler from this point forward + m.methodNotAllowedHandler = hFn + m.updateSubRoutes(func(subMux *Mux) { + if subMux.methodNotAllowedHandler == nil { + subMux.MethodNotAllowed(hFn) + } + }) +} + +// With adds inline middlewares for an endpoint handler. +func (mx *Mux) With(middlewares ...func(http.Handler) http.Handler) Router { + // Similarly as in handle(), we must build the mux handler once additional + // middleware registration isn't allowed for this stack, like now. + if !mx.inline && mx.handler == nil { + mx.updateRouteHandler() + } + + // Copy middlewares from parent inline muxs + var mws Middlewares + if mx.inline { + mws = make(Middlewares, len(mx.middlewares)) + copy(mws, mx.middlewares) + } + mws = append(mws, middlewares...) + + im := &Mux{ + pool: mx.pool, inline: true, parent: mx, tree: mx.tree, middlewares: mws, + notFoundHandler: mx.notFoundHandler, methodNotAllowedHandler: mx.methodNotAllowedHandler, + } + + return im +} + +// Group creates a new inline-Mux with a copy of middleware stack. It's useful +// for a group of handlers along the same routing path that use an additional +// set of middlewares. See _examples/. +func (mx *Mux) Group(fn func(r Router)) Router { + im := mx.With() + if fn != nil { + fn(im) + } + return im +} + +// Route creates a new Mux and mounts it along the `pattern` as a subrouter. +// Effectively, this is a short-hand call to Mount. See _examples/. +func (mx *Mux) Route(pattern string, fn func(r Router)) Router { + if fn == nil { + panic(fmt.Sprintf("chi: attempting to Route() a nil subrouter on '%s'", pattern)) + } + subRouter := NewRouter() + fn(subRouter) + mx.Mount(pattern, subRouter) + return subRouter +} + +// Mount attaches another http.Handler or chi Router as a subrouter along a routing +// path. It's very useful to split up a large API as many independent routers and +// compose them as a single service using Mount. See _examples/. +// +// Note that Mount() simply sets a wildcard along the `pattern` that will continue +// routing at the `handler`, which in most cases is another chi.Router. As a result, +// if you define two Mount() routes on the exact same pattern the mount will panic. +func (mx *Mux) Mount(pattern string, handler http.Handler) { + if handler == nil { + panic(fmt.Sprintf("chi: attempting to Mount() a nil handler on '%s'", pattern)) + } + + // Provide runtime safety for ensuring a pattern isn't mounted on an existing + // routing pattern. + if mx.tree.findPattern(pattern+"*") || mx.tree.findPattern(pattern+"/*") { + panic(fmt.Sprintf("chi: attempting to Mount() a handler on an existing path, '%s'", pattern)) + } + + // Assign sub-Router's with the parent not found & method not allowed handler if not specified. + subr, ok := handler.(*Mux) + if ok && subr.notFoundHandler == nil && mx.notFoundHandler != nil { + subr.NotFound(mx.notFoundHandler) + } + if ok && subr.methodNotAllowedHandler == nil && mx.methodNotAllowedHandler != nil { + subr.MethodNotAllowed(mx.methodNotAllowedHandler) + } + + mountHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rctx := RouteContext(r.Context()) + + // shift the url path past the previous subrouter + rctx.RoutePath = mx.nextRoutePath(rctx) + + // reset the wildcard URLParam which connects the subrouter + n := len(rctx.URLParams.Keys) - 1 + if n >= 0 && rctx.URLParams.Keys[n] == "*" && len(rctx.URLParams.Values) > n { + rctx.URLParams.Values[n] = "" + } + + handler.ServeHTTP(w, r) + }) + + if pattern == "" || pattern[len(pattern)-1] != '/' { + mx.handle(mALL|mSTUB, pattern, mountHandler) + mx.handle(mALL|mSTUB, pattern+"/", mountHandler) + pattern += "/" + } + + method := mALL + subroutes, _ := handler.(Routes) + if subroutes != nil { + method |= mSTUB + } + n := mx.handle(method, pattern+"*", mountHandler) + + if subroutes != nil { + n.subroutes = subroutes + } +} + +// Routes returns a slice of routing information from the tree, +// useful for traversing available routes of a router. +func (mx *Mux) Routes() []Route { + return mx.tree.routes() +} + +// Middlewares returns a slice of middleware handler functions. +func (mx *Mux) Middlewares() Middlewares { + return mx.middlewares +} + +// Match searches the routing tree for a handler that matches the method/path. +// It's similar to routing a http request, but without executing the handler +// thereafter. +// +// Note: the *Context state is updated during execution, so manage +// the state carefully or make a NewRouteContext(). +func (mx *Mux) Match(rctx *Context, method, path string) bool { + return mx.Find(rctx, method, path) != "" +} + +// Find searches the routing tree for the pattern that matches +// the method/path. +// +// Note: the *Context state is updated during execution, so manage +// the state carefully or make a NewRouteContext(). +func (mx *Mux) Find(rctx *Context, method, path string) string { + m, ok := methodMap[method] + if !ok { + return "" + } + + node, _, _ := mx.tree.FindRoute(rctx, m, path) + pattern := rctx.routePattern + + if node != nil { + if node.subroutes == nil { + e := node.endpoints[m] + return e.pattern + } + + rctx.RoutePath = mx.nextRoutePath(rctx) + subPattern := node.subroutes.Find(rctx, method, rctx.RoutePath) + if subPattern == "" { + return "" + } + + pattern = strings.TrimSuffix(pattern, "/*") + pattern += subPattern + } + + return pattern +} + +// NotFoundHandler returns the default Mux 404 responder whenever a route +// cannot be found. +func (mx *Mux) NotFoundHandler() http.HandlerFunc { + if mx.notFoundHandler != nil { + return mx.notFoundHandler + } + return http.NotFound +} + +// MethodNotAllowedHandler returns the default Mux 405 responder whenever +// a method cannot be resolved for a route. +func (mx *Mux) MethodNotAllowedHandler(methodsAllowed ...methodTyp) http.HandlerFunc { + if mx.methodNotAllowedHandler != nil { + return mx.methodNotAllowedHandler + } + return methodNotAllowedHandler(methodsAllowed...) +} + +// handle registers a http.Handler in the routing tree for a particular http method +// and routing pattern. +func (mx *Mux) handle(method methodTyp, pattern string, handler http.Handler) *node { + if len(pattern) == 0 || pattern[0] != '/' { + panic(fmt.Sprintf("chi: routing pattern must begin with '/' in '%s'", pattern)) + } + + // Build the computed routing handler for this routing pattern. + if !mx.inline && mx.handler == nil { + mx.updateRouteHandler() + } + + // Build endpoint handler with inline middlewares for the route + var h http.Handler + if mx.inline { + mx.handler = http.HandlerFunc(mx.routeHTTP) + h = Chain(mx.middlewares...).Handler(handler) + } else { + h = handler + } + + // Add the endpoint to the tree and return the node + return mx.tree.InsertRoute(method, pattern, h) +} + +// routeHTTP routes a http.Request through the Mux routing tree to serve +// the matching handler for a particular http method. +func (mx *Mux) routeHTTP(w http.ResponseWriter, r *http.Request) { + // Grab the route context object + rctx := r.Context().Value(RouteCtxKey).(*Context) + + // The request routing path + routePath := rctx.RoutePath + if routePath == "" { + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } else { + routePath = r.URL.Path + } + if routePath == "" { + routePath = "/" + } + } + + // Check if method is supported by chi + if rctx.RouteMethod == "" { + rctx.RouteMethod = r.Method + } + method, ok := methodMap[rctx.RouteMethod] + if !ok { + mx.MethodNotAllowedHandler().ServeHTTP(w, r) + return + } + + // Find the route + if _, _, h := mx.tree.FindRoute(rctx, method, routePath); h != nil { + // Set http.Request path values from our request context + for i, key := range rctx.URLParams.Keys { + value := rctx.URLParams.Values[i] + r.SetPathValue(key, value) + } + r.Pattern = rctx.RoutePattern() + + h.ServeHTTP(w, r) + return + } + if rctx.methodNotAllowed { + mx.MethodNotAllowedHandler(rctx.methodsAllowed...).ServeHTTP(w, r) + } else { + mx.NotFoundHandler().ServeHTTP(w, r) + } +} + +func (mx *Mux) nextRoutePath(rctx *Context) string { + routePath := "/" + nx := len(rctx.routeParams.Keys) - 1 // index of last param in list + if nx >= 0 && rctx.routeParams.Keys[nx] == "*" && len(rctx.routeParams.Values) > nx { + routePath = "/" + rctx.routeParams.Values[nx] + } + return routePath +} + +// Recursively update data on child routers. +func (mx *Mux) updateSubRoutes(fn func(subMux *Mux)) { + for _, r := range mx.tree.routes() { + subMux, ok := r.SubRoutes.(*Mux) + if !ok { + continue + } + fn(subMux) + } +} + +// updateRouteHandler builds the single mux handler that is a chain of the middleware +// stack, as defined by calls to Use(), and the tree router (Mux) itself. After this +// point, no other middlewares can be registered on this Mux's stack. But you can still +// compose additional middlewares via Group()'s or using a chained middleware handler. +func (mx *Mux) updateRouteHandler() { + mx.handler = chain(mx.middlewares, http.HandlerFunc(mx.routeHTTP)) +} + +// methodNotAllowedHandler is a helper function to respond with a 405, +// method not allowed. It sets the Allow header with the list of allowed +// methods for the route. +func methodNotAllowedHandler(methodsAllowed ...methodTyp) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + for _, m := range methodsAllowed { + w.Header().Add("Allow", reverseMethodMap[m]) + } + w.WriteHeader(405) + w.Write(nil) + } +} diff --git a/testdata/chi/tree.go b/testdata/chi/tree.go new file mode 100644 index 0000000..95f31d4 --- /dev/null +++ b/testdata/chi/tree.go @@ -0,0 +1,877 @@ +package chi + +// Radix tree implementation below is a based on the original work by +// Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go +// (MIT licensed). It's been heavily modified for use as a HTTP routing tree. + +import ( + "fmt" + "net/http" + "regexp" + "slices" + "sort" + "strconv" + "strings" +) + +type methodTyp uint + +const ( + mSTUB methodTyp = 1 << iota + mCONNECT + mDELETE + mGET + mHEAD + mOPTIONS + mPATCH + mPOST + mPUT + mTRACE +) + +var mALL = mCONNECT | mDELETE | mGET | mHEAD | + mOPTIONS | mPATCH | mPOST | mPUT | mTRACE + +var methodMap = map[string]methodTyp{ + http.MethodConnect: mCONNECT, + http.MethodDelete: mDELETE, + http.MethodGet: mGET, + http.MethodHead: mHEAD, + http.MethodOptions: mOPTIONS, + http.MethodPatch: mPATCH, + http.MethodPost: mPOST, + http.MethodPut: mPUT, + http.MethodTrace: mTRACE, +} + +var reverseMethodMap = map[methodTyp]string{ + mCONNECT: http.MethodConnect, + mDELETE: http.MethodDelete, + mGET: http.MethodGet, + mHEAD: http.MethodHead, + mOPTIONS: http.MethodOptions, + mPATCH: http.MethodPatch, + mPOST: http.MethodPost, + mPUT: http.MethodPut, + mTRACE: http.MethodTrace, +} + +// RegisterMethod adds support for custom HTTP method handlers, available +// via Router#Method and Router#MethodFunc +func RegisterMethod(method string) { + if method == "" { + return + } + method = strings.ToUpper(method) + if _, ok := methodMap[method]; ok { + return + } + n := len(methodMap) + if n > strconv.IntSize-2 { + panic(fmt.Sprintf("chi: max number of methods reached (%d)", strconv.IntSize)) + } + mt := methodTyp(2 << n) + methodMap[method] = mt + reverseMethodMap[mt] = method + mALL |= mt +} + +type nodeTyp uint8 + +const ( + ntStatic nodeTyp = iota // /home + ntRegexp // /{id:[0-9]+} + ntParam // /{user} + ntCatchAll // /api/v1/* +) + +type node struct { + // subroutes on the leaf node + subroutes Routes + + // regexp matcher for regexp nodes + rex *regexp.Regexp + + // HTTP handler endpoints on the leaf node + endpoints endpoints + + // prefix is the common prefix we ignore + prefix string + + // child nodes should be stored in-order for iteration, + // in groups of the node type. + children [ntCatchAll + 1]nodes + + // first byte of the child prefix + tail byte + + // node type: static, regexp, param, catchAll + typ nodeTyp + + // first byte of the prefix + label byte +} + +// endpoints is a mapping of http method constants to handlers +// for a given route. +type endpoints map[methodTyp]*endpoint + +type endpoint struct { + // endpoint handler + handler http.Handler + + // pattern is the routing pattern for handler nodes + pattern string + + // parameter keys recorded on handler nodes + paramKeys []string +} + +func (s endpoints) Value(method methodTyp) *endpoint { + mh, ok := s[method] + if !ok { + mh = &endpoint{} + s[method] = mh + } + return mh +} + +func (n *node) InsertRoute(method methodTyp, pattern string, handler http.Handler) *node { + var parent *node + search := pattern + + for { + // Handle key exhaustion + if len(search) == 0 { + // Insert or update the node's leaf handler + n.setEndpoint(method, handler, pattern) + return n + } + + // We're going to be searching for a wild node next, + // in this case, we need to get the tail + var label = search[0] + var segTail byte + var segEndIdx int + var segTyp nodeTyp + var segRexpat string + if label == '{' || label == '*' { + segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search) + } + + var prefix string + if segTyp == ntRegexp { + prefix = segRexpat + } + + // Look for the edge to attach to + parent = n + n = n.getEdge(segTyp, label, segTail, prefix) + + // No edge, create one + if n == nil { + child := &node{label: label, tail: segTail, prefix: search} + hn := parent.addChild(child, search) + hn.setEndpoint(method, handler, pattern) + + return hn + } + + // Found an edge to match the pattern + + if n.typ > ntStatic { + // We found a param node, trim the param from the search path and continue. + // This param/wild pattern segment would already be on the tree from a previous + // call to addChild when creating a new node. + search = search[segEndIdx:] + continue + } + + // Static nodes fall below here. + // Determine longest prefix of the search key on match. + commonPrefix := longestPrefix(search, n.prefix) + if commonPrefix == len(n.prefix) { + // the common prefix is as long as the current node's prefix we're attempting to insert. + // keep the search going. + search = search[commonPrefix:] + continue + } + + // Split the node + child := &node{ + typ: ntStatic, + prefix: search[:commonPrefix], + } + parent.replaceChild(search[0], segTail, child) + + // Restore the existing node + n.label = n.prefix[commonPrefix] + n.prefix = n.prefix[commonPrefix:] + child.addChild(n, n.prefix) + + // If the new key is a subset, set the method/handler on this node and finish. + search = search[commonPrefix:] + if len(search) == 0 { + child.setEndpoint(method, handler, pattern) + return child + } + + // Create a new edge for the node + subchild := &node{ + typ: ntStatic, + label: search[0], + prefix: search, + } + hn := child.addChild(subchild, search) + hn.setEndpoint(method, handler, pattern) + return hn + } +} + +// addChild appends the new `child` node to the tree using the `pattern` as the trie key. +// For a URL router like chi's, we split the static, param, regexp and wildcard segments +// into different nodes. In addition, addChild will recursively call itself until every +// pattern segment is added to the url pattern tree as individual nodes, depending on type. +func (n *node) addChild(child *node, prefix string) *node { + search := prefix + + // handler leaf node added to the tree is the child. + // this may be overridden later down the flow + hn := child + + // Parse next segment + segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search) + + // Add child depending on next up segment + switch segTyp { + + case ntStatic: + // Search prefix is all static (that is, has no params in path) + // noop + + default: + // Search prefix contains a param, regexp or wildcard + + if segTyp == ntRegexp { + rex, err := regexp.Compile(segRexpat) + if err != nil { + panic(fmt.Sprintf("chi: invalid regexp pattern '%s' in route param", segRexpat)) + } + child.prefix = segRexpat + child.rex = rex + } + + if segStartIdx == 0 { + // Route starts with a param + child.typ = segTyp + + if segTyp == ntCatchAll { + segStartIdx = -1 + } else { + segStartIdx = segEndIdx + } + if segStartIdx < 0 { + segStartIdx = len(search) + } + child.tail = segTail // for params, we set the tail + + if segStartIdx != len(search) { + // add static edge for the remaining part, split the end. + // its not possible to have adjacent param nodes, so its certainly + // going to be a static node next. + + search = search[segStartIdx:] // advance search position + + nn := &node{ + typ: ntStatic, + label: search[0], + prefix: search, + } + hn = child.addChild(nn, search) + } + + } else if segStartIdx > 0 { + // Route has some param + + // starts with a static segment + child.typ = ntStatic + child.prefix = search[:segStartIdx] + child.rex = nil + + // add the param edge node + search = search[segStartIdx:] + + nn := &node{ + typ: segTyp, + label: search[0], + tail: segTail, + } + hn = child.addChild(nn, search) + + } + } + + n.children[child.typ] = append(n.children[child.typ], child) + n.children[child.typ].Sort() + return hn +} + +func (n *node) replaceChild(label, tail byte, child *node) { + for i := 0; i < len(n.children[child.typ]); i++ { + if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail { + n.children[child.typ][i] = child + n.children[child.typ][i].label = label + n.children[child.typ][i].tail = tail + return + } + } + panic("chi: replacing missing child") +} + +func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node { + nds := n.children[ntyp] + for i := range nds { + if nds[i].label == label && nds[i].tail == tail { + if ntyp == ntRegexp && nds[i].prefix != prefix { + continue + } + return nds[i] + } + } + return nil +} + +func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern string) { + // Set the handler for the method type on the node + if n.endpoints == nil { + n.endpoints = make(endpoints) + } + + paramKeys := patParamKeys(pattern) + + if method&mSTUB == mSTUB { + n.endpoints.Value(mSTUB).handler = handler + } + if method&mALL == mALL { + h := n.endpoints.Value(mALL) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + for _, m := range methodMap { + h := n.endpoints.Value(m) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + } + } else { + h := n.endpoints.Value(method) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + } +} + +func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, http.Handler) { + // Reset the context routing pattern and params + rctx.routePattern = "" + rctx.routeParams.Keys = rctx.routeParams.Keys[:0] + rctx.routeParams.Values = rctx.routeParams.Values[:0] + + // Find the routing handlers for the path + rn := n.findRoute(rctx, method, path) + if rn == nil { + return nil, nil, nil + } + + // Record the routing params in the request lifecycle + rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...) + rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...) + + // Record the routing pattern in the request lifecycle + if rn.endpoints[method].pattern != "" { + rctx.routePattern = rn.endpoints[method].pattern + rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern) + } + + return rn, rn.endpoints, rn.endpoints[method].handler +} + +// Recursive edge traversal by checking all nodeTyp groups along the way. +// It's like searching through a multi-dimensional radix trie. +func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { + nn := n + search := path + + for t, nds := range nn.children { + ntyp := nodeTyp(t) + if len(nds) == 0 { + continue + } + + var xn *node + xsearch := search + + var label byte + if search != "" { + label = search[0] + } + + switch ntyp { + case ntStatic: + xn = nds.findEdge(label) + if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) { + continue + } + xsearch = xsearch[len(xn.prefix):] + + case ntParam, ntRegexp: + // short-circuit and return no matching route for empty param values + if xsearch == "" { + continue + } + + // serially loop through each node grouped by the tail delimiter + for _, xn = range nds { + // label for param nodes is the delimiter byte + p := strings.IndexByte(xsearch, xn.tail) + + if p < 0 { + if xn.tail == '/' { + p = len(xsearch) + } else { + continue + } + } else if ntyp == ntRegexp && p == 0 { + continue + } + + if ntyp == ntRegexp && xn.rex != nil { + if !xn.rex.MatchString(xsearch[:p]) { + continue + } + } else if strings.IndexByte(xsearch[:p], '/') != -1 { + // avoid a match across path segments + continue + } + + prevlen := len(rctx.routeParams.Values) + rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p]) + xsearch = xsearch[p:] + + if len(xsearch) == 0 { + if xn.isLeaf() { + h := xn.endpoints[method] + if h != nil && h.handler != nil { + rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) + return xn + } + + for endpoints := range xn.endpoints { + if endpoints == mALL || endpoints == mSTUB { + continue + } + rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints) + } + + // flag that the routing context found a route, but not a corresponding + // supported method + rctx.methodNotAllowed = true + } + } + + // recursively find the next node on this branch + fin := xn.findRoute(rctx, method, xsearch) + if fin != nil { + return fin + } + + // not found on this branch, reset vars + rctx.routeParams.Values = rctx.routeParams.Values[:prevlen] + xsearch = search + } + + rctx.routeParams.Values = append(rctx.routeParams.Values, "") + + default: + // catch-all nodes + rctx.routeParams.Values = append(rctx.routeParams.Values, search) + xn = nds[0] + xsearch = "" + } + + if xn == nil { + continue + } + + // did we find it yet? + if len(xsearch) == 0 { + if xn.isLeaf() { + h := xn.endpoints[method] + if h != nil && h.handler != nil { + rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) + return xn + } + + for endpoints := range xn.endpoints { + if endpoints == mALL || endpoints == mSTUB { + continue + } + rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints) + } + + // flag that the routing context found a route, but not a corresponding + // supported method + rctx.methodNotAllowed = true + } + } + + // recursively find the next node.. + fin := xn.findRoute(rctx, method, xsearch) + if fin != nil { + return fin + } + + // Did not find final handler, let's remove the param here if it was set + if xn.typ > ntStatic { + if len(rctx.routeParams.Values) > 0 { + rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1] + } + } + + } + + return nil +} + +func (n *node) findEdge(ntyp nodeTyp, label byte) *node { + nds := n.children[ntyp] + num := len(nds) + idx := 0 + + switch ntyp { + case ntStatic, ntParam, ntRegexp: + i, j := 0, num-1 + for i <= j { + idx = i + (j-i)/2 + if label > nds[idx].label { + i = idx + 1 + } else if label < nds[idx].label { + j = idx - 1 + } else { + i = num // breaks cond + } + } + if nds[idx].label != label { + return nil + } + return nds[idx] + + default: // catch all + return nds[idx] + } +} + +func (n *node) isLeaf() bool { + return n.endpoints != nil +} + +func (n *node) findPattern(pattern string) bool { + nn := n + for _, nds := range nn.children { + if len(nds) == 0 { + continue + } + + n = nn.findEdge(nds[0].typ, pattern[0]) + if n == nil { + continue + } + + var idx int + var xpattern string + + switch n.typ { + case ntStatic: + idx = longestPrefix(pattern, n.prefix) + if idx < len(n.prefix) { + continue + } + + case ntParam, ntRegexp: + idx = strings.IndexByte(pattern, '}') + 1 + + case ntCatchAll: + idx = longestPrefix(pattern, "*") + + default: + panic("chi: unknown node type") + } + + xpattern = pattern[idx:] + if len(xpattern) == 0 { + return true + } + + return n.findPattern(xpattern) + } + return false +} + +func (n *node) routes() []Route { + rts := []Route{} + + n.walk(func(eps endpoints, subroutes Routes) bool { + if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil { + return false + } + + // Group methodHandlers by unique patterns + pats := make(map[string]endpoints) + + for mt, h := range eps { + if h.pattern == "" { + continue + } + p, ok := pats[h.pattern] + if !ok { + p = endpoints{} + pats[h.pattern] = p + } + p[mt] = h + } + + for p, mh := range pats { + hs := make(map[string]http.Handler) + if mh[mALL] != nil && mh[mALL].handler != nil { + hs["*"] = mh[mALL].handler + } + + for mt, h := range mh { + if h.handler == nil { + continue + } + if m, ok := reverseMethodMap[mt]; ok { + hs[m] = h.handler + } + } + + rt := Route{subroutes, hs, p} + rts = append(rts, rt) + } + + return false + }) + + return rts +} + +func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool { + // Visit the leaf values if any + if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) { + return true + } + + // Recurse on the children + for _, ns := range n.children { + for _, cn := range ns { + if cn.walk(fn) { + return true + } + } + } + return false +} + +// patNextSegment returns the next segment details from a pattern: +// node type, param key, regexp string, param tail byte, param starting index, param ending index +func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) { + ps := strings.Index(pattern, "{") + ws := strings.Index(pattern, "*") + + if ps < 0 && ws < 0 { + return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing + } + + // Sanity check + if ps >= 0 && ws >= 0 && ws < ps { + panic("chi: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'") + } + + var tail byte = '/' // Default endpoint tail to / byte + + if ps >= 0 { + // Param/Regexp pattern is next + nt := ntParam + + // Read to closing } taking into account opens and closes in curl count (cc) + cc := 0 + pe := ps + for i, c := range pattern[ps:] { + if c == '{' { + cc++ + } else if c == '}' { + cc-- + if cc == 0 { + pe = ps + i + break + } + } + } + if pe == ps { + panic("chi: route param closing delimiter '}' is missing") + } + + key := pattern[ps+1 : pe] + pe++ // set end to next position + + if pe < len(pattern) { + tail = pattern[pe] + } + + key, rexpat, isRegexp := strings.Cut(key, ":") + if isRegexp { + nt = ntRegexp + } + + if len(rexpat) > 0 { + if rexpat[0] != '^' { + rexpat = "^" + rexpat + } + if rexpat[len(rexpat)-1] != '$' { + rexpat += "$" + } + } + + return nt, key, rexpat, tail, ps, pe + } + + // Wildcard pattern as finale + if ws < len(pattern)-1 { + panic("chi: wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead") + } + return ntCatchAll, "*", "", 0, ws, len(pattern) +} + +func patParamKeys(pattern string) []string { + pat := pattern + paramKeys := []string{} + for { + ptyp, paramKey, _, _, _, e := patNextSegment(pat) + if ptyp == ntStatic { + return paramKeys + } + for i := 0; i < len(paramKeys); i++ { + if paramKeys[i] == paramKey { + panic(fmt.Sprintf("chi: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey)) + } + } + paramKeys = append(paramKeys, paramKey) + pat = pat[e:] + } +} + +// longestPrefix finds the length of the shared prefix of two strings +func longestPrefix(k1, k2 string) (i int) { + for i = 0; i < min(len(k1), len(k2)); i++ { + if k1[i] != k2[i] { + break + } + } + return +} + +type nodes []*node + +// Sort the list of nodes by label +func (ns nodes) Sort() { sort.Sort(ns); ns.tailSort() } +func (ns nodes) Len() int { return len(ns) } +func (ns nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] } +func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label } + +// tailSort pushes nodes with '/' as the tail to the end of the list for param nodes. +// The list order determines the traversal order. +func (ns nodes) tailSort() { + for i := len(ns) - 1; i >= 0; i-- { + if ns[i].typ > ntStatic && ns[i].tail == '/' { + ns.Swap(i, len(ns)-1) + return + } + } +} + +func (ns nodes) findEdge(label byte) *node { + num := len(ns) + idx := 0 + i, j := 0, num-1 + for i <= j { + idx = i + (j-i)/2 + if label > ns[idx].label { + i = idx + 1 + } else if label < ns[idx].label { + j = idx - 1 + } else { + i = num // breaks cond + } + } + if ns[idx].label != label { + return nil + } + return ns[idx] +} + +// Route describes the details of a routing handler. +// Handlers map key is an HTTP method +type Route struct { + SubRoutes Routes + Handlers map[string]http.Handler + Pattern string +} + +// WalkFunc is the type of the function called for each method and route visited by Walk. +type WalkFunc func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error + +// Walk walks any router tree that implements Routes interface. +func Walk(r Routes, walkFn WalkFunc) error { + return walk(r, walkFn, "") +} + +func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(http.Handler) http.Handler) error { + for _, route := range r.Routes() { + mws := slices.Concat(parentMw, r.Middlewares()) + + if route.SubRoutes != nil { + if handler, ok := route.Handlers["*"]; ok { + if chain, ok := handler.(*ChainHandler); ok { + mws = append(mws, chain.Middlewares...) + } + } + + if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil { + return err + } + continue + } + + for method, handler := range route.Handlers { + if method == "*" { + // Ignore a "catchAll" method, since we pass down all the specific methods for each route. + continue + } + + fullRoute := parentRoute + route.Pattern + fullRoute = strings.ReplaceAll(fullRoute, "/*/", "/") + + if chain, ok := handler.(*ChainHandler); ok { + if err := walkFn(method, fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil { + return err + } + } else { + if err := walkFn(method, fullRoute, handler, mws...); err != nil { + return err + } + } + } + } + + return nil +} diff --git a/testdata/fixture/go.mod b/testdata/fixture/go.mod deleted file mode 100644 index 30818c1..0000000 --- a/testdata/fixture/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module example.com/fixture - -go 1.21 diff --git a/testdata/generics/fn/fn.go b/testdata/generics/fn/fn.go new file mode 100644 index 0000000..a18a43b --- /dev/null +++ b/testdata/generics/fn/fn.go @@ -0,0 +1,46 @@ +package fn + +// Ordered is a union-constraint interface for comparable ordered primitives. +type Ordered interface { + ~int | ~int64 | ~float64 | ~string +} + +// Numeric extends Ordered with unsigned integer kinds. +type Numeric interface { + Ordered + ~uint | ~uint64 +} + +func Min[T Ordered](a, b T) T { + if a < b { + return a + } + return b +} + +func Max[T Ordered](a, b T) T { + if a > b { + return a + } + return b +} + +// Map applies f to every element of in and returns the results. +func Map[T, U any](in []T, f func(T) U) []U { + out := make([]U, len(in)) + for i, v := range in { + out[i] = f(v) + } + return out +} + +// Filter returns the elements of in for which keep returns true. +func Filter[T any](in []T, keep func(T) bool) []T { + var out []T + for _, v := range in { + if keep(v) { + out = append(out, v) + } + } + return out +} diff --git a/testdata/generics/go.mod b/testdata/generics/go.mod new file mode 100644 index 0000000..3ad2259 --- /dev/null +++ b/testdata/generics/go.mod @@ -0,0 +1,3 @@ +module example.com/generics + +go 1.21 diff --git a/testdata/generics/main.go b/testdata/generics/main.go new file mode 100644 index 0000000..0ba0512 --- /dev/null +++ b/testdata/generics/main.go @@ -0,0 +1,22 @@ +package main + +import ( + "fmt" + + "example.com/generics/fn" + "example.com/generics/set" +) + +func main() { + s := set.New[string]() + s.Add("hello") + s.Add("world") + fmt.Println(s.Contains("hello"), s.Len()) + + fmt.Println(fn.Min(3, 7)) + fmt.Println(fn.Max(3.14, 2.72)) + + nums := fn.Map([]int{1, 2, 3}, func(x int) string { return fmt.Sprintf("%d", x) }) + evens := fn.Filter([]int{1, 2, 3, 4}, func(x int) bool { return x%2 == 0 }) + fmt.Println(nums, evens) +} diff --git a/testdata/generics/set/set.go b/testdata/generics/set/set.go new file mode 100644 index 0000000..c5039d4 --- /dev/null +++ b/testdata/generics/set/set.go @@ -0,0 +1,36 @@ +package set + +// Set is a generic hash set. +type Set[T comparable] struct { + items map[T]struct{} +} + +func New[T comparable]() *Set[T] { + return &Set[T]{items: make(map[T]struct{})} +} + +func (s *Set[T]) Add(v T) { + s.items[v] = struct{}{} +} + +func (s *Set[T]) Remove(v T) { + delete(s.items, v) +} + +func (s *Set[T]) Contains(v T) bool { + _, ok := s.items[v] + return ok +} + +func (s *Set[T]) Len() int { + return len(s.items) +} + +// Snapshot returns all elements as a slice. Unexported helper for internal use. +func (s *Set[T]) snapshot() []T { + out := make([]T, 0, len(s.items)) + for k := range s.items { + out = append(out, k) + } + return out +} diff --git a/testdata/greeter/go.mod b/testdata/greeter/go.mod new file mode 100644 index 0000000..162252e --- /dev/null +++ b/testdata/greeter/go.mod @@ -0,0 +1,3 @@ +module example.com/greeter + +go 1.21 diff --git a/testdata/fixture/main.go b/testdata/greeter/main.go similarity index 82% rename from testdata/fixture/main.go rename to testdata/greeter/main.go index f20e3f9..1436fd2 100644 --- a/testdata/fixture/main.go +++ b/testdata/greeter/main.go @@ -3,7 +3,7 @@ package main import ( "fmt" - "example.com/fixture/pkg/greeter" + "example.com/greeter/pkg/greeter" ) func main() { diff --git a/testdata/fixture/pkg/greeter/greeter.go b/testdata/greeter/pkg/greeter/greeter.go similarity index 100% rename from testdata/fixture/pkg/greeter/greeter.go rename to testdata/greeter/pkg/greeter/greeter.go diff --git a/testdata/multipackage/go.mod b/testdata/multipackage/go.mod new file mode 100644 index 0000000..6a64ad1 --- /dev/null +++ b/testdata/multipackage/go.mod @@ -0,0 +1,3 @@ +module example.com/multipackage + +go 1.21 diff --git a/testdata/realistic/main.go b/testdata/multipackage/main.go similarity index 87% rename from testdata/realistic/main.go rename to testdata/multipackage/main.go index 81ae136..d3d6f89 100644 --- a/testdata/realistic/main.go +++ b/testdata/multipackage/main.go @@ -4,8 +4,8 @@ import ( "fmt" "log" - "example.com/realistic/server" - "example.com/realistic/worker" + "example.com/multipackage/server" + "example.com/multipackage/worker" ) func main() { diff --git a/testdata/realistic/server/middleware.go b/testdata/multipackage/server/middleware.go similarity index 100% rename from testdata/realistic/server/middleware.go rename to testdata/multipackage/server/middleware.go diff --git a/testdata/realistic/server/server.go b/testdata/multipackage/server/server.go similarity index 100% rename from testdata/realistic/server/server.go rename to testdata/multipackage/server/server.go diff --git a/testdata/multipackage/server/server_test.go b/testdata/multipackage/server/server_test.go new file mode 100644 index 0000000..351ae34 --- /dev/null +++ b/testdata/multipackage/server/server_test.go @@ -0,0 +1,16 @@ +package server + +import "testing" + +// TestServer_Addr is a minimal test in the realistic fixture's server package. +// Its only purpose is to give the --skip-tests=false integration test a +// _test.go file to look for in the symbol table. +func TestServer_Addr(t *testing.T) { + s, err := New(Config{Host: "localhost", Port: 8080}) + if err != nil { + t.Fatalf("New: %v", err) + } + if s.Addr() == "" { + t.Error("Addr() returned empty string") + } +} diff --git a/testdata/realistic/worker/worker.go b/testdata/multipackage/worker/worker.go similarity index 100% rename from testdata/realistic/worker/worker.go rename to testdata/multipackage/worker/worker.go diff --git a/testdata/realistic/go.mod b/testdata/realistic/go.mod deleted file mode 100644 index f1da8db..0000000 --- a/testdata/realistic/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module example.com/realistic - -go 1.21 From 4c0edaae238b46a6d1702a524b7f9f81d97488a9 Mon Sep 17 00:00:00 2001 From: Saurabh Sinha Date: Fri, 3 Jul 2026 13:24:19 -0400 Subject: [PATCH 4/6] chore: add CI workflow, fix README fixture names, and gitignore output/ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add .github/workflows/ci.yml: test (go build + go test) and lint (golangci-lint) jobs triggered on push to main/feat/** and PRs to main - Fix stale testdata path references in README (fixture/→greeter/, realistic/→multipackage/) and update test count (57→105) - Add /output/ to .gitignore to exclude generated analysis artifacts Signed-off-by: Saurabh Sinha --- .github/workflows/ci.yml | 42 ++++++++++++++++++++++++++++++++++++++++ .gitignore | 3 +++ README.md | 6 +++--- 3 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..e716ec8 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,42 @@ +name: CI + +on: + push: + branches: [main, "feat/**"] + pull_request: + branches: [main] + +jobs: + test: + name: Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Build + run: go build ./... + + - name: Test + run: go test ./... -count=1 + + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: latest + args: --timeout=5m diff --git a/.gitignore b/.gitignore index 08088c3..ba10dd8 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,6 @@ go.work.sum # macOS .DS_Store + +# Generated analysis output +/output/ diff --git a/README.md b/README.md index 9310f1c..23c87a2 100644 --- a/README.md +++ b/README.md @@ -188,8 +188,8 @@ codeanalyzer-go/ │ ├── frameworks/ # BaseEntrypointFinder — extension seam for framework passes │ └── utils/ # DiscoverGoFiles, IsVendored, IsTestFile, logging ├── testdata/ -│ ├── fixture/ # Minimal two-package fixture (basic struct/interface/call sites) -│ ├── realistic/ # Richer fixture covering embedded fields, variadic params, goroutines, … +│ ├── greeter/ # Minimal two-package fixture (basic struct/interface/call sites) +│ ├── multipackage/ # Richer fixture covering embedded fields, variadic params, goroutines, … │ ├── generics/ # Go 1.18+ generics fixture (Set[T], union-constraint interfaces, Map[T,U]) │ └── chi/ # External-dep fixture (chi v5, vendored) for HTTP handler patterns ``` @@ -204,7 +204,7 @@ The `core` package is a pure orchestrator: it calls `syntactic_analysis` → `se go test ./... ``` -Tests run against four fixtures: `testdata/fixture/` (basic), `testdata/realistic/` (multi-file packages, goroutines, variadic params), `testdata/generics/` (Go 1.18+ generics — `Set[T]`, union constraints, multi-type-param functions), and `testdata/chi/` (external dependency via vendored chi v5, HTTP handler patterns). All 57 tests cover symbol table correctness, generic receiver attribution, call graph edges, JSON round-trip, output format validation, caching behaviour, and error paths. +Tests run against four fixtures: `testdata/greeter/` (basic), `testdata/multipackage/` (multi-file packages, goroutines, variadic params), `testdata/generics/` (Go 1.18+ generics — `Set[T]`, union constraints, multi-type-param functions), and `testdata/chi/` (external dependency via vendored chi v5, HTTP handler patterns). All 105 tests cover symbol table correctness, generic receiver attribution, call graph edges, JSON round-trip, output format validation, caching behaviour, and error paths. `go test` caches passing results by source hash. To force a full re-run: From 8d4f35b5682d48951b8a737f0351e6f26bc128e4 Mon Sep 17 00:00:00 2001 From: Saurabh Sinha Date: Fri, 3 Jul 2026 13:45:27 -0400 Subject: [PATCH 5/6] fix: install golangci-lint via go install to match Go 1.25 toolchain golangci-lint-action@v6 resolved to v1.64.8 (built with Go 1.24) which refuses to analyze go 1.25.0 modules. Installing via go install compiles the tool from source using the project's own toolchain, so the version check always passes. Signed-off-by: Saurabh Sinha --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e716ec8..6cdb60c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,8 +35,8 @@ jobs: with: go-version-file: go.mod + - name: Install golangci-lint + run: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + - name: golangci-lint - uses: golangci/golangci-lint-action@v6 - with: - version: latest - args: --timeout=5m + run: golangci-lint run --timeout=5m From 02505d45dad4262dea0919936279758e8c9d20b7 Mon Sep 17 00:00:00 2001 From: Saurabh Sinha Date: Fri, 3 Jul 2026 13:53:39 -0400 Subject: [PATCH 6/6] fix: resolve golangci-lint errcheck and unused findings - fs_test.go: use writeFile helper for TestFileHash_DifferentContent and check os.MkdirAll error in TestDiscoverGoFiles_SkipsVendorDir - signature.go: remove unused signatureOfNamed, signatureForCall, normalizeReturnType functions and their unused strings import Signed-off-by: Saurabh Sinha --- internal/syntactic_analysis/signature.go | 35 ------------------------ internal/utils/fs_test.go | 8 ++++-- 2 files changed, 5 insertions(+), 38 deletions(-) diff --git a/internal/syntactic_analysis/signature.go b/internal/syntactic_analysis/signature.go index 6f538bf..cdfd71e 100644 --- a/internal/syntactic_analysis/signature.go +++ b/internal/syntactic_analysis/signature.go @@ -4,7 +4,6 @@ package syntactic_analysis import ( "fmt" "go/types" - "strings" ) // signatureOf is the single canonicalizer for all signature strings in the analyzer. @@ -50,37 +49,3 @@ func signatureOf(obj types.Object) string { } } -// signatureOfNamed builds a type signature from a *types.Named directly. -// Used when we have the named type but not a types.Object. -func signatureOfNamed(named *types.Named) string { - if named == nil { - return "" - } - obj := named.Obj() - pkgPath := "" - if obj.Pkg() != nil { - pkgPath = obj.Pkg().Path() - } - return fmt.Sprintf("%s.%s", pkgPath, obj.Name()) -} - -// signatureForCall builds a callee signature from a *types.Func resolved at a call site. -func signatureForCall(fn *types.Func) string { - return signatureOf(fn) -} - -// normalizeReturnType joins multiple return types into a single parenthesized string. -// Single non-error returns are returned as-is; multiple returns become "(t1, t2, ...)". -func normalizeReturnType(results *types.Tuple) (joined string, parts []string) { - if results == nil || results.Len() == 0 { - return "", nil - } - parts = make([]string, results.Len()) - for i := 0; i < results.Len(); i++ { - parts[i] = results.At(i).Type().String() - } - if len(parts) == 1 { - return parts[0], parts - } - return "(" + strings.Join(parts, ", ") + ")", parts -} diff --git a/internal/utils/fs_test.go b/internal/utils/fs_test.go index a349d07..91df3d4 100644 --- a/internal/utils/fs_test.go +++ b/internal/utils/fs_test.go @@ -80,10 +80,10 @@ func TestFileHash_Deterministic(t *testing.T) { func TestFileHash_DifferentContent(t *testing.T) { dir := t.TempDir() + writeFile(t, dir, "a.txt", "aaa") + writeFile(t, dir, "b.txt", "bbb") a := filepath.Join(dir, "a.txt") b := filepath.Join(dir, "b.txt") - os.WriteFile(a, []byte("aaa"), 0o644) - os.WriteFile(b, []byte("bbb"), 0o644) ha, _ := utils.FileHash(a) hb, _ := utils.FileHash(b) @@ -168,7 +168,9 @@ func TestDiscoverGoFiles_SkipsVendorDir(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "main.go", "package main") vendor := filepath.Join(dir, "vendor", "pkg") - os.MkdirAll(vendor, 0o755) + if err := os.MkdirAll(vendor, 0o755); err != nil { + t.Fatal(err) + } writeFile(t, vendor, "lib.go", "package pkg") files, err := utils.DiscoverGoFiles(dir, true)