From 5e4cdb8c6576438a27f51222227d2e23c2c808f6 Mon Sep 17 00:00:00 2001 From: Lukas Wuttke Date: Sat, 20 Jun 2026 11:24:29 +0200 Subject: [PATCH 1/3] feat(cli): client create / list / use commands (#84) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit RFC-0001 P1. Adds the `client` subtree the login flow hands off to: provision a tracebloc client for this machine, list the account's clients, and attach this machine to an existing one. - internal/slug: Go port of RFC-0001 Appendix B (backend common/utils/slug.py) — DNS-1123 slugify (NFKD via x/text) + collision suffix + empty-slug guard, kept in lock-step with the backend that validates the result. - internal/api: CreateClient / ListClients / ListClientAdmins against /edge-device/, Bearer-authed (backend#836). - internal/cli/client.go: create (--name / --location / --yes), list (ls), use . The derived namespace is shown for confirmation; location is required (never silent-empty); a 403 surfaces the ask-an-admin path (backend#836); the generated machine credential is printed once. Location auto-detect (cloud-metadata / GeoIP suggested default) is a fast-follow — this PR takes --location or prompts for it. Co-Authored-By: Claude Opus 4.8 (1M context) --- internal/api/client.go | 90 +++++++++++++ internal/cli/client.go | 250 ++++++++++++++++++++++++++++++++---- internal/cli/client_test.go | 126 ++++++++++++++++++ internal/slug/slug.go | 89 +++++++++++++ internal/slug/slug_test.go | 45 +++++++ 5 files changed, 575 insertions(+), 25 deletions(-) create mode 100644 internal/cli/client_test.go create mode 100644 internal/slug/slug.go create mode 100644 internal/slug/slug_test.go diff --git a/internal/api/client.go b/internal/api/client.go index 0bb22d3..be9c532 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -249,3 +249,93 @@ func (c *Client) WhoAmI(ctx context.Context) (*Identity, error) { } return &id, nil } + +// ── Client provisioning (Bearer-authed) — backend#836, /edge-device/ ── + +// ProvisionedClient is a tracebloc client (machine), as returned by the +// EdgeDevice endpoints. +type ProvisionedClient struct { + ID int `json:"id"` + Name string `json:"first_name"` + Username string `json:"username"` + Namespace string `json:"namespace"` + Location string `json:"location"` + Status int `json:"status"` +} + +// CreateClientRequest is the POST /edge-device/ body. The account is stamped +// server-side from the token; password is the machine credential the caller +// generates (write-only on the backend). +type CreateClientRequest struct { + Name string `json:"first_name"` + Namespace string `json:"namespace"` + Location string `json:"location"` + Password string `json:"password"` +} + +// AdminContact is one "ask an admin" entry from GET /edge-device/admins/. +type AdminContact struct { + Name string `json:"name"` + Email string `json:"email"` +} + +// CreateClient provisions a client. A 403 *APIError means the caller lacks +// CLIENT_WRITE — callers fall back to ListClientAdmins for the ask-an-admin +// path (backend#836 Q4). +func (c *Client) CreateClient(ctx context.Context, req CreateClientRequest) (*ProvisionedClient, error) { + url := c.BaseURL + "/edge-device/" + status, raw, err := c.post(ctx, "/edge-device/", req) + if err != nil { + return nil, err + } + if status < 200 || status >= 300 { + return nil, &APIError{StatusCode: status, Body: string(raw), URL: url} + } + var out ProvisionedClient + if err := json.Unmarshal(raw, &out); err != nil { + return nil, fmt.Errorf("decoding create-client response: %w", err) + } + return &out, nil +} + +// ListClients returns the clients in the caller's account (GET /edge-device/). +// Tolerates both a DRF-paginated body and a bare list. +func (c *Client) ListClients(ctx context.Context) ([]ProvisionedClient, error) { + url := c.BaseURL + "/edge-device/" + status, raw, err := c.get(ctx, "/edge-device/") + if err != nil { + return nil, err + } + if status < 200 || status >= 300 { + return nil, &APIError{StatusCode: status, Body: string(raw), URL: url} + } + var list []ProvisionedClient + if err := json.Unmarshal(raw, &list); err == nil { + return list, nil + } + var paged struct { + Results []ProvisionedClient `json:"results"` + } + if err := json.Unmarshal(raw, &paged); err != nil { + return nil, fmt.Errorf("decoding client list: %w", err) + } + return paged.Results, nil +} + +// ListClientAdmins returns who in the account can provision (the ask-an-admin +// path), from GET /edge-device/admins/ (backend#836 Q4). +func (c *Client) ListClientAdmins(ctx context.Context) ([]AdminContact, error) { + url := c.BaseURL + "/edge-device/admins/" + status, raw, err := c.get(ctx, "/edge-device/admins/") + if err != nil { + return nil, err + } + if status < 200 || status >= 300 { + return nil, &APIError{StatusCode: status, Body: string(raw), URL: url} + } + var out []AdminContact + if err := json.Unmarshal(raw, &out); err != nil { + return nil, fmt.Errorf("decoding admins response: %w", err) + } + return out, nil +} diff --git a/internal/cli/client.go b/internal/cli/client.go index ddda04f..d607ac9 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -1,53 +1,53 @@ package cli import ( + "context" + "crypto/rand" + "encoding/hex" "errors" + "fmt" + "net/http" + "strconv" + "strings" "github.com/spf13/cobra" + + "github.com/tracebloc/cli/internal/api" + "github.com/tracebloc/cli/internal/config" + "github.com/tracebloc/cli/internal/slug" + "github.com/tracebloc/cli/internal/ui" ) -// newClientCmd wires the `tracebloc client` subtree — provisioning and -// selecting the client (machine) this host enrolls as. The verbs are stubbed -// here: the implementation is cli#84 and depends on the backend device-grant -// (backend#835, for the user token from `tracebloc login`) and provisioning -// (backend#836). The command shape is in place now so the tree + help are -// stable and `--name`/`--location` are pinned. +// newClientCmd wires the `tracebloc client` subtree — provisioning + selecting +// the client (machine) this host enrolls as. Consumes the backend provisioning +// endpoints (backend#836) with the user token from `tracebloc login`. func newClientCmd() *cobra.Command { cmd := &cobra.Command{ Use: "client", Short: "Provision and manage tracebloc clients (machines)", Long: `Provision a tracebloc client for this machine and list/select clients -in your account. - -Requires sign-in first (` + "`tracebloc login`" + `). Implemented in cli#84; -the backend it calls lands in backend#835 / #836.`, +in your account. Requires sign-in first (` + "`tracebloc login`" + `).`, } cmd.AddCommand(newClientCreateCmd(), newClientListCmd(), newClientUseCmd()) return cmd } -// errClientNotYet is the shared "this lands in cli#84" stub error. -func errClientNotYet() error { - return &exitError{code: 1, err: errors.New( - "`tracebloc client` is not implemented yet — it lands in cli#84 and needs the " + - "backend device-grant (backend#835) + provisioning (backend#836). " + - "`tracebloc login` is the first piece.")} -} - func newClientCreateCmd() *cobra.Command { var name, location string + var yes bool cmd := &cobra.Command{ Use: "create", Short: "Provision a new client for this machine (--name, --location)", Args: cobra.NoArgs, - RunE: func(_ *cobra.Command, _ []string) error { - return errClientNotYet() + RunE: func(cmd *cobra.Command, _ []string) error { + return runClientCreate(cmd.Context(), printerFor(cmd), clientPrompter(), name, location, yes) }, } cmd.Flags().StringVar(&name, "name", "", "human-readable client name (shown on your dashboard + carbon reports)") cmd.Flags().StringVar(&location, "location", "", - "physical location zone for carbon footprint (e.g. DE); auto-detected + confirmed if omitted") + "location zone for carbon footprint (e.g. DE); prompted if omitted") + cmd.Flags().BoolVar(&yes, "yes", false, "skip the confirmation prompt") return cmd } @@ -57,8 +57,8 @@ func newClientListCmd() *cobra.Command { Aliases: []string{"ls"}, Short: "List the clients in your account", Args: cobra.NoArgs, - RunE: func(_ *cobra.Command, _ []string) error { - return errClientNotYet() + RunE: func(cmd *cobra.Command, _ []string) error { + return runClientList(cmd.Context(), printerFor(cmd)) }, } } @@ -68,8 +68,208 @@ func newClientUseCmd() *cobra.Command { Use: "use ", Short: "Enroll this machine as an existing client", Args: cobra.ExactArgs(1), - RunE: func(_ *cobra.Command, _ []string) error { - return errClientNotYet() + RunE: func(cmd *cobra.Command, args []string) error { + return runClientUse(cmd.Context(), printerFor(cmd), args[0]) }, } } + +// clientPrompter returns the interactive prompter on a TTY, else nil (so +// commands fall back to flags-only and never block on a pipe / in CI). +func clientPrompter() prompter { + if isInteractiveTTY() { + return surveyPrompter{} + } + return nil +} + +// authedClient loads the signed-in config and returns a token-bearing API +// client, or an error telling the user to log in. +func authedClient() (*api.Client, *config.Config, error) { + cfg, err := config.Load() + if err != nil { + return nil, nil, err + } + if !cfg.SignedIn() { + return nil, nil, errors.New("not signed in — run `tracebloc login` first") + } + env := cfg.Env + if env == "" { + env = api.ResolveEnv("") + } + client := newAPIClient(env) + client.Token = cfg.Token + return client, cfg, nil +} + +func runClientCreate(ctx context.Context, p *ui.Printer, pr prompter, name, location string, yes bool) error { + client, cfg, err := authedClient() + if err != nil { + return &exitError{code: 1, err: err} + } + + if name == "" { + if pr == nil { + return &exitError{code: 1, err: errors.New("--name is required (non-interactive)")} + } + if name, err = pr.Input("Client name", "shown on your dashboard + carbon reports", "", validateNonEmpty); err != nil { + return mapClientErr(err) + } + } + + // Derive the namespace slug from the name, avoiding collisions with existing + // clients (best-effort: if the list call fails we still derive a base slug). + var existing []string + if clients, lerr := client.ListClients(ctx); lerr == nil { + for _, c := range clients { + if c.Namespace != "" { + existing = append(existing, c.Namespace) + } + } + } + namespace, err := slug.Derive(name, existing, "client-"+randHex(4)) + if err != nil { + return &exitError{code: 1, err: err} + } + + if pr != nil && !yes { + ok, cerr := pr.Confirm(fmt.Sprintf("Provision client %q with namespace %q?", name, namespace), true) + if cerr != nil { + return mapClientErr(cerr) + } + if !ok { + p.Hintf("Cancelled.") + return nil + } + } + + if location == "" { + if pr == nil { + return &exitError{code: 1, err: errors.New("--location is required (non-interactive)")} + } + // Never silent-empty: the prompt requires a non-empty zone. (Cloud / + // GeoIP auto-detect of a suggested default is a fast-follow.) + if location, err = pr.Input("Location zone (e.g. DE)", "physical zone, for the carbon footprint", "", validateNonEmpty); err != nil { + return mapClientErr(err) + } + } + + // The machine credential: the CLI generates the password, the backend stores + // it (write-only). The client-runtime authenticates with username+password. + password := randHex(24) + pc, err := client.CreateClient(ctx, api.CreateClientRequest{ + Name: name, + Namespace: namespace, + Location: location, + Password: password, + }) + if err != nil { + var ae *api.APIError + if errors.As(err, &ae) && ae.StatusCode == http.StatusForbidden { + return askAnAdmin(ctx, p, client) + } + return &exitError{code: 1, err: err} + } + + cfg.ActiveClientID = strconv.Itoa(pc.ID) + if serr := cfg.Save(); serr != nil { + return &exitError{code: 1, err: serr} + } + + p.Newline() + p.Successf("Provisioned client %q (namespace %s).", pc.Name, pc.Namespace) + p.Section("Machine credential — needed by the installer to connect this client") + p.Field("client id", strconv.Itoa(pc.ID)) + p.Field("username", pc.Username) + p.Field("password", password) + return nil +} + +// askAnAdmin renders the "you can't provision — here's who can" path (a 403 from +// the backend means no CLIENT_WRITE; backend#836 Q4). +func askAnAdmin(ctx context.Context, p *ui.Printer, client *api.Client) error { + p.Newline() + p.Hintf("You don't have permission to provision a client in this account.") + if admins, err := client.ListClientAdmins(ctx); err == nil && len(admins) > 0 { + p.Section("Ask one of these admins to provision it (or grant you access)") + for _, a := range admins { + label := a.Name + if label == "" { + label = a.Email + } + p.Field(label, a.Email) + } + } + return &exitError{code: 1, err: errors.New("provisioning requires CLIENT_WRITE permission")} +} + +func runClientList(ctx context.Context, p *ui.Printer) error { + client, cfg, err := authedClient() + if err != nil { + return &exitError{code: 1, err: err} + } + clients, err := client.ListClients(ctx) + if err != nil { + return &exitError{code: 1, err: err} + } + if len(clients) == 0 { + p.Hintf("No clients yet. Run `tracebloc client create`.") + return nil + } + p.Section("Clients in your account") + for _, c := range clients { + marker := "" + if strconv.Itoa(c.ID) == cfg.ActiveClientID { + marker = " (active)" + } + p.Field(strconv.Itoa(c.ID)+marker, + fmt.Sprintf("%s namespace=%s location=%s", c.Name, c.Namespace, c.Location)) + } + return nil +} + +func runClientUse(ctx context.Context, p *ui.Printer, id string) error { + client, cfg, err := authedClient() + if err != nil { + return &exitError{code: 1, err: err} + } + clients, err := client.ListClients(ctx) + if err != nil { + return &exitError{code: 1, err: err} + } + for _, c := range clients { + if strconv.Itoa(c.ID) == id { + cfg.ActiveClientID = id + if serr := cfg.Save(); serr != nil { + return &exitError{code: 1, err: serr} + } + p.Successf("This machine is now set to enroll as client %s (%s).", id, c.Name) + return nil + } + } + return &exitError{code: 1, err: fmt.Errorf( + "no client %s in your account — run `tracebloc client list` to see the ids", id)} +} + +// validateNonEmpty rejects blank prompt input. +func validateNonEmpty(s string) error { + if strings.TrimSpace(s) == "" { + return errors.New("required") + } + return nil +} + +// mapClientErr turns a cancelled interactive prompt into a clean exit. +func mapClientErr(err error) error { + if errors.Is(err, errInteractiveCancelled) { + return nil + } + return &exitError{code: 1, err: err} +} + +// randHex returns nbytes of crypto-random data hex-encoded. +func randHex(nbytes int) string { + b := make([]byte, nbytes) + _, _ = rand.Read(b) // crypto/rand.Read does not fail on a healthy system + return hex.EncodeToString(b) +} diff --git a/internal/cli/client_test.go b/internal/cli/client_test.go new file mode 100644 index 0000000..767fe9c --- /dev/null +++ b/internal/cli/client_test.go @@ -0,0 +1,126 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/tracebloc/cli/internal/api" + "github.com/tracebloc/cli/internal/config" + "github.com/tracebloc/cli/internal/ui" +) + +// withClientBackend points the client commands at an httptest server (via the +// newAPIClient seam) and writes a signed-in config to a temp dir. +func withClientBackend(t *testing.T, h http.HandlerFunc) { + t.Helper() + srv := httptest.NewServer(h) + t.Cleanup(srv.Close) + t.Setenv("TRACEBLOC_CONFIG_DIR", t.TempDir()) + if err := (&config.Config{Env: "dev", Token: "tok"}).Save(); err != nil { + t.Fatal(err) + } + orig := newAPIClient + newAPIClient = func(string) *api.Client { + return &api.Client{BaseURL: srv.URL, HTTP: srv.Client()} + } + t.Cleanup(func() { newAPIClient = orig }) +} + +func TestClientCreate_Success(t *testing.T) { + var body api.CreateClientRequest + withClientBackend(t, func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/edge-device/": + _, _ = w.Write([]byte(`[]`)) // no existing clients + case r.Method == http.MethodPost && r.URL.Path == "/edge-device/": + if got := r.Header.Get("Authorization"); got != "Bearer tok" { + t.Errorf("auth header = %q, want Bearer tok", got) + } + _ = json.NewDecoder(r.Body).Decode(&body) + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":5,"first_name":"my-client","username":"u-123","namespace":"my-client","location":"DE"}`)) + default: + t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + }) + var out bytes.Buffer + if err := runClientCreate(context.Background(), ui.New(&out), nil, "my-client", "DE", true); err != nil { + t.Fatalf("create: %v", err) + } + if body.Namespace != "my-client" || body.Location != "DE" || body.Password == "" { + t.Errorf("create body = %+v", body) + } + cfg, _ := config.Load() + if cfg.ActiveClientID != "5" { + t.Errorf("active client = %q, want 5", cfg.ActiveClientID) + } + if !strings.Contains(out.String(), "u-123") { + t.Errorf("output missing username:\n%s", out.String()) + } +} + +func TestClientCreate_AskAnAdmin(t *testing.T) { + withClientBackend(t, func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/edge-device/" && r.Method == http.MethodGet: + _, _ = w.Write([]byte(`[]`)) + case r.URL.Path == "/edge-device/" && r.Method == http.MethodPost: + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"detail":"no permission"}`)) + case r.URL.Path == "/edge-device/admins/": + _, _ = w.Write([]byte(`[{"name":"Ada","email":"ada@co.io"}]`)) + } + }) + var out bytes.Buffer + err := runClientCreate(context.Background(), ui.New(&out), nil, "my-client", "DE", true) + if err == nil || !strings.Contains(err.Error(), "CLIENT_WRITE") { + t.Errorf("want permission error, got %v", err) + } + if !strings.Contains(out.String(), "ada@co.io") { + t.Errorf("expected admins shown, got:\n%s", out.String()) + } +} + +func TestClientCreate_RequiresLogin(t *testing.T) { + t.Setenv("TRACEBLOC_CONFIG_DIR", t.TempDir()) // no config → not signed in + err := runClientCreate(context.Background(), ui.New(&bytes.Buffer{}), nil, "x", "DE", true) + if err == nil || !strings.Contains(err.Error(), "login") { + t.Errorf("want not-signed-in error, got %v", err) + } +} + +func TestClientList(t *testing.T) { + withClientBackend(t, func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`[{"id":1,"first_name":"alpha","namespace":"alpha","location":"DE"},{"id":2,"first_name":"beta","namespace":"beta","location":"US"}]`)) + }) + var out bytes.Buffer + if err := runClientList(context.Background(), ui.New(&out)); err != nil { + t.Fatal(err) + } + for _, want := range []string{"alpha", "beta"} { + if !strings.Contains(out.String(), want) { + t.Errorf("list missing %q:\n%s", want, out.String()) + } + } +} + +func TestClientUse(t *testing.T) { + withClientBackend(t, func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`[{"id":7,"first_name":"gamma","namespace":"gamma"}]`)) + }) + if err := runClientUse(context.Background(), ui.New(&bytes.Buffer{}), "7"); err != nil { + t.Fatal(err) + } + cfg, _ := config.Load() + if cfg.ActiveClientID != "7" { + t.Errorf("active = %q, want 7", cfg.ActiveClientID) + } + if err := runClientUse(context.Background(), ui.New(&bytes.Buffer{}), "99"); err == nil { + t.Error("expected an error for an unknown client id") + } +} diff --git a/internal/slug/slug.go b/internal/slug/slug.go new file mode 100644 index 0000000..15e75aa --- /dev/null +++ b/internal/slug/slug.go @@ -0,0 +1,89 @@ +// Package slug ports the RFC-0001 namespace-slug rule (backend +// common/utils/slug.py, RFC-0001 Appendix B) to Go. A client's display name is +// slugified ONCE at provisioning into the immutable Kubernetes namespace, so +// this MUST stay in lock-step with the Python definition — the backend +// validates exactly what this produces (RFC-0001 backend#830; provisioning + +// namespace validation in backend#836). +package slug + +import ( + "fmt" + "regexp" + "strings" + + "golang.org/x/text/unicode/norm" +) + +// MaxLabelLength is the DNS-1123 label cap. +const MaxLabelLength = 63 + +var ( + nonAlnum = regexp.MustCompile(`[^a-z0-9]+`) + multiDash = regexp.MustCompile(`-+`) +) + +// Slugify maps name to a DNS-1123 label, or "" if nothing survives. Mirrors +// slug.slugify_dns1123: NFKD-transliterate to ASCII, lowercase, map every run +// of non-alphanumerics to a single "-", trim leading/trailing "-", cap at 63. +func Slugify(name string) string { + if name == "" { + return "" + } + s := strings.ToLower(toASCII(name)) + s = nonAlnum.ReplaceAllString(s, "-") + s = multiDash.ReplaceAllString(s, "-") + s = strings.Trim(s, "-") + if len(s) > MaxLabelLength { + s = s[:MaxLabelLength] + } + return strings.TrimRight(s, "-") +} + +// Derive returns a UNIQUE DNS-1123 slug for name, avoiding taken. On collision +// it appends -2, -3, … within the 63-char cap. If name slugifies to empty it +// falls back to fallback; with an empty fallback it errors (an empty slug must +// never silently become a namespace). Mirrors slug.derive_slug. +func Derive(name string, taken []string, fallback string) (string, error) { + base := Slugify(name) + if base == "" { + if fallback == "" { + return "", fmt.Errorf("name %q slugifies to empty; a fallback is required", name) + } + if base = Slugify(fallback); base == "" { + base = fallback + } + } + set := make(map[string]struct{}, len(taken)) + for _, t := range taken { + set[t] = struct{}{} + } + if _, clash := set[base]; !clash { + return base, nil + } + for n := 2; ; n++ { + suffix := fmt.Sprintf("-%d", n) + end := MaxLabelLength - len(suffix) + if end > len(base) { + end = len(base) + } + if end < 0 { + end = 0 + } + cand := strings.TrimRight(base[:end], "-") + suffix + if _, clash := set[cand]; !clash { + return cand, nil + } + } +} + +func toASCII(s string) string { + // NFKD decompose then drop non-ASCII — matches Python's + // unicodedata.normalize("NFKD", s).encode("ascii", "ignore"). + var b strings.Builder + for _, r := range norm.NFKD.String(s) { + if r < 128 { + b.WriteByte(byte(r)) + } + } + return b.String() +} diff --git a/internal/slug/slug_test.go b/internal/slug/slug_test.go new file mode 100644 index 0000000..ac5c516 --- /dev/null +++ b/internal/slug/slug_test.go @@ -0,0 +1,45 @@ +package slug + +import ( + "strings" + "testing" +) + +func TestSlugify(t *testing.T) { + cases := map[string]string{ + "My Client": "my-client", + "café": "cafe", // NFKD transliteration, matches slug.py + " spaces ": "spaces", + "UPPER_Case": "upper-case", + "a--b": "a-b", // collapse consecutive dashes + "--trim--": "trim", + "": "", + "!!!": "", + strings.Repeat("a", 70): strings.Repeat("a", 63), // 63-char cap + } + for in, want := range cases { + if got := Slugify(in); got != want { + t.Errorf("Slugify(%q) = %q, want %q", in, got, want) + } + } +} + +func TestDerive(t *testing.T) { + if got, err := Derive("My Client", nil, ""); err != nil || got != "my-client" { + t.Errorf("no collision: got %q, err %v", got, err) + } + if got, _ := Derive("My Client", []string{"my-client"}, ""); got != "my-client-2" { + t.Errorf("collision: got %q, want my-client-2", got) + } + if got, _ := Derive("My Client", []string{"my-client", "my-client-2"}, ""); got != "my-client-3" { + t.Errorf("double collision: got %q, want my-client-3", got) + } + // CJK slugifies to empty → uses fallback + if got, _ := Derive("世界", nil, "client-abc"); got != "client-abc" { + t.Errorf("fallback: got %q, want client-abc", got) + } + // empty slug + no fallback → error (never a silent-empty namespace) + if _, err := Derive("!!!", nil, ""); err == nil { + t.Error("expected an error for empty slug with no fallback") + } +} From 8058b1b2b63f9c001fab68f0a74002e742d3a97b Mon Sep 17 00:00:00 2001 From: Lukas Wuttke Date: Sun, 21 Jun 2026 12:03:38 +0200 Subject: [PATCH 2/3] refactor(cli): paginate client list + gather-then-review + parity/interactive tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Self-review follow-up on the client commands: - api: ListClients now follows DRF `next` to the end (was page-1 only), so `list`, `use `, and create-time collision detection see every client in the account, not just the first page. - cli: `create` gathers name + location first, then shows one review + a single confirm (was confirm-mid-flow) — matches the dataset-push interactive flow. - tests: committed slug golden-parity test (24 pairs verified byte-identical against the Python slugify_dns1123, incl. NFKD ligatures/fractions/roman/ fullwidth); interactive create + cancel via the prompter seam; paginated list; collision-suffix end-to-end. - slug: doc-note the redundant dash-collapse (mirrors slug.py) and the ""/None fallback divergence. Co-Authored-By: Claude Opus 4.8 (1M context) --- internal/api/client.go | 71 +++++++++++++++----- internal/cli/client.go | 43 ++++++++---- internal/cli/client_test.go | 107 ++++++++++++++++++++++++++++++ internal/slug/slug.go | 7 +- internal/slug/slug_golden_test.go | 44 ++++++++++++ 5 files changed, 240 insertions(+), 32 deletions(-) create mode 100644 internal/slug/slug_golden_test.go diff --git a/internal/api/client.go b/internal/api/client.go index be9c532..a636620 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "strings" "time" @@ -298,28 +299,62 @@ func (c *Client) CreateClient(ctx context.Context, req CreateClientRequest) (*Pr return &out, nil } -// ListClients returns the clients in the caller's account (GET /edge-device/). -// Tolerates both a DRF-paginated body and a bare list. +// maxListPages bounds how many pages ListClients will follow — a backstop +// against a misbehaving `next` chain, set well above any real account. +const maxListPages = 100 + +// ListClients returns ALL clients in the caller's account (GET /edge-device/). +// The endpoint is DRF-paginated, so this follows `next` to the end — list, +// `use `, and create-time collision detection must see every client, not +// just the first page. Also tolerates a bare (unpaginated) list body. func (c *Client) ListClients(ctx context.Context) ([]ProvisionedClient, error) { - url := c.BaseURL + "/edge-device/" - status, raw, err := c.get(ctx, "/edge-device/") - if err != nil { - return nil, err - } - if status < 200 || status >= 300 { - return nil, &APIError{StatusCode: status, Body: string(raw), URL: url} - } - var list []ProvisionedClient - if err := json.Unmarshal(raw, &list); err == nil { - return list, nil + var all []ProvisionedClient + path := "/edge-device/" + for pageNum := 0; path != ""; pageNum++ { + if pageNum >= maxListPages { + return nil, fmt.Errorf("client list exceeded %d pages — aborting", maxListPages) + } + reqURL := c.BaseURL + path + status, raw, err := c.get(ctx, path) + if err != nil { + return nil, err + } + if status < 200 || status >= 300 { + return nil, &APIError{StatusCode: status, Body: string(raw), URL: reqURL} + } + // Unpaginated deployment → a bare array; return it as-is. + var bare []ProvisionedClient + if err := json.Unmarshal(raw, &bare); err == nil { + return append(all, bare...), nil + } + var body struct { + Next string `json:"next"` + Results []ProvisionedClient `json:"results"` + } + if err := json.Unmarshal(raw, &body); err != nil { + return nil, fmt.Errorf("decoding client list: %w", err) + } + all = append(all, body.Results...) + path = nextPath(body.Next) } - var paged struct { - Results []ProvisionedClient `json:"results"` + return all, nil +} + +// nextPath reduces a DRF `next` link (an absolute URL) to the path+query this +// client appends to BaseURL. Returns "" for an empty/unparseable link, which +// ends the pagination loop. +func nextPath(next string) string { + if next == "" { + return "" + } + u, err := url.Parse(next) + if err != nil { + return "" } - if err := json.Unmarshal(raw, &paged); err != nil { - return nil, fmt.Errorf("decoding client list: %w", err) + if u.RawQuery != "" { + return u.Path + "?" + u.RawQuery } - return paged.Results, nil + return u.Path } // ListClientAdmins returns who in the account can provision (the ask-an-admin diff --git a/internal/cli/client.go b/internal/cli/client.go index d607ac9..be66b41 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -108,14 +108,26 @@ func runClientCreate(ctx context.Context, p *ui.Printer, pr prompter, name, loca return &exitError{code: 1, err: err} } + // Gather inputs first (flags win; prompt only what's missing, and only on a + // TTY), then show one review + confirm — matching the dataset-push flow. if name == "" { if pr == nil { - return &exitError{code: 1, err: errors.New("--name is required (non-interactive)")} + return errMissingFlag("--name") } if name, err = pr.Input("Client name", "shown on your dashboard + carbon reports", "", validateNonEmpty); err != nil { return mapClientErr(err) } } + if location == "" { + if pr == nil { + return errMissingFlag("--location") + } + // Never silent-empty: the prompt requires a non-empty zone. (Cloud / + // GeoIP auto-detect of a suggested default is a fast-follow.) + if location, err = pr.Input("Location zone (e.g. DE)", "physical zone, for the carbon footprint", "", validateNonEmpty); err != nil { + return mapClientErr(err) + } + } // Derive the namespace slug from the name, avoiding collisions with existing // clients (best-effort: if the list call fails we still derive a base slug). @@ -133,7 +145,8 @@ func runClientCreate(ctx context.Context, p *ui.Printer, pr prompter, name, loca } if pr != nil && !yes { - ok, cerr := pr.Confirm(fmt.Sprintf("Provision client %q with namespace %q?", name, namespace), true) + renderClientReview(p, name, namespace, location) + ok, cerr := pr.Confirm("Provision this client?", true) if cerr != nil { return mapClientErr(cerr) } @@ -143,17 +156,6 @@ func runClientCreate(ctx context.Context, p *ui.Printer, pr prompter, name, loca } } - if location == "" { - if pr == nil { - return &exitError{code: 1, err: errors.New("--location is required (non-interactive)")} - } - // Never silent-empty: the prompt requires a non-empty zone. (Cloud / - // GeoIP auto-detect of a suggested default is a fast-follow.) - if location, err = pr.Input("Location zone (e.g. DE)", "physical zone, for the carbon footprint", "", validateNonEmpty); err != nil { - return mapClientErr(err) - } - } - // The machine credential: the CLI generates the password, the backend stores // it (write-only). The client-runtime authenticates with username+password. password := randHex(24) @@ -251,6 +253,21 @@ func runClientUse(ctx context.Context, p *ui.Printer, id string) error { "no client %s in your account — run `tracebloc client list` to see the ids", id)} } +// renderClientReview shows the assembled inputs before the confirm prompt, so +// the user sees the derived namespace and location before anything is created. +func renderClientReview(p *ui.Printer, name, namespace, location string) { + p.Section("Review") + p.Field("name", name) + p.Field("namespace", namespace) + p.Field("location", location) +} + +// errMissingFlag reports a required flag absent in a non-interactive run (no TTY +// to prompt — CI, a pipe, or output redirected). +func errMissingFlag(flag string) error { + return &exitError{code: 1, err: fmt.Errorf("%s is required (non-interactive — no TTY to prompt)", flag)} +} + // validateNonEmpty rejects blank prompt input. func validateNonEmpty(s string) error { if strings.TrimSpace(s) == "" { diff --git a/internal/cli/client_test.go b/internal/cli/client_test.go index 767fe9c..d4ac0e4 100644 --- a/internal/cli/client_test.go +++ b/internal/cli/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -124,3 +125,109 @@ func TestClientUse(t *testing.T) { t.Error("expected an error for an unknown client id") } } + +func TestClientCreate_Interactive(t *testing.T) { + var body api.CreateClientRequest + posted := false + withClientBackend(t, func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/edge-device/": + _, _ = w.Write([]byte(`[]`)) + case r.Method == http.MethodPost && r.URL.Path == "/edge-device/": + posted = true + _ = json.NewDecoder(r.Body).Decode(&body) + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":9,"first_name":"Lab One","username":"u-9","namespace":"lab-one","location":"DE"}`)) + } + }) + confirmYes := true + pr := &fakePrompter{ + answers: map[string]string{ + "Client name": "Lab One", + "Location zone (e.g. DE)": "DE", + }, + confirm: &confirmYes, + } + var out bytes.Buffer + if err := runClientCreate(context.Background(), ui.New(&out), pr, "", "", false); err != nil { + t.Fatalf("interactive create: %v", err) + } + if !posted { + t.Fatal("expected a POST after the user confirmed") + } + if body.Name != "Lab One" || body.Namespace != "lab-one" || body.Location != "DE" { + t.Errorf("create body = %+v", body) + } + if !strings.Contains(out.String(), "Review") { + t.Errorf("expected a review section before the confirm, got:\n%s", out.String()) + } +} + +func TestClientCreate_InteractiveCancel(t *testing.T) { + posted := false + withClientBackend(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + posted = true + } + _, _ = w.Write([]byte(`[]`)) + }) + confirmNo := false + pr := &fakePrompter{ + answers: map[string]string{ + "Client name": "Lab Two", + "Location zone (e.g. DE)": "US", + }, + confirm: &confirmNo, + } + var out bytes.Buffer + if err := runClientCreate(context.Background(), ui.New(&out), pr, "", "", false); err != nil { + t.Fatalf("declining the confirm should be a clean exit, got: %v", err) + } + if posted { + t.Error("no client should be created when the user declines the confirm") + } + if !strings.Contains(out.String(), "Cancelled") { + t.Errorf("expected a Cancelled note, got:\n%s", out.String()) + } +} + +func TestClientList_Paginated(t *testing.T) { + withClientBackend(t, func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("page") == "2" { + _, _ = w.Write([]byte(`{"count":2,"next":null,"results":[{"id":2,"first_name":"beta","namespace":"beta"}]}`)) + return + } + // page 1: an absolute `next` link, like real DRF pagination + _, _ = fmt.Fprintf(w, `{"count":2,"next":"http://%s/edge-device/?page=2","results":[{"id":1,"first_name":"alpha","namespace":"alpha"}]}`, r.Host) + }) + var out bytes.Buffer + if err := runClientList(context.Background(), ui.New(&out)); err != nil { + t.Fatal(err) + } + for _, want := range []string{"alpha", "beta"} { + if !strings.Contains(out.String(), want) { + t.Errorf("paginated list missing %q (next not followed?):\n%s", want, out.String()) + } + } +} + +func TestClientCreate_CollisionSuffix(t *testing.T) { + var body api.CreateClientRequest + withClientBackend(t, func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet: + // an existing client already holds "my-client" + _, _ = w.Write([]byte(`[{"id":1,"first_name":"My Client","namespace":"my-client"}]`)) + case r.Method == http.MethodPost: + _ = json.NewDecoder(r.Body).Decode(&body) + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":2,"first_name":"My Client","username":"u-2","namespace":"my-client-2","location":"DE"}`)) + } + }) + if err := runClientCreate(context.Background(), ui.New(&bytes.Buffer{}), nil, "My Client", "DE", true); err != nil { + t.Fatal(err) + } + if body.Namespace != "my-client-2" { + t.Errorf("namespace = %q, want my-client-2 (collision suffix not applied)", body.Namespace) + } +} diff --git a/internal/slug/slug.go b/internal/slug/slug.go index 15e75aa..fb3f8a0 100644 --- a/internal/slug/slug.go +++ b/internal/slug/slug.go @@ -31,6 +31,8 @@ func Slugify(name string) string { } s := strings.ToLower(toASCII(name)) s = nonAlnum.ReplaceAllString(s, "-") + // Redundant after the run-collapse above, but slug.py runs the identical + // second pass — mirror it so the two stay structurally in lock-step. s = multiDash.ReplaceAllString(s, "-") s = strings.Trim(s, "-") if len(s) > MaxLabelLength { @@ -42,7 +44,10 @@ func Slugify(name string) string { // Derive returns a UNIQUE DNS-1123 slug for name, avoiding taken. On collision // it appends -2, -3, … within the 63-char cap. If name slugifies to empty it // falls back to fallback; with an empty fallback it errors (an empty slug must -// never silently become a namespace). Mirrors slug.derive_slug. +// never silently become a namespace). Mirrors slug.derive_slug — except Go has +// no nil string, so fallback=="" is treated as "no fallback" (erroring); the +// Python distinguishes None from "" but no caller passes "", and erroring is +// the safer side of that edge. func Derive(name string, taken []string, fallback string) (string, error) { base := Slugify(name) if base == "" { diff --git a/internal/slug/slug_golden_test.go b/internal/slug/slug_golden_test.go new file mode 100644 index 0000000..04e25c2 --- /dev/null +++ b/internal/slug/slug_golden_test.go @@ -0,0 +1,44 @@ +package slug + +import "testing" + +// goldenPairs are (display name → DNS-1123 slug) pairs verified BYTE-IDENTICAL +// against the canonical Python slugify_dns1123 (backend common/utils/slug.py) +// with a cross-language harness. They lock the NFKD transliteration — the one +// place a Go port can silently drift from the backend validator that rejects +// what this produces. If slug.py ever changes, re-run the harness and update +// these rather than hand-editing. +var goldenPairs = []struct{ name, want string }{ + {"My Client", "my-client"}, + {"café", "cafe"}, // canonical NFKD: é → e + accent (dropped) + {"CAFÉ", "cafe"}, // + lowercase + {"Müller GmbH", "muller-gmbh"}, + {"Straße 42", "strae-42"}, // ß does not decompose → dropped (matches Python) + {"naïve", "naive"}, + {"São Paulo", "sao-paulo"}, + {"Zürich-Lab", "zurich-lab"}, + {"北京 client", "client"}, // CJK dropped + {"client🚀rocket", "clientrocket"}, // emoji dropped + {"über_cool", "uber-cool"}, + {"piñata", "pinata"}, + {"Ω omega", "omega"}, + {"fi-ligature", "fi-ligature"}, // compat NFKD: fi ligature → "fi" + {"①②③ circled", "123-circled"}, // circled digits → "123" + {"½ half", "12-half"}, // vulgar fraction → "1" + fraction-slash(dropped) + "2" + {"FULL width", "full-width"}, // fullwidth → ASCII + {"trailing---dashes---", "trailing-dashes"}, + {"UPPER MixED", "upper-mixed"}, + {"a.b.c-d_e f", "a-b-c-d-e-f"}, + {"2001: A Space Odyssey", "2001-a-space-odyssey"}, + {"Ⅻ roman", "xii-roman"}, // roman numeral → "XII" → "xii" + {"²cubed³", "2cubed3"}, // super/subscripts → digits + {"Hello @ World #1", "hello-world-1"}, +} + +func TestSlugifyGoldenParity(t *testing.T) { + for _, p := range goldenPairs { + if got := Slugify(p.name); got != p.want { + t.Errorf("Slugify(%q) = %q, want %q (golden parity with slug.py)", p.name, got, p.want) + } + } +} From 7efdd9c2e07975cbd810c8d69d2fdd287f4898eb Mon Sep 17 00:00:00 2001 From: Lukas Wuttke Date: Sun, 21 Jun 2026 14:44:17 +0200 Subject: [PATCH 3/3] feat(cli): location auto-detect for client create (#84) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The deferred fast-follow from cli#92: pre-fill `client create`'s location prompt with a detected electricityMaps zone (backend ZONE_CHOICES), so a cloud-hosted client doesn't have to look up its own zone. - internal/geo: Detect() probes cloud instance metadata first (AWS IMDSv2/v1, GCP, Azure — concurrently under one short deadline, first wins → high confidence), then Cloudflare IP geolocation (low confidence, flagged). Returns an ISO country code, always a valid top-level zone; cloud regions map via a curated AWS/GCP/Azure table, and an unmapped region falls through to GeoIP for a valid zone rather than suggest something the backend would reject. - client create: the detected zone pre-fills the prompt default (the user confirms with Enter or overrides) — still never silent, never empty. A detectZone seam keeps the command tests hermetic. - Best-effort: offline / egress-restricted / bare metal → empty default (the prior behavior). Only runs interactively when --location is omitted. Tests: geo per-provider + GeoIP fallback + unmapped-region fallthrough + nothing-detected + region→country table; cli accepts-detected-zone end-to-end. go build/vet/test green. Co-Authored-By: Claude Opus 4.8 (1M context) --- internal/cli/client.go | 19 +++- internal/cli/client_test.go | 40 ++++++++ internal/geo/geo.go | 192 ++++++++++++++++++++++++++++++++++++ internal/geo/geo_test.go | 150 ++++++++++++++++++++++++++++ internal/geo/regions.go | 67 +++++++++++++ 5 files changed, 465 insertions(+), 3 deletions(-) create mode 100644 internal/geo/geo.go create mode 100644 internal/geo/geo_test.go create mode 100644 internal/geo/regions.go diff --git a/internal/cli/client.go b/internal/cli/client.go index be66b41..1a207d4 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -14,6 +14,7 @@ import ( "github.com/tracebloc/cli/internal/api" "github.com/tracebloc/cli/internal/config" + "github.com/tracebloc/cli/internal/geo" "github.com/tracebloc/cli/internal/slug" "github.com/tracebloc/cli/internal/ui" ) @@ -102,6 +103,10 @@ func authedClient() (*api.Client, *config.Config, error) { return client, cfg, nil } +// detectZone suggests a location zone (cloud metadata → GeoIP). A seam so tests +// stay hermetic (no network). +var detectZone = geo.Detect + func runClientCreate(ctx context.Context, p *ui.Printer, pr prompter, name, location string, yes bool) error { client, cfg, err := authedClient() if err != nil { @@ -122,9 +127,17 @@ func runClientCreate(ctx context.Context, p *ui.Printer, pr prompter, name, loca if pr == nil { return errMissingFlag("--location") } - // Never silent-empty: the prompt requires a non-empty zone. (Cloud / - // GeoIP auto-detect of a suggested default is a fast-follow.) - if location, err = pr.Input("Location zone (e.g. DE)", "physical zone, for the carbon footprint", "", validateNonEmpty); err != nil { + // Auto-detect a suggested zone (cloud metadata → IP geolocation) and + // pre-fill it as the prompt default; the user confirms with Enter or + // overrides. Never silent (it's a prompt), never empty (validateNonEmpty). + suggested := "" + help := "electricityMaps zone for the carbon footprint (e.g. DE)" + if z := detectZone(ctx); z != nil { + suggested = z.Code + help = fmt.Sprintf("detected %s via %s (%s confidence) — Enter to accept, or type your zone", + z.Code, z.Source, z.Confidence) + } + if location, err = pr.Input("Location zone (e.g. DE)", help, suggested, validateNonEmpty); err != nil { return mapClientErr(err) } } diff --git a/internal/cli/client_test.go b/internal/cli/client_test.go index d4ac0e4..176728c 100644 --- a/internal/cli/client_test.go +++ b/internal/cli/client_test.go @@ -12,6 +12,7 @@ import ( "github.com/tracebloc/cli/internal/api" "github.com/tracebloc/cli/internal/config" + "github.com/tracebloc/cli/internal/geo" "github.com/tracebloc/cli/internal/ui" ) @@ -32,6 +33,15 @@ func withClientBackend(t *testing.T, h http.HandlerFunc) { t.Cleanup(func() { newAPIClient = orig }) } +// stubDetect replaces the location auto-detector so command tests stay hermetic +// (no real cloud-metadata / GeoIP probes). +func stubDetect(t *testing.T, z *geo.Zone) { + t.Helper() + orig := detectZone + detectZone = func(context.Context) *geo.Zone { return z } + t.Cleanup(func() { detectZone = orig }) +} + func TestClientCreate_Success(t *testing.T) { var body api.CreateClientRequest withClientBackend(t, func(w http.ResponseWriter, r *http.Request) { @@ -140,6 +150,7 @@ func TestClientCreate_Interactive(t *testing.T) { _, _ = w.Write([]byte(`{"id":9,"first_name":"Lab One","username":"u-9","namespace":"lab-one","location":"DE"}`)) } }) + stubDetect(t, nil) // hermetic: no real cloud/GeoIP probes confirmYes := true pr := &fakePrompter{ answers: map[string]string{ @@ -171,6 +182,7 @@ func TestClientCreate_InteractiveCancel(t *testing.T) { } _, _ = w.Write([]byte(`[]`)) }) + stubDetect(t, nil) confirmNo := false pr := &fakePrompter{ answers: map[string]string{ @@ -231,3 +243,31 @@ func TestClientCreate_CollisionSuffix(t *testing.T) { t.Errorf("namespace = %q, want my-client-2 (collision suffix not applied)", body.Namespace) } } + +func TestClientCreate_AcceptsDetectedZone(t *testing.T) { + var body api.CreateClientRequest + withClientBackend(t, func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/edge-device/": + _, _ = w.Write([]byte(`[]`)) + case r.Method == http.MethodPost && r.URL.Path == "/edge-device/": + _ = json.NewDecoder(r.Body).Decode(&body) + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":3,"first_name":"Edge","username":"u-3","namespace":"edge","location":"FR"}`)) + } + }) + // Detector suggests FR; the user accepts it — no scripted answer for the + // location prompt, so the fake returns the pre-filled default. + stubDetect(t, &geo.Zone{Code: "FR", Source: "aws", Confidence: geo.High}) + confirmYes := true + pr := &fakePrompter{ + answers: map[string]string{"Client name": "Edge"}, + confirm: &confirmYes, + } + if err := runClientCreate(context.Background(), ui.New(&bytes.Buffer{}), pr, "", "", false); err != nil { + t.Fatalf("create: %v", err) + } + if body.Location != "FR" { + t.Errorf("location = %q, want FR (detected zone accepted as the default)", body.Location) + } +} diff --git a/internal/geo/geo.go b/internal/geo/geo.go new file mode 100644 index 0000000..6079f78 --- /dev/null +++ b/internal/geo/geo.go @@ -0,0 +1,192 @@ +// Package geo best-effort detects the host's electricityMaps zone (backend +// ZONE_CHOICES) to pre-fill `client create`'s location prompt: cloud instance +// metadata first (high confidence — the VM reports its own region), then IP +// geolocation (low confidence — flagged). The result is only ever a SUGGESTED +// default the user confirms or overrides; detection failing just means an empty +// default (RFC-0001 location auto-detect, cli#84). +package geo + +import ( + "context" + "io" + "net/http" + "strings" + "time" +) + +// Confidence levels for a detected zone. +const ( + High = "high" // cloud instance metadata — the host runs in this region + Low = "low" // IP geolocation — can be wrong behind VPN / proxy / egress NAT +) + +// Zone is a best-effort location guess. Code is an ISO 3166-1 alpha-2 country +// (always a valid top-level electricityMaps zone); Source names how it was found. +type Zone struct { + Code string + Source string + Confidence string +} + +// Metadata / GeoIP endpoints — package vars so tests can point them at httptest. +var ( + awsIMDSBase = "http://169.254.169.254" + gcpMetaBase = "http://metadata.google.internal" + azureIMDSBase = "http://169.254.169.254" + geoIPURL = "https://www.cloudflare.com/cdn-cgi/trace" +) + +const ( + cloudProbeTimeout = 1500 * time.Millisecond + geoIPTimeout = 3 * time.Second +) + +var ( + // Metadata endpoints are link-local — never via a proxy, and fail fast. + metadataClient = &http.Client{Transport: &http.Transport{Proxy: nil}} + // GeoIP is a public host — honor the corporate proxy like the API client. + geoIPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyFromEnvironment}} +) + +// Detect returns a best-effort zone, or nil if nothing could be determined +// (offline, egress-restricted, or bare metal with no usable IP geolocation). It +// never blocks long: the cloud probes share one short deadline and run +// concurrently; GeoIP is a single call only reached when the host isn't a +// recognized cloud region. +func Detect(ctx context.Context) *Zone { + if region, provider := probeCloud(ctx); region != "" { + if cc, ok := regionCountry(region); ok { + return &Zone{Code: cc, Source: provider, Confidence: High} + } + // A cloud host whose region isn't in the map — fall through to GeoIP for + // a VALID zone rather than suggest an unknown string the backend rejects. + } + if cc := probeGeoIP(ctx); cc != "" { + return &Zone{Code: cc, Source: "geoip", Confidence: Low} + } + return nil +} + +// probeCloud runs the three cloud probes concurrently under one deadline and +// returns the first that reports a region (so a real cloud host answers in one +// round-trip instead of waiting through the others' timeouts). +func probeCloud(ctx context.Context) (region, provider string) { + ctx, cancel := context.WithTimeout(ctx, cloudProbeTimeout) + defer cancel() + type res struct{ region, provider string } + probes := []struct { + name string + fn func(context.Context) string + }{ + {"aws", detectAWS}, + {"gcp", detectGCP}, + {"azure", detectAzure}, + } + ch := make(chan res, len(probes)) + for _, p := range probes { + p := p + go func() { ch <- res{p.fn(ctx), p.name} }() + } + for range probes { + if r := <-ch; r.region != "" { + return r.region, r.provider + } + } + return "", "" +} + +// detectAWS reads the region from EC2 IMDS, preferring IMDSv2 (token) and +// falling back to IMDSv1 (no token) if the token PUT is refused. +func detectAWS(ctx context.Context) string { + var token string + if req, err := http.NewRequestWithContext(ctx, http.MethodPut, awsIMDSBase+"/latest/api/token", nil); err == nil { + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "60") + if t, ok := doText(metadataClient, req); ok { + token = t + } + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, awsIMDSBase+"/latest/meta-data/placement/region", nil) + if err != nil { + return "" + } + if token != "" { + req.Header.Set("X-aws-ec2-metadata-token", token) + } + region, _ := doText(metadataClient, req) + return region +} + +// detectGCP reads the instance zone and trims the trailing zone letter to a +// region ("projects/N/zones/europe-west3-c" → "europe-west3"). +func detectGCP(ctx context.Context) string { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, gcpMetaBase+"/computeMetadata/v1/instance/zone", nil) + if err != nil { + return "" + } + req.Header.Set("Metadata-Flavor", "Google") + zone, ok := doText(metadataClient, req) + if !ok || zone == "" { + return "" + } + if i := strings.LastIndex(zone, "/"); i >= 0 { + zone = zone[i+1:] + } + if i := strings.LastIndex(zone, "-"); i >= 0 { + zone = zone[:i] + } + return zone +} + +// detectAzure reads the compute location from Azure IMDS (already a region-like +// string, e.g. "germanywestcentral"). +func detectAzure(ctx context.Context) string { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + azureIMDSBase+"/metadata/instance/compute/location?api-version=2021-02-01&format=text", nil) + if err != nil { + return "" + } + req.Header.Set("Metadata", "true") + loc, _ := doText(metadataClient, req) + return loc +} + +// probeGeoIP reads the ISO country from Cloudflare's trace endpoint (the `loc=` +// line) — HTTPS, no API key, returns a 2-letter country code. +func probeGeoIP(ctx context.Context) string { + ctx, cancel := context.WithTimeout(ctx, geoIPTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, geoIPURL, nil) + if err != nil { + return "" + } + body, ok := doText(geoIPClient, req) + if !ok { + return "" + } + for _, line := range strings.Split(body, "\n") { + if cc, found := strings.CutPrefix(line, "loc="); found { + cc = strings.TrimSpace(cc) + if len(cc) == 2 { + return strings.ToUpper(cc) + } + } + } + return "" +} + +// doText runs req and returns the trimmed body on a 2xx, else ("", false). +func doText(client *http.Client, req *http.Request) (string, bool) { + resp, err := client.Do(req) + if err != nil { + return "", false + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", false + } + b, err := io.ReadAll(io.LimitReader(resp.Body, 4096)) + if err != nil { + return "", false + } + return strings.TrimSpace(string(b)), true +} diff --git a/internal/geo/geo_test.go b/internal/geo/geo_test.go new file mode 100644 index 0000000..a91c4ef --- /dev/null +++ b/internal/geo/geo_test.go @@ -0,0 +1,150 @@ +package geo + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" +) + +// setEndpoints points the metadata / GeoIP endpoints at test servers. +func setEndpoints(t *testing.T, aws, gcp, azure, geoip string) { + t.Helper() + oa, og, oz, ogi := awsIMDSBase, gcpMetaBase, azureIMDSBase, geoIPURL + awsIMDSBase, gcpMetaBase, azureIMDSBase, geoIPURL = aws, gcp, azure, geoip + t.Cleanup(func() { awsIMDSBase, gcpMetaBase, azureIMDSBase, geoIPURL = oa, og, oz, ogi }) +} + +// notFoundServer is a stand-in for an absent provider (every probe 404s fast). +func notFoundServer(t *testing.T) string { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + t.Cleanup(srv.Close) + return srv.URL +} + +func TestDetect_AWS(t *testing.T) { + aws := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPut && r.URL.Path == "/latest/api/token": + _, _ = w.Write([]byte("tok-123")) + case r.Method == http.MethodGet && r.URL.Path == "/latest/meta-data/placement/region": + if r.Header.Get("X-aws-ec2-metadata-token") != "tok-123" { + w.WriteHeader(http.StatusUnauthorized) // enforce the IMDSv2 token + return + } + _, _ = w.Write([]byte("eu-central-1")) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + t.Cleanup(aws.Close) + nf := notFoundServer(t) + setEndpoints(t, aws.URL, nf, nf, nf) + + if z := Detect(context.Background()); z == nil || z.Code != "DE" || z.Source != "aws" || z.Confidence != High { + t.Fatalf("got %+v, want DE/aws/high", z) + } +} + +func TestDetect_GCP(t *testing.T) { + gcp := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/computeMetadata/v1/instance/zone" && r.Header.Get("Metadata-Flavor") == "Google" { + _, _ = w.Write([]byte("projects/123/zones/europe-west3-c")) + return + } + w.WriteHeader(http.StatusNotFound) + })) + t.Cleanup(gcp.Close) + nf := notFoundServer(t) + setEndpoints(t, nf, gcp.URL, nf, nf) + + if z := Detect(context.Background()); z == nil || z.Code != "DE" || z.Source != "gcp" || z.Confidence != High { + t.Fatalf("got %+v, want DE/gcp/high", z) + } +} + +func TestDetect_Azure(t *testing.T) { + az := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/metadata/instance/compute/location" && r.Header.Get("Metadata") == "true" { + _, _ = w.Write([]byte("germanywestcentral")) + return + } + w.WriteHeader(http.StatusNotFound) + })) + t.Cleanup(az.Close) + nf := notFoundServer(t) + setEndpoints(t, nf, nf, az.URL, nf) + + if z := Detect(context.Background()); z == nil || z.Code != "DE" || z.Source != "azure" || z.Confidence != High { + t.Fatalf("got %+v, want DE/azure/high", z) + } +} + +func TestDetect_GeoIPFallback(t *testing.T) { + geoip := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("fl=1f\nip=1.2.3.4\nloc=FR\ncolo=CDG\n")) + })) + t.Cleanup(geoip.Close) + nf := notFoundServer(t) + setEndpoints(t, nf, nf, nf, geoip.URL) + + if z := Detect(context.Background()); z == nil || z.Code != "FR" || z.Source != "geoip" || z.Confidence != Low { + t.Fatalf("got %+v, want FR/geoip/low", z) + } +} + +func TestDetect_UnmappedRegionFallsBackToGeoIP(t *testing.T) { + // A cloud region we don't map must NOT be suggested verbatim (the backend + // would reject it) — Detect falls through to GeoIP for a valid zone. + aws := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/api/token": + _, _ = w.Write([]byte("t")) + case "/latest/meta-data/placement/region": + _, _ = w.Write([]byte("antarctica-south-1")) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + t.Cleanup(aws.Close) + geoip := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("loc=US\n")) + })) + t.Cleanup(geoip.Close) + nf := notFoundServer(t) + setEndpoints(t, aws.URL, nf, nf, geoip.URL) + + if z := Detect(context.Background()); z == nil || z.Code != "US" || z.Source != "geoip" { + t.Fatalf("got %+v, want US/geoip (unmapped region → GeoIP)", z) + } +} + +func TestDetect_Nothing(t *testing.T) { + nf := notFoundServer(t) + setEndpoints(t, nf, nf, nf, nf) + if z := Detect(context.Background()); z != nil { + t.Fatalf("got %+v, want nil", z) + } +} + +func TestRegionCountry(t *testing.T) { + cases := map[string]string{ + "eu-central-1": "DE", // AWS + "europe-west3": "DE", // GCP + "germanywestcentral": "DE", // Azure + "us-east-1": "US", + "ap-southeast-1": "SG", + "EU-WEST-2": "GB", // case-insensitive + } + for region, want := range cases { + if got, ok := regionCountry(region); !ok || got != want { + t.Errorf("regionCountry(%q) = %q,%v; want %q,true", region, got, ok, want) + } + } + if _, ok := regionCountry("mars-north-1"); ok { + t.Error("unmapped region should return ok=false") + } +} diff --git a/internal/geo/regions.go b/internal/geo/regions.go new file mode 100644 index 0000000..e22588d --- /dev/null +++ b/internal/geo/regions.go @@ -0,0 +1,67 @@ +package geo + +import "strings" + +// regionCountry maps a cloud region / location string to its ISO 3166-1 alpha-2 +// country — always a valid top-level electricityMaps zone (backend ZONE_CHOICES). +// It covers the common AWS / GCP / Azure regions; an unmapped region falls +// through to IP geolocation in Detect, so this need not be exhaustive — extend +// as new regions appear. Country = where the region's datacenters physically sit. +func regionCountry(region string) (string, bool) { + cc, ok := regionToCountry[strings.ToLower(strings.TrimSpace(region))] + return cc, ok +} + +var regionToCountry = map[string]string{ + // ── AWS ── + "us-east-1": "US", "us-east-2": "US", "us-west-1": "US", "us-west-2": "US", + "ca-central-1": "CA", "ca-west-1": "CA", + "eu-west-1": "IE", "eu-west-2": "GB", "eu-west-3": "FR", + "eu-central-1": "DE", "eu-central-2": "CH", + "eu-north-1": "SE", "eu-south-1": "IT", "eu-south-2": "ES", + "ap-south-1": "IN", "ap-south-2": "IN", + "ap-southeast-1": "SG", "ap-southeast-2": "AU", "ap-southeast-3": "ID", "ap-southeast-4": "AU", + "ap-northeast-1": "JP", "ap-northeast-2": "KR", "ap-northeast-3": "JP", + "ap-east-1": "HK", + "sa-east-1": "BR", + "me-south-1": "BH", "me-central-1": "AE", + "af-south-1": "ZA", + "il-central-1": "IL", + + // ── GCP ── + "us-central1": "US", "us-east1": "US", "us-east4": "US", "us-east5": "US", + "us-west1": "US", "us-west2": "US", "us-west3": "US", "us-west4": "US", "us-south1": "US", + "northamerica-northeast1": "CA", "northamerica-northeast2": "CA", + "southamerica-east1": "BR", "southamerica-west1": "CL", + "europe-west1": "BE", "europe-west2": "GB", "europe-west3": "DE", "europe-west4": "NL", + "europe-west6": "CH", "europe-west8": "IT", "europe-west9": "FR", "europe-west10": "DE", "europe-west12": "IT", + "europe-central2": "PL", "europe-north1": "FI", "europe-southwest1": "ES", + "asia-east1": "TW", "asia-east2": "HK", + "asia-northeast1": "JP", "asia-northeast2": "JP", "asia-northeast3": "KR", + "asia-south1": "IN", "asia-south2": "IN", + "asia-southeast1": "SG", "asia-southeast2": "ID", + "australia-southeast1": "AU", "australia-southeast2": "AU", + "me-west1": "IL", "me-central1": "QA", "me-central2": "SA", + + // ── Azure ── + "eastus": "US", "eastus2": "US", "centralus": "US", "northcentralus": "US", + "southcentralus": "US", "westus": "US", "westus2": "US", "westus3": "US", "westcentralus": "US", + "canadacentral": "CA", "canadaeast": "CA", + "brazilsouth": "BR", "brazilsoutheast": "BR", + "northeurope": "IE", "westeurope": "NL", + "uksouth": "GB", "ukwest": "GB", + "francecentral": "FR", "francesouth": "FR", + "germanywestcentral": "DE", "germanynorth": "DE", + "switzerlandnorth": "CH", "switzerlandwest": "CH", + "norwayeast": "NO", "norwaywest": "NO", + "swedencentral": "SE", "polandcentral": "PL", "italynorth": "IT", "spaincentral": "ES", + "eastasia": "HK", "southeastasia": "SG", + "japaneast": "JP", "japanwest": "JP", + "koreacentral": "KR", "koreasouth": "KR", + "centralindia": "IN", "southindia": "IN", "westindia": "IN", + "australiaeast": "AU", "australiasoutheast": "AU", "australiacentral": "AU", + "uaenorth": "AE", "uaecentral": "AE", + "qatarcentral": "QA", + "southafricanorth": "ZA", + "israelcentral": "IL", +}