diff --git a/api/oauth.go b/api/oauth.go new file mode 100644 index 0000000..60c6406 --- /dev/null +++ b/api/oauth.go @@ -0,0 +1,653 @@ +package api + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "embed" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "html/template" + "io" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "time" +) + +// OAuth 2.0 + PKCE (S256) client for Hostinger SSO. It runs an interactive +// browser sign-in against a local loopback callback, persists the resulting +// tokens to disk, and refreshes them on demand so callers can obtain a bearer +// token without a statically configured API token. +const ( + defaultIssuer = "https://auth.hostinger.com" + defaultClientName = "hostinger-cli" + + registerPath = "/api/external/v1/oauth-server/register" + authorizePath = "/api/external/v1/oauth-server/authorize" + tokenPath = "/api/external/v1/oauth-server/token" + revokePath = "/api/external/v1/oauth-server/token/revoke" + + callbackPath = "/oauth/callback" + + // httpTimeout bounds a single auth-server round trip; loginTimeout bounds + // the interactive wait for the user to finish in their browser. Both are + // applied only when the caller's context carries no deadline of its own. + httpTimeout = 30 * time.Second + loginTimeout = 5 * time.Minute + + // expiryBufferSeconds is subtracted from the server-reported lifetime so we + // treat a token as expired slightly early and avoid a mid-request expiry. + expiryBufferSeconds = 60 +) + +// errRefreshTokenDead marks a refresh grant rejected with a 4xx — the refresh +// token is definitively dead and the caller must fall back to a full login. +// A 5xx or network failure is returned as a plain error instead, so transient +// outages never spuriously launch the browser. +var errRefreshTokenDead = errors.New("refresh token rejected") + +// openBrowser launches the user's default browser. It is a package var so tests +// can substitute it; in production it is fire-and-forget (see Login). +var openBrowser = func(rawURL string) error { + switch runtime.GOOS { + case "darwin": + return exec.Command("open", rawURL).Start() + case "windows": + // The empty quoted string is cmd.exe's required first "title" arg so + // that it interprets the URL as the target rather than the title. + return exec.Command("cmd", "/c", "start", "", rawURL).Start() + default: + return exec.Command("xdg-open", rawURL).Start() + } +} + +//go:embed templates/*.html +var templateFS embed.FS + +// Callback pages, parsed once at load and reused. The error page interpolates +// the untrusted auth-server error code via {{ . }}, which html/template escapes. +var ( + successTemplate = template.Must(template.ParseFS(templateFS, "templates/success.html")) + errorTemplate = template.Must(template.ParseFS(templateFS, "templates/error.html")) +) + +// OAuthService obtains and maintains a bearer token via the OAuth flow. +// The zero value is not usable; construct one with NewOAuthService. +type OAuthService struct { + issuer string + clientName string + credPath string + httpClient *http.Client + + // mu serializes the public methods so concurrent callers in one process + // share a single auth-server round trip rather than racing on disk. + mu sync.Mutex +} + +// credentials is the on-disk token state. Absent fields imply "not yet known". +type credentials struct { + ClientID string `json:"client_id,omitempty"` + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` // local expiry, ms epoch +} + +type tokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` +} + +// NewOAuthService builds a service with the default issuer (overridable via the +// HOSTINGER_OAUTH_ISSUER env var) and the standard on-disk credentials path. +func NewOAuthService() (*OAuthService, error) { + credPath, err := defaultCredPath() + if err != nil { + return nil, err + } + + issuer := defaultIssuer + if v := strings.TrimSpace(os.Getenv("HOSTINGER_OAUTH_ISSUER")); v != "" { + issuer = v + } + + return &OAuthService{ + issuer: strings.TrimRight(issuer, "/"), + clientName: defaultClientName, + credPath: credPath, + httpClient: &http.Client{}, + }, nil +} + +// Token returns a usable access token: the cached one if still valid, otherwise +// a proactive refresh, otherwise a full interactive login. +func (s *OAuthService) Token(ctx context.Context) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.resolve(ctx) +} + +// Login runs the full interactive PKCE flow regardless of cache state. +func (s *OAuthService) Login(ctx context.Context) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.login(ctx) +} + +// Refresh exchanges the stored refresh token for a fresh access token. +func (s *OAuthService) Refresh(ctx context.Context) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + creds, err := s.loadCredentials() + if err != nil { + return "", err + } + if creds.RefreshToken == "" || creds.ClientID == "" { + return "", errors.New("no refresh token available; run an interactive login first") + } + return s.refresh(ctx, creds) +} + +// Reauthenticate forces a fresh token, bypassing the cached-token fast path. It +// mirrors Token's refresh-then-login fallback and is meant for the reactive 401 +// recovery path, where the cached token is already known to be dead. +func (s *OAuthService) Reauthenticate(ctx context.Context) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + creds, _ := s.loadCredentials() + return s.refreshOrLogin(ctx, creds) +} + +// Logout best-effort revokes the access token and wipes local tokens, keeping +// the registered client_id so a later login can skip dynamic registration. +func (s *OAuthService) Logout(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + creds, _ := s.loadCredentials() + if creds.ClientID == "" && creds.AccessToken == "" { + return nil // nothing stored + } + + if creds.AccessToken != "" { + form := url.Values{} + form.Set("token", creds.AccessToken) + form.Set("client_id", creds.ClientID) + // Revocation is best-effort: logout must succeed locally even if the + // auth server is unreachable. + _, _, _ = s.postForm(ctx, revokePath, form) + } + + return s.saveCredentials(credentials{ClientID: creds.ClientID}) +} + +// resolve is the token decision tree: cached -> refresh -> login. Callers must +// hold s.mu. +func (s *OAuthService) resolve(ctx context.Context) (string, error) { + creds, _ := s.loadCredentials() + + if creds.AccessToken != "" && creds.ExpiresAt != 0 && nowMillis() < creds.ExpiresAt { + return creds.AccessToken, nil + } + return s.refreshOrLogin(ctx, creds) +} + +// refreshOrLogin attempts a proactive refresh and falls back to a full +// interactive login when there is no usable refresh token or the refresh token +// is dead (4xx). A transient refresh failure (5xx/network) is surfaced as-is so +// it never spuriously launches a browser. Callers must hold s.mu. +func (s *OAuthService) refreshOrLogin(ctx context.Context, creds credentials) (string, error) { + if creds.RefreshToken != "" && creds.ClientID != "" { + token, err := s.refresh(ctx, creds) + if err == nil { + return token, nil + } + if !errors.Is(err, errRefreshTokenDead) { + return "", err + } + // 4xx: refresh token is dead, fall through to a full login. + } + return s.login(ctx) +} + +// login runs the PKCE flow end to end and returns the issued access token. +// Callers must hold s.mu. +func (s *OAuthService) login(ctx context.Context) (string, error) { + // Bind the loopback callback before launching the browser so the redirect + // can never race ahead of the listener. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", fmt.Errorf("failed to start callback listener: %w", err) + } + defer listener.Close() + + port := listener.Addr().(*net.TCPAddr).Port + redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, callbackPath) + + creds, _ := s.loadCredentials() + if creds.ClientID == "" { + clientID, err := s.register(ctx, redirectURI) + if err != nil { + return "", err + } + creds.ClientID = clientID + if err := s.saveCredentials(creds); err != nil { + return "", err + } + } + + verifier, challenge, err := pkce() + if err != nil { + return "", err + } + state, err := randomState() + if err != nil { + return "", err + } + + authURL := s.authorizeURL(creds.ClientID, redirectURI, state, challenge) + if err := openBrowser(authURL); err != nil { + fmt.Fprintf(os.Stderr, "Could not open a browser automatically.\nOpen this URL to continue signing in:\n\n%s\n\n", authURL) + } + + waitCtx, cancel := withTimeout(ctx, loginTimeout) + defer cancel() + code, err := s.awaitCallback(waitCtx, listener, state) + if err != nil { + return "", err + } + + return s.exchangeCode(ctx, creds.ClientID, code, verifier, redirectURI) +} + +// register performs RFC 7591 dynamic client registration and returns the new +// client_id. It is a public (PKCE-only) client, so no client_secret is expected. +func (s *OAuthService) register(ctx context.Context, redirectURI string) (string, error) { + payload, err := json.Marshal(map[string]any{ + "client_name": s.clientName, + "redirect_uris": []string{redirectURI}, + }) + if err != nil { + return "", err + } + + status, body, err := s.postJSON(ctx, registerPath, payload) + if err != nil { + return "", err + } + if !is2xx(status) { + return "", fmt.Errorf("client registration failed (status %d): %s", status, snippet(body)) + } + + var out struct { + ClientID string `json:"client_id"` + } + if err := json.Unmarshal(body, &out); err != nil { + return "", fmt.Errorf("client registration: invalid response: %w", err) + } + if out.ClientID == "" { + return "", fmt.Errorf("client registration response missing client_id: %s", snippet(body)) + } + return out.ClientID, nil +} + +// exchangeCode swaps the authorization code for tokens and persists them. +func (s *OAuthService) exchangeCode(ctx context.Context, clientID, code, verifier, redirectURI string) (string, error) { + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", code) + form.Set("code_verifier", verifier) + form.Set("redirect_uri", redirectURI) + form.Set("client_id", clientID) + + status, body, err := s.postForm(ctx, tokenPath, form) + if err != nil { + return "", err + } + if !is2xx(status) { + return "", fmt.Errorf("token exchange failed (status %d): %s", status, snippet(body)) + } + + tr, err := parseTokenResponse(body) + if err != nil { + return "", err + } + + creds, _ := s.loadCredentials() + creds.ClientID = clientID + creds.AccessToken = tr.AccessToken + creds.RefreshToken = tr.RefreshToken + creds.ExpiresAt = expiresAtMillis(tr.ExpiresIn) + if err := s.saveCredentials(creds); err != nil { + return "", err + } + return creds.AccessToken, nil +} + +// refresh runs a refresh-token grant. A 4xx is reported as errRefreshTokenDead; +// a 5xx or network error is returned verbatim so the caller can treat it as +// transient. +func (s *OAuthService) refresh(ctx context.Context, creds credentials) (string, error) { + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("refresh_token", creds.RefreshToken) + form.Set("client_id", creds.ClientID) + + status, body, err := s.postForm(ctx, tokenPath, form) + if err != nil { + return "", err + } + switch { + case is2xx(status): + // handled below + case status >= 400 && status < 500: + return "", fmt.Errorf("%w (status %d)", errRefreshTokenDead, status) + default: + return "", fmt.Errorf("token refresh failed (status %d): %s", status, snippet(body)) + } + + tr, err := parseTokenResponse(body) + if err != nil { + return "", err + } + + creds.AccessToken = tr.AccessToken + if tr.RefreshToken != "" { + creds.RefreshToken = tr.RefreshToken // server may rotate; keep old otherwise + } + creds.ExpiresAt = expiresAtMillis(tr.ExpiresIn) + if err := s.saveCredentials(creds); err != nil { + return "", err + } + return creds.AccessToken, nil +} + +// awaitCallback serves the loopback listener until the OAuth redirect arrives, +// then returns the authorization code. Only GET /oauth/callback is honored; +// everything else is 404 and does not affect the flow. +func (s *OAuthService) awaitCallback(ctx context.Context, listener net.Listener, expectedState string) (string, error) { + type result struct { + code string + err error + } + // Buffered + non-blocking send: the first callback resolves the flow; any + // duplicate hit (browser retry, replay) is answered but its result dropped, + // so the handler goroutine never blocks on a full channel. + done := make(chan result, 1) + reply := func(res result) { + select { + case done <- res: + default: + } + } + + mux := http.NewServeMux() + mux.HandleFunc(callbackPath, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.NotFound(w, r) + return + } + q := r.URL.Query() + + if e := q.Get("error"); e != "" { + renderPage(w, http.StatusBadRequest, errorTemplate, e) + reply(result{err: fmt.Errorf("authorization failed: %s", e)}) + return + } + // CSRF guard: the state must match byte-for-byte what we generated. + if q.Get("state") != expectedState { + renderPage(w, http.StatusBadRequest, errorTemplate, "state mismatch") + reply(result{err: errors.New("oauth state mismatch")}) + return + } + code := q.Get("code") + if code == "" { + renderPage(w, http.StatusBadRequest, errorTemplate, "missing authorization code") + reply(result{err: errors.New("oauth callback missing authorization code")}) + return + } + renderPage(w, http.StatusOK, successTemplate, nil) + reply(result{code: code}) + }) + + srv := &http.Server{Handler: mux} + go srv.Serve(listener) //nolint:errcheck // returns ErrServerClosed on shutdown + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _ = srv.Shutdown(shutdownCtx) // graceful: lets the final response flush + }() + + select { + case <-ctx.Done(): + return "", ctx.Err() + case res := <-done: + return res.code, res.err + } +} + +func (s *OAuthService) authorizeURL(clientID, redirectURI, state, challenge string) string { + q := url.Values{} + q.Set("client_id", clientID) + q.Set("redirect_uri", redirectURI) + q.Set("state", state) + q.Set("code_challenge", challenge) + q.Set("code_challenge_method", "S256") + q.Set("response_type", "code") + return s.issuer + authorizePath + "?" + q.Encode() +} + +// --- HTTP plumbing --------------------------------------------------------- + +func (s *OAuthService) postForm(ctx context.Context, path string, form url.Values) (int, []byte, error) { + ctx, cancel := withTimeout(ctx, httpTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.issuer+path, strings.NewReader(form.Encode())) + if err != nil { + return 0, nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return s.do(req) +} + +func (s *OAuthService) postJSON(ctx context.Context, path string, payload []byte) (int, []byte, error) { + ctx, cancel := withTimeout(ctx, httpTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.issuer+path, bytes.NewReader(payload)) + if err != nil { + return 0, nil, err + } + req.Header.Set("Content-Type", "application/json") + return s.do(req) +} + +func (s *OAuthService) do(req *http.Request) (int, []byte, error) { + resp, err := s.httpClient.Do(req) + if err != nil { + return 0, nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, nil, err + } + return resp.StatusCode, body, nil +} + +// --- credentials persistence ----------------------------------------------- + +// loadCredentials reads the on-disk state. A missing or corrupt file is treated +// as "no credentials" (zero value, nil error). +func (s *OAuthService) loadCredentials() (credentials, error) { + data, err := os.ReadFile(s.credPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return credentials{}, nil + } + return credentials{}, err + } + + var creds credentials + if err := json.Unmarshal(data, &creds); err != nil { + return credentials{}, nil // corrupt: start fresh + } + return creds, nil +} + +// saveCredentials writes the state atomically (temp file + rename) with mode +// 0600, creating the parent directory on demand. +func (s *OAuthService) saveCredentials(creds credentials) error { + data, err := json.MarshalIndent(creds, "", " ") + if err != nil { + return err + } + + dir := filepath.Dir(s.credPath) + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + + tmp, err := os.CreateTemp(dir, "credentials-*.tmp") + if err != nil { + return err + } + tmpName := tmp.Name() + defer os.Remove(tmpName) // no-op once the rename succeeds + + if err := tmp.Chmod(0600); err != nil { + tmp.Close() + return err + } + if _, err := tmp.Write(data); err != nil { + tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + return os.Rename(tmpName, s.credPath) +} + +// --- helpers --------------------------------------------------------------- + +// defaultCredPath returns ~/.config/hostinger/api-cli/credentials.json on POSIX +// and %APPDATA%\hostinger\api-cli\credentials.json on Windows. +func defaultCredPath() (string, error) { + if runtime.GOOS == "windows" { + base := os.Getenv("APPDATA") + if base == "" { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + base = filepath.Join(home, "AppData", "Roaming") + } + return filepath.Join(base, "hostinger", "api-cli", "credentials.json"), nil + } + + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, ".config", "hostinger", "api-cli", "credentials.json"), nil +} + +// withTimeout applies d only when ctx carries no deadline of its own. +func withTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + if _, ok := ctx.Deadline(); ok { + return ctx, func() {} + } + return context.WithTimeout(ctx, d) +} + +// pkce generates an S256 verifier/challenge pair from 32 bytes of CSPRNG output. +func pkce() (verifier, challenge string, err error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", "", err + } + verifier = base64.RawURLEncoding.EncodeToString(b) + sum := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(sum[:]) + return verifier, challenge, nil +} + +// randomState returns 16 bytes of CSPRNG output as lowercase hex (the CSRF token). +func randomState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +func parseTokenResponse(body []byte) (tokenResponse, error) { + var tr tokenResponse + if err := json.Unmarshal(body, &tr); err != nil { + return tokenResponse{}, fmt.Errorf("invalid token response: %w", err) + } + if tr.AccessToken == "" { + return tokenResponse{}, errors.New("token response missing access_token") + } + return tr, nil +} + +func expiresAtMillis(expiresIn int64) int64 { + // Clamp so a short-lived (or zero) server lifetime can't underflow into the + // past; the token is simply treated as already expired on the next call. + effective := expiresIn - expiryBufferSeconds + if effective < 0 { + effective = 0 + } + return time.Now().Add(time.Duration(effective) * time.Second).UnixMilli() +} + +func nowMillis() int64 { + return time.Now().UnixMilli() +} + +func is2xx(status int) bool { + return status >= 200 && status < 300 +} + +// renderPage executes a callback page template into the response. It renders to +// a buffer first so a template failure can't leave a half-written body, and +// flushes so the page reaches the browser before the listener is shut down. +func renderPage(w http.ResponseWriter, status int, tmpl *template.Template, data any) { + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(status) + _, _ = buf.WriteTo(w) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } +} + +// snippet trims and bounds an untrusted response body for inclusion in errors. +func snippet(body []byte) string { + s := strings.TrimSpace(string(body)) + if len(s) > 512 { + s = s[:512] + "…" + } + return s +} diff --git a/api/oauth_test.go b/api/oauth_test.go new file mode 100644 index 0000000..69d3aec --- /dev/null +++ b/api/oauth_test.go @@ -0,0 +1,598 @@ +package api + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + "sync/atomic" + "testing" + "time" +) + +// newTestService builds a white-box service pointed at the given issuer with an +// isolated temp credentials file, so tests never touch the real ~/.config path. +func newTestService(t *testing.T, issuer string) *OAuthService { + t.Helper() + return &OAuthService{ + issuer: strings.TrimRight(issuer, "/"), + clientName: defaultClientName, + credPath: filepath.Join(t.TempDir(), "credentials.json"), + httpClient: &http.Client{}, + } +} + +func TestPKCE_S256Relationship(t *testing.T) { + verifier, challenge, err := pkce() + if err != nil { + t.Fatalf("pkce: %v", err) + } + // 32 raw bytes base64url-encoded is 43 chars. + if len(verifier) != 43 { + t.Errorf("verifier length = %d, want 43", len(verifier)) + } + sum := sha256.Sum256([]byte(verifier)) + want := base64.RawURLEncoding.EncodeToString(sum[:]) + if challenge != want { + t.Errorf("challenge = %q, want sha256(verifier) = %q", challenge, want) + } + // Distinct each call. + v2, _, _ := pkce() + if verifier == v2 { + t.Error("two pkce verifiers were identical") + } +} + +func TestRandomState(t *testing.T) { + s1, err := randomState() + if err != nil { + t.Fatalf("randomState: %v", err) + } + if len(s1) != 32 { // 16 bytes hex + t.Errorf("state length = %d, want 32", len(s1)) + } + s2, _ := randomState() + if s1 == s2 { + t.Error("two states were identical") + } +} + +func TestExpiresAtMillis(t *testing.T) { + before := time.Now().Add((3600 - expiryBufferSeconds) * time.Second).UnixMilli() + got := expiresAtMillis(3600) + after := time.Now().Add((3600 - expiryBufferSeconds) * time.Second).UnixMilli() + if got < before || got > after { + t.Errorf("expiresAtMillis(3600) = %d, want within [%d,%d]", got, before, after) + } + + // A lifetime shorter than the buffer must clamp to ~now, never to the past. + lower := time.Now().UnixMilli() + clamped := expiresAtMillis(10) + upper := time.Now().Add(time.Second).UnixMilli() + if clamped < lower || clamped > upper { + t.Errorf("expiresAtMillis(10) = %d, want clamped to ~now [%d,%d]", clamped, lower, upper) + } +} + +func TestSaveAndLoadCredentials(t *testing.T) { + s := newTestService(t, "http://unused") + in := credentials{ClientID: "cid", AccessToken: "at", RefreshToken: "rt", ExpiresAt: 123} + + if err := s.saveCredentials(in); err != nil { + t.Fatalf("saveCredentials: %v", err) + } + got, err := s.loadCredentials() + if err != nil { + t.Fatalf("loadCredentials: %v", err) + } + if got != in { + t.Errorf("round-trip = %+v, want %+v", got, in) + } + + if runtime.GOOS != "windows" { + info, err := os.Stat(s.credPath) + if err != nil { + t.Fatalf("stat: %v", err) + } + if perm := info.Mode().Perm(); perm != 0600 { + t.Errorf("file mode = %o, want 0600", perm) + } + } +} + +func TestLoadCredentials_MissingAndCorrupt(t *testing.T) { + s := newTestService(t, "http://unused") + + // Missing file -> empty, no error. + got, err := s.loadCredentials() + if err != nil || got != (credentials{}) { + t.Fatalf("missing file: got (%+v, %v), want (empty, nil)", got, err) + } + + // Corrupt JSON -> empty, no error. + if err := os.WriteFile(s.credPath, []byte("{not json"), 0600); err != nil { + t.Fatal(err) + } + got, err = s.loadCredentials() + if err != nil || got != (credentials{}) { + t.Fatalf("corrupt file: got (%+v, %v), want (empty, nil)", got, err) + } +} + +func TestAuthorizeURL(t *testing.T) { + s := newTestService(t, "https://auth.example.com") + raw := s.authorizeURL("cid", "http://127.0.0.1:5000/oauth/callback", "state123", "chal") + + u, err := url.Parse(raw) + if err != nil { + t.Fatalf("parse: %v", err) + } + if u.Host != "auth.example.com" || u.Path != authorizePath { + t.Errorf("host/path = %s%s", u.Host, u.Path) + } + q := u.Query() + checks := map[string]string{ + "client_id": "cid", + "redirect_uri": "http://127.0.0.1:5000/oauth/callback", + "state": "state123", + "code_challenge": "chal", + "code_challenge_method": "S256", + "response_type": "code", + } + for k, want := range checks { + if got := q.Get(k); got != want { + t.Errorf("query %q = %q, want %q", k, got, want) + } + } +} + +func TestRefresh_Success_RotatesToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != tokenPath { + t.Errorf("unexpected path %s", r.URL.Path) + } + _ = r.ParseForm() + if r.Form.Get("grant_type") != "refresh_token" { + t.Errorf("grant_type = %q", r.Form.Get("grant_type")) + } + writeJSON(w, 200, tokenResponse{AccessToken: "new-at", RefreshToken: "new-rt", ExpiresIn: 3600}) + })) + defer srv.Close() + + s := newTestService(t, srv.URL) + token, err := s.refresh(context.Background(), credentials{ClientID: "cid", RefreshToken: "old-rt"}) + if err != nil { + t.Fatalf("refresh: %v", err) + } + if token != "new-at" { + t.Errorf("token = %q, want new-at", token) + } + + got, _ := s.loadCredentials() + if got.AccessToken != "new-at" || got.RefreshToken != "new-rt" { + t.Errorf("persisted = %+v, want at=new-at rt=new-rt", got) + } + if got.ExpiresAt <= nowMillis() { + t.Errorf("expires_at = %d not in the future", got.ExpiresAt) + } +} + +func TestRefresh_KeepsOldRefreshTokenWhenOmitted(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 200, tokenResponse{AccessToken: "new-at", ExpiresIn: 3600}) // no refresh_token + })) + defer srv.Close() + + s := newTestService(t, srv.URL) + if _, err := s.refresh(context.Background(), credentials{ClientID: "cid", RefreshToken: "keep-me"}); err != nil { + t.Fatalf("refresh: %v", err) + } + got, _ := s.loadCredentials() + if got.RefreshToken != "keep-me" { + t.Errorf("refresh token = %q, want keep-me", got.RefreshToken) + } +} + +func TestRefresh_4xxIsDead(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant"}) + })) + defer srv.Close() + + s := newTestService(t, srv.URL) + _, err := s.refresh(context.Background(), credentials{ClientID: "cid", RefreshToken: "rt"}) + if !errors.Is(err, errRefreshTokenDead) { + t.Fatalf("err = %v, want errRefreshTokenDead", err) + } +} + +func TestRefresh_5xxIsTransient(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer srv.Close() + + s := newTestService(t, srv.URL) + _, err := s.refresh(context.Background(), credentials{ClientID: "cid", RefreshToken: "rt"}) + if err == nil { + t.Fatal("expected error on 5xx") + } + if errors.Is(err, errRefreshTokenDead) { + t.Error("5xx must not be classified as dead refresh token") + } +} + +func TestRegister(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != registerPath { + t.Errorf("unexpected path %s", r.URL.Path) + } + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + if body["client_name"] != defaultClientName { + t.Errorf("client_name = %v", body["client_name"]) + } + writeJSON(w, 200, map[string]any{"client_id": "client-xyz"}) + })) + defer srv.Close() + + s := newTestService(t, srv.URL) + id, err := s.register(context.Background(), "http://127.0.0.1:1/oauth/callback") + if err != nil { + t.Fatalf("register: %v", err) + } + if id != "client-xyz" { + t.Errorf("client_id = %q, want client-xyz", id) + } +} + +func TestRegister_MissingClientID(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 200, map[string]any{}) // no client_id + })) + defer srv.Close() + + s := newTestService(t, srv.URL) + if _, err := s.register(context.Background(), "http://127.0.0.1:1/oauth/callback"); err == nil { + t.Fatal("expected error when client_id is missing") + } +} + +func TestRegister_4xx(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_redirect_uri"}) + })) + defer srv.Close() + + s := newTestService(t, srv.URL) + if _, err := s.register(context.Background(), "http://127.0.0.1:1/oauth/callback"); err == nil { + t.Fatal("expected error on 4xx registration") + } +} + +func TestExchangeCode_PersistsCreds(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + if r.Form.Get("grant_type") != "authorization_code" { + t.Errorf("grant_type = %q", r.Form.Get("grant_type")) + } + if r.Form.Get("code") != "the-code" || r.Form.Get("code_verifier") == "" { + t.Errorf("missing code/verifier: %v", r.Form) + } + writeJSON(w, 200, tokenResponse{AccessToken: "at", RefreshToken: "rt", ExpiresIn: 3600}) + })) + defer srv.Close() + + s := newTestService(t, srv.URL) + token, err := s.exchangeCode(context.Background(), "cid", "the-code", "verifier", "http://127.0.0.1:1/oauth/callback") + if err != nil { + t.Fatalf("exchangeCode: %v", err) + } + if token != "at" { + t.Errorf("token = %q, want at", token) + } + got, _ := s.loadCredentials() + if got.ClientID != "cid" || got.AccessToken != "at" || got.RefreshToken != "rt" { + t.Errorf("persisted creds = %+v", got) + } +} + +func TestResolve_HappyPathMakesNoHTTPCalls(t *testing.T) { + var hits int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + })) + defer srv.Close() + + s := newTestService(t, srv.URL) + // Valid, unexpired cached token. + _ = s.saveCredentials(credentials{ + ClientID: "cid", + AccessToken: "cached-at", + RefreshToken: "rt", + ExpiresAt: nowMillis() + 60_000, + }) + + token, err := s.resolve(context.Background()) + if err != nil { + t.Fatalf("resolve: %v", err) + } + if token != "cached-at" { + t.Errorf("token = %q, want cached-at", token) + } + if n := atomic.LoadInt32(&hits); n != 0 { + t.Errorf("made %d HTTP calls on the happy path, want 0", n) + } +} + +func TestResolve_ExpiredTokenTriggersRefresh(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + if r.Form.Get("grant_type") != "refresh_token" { + t.Errorf("expected refresh grant, got %q", r.Form.Get("grant_type")) + } + writeJSON(w, 200, tokenResponse{AccessToken: "refreshed-at", RefreshToken: "rt2", ExpiresIn: 3600}) + })) + defer srv.Close() + + s := newTestService(t, srv.URL) + _ = s.saveCredentials(credentials{ + ClientID: "cid", + AccessToken: "stale-at", + RefreshToken: "rt", + ExpiresAt: nowMillis() - 1000, // expired + }) + + token, err := s.resolve(context.Background()) + if err != nil { + t.Fatalf("resolve: %v", err) + } + if token != "refreshed-at" { + t.Errorf("token = %q, want refreshed-at", token) + } +} + +func TestAwaitCallback_Success(t *testing.T) { + s := newTestService(t, "http://unused") + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + port := ln.Addr().(*net.TCPAddr).Port + + type res struct { + code string + err error + } + out := make(chan res, 1) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + go func() { + code, err := s.awaitCallback(ctx, ln, "state-ok") + out <- res{code, err} + }() + + // A non-callback path must 404 and leave the flow untouched. + resp404, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/whatever", port)) + if err != nil { + t.Fatalf("get /whatever: %v", err) + } + if resp404.StatusCode != http.StatusNotFound { + t.Errorf("wrong-path status = %d, want 404", resp404.StatusCode) + } + resp404.Body.Close() + + // The real callback resolves the flow. + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/oauth/callback?code=abc&state=state-ok", port)) + if err != nil { + t.Fatalf("get callback: %v", err) + } + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("callback status = %d, want 200", resp.StatusCode) + } + // The embedded success template should have rendered into the response. + if !strings.Contains(strings.ToLower(string(body)), "signed in") { + t.Errorf("success page body missing expected content: %q", string(body)) + } + + r := <-out + if r.err != nil || r.code != "abc" { + t.Errorf("awaitCallback = (%q, %v), want (abc, nil)", r.code, r.err) + } +} + +func TestAwaitCallback_StateMismatch(t *testing.T) { + code, err := runCallback(t, "expected-state", "/oauth/callback?code=abc&state=wrong") + if err == nil || !strings.Contains(err.Error(), "state mismatch") { + t.Fatalf("err = %v, want state mismatch", err) + } + if code != "" { + t.Errorf("code = %q, want empty", code) + } +} + +func TestAwaitCallback_AuthServerError(t *testing.T) { + code, err := runCallback(t, "s", "/oauth/callback?error=access_denied&state=s") + if err == nil || !strings.Contains(err.Error(), "access_denied") { + t.Fatalf("err = %v, want access_denied", err) + } + if code != "" { + t.Errorf("code = %q, want empty", code) + } +} + +// runCallback drives awaitCallback through a single GET and returns its result. +func runCallback(t *testing.T, expectedState, requestPath string) (string, error) { + t.Helper() + s := newTestService(t, "http://unused") + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + port := ln.Addr().(*net.TCPAddr).Port + + type res struct { + code string + err error + } + out := make(chan res, 1) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + go func() { + code, err := s.awaitCallback(ctx, ln, expectedState) + out <- res{code, err} + }() + + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d%s", port, requestPath)) + if err != nil { + t.Fatalf("get: %v", err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + + r := <-out + return r.code, r.err +} + +func TestErrorTemplate_EscapesUntrustedInput(t *testing.T) { + var buf bytes.Buffer + if err := errorTemplate.Execute(&buf, ``); err != nil { + t.Fatalf("execute error template: %v", err) + } + page := buf.String() + if strings.Contains(page, "