diff --git a/api/oauth.go b/api/oauth.go
new file mode 100644
index 0000000..60c6406
--- /dev/null
+++ b/api/oauth.go
@@ -0,0 +1,653 @@
+package api
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "embed"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "html/template"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "sync"
+ "time"
+)
+
+// OAuth 2.0 + PKCE (S256) client for Hostinger SSO. It runs an interactive
+// browser sign-in against a local loopback callback, persists the resulting
+// tokens to disk, and refreshes them on demand so callers can obtain a bearer
+// token without a statically configured API token.
+const (
+ defaultIssuer = "https://auth.hostinger.com"
+ defaultClientName = "hostinger-cli"
+
+ registerPath = "/api/external/v1/oauth-server/register"
+ authorizePath = "/api/external/v1/oauth-server/authorize"
+ tokenPath = "/api/external/v1/oauth-server/token"
+ revokePath = "/api/external/v1/oauth-server/token/revoke"
+
+ callbackPath = "/oauth/callback"
+
+ // httpTimeout bounds a single auth-server round trip; loginTimeout bounds
+ // the interactive wait for the user to finish in their browser. Both are
+ // applied only when the caller's context carries no deadline of its own.
+ httpTimeout = 30 * time.Second
+ loginTimeout = 5 * time.Minute
+
+ // expiryBufferSeconds is subtracted from the server-reported lifetime so we
+ // treat a token as expired slightly early and avoid a mid-request expiry.
+ expiryBufferSeconds = 60
+)
+
+// errRefreshTokenDead marks a refresh grant rejected with a 4xx — the refresh
+// token is definitively dead and the caller must fall back to a full login.
+// A 5xx or network failure is returned as a plain error instead, so transient
+// outages never spuriously launch the browser.
+var errRefreshTokenDead = errors.New("refresh token rejected")
+
+// openBrowser launches the user's default browser. It is a package var so tests
+// can substitute it; in production it is fire-and-forget (see Login).
+var openBrowser = func(rawURL string) error {
+ switch runtime.GOOS {
+ case "darwin":
+ return exec.Command("open", rawURL).Start()
+ case "windows":
+ // The empty quoted string is cmd.exe's required first "title" arg so
+ // that it interprets the URL as the target rather than the title.
+ return exec.Command("cmd", "/c", "start", "", rawURL).Start()
+ default:
+ return exec.Command("xdg-open", rawURL).Start()
+ }
+}
+
+//go:embed templates/*.html
+var templateFS embed.FS
+
+// Callback pages, parsed once at load and reused. The error page interpolates
+// the untrusted auth-server error code via {{ . }}, which html/template escapes.
+var (
+ successTemplate = template.Must(template.ParseFS(templateFS, "templates/success.html"))
+ errorTemplate = template.Must(template.ParseFS(templateFS, "templates/error.html"))
+)
+
+// OAuthService obtains and maintains a bearer token via the OAuth flow.
+// The zero value is not usable; construct one with NewOAuthService.
+type OAuthService struct {
+ issuer string
+ clientName string
+ credPath string
+ httpClient *http.Client
+
+ // mu serializes the public methods so concurrent callers in one process
+ // share a single auth-server round trip rather than racing on disk.
+ mu sync.Mutex
+}
+
+// credentials is the on-disk token state. Absent fields imply "not yet known".
+type credentials struct {
+ ClientID string `json:"client_id,omitempty"`
+ AccessToken string `json:"access_token,omitempty"`
+ RefreshToken string `json:"refresh_token,omitempty"`
+ ExpiresAt int64 `json:"expires_at,omitempty"` // local expiry, ms epoch
+}
+
+type tokenResponse struct {
+ AccessToken string `json:"access_token"`
+ RefreshToken string `json:"refresh_token"`
+ ExpiresIn int64 `json:"expires_in"`
+}
+
+// NewOAuthService builds a service with the default issuer (overridable via the
+// HOSTINGER_OAUTH_ISSUER env var) and the standard on-disk credentials path.
+func NewOAuthService() (*OAuthService, error) {
+ credPath, err := defaultCredPath()
+ if err != nil {
+ return nil, err
+ }
+
+ issuer := defaultIssuer
+ if v := strings.TrimSpace(os.Getenv("HOSTINGER_OAUTH_ISSUER")); v != "" {
+ issuer = v
+ }
+
+ return &OAuthService{
+ issuer: strings.TrimRight(issuer, "/"),
+ clientName: defaultClientName,
+ credPath: credPath,
+ httpClient: &http.Client{},
+ }, nil
+}
+
+// Token returns a usable access token: the cached one if still valid, otherwise
+// a proactive refresh, otherwise a full interactive login.
+func (s *OAuthService) Token(ctx context.Context) (string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.resolve(ctx)
+}
+
+// Login runs the full interactive PKCE flow regardless of cache state.
+func (s *OAuthService) Login(ctx context.Context) (string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.login(ctx)
+}
+
+// Refresh exchanges the stored refresh token for a fresh access token.
+func (s *OAuthService) Refresh(ctx context.Context) (string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ creds, err := s.loadCredentials()
+ if err != nil {
+ return "", err
+ }
+ if creds.RefreshToken == "" || creds.ClientID == "" {
+ return "", errors.New("no refresh token available; run an interactive login first")
+ }
+ return s.refresh(ctx, creds)
+}
+
+// Reauthenticate forces a fresh token, bypassing the cached-token fast path. It
+// mirrors Token's refresh-then-login fallback and is meant for the reactive 401
+// recovery path, where the cached token is already known to be dead.
+func (s *OAuthService) Reauthenticate(ctx context.Context) (string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ creds, _ := s.loadCredentials()
+ return s.refreshOrLogin(ctx, creds)
+}
+
+// Logout best-effort revokes the access token and wipes local tokens, keeping
+// the registered client_id so a later login can skip dynamic registration.
+func (s *OAuthService) Logout(ctx context.Context) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ creds, _ := s.loadCredentials()
+ if creds.ClientID == "" && creds.AccessToken == "" {
+ return nil // nothing stored
+ }
+
+ if creds.AccessToken != "" {
+ form := url.Values{}
+ form.Set("token", creds.AccessToken)
+ form.Set("client_id", creds.ClientID)
+ // Revocation is best-effort: logout must succeed locally even if the
+ // auth server is unreachable.
+ _, _, _ = s.postForm(ctx, revokePath, form)
+ }
+
+ return s.saveCredentials(credentials{ClientID: creds.ClientID})
+}
+
+// resolve is the token decision tree: cached -> refresh -> login. Callers must
+// hold s.mu.
+func (s *OAuthService) resolve(ctx context.Context) (string, error) {
+ creds, _ := s.loadCredentials()
+
+ if creds.AccessToken != "" && creds.ExpiresAt != 0 && nowMillis() < creds.ExpiresAt {
+ return creds.AccessToken, nil
+ }
+ return s.refreshOrLogin(ctx, creds)
+}
+
+// refreshOrLogin attempts a proactive refresh and falls back to a full
+// interactive login when there is no usable refresh token or the refresh token
+// is dead (4xx). A transient refresh failure (5xx/network) is surfaced as-is so
+// it never spuriously launches a browser. Callers must hold s.mu.
+func (s *OAuthService) refreshOrLogin(ctx context.Context, creds credentials) (string, error) {
+ if creds.RefreshToken != "" && creds.ClientID != "" {
+ token, err := s.refresh(ctx, creds)
+ if err == nil {
+ return token, nil
+ }
+ if !errors.Is(err, errRefreshTokenDead) {
+ return "", err
+ }
+ // 4xx: refresh token is dead, fall through to a full login.
+ }
+ return s.login(ctx)
+}
+
+// login runs the PKCE flow end to end and returns the issued access token.
+// Callers must hold s.mu.
+func (s *OAuthService) login(ctx context.Context) (string, error) {
+ // Bind the loopback callback before launching the browser so the redirect
+ // can never race ahead of the listener.
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return "", fmt.Errorf("failed to start callback listener: %w", err)
+ }
+ defer listener.Close()
+
+ port := listener.Addr().(*net.TCPAddr).Port
+ redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, callbackPath)
+
+ creds, _ := s.loadCredentials()
+ if creds.ClientID == "" {
+ clientID, err := s.register(ctx, redirectURI)
+ if err != nil {
+ return "", err
+ }
+ creds.ClientID = clientID
+ if err := s.saveCredentials(creds); err != nil {
+ return "", err
+ }
+ }
+
+ verifier, challenge, err := pkce()
+ if err != nil {
+ return "", err
+ }
+ state, err := randomState()
+ if err != nil {
+ return "", err
+ }
+
+ authURL := s.authorizeURL(creds.ClientID, redirectURI, state, challenge)
+ if err := openBrowser(authURL); err != nil {
+ fmt.Fprintf(os.Stderr, "Could not open a browser automatically.\nOpen this URL to continue signing in:\n\n%s\n\n", authURL)
+ }
+
+ waitCtx, cancel := withTimeout(ctx, loginTimeout)
+ defer cancel()
+ code, err := s.awaitCallback(waitCtx, listener, state)
+ if err != nil {
+ return "", err
+ }
+
+ return s.exchangeCode(ctx, creds.ClientID, code, verifier, redirectURI)
+}
+
+// register performs RFC 7591 dynamic client registration and returns the new
+// client_id. It is a public (PKCE-only) client, so no client_secret is expected.
+func (s *OAuthService) register(ctx context.Context, redirectURI string) (string, error) {
+ payload, err := json.Marshal(map[string]any{
+ "client_name": s.clientName,
+ "redirect_uris": []string{redirectURI},
+ })
+ if err != nil {
+ return "", err
+ }
+
+ status, body, err := s.postJSON(ctx, registerPath, payload)
+ if err != nil {
+ return "", err
+ }
+ if !is2xx(status) {
+ return "", fmt.Errorf("client registration failed (status %d): %s", status, snippet(body))
+ }
+
+ var out struct {
+ ClientID string `json:"client_id"`
+ }
+ if err := json.Unmarshal(body, &out); err != nil {
+ return "", fmt.Errorf("client registration: invalid response: %w", err)
+ }
+ if out.ClientID == "" {
+ return "", fmt.Errorf("client registration response missing client_id: %s", snippet(body))
+ }
+ return out.ClientID, nil
+}
+
+// exchangeCode swaps the authorization code for tokens and persists them.
+func (s *OAuthService) exchangeCode(ctx context.Context, clientID, code, verifier, redirectURI string) (string, error) {
+ form := url.Values{}
+ form.Set("grant_type", "authorization_code")
+ form.Set("code", code)
+ form.Set("code_verifier", verifier)
+ form.Set("redirect_uri", redirectURI)
+ form.Set("client_id", clientID)
+
+ status, body, err := s.postForm(ctx, tokenPath, form)
+ if err != nil {
+ return "", err
+ }
+ if !is2xx(status) {
+ return "", fmt.Errorf("token exchange failed (status %d): %s", status, snippet(body))
+ }
+
+ tr, err := parseTokenResponse(body)
+ if err != nil {
+ return "", err
+ }
+
+ creds, _ := s.loadCredentials()
+ creds.ClientID = clientID
+ creds.AccessToken = tr.AccessToken
+ creds.RefreshToken = tr.RefreshToken
+ creds.ExpiresAt = expiresAtMillis(tr.ExpiresIn)
+ if err := s.saveCredentials(creds); err != nil {
+ return "", err
+ }
+ return creds.AccessToken, nil
+}
+
+// refresh runs a refresh-token grant. A 4xx is reported as errRefreshTokenDead;
+// a 5xx or network error is returned verbatim so the caller can treat it as
+// transient.
+func (s *OAuthService) refresh(ctx context.Context, creds credentials) (string, error) {
+ form := url.Values{}
+ form.Set("grant_type", "refresh_token")
+ form.Set("refresh_token", creds.RefreshToken)
+ form.Set("client_id", creds.ClientID)
+
+ status, body, err := s.postForm(ctx, tokenPath, form)
+ if err != nil {
+ return "", err
+ }
+ switch {
+ case is2xx(status):
+ // handled below
+ case status >= 400 && status < 500:
+ return "", fmt.Errorf("%w (status %d)", errRefreshTokenDead, status)
+ default:
+ return "", fmt.Errorf("token refresh failed (status %d): %s", status, snippet(body))
+ }
+
+ tr, err := parseTokenResponse(body)
+ if err != nil {
+ return "", err
+ }
+
+ creds.AccessToken = tr.AccessToken
+ if tr.RefreshToken != "" {
+ creds.RefreshToken = tr.RefreshToken // server may rotate; keep old otherwise
+ }
+ creds.ExpiresAt = expiresAtMillis(tr.ExpiresIn)
+ if err := s.saveCredentials(creds); err != nil {
+ return "", err
+ }
+ return creds.AccessToken, nil
+}
+
+// awaitCallback serves the loopback listener until the OAuth redirect arrives,
+// then returns the authorization code. Only GET /oauth/callback is honored;
+// everything else is 404 and does not affect the flow.
+func (s *OAuthService) awaitCallback(ctx context.Context, listener net.Listener, expectedState string) (string, error) {
+ type result struct {
+ code string
+ err error
+ }
+ // Buffered + non-blocking send: the first callback resolves the flow; any
+ // duplicate hit (browser retry, replay) is answered but its result dropped,
+ // so the handler goroutine never blocks on a full channel.
+ done := make(chan result, 1)
+ reply := func(res result) {
+ select {
+ case done <- res:
+ default:
+ }
+ }
+
+ mux := http.NewServeMux()
+ mux.HandleFunc(callbackPath, func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.NotFound(w, r)
+ return
+ }
+ q := r.URL.Query()
+
+ if e := q.Get("error"); e != "" {
+ renderPage(w, http.StatusBadRequest, errorTemplate, e)
+ reply(result{err: fmt.Errorf("authorization failed: %s", e)})
+ return
+ }
+ // CSRF guard: the state must match byte-for-byte what we generated.
+ if q.Get("state") != expectedState {
+ renderPage(w, http.StatusBadRequest, errorTemplate, "state mismatch")
+ reply(result{err: errors.New("oauth state mismatch")})
+ return
+ }
+ code := q.Get("code")
+ if code == "" {
+ renderPage(w, http.StatusBadRequest, errorTemplate, "missing authorization code")
+ reply(result{err: errors.New("oauth callback missing authorization code")})
+ return
+ }
+ renderPage(w, http.StatusOK, successTemplate, nil)
+ reply(result{code: code})
+ })
+
+ srv := &http.Server{Handler: mux}
+ go srv.Serve(listener) //nolint:errcheck // returns ErrServerClosed on shutdown
+ defer func() {
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ _ = srv.Shutdown(shutdownCtx) // graceful: lets the final response flush
+ }()
+
+ select {
+ case <-ctx.Done():
+ return "", ctx.Err()
+ case res := <-done:
+ return res.code, res.err
+ }
+}
+
+func (s *OAuthService) authorizeURL(clientID, redirectURI, state, challenge string) string {
+ q := url.Values{}
+ q.Set("client_id", clientID)
+ q.Set("redirect_uri", redirectURI)
+ q.Set("state", state)
+ q.Set("code_challenge", challenge)
+ q.Set("code_challenge_method", "S256")
+ q.Set("response_type", "code")
+ return s.issuer + authorizePath + "?" + q.Encode()
+}
+
+// --- HTTP plumbing ---------------------------------------------------------
+
+func (s *OAuthService) postForm(ctx context.Context, path string, form url.Values) (int, []byte, error) {
+ ctx, cancel := withTimeout(ctx, httpTimeout)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.issuer+path, strings.NewReader(form.Encode()))
+ if err != nil {
+ return 0, nil, err
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ return s.do(req)
+}
+
+func (s *OAuthService) postJSON(ctx context.Context, path string, payload []byte) (int, []byte, error) {
+ ctx, cancel := withTimeout(ctx, httpTimeout)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.issuer+path, bytes.NewReader(payload))
+ if err != nil {
+ return 0, nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ return s.do(req)
+}
+
+func (s *OAuthService) do(req *http.Request) (int, []byte, error) {
+ resp, err := s.httpClient.Do(req)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return resp.StatusCode, nil, err
+ }
+ return resp.StatusCode, body, nil
+}
+
+// --- credentials persistence -----------------------------------------------
+
+// loadCredentials reads the on-disk state. A missing or corrupt file is treated
+// as "no credentials" (zero value, nil error).
+func (s *OAuthService) loadCredentials() (credentials, error) {
+ data, err := os.ReadFile(s.credPath)
+ if err != nil {
+ if errors.Is(err, os.ErrNotExist) {
+ return credentials{}, nil
+ }
+ return credentials{}, err
+ }
+
+ var creds credentials
+ if err := json.Unmarshal(data, &creds); err != nil {
+ return credentials{}, nil // corrupt: start fresh
+ }
+ return creds, nil
+}
+
+// saveCredentials writes the state atomically (temp file + rename) with mode
+// 0600, creating the parent directory on demand.
+func (s *OAuthService) saveCredentials(creds credentials) error {
+ data, err := json.MarshalIndent(creds, "", " ")
+ if err != nil {
+ return err
+ }
+
+ dir := filepath.Dir(s.credPath)
+ if err := os.MkdirAll(dir, 0700); err != nil {
+ return err
+ }
+
+ tmp, err := os.CreateTemp(dir, "credentials-*.tmp")
+ if err != nil {
+ return err
+ }
+ tmpName := tmp.Name()
+ defer os.Remove(tmpName) // no-op once the rename succeeds
+
+ if err := tmp.Chmod(0600); err != nil {
+ tmp.Close()
+ return err
+ }
+ if _, err := tmp.Write(data); err != nil {
+ tmp.Close()
+ return err
+ }
+ if err := tmp.Close(); err != nil {
+ return err
+ }
+ return os.Rename(tmpName, s.credPath)
+}
+
+// --- helpers ---------------------------------------------------------------
+
+// defaultCredPath returns ~/.config/hostinger/api-cli/credentials.json on POSIX
+// and %APPDATA%\hostinger\api-cli\credentials.json on Windows.
+func defaultCredPath() (string, error) {
+ if runtime.GOOS == "windows" {
+ base := os.Getenv("APPDATA")
+ if base == "" {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "", err
+ }
+ base = filepath.Join(home, "AppData", "Roaming")
+ }
+ return filepath.Join(base, "hostinger", "api-cli", "credentials.json"), nil
+ }
+
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "", err
+ }
+ return filepath.Join(home, ".config", "hostinger", "api-cli", "credentials.json"), nil
+}
+
+// withTimeout applies d only when ctx carries no deadline of its own.
+func withTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
+ if _, ok := ctx.Deadline(); ok {
+ return ctx, func() {}
+ }
+ return context.WithTimeout(ctx, d)
+}
+
+// pkce generates an S256 verifier/challenge pair from 32 bytes of CSPRNG output.
+func pkce() (verifier, challenge string, err error) {
+ b := make([]byte, 32)
+ if _, err := rand.Read(b); err != nil {
+ return "", "", err
+ }
+ verifier = base64.RawURLEncoding.EncodeToString(b)
+ sum := sha256.Sum256([]byte(verifier))
+ challenge = base64.RawURLEncoding.EncodeToString(sum[:])
+ return verifier, challenge, nil
+}
+
+// randomState returns 16 bytes of CSPRNG output as lowercase hex (the CSRF token).
+func randomState() (string, error) {
+ b := make([]byte, 16)
+ if _, err := rand.Read(b); err != nil {
+ return "", err
+ }
+ return hex.EncodeToString(b), nil
+}
+
+func parseTokenResponse(body []byte) (tokenResponse, error) {
+ var tr tokenResponse
+ if err := json.Unmarshal(body, &tr); err != nil {
+ return tokenResponse{}, fmt.Errorf("invalid token response: %w", err)
+ }
+ if tr.AccessToken == "" {
+ return tokenResponse{}, errors.New("token response missing access_token")
+ }
+ return tr, nil
+}
+
+func expiresAtMillis(expiresIn int64) int64 {
+ // Clamp so a short-lived (or zero) server lifetime can't underflow into the
+ // past; the token is simply treated as already expired on the next call.
+ effective := expiresIn - expiryBufferSeconds
+ if effective < 0 {
+ effective = 0
+ }
+ return time.Now().Add(time.Duration(effective) * time.Second).UnixMilli()
+}
+
+func nowMillis() int64 {
+ return time.Now().UnixMilli()
+}
+
+func is2xx(status int) bool {
+ return status >= 200 && status < 300
+}
+
+// renderPage executes a callback page template into the response. It renders to
+// a buffer first so a template failure can't leave a half-written body, and
+// flushes so the page reaches the browser before the listener is shut down.
+func renderPage(w http.ResponseWriter, status int, tmpl *template.Template, data any) {
+ var buf bytes.Buffer
+ if err := tmpl.Execute(&buf, data); err != nil {
+ http.Error(w, "internal error", http.StatusInternalServerError)
+ return
+ }
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.WriteHeader(status)
+ _, _ = buf.WriteTo(w)
+ if f, ok := w.(http.Flusher); ok {
+ f.Flush()
+ }
+}
+
+// snippet trims and bounds an untrusted response body for inclusion in errors.
+func snippet(body []byte) string {
+ s := strings.TrimSpace(string(body))
+ if len(s) > 512 {
+ s = s[:512] + "…"
+ }
+ return s
+}
diff --git a/api/oauth_test.go b/api/oauth_test.go
new file mode 100644
index 0000000..69d3aec
--- /dev/null
+++ b/api/oauth_test.go
@@ -0,0 +1,598 @@
+package api
+
+import (
+ "bytes"
+ "context"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+// newTestService builds a white-box service pointed at the given issuer with an
+// isolated temp credentials file, so tests never touch the real ~/.config path.
+func newTestService(t *testing.T, issuer string) *OAuthService {
+ t.Helper()
+ return &OAuthService{
+ issuer: strings.TrimRight(issuer, "/"),
+ clientName: defaultClientName,
+ credPath: filepath.Join(t.TempDir(), "credentials.json"),
+ httpClient: &http.Client{},
+ }
+}
+
+func TestPKCE_S256Relationship(t *testing.T) {
+ verifier, challenge, err := pkce()
+ if err != nil {
+ t.Fatalf("pkce: %v", err)
+ }
+ // 32 raw bytes base64url-encoded is 43 chars.
+ if len(verifier) != 43 {
+ t.Errorf("verifier length = %d, want 43", len(verifier))
+ }
+ sum := sha256.Sum256([]byte(verifier))
+ want := base64.RawURLEncoding.EncodeToString(sum[:])
+ if challenge != want {
+ t.Errorf("challenge = %q, want sha256(verifier) = %q", challenge, want)
+ }
+ // Distinct each call.
+ v2, _, _ := pkce()
+ if verifier == v2 {
+ t.Error("two pkce verifiers were identical")
+ }
+}
+
+func TestRandomState(t *testing.T) {
+ s1, err := randomState()
+ if err != nil {
+ t.Fatalf("randomState: %v", err)
+ }
+ if len(s1) != 32 { // 16 bytes hex
+ t.Errorf("state length = %d, want 32", len(s1))
+ }
+ s2, _ := randomState()
+ if s1 == s2 {
+ t.Error("two states were identical")
+ }
+}
+
+func TestExpiresAtMillis(t *testing.T) {
+ before := time.Now().Add((3600 - expiryBufferSeconds) * time.Second).UnixMilli()
+ got := expiresAtMillis(3600)
+ after := time.Now().Add((3600 - expiryBufferSeconds) * time.Second).UnixMilli()
+ if got < before || got > after {
+ t.Errorf("expiresAtMillis(3600) = %d, want within [%d,%d]", got, before, after)
+ }
+
+ // A lifetime shorter than the buffer must clamp to ~now, never to the past.
+ lower := time.Now().UnixMilli()
+ clamped := expiresAtMillis(10)
+ upper := time.Now().Add(time.Second).UnixMilli()
+ if clamped < lower || clamped > upper {
+ t.Errorf("expiresAtMillis(10) = %d, want clamped to ~now [%d,%d]", clamped, lower, upper)
+ }
+}
+
+func TestSaveAndLoadCredentials(t *testing.T) {
+ s := newTestService(t, "http://unused")
+ in := credentials{ClientID: "cid", AccessToken: "at", RefreshToken: "rt", ExpiresAt: 123}
+
+ if err := s.saveCredentials(in); err != nil {
+ t.Fatalf("saveCredentials: %v", err)
+ }
+ got, err := s.loadCredentials()
+ if err != nil {
+ t.Fatalf("loadCredentials: %v", err)
+ }
+ if got != in {
+ t.Errorf("round-trip = %+v, want %+v", got, in)
+ }
+
+ if runtime.GOOS != "windows" {
+ info, err := os.Stat(s.credPath)
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ if perm := info.Mode().Perm(); perm != 0600 {
+ t.Errorf("file mode = %o, want 0600", perm)
+ }
+ }
+}
+
+func TestLoadCredentials_MissingAndCorrupt(t *testing.T) {
+ s := newTestService(t, "http://unused")
+
+ // Missing file -> empty, no error.
+ got, err := s.loadCredentials()
+ if err != nil || got != (credentials{}) {
+ t.Fatalf("missing file: got (%+v, %v), want (empty, nil)", got, err)
+ }
+
+ // Corrupt JSON -> empty, no error.
+ if err := os.WriteFile(s.credPath, []byte("{not json"), 0600); err != nil {
+ t.Fatal(err)
+ }
+ got, err = s.loadCredentials()
+ if err != nil || got != (credentials{}) {
+ t.Fatalf("corrupt file: got (%+v, %v), want (empty, nil)", got, err)
+ }
+}
+
+func TestAuthorizeURL(t *testing.T) {
+ s := newTestService(t, "https://auth.example.com")
+ raw := s.authorizeURL("cid", "http://127.0.0.1:5000/oauth/callback", "state123", "chal")
+
+ u, err := url.Parse(raw)
+ if err != nil {
+ t.Fatalf("parse: %v", err)
+ }
+ if u.Host != "auth.example.com" || u.Path != authorizePath {
+ t.Errorf("host/path = %s%s", u.Host, u.Path)
+ }
+ q := u.Query()
+ checks := map[string]string{
+ "client_id": "cid",
+ "redirect_uri": "http://127.0.0.1:5000/oauth/callback",
+ "state": "state123",
+ "code_challenge": "chal",
+ "code_challenge_method": "S256",
+ "response_type": "code",
+ }
+ for k, want := range checks {
+ if got := q.Get(k); got != want {
+ t.Errorf("query %q = %q, want %q", k, got, want)
+ }
+ }
+}
+
+func TestRefresh_Success_RotatesToken(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != tokenPath {
+ t.Errorf("unexpected path %s", r.URL.Path)
+ }
+ _ = r.ParseForm()
+ if r.Form.Get("grant_type") != "refresh_token" {
+ t.Errorf("grant_type = %q", r.Form.Get("grant_type"))
+ }
+ writeJSON(w, 200, tokenResponse{AccessToken: "new-at", RefreshToken: "new-rt", ExpiresIn: 3600})
+ }))
+ defer srv.Close()
+
+ s := newTestService(t, srv.URL)
+ token, err := s.refresh(context.Background(), credentials{ClientID: "cid", RefreshToken: "old-rt"})
+ if err != nil {
+ t.Fatalf("refresh: %v", err)
+ }
+ if token != "new-at" {
+ t.Errorf("token = %q, want new-at", token)
+ }
+
+ got, _ := s.loadCredentials()
+ if got.AccessToken != "new-at" || got.RefreshToken != "new-rt" {
+ t.Errorf("persisted = %+v, want at=new-at rt=new-rt", got)
+ }
+ if got.ExpiresAt <= nowMillis() {
+ t.Errorf("expires_at = %d not in the future", got.ExpiresAt)
+ }
+}
+
+func TestRefresh_KeepsOldRefreshTokenWhenOmitted(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ writeJSON(w, 200, tokenResponse{AccessToken: "new-at", ExpiresIn: 3600}) // no refresh_token
+ }))
+ defer srv.Close()
+
+ s := newTestService(t, srv.URL)
+ if _, err := s.refresh(context.Background(), credentials{ClientID: "cid", RefreshToken: "keep-me"}); err != nil {
+ t.Fatalf("refresh: %v", err)
+ }
+ got, _ := s.loadCredentials()
+ if got.RefreshToken != "keep-me" {
+ t.Errorf("refresh token = %q, want keep-me", got.RefreshToken)
+ }
+}
+
+func TestRefresh_4xxIsDead(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant"})
+ }))
+ defer srv.Close()
+
+ s := newTestService(t, srv.URL)
+ _, err := s.refresh(context.Background(), credentials{ClientID: "cid", RefreshToken: "rt"})
+ if !errors.Is(err, errRefreshTokenDead) {
+ t.Fatalf("err = %v, want errRefreshTokenDead", err)
+ }
+}
+
+func TestRefresh_5xxIsTransient(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, "boom", http.StatusInternalServerError)
+ }))
+ defer srv.Close()
+
+ s := newTestService(t, srv.URL)
+ _, err := s.refresh(context.Background(), credentials{ClientID: "cid", RefreshToken: "rt"})
+ if err == nil {
+ t.Fatal("expected error on 5xx")
+ }
+ if errors.Is(err, errRefreshTokenDead) {
+ t.Error("5xx must not be classified as dead refresh token")
+ }
+}
+
+func TestRegister(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != registerPath {
+ t.Errorf("unexpected path %s", r.URL.Path)
+ }
+ var body map[string]any
+ _ = json.NewDecoder(r.Body).Decode(&body)
+ if body["client_name"] != defaultClientName {
+ t.Errorf("client_name = %v", body["client_name"])
+ }
+ writeJSON(w, 200, map[string]any{"client_id": "client-xyz"})
+ }))
+ defer srv.Close()
+
+ s := newTestService(t, srv.URL)
+ id, err := s.register(context.Background(), "http://127.0.0.1:1/oauth/callback")
+ if err != nil {
+ t.Fatalf("register: %v", err)
+ }
+ if id != "client-xyz" {
+ t.Errorf("client_id = %q, want client-xyz", id)
+ }
+}
+
+func TestRegister_MissingClientID(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ writeJSON(w, 200, map[string]any{}) // no client_id
+ }))
+ defer srv.Close()
+
+ s := newTestService(t, srv.URL)
+ if _, err := s.register(context.Background(), "http://127.0.0.1:1/oauth/callback"); err == nil {
+ t.Fatal("expected error when client_id is missing")
+ }
+}
+
+func TestRegister_4xx(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_redirect_uri"})
+ }))
+ defer srv.Close()
+
+ s := newTestService(t, srv.URL)
+ if _, err := s.register(context.Background(), "http://127.0.0.1:1/oauth/callback"); err == nil {
+ t.Fatal("expected error on 4xx registration")
+ }
+}
+
+func TestExchangeCode_PersistsCreds(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = r.ParseForm()
+ if r.Form.Get("grant_type") != "authorization_code" {
+ t.Errorf("grant_type = %q", r.Form.Get("grant_type"))
+ }
+ if r.Form.Get("code") != "the-code" || r.Form.Get("code_verifier") == "" {
+ t.Errorf("missing code/verifier: %v", r.Form)
+ }
+ writeJSON(w, 200, tokenResponse{AccessToken: "at", RefreshToken: "rt", ExpiresIn: 3600})
+ }))
+ defer srv.Close()
+
+ s := newTestService(t, srv.URL)
+ token, err := s.exchangeCode(context.Background(), "cid", "the-code", "verifier", "http://127.0.0.1:1/oauth/callback")
+ if err != nil {
+ t.Fatalf("exchangeCode: %v", err)
+ }
+ if token != "at" {
+ t.Errorf("token = %q, want at", token)
+ }
+ got, _ := s.loadCredentials()
+ if got.ClientID != "cid" || got.AccessToken != "at" || got.RefreshToken != "rt" {
+ t.Errorf("persisted creds = %+v", got)
+ }
+}
+
+func TestResolve_HappyPathMakesNoHTTPCalls(t *testing.T) {
+ var hits int32
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ atomic.AddInt32(&hits, 1)
+ }))
+ defer srv.Close()
+
+ s := newTestService(t, srv.URL)
+ // Valid, unexpired cached token.
+ _ = s.saveCredentials(credentials{
+ ClientID: "cid",
+ AccessToken: "cached-at",
+ RefreshToken: "rt",
+ ExpiresAt: nowMillis() + 60_000,
+ })
+
+ token, err := s.resolve(context.Background())
+ if err != nil {
+ t.Fatalf("resolve: %v", err)
+ }
+ if token != "cached-at" {
+ t.Errorf("token = %q, want cached-at", token)
+ }
+ if n := atomic.LoadInt32(&hits); n != 0 {
+ t.Errorf("made %d HTTP calls on the happy path, want 0", n)
+ }
+}
+
+func TestResolve_ExpiredTokenTriggersRefresh(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = r.ParseForm()
+ if r.Form.Get("grant_type") != "refresh_token" {
+ t.Errorf("expected refresh grant, got %q", r.Form.Get("grant_type"))
+ }
+ writeJSON(w, 200, tokenResponse{AccessToken: "refreshed-at", RefreshToken: "rt2", ExpiresIn: 3600})
+ }))
+ defer srv.Close()
+
+ s := newTestService(t, srv.URL)
+ _ = s.saveCredentials(credentials{
+ ClientID: "cid",
+ AccessToken: "stale-at",
+ RefreshToken: "rt",
+ ExpiresAt: nowMillis() - 1000, // expired
+ })
+
+ token, err := s.resolve(context.Background())
+ if err != nil {
+ t.Fatalf("resolve: %v", err)
+ }
+ if token != "refreshed-at" {
+ t.Errorf("token = %q, want refreshed-at", token)
+ }
+}
+
+func TestAwaitCallback_Success(t *testing.T) {
+ s := newTestService(t, "http://unused")
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ port := ln.Addr().(*net.TCPAddr).Port
+
+ type res struct {
+ code string
+ err error
+ }
+ out := make(chan res, 1)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ go func() {
+ code, err := s.awaitCallback(ctx, ln, "state-ok")
+ out <- res{code, err}
+ }()
+
+ // A non-callback path must 404 and leave the flow untouched.
+ resp404, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/whatever", port))
+ if err != nil {
+ t.Fatalf("get /whatever: %v", err)
+ }
+ if resp404.StatusCode != http.StatusNotFound {
+ t.Errorf("wrong-path status = %d, want 404", resp404.StatusCode)
+ }
+ resp404.Body.Close()
+
+ // The real callback resolves the flow.
+ resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/oauth/callback?code=abc&state=state-ok", port))
+ if err != nil {
+ t.Fatalf("get callback: %v", err)
+ }
+ body, _ := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("callback status = %d, want 200", resp.StatusCode)
+ }
+ // The embedded success template should have rendered into the response.
+ if !strings.Contains(strings.ToLower(string(body)), "signed in") {
+ t.Errorf("success page body missing expected content: %q", string(body))
+ }
+
+ r := <-out
+ if r.err != nil || r.code != "abc" {
+ t.Errorf("awaitCallback = (%q, %v), want (abc, nil)", r.code, r.err)
+ }
+}
+
+func TestAwaitCallback_StateMismatch(t *testing.T) {
+ code, err := runCallback(t, "expected-state", "/oauth/callback?code=abc&state=wrong")
+ if err == nil || !strings.Contains(err.Error(), "state mismatch") {
+ t.Fatalf("err = %v, want state mismatch", err)
+ }
+ if code != "" {
+ t.Errorf("code = %q, want empty", code)
+ }
+}
+
+func TestAwaitCallback_AuthServerError(t *testing.T) {
+ code, err := runCallback(t, "s", "/oauth/callback?error=access_denied&state=s")
+ if err == nil || !strings.Contains(err.Error(), "access_denied") {
+ t.Fatalf("err = %v, want access_denied", err)
+ }
+ if code != "" {
+ t.Errorf("code = %q, want empty", code)
+ }
+}
+
+// runCallback drives awaitCallback through a single GET and returns its result.
+func runCallback(t *testing.T, expectedState, requestPath string) (string, error) {
+ t.Helper()
+ s := newTestService(t, "http://unused")
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ port := ln.Addr().(*net.TCPAddr).Port
+
+ type res struct {
+ code string
+ err error
+ }
+ out := make(chan res, 1)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ go func() {
+ code, err := s.awaitCallback(ctx, ln, expectedState)
+ out <- res{code, err}
+ }()
+
+ resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d%s", port, requestPath))
+ if err != nil {
+ t.Fatalf("get: %v", err)
+ }
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+
+ r := <-out
+ return r.code, r.err
+}
+
+func TestErrorTemplate_EscapesUntrustedInput(t *testing.T) {
+ var buf bytes.Buffer
+ if err := errorTemplate.Execute(&buf, ``); err != nil {
+ t.Fatalf("execute error template: %v", err)
+ }
+ page := buf.String()
+ if strings.Contains(page, "