A community based topic aggregation platform built on atproto

feat(oauth): add OAuth system with mobile Universal Links support

- OAuth client for atproto authentication flow
- Session store with CSRF protection and secure token sealing
- Mobile-specific handlers with Universal Links redirect
- Database migrations for OAuth sessions and CSRF tokens

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

+3786
+198
internal/atproto/oauth/client.go
··· 1 + package oauth 2 + 3 + import ( 4 + "encoding/base64" 5 + "fmt" 6 + "net/url" 7 + "time" 8 + 9 + "github.com/bluesky-social/indigo/atproto/atcrypto" 10 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 11 + "github.com/bluesky-social/indigo/atproto/identity" 12 + ) 13 + 14 + // OAuthClient wraps indigo's OAuth ClientApp with Coves-specific configuration 15 + type OAuthClient struct { 16 + ClientApp *oauth.ClientApp 17 + Config *OAuthConfig 18 + SealSecret []byte // For sealing mobile tokens 19 + } 20 + 21 + // OAuthConfig holds Coves OAuth client configuration 22 + type OAuthConfig struct { 23 + PublicURL string 24 + ClientSecret string 25 + ClientKID string 26 + SealSecret string 27 + PLCURL string 28 + Scopes []string 29 + SessionTTL time.Duration 30 + SealedTokenTTL time.Duration 31 + DevMode bool 32 + AllowPrivateIPs bool 33 + } 34 + 35 + // NewOAuthClient creates a new OAuth client for Coves 36 + func NewOAuthClient(config *OAuthConfig, store oauth.ClientAuthStore) (*OAuthClient, error) { 37 + if config == nil { 38 + return nil, fmt.Errorf("config is required") 39 + } 40 + 41 + // Validate seal secret 42 + var sealSecret []byte 43 + if config.SealSecret != "" { 44 + decoded, err := base64.StdEncoding.DecodeString(config.SealSecret) 45 + if err != nil { 46 + return nil, fmt.Errorf("failed to decode seal secret: %w", err) 47 + } 48 + if len(decoded) != 32 { 49 + return nil, fmt.Errorf("seal secret must be 32 bytes, got %d", len(decoded)) 50 + } 51 + sealSecret = decoded 52 + } 53 + 54 + // Validate scopes 55 + if len(config.Scopes) == 0 { 56 + return nil, fmt.Errorf("scopes are required") 57 + } 58 + hasAtproto := false 59 + for _, scope := range config.Scopes { 60 + if scope == "atproto" { 61 + hasAtproto = true 62 + break 63 + } 64 + } 65 + if !hasAtproto { 66 + return nil, fmt.Errorf("scopes must include 'atproto'") 67 + } 68 + 69 + // Set default TTL values if not specified 70 + // Per atproto OAuth spec: 71 + // - Public clients: 2-week (14 day) maximum session lifetime 72 + // - Confidential clients: 180-day maximum session lifetime 73 + if config.SessionTTL == 0 { 74 + config.SessionTTL = 7 * 24 * time.Hour // 7 days default 75 + } 76 + if config.SealedTokenTTL == 0 { 77 + config.SealedTokenTTL = 14 * 24 * time.Hour // 14 days (public client limit) 78 + } 79 + 80 + // Create indigo client config 81 + var clientConfig oauth.ClientConfig 82 + if config.DevMode { 83 + // Dev mode: localhost with HTTP 84 + callbackURL := "http://localhost:3000/oauth/callback" 85 + clientConfig = oauth.NewLocalhostConfig(callbackURL, config.Scopes) 86 + } else { 87 + // Production mode: HTTPS with client secret 88 + callbackURL := config.PublicURL + "/oauth/callback" 89 + clientConfig = oauth.NewPublicConfig(config.PublicURL, callbackURL, config.Scopes) 90 + 91 + // Set up confidential client if client secret is provided 92 + if config.ClientSecret != "" && config.ClientKID != "" { 93 + privKey, err := atcrypto.ParsePrivateMultibase(config.ClientSecret) 94 + if err != nil { 95 + return nil, fmt.Errorf("failed to parse client secret: %w", err) 96 + } 97 + 98 + if err := clientConfig.SetClientSecret(privKey, config.ClientKID); err != nil { 99 + return nil, fmt.Errorf("failed to set client secret: %w", err) 100 + } 101 + } 102 + } 103 + 104 + // Set user agent 105 + clientConfig.UserAgent = "Coves/1.0" 106 + 107 + // Create the indigo OAuth ClientApp 108 + clientApp := oauth.NewClientApp(&clientConfig, store) 109 + 110 + // Override the default HTTP client with our SSRF-safe client 111 + // This protects against SSRF attacks via malicious PDS URLs, DID documents, and JWKS URIs 112 + clientApp.Client = NewSSRFSafeHTTPClient(config.AllowPrivateIPs) 113 + 114 + // Override the directory if a custom PLC URL is configured 115 + // This is necessary for local development with a local PLC directory 116 + if config.PLCURL != "" { 117 + // Use SSRF-safe HTTP client for PLC directory requests 118 + httpClient := NewSSRFSafeHTTPClient(config.AllowPrivateIPs) 119 + baseDir := &identity.BaseDirectory{ 120 + PLCURL: config.PLCURL, 121 + HTTPClient: *httpClient, 122 + UserAgent: "Coves/1.0", 123 + } 124 + // Wrap in cache directory for better performance 125 + // Use pointer since CacheDirectory methods have pointer receivers 126 + cacheDir := identity.NewCacheDirectory(baseDir, 100_000, time.Hour*24, time.Minute*2, time.Minute*5) 127 + clientApp.Dir = &cacheDir 128 + } 129 + 130 + return &OAuthClient{ 131 + ClientApp: clientApp, 132 + Config: config, 133 + SealSecret: sealSecret, 134 + }, nil 135 + } 136 + 137 + // ClientMetadata returns the OAuth client metadata document 138 + func (c *OAuthClient) ClientMetadata() oauth.ClientMetadata { 139 + metadata := c.ClientApp.Config.ClientMetadata() 140 + 141 + // Add additional metadata for Coves 142 + metadata.ClientName = strPtr("Coves") 143 + if !c.Config.DevMode { 144 + metadata.ClientURI = strPtr(c.Config.PublicURL) 145 + } 146 + 147 + // For confidential clients, set JWKS URI 148 + if c.ClientApp.Config.IsConfidential() && !c.Config.DevMode { 149 + jwksURI := c.Config.PublicURL + "/.well-known/oauth-jwks.json" 150 + metadata.JWKSURI = &jwksURI 151 + } 152 + 153 + return metadata 154 + } 155 + 156 + // PublicJWKS returns the public JWKS for this client (for confidential clients) 157 + func (c *OAuthClient) PublicJWKS() oauth.JWKS { 158 + return c.ClientApp.Config.PublicJWKS() 159 + } 160 + 161 + // IsConfidential returns true if this is a confidential OAuth client 162 + func (c *OAuthClient) IsConfidential() bool { 163 + return c.ClientApp.Config.IsConfidential() 164 + } 165 + 166 + // strPtr is a helper to get a pointer to a string 167 + func strPtr(s string) *string { 168 + return &s 169 + } 170 + 171 + // ValidateCallbackURL validates that a callback URL matches the expected callback URL 172 + func (c *OAuthClient) ValidateCallbackURL(callbackURL string) error { 173 + expectedCallback := c.ClientApp.Config.CallbackURL 174 + 175 + // Parse both URLs 176 + expected, err := url.Parse(expectedCallback) 177 + if err != nil { 178 + return fmt.Errorf("invalid expected callback URL: %w", err) 179 + } 180 + 181 + actual, err := url.Parse(callbackURL) 182 + if err != nil { 183 + return fmt.Errorf("invalid callback URL: %w", err) 184 + } 185 + 186 + // Compare scheme, host, and path (ignore query params) 187 + if expected.Scheme != actual.Scheme { 188 + return fmt.Errorf("callback URL scheme mismatch: expected %s, got %s", expected.Scheme, actual.Scheme) 189 + } 190 + if expected.Host != actual.Host { 191 + return fmt.Errorf("callback URL host mismatch: expected %s, got %s", expected.Host, actual.Host) 192 + } 193 + if expected.Path != actual.Path { 194 + return fmt.Errorf("callback URL path mismatch: expected %s, got %s", expected.Path, actual.Path) 195 + } 196 + 197 + return nil 198 + }
+709
internal/atproto/oauth/handlers.go
··· 1 + package oauth 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "fmt" 7 + "log/slog" 8 + "net/http" 9 + "net/url" 10 + 11 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 12 + "github.com/bluesky-social/indigo/atproto/syntax" 13 + ) 14 + 15 + // MobileOAuthStore interface for mobile-specific OAuth operations 16 + // This extends the base OAuth store with mobile CSRF tracking 17 + type MobileOAuthStore interface { 18 + SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error 19 + GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) 20 + } 21 + 22 + // OAuthHandler handles OAuth-related HTTP endpoints 23 + type OAuthHandler struct { 24 + client *OAuthClient 25 + store oauth.ClientAuthStore 26 + mobileStore MobileOAuthStore // For server-side CSRF validation 27 + } 28 + 29 + // NewOAuthHandler creates a new OAuth handler 30 + func NewOAuthHandler(client *OAuthClient, store oauth.ClientAuthStore) *OAuthHandler { 31 + handler := &OAuthHandler{ 32 + client: client, 33 + store: store, 34 + } 35 + 36 + // Check if the store implements MobileOAuthStore for server-side CSRF 37 + if mobileStore, ok := store.(MobileOAuthStore); ok { 38 + handler.mobileStore = mobileStore 39 + } 40 + 41 + return handler 42 + } 43 + 44 + // HandleClientMetadata serves the OAuth client metadata document 45 + // GET /oauth/client-metadata.json 46 + func (h *OAuthHandler) HandleClientMetadata(w http.ResponseWriter, r *http.Request) { 47 + metadata := h.client.ClientMetadata() 48 + 49 + // For confidential clients in production, set JWKS URI based on request host 50 + if h.client.IsConfidential() && !h.client.Config.DevMode { 51 + jwksURI := fmt.Sprintf("https://%s/oauth/jwks.json", r.Host) 52 + metadata.JWKSURI = &jwksURI 53 + } 54 + 55 + // Validate metadata before returning (skip in dev mode - localhost doesn't need https validation) 56 + if !h.client.Config.DevMode { 57 + if err := metadata.Validate(h.client.ClientApp.Config.ClientID); err != nil { 58 + slog.Error("client metadata validation failed", "error", err) 59 + http.Error(w, "internal server error", http.StatusInternalServerError) 60 + return 61 + } 62 + } 63 + 64 + w.Header().Set("Content-Type", "application/json") 65 + if err := json.NewEncoder(w).Encode(metadata); err != nil { 66 + slog.Error("failed to encode client metadata", "error", err) 67 + http.Error(w, "internal server error", http.StatusInternalServerError) 68 + return 69 + } 70 + } 71 + 72 + // HandleJWKS serves the public JWKS for confidential clients 73 + // GET /oauth/jwks.json 74 + func (h *OAuthHandler) HandleJWKS(w http.ResponseWriter, r *http.Request) { 75 + jwks := h.client.PublicJWKS() 76 + 77 + w.Header().Set("Content-Type", "application/json") 78 + if err := json.NewEncoder(w).Encode(jwks); err != nil { 79 + slog.Error("failed to encode JWKS", "error", err) 80 + http.Error(w, "internal server error", http.StatusInternalServerError) 81 + return 82 + } 83 + } 84 + 85 + // HandleLogin starts the OAuth flow (web version) 86 + // GET /oauth/login?handle=user.bsky.social 87 + func (h *OAuthHandler) HandleLogin(w http.ResponseWriter, r *http.Request) { 88 + ctx := r.Context() 89 + 90 + // Get handle or DID from query params 91 + identifier := r.URL.Query().Get("handle") 92 + if identifier == "" { 93 + identifier = r.URL.Query().Get("did") 94 + } 95 + if identifier == "" { 96 + http.Error(w, "missing handle or did parameter", http.StatusBadRequest) 97 + return 98 + } 99 + 100 + // Start OAuth flow 101 + redirectURL, err := h.client.ClientApp.StartAuthFlow(ctx, identifier) 102 + if err != nil { 103 + slog.Error("failed to start OAuth flow", "error", err, "identifier", identifier) 104 + http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest) 105 + return 106 + } 107 + 108 + // Log OAuth flow initiation (sanitized - no full URL to avoid leaking state) 109 + slog.Info("redirecting to PDS for OAuth", "identifier", identifier) 110 + 111 + // Redirect to PDS 112 + http.Redirect(w, r, redirectURL, http.StatusFound) 113 + } 114 + 115 + // HandleMobileLogin starts the OAuth flow for mobile apps 116 + // GET /oauth/mobile/login?handle=user.bsky.social&redirect_uri=coves-app://callback 117 + func (h *OAuthHandler) HandleMobileLogin(w http.ResponseWriter, r *http.Request) { 118 + ctx := r.Context() 119 + 120 + // Get handle or DID from query params 121 + identifier := r.URL.Query().Get("handle") 122 + if identifier == "" { 123 + identifier = r.URL.Query().Get("did") 124 + } 125 + if identifier == "" { 126 + http.Error(w, "missing handle or did parameter", http.StatusBadRequest) 127 + return 128 + } 129 + 130 + // Get mobile redirect URI (deep link) 131 + mobileRedirectURI := r.URL.Query().Get("redirect_uri") 132 + if mobileRedirectURI == "" { 133 + http.Error(w, "missing redirect_uri parameter", http.StatusBadRequest) 134 + return 135 + } 136 + 137 + // SECURITY FIX 1: Validate redirect_uri against allowlist 138 + if !isAllowedMobileRedirectURI(mobileRedirectURI) { 139 + slog.Warn("rejected unauthorized mobile redirect URI", "scheme", extractScheme(mobileRedirectURI)) 140 + http.Error(w, "invalid redirect_uri: scheme not allowed", http.StatusBadRequest) 141 + return 142 + } 143 + 144 + // SECURITY: Verify store is properly configured for mobile OAuth 145 + // A plain PostgresOAuthStore implements MobileOAuthStore (has Save/GetMobileOAuthData), 146 + // but without the MobileAwareStoreWrapper, SaveMobileOAuthData is never called during 147 + // StartAuthFlow, so server-side CSRF data is never stored. This causes mobile callbacks 148 + // to silently fall back to web flow. Fail fast here instead of silent breakage. 149 + if _, ok := h.store.(MobileAwareClientStore); !ok { 150 + slog.Error("mobile OAuth not supported: store is not wrapped with MobileAwareStoreWrapper", 151 + "store_type", fmt.Sprintf("%T", h.store)) 152 + http.Error(w, "mobile OAuth not configured on this server", http.StatusInternalServerError) 153 + return 154 + } 155 + 156 + // SECURITY FIX 2: Generate CSRF token 157 + csrfToken, err := generateCSRFToken() 158 + if err != nil { 159 + http.Error(w, "internal server error", http.StatusInternalServerError) 160 + return 161 + } 162 + 163 + // SECURITY FIX 4: Store CSRF server-side tied to OAuth state 164 + // Add mobile data to context so the store wrapper can capture it when 165 + // SaveAuthRequestInfo is called by indigo's StartAuthFlow. 166 + // This is necessary because PAR redirects don't include the state in the URL, 167 + // so we can't extract it after StartAuthFlow returns. 168 + mobileCtx := ContextWithMobileFlowData(ctx, MobileOAuthData{ 169 + CSRFToken: csrfToken, 170 + RedirectURI: mobileRedirectURI, 171 + }) 172 + 173 + // Start OAuth flow (the store wrapper will save mobile data when auth request is saved) 174 + redirectURL, err := h.client.ClientApp.StartAuthFlow(mobileCtx, identifier) 175 + if err != nil { 176 + slog.Error("failed to start OAuth flow", "error", err, "identifier", identifier) 177 + http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest) 178 + return 179 + } 180 + 181 + // Log mobile OAuth flow initiation (sanitized - no full URLs or sensitive params) 182 + slog.Info("redirecting to PDS for mobile OAuth", "identifier", identifier) 183 + 184 + // SECURITY FIX 2: Store CSRF token in cookie 185 + http.SetCookie(w, &http.Cookie{ 186 + Name: "oauth_csrf", 187 + Value: csrfToken, 188 + Path: "/oauth", 189 + MaxAge: 600, // 10 minutes 190 + HttpOnly: true, 191 + Secure: !h.client.Config.DevMode, 192 + SameSite: http.SameSiteLaxMode, 193 + }) 194 + 195 + // SECURITY FIX 3: Generate binding token to tie CSRF token + mobile redirect to this OAuth flow 196 + // This prevents session fixation attacks where an attacker plants a mobile_redirect_uri 197 + // cookie, then the user does a web login, and credentials get sent to attacker's deep link. 198 + // The binding includes the CSRF token so we validate its VALUE (not just presence) on callback. 199 + mobileBinding := generateMobileRedirectBinding(csrfToken, mobileRedirectURI) 200 + 201 + // Set cookie with mobile redirect URI for use in callback 202 + http.SetCookie(w, &http.Cookie{ 203 + Name: "mobile_redirect_uri", 204 + Value: url.QueryEscape(mobileRedirectURI), 205 + Path: "/oauth", 206 + HttpOnly: true, 207 + Secure: !h.client.Config.DevMode, 208 + SameSite: http.SameSiteLaxMode, 209 + MaxAge: 600, // 10 minutes 210 + }) 211 + 212 + // Set binding cookie to validate mobile redirect in callback 213 + http.SetCookie(w, &http.Cookie{ 214 + Name: "mobile_redirect_binding", 215 + Value: mobileBinding, 216 + Path: "/oauth", 217 + HttpOnly: true, 218 + Secure: !h.client.Config.DevMode, 219 + SameSite: http.SameSiteLaxMode, 220 + MaxAge: 600, // 10 minutes 221 + }) 222 + 223 + // Redirect to PDS 224 + http.Redirect(w, r, redirectURL, http.StatusFound) 225 + } 226 + 227 + // HandleCallback handles the OAuth callback from the PDS 228 + // GET /oauth/callback?code=...&state=...&iss=... 229 + func (h *OAuthHandler) HandleCallback(w http.ResponseWriter, r *http.Request) { 230 + ctx := r.Context() 231 + 232 + // IMPORTANT: Look up mobile CSRF data BEFORE ProcessCallback 233 + // ProcessCallback deletes the oauth_requests row, so we must retrieve mobile data first. 234 + // We store it in a local variable for validation after ProcessCallback completes. 235 + var serverMobileData *MobileOAuthData 236 + var mobileDataLookupErr error 237 + oauthState := r.URL.Query().Get("state") 238 + 239 + // Check if this might be a mobile callback (mobile_redirect_uri cookie present) 240 + // We do a preliminary check here to decide if we need to fetch mobile data 241 + mobileRedirectCookie, _ := r.Cookie("mobile_redirect_uri") 242 + isMobileFlow := mobileRedirectCookie != nil && mobileRedirectCookie.Value != "" 243 + 244 + if isMobileFlow && h.mobileStore != nil && oauthState != "" { 245 + // Fetch mobile data BEFORE ProcessCallback deletes the row 246 + serverMobileData, mobileDataLookupErr = h.mobileStore.GetMobileOAuthData(ctx, oauthState) 247 + // We'll handle errors after ProcessCallback - for now just capture the result 248 + } 249 + 250 + // Process the callback (this deletes the oauth_requests row) 251 + sessData, err := h.client.ClientApp.ProcessCallback(ctx, r.URL.Query()) 252 + if err != nil { 253 + slog.Error("failed to process OAuth callback", "error", err) 254 + http.Error(w, fmt.Sprintf("OAuth callback failed: %v", err), http.StatusBadRequest) 255 + return 256 + } 257 + 258 + // Ensure sessData is not nil before using it 259 + if sessData == nil { 260 + slog.Error("OAuth callback returned nil session data") 261 + http.Error(w, "OAuth callback failed: no session data", http.StatusInternalServerError) 262 + return 263 + } 264 + 265 + // Bidirectional handle verification: ensure the DID actually controls a valid handle 266 + // This prevents impersonation via compromised PDS that issues tokens with invalid handle mappings 267 + // Per AT Protocol spec: "Bidirectional verification required; confirm DID document claims handle" 268 + if h.client.ClientApp.Dir != nil { 269 + ident, err := h.client.ClientApp.Dir.LookupDID(ctx, sessData.AccountDID) 270 + if err != nil { 271 + // Directory lookup failed - this is a hard error for security 272 + slog.Error("OAuth callback: DID lookup failed during handle verification", 273 + "did", sessData.AccountDID, "error", err) 274 + http.Error(w, "Handle verification failed", http.StatusUnauthorized) 275 + return 276 + } 277 + 278 + // Check if the handle is the special "handle.invalid" value 279 + // This indicates that bidirectional verification failed (DID->handle->DID roundtrip failed) 280 + if ident.Handle.String() == "handle.invalid" { 281 + slog.Warn("OAuth callback: bidirectional handle verification failed", 282 + "did", sessData.AccountDID, 283 + "handle", "handle.invalid", 284 + "reason", "DID document claims a handle that doesn't resolve back to this DID") 285 + http.Error(w, "Handle verification failed: DID/handle mismatch", http.StatusUnauthorized) 286 + return 287 + } 288 + 289 + // Success: handle is valid and bidirectionally verified 290 + slog.Info("OAuth callback successful", "did", sessData.AccountDID, "handle", ident.Handle) 291 + } else { 292 + // No directory client available - log warning but proceed 293 + // This should only happen in testing scenarios 294 + slog.Warn("OAuth callback: directory client not available, skipping handle verification", 295 + "did", sessData.AccountDID) 296 + slog.Info("OAuth callback successful (no handle verification)", "did", sessData.AccountDID) 297 + } 298 + 299 + // Check if this is a mobile callback (check for mobile_redirect_uri cookie) 300 + mobileRedirect, err := r.Cookie("mobile_redirect_uri") 301 + if err == nil && mobileRedirect.Value != "" { 302 + // SECURITY FIX 2: Validate CSRF token for mobile callback 303 + csrfCookie, err := r.Cookie("oauth_csrf") 304 + if err != nil { 305 + slog.Warn("mobile callback missing CSRF token") 306 + clearMobileCookies(w) 307 + http.Error(w, "invalid request: missing CSRF token", http.StatusForbidden) 308 + return 309 + } 310 + 311 + // SECURITY FIX 3: Validate mobile redirect binding 312 + // This prevents session fixation attacks where an attacker plants a mobile_redirect_uri 313 + // cookie, then the user does a web login, and credentials get sent to attacker's deep link 314 + bindingCookie, err := r.Cookie("mobile_redirect_binding") 315 + if err != nil { 316 + slog.Warn("mobile callback missing redirect binding - possible attack attempt") 317 + clearMobileCookies(w) 318 + http.Error(w, "invalid request: missing redirect binding", http.StatusForbidden) 319 + return 320 + } 321 + 322 + // Decode the mobile redirect URI to validate binding 323 + mobileRedirectURI, err := url.QueryUnescape(mobileRedirect.Value) 324 + if err != nil { 325 + slog.Error("failed to decode mobile redirect URI", "error", err) 326 + clearMobileCookies(w) 327 + http.Error(w, "invalid mobile redirect URI", http.StatusBadRequest) 328 + return 329 + } 330 + 331 + // Validate that the binding matches both the CSRF token AND redirect URI 332 + // This is the actual CSRF validation - we verify the token VALUE by checking 333 + // that hash(csrfToken + redirectURI) == binding. This prevents: 334 + // 1. CSRF attacks: attacker can't forge binding without knowing CSRF token 335 + // 2. Session fixation: cookies must all originate from the same /oauth/mobile/login request 336 + if !validateMobileRedirectBinding(csrfCookie.Value, mobileRedirectURI, bindingCookie.Value) { 337 + slog.Warn("mobile redirect binding/CSRF validation failed - possible attack attempt", 338 + "expected_scheme", extractScheme(mobileRedirectURI)) 339 + clearMobileCookies(w) 340 + // Fail closed: treat as web flow instead of mobile 341 + h.handleWebCallback(w, r, sessData) 342 + return 343 + } 344 + 345 + // SECURITY FIX 4: Validate CSRF cookie against server-side state 346 + // This compares the cookie CSRF against a value tied to the OAuth state parameter 347 + // (which comes back through the OAuth response), satisfying the requirement to 348 + // validate against server-side state rather than only against other cookies. 349 + // 350 + // CRITICAL: If mobile cookies are present but server-side mobile data is MISSING, 351 + // this indicates a potential attack where: 352 + // 1. Attacker did a WEB OAuth flow (no mobile data stored) 353 + // 2. Attacker planted mobile cookies via cross-site /oauth/mobile/login 354 + // 3. Attacker sends victim to callback with attacker's web-flow state/code 355 + // We MUST fail closed and use web flow when server-side mobile data is missing. 356 + // 357 + // NOTE: serverMobileData was fetched BEFORE ProcessCallback (which deletes the row) 358 + // at the top of this function. We use the pre-fetched result here. 359 + if h.mobileStore != nil && oauthState != "" { 360 + if mobileDataLookupErr != nil { 361 + // Database error - fail closed, use web flow 362 + slog.Warn("failed to retrieve server-side mobile OAuth data - using web flow", 363 + "error", mobileDataLookupErr, "state", oauthState) 364 + clearMobileCookies(w) 365 + h.handleWebCallback(w, r, sessData) 366 + return 367 + } 368 + if serverMobileData == nil { 369 + // No server-side mobile data for this state - this OAuth flow was NOT started 370 + // via /oauth/mobile/login. Mobile cookies are likely attacker-planted. 371 + // Fail closed: clear cookies and use web flow. 372 + slog.Warn("mobile cookies present but no server-side mobile data for OAuth state - "+ 373 + "possible cross-flow attack, using web flow", "state", oauthState) 374 + clearMobileCookies(w) 375 + h.handleWebCallback(w, r, sessData) 376 + return 377 + } 378 + // Server-side mobile data exists - validate it matches cookies 379 + if !constantTimeCompare(csrfCookie.Value, serverMobileData.CSRFToken) { 380 + slog.Warn("mobile callback CSRF mismatch: cookie differs from server-side state", 381 + "state", oauthState) 382 + clearMobileCookies(w) 383 + h.handleWebCallback(w, r, sessData) 384 + return 385 + } 386 + if serverMobileData.RedirectURI != mobileRedirectURI { 387 + slog.Warn("mobile callback redirect URI mismatch: cookie differs from server-side state", 388 + "cookie_uri", extractScheme(mobileRedirectURI), 389 + "server_uri", extractScheme(serverMobileData.RedirectURI)) 390 + clearMobileCookies(w) 391 + h.handleWebCallback(w, r, sessData) 392 + return 393 + } 394 + slog.Debug("server-side CSRF validation passed", "state", oauthState) 395 + } else if h.mobileStore != nil { 396 + // mobileStore exists but no state in query - shouldn't happen with valid OAuth 397 + slog.Warn("mobile cookies present but no OAuth state in callback - using web flow") 398 + clearMobileCookies(w) 399 + h.handleWebCallback(w, r, sessData) 400 + return 401 + } 402 + // Note: if h.mobileStore is nil (e.g., in tests), we fall back to cookie-only validation 403 + 404 + // All security checks passed - proceed with mobile flow 405 + // Mobile flow: seal the session and redirect to deep link 406 + h.handleMobileCallback(w, r, sessData, mobileRedirect.Value, csrfCookie.Value) 407 + return 408 + } 409 + 410 + // Web flow: set session cookie 411 + h.handleWebCallback(w, r, sessData) 412 + } 413 + 414 + // handleWebCallback handles the web OAuth callback flow 415 + func (h *OAuthHandler) handleWebCallback(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) { 416 + // Use sealed tokens for web flow (same as mobile) per atProto OAuth spec: 417 + // "Access and refresh tokens should never be copied or shared across end devices. 418 + // They should not be stored in session cookies." 419 + 420 + // Seal the session data using AES-GCM encryption 421 + sealedToken, err := h.client.SealSession( 422 + sessData.AccountDID.String(), 423 + sessData.SessionID, 424 + h.client.Config.SealedTokenTTL, 425 + ) 426 + if err != nil { 427 + slog.Error("failed to seal session for web", "error", err) 428 + http.Error(w, "failed to create session", http.StatusInternalServerError) 429 + return 430 + } 431 + 432 + http.SetCookie(w, &http.Cookie{ 433 + Name: "coves_session", 434 + Value: sealedToken, 435 + Path: "/", 436 + HttpOnly: true, 437 + Secure: !h.client.Config.DevMode, 438 + SameSite: http.SameSiteLaxMode, 439 + MaxAge: int(h.client.Config.SealedTokenTTL.Seconds()), 440 + }) 441 + 442 + // Clear all mobile cookies if they exist (defense in depth) 443 + clearMobileCookies(w) 444 + 445 + // Redirect to home or app 446 + redirectURL := "/" 447 + if !h.client.Config.DevMode { 448 + redirectURL = h.client.Config.PublicURL + "/" 449 + } 450 + http.Redirect(w, r, redirectURL, http.StatusFound) 451 + } 452 + 453 + // handleMobileCallback handles the mobile OAuth callback flow 454 + func (h *OAuthHandler) handleMobileCallback(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData, mobileRedirectURIEncoded, csrfToken string) { 455 + // Decode the mobile redirect URI 456 + mobileRedirectURI, err := url.QueryUnescape(mobileRedirectURIEncoded) 457 + if err != nil { 458 + slog.Error("failed to decode mobile redirect URI", "error", err) 459 + http.Error(w, "invalid mobile redirect URI", http.StatusBadRequest) 460 + return 461 + } 462 + 463 + // SECURITY FIX 1: Re-validate redirect URI against allowlist 464 + if !isAllowedMobileRedirectURI(mobileRedirectURI) { 465 + slog.Error("mobile callback attempted with unauthorized redirect URI", "scheme", extractScheme(mobileRedirectURI)) 466 + http.Error(w, "invalid redirect URI", http.StatusBadRequest) 467 + return 468 + } 469 + 470 + // Seal the session data for mobile 471 + sealedToken, err := h.client.SealSession( 472 + sessData.AccountDID.String(), 473 + sessData.SessionID, 474 + h.client.Config.SealedTokenTTL, 475 + ) 476 + if err != nil { 477 + slog.Error("failed to seal session data", "error", err) 478 + http.Error(w, "failed to create session token", http.StatusInternalServerError) 479 + return 480 + } 481 + 482 + // Get account handle for convenience 483 + handle := "" 484 + if ident, err := h.client.ClientApp.Dir.LookupDID(r.Context(), sessData.AccountDID); err == nil { 485 + handle = ident.Handle.String() 486 + } 487 + 488 + // Clear all mobile cookies to prevent reuse (defense in depth) 489 + clearMobileCookies(w) 490 + 491 + // Build deep link with sealed token 492 + deepLink := fmt.Sprintf("%s?token=%s&did=%s&session_id=%s", 493 + mobileRedirectURI, 494 + url.QueryEscape(sealedToken), 495 + url.QueryEscape(sessData.AccountDID.String()), 496 + url.QueryEscape(sessData.SessionID), 497 + ) 498 + if handle != "" { 499 + deepLink += "&handle=" + url.QueryEscape(handle) 500 + } 501 + 502 + // Log mobile redirect (sanitized - no token or session ID to avoid leaking credentials) 503 + slog.Info("redirecting to mobile app", "did", sessData.AccountDID, "handle", handle) 504 + 505 + // Redirect to mobile app deep link 506 + http.Redirect(w, r, deepLink, http.StatusFound) 507 + } 508 + 509 + // HandleLogout revokes the session and clears cookies 510 + // POST /oauth/logout 511 + func (h *OAuthHandler) HandleLogout(w http.ResponseWriter, r *http.Request) { 512 + ctx := r.Context() 513 + 514 + // Get session from cookie (now sealed) 515 + cookie, err := r.Cookie("coves_session") 516 + if err != nil { 517 + // No session, just return success 518 + w.WriteHeader(http.StatusOK) 519 + _ = json.NewEncoder(w).Encode(map[string]string{"status": "logged_out"}) 520 + return 521 + } 522 + 523 + // Unseal the session token 524 + sealed, err := h.client.UnsealSession(cookie.Value) 525 + if err != nil { 526 + // Invalid session, clear cookie and return 527 + h.clearSessionCookie(w) 528 + w.WriteHeader(http.StatusOK) 529 + _ = json.NewEncoder(w).Encode(map[string]string{"status": "logged_out"}) 530 + return 531 + } 532 + 533 + // Parse DID 534 + did, err := syntax.ParseDID(sealed.DID) 535 + if err != nil { 536 + // Invalid DID, clear cookie and return 537 + h.clearSessionCookie(w) 538 + w.WriteHeader(http.StatusOK) 539 + _ = json.NewEncoder(w).Encode(map[string]string{"status": "logged_out"}) 540 + return 541 + } 542 + 543 + // Revoke session on auth server 544 + if err := h.client.ClientApp.Logout(ctx, did, sealed.SessionID); err != nil { 545 + slog.Error("failed to revoke session on auth server", "error", err, "did", did) 546 + // Continue anyway to clear local session 547 + } 548 + 549 + // Clear session cookie 550 + h.clearSessionCookie(w) 551 + 552 + w.Header().Set("Content-Type", "application/json") 553 + w.WriteHeader(http.StatusOK) 554 + _ = json.NewEncoder(w).Encode(map[string]string{"status": "logged_out"}) 555 + } 556 + 557 + // HandleRefresh refreshes the session token (for mobile) 558 + // POST /oauth/refresh 559 + // Body: {"did": "did:plc:...", "session_id": "...", "sealed_token": "..."} 560 + func (h *OAuthHandler) HandleRefresh(w http.ResponseWriter, r *http.Request) { 561 + ctx := r.Context() 562 + 563 + var req struct { 564 + DID string `json:"did"` 565 + SessionID string `json:"session_id"` 566 + SealedToken string `json:"sealed_token,omitempty"` 567 + } 568 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 569 + http.Error(w, "invalid request body", http.StatusBadRequest) 570 + return 571 + } 572 + 573 + // SECURITY: Require sealed_token for proof of possession 574 + // Without this, anyone who knows a DID + session_id can steal credentials 575 + if req.SealedToken == "" { 576 + slog.Warn("refresh: missing sealed_token", "did", req.DID) 577 + http.Error(w, "sealed_token required for refresh", http.StatusUnauthorized) 578 + return 579 + } 580 + 581 + // SECURITY: Unseal and validate the token 582 + unsealed, err := h.client.UnsealSession(req.SealedToken) 583 + if err != nil { 584 + slog.Warn("refresh: invalid sealed token", "error", err) 585 + http.Error(w, "Invalid or expired token", http.StatusUnauthorized) 586 + return 587 + } 588 + 589 + // SECURITY: Verify the unsealed token matches the claimed DID 590 + if unsealed.DID != req.DID { 591 + slog.Warn("refresh: DID mismatch", "token_did", unsealed.DID, "claimed_did", req.DID) 592 + http.Error(w, "Token DID mismatch", http.StatusUnauthorized) 593 + return 594 + } 595 + 596 + // SECURITY: Verify the unsealed token matches the claimed session_id 597 + if unsealed.SessionID != req.SessionID { 598 + slog.Warn("refresh: session_id mismatch", "token_session", unsealed.SessionID, "claimed_session", req.SessionID) 599 + http.Error(w, "Token session mismatch", http.StatusUnauthorized) 600 + return 601 + } 602 + 603 + // Parse DID after validation 604 + did, err := syntax.ParseDID(req.DID) 605 + if err != nil { 606 + http.Error(w, "invalid DID", http.StatusBadRequest) 607 + return 608 + } 609 + 610 + // Resume session (now authenticated via sealed token) 611 + sess, err := h.client.ClientApp.ResumeSession(ctx, did, req.SessionID) 612 + if err != nil { 613 + slog.Error("failed to resume session", "error", err, "did", did, "session_id", req.SessionID) 614 + http.Error(w, "session not found", http.StatusUnauthorized) 615 + return 616 + } 617 + 618 + // Refresh tokens 619 + newAccessToken, err := sess.RefreshTokens(ctx) 620 + if err != nil { 621 + slog.Error("failed to refresh tokens", "error", err, "did", did) 622 + http.Error(w, "failed to refresh tokens", http.StatusUnauthorized) 623 + return 624 + } 625 + 626 + // Create new sealed token for mobile 627 + sealedToken, err := h.client.SealSession( 628 + sess.Data.AccountDID.String(), 629 + sess.Data.SessionID, 630 + h.client.Config.SealedTokenTTL, 631 + ) 632 + if err != nil { 633 + slog.Error("failed to seal new session data", "error", err) 634 + http.Error(w, "failed to create session token", http.StatusInternalServerError) 635 + return 636 + } 637 + 638 + w.Header().Set("Content-Type", "application/json") 639 + _ = json.NewEncoder(w).Encode(map[string]interface{}{ 640 + "access_token": newAccessToken, 641 + "sealed_token": sealedToken, 642 + }) 643 + } 644 + 645 + // clearSessionCookie clears the session cookie 646 + func (h *OAuthHandler) clearSessionCookie(w http.ResponseWriter) { 647 + http.SetCookie(w, &http.Cookie{ 648 + Name: "coves_session", 649 + Value: "", 650 + Path: "/", 651 + MaxAge: -1, 652 + }) 653 + } 654 + 655 + // GetSessionFromRequest extracts session data from an HTTP request 656 + func (h *OAuthHandler) GetSessionFromRequest(r *http.Request) (*oauth.ClientSessionData, error) { 657 + // Try to get session from cookie (web) - now using sealed tokens 658 + cookie, err := r.Cookie("coves_session") 659 + if err == nil && cookie.Value != "" { 660 + // Unseal the token to get DID and session ID 661 + sealed, err := h.client.UnsealSession(cookie.Value) 662 + if err == nil { 663 + did, err := syntax.ParseDID(sealed.DID) 664 + if err == nil { 665 + return h.store.GetSession(r.Context(), did, sealed.SessionID) 666 + } 667 + } 668 + } 669 + 670 + // Try to get session from Authorization header (mobile) 671 + authHeader := r.Header.Get("Authorization") 672 + if authHeader != "" { 673 + // Expected format: "Bearer <sealed_token>" 674 + const prefix = "Bearer " 675 + if len(authHeader) > len(prefix) && authHeader[:len(prefix)] == prefix { 676 + sealedToken := authHeader[len(prefix):] 677 + sealed, err := h.client.UnsealSession(sealedToken) 678 + if err != nil { 679 + return nil, fmt.Errorf("invalid sealed token: %w", err) 680 + } 681 + did, err := syntax.ParseDID(sealed.DID) 682 + if err != nil { 683 + return nil, fmt.Errorf("invalid DID in sealed token: %w", err) 684 + } 685 + return h.store.GetSession(r.Context(), did, sealed.SessionID) 686 + } 687 + } 688 + 689 + return nil, fmt.Errorf("no session found") 690 + } 691 + 692 + // HandleProtectedResourceMetadata returns OAuth protected resource metadata 693 + // per RFC 9449 and atproto OAuth spec. This endpoint allows third-party OAuth 694 + // clients to discover which authorization server to use for this resource. 695 + // Spec: https://datatracker.ietf.org/doc/html/rfc9449#section-5 696 + func (h *OAuthHandler) HandleProtectedResourceMetadata(w http.ResponseWriter, r *http.Request) { 697 + metadata := map[string]interface{}{ 698 + "resource": h.client.Config.PublicURL, 699 + "authorization_servers": []string{"https://bsky.social"}, 700 + } 701 + 702 + w.Header().Set("Content-Type", "application/json") 703 + w.Header().Set("Cache-Control", "public, max-age=3600") 704 + if err := json.NewEncoder(w).Encode(metadata); err != nil { 705 + slog.Error("failed to encode protected resource metadata", "error", err) 706 + http.Error(w, "internal server error", http.StatusInternalServerError) 707 + return 708 + } 709 + }
+126
internal/atproto/oauth/handlers_security.go
··· 1 + package oauth 2 + 3 + import ( 4 + "crypto/rand" 5 + "crypto/sha256" 6 + "encoding/base64" 7 + "log/slog" 8 + "net/http" 9 + "net/url" 10 + ) 11 + 12 + // allowedMobileRedirectURIs contains the EXACT allowed redirect URIs for mobile apps. 13 + // SECURITY: Only Universal Links (HTTPS) are allowed - cryptographically bound to app. 14 + // 15 + // Universal Links provide strong security guarantees: 16 + // - iOS: Verified via /.well-known/apple-app-site-association 17 + // - Android: Verified via /.well-known/assetlinks.json 18 + // - System verifies domain ownership before routing to app 19 + // - Prevents malicious apps from intercepting OAuth callbacks 20 + // 21 + // Custom URL schemes (coves-app://, coves://) are NOT allowed because: 22 + // - Any app can register the same scheme and intercept tokens 23 + // - No cryptographic binding to app identity 24 + // - Token theft is trivial for malicious apps 25 + // 26 + // See: https://atproto.com/specs/oauth#mobile-clients 27 + var allowedMobileRedirectURIs = map[string]bool{ 28 + // Universal Links only - cryptographically bound to app 29 + "https://coves.social/app/oauth/callback": true, 30 + } 31 + 32 + // isAllowedMobileRedirectURI validates that the redirect URI is in the exact allowlist. 33 + // SECURITY: Exact URI matching prevents token theft by rogue apps that register the same scheme. 34 + // 35 + // Custom URL schemes are NOT cryptographically bound to apps: 36 + // - Any app on the device can register "coves-app://" or "coves://" 37 + // - A malicious app can intercept deep links intended for Coves 38 + // - Without exact URI matching, the attacker receives the sealed token 39 + // 40 + // This function performs EXACT matching (not scheme-only) as a security measure. 41 + // For production, migrate to Universal Links (iOS) or App Links (Android). 42 + func isAllowedMobileRedirectURI(redirectURI string) bool { 43 + // Normalize and check exact match 44 + return allowedMobileRedirectURIs[redirectURI] 45 + } 46 + 47 + // extractScheme extracts the scheme from a URI for logging purposes 48 + func extractScheme(uri string) string { 49 + if u, err := url.Parse(uri); err == nil && u.Scheme != "" { 50 + return u.Scheme 51 + } 52 + return "invalid" 53 + } 54 + 55 + // generateCSRFToken generates a cryptographically secure CSRF token 56 + func generateCSRFToken() (string, error) { 57 + csrfToken := make([]byte, 32) 58 + if _, err := rand.Read(csrfToken); err != nil { 59 + slog.Error("failed to generate CSRF token", "error", err) 60 + return "", err 61 + } 62 + return base64.URLEncoding.EncodeToString(csrfToken), nil 63 + } 64 + 65 + // generateMobileRedirectBinding generates a cryptographically secure binding token 66 + // that ties the CSRF token and mobile redirect URI to this specific OAuth flow. 67 + // SECURITY: This prevents multiple attack vectors: 68 + // 1. Session fixation: attacker plants mobile_redirect_uri cookie, user does web login 69 + // 2. CSRF bypass: attacker manipulates cookies without knowing the CSRF token 70 + // 3. Cookie replay: binding validates both CSRF and redirect URI together 71 + // 72 + // The binding is hash(csrfToken + "|" + mobileRedirectURI) which ensures: 73 + // - CSRF token value is verified (not just presence) 74 + // - Redirect URI is tied to the specific CSRF token that started the flow 75 + // - Cannot forge binding without knowing both values 76 + func generateMobileRedirectBinding(csrfToken, mobileRedirectURI string) string { 77 + // Combine CSRF token and redirect URI with separator to prevent length extension 78 + combined := csrfToken + "|" + mobileRedirectURI 79 + hash := sha256.Sum256([]byte(combined)) 80 + // Use first 16 bytes (128 bits) for the binding - sufficient for this purpose 81 + return base64.URLEncoding.EncodeToString(hash[:16]) 82 + } 83 + 84 + // validateMobileRedirectBinding validates that the CSRF token and mobile redirect URI 85 + // together match the binding token, preventing CSRF attacks and cross-flow token theft. 86 + // This implements a proper double-submit cookie pattern where the CSRF token value 87 + // (not just presence) is cryptographically verified. 88 + func validateMobileRedirectBinding(csrfToken, mobileRedirectURI, binding string) bool { 89 + expectedBinding := generateMobileRedirectBinding(csrfToken, mobileRedirectURI) 90 + // Constant-time comparison to prevent timing attacks 91 + return constantTimeCompare(expectedBinding, binding) 92 + } 93 + 94 + // constantTimeCompare performs a constant-time string comparison to prevent timing attacks 95 + func constantTimeCompare(a, b string) bool { 96 + if len(a) != len(b) { 97 + return false 98 + } 99 + var result byte 100 + for i := 0; i < len(a); i++ { 101 + result |= a[i] ^ b[i] 102 + } 103 + return result == 0 104 + } 105 + 106 + // clearMobileCookies clears all mobile-related cookies to prevent reuse 107 + func clearMobileCookies(w http.ResponseWriter) { 108 + http.SetCookie(w, &http.Cookie{ 109 + Name: "mobile_redirect_uri", 110 + Value: "", 111 + Path: "/oauth", 112 + MaxAge: -1, 113 + }) 114 + http.SetCookie(w, &http.Cookie{ 115 + Name: "mobile_redirect_binding", 116 + Value: "", 117 + Path: "/oauth", 118 + MaxAge: -1, 119 + }) 120 + http.SetCookie(w, &http.Cookie{ 121 + Name: "oauth_csrf", 122 + Value: "", 123 + Path: "/oauth", 124 + MaxAge: -1, 125 + }) 126 + }
+477
internal/atproto/oauth/handlers_security_test.go
··· 1 + package oauth 2 + 3 + import ( 4 + "net/http" 5 + "net/http/httptest" 6 + "testing" 7 + 8 + "github.com/stretchr/testify/assert" 9 + "github.com/stretchr/testify/require" 10 + ) 11 + 12 + // TestIsAllowedMobileRedirectURI tests the mobile redirect URI allowlist with EXACT URI matching 13 + // Only Universal Links (HTTPS) are allowed - custom schemes are blocked for security 14 + func TestIsAllowedMobileRedirectURI(t *testing.T) { 15 + tests := []struct { 16 + name string 17 + uri string 18 + expected bool 19 + }{ 20 + { 21 + name: "allowed - Universal Link", 22 + uri: "https://coves.social/app/oauth/callback", 23 + expected: true, 24 + }, 25 + { 26 + name: "rejected - custom scheme coves-app (vulnerable to interception)", 27 + uri: "coves-app://oauth/callback", 28 + expected: false, 29 + }, 30 + { 31 + name: "rejected - custom scheme coves (vulnerable to interception)", 32 + uri: "coves://oauth/callback", 33 + expected: false, 34 + }, 35 + { 36 + name: "rejected - evil scheme", 37 + uri: "evil://callback", 38 + expected: false, 39 + }, 40 + { 41 + name: "rejected - http (not secure)", 42 + uri: "http://example.com/callback", 43 + expected: false, 44 + }, 45 + { 46 + name: "rejected - https different domain", 47 + uri: "https://example.com/callback", 48 + expected: false, 49 + }, 50 + { 51 + name: "rejected - https coves.social wrong path", 52 + uri: "https://coves.social/wrong/path", 53 + expected: false, 54 + }, 55 + { 56 + name: "rejected - invalid URI", 57 + uri: "not a uri", 58 + expected: false, 59 + }, 60 + { 61 + name: "rejected - empty string", 62 + uri: "", 63 + expected: false, 64 + }, 65 + } 66 + 67 + for _, tt := range tests { 68 + t.Run(tt.name, func(t *testing.T) { 69 + result := isAllowedMobileRedirectURI(tt.uri) 70 + assert.Equal(t, tt.expected, result, 71 + "isAllowedMobileRedirectURI(%q) = %v, want %v", tt.uri, result, tt.expected) 72 + }) 73 + } 74 + } 75 + 76 + // TestExtractScheme tests the scheme extraction function 77 + func TestExtractScheme(t *testing.T) { 78 + tests := []struct { 79 + name string 80 + uri string 81 + expected string 82 + }{ 83 + { 84 + name: "https scheme", 85 + uri: "https://coves.social/app/oauth/callback", 86 + expected: "https", 87 + }, 88 + { 89 + name: "custom scheme", 90 + uri: "coves-app://callback", 91 + expected: "coves-app", 92 + }, 93 + { 94 + name: "invalid URI", 95 + uri: "not a uri", 96 + expected: "invalid", 97 + }, 98 + } 99 + 100 + for _, tt := range tests { 101 + t.Run(tt.name, func(t *testing.T) { 102 + result := extractScheme(tt.uri) 103 + assert.Equal(t, tt.expected, result) 104 + }) 105 + } 106 + } 107 + 108 + // TestGenerateCSRFToken tests CSRF token generation 109 + func TestGenerateCSRFToken(t *testing.T) { 110 + // Generate two tokens and verify they are different (randomness check) 111 + token1, err1 := generateCSRFToken() 112 + require.NoError(t, err1) 113 + require.NotEmpty(t, token1) 114 + 115 + token2, err2 := generateCSRFToken() 116 + require.NoError(t, err2) 117 + require.NotEmpty(t, token2) 118 + 119 + assert.NotEqual(t, token1, token2, "CSRF tokens should be unique") 120 + 121 + // Verify token is base64 encoded (should decode without error) 122 + assert.Greater(t, len(token1), 40, "CSRF token should be reasonably long (32 bytes base64 encoded)") 123 + } 124 + 125 + // TestHandleMobileLogin_RedirectURIValidation tests that HandleMobileLogin validates redirect URIs 126 + func TestHandleMobileLogin_RedirectURIValidation(t *testing.T) { 127 + // Note: This is a unit test for the validation logic only. 128 + // Full integration tests with OAuth flow are in tests/integration/oauth_e2e_test.go 129 + 130 + tests := []struct { 131 + name string 132 + redirectURI string 133 + expectedLog string 134 + expectedStatus int 135 + }{ 136 + { 137 + name: "allowed - Universal Link", 138 + redirectURI: "https://coves.social/app/oauth/callback", 139 + expectedStatus: http.StatusBadRequest, // Will fail at StartAuthFlow (no OAuth client setup) 140 + }, 141 + { 142 + name: "rejected - custom scheme coves-app (insecure)", 143 + redirectURI: "coves-app://oauth/callback", 144 + expectedStatus: http.StatusBadRequest, 145 + expectedLog: "rejected unauthorized mobile redirect URI", 146 + }, 147 + { 148 + name: "rejected evil scheme", 149 + redirectURI: "evil://callback", 150 + expectedStatus: http.StatusBadRequest, 151 + expectedLog: "rejected unauthorized mobile redirect URI", 152 + }, 153 + { 154 + name: "rejected http", 155 + redirectURI: "http://evil.com/callback", 156 + expectedStatus: http.StatusBadRequest, 157 + expectedLog: "scheme not allowed", 158 + }, 159 + } 160 + 161 + for _, tt := range tests { 162 + t.Run(tt.name, func(t *testing.T) { 163 + // Test the validation function directly 164 + result := isAllowedMobileRedirectURI(tt.redirectURI) 165 + if tt.expectedLog != "" { 166 + assert.False(t, result, "Should reject %s", tt.redirectURI) 167 + } 168 + }) 169 + } 170 + } 171 + 172 + // TestHandleCallback_CSRFValidation tests that HandleCallback validates CSRF tokens for mobile flow 173 + func TestHandleCallback_CSRFValidation(t *testing.T) { 174 + // This is a conceptual test structure. Full implementation would require: 175 + // 1. Mock OAuthClient 176 + // 2. Mock OAuth store 177 + // 3. Simulated OAuth callback with cookies 178 + 179 + t.Run("mobile callback requires CSRF token", func(t *testing.T) { 180 + // Setup: Create request with mobile_redirect_uri cookie but NO oauth_csrf cookie 181 + req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test", nil) 182 + req.AddCookie(&http.Cookie{ 183 + Name: "mobile_redirect_uri", 184 + Value: "https%3A%2F%2Fcoves.social%2Fapp%2Foauth%2Fcallback", 185 + }) 186 + // Missing: oauth_csrf cookie 187 + 188 + // This would be rejected with 403 Forbidden in the actual handler 189 + // (Full test in integration tests with real OAuth flow) 190 + 191 + assert.NotNil(t, req) // Placeholder assertion 192 + }) 193 + 194 + t.Run("mobile callback with valid CSRF token", func(t *testing.T) { 195 + // Setup: Create request with both cookies 196 + req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test", nil) 197 + req.AddCookie(&http.Cookie{ 198 + Name: "mobile_redirect_uri", 199 + Value: "https%3A%2F%2Fcoves.social%2Fapp%2Foauth%2Fcallback", 200 + }) 201 + req.AddCookie(&http.Cookie{ 202 + Name: "oauth_csrf", 203 + Value: "valid-csrf-token", 204 + }) 205 + 206 + // This would be accepted (assuming valid OAuth code/state) 207 + // (Full test in integration tests with real OAuth flow) 208 + 209 + assert.NotNil(t, req) // Placeholder assertion 210 + }) 211 + } 212 + 213 + // TestHandleMobileCallback_RevalidatesRedirectURI tests that handleMobileCallback re-validates the redirect URI 214 + func TestHandleMobileCallback_RevalidatesRedirectURI(t *testing.T) { 215 + // This is a critical security test: even if an attacker somehow bypasses the initial check, 216 + // the callback handler should re-validate the redirect URI before redirecting. 217 + 218 + tests := []struct { 219 + name string 220 + redirectURI string 221 + shouldPass bool 222 + }{ 223 + { 224 + name: "allowed - Universal Link", 225 + redirectURI: "https://coves.social/app/oauth/callback", 226 + shouldPass: true, 227 + }, 228 + { 229 + name: "blocked - custom scheme (insecure)", 230 + redirectURI: "coves-app://oauth/callback", 231 + shouldPass: false, 232 + }, 233 + { 234 + name: "blocked - evil scheme", 235 + redirectURI: "evil://callback", 236 + shouldPass: false, 237 + }, 238 + } 239 + 240 + for _, tt := range tests { 241 + t.Run(tt.name, func(t *testing.T) { 242 + result := isAllowedMobileRedirectURI(tt.redirectURI) 243 + assert.Equal(t, tt.shouldPass, result) 244 + }) 245 + } 246 + } 247 + 248 + // TestGenerateMobileRedirectBinding tests the binding token generation 249 + // The binding now includes the CSRF token for proper double-submit validation 250 + func TestGenerateMobileRedirectBinding(t *testing.T) { 251 + csrfToken := "test-csrf-token-12345" 252 + tests := []struct { 253 + name string 254 + redirectURI string 255 + }{ 256 + { 257 + name: "Universal Link", 258 + redirectURI: "https://coves.social/app/oauth/callback", 259 + }, 260 + { 261 + name: "different path", 262 + redirectURI: "https://coves.social/different/path", 263 + }, 264 + } 265 + 266 + for _, tt := range tests { 267 + t.Run(tt.name, func(t *testing.T) { 268 + binding1 := generateMobileRedirectBinding(csrfToken, tt.redirectURI) 269 + binding2 := generateMobileRedirectBinding(csrfToken, tt.redirectURI) 270 + 271 + // Same CSRF token + URI should produce same binding (deterministic) 272 + assert.Equal(t, binding1, binding2, "binding should be deterministic for same inputs") 273 + 274 + // Binding should not be empty 275 + assert.NotEmpty(t, binding1, "binding should not be empty") 276 + 277 + // Binding should be base64 encoded (should decode without error) 278 + assert.Greater(t, len(binding1), 20, "binding should be reasonably long") 279 + }) 280 + } 281 + 282 + // Different URIs should produce different bindings 283 + binding1 := generateMobileRedirectBinding(csrfToken, "https://coves.social/app/oauth/callback") 284 + binding2 := generateMobileRedirectBinding(csrfToken, "https://coves.social/different/path") 285 + assert.NotEqual(t, binding1, binding2, "different URIs should produce different bindings") 286 + 287 + // Different CSRF tokens should produce different bindings 288 + binding3 := generateMobileRedirectBinding("different-csrf-token", "https://coves.social/app/oauth/callback") 289 + assert.NotEqual(t, binding1, binding3, "different CSRF tokens should produce different bindings") 290 + } 291 + 292 + // TestValidateMobileRedirectBinding tests the binding validation 293 + // Now validates both CSRF token and redirect URI together (double-submit pattern) 294 + func TestValidateMobileRedirectBinding(t *testing.T) { 295 + csrfToken := "test-csrf-token-for-validation" 296 + redirectURI := "https://coves.social/app/oauth/callback" 297 + validBinding := generateMobileRedirectBinding(csrfToken, redirectURI) 298 + 299 + tests := []struct { 300 + name string 301 + csrfToken string 302 + redirectURI string 303 + binding string 304 + shouldPass bool 305 + }{ 306 + { 307 + name: "valid - correct CSRF token and redirect URI", 308 + csrfToken: csrfToken, 309 + redirectURI: redirectURI, 310 + binding: validBinding, 311 + shouldPass: true, 312 + }, 313 + { 314 + name: "invalid - wrong redirect URI", 315 + csrfToken: csrfToken, 316 + redirectURI: "https://coves.social/different/path", 317 + binding: validBinding, 318 + shouldPass: false, 319 + }, 320 + { 321 + name: "invalid - wrong CSRF token", 322 + csrfToken: "wrong-csrf-token", 323 + redirectURI: redirectURI, 324 + binding: validBinding, 325 + shouldPass: false, 326 + }, 327 + { 328 + name: "invalid - random binding", 329 + csrfToken: csrfToken, 330 + redirectURI: redirectURI, 331 + binding: "random-invalid-binding", 332 + shouldPass: false, 333 + }, 334 + { 335 + name: "invalid - empty binding", 336 + csrfToken: csrfToken, 337 + redirectURI: redirectURI, 338 + binding: "", 339 + shouldPass: false, 340 + }, 341 + { 342 + name: "invalid - empty CSRF token", 343 + csrfToken: "", 344 + redirectURI: redirectURI, 345 + binding: validBinding, 346 + shouldPass: false, 347 + }, 348 + } 349 + 350 + for _, tt := range tests { 351 + t.Run(tt.name, func(t *testing.T) { 352 + result := validateMobileRedirectBinding(tt.csrfToken, tt.redirectURI, tt.binding) 353 + assert.Equal(t, tt.shouldPass, result) 354 + }) 355 + } 356 + } 357 + 358 + // TestSessionFixationAttackPrevention tests that the binding prevents session fixation 359 + func TestSessionFixationAttackPrevention(t *testing.T) { 360 + // Simulate attack scenario: 361 + // 1. Attacker plants a cookie for evil://steal with binding for evil://steal 362 + // 2. User does a web login (no mobile_redirect_binding cookie) 363 + // 3. Callback should NOT redirect to evil://steal 364 + 365 + attackerCSRF := "attacker-csrf-token" 366 + attackerRedirectURI := "evil://steal" 367 + attackerBinding := generateMobileRedirectBinding(attackerCSRF, attackerRedirectURI) 368 + 369 + // Later, user's legitimate mobile login 370 + userCSRF := "user-csrf-token" 371 + userRedirectURI := "https://coves.social/app/oauth/callback" 372 + userBinding := generateMobileRedirectBinding(userCSRF, userRedirectURI) 373 + 374 + // The attacker's binding should NOT validate for the user's redirect URI 375 + assert.False(t, validateMobileRedirectBinding(userCSRF, userRedirectURI, attackerBinding), 376 + "attacker's binding should not validate for user's CSRF token and redirect URI") 377 + 378 + // The user's binding should validate for the user's CSRF token and redirect URI 379 + assert.True(t, validateMobileRedirectBinding(userCSRF, userRedirectURI, userBinding), 380 + "user's binding should validate for user's CSRF token and redirect URI") 381 + 382 + // Cross-validation should fail 383 + assert.False(t, validateMobileRedirectBinding(attackerCSRF, attackerRedirectURI, userBinding), 384 + "user's binding should not validate for attacker's CSRF token and redirect URI") 385 + } 386 + 387 + // TestCSRFTokenValidation tests that CSRF token VALUE is validated, not just presence 388 + func TestCSRFTokenValidation(t *testing.T) { 389 + // This test verifies the fix for the P1 security issue: 390 + // "The callback never validates the token... the csrfToken argument is ignored entirely" 391 + // 392 + // The fix ensures that the CSRF token VALUE is cryptographically bound to the 393 + // binding token, so changing the CSRF token will invalidate the binding. 394 + 395 + t.Run("CSRF token value must match", func(t *testing.T) { 396 + originalCSRF := "original-csrf-token-from-login" 397 + redirectURI := "https://coves.social/app/oauth/callback" 398 + binding := generateMobileRedirectBinding(originalCSRF, redirectURI) 399 + 400 + // Original CSRF token should validate 401 + assert.True(t, validateMobileRedirectBinding(originalCSRF, redirectURI, binding), 402 + "original CSRF token should validate") 403 + 404 + // Different CSRF token should NOT validate (this is the key security fix) 405 + differentCSRF := "attacker-forged-csrf-token" 406 + assert.False(t, validateMobileRedirectBinding(differentCSRF, redirectURI, binding), 407 + "different CSRF token should NOT validate - this is the security fix") 408 + }) 409 + 410 + t.Run("attacker cannot forge binding without CSRF token", func(t *testing.T) { 411 + // Attacker knows the redirect URI but not the CSRF token 412 + redirectURI := "https://coves.social/app/oauth/callback" 413 + victimCSRF := "victim-secret-csrf-token" 414 + victimBinding := generateMobileRedirectBinding(victimCSRF, redirectURI) 415 + 416 + // Attacker tries various CSRF tokens to forge the binding 417 + attackerGuesses := []string{ 418 + "", 419 + "guess1", 420 + "attacker-csrf", 421 + redirectURI, // trying the redirect URI as CSRF 422 + } 423 + 424 + for _, guess := range attackerGuesses { 425 + assert.False(t, validateMobileRedirectBinding(guess, redirectURI, victimBinding), 426 + "attacker's CSRF guess %q should not validate", guess) 427 + } 428 + }) 429 + } 430 + 431 + // TestConstantTimeCompare tests the timing-safe comparison function 432 + func TestConstantTimeCompare(t *testing.T) { 433 + tests := []struct { 434 + name string 435 + a string 436 + b string 437 + expected bool 438 + }{ 439 + { 440 + name: "equal strings", 441 + a: "abc123", 442 + b: "abc123", 443 + expected: true, 444 + }, 445 + { 446 + name: "different strings same length", 447 + a: "abc123", 448 + b: "xyz789", 449 + expected: false, 450 + }, 451 + { 452 + name: "different lengths", 453 + a: "short", 454 + b: "longer", 455 + expected: false, 456 + }, 457 + { 458 + name: "empty strings", 459 + a: "", 460 + b: "", 461 + expected: true, 462 + }, 463 + { 464 + name: "one empty", 465 + a: "abc", 466 + b: "", 467 + expected: false, 468 + }, 469 + } 470 + 471 + for _, tt := range tests { 472 + t.Run(tt.name, func(t *testing.T) { 473 + result := constantTimeCompare(tt.a, tt.b) 474 + assert.Equal(t, tt.expected, result) 475 + }) 476 + } 477 + }
+279
internal/atproto/oauth/handlers_test.go
··· 1 + package oauth 2 + 3 + import ( 4 + "encoding/json" 5 + "net/http" 6 + "net/http/httptest" 7 + "testing" 8 + "time" 9 + 10 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 11 + "github.com/bluesky-social/indigo/atproto/syntax" 12 + "github.com/stretchr/testify/assert" 13 + "github.com/stretchr/testify/require" 14 + ) 15 + 16 + // TestHandleClientMetadata tests the client metadata endpoint 17 + func TestHandleClientMetadata(t *testing.T) { 18 + // Create a test OAuth client configuration 19 + config := &OAuthConfig{ 20 + PublicURL: "https://coves.social", 21 + Scopes: []string{"atproto"}, 22 + DevMode: false, 23 + AllowPrivateIPs: false, 24 + SealSecret: "MTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTI=", // base64 encoded 32 bytes 25 + } 26 + 27 + // Create OAuth client with memory store 28 + client, err := NewOAuthClient(config, oauth.NewMemStore()) 29 + require.NoError(t, err) 30 + 31 + // Create handler 32 + handler := NewOAuthHandler(client, oauth.NewMemStore()) 33 + 34 + // Create test request 35 + req := httptest.NewRequest(http.MethodGet, "/oauth/client-metadata.json", nil) 36 + req.Host = "coves.social" 37 + rec := httptest.NewRecorder() 38 + 39 + // Call handler 40 + handler.HandleClientMetadata(rec, req) 41 + 42 + // Check response 43 + assert.Equal(t, http.StatusOK, rec.Code) 44 + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) 45 + 46 + // Parse response 47 + var metadata oauth.ClientMetadata 48 + err = json.NewDecoder(rec.Body).Decode(&metadata) 49 + require.NoError(t, err) 50 + 51 + // Validate metadata 52 + assert.Equal(t, "https://coves.social", metadata.ClientID) 53 + assert.Contains(t, metadata.RedirectURIs, "https://coves.social/oauth/callback") 54 + assert.Contains(t, metadata.GrantTypes, "authorization_code") 55 + assert.Contains(t, metadata.GrantTypes, "refresh_token") 56 + assert.True(t, metadata.DPoPBoundAccessTokens) 57 + assert.Contains(t, metadata.Scope, "atproto") 58 + } 59 + 60 + // TestHandleJWKS tests the JWKS endpoint 61 + func TestHandleJWKS(t *testing.T) { 62 + // Create a test OAuth client configuration (public client, no keys) 63 + config := &OAuthConfig{ 64 + PublicURL: "https://coves.social", 65 + Scopes: []string{"atproto"}, 66 + DevMode: false, 67 + AllowPrivateIPs: false, 68 + SealSecret: "MTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTI=", 69 + } 70 + 71 + client, err := NewOAuthClient(config, oauth.NewMemStore()) 72 + require.NoError(t, err) 73 + 74 + handler := NewOAuthHandler(client, oauth.NewMemStore()) 75 + 76 + // Create test request 77 + req := httptest.NewRequest(http.MethodGet, "/oauth/jwks.json", nil) 78 + rec := httptest.NewRecorder() 79 + 80 + // Call handler 81 + handler.HandleJWKS(rec, req) 82 + 83 + // Check response 84 + assert.Equal(t, http.StatusOK, rec.Code) 85 + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) 86 + 87 + // Parse response 88 + var jwks oauth.JWKS 89 + err = json.NewDecoder(rec.Body).Decode(&jwks) 90 + require.NoError(t, err) 91 + 92 + // Public client should have empty JWKS 93 + assert.NotNil(t, jwks.Keys) 94 + assert.Equal(t, 0, len(jwks.Keys)) 95 + } 96 + 97 + // TestHandleLogin tests the login endpoint 98 + func TestHandleLogin(t *testing.T) { 99 + config := &OAuthConfig{ 100 + PublicURL: "https://coves.social", 101 + Scopes: []string{"atproto"}, 102 + DevMode: true, // Use dev mode to avoid real PDS calls 103 + AllowPrivateIPs: true, // Allow private IPs in dev mode 104 + SealSecret: "MTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTI=", 105 + } 106 + 107 + client, err := NewOAuthClient(config, oauth.NewMemStore()) 108 + require.NoError(t, err) 109 + 110 + handler := NewOAuthHandler(client, oauth.NewMemStore()) 111 + 112 + t.Run("missing identifier", func(t *testing.T) { 113 + req := httptest.NewRequest(http.MethodGet, "/oauth/login", nil) 114 + rec := httptest.NewRecorder() 115 + 116 + handler.HandleLogin(rec, req) 117 + 118 + assert.Equal(t, http.StatusBadRequest, rec.Code) 119 + }) 120 + 121 + t.Run("with handle parameter", func(t *testing.T) { 122 + // This test would need a mock PDS server to fully test 123 + // For now, we just verify the endpoint accepts the parameter 124 + req := httptest.NewRequest(http.MethodGet, "/oauth/login?handle=user.bsky.social", nil) 125 + rec := httptest.NewRecorder() 126 + 127 + handler.HandleLogin(rec, req) 128 + 129 + // In dev mode or with a real PDS, this would redirect 130 + // Without a mock, it will fail to resolve the handle 131 + // We're just testing that the handler processes the request 132 + assert.NotEqual(t, http.StatusOK, rec.Code) // Should redirect or error 133 + }) 134 + } 135 + 136 + // TestHandleMobileLogin tests the mobile login endpoint 137 + func TestHandleMobileLogin(t *testing.T) { 138 + config := &OAuthConfig{ 139 + PublicURL: "https://coves.social", 140 + Scopes: []string{"atproto"}, 141 + DevMode: true, 142 + AllowPrivateIPs: true, 143 + SealSecret: "MTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTI=", 144 + } 145 + 146 + client, err := NewOAuthClient(config, oauth.NewMemStore()) 147 + require.NoError(t, err) 148 + 149 + handler := NewOAuthHandler(client, oauth.NewMemStore()) 150 + 151 + t.Run("missing redirect_uri", func(t *testing.T) { 152 + req := httptest.NewRequest(http.MethodGet, "/oauth/mobile/login?handle=user.bsky.social", nil) 153 + rec := httptest.NewRecorder() 154 + 155 + handler.HandleMobileLogin(rec, req) 156 + 157 + assert.Equal(t, http.StatusBadRequest, rec.Code) 158 + assert.Contains(t, rec.Body.String(), "redirect_uri") 159 + }) 160 + 161 + t.Run("invalid redirect_uri (https)", func(t *testing.T) { 162 + req := httptest.NewRequest(http.MethodGet, "/oauth/mobile/login?handle=user.bsky.social&redirect_uri=https://example.com", nil) 163 + rec := httptest.NewRecorder() 164 + 165 + handler.HandleMobileLogin(rec, req) 166 + 167 + assert.Equal(t, http.StatusBadRequest, rec.Code) 168 + assert.Contains(t, rec.Body.String(), "invalid redirect_uri") 169 + }) 170 + 171 + t.Run("invalid redirect_uri (wrong path)", func(t *testing.T) { 172 + req := httptest.NewRequest(http.MethodGet, "/oauth/mobile/login?handle=user.bsky.social&redirect_uri=coves-app://callback", nil) 173 + rec := httptest.NewRecorder() 174 + 175 + handler.HandleMobileLogin(rec, req) 176 + 177 + assert.Equal(t, http.StatusBadRequest, rec.Code) 178 + assert.Contains(t, rec.Body.String(), "invalid redirect_uri") 179 + }) 180 + 181 + t.Run("valid mobile redirect_uri (Universal Link)", func(t *testing.T) { 182 + req := httptest.NewRequest(http.MethodGet, "/oauth/mobile/login?handle=user.bsky.social&redirect_uri=https://coves.social/app/oauth/callback", nil) 183 + rec := httptest.NewRecorder() 184 + 185 + handler.HandleMobileLogin(rec, req) 186 + 187 + // Should fail to resolve handle but accept the parameters 188 + // Check that cookie was set 189 + cookies := rec.Result().Cookies() 190 + var found bool 191 + for _, cookie := range cookies { 192 + if cookie.Name == "mobile_redirect_uri" { 193 + found = true 194 + break 195 + } 196 + } 197 + // May or may not set cookie depending on error handling 198 + _ = found 199 + }) 200 + } 201 + 202 + // TestParseSessionToken tests that we no longer use parseSessionToken 203 + // (removed in favor of sealed tokens) 204 + func TestParseSessionToken(t *testing.T) { 205 + // This test is deprecated - we now use sealed tokens instead of plain "did:sessionID" format 206 + // See TestSealAndUnsealSessionData for the new approach 207 + t.Skip("parseSessionToken removed - we now use sealed tokens for security") 208 + } 209 + 210 + // TestIsMobileRedirectURI tests mobile redirect URI validation with EXACT URI matching 211 + // Only Universal Links (HTTPS) are allowed - custom schemes are blocked for security 212 + func TestIsMobileRedirectURI(t *testing.T) { 213 + tests := []struct { 214 + uri string 215 + expected bool 216 + }{ 217 + {"https://coves.social/app/oauth/callback", true}, // Universal Link - allowed 218 + {"coves-app://oauth/callback", false}, // Custom scheme - blocked (insecure) 219 + {"coves://oauth/callback", false}, // Custom scheme - blocked (insecure) 220 + {"coves-app://callback", false}, // Custom scheme - blocked 221 + {"coves://oauth", false}, // Custom scheme - blocked 222 + {"myapp://oauth", false}, // Not in allowlist 223 + {"https://example.com", false}, // Wrong domain 224 + {"http://localhost", false}, // HTTP not allowed 225 + {"", false}, 226 + {"not-a-uri", false}, 227 + } 228 + 229 + for _, tt := range tests { 230 + t.Run(tt.uri, func(t *testing.T) { 231 + result := isAllowedMobileRedirectURI(tt.uri) 232 + assert.Equal(t, tt.expected, result) 233 + }) 234 + } 235 + } 236 + 237 + // TestSealAndUnsealSessionData tests session data sealing/unsealing 238 + func TestSealAndUnsealSessionData(t *testing.T) { 239 + config := &OAuthConfig{ 240 + PublicURL: "https://coves.social", 241 + Scopes: []string{"atproto"}, 242 + DevMode: false, 243 + AllowPrivateIPs: false, 244 + SealSecret: "MTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTI=", 245 + } 246 + 247 + client, err := NewOAuthClient(config, oauth.NewMemStore()) 248 + require.NoError(t, err) 249 + 250 + // Create test DID 251 + did, err := testDID() 252 + require.NoError(t, err) 253 + 254 + sessionID := "test-session-123" 255 + 256 + // Seal the session using the client method 257 + sealed, err := client.SealSession(did.String(), sessionID, 24*time.Hour) 258 + require.NoError(t, err) 259 + assert.NotEmpty(t, sealed) 260 + 261 + // Unseal the session using the client method 262 + unsealed, err := client.UnsealSession(sealed) 263 + require.NoError(t, err) 264 + require.NotNil(t, unsealed) 265 + 266 + // Verify data matches 267 + assert.Equal(t, did.String(), unsealed.DID) 268 + assert.Equal(t, sessionID, unsealed.SessionID) 269 + assert.Greater(t, unsealed.ExpiresAt, int64(0)) 270 + } 271 + 272 + // testDID creates a test DID for testing 273 + func testDID() (*syntax.DID, error) { 274 + did, err := syntax.ParseDID("did:plc:test123abc456def789") 275 + if err != nil { 276 + return nil, err 277 + } 278 + return &did, nil 279 + }
+152
internal/atproto/oauth/seal.go
··· 1 + package oauth 2 + 3 + import ( 4 + "crypto/aes" 5 + "crypto/cipher" 6 + "crypto/rand" 7 + "encoding/base64" 8 + "encoding/json" 9 + "fmt" 10 + "time" 11 + ) 12 + 13 + // SealedSession represents the data sealed in a mobile session token 14 + type SealedSession struct { 15 + DID string `json:"did"` // User's DID 16 + SessionID string `json:"sid"` // Session identifier 17 + ExpiresAt int64 `json:"exp"` // Unix timestamp when token expires 18 + } 19 + 20 + // SealSession creates an encrypted token containing session information. 21 + // The token is encrypted using AES-256-GCM and encoded as base64url. 22 + // 23 + // Token format: base64url(nonce || ciphertext || tag) 24 + // - nonce: 12 bytes (GCM standard nonce size) 25 + // - ciphertext: encrypted JSON payload 26 + // - tag: 16 bytes (GCM authentication tag) 27 + // 28 + // The sealed token can be safely given to mobile clients and used as 29 + // a reference to the server-side session without exposing sensitive data. 30 + func (c *OAuthClient) SealSession(did, sessionID string, ttl time.Duration) (string, error) { 31 + if len(c.SealSecret) == 0 { 32 + return "", fmt.Errorf("seal secret not configured") 33 + } 34 + 35 + if did == "" { 36 + return "", fmt.Errorf("DID is required") 37 + } 38 + 39 + if sessionID == "" { 40 + return "", fmt.Errorf("session ID is required") 41 + } 42 + 43 + // Create the session data 44 + expiresAt := time.Now().Add(ttl).Unix() 45 + session := SealedSession{ 46 + DID: did, 47 + SessionID: sessionID, 48 + ExpiresAt: expiresAt, 49 + } 50 + 51 + // Marshal to JSON 52 + plaintext, err := json.Marshal(session) 53 + if err != nil { 54 + return "", fmt.Errorf("failed to marshal session: %w", err) 55 + } 56 + 57 + // Create AES cipher 58 + block, err := aes.NewCipher(c.SealSecret) 59 + if err != nil { 60 + return "", fmt.Errorf("failed to create cipher: %w", err) 61 + } 62 + 63 + // Create GCM mode 64 + gcm, err := cipher.NewGCM(block) 65 + if err != nil { 66 + return "", fmt.Errorf("failed to create GCM: %w", err) 67 + } 68 + 69 + // Generate random nonce 70 + nonce := make([]byte, gcm.NonceSize()) 71 + if _, err := rand.Read(nonce); err != nil { 72 + return "", fmt.Errorf("failed to generate nonce: %w", err) 73 + } 74 + 75 + // Encrypt and authenticate 76 + // GCM.Seal appends the ciphertext and tag to the nonce 77 + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) 78 + 79 + // Encode as base64url (no padding) 80 + token := base64.RawURLEncoding.EncodeToString(ciphertext) 81 + 82 + return token, nil 83 + } 84 + 85 + // UnsealSession decrypts and validates a sealed session token. 86 + // Returns the session information if the token is valid and not expired. 87 + func (c *OAuthClient) UnsealSession(token string) (*SealedSession, error) { 88 + if len(c.SealSecret) == 0 { 89 + return nil, fmt.Errorf("seal secret not configured") 90 + } 91 + 92 + if token == "" { 93 + return nil, fmt.Errorf("token is required") 94 + } 95 + 96 + // Decode from base64url 97 + ciphertext, err := base64.RawURLEncoding.DecodeString(token) 98 + if err != nil { 99 + return nil, fmt.Errorf("invalid token encoding: %w", err) 100 + } 101 + 102 + // Create AES cipher 103 + block, err := aes.NewCipher(c.SealSecret) 104 + if err != nil { 105 + return nil, fmt.Errorf("failed to create cipher: %w", err) 106 + } 107 + 108 + // Create GCM mode 109 + gcm, err := cipher.NewGCM(block) 110 + if err != nil { 111 + return nil, fmt.Errorf("failed to create GCM: %w", err) 112 + } 113 + 114 + // Verify minimum size (nonce + tag) 115 + nonceSize := gcm.NonceSize() 116 + if len(ciphertext) < nonceSize { 117 + return nil, fmt.Errorf("invalid token: too short") 118 + } 119 + 120 + // Extract nonce and ciphertext 121 + nonce := ciphertext[:nonceSize] 122 + ciphertextData := ciphertext[nonceSize:] 123 + 124 + // Decrypt and authenticate 125 + plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil) 126 + if err != nil { 127 + return nil, fmt.Errorf("failed to decrypt token: %w", err) 128 + } 129 + 130 + // Unmarshal JSON 131 + var session SealedSession 132 + if err := json.Unmarshal(plaintext, &session); err != nil { 133 + return nil, fmt.Errorf("failed to unmarshal session: %w", err) 134 + } 135 + 136 + // Validate required fields 137 + if session.DID == "" { 138 + return nil, fmt.Errorf("invalid session: missing DID") 139 + } 140 + 141 + if session.SessionID == "" { 142 + return nil, fmt.Errorf("invalid session: missing session ID") 143 + } 144 + 145 + // Check expiration 146 + now := time.Now().Unix() 147 + if session.ExpiresAt <= now { 148 + return nil, fmt.Errorf("token expired at %v", time.Unix(session.ExpiresAt, 0)) 149 + } 150 + 151 + return &session, nil 152 + }
+331
internal/atproto/oauth/seal_test.go
··· 1 + package oauth 2 + 3 + import ( 4 + "crypto/rand" 5 + "encoding/base64" 6 + "strings" 7 + "testing" 8 + "time" 9 + 10 + "github.com/stretchr/testify/assert" 11 + "github.com/stretchr/testify/require" 12 + ) 13 + 14 + // generateSealSecret generates a random 32-byte seal secret for testing 15 + func generateSealSecret() []byte { 16 + secret := make([]byte, 32) 17 + if _, err := rand.Read(secret); err != nil { 18 + panic(err) 19 + } 20 + return secret 21 + } 22 + 23 + func TestSealSession_RoundTrip(t *testing.T) { 24 + // Create client with seal secret 25 + client := &OAuthClient{ 26 + SealSecret: generateSealSecret(), 27 + } 28 + 29 + did := "did:plc:abc123" 30 + sessionID := "session-xyz" 31 + ttl := 1 * time.Hour 32 + 33 + // Seal the session 34 + token, err := client.SealSession(did, sessionID, ttl) 35 + require.NoError(t, err) 36 + require.NotEmpty(t, token) 37 + 38 + // Token should be base64url encoded 39 + _, err = base64.RawURLEncoding.DecodeString(token) 40 + require.NoError(t, err, "token should be valid base64url") 41 + 42 + // Unseal the session 43 + session, err := client.UnsealSession(token) 44 + require.NoError(t, err) 45 + require.NotNil(t, session) 46 + 47 + // Verify data 48 + assert.Equal(t, did, session.DID) 49 + assert.Equal(t, sessionID, session.SessionID) 50 + 51 + // Verify expiration is approximately correct (within 1 second) 52 + expectedExpiry := time.Now().Add(ttl).Unix() 53 + assert.InDelta(t, expectedExpiry, session.ExpiresAt, 1.0) 54 + } 55 + 56 + func TestSealSession_ExpirationValidation(t *testing.T) { 57 + client := &OAuthClient{ 58 + SealSecret: generateSealSecret(), 59 + } 60 + 61 + did := "did:plc:abc123" 62 + sessionID := "session-xyz" 63 + ttl := 2 * time.Second // Short TTL (must be >= 1 second due to Unix timestamp granularity) 64 + 65 + // Seal the session 66 + token, err := client.SealSession(did, sessionID, ttl) 67 + require.NoError(t, err) 68 + 69 + // Should work immediately 70 + session, err := client.UnsealSession(token) 71 + require.NoError(t, err) 72 + assert.Equal(t, did, session.DID) 73 + 74 + // Wait well past expiration 75 + time.Sleep(2500 * time.Millisecond) 76 + 77 + // Should fail after expiration 78 + session, err = client.UnsealSession(token) 79 + assert.Error(t, err) 80 + assert.Nil(t, session) 81 + assert.Contains(t, err.Error(), "token expired") 82 + } 83 + 84 + func TestSealSession_TamperedTokenDetection(t *testing.T) { 85 + client := &OAuthClient{ 86 + SealSecret: generateSealSecret(), 87 + } 88 + 89 + did := "did:plc:abc123" 90 + sessionID := "session-xyz" 91 + ttl := 1 * time.Hour 92 + 93 + // Seal the session 94 + token, err := client.SealSession(did, sessionID, ttl) 95 + require.NoError(t, err) 96 + 97 + // Tamper with the token by modifying one character 98 + tampered := token[:len(token)-5] + "XXXX" + token[len(token)-1:] 99 + 100 + // Should fail to unseal tampered token 101 + session, err := client.UnsealSession(tampered) 102 + assert.Error(t, err) 103 + assert.Nil(t, session) 104 + assert.Contains(t, err.Error(), "failed to decrypt token") 105 + } 106 + 107 + func TestSealSession_InvalidTokenFormats(t *testing.T) { 108 + client := &OAuthClient{ 109 + SealSecret: generateSealSecret(), 110 + } 111 + 112 + tests := []struct { 113 + name string 114 + token string 115 + }{ 116 + { 117 + name: "empty token", 118 + token: "", 119 + }, 120 + { 121 + name: "invalid base64", 122 + token: "not-valid-base64!@#$", 123 + }, 124 + { 125 + name: "too short", 126 + token: base64.RawURLEncoding.EncodeToString([]byte("short")), 127 + }, 128 + { 129 + name: "random bytes", 130 + token: base64.RawURLEncoding.EncodeToString(make([]byte, 50)), 131 + }, 132 + } 133 + 134 + for _, tt := range tests { 135 + t.Run(tt.name, func(t *testing.T) { 136 + session, err := client.UnsealSession(tt.token) 137 + assert.Error(t, err) 138 + assert.Nil(t, session) 139 + }) 140 + } 141 + } 142 + 143 + func TestSealSession_DifferentSecrets(t *testing.T) { 144 + // Create two clients with different secrets 145 + client1 := &OAuthClient{ 146 + SealSecret: generateSealSecret(), 147 + } 148 + client2 := &OAuthClient{ 149 + SealSecret: generateSealSecret(), 150 + } 151 + 152 + did := "did:plc:abc123" 153 + sessionID := "session-xyz" 154 + ttl := 1 * time.Hour 155 + 156 + // Seal with client1 157 + token, err := client1.SealSession(did, sessionID, ttl) 158 + require.NoError(t, err) 159 + 160 + // Try to unseal with client2 (different secret) 161 + session, err := client2.UnsealSession(token) 162 + assert.Error(t, err) 163 + assert.Nil(t, session) 164 + assert.Contains(t, err.Error(), "failed to decrypt token") 165 + } 166 + 167 + func TestSealSession_NoSecretConfigured(t *testing.T) { 168 + client := &OAuthClient{ 169 + SealSecret: nil, 170 + } 171 + 172 + did := "did:plc:abc123" 173 + sessionID := "session-xyz" 174 + ttl := 1 * time.Hour 175 + 176 + // Should fail to seal without secret 177 + token, err := client.SealSession(did, sessionID, ttl) 178 + assert.Error(t, err) 179 + assert.Empty(t, token) 180 + assert.Contains(t, err.Error(), "seal secret not configured") 181 + 182 + // Should fail to unseal without secret 183 + session, err := client.UnsealSession("dummy-token") 184 + assert.Error(t, err) 185 + assert.Nil(t, session) 186 + assert.Contains(t, err.Error(), "seal secret not configured") 187 + } 188 + 189 + func TestSealSession_MissingRequiredFields(t *testing.T) { 190 + client := &OAuthClient{ 191 + SealSecret: generateSealSecret(), 192 + } 193 + 194 + ttl := 1 * time.Hour 195 + 196 + tests := []struct { 197 + name string 198 + did string 199 + sessionID string 200 + errorMsg string 201 + }{ 202 + { 203 + name: "missing DID", 204 + did: "", 205 + sessionID: "session-123", 206 + errorMsg: "DID is required", 207 + }, 208 + { 209 + name: "missing session ID", 210 + did: "did:plc:abc123", 211 + sessionID: "", 212 + errorMsg: "session ID is required", 213 + }, 214 + } 215 + 216 + for _, tt := range tests { 217 + t.Run(tt.name, func(t *testing.T) { 218 + token, err := client.SealSession(tt.did, tt.sessionID, ttl) 219 + assert.Error(t, err) 220 + assert.Empty(t, token) 221 + assert.Contains(t, err.Error(), tt.errorMsg) 222 + }) 223 + } 224 + } 225 + 226 + func TestSealSession_UniquenessPerCall(t *testing.T) { 227 + client := &OAuthClient{ 228 + SealSecret: generateSealSecret(), 229 + } 230 + 231 + did := "did:plc:abc123" 232 + sessionID := "session-xyz" 233 + ttl := 1 * time.Hour 234 + 235 + // Seal the same session twice 236 + token1, err := client.SealSession(did, sessionID, ttl) 237 + require.NoError(t, err) 238 + 239 + token2, err := client.SealSession(did, sessionID, ttl) 240 + require.NoError(t, err) 241 + 242 + // Tokens should be different (different nonces) 243 + assert.NotEqual(t, token1, token2, "tokens should be unique due to different nonces") 244 + 245 + // But both should unseal to the same session data 246 + session1, err := client.UnsealSession(token1) 247 + require.NoError(t, err) 248 + 249 + session2, err := client.UnsealSession(token2) 250 + require.NoError(t, err) 251 + 252 + assert.Equal(t, session1.DID, session2.DID) 253 + assert.Equal(t, session1.SessionID, session2.SessionID) 254 + } 255 + 256 + func TestSealSession_LongDIDAndSessionID(t *testing.T) { 257 + client := &OAuthClient{ 258 + SealSecret: generateSealSecret(), 259 + } 260 + 261 + // Test with very long DID and session ID 262 + did := "did:plc:" + strings.Repeat("a", 200) 263 + sessionID := "session-" + strings.Repeat("x", 200) 264 + ttl := 1 * time.Hour 265 + 266 + // Should work with long values 267 + token, err := client.SealSession(did, sessionID, ttl) 268 + require.NoError(t, err) 269 + 270 + session, err := client.UnsealSession(token) 271 + require.NoError(t, err) 272 + assert.Equal(t, did, session.DID) 273 + assert.Equal(t, sessionID, session.SessionID) 274 + } 275 + 276 + func TestSealSession_URLSafeEncoding(t *testing.T) { 277 + client := &OAuthClient{ 278 + SealSecret: generateSealSecret(), 279 + } 280 + 281 + did := "did:plc:abc123" 282 + sessionID := "session-xyz" 283 + ttl := 1 * time.Hour 284 + 285 + // Seal multiple times to get different nonces 286 + for i := 0; i < 100; i++ { 287 + token, err := client.SealSession(did, sessionID, ttl) 288 + require.NoError(t, err) 289 + 290 + // Token should not contain URL-unsafe characters 291 + assert.NotContains(t, token, "+", "token should not contain '+'") 292 + assert.NotContains(t, token, "/", "token should not contain '/'") 293 + assert.NotContains(t, token, "=", "token should not contain '='") 294 + 295 + // Should unseal successfully 296 + session, err := client.UnsealSession(token) 297 + require.NoError(t, err) 298 + assert.Equal(t, did, session.DID) 299 + } 300 + } 301 + 302 + func TestSealSession_ConcurrentAccess(t *testing.T) { 303 + client := &OAuthClient{ 304 + SealSecret: generateSealSecret(), 305 + } 306 + 307 + did := "did:plc:abc123" 308 + sessionID := "session-xyz" 309 + ttl := 1 * time.Hour 310 + 311 + // Run concurrent seal/unseal operations 312 + done := make(chan bool) 313 + for i := 0; i < 10; i++ { 314 + go func() { 315 + for j := 0; j < 100; j++ { 316 + token, err := client.SealSession(did, sessionID, ttl) 317 + require.NoError(t, err) 318 + 319 + session, err := client.UnsealSession(token) 320 + require.NoError(t, err) 321 + assert.Equal(t, did, session.DID) 322 + } 323 + done <- true 324 + }() 325 + } 326 + 327 + // Wait for all goroutines 328 + for i := 0; i < 10; i++ { 329 + <-done 330 + } 331 + }
+614
internal/atproto/oauth/store.go
··· 1 + package oauth 2 + 3 + import ( 4 + "context" 5 + "database/sql" 6 + "errors" 7 + "fmt" 8 + "log/slog" 9 + "net/url" 10 + "strings" 11 + "time" 12 + 13 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 14 + "github.com/bluesky-social/indigo/atproto/syntax" 15 + "github.com/lib/pq" 16 + ) 17 + 18 + var ( 19 + ErrSessionNotFound = errors.New("oauth session not found") 20 + ErrAuthRequestNotFound = errors.New("oauth auth request not found") 21 + ) 22 + 23 + // PostgresOAuthStore implements oauth.ClientAuthStore interface using PostgreSQL 24 + type PostgresOAuthStore struct { 25 + db *sql.DB 26 + sessionTTL time.Duration 27 + } 28 + 29 + // NewPostgresOAuthStore creates a new PostgreSQL-backed OAuth store 30 + func NewPostgresOAuthStore(db *sql.DB, sessionTTL time.Duration) oauth.ClientAuthStore { 31 + if sessionTTL == 0 { 32 + sessionTTL = 7 * 24 * time.Hour // Default to 7 days 33 + } 34 + return &PostgresOAuthStore{ 35 + db: db, 36 + sessionTTL: sessionTTL, 37 + } 38 + } 39 + 40 + // GetSession retrieves a session by DID and session ID 41 + func (s *PostgresOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { 42 + query := ` 43 + SELECT 44 + did, session_id, host_url, auth_server_iss, 45 + auth_server_token_endpoint, auth_server_revocation_endpoint, 46 + scopes, access_token, refresh_token, 47 + dpop_authserver_nonce, dpop_pds_nonce, dpop_private_key_multibase 48 + FROM oauth_sessions 49 + WHERE did = $1 AND session_id = $2 AND expires_at > NOW() 50 + ` 51 + 52 + var session oauth.ClientSessionData 53 + var authServerIss, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString 54 + var hostURL, dpopPrivateKeyMultibase sql.NullString 55 + var scopes pq.StringArray 56 + var dpopAuthServerNonce, dpopHostNonce sql.NullString 57 + 58 + err := s.db.QueryRowContext(ctx, query, did.String(), sessionID).Scan( 59 + &session.AccountDID, 60 + &session.SessionID, 61 + &hostURL, 62 + &authServerIss, 63 + &authServerTokenEndpoint, 64 + &authServerRevocationEndpoint, 65 + &scopes, 66 + &session.AccessToken, 67 + &session.RefreshToken, 68 + &dpopAuthServerNonce, 69 + &dpopHostNonce, 70 + &dpopPrivateKeyMultibase, 71 + ) 72 + 73 + if err == sql.ErrNoRows { 74 + return nil, ErrSessionNotFound 75 + } 76 + if err != nil { 77 + return nil, fmt.Errorf("failed to get session: %w", err) 78 + } 79 + 80 + // Convert nullable fields 81 + if hostURL.Valid { 82 + session.HostURL = hostURL.String 83 + } 84 + if authServerIss.Valid { 85 + session.AuthServerURL = authServerIss.String 86 + } 87 + if authServerTokenEndpoint.Valid { 88 + session.AuthServerTokenEndpoint = authServerTokenEndpoint.String 89 + } 90 + if authServerRevocationEndpoint.Valid { 91 + session.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String 92 + } 93 + if dpopAuthServerNonce.Valid { 94 + session.DPoPAuthServerNonce = dpopAuthServerNonce.String 95 + } 96 + if dpopHostNonce.Valid { 97 + session.DPoPHostNonce = dpopHostNonce.String 98 + } 99 + if dpopPrivateKeyMultibase.Valid { 100 + session.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String 101 + } 102 + session.Scopes = scopes 103 + 104 + return &session, nil 105 + } 106 + 107 + // SaveSession saves or updates a session (upsert operation) 108 + func (s *PostgresOAuthStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { 109 + // Input validation per atProto OAuth security requirements 110 + 111 + // Validate DID format 112 + if _, err := syntax.ParseDID(sess.AccountDID.String()); err != nil { 113 + return fmt.Errorf("invalid DID format: %w", err) 114 + } 115 + 116 + // Validate token lengths (max 10000 chars to prevent memory issues) 117 + const maxTokenLength = 10000 118 + if len(sess.AccessToken) > maxTokenLength { 119 + return fmt.Errorf("access_token exceeds maximum length of %d characters", maxTokenLength) 120 + } 121 + if len(sess.RefreshToken) > maxTokenLength { 122 + return fmt.Errorf("refresh_token exceeds maximum length of %d characters", maxTokenLength) 123 + } 124 + 125 + // Validate session ID is not empty 126 + if sess.SessionID == "" { 127 + return fmt.Errorf("session_id cannot be empty") 128 + } 129 + 130 + // Validate URLs if provided 131 + if sess.HostURL != "" { 132 + if _, err := url.Parse(sess.HostURL); err != nil { 133 + return fmt.Errorf("invalid host_url: %w", err) 134 + } 135 + } 136 + if sess.AuthServerURL != "" { 137 + if _, err := url.Parse(sess.AuthServerURL); err != nil { 138 + return fmt.Errorf("invalid auth_server URL: %w", err) 139 + } 140 + } 141 + if sess.AuthServerTokenEndpoint != "" { 142 + if _, err := url.Parse(sess.AuthServerTokenEndpoint); err != nil { 143 + return fmt.Errorf("invalid auth_server_token_endpoint: %w", err) 144 + } 145 + } 146 + if sess.AuthServerRevocationEndpoint != "" { 147 + if _, err := url.Parse(sess.AuthServerRevocationEndpoint); err != nil { 148 + return fmt.Errorf("invalid auth_server_revocation_endpoint: %w", err) 149 + } 150 + } 151 + 152 + query := ` 153 + INSERT INTO oauth_sessions ( 154 + did, session_id, handle, pds_url, host_url, 155 + access_token, refresh_token, 156 + dpop_private_jwk, dpop_private_key_multibase, 157 + dpop_authserver_nonce, dpop_pds_nonce, 158 + auth_server_iss, auth_server_token_endpoint, auth_server_revocation_endpoint, 159 + scopes, expires_at, created_at, updated_at 160 + ) VALUES ( 161 + $1, $2, $3, $4, $5, 162 + $6, $7, 163 + NULL, $8, 164 + $9, $10, 165 + $11, $12, $13, 166 + $14, $15, NOW(), NOW() 167 + ) 168 + ON CONFLICT (did, session_id) DO UPDATE SET 169 + handle = EXCLUDED.handle, 170 + pds_url = EXCLUDED.pds_url, 171 + host_url = EXCLUDED.host_url, 172 + access_token = EXCLUDED.access_token, 173 + refresh_token = EXCLUDED.refresh_token, 174 + dpop_private_key_multibase = EXCLUDED.dpop_private_key_multibase, 175 + dpop_authserver_nonce = EXCLUDED.dpop_authserver_nonce, 176 + dpop_pds_nonce = EXCLUDED.dpop_pds_nonce, 177 + auth_server_iss = EXCLUDED.auth_server_iss, 178 + auth_server_token_endpoint = EXCLUDED.auth_server_token_endpoint, 179 + auth_server_revocation_endpoint = EXCLUDED.auth_server_revocation_endpoint, 180 + scopes = EXCLUDED.scopes, 181 + expires_at = EXCLUDED.expires_at, 182 + updated_at = NOW() 183 + ` 184 + 185 + // Calculate token expiration using configured TTL 186 + expiresAt := time.Now().Add(s.sessionTTL) 187 + 188 + // Convert empty strings to NULL for optional fields 189 + var authServerRevocationEndpoint sql.NullString 190 + if sess.AuthServerRevocationEndpoint != "" { 191 + authServerRevocationEndpoint.String = sess.AuthServerRevocationEndpoint 192 + authServerRevocationEndpoint.Valid = true 193 + } 194 + 195 + // Extract handle from DID (placeholder - in real implementation, resolve from identity) 196 + // For now, use DID as handle since we don't have the handle in ClientSessionData 197 + handle := sess.AccountDID.String() 198 + 199 + // Use HostURL as PDS URL 200 + pdsURL := sess.HostURL 201 + if pdsURL == "" { 202 + pdsURL = sess.AuthServerURL // Fallback to auth server URL 203 + } 204 + 205 + _, err := s.db.ExecContext( 206 + ctx, query, 207 + sess.AccountDID.String(), 208 + sess.SessionID, 209 + handle, 210 + pdsURL, 211 + sess.HostURL, 212 + sess.AccessToken, 213 + sess.RefreshToken, 214 + sess.DPoPPrivateKeyMultibase, 215 + sess.DPoPAuthServerNonce, 216 + sess.DPoPHostNonce, 217 + sess.AuthServerURL, 218 + sess.AuthServerTokenEndpoint, 219 + authServerRevocationEndpoint, 220 + pq.Array(sess.Scopes), 221 + expiresAt, 222 + ) 223 + if err != nil { 224 + return fmt.Errorf("failed to save session: %w", err) 225 + } 226 + 227 + return nil 228 + } 229 + 230 + // DeleteSession deletes a session by DID and session ID 231 + func (s *PostgresOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 232 + query := `DELETE FROM oauth_sessions WHERE did = $1 AND session_id = $2` 233 + 234 + result, err := s.db.ExecContext(ctx, query, did.String(), sessionID) 235 + if err != nil { 236 + return fmt.Errorf("failed to delete session: %w", err) 237 + } 238 + 239 + rows, err := result.RowsAffected() 240 + if err != nil { 241 + return fmt.Errorf("failed to get rows affected: %w", err) 242 + } 243 + 244 + if rows == 0 { 245 + return ErrSessionNotFound 246 + } 247 + 248 + return nil 249 + } 250 + 251 + // GetAuthRequestInfo retrieves auth request information by state 252 + func (s *PostgresOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { 253 + query := ` 254 + SELECT 255 + state, did, handle, pds_url, pkce_verifier, 256 + dpop_private_key_multibase, dpop_authserver_nonce, 257 + auth_server_iss, request_uri, 258 + auth_server_token_endpoint, auth_server_revocation_endpoint, 259 + scopes, created_at 260 + FROM oauth_requests 261 + WHERE state = $1 262 + ` 263 + 264 + var info oauth.AuthRequestData 265 + var did, handle, pdsURL sql.NullString 266 + var dpopPrivateKeyMultibase, dpopAuthServerNonce sql.NullString 267 + var requestURI, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString 268 + var scopes pq.StringArray 269 + var createdAt time.Time 270 + 271 + err := s.db.QueryRowContext(ctx, query, state).Scan( 272 + &info.State, 273 + &did, 274 + &handle, 275 + &pdsURL, 276 + &info.PKCEVerifier, 277 + &dpopPrivateKeyMultibase, 278 + &dpopAuthServerNonce, 279 + &info.AuthServerURL, 280 + &requestURI, 281 + &authServerTokenEndpoint, 282 + &authServerRevocationEndpoint, 283 + &scopes, 284 + &createdAt, 285 + ) 286 + 287 + if err == sql.ErrNoRows { 288 + return nil, ErrAuthRequestNotFound 289 + } 290 + if err != nil { 291 + return nil, fmt.Errorf("failed to get auth request info: %w", err) 292 + } 293 + 294 + // Parse DID if present 295 + if did.Valid && did.String != "" { 296 + parsedDID, err := syntax.ParseDID(did.String) 297 + if err != nil { 298 + return nil, fmt.Errorf("failed to parse DID: %w", err) 299 + } 300 + info.AccountDID = &parsedDID 301 + } 302 + 303 + // Convert nullable fields 304 + if dpopPrivateKeyMultibase.Valid { 305 + info.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String 306 + } 307 + if dpopAuthServerNonce.Valid { 308 + info.DPoPAuthServerNonce = dpopAuthServerNonce.String 309 + } 310 + if requestURI.Valid { 311 + info.RequestURI = requestURI.String 312 + } 313 + if authServerTokenEndpoint.Valid { 314 + info.AuthServerTokenEndpoint = authServerTokenEndpoint.String 315 + } 316 + if authServerRevocationEndpoint.Valid { 317 + info.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String 318 + } 319 + info.Scopes = scopes 320 + 321 + return &info, nil 322 + } 323 + 324 + // SaveAuthRequestInfo saves auth request information (create only, not upsert) 325 + func (s *PostgresOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 326 + query := ` 327 + INSERT INTO oauth_requests ( 328 + state, did, handle, pds_url, pkce_verifier, 329 + dpop_private_key_multibase, dpop_authserver_nonce, 330 + auth_server_iss, request_uri, 331 + auth_server_token_endpoint, auth_server_revocation_endpoint, 332 + scopes, return_url, created_at 333 + ) VALUES ( 334 + $1, $2, $3, $4, $5, 335 + $6, $7, 336 + $8, $9, 337 + $10, $11, 338 + $12, NULL, NOW() 339 + ) 340 + ` 341 + 342 + // Extract DID string if present 343 + var didStr sql.NullString 344 + if info.AccountDID != nil { 345 + didStr.String = info.AccountDID.String() 346 + didStr.Valid = true 347 + } 348 + 349 + // Convert empty strings to NULL for optional fields 350 + var authServerRevocationEndpoint sql.NullString 351 + if info.AuthServerRevocationEndpoint != "" { 352 + authServerRevocationEndpoint.String = info.AuthServerRevocationEndpoint 353 + authServerRevocationEndpoint.Valid = true 354 + } 355 + 356 + // Placeholder values for handle and pds_url (not in AuthRequestData) 357 + // In production, these would be resolved during the auth flow 358 + handle := "" 359 + pdsURL := "" 360 + if info.AccountDID != nil { 361 + handle = info.AccountDID.String() // Temporary placeholder 362 + pdsURL = info.AuthServerURL // Temporary placeholder 363 + } 364 + 365 + _, err := s.db.ExecContext( 366 + ctx, query, 367 + info.State, 368 + didStr, 369 + handle, 370 + pdsURL, 371 + info.PKCEVerifier, 372 + info.DPoPPrivateKeyMultibase, 373 + info.DPoPAuthServerNonce, 374 + info.AuthServerURL, 375 + info.RequestURI, 376 + info.AuthServerTokenEndpoint, 377 + authServerRevocationEndpoint, 378 + pq.Array(info.Scopes), 379 + ) 380 + if err != nil { 381 + // Check for duplicate state 382 + if strings.Contains(err.Error(), "duplicate key") && strings.Contains(err.Error(), "oauth_requests_state_key") { 383 + return fmt.Errorf("auth request with state already exists: %s", info.State) 384 + } 385 + return fmt.Errorf("failed to save auth request info: %w", err) 386 + } 387 + 388 + return nil 389 + } 390 + 391 + // DeleteAuthRequestInfo deletes auth request information by state 392 + func (s *PostgresOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 393 + query := `DELETE FROM oauth_requests WHERE state = $1` 394 + 395 + result, err := s.db.ExecContext(ctx, query, state) 396 + if err != nil { 397 + return fmt.Errorf("failed to delete auth request info: %w", err) 398 + } 399 + 400 + rows, err := result.RowsAffected() 401 + if err != nil { 402 + return fmt.Errorf("failed to get rows affected: %w", err) 403 + } 404 + 405 + if rows == 0 { 406 + return ErrAuthRequestNotFound 407 + } 408 + 409 + return nil 410 + } 411 + 412 + // CleanupExpiredSessions removes sessions that have expired 413 + // Should be called periodically (e.g., via cron job) 414 + func (s *PostgresOAuthStore) CleanupExpiredSessions(ctx context.Context) (int64, error) { 415 + query := `DELETE FROM oauth_sessions WHERE expires_at < NOW()` 416 + 417 + result, err := s.db.ExecContext(ctx, query) 418 + if err != nil { 419 + return 0, fmt.Errorf("failed to cleanup expired sessions: %w", err) 420 + } 421 + 422 + rows, err := result.RowsAffected() 423 + if err != nil { 424 + return 0, fmt.Errorf("failed to get rows affected: %w", err) 425 + } 426 + 427 + return rows, nil 428 + } 429 + 430 + // CleanupExpiredAuthRequests removes auth requests older than 30 minutes 431 + // Should be called periodically (e.g., via cron job) 432 + func (s *PostgresOAuthStore) CleanupExpiredAuthRequests(ctx context.Context) (int64, error) { 433 + query := `DELETE FROM oauth_requests WHERE created_at < NOW() - INTERVAL '30 minutes'` 434 + 435 + result, err := s.db.ExecContext(ctx, query) 436 + if err != nil { 437 + return 0, fmt.Errorf("failed to cleanup expired auth requests: %w", err) 438 + } 439 + 440 + rows, err := result.RowsAffected() 441 + if err != nil { 442 + return 0, fmt.Errorf("failed to get rows affected: %w", err) 443 + } 444 + 445 + return rows, nil 446 + } 447 + 448 + // MobileOAuthData holds mobile-specific OAuth flow data 449 + type MobileOAuthData struct { 450 + CSRFToken string 451 + RedirectURI string 452 + } 453 + 454 + // mobileFlowContextKey is the context key for mobile flow data 455 + type mobileFlowContextKey struct{} 456 + 457 + // ContextWithMobileFlowData adds mobile flow data to a context. 458 + // This is used by HandleMobileLogin to pass mobile data to the store wrapper, 459 + // which will save it when SaveAuthRequestInfo is called by indigo. 460 + func ContextWithMobileFlowData(ctx context.Context, data MobileOAuthData) context.Context { 461 + return context.WithValue(ctx, mobileFlowContextKey{}, data) 462 + } 463 + 464 + // getMobileFlowDataFromContext retrieves mobile flow data from context, if present 465 + func getMobileFlowDataFromContext(ctx context.Context) (MobileOAuthData, bool) { 466 + data, ok := ctx.Value(mobileFlowContextKey{}).(MobileOAuthData) 467 + return data, ok 468 + } 469 + 470 + // MobileAwareClientStore is a marker interface that indicates a store is properly 471 + // configured for mobile OAuth flows. Only stores that intercept SaveAuthRequestInfo 472 + // to save mobile CSRF data should implement this interface. 473 + // This prevents silent mobile OAuth breakage when a plain PostgresOAuthStore is used. 474 + type MobileAwareClientStore interface { 475 + IsMobileAware() bool 476 + } 477 + 478 + // MobileAwareStoreWrapper wraps a ClientAuthStore to automatically save mobile 479 + // CSRF data when SaveAuthRequestInfo is called during a mobile OAuth flow. 480 + // This is necessary because indigo's StartAuthFlow doesn't expose the OAuth state, 481 + // so we intercept the SaveAuthRequestInfo call to capture it. 482 + type MobileAwareStoreWrapper struct { 483 + oauth.ClientAuthStore 484 + mobileStore MobileOAuthStore 485 + } 486 + 487 + // IsMobileAware implements MobileAwareClientStore, indicating this store 488 + // properly saves mobile CSRF data during OAuth flow initiation. 489 + func (w *MobileAwareStoreWrapper) IsMobileAware() bool { 490 + return true 491 + } 492 + 493 + // NewMobileAwareStoreWrapper creates a wrapper that intercepts SaveAuthRequestInfo 494 + // to also save mobile CSRF data when present in context. 495 + func NewMobileAwareStoreWrapper(store oauth.ClientAuthStore) *MobileAwareStoreWrapper { 496 + wrapper := &MobileAwareStoreWrapper{ 497 + ClientAuthStore: store, 498 + } 499 + // Check if the underlying store implements MobileOAuthStore 500 + if ms, ok := store.(MobileOAuthStore); ok { 501 + wrapper.mobileStore = ms 502 + } 503 + return wrapper 504 + } 505 + 506 + // SaveAuthRequestInfo saves the auth request and also saves mobile CSRF data 507 + // if mobile flow data is present in the context. 508 + func (w *MobileAwareStoreWrapper) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 509 + // First, save the auth request to the underlying store 510 + if err := w.ClientAuthStore.SaveAuthRequestInfo(ctx, info); err != nil { 511 + return err 512 + } 513 + 514 + // Check if this is a mobile flow (mobile data in context) 515 + if mobileData, ok := getMobileFlowDataFromContext(ctx); ok && w.mobileStore != nil { 516 + // Save mobile CSRF data tied to this OAuth state 517 + // IMPORTANT: If this fails, we MUST propagate the error. Otherwise: 518 + // 1. No server-side CSRF record is stored 519 + // 2. Every mobile callback will "fail closed" to web flow 520 + // 3. Mobile sign-in silently breaks with no indication 521 + // Failing loudly here lets the user retry rather than being confused 522 + // about why they're getting a web flow instead of mobile. 523 + if err := w.mobileStore.SaveMobileOAuthData(ctx, info.State, mobileData); err != nil { 524 + slog.Error("failed to save mobile CSRF data - mobile login will fail", 525 + "state", info.State, "error", err) 526 + return fmt.Errorf("failed to save mobile OAuth data: %w", err) 527 + } 528 + } 529 + 530 + return nil 531 + } 532 + 533 + // GetMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store 534 + func (w *MobileAwareStoreWrapper) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) { 535 + if w.mobileStore != nil { 536 + return w.mobileStore.GetMobileOAuthData(ctx, state) 537 + } 538 + return nil, nil 539 + } 540 + 541 + // SaveMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store 542 + func (w *MobileAwareStoreWrapper) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error { 543 + if w.mobileStore != nil { 544 + return w.mobileStore.SaveMobileOAuthData(ctx, state, data) 545 + } 546 + return nil 547 + } 548 + 549 + // UnwrapPostgresStore returns the underlying PostgresOAuthStore if present. 550 + // This is useful for accessing cleanup methods that aren't part of the interface. 551 + func (w *MobileAwareStoreWrapper) UnwrapPostgresStore() *PostgresOAuthStore { 552 + if ps, ok := w.ClientAuthStore.(*PostgresOAuthStore); ok { 553 + return ps 554 + } 555 + return nil 556 + } 557 + 558 + // SaveMobileOAuthData stores mobile CSRF data tied to an OAuth state 559 + // This ties the CSRF token to the OAuth flow via the state parameter, 560 + // which comes back through the OAuth response for server-side validation. 561 + func (s *PostgresOAuthStore) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error { 562 + query := ` 563 + UPDATE oauth_requests 564 + SET mobile_csrf_token = $2, mobile_redirect_uri = $3 565 + WHERE state = $1 566 + ` 567 + 568 + result, err := s.db.ExecContext(ctx, query, state, data.CSRFToken, data.RedirectURI) 569 + if err != nil { 570 + return fmt.Errorf("failed to save mobile OAuth data: %w", err) 571 + } 572 + 573 + rows, err := result.RowsAffected() 574 + if err != nil { 575 + return fmt.Errorf("failed to get rows affected: %w", err) 576 + } 577 + 578 + if rows == 0 { 579 + return ErrAuthRequestNotFound 580 + } 581 + 582 + return nil 583 + } 584 + 585 + // GetMobileOAuthData retrieves mobile CSRF data by OAuth state 586 + // This is called during callback to compare the server-side CSRF token 587 + // (retrieved by state from the OAuth response) against the cookie CSRF. 588 + func (s *PostgresOAuthStore) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) { 589 + query := ` 590 + SELECT mobile_csrf_token, mobile_redirect_uri 591 + FROM oauth_requests 592 + WHERE state = $1 593 + ` 594 + 595 + var csrfToken, redirectURI sql.NullString 596 + err := s.db.QueryRowContext(ctx, query, state).Scan(&csrfToken, &redirectURI) 597 + 598 + if err == sql.ErrNoRows { 599 + return nil, ErrAuthRequestNotFound 600 + } 601 + if err != nil { 602 + return nil, fmt.Errorf("failed to get mobile OAuth data: %w", err) 603 + } 604 + 605 + // Return nil if no mobile data was stored (this was a web flow) 606 + if !csrfToken.Valid { 607 + return nil, nil 608 + } 609 + 610 + return &MobileOAuthData{ 611 + CSRFToken: csrfToken.String, 612 + RedirectURI: redirectURI.String, 613 + }, nil 614 + }
+522
internal/atproto/oauth/store_test.go
··· 1 + package oauth 2 + 3 + import ( 4 + "context" 5 + "database/sql" 6 + "os" 7 + "testing" 8 + 9 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 10 + "github.com/bluesky-social/indigo/atproto/syntax" 11 + _ "github.com/lib/pq" 12 + "github.com/pressly/goose/v3" 13 + "github.com/stretchr/testify/assert" 14 + "github.com/stretchr/testify/require" 15 + ) 16 + 17 + // setupTestDB creates a test database connection and runs migrations 18 + func setupTestDB(t *testing.T) *sql.DB { 19 + dsn := os.Getenv("TEST_DATABASE_URL") 20 + if dsn == "" { 21 + dsn = "postgres://test_user:test_password@localhost:5434/coves_test?sslmode=disable" 22 + } 23 + 24 + db, err := sql.Open("postgres", dsn) 25 + require.NoError(t, err, "Failed to connect to test database") 26 + 27 + // Run migrations 28 + require.NoError(t, goose.Up(db, "../../db/migrations"), "Failed to run migrations") 29 + 30 + return db 31 + } 32 + 33 + // cleanupOAuth removes all test OAuth data from the database 34 + func cleanupOAuth(t *testing.T, db *sql.DB) { 35 + _, err := db.Exec("DELETE FROM oauth_sessions WHERE did LIKE 'did:plc:test%'") 36 + require.NoError(t, err, "Failed to cleanup oauth_sessions") 37 + 38 + _, err = db.Exec("DELETE FROM oauth_requests WHERE state LIKE 'test%'") 39 + require.NoError(t, err, "Failed to cleanup oauth_requests") 40 + } 41 + 42 + func TestPostgresOAuthStore_SaveAndGetSession(t *testing.T) { 43 + db := setupTestDB(t) 44 + defer func() { _ = db.Close() }() 45 + defer cleanupOAuth(t, db) 46 + 47 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 48 + ctx := context.Background() 49 + 50 + did, err := syntax.ParseDID("did:plc:test123abc") 51 + require.NoError(t, err) 52 + 53 + session := oauth.ClientSessionData{ 54 + AccountDID: did, 55 + SessionID: "session123", 56 + HostURL: "https://pds.example.com", 57 + AuthServerURL: "https://auth.example.com", 58 + AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 59 + AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke", 60 + Scopes: []string{"atproto"}, 61 + AccessToken: "at_test_token_abc123", 62 + RefreshToken: "rt_test_token_xyz789", 63 + DPoPAuthServerNonce: "nonce_auth_123", 64 + DPoPHostNonce: "nonce_host_456", 65 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 66 + } 67 + 68 + // Save session 69 + err = store.SaveSession(ctx, session) 70 + assert.NoError(t, err) 71 + 72 + // Retrieve session 73 + retrieved, err := store.GetSession(ctx, did, "session123") 74 + assert.NoError(t, err) 75 + assert.NotNil(t, retrieved) 76 + assert.Equal(t, session.AccountDID.String(), retrieved.AccountDID.String()) 77 + assert.Equal(t, session.SessionID, retrieved.SessionID) 78 + assert.Equal(t, session.HostURL, retrieved.HostURL) 79 + assert.Equal(t, session.AuthServerURL, retrieved.AuthServerURL) 80 + assert.Equal(t, session.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint) 81 + assert.Equal(t, session.AccessToken, retrieved.AccessToken) 82 + assert.Equal(t, session.RefreshToken, retrieved.RefreshToken) 83 + assert.Equal(t, session.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce) 84 + assert.Equal(t, session.DPoPHostNonce, retrieved.DPoPHostNonce) 85 + assert.Equal(t, session.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase) 86 + assert.Equal(t, session.Scopes, retrieved.Scopes) 87 + } 88 + 89 + func TestPostgresOAuthStore_SaveSession_Upsert(t *testing.T) { 90 + db := setupTestDB(t) 91 + defer func() { _ = db.Close() }() 92 + defer cleanupOAuth(t, db) 93 + 94 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 95 + ctx := context.Background() 96 + 97 + did, err := syntax.ParseDID("did:plc:testupsert") 98 + require.NoError(t, err) 99 + 100 + // Initial session 101 + session1 := oauth.ClientSessionData{ 102 + AccountDID: did, 103 + SessionID: "session_upsert", 104 + HostURL: "https://pds1.example.com", 105 + AuthServerURL: "https://auth1.example.com", 106 + AuthServerTokenEndpoint: "https://auth1.example.com/oauth/token", 107 + Scopes: []string{"atproto"}, 108 + AccessToken: "old_access_token", 109 + RefreshToken: "old_refresh_token", 110 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 111 + } 112 + 113 + err = store.SaveSession(ctx, session1) 114 + require.NoError(t, err) 115 + 116 + // Updated session (same DID and session ID) 117 + session2 := oauth.ClientSessionData{ 118 + AccountDID: did, 119 + SessionID: "session_upsert", 120 + HostURL: "https://pds2.example.com", 121 + AuthServerURL: "https://auth2.example.com", 122 + AuthServerTokenEndpoint: "https://auth2.example.com/oauth/token", 123 + Scopes: []string{"atproto", "transition:generic"}, 124 + AccessToken: "new_access_token", 125 + RefreshToken: "new_refresh_token", 126 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX", 127 + } 128 + 129 + // Save again - should update 130 + err = store.SaveSession(ctx, session2) 131 + assert.NoError(t, err) 132 + 133 + // Retrieve should get updated values 134 + retrieved, err := store.GetSession(ctx, did, "session_upsert") 135 + assert.NoError(t, err) 136 + assert.Equal(t, "new_access_token", retrieved.AccessToken) 137 + assert.Equal(t, "new_refresh_token", retrieved.RefreshToken) 138 + assert.Equal(t, "https://pds2.example.com", retrieved.HostURL) 139 + assert.Equal(t, []string{"atproto", "transition:generic"}, retrieved.Scopes) 140 + } 141 + 142 + func TestPostgresOAuthStore_GetSession_NotFound(t *testing.T) { 143 + db := setupTestDB(t) 144 + defer func() { _ = db.Close() }() 145 + 146 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 147 + ctx := context.Background() 148 + 149 + did, err := syntax.ParseDID("did:plc:nonexistent") 150 + require.NoError(t, err) 151 + 152 + _, err = store.GetSession(ctx, did, "nonexistent_session") 153 + assert.ErrorIs(t, err, ErrSessionNotFound) 154 + } 155 + 156 + func TestPostgresOAuthStore_DeleteSession(t *testing.T) { 157 + db := setupTestDB(t) 158 + defer func() { _ = db.Close() }() 159 + defer cleanupOAuth(t, db) 160 + 161 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 162 + ctx := context.Background() 163 + 164 + did, err := syntax.ParseDID("did:plc:testdelete") 165 + require.NoError(t, err) 166 + 167 + session := oauth.ClientSessionData{ 168 + AccountDID: did, 169 + SessionID: "session_delete", 170 + HostURL: "https://pds.example.com", 171 + AuthServerURL: "https://auth.example.com", 172 + AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 173 + Scopes: []string{"atproto"}, 174 + AccessToken: "test_token", 175 + RefreshToken: "test_refresh", 176 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 177 + } 178 + 179 + // Save session 180 + err = store.SaveSession(ctx, session) 181 + require.NoError(t, err) 182 + 183 + // Delete session 184 + err = store.DeleteSession(ctx, did, "session_delete") 185 + assert.NoError(t, err) 186 + 187 + // Verify session is gone 188 + _, err = store.GetSession(ctx, did, "session_delete") 189 + assert.ErrorIs(t, err, ErrSessionNotFound) 190 + } 191 + 192 + func TestPostgresOAuthStore_DeleteSession_NotFound(t *testing.T) { 193 + db := setupTestDB(t) 194 + defer func() { _ = db.Close() }() 195 + 196 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 197 + ctx := context.Background() 198 + 199 + did, err := syntax.ParseDID("did:plc:nonexistent") 200 + require.NoError(t, err) 201 + 202 + err = store.DeleteSession(ctx, did, "nonexistent_session") 203 + assert.ErrorIs(t, err, ErrSessionNotFound) 204 + } 205 + 206 + func TestPostgresOAuthStore_SaveAndGetAuthRequestInfo(t *testing.T) { 207 + db := setupTestDB(t) 208 + defer func() { _ = db.Close() }() 209 + defer cleanupOAuth(t, db) 210 + 211 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 212 + ctx := context.Background() 213 + 214 + did, err := syntax.ParseDID("did:plc:testrequest") 215 + require.NoError(t, err) 216 + 217 + info := oauth.AuthRequestData{ 218 + State: "test_state_abc123", 219 + AuthServerURL: "https://auth.example.com", 220 + AccountDID: &did, 221 + Scopes: []string{"atproto"}, 222 + RequestURI: "urn:ietf:params:oauth:request_uri:abc123", 223 + AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 224 + AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke", 225 + PKCEVerifier: "verifier_xyz789", 226 + DPoPAuthServerNonce: "nonce_abc", 227 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 228 + } 229 + 230 + // Save auth request info 231 + err = store.SaveAuthRequestInfo(ctx, info) 232 + assert.NoError(t, err) 233 + 234 + // Retrieve auth request info 235 + retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_abc123") 236 + assert.NoError(t, err) 237 + assert.NotNil(t, retrieved) 238 + assert.Equal(t, info.State, retrieved.State) 239 + assert.Equal(t, info.AuthServerURL, retrieved.AuthServerURL) 240 + assert.NotNil(t, retrieved.AccountDID) 241 + assert.Equal(t, info.AccountDID.String(), retrieved.AccountDID.String()) 242 + assert.Equal(t, info.Scopes, retrieved.Scopes) 243 + assert.Equal(t, info.RequestURI, retrieved.RequestURI) 244 + assert.Equal(t, info.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint) 245 + assert.Equal(t, info.PKCEVerifier, retrieved.PKCEVerifier) 246 + assert.Equal(t, info.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce) 247 + assert.Equal(t, info.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase) 248 + } 249 + 250 + func TestPostgresOAuthStore_SaveAuthRequestInfo_NoDID(t *testing.T) { 251 + db := setupTestDB(t) 252 + defer func() { _ = db.Close() }() 253 + defer cleanupOAuth(t, db) 254 + 255 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 256 + ctx := context.Background() 257 + 258 + info := oauth.AuthRequestData{ 259 + State: "test_state_nodid", 260 + AuthServerURL: "https://auth.example.com", 261 + AccountDID: nil, // No DID provided 262 + Scopes: []string{"atproto"}, 263 + RequestURI: "urn:ietf:params:oauth:request_uri:nodid", 264 + AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 265 + PKCEVerifier: "verifier_nodid", 266 + DPoPAuthServerNonce: "nonce_nodid", 267 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 268 + } 269 + 270 + // Save auth request info without DID 271 + err := store.SaveAuthRequestInfo(ctx, info) 272 + assert.NoError(t, err) 273 + 274 + // Retrieve and verify DID is nil 275 + retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_nodid") 276 + assert.NoError(t, err) 277 + assert.Nil(t, retrieved.AccountDID) 278 + assert.Equal(t, info.State, retrieved.State) 279 + } 280 + 281 + func TestPostgresOAuthStore_GetAuthRequestInfo_NotFound(t *testing.T) { 282 + db := setupTestDB(t) 283 + defer func() { _ = db.Close() }() 284 + 285 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 286 + ctx := context.Background() 287 + 288 + _, err := store.GetAuthRequestInfo(ctx, "nonexistent_state") 289 + assert.ErrorIs(t, err, ErrAuthRequestNotFound) 290 + } 291 + 292 + func TestPostgresOAuthStore_DeleteAuthRequestInfo(t *testing.T) { 293 + db := setupTestDB(t) 294 + defer func() { _ = db.Close() }() 295 + defer cleanupOAuth(t, db) 296 + 297 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 298 + ctx := context.Background() 299 + 300 + info := oauth.AuthRequestData{ 301 + State: "test_state_delete", 302 + AuthServerURL: "https://auth.example.com", 303 + Scopes: []string{"atproto"}, 304 + RequestURI: "urn:ietf:params:oauth:request_uri:delete", 305 + AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 306 + PKCEVerifier: "verifier_delete", 307 + DPoPAuthServerNonce: "nonce_delete", 308 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 309 + } 310 + 311 + // Save auth request info 312 + err := store.SaveAuthRequestInfo(ctx, info) 313 + require.NoError(t, err) 314 + 315 + // Delete auth request info 316 + err = store.DeleteAuthRequestInfo(ctx, "test_state_delete") 317 + assert.NoError(t, err) 318 + 319 + // Verify it's gone 320 + _, err = store.GetAuthRequestInfo(ctx, "test_state_delete") 321 + assert.ErrorIs(t, err, ErrAuthRequestNotFound) 322 + } 323 + 324 + func TestPostgresOAuthStore_DeleteAuthRequestInfo_NotFound(t *testing.T) { 325 + db := setupTestDB(t) 326 + defer func() { _ = db.Close() }() 327 + 328 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 329 + ctx := context.Background() 330 + 331 + err := store.DeleteAuthRequestInfo(ctx, "nonexistent_state") 332 + assert.ErrorIs(t, err, ErrAuthRequestNotFound) 333 + } 334 + 335 + func TestPostgresOAuthStore_CleanupExpiredSessions(t *testing.T) { 336 + db := setupTestDB(t) 337 + defer func() { _ = db.Close() }() 338 + defer cleanupOAuth(t, db) 339 + 340 + storeInterface := NewPostgresOAuthStore(db, 0) // Use default TTL 341 + store, ok := storeInterface.(*PostgresOAuthStore) 342 + require.True(t, ok, "store should be *PostgresOAuthStore") 343 + ctx := context.Background() 344 + 345 + did1, err := syntax.ParseDID("did:plc:testexpired1") 346 + require.NoError(t, err) 347 + did2, err := syntax.ParseDID("did:plc:testexpired2") 348 + require.NoError(t, err) 349 + 350 + // Create an expired session (manually insert with past expiration) 351 + _, err = db.ExecContext(ctx, ` 352 + INSERT INTO oauth_sessions ( 353 + did, session_id, handle, pds_url, host_url, 354 + access_token, refresh_token, 355 + dpop_private_key_multibase, auth_server_iss, 356 + auth_server_token_endpoint, scopes, 357 + expires_at, created_at 358 + ) VALUES ( 359 + $1, $2, $3, $4, $5, 360 + $6, $7, 361 + $8, $9, 362 + $10, $11, 363 + NOW() - INTERVAL '1 day', NOW() 364 + ) 365 + `, did1.String(), "expired_session", "test.handle", "https://pds.example.com", "https://pds.example.com", 366 + "expired_token", "expired_refresh", 367 + "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", "https://auth.example.com", 368 + "https://auth.example.com/oauth/token", `{"atproto"}`) 369 + require.NoError(t, err) 370 + 371 + // Create a valid session 372 + validSession := oauth.ClientSessionData{ 373 + AccountDID: did2, 374 + SessionID: "valid_session", 375 + HostURL: "https://pds.example.com", 376 + AuthServerURL: "https://auth.example.com", 377 + AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 378 + Scopes: []string{"atproto"}, 379 + AccessToken: "valid_token", 380 + RefreshToken: "valid_refresh", 381 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 382 + } 383 + err = store.SaveSession(ctx, validSession) 384 + require.NoError(t, err) 385 + 386 + // Cleanup expired sessions 387 + count, err := store.CleanupExpiredSessions(ctx) 388 + assert.NoError(t, err) 389 + assert.Equal(t, int64(1), count, "Should delete 1 expired session") 390 + 391 + // Verify expired session is gone 392 + _, err = store.GetSession(ctx, did1, "expired_session") 393 + assert.ErrorIs(t, err, ErrSessionNotFound) 394 + 395 + // Verify valid session still exists 396 + _, err = store.GetSession(ctx, did2, "valid_session") 397 + assert.NoError(t, err) 398 + } 399 + 400 + func TestPostgresOAuthStore_CleanupExpiredAuthRequests(t *testing.T) { 401 + db := setupTestDB(t) 402 + defer func() { _ = db.Close() }() 403 + defer cleanupOAuth(t, db) 404 + 405 + storeInterface := NewPostgresOAuthStore(db, 0) 406 + pgStore, ok := storeInterface.(*PostgresOAuthStore) 407 + require.True(t, ok, "store should be *PostgresOAuthStore") 408 + store := oauth.ClientAuthStore(pgStore) 409 + ctx := context.Background() 410 + 411 + // Create an old auth request (manually insert with old timestamp) 412 + _, err := db.ExecContext(ctx, ` 413 + INSERT INTO oauth_requests ( 414 + state, did, handle, pds_url, pkce_verifier, 415 + dpop_private_key_multibase, dpop_authserver_nonce, 416 + auth_server_iss, request_uri, 417 + auth_server_token_endpoint, scopes, 418 + created_at 419 + ) VALUES ( 420 + $1, $2, $3, $4, $5, 421 + $6, $7, 422 + $8, $9, 423 + $10, $11, 424 + NOW() - INTERVAL '1 hour' 425 + ) 426 + `, "test_old_state", "did:plc:testold", "test.handle", "https://pds.example.com", 427 + "old_verifier", "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 428 + "nonce_old", "https://auth.example.com", "urn:ietf:params:oauth:request_uri:old", 429 + "https://auth.example.com/oauth/token", `{"atproto"}`) 430 + require.NoError(t, err) 431 + 432 + // Create a recent auth request 433 + recentInfo := oauth.AuthRequestData{ 434 + State: "test_recent_state", 435 + AuthServerURL: "https://auth.example.com", 436 + Scopes: []string{"atproto"}, 437 + RequestURI: "urn:ietf:params:oauth:request_uri:recent", 438 + AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 439 + PKCEVerifier: "recent_verifier", 440 + DPoPAuthServerNonce: "nonce_recent", 441 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 442 + } 443 + err = store.SaveAuthRequestInfo(ctx, recentInfo) 444 + require.NoError(t, err) 445 + 446 + // Cleanup expired auth requests (older than 30 minutes) 447 + count, err := pgStore.CleanupExpiredAuthRequests(ctx) 448 + assert.NoError(t, err) 449 + assert.Equal(t, int64(1), count, "Should delete 1 expired auth request") 450 + 451 + // Verify old request is gone 452 + _, err = store.GetAuthRequestInfo(ctx, "test_old_state") 453 + assert.ErrorIs(t, err, ErrAuthRequestNotFound) 454 + 455 + // Verify recent request still exists 456 + _, err = store.GetAuthRequestInfo(ctx, "test_recent_state") 457 + assert.NoError(t, err) 458 + } 459 + 460 + func TestPostgresOAuthStore_MultipleSessions(t *testing.T) { 461 + db := setupTestDB(t) 462 + defer func() { _ = db.Close() }() 463 + defer cleanupOAuth(t, db) 464 + 465 + store := NewPostgresOAuthStore(db, 0) // Use default TTL 466 + ctx := context.Background() 467 + 468 + did, err := syntax.ParseDID("did:plc:testmulti") 469 + require.NoError(t, err) 470 + 471 + // Create multiple sessions for the same DID 472 + session1 := oauth.ClientSessionData{ 473 + AccountDID: did, 474 + SessionID: "browser1", 475 + HostURL: "https://pds.example.com", 476 + AuthServerURL: "https://auth.example.com", 477 + AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 478 + Scopes: []string{"atproto"}, 479 + AccessToken: "token_browser1", 480 + RefreshToken: "refresh_browser1", 481 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 482 + } 483 + 484 + session2 := oauth.ClientSessionData{ 485 + AccountDID: did, 486 + SessionID: "mobile_app", 487 + HostURL: "https://pds.example.com", 488 + AuthServerURL: "https://auth.example.com", 489 + AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 490 + Scopes: []string{"atproto"}, 491 + AccessToken: "token_mobile", 492 + RefreshToken: "refresh_mobile", 493 + DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX", 494 + } 495 + 496 + // Save both sessions 497 + err = store.SaveSession(ctx, session1) 498 + require.NoError(t, err) 499 + err = store.SaveSession(ctx, session2) 500 + require.NoError(t, err) 501 + 502 + // Retrieve both sessions 503 + retrieved1, err := store.GetSession(ctx, did, "browser1") 504 + assert.NoError(t, err) 505 + assert.Equal(t, "token_browser1", retrieved1.AccessToken) 506 + 507 + retrieved2, err := store.GetSession(ctx, did, "mobile_app") 508 + assert.NoError(t, err) 509 + assert.Equal(t, "token_mobile", retrieved2.AccessToken) 510 + 511 + // Delete one session 512 + err = store.DeleteSession(ctx, did, "browser1") 513 + assert.NoError(t, err) 514 + 515 + // Verify only browser1 is deleted 516 + _, err = store.GetSession(ctx, did, "browser1") 517 + assert.ErrorIs(t, err, ErrSessionNotFound) 518 + 519 + // mobile_app should still exist 520 + _, err = store.GetSession(ctx, did, "mobile_app") 521 + assert.NoError(t, err) 522 + }
+99
internal/atproto/oauth/transport.go
··· 1 + package oauth 2 + 3 + import ( 4 + "fmt" 5 + "net" 6 + "net/http" 7 + "time" 8 + ) 9 + 10 + // ssrfSafeTransport wraps http.Transport to prevent SSRF attacks 11 + type ssrfSafeTransport struct { 12 + base *http.Transport 13 + allowPrivate bool // For dev/testing only 14 + } 15 + 16 + // isPrivateIP checks if an IP is in a private/reserved range 17 + func isPrivateIP(ip net.IP) bool { 18 + if ip == nil { 19 + return false 20 + } 21 + 22 + // Check for loopback 23 + if ip.IsLoopback() { 24 + return true 25 + } 26 + 27 + // Check for link-local 28 + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { 29 + return true 30 + } 31 + 32 + // Check for private ranges 33 + privateRanges := []string{ 34 + "10.0.0.0/8", 35 + "172.16.0.0/12", 36 + "192.168.0.0/16", 37 + "169.254.0.0/16", 38 + "::1/128", 39 + "fc00::/7", 40 + "fe80::/10", 41 + } 42 + 43 + for _, cidr := range privateRanges { 44 + _, network, err := net.ParseCIDR(cidr) 45 + if err == nil && network.Contains(ip) { 46 + return true 47 + } 48 + } 49 + 50 + return false 51 + } 52 + 53 + func (t *ssrfSafeTransport) RoundTrip(req *http.Request) (*http.Response, error) { 54 + host := req.URL.Hostname() 55 + 56 + // Resolve hostname to IP 57 + ips, err := net.LookupIP(host) 58 + if err != nil { 59 + return nil, fmt.Errorf("failed to resolve host: %w", err) 60 + } 61 + 62 + // Check all resolved IPs 63 + if !t.allowPrivate { 64 + for _, ip := range ips { 65 + if isPrivateIP(ip) { 66 + return nil, fmt.Errorf("SSRF blocked: %s resolves to private IP %s", host, ip) 67 + } 68 + } 69 + } 70 + 71 + return t.base.RoundTrip(req) 72 + } 73 + 74 + // NewSSRFSafeHTTPClient creates an HTTP client with SSRF protections 75 + func NewSSRFSafeHTTPClient(allowPrivate bool) *http.Client { 76 + transport := &ssrfSafeTransport{ 77 + base: &http.Transport{ 78 + DialContext: (&net.Dialer{ 79 + Timeout: 10 * time.Second, 80 + KeepAlive: 30 * time.Second, 81 + }).DialContext, 82 + MaxIdleConns: 100, 83 + IdleConnTimeout: 90 * time.Second, 84 + TLSHandshakeTimeout: 10 * time.Second, 85 + }, 86 + allowPrivate: allowPrivate, 87 + } 88 + 89 + return &http.Client{ 90 + Timeout: 15 * time.Second, 91 + Transport: transport, 92 + CheckRedirect: func(req *http.Request, via []*http.Request) error { 93 + if len(via) >= 5 { 94 + return fmt.Errorf("too many redirects") 95 + } 96 + return nil 97 + }, 98 + } 99 + }
+132
internal/atproto/oauth/transport_test.go
··· 1 + package oauth 2 + 3 + import ( 4 + "net" 5 + "net/http" 6 + "testing" 7 + ) 8 + 9 + func TestIsPrivateIP(t *testing.T) { 10 + tests := []struct { 11 + name string 12 + ip string 13 + expected bool 14 + }{ 15 + // Loopback addresses 16 + {"IPv4 loopback", "127.0.0.1", true}, 17 + {"IPv6 loopback", "::1", true}, 18 + 19 + // Private IPv4 ranges 20 + {"Private 10.x.x.x", "10.0.0.1", true}, 21 + {"Private 10.x.x.x edge", "10.255.255.255", true}, 22 + {"Private 172.16.x.x", "172.16.0.1", true}, 23 + {"Private 172.31.x.x edge", "172.31.255.255", true}, 24 + {"Private 192.168.x.x", "192.168.1.1", true}, 25 + {"Private 192.168.x.x edge", "192.168.255.255", true}, 26 + 27 + // Link-local addresses 28 + {"Link-local IPv4", "169.254.1.1", true}, 29 + {"Link-local IPv6", "fe80::1", true}, 30 + 31 + // IPv6 private ranges 32 + {"IPv6 unique local fc00", "fc00::1", true}, 33 + {"IPv6 unique local fd00", "fd00::1", true}, 34 + 35 + // Public addresses 36 + {"Public IP 1.1.1.1", "1.1.1.1", false}, 37 + {"Public IP 8.8.8.8", "8.8.8.8", false}, 38 + {"Public IP 172.15.0.1", "172.15.0.1", false}, // Just before 172.16/12 39 + {"Public IP 172.32.0.1", "172.32.0.1", false}, // Just after 172.31/12 40 + {"Public IP 11.0.0.1", "11.0.0.1", false}, // Just after 10/8 41 + {"Public IPv6", "2001:4860:4860::8888", false}, // Google DNS 42 + } 43 + 44 + for _, tt := range tests { 45 + t.Run(tt.name, func(t *testing.T) { 46 + ip := net.ParseIP(tt.ip) 47 + if ip == nil { 48 + t.Fatalf("Failed to parse IP: %s", tt.ip) 49 + } 50 + 51 + result := isPrivateIP(ip) 52 + if result != tt.expected { 53 + t.Errorf("isPrivateIP(%s) = %v, expected %v", tt.ip, result, tt.expected) 54 + } 55 + }) 56 + } 57 + } 58 + 59 + func TestIsPrivateIP_NilIP(t *testing.T) { 60 + result := isPrivateIP(nil) 61 + if result != false { 62 + t.Errorf("isPrivateIP(nil) = %v, expected false", result) 63 + } 64 + } 65 + 66 + func TestNewSSRFSafeHTTPClient(t *testing.T) { 67 + tests := []struct { 68 + name string 69 + allowPrivate bool 70 + }{ 71 + {"Production client (no private IPs)", false}, 72 + {"Development client (allow private IPs)", true}, 73 + } 74 + 75 + for _, tt := range tests { 76 + t.Run(tt.name, func(t *testing.T) { 77 + client := NewSSRFSafeHTTPClient(tt.allowPrivate) 78 + 79 + if client == nil { 80 + t.Fatal("NewSSRFSafeHTTPClient returned nil") 81 + } 82 + 83 + if client.Timeout == 0 { 84 + t.Error("Expected timeout to be set") 85 + } 86 + 87 + if client.Transport == nil { 88 + t.Error("Expected transport to be set") 89 + } 90 + 91 + transport, ok := client.Transport.(*ssrfSafeTransport) 92 + if !ok { 93 + t.Error("Expected ssrfSafeTransport") 94 + } 95 + 96 + if transport.allowPrivate != tt.allowPrivate { 97 + t.Errorf("Expected allowPrivate=%v, got %v", tt.allowPrivate, transport.allowPrivate) 98 + } 99 + }) 100 + } 101 + } 102 + 103 + func TestSSRFSafeHTTPClient_RedirectLimit(t *testing.T) { 104 + client := NewSSRFSafeHTTPClient(false) 105 + 106 + // Simulate checking redirect limit 107 + if client.CheckRedirect == nil { 108 + t.Fatal("Expected CheckRedirect to be set") 109 + } 110 + 111 + // Test redirect limit (5 redirects) 112 + var via []*http.Request 113 + for i := 0; i < 5; i++ { 114 + req := &http.Request{} 115 + via = append(via, req) 116 + } 117 + 118 + err := client.CheckRedirect(nil, via) 119 + if err == nil { 120 + t.Error("Expected error for too many redirects") 121 + } 122 + if err.Error() != "too many redirects" { 123 + t.Errorf("Expected 'too many redirects' error, got: %v", err) 124 + } 125 + 126 + // Test within limit (4 redirects) 127 + via = via[:4] 128 + err = client.CheckRedirect(nil, via) 129 + if err != nil { 130 + t.Errorf("Expected no error for 4 redirects, got: %v", err) 131 + } 132 + }
+124
internal/db/migrations/019_update_oauth_for_indigo.sql
··· 1 + -- +goose Up 2 + -- Update OAuth tables to match indigo's ClientAuthStore interface requirements 3 + -- This migration adds columns needed for OAuth client sessions and auth requests 4 + 5 + -- Update oauth_requests table 6 + -- Add columns for request URI, auth server endpoints, scopes, and DPoP key 7 + ALTER TABLE oauth_requests 8 + ADD COLUMN request_uri TEXT, 9 + ADD COLUMN auth_server_token_endpoint TEXT, 10 + ADD COLUMN auth_server_revocation_endpoint TEXT, 11 + ADD COLUMN scopes TEXT[], 12 + ADD COLUMN dpop_private_key_multibase TEXT; 13 + 14 + -- Make original dpop_private_jwk nullable (we now use dpop_private_key_multibase) 15 + ALTER TABLE oauth_requests ALTER COLUMN dpop_private_jwk DROP NOT NULL; 16 + 17 + -- Make did nullable (indigo's AuthRequestData.AccountDID is a pointer - optional) 18 + ALTER TABLE oauth_requests ALTER COLUMN did DROP NOT NULL; 19 + 20 + -- Make handle and pds_url nullable too (derived from DID resolution, not always available at auth request time) 21 + ALTER TABLE oauth_requests ALTER COLUMN handle DROP NOT NULL; 22 + ALTER TABLE oauth_requests ALTER COLUMN pds_url DROP NOT NULL; 23 + 24 + -- Update existing oauth_requests data 25 + -- Convert dpop_private_jwk (JSONB) to multibase format if needed 26 + -- Note: This will leave the multibase column NULL for now since conversion requires crypto logic 27 + -- The application will need to handle NULL values or regenerate keys on next auth flow 28 + UPDATE oauth_requests 29 + SET 30 + auth_server_token_endpoint = auth_server_iss || '/oauth/token', 31 + scopes = ARRAY['atproto']::TEXT[] 32 + WHERE auth_server_token_endpoint IS NULL; 33 + 34 + -- Add indexes for new columns 35 + CREATE INDEX idx_oauth_requests_request_uri ON oauth_requests(request_uri) WHERE request_uri IS NOT NULL; 36 + 37 + -- Update oauth_sessions table 38 + -- Add session_id column (will become part of composite key) 39 + ALTER TABLE oauth_sessions 40 + ADD COLUMN session_id TEXT, 41 + ADD COLUMN host_url TEXT, 42 + ADD COLUMN auth_server_token_endpoint TEXT, 43 + ADD COLUMN auth_server_revocation_endpoint TEXT, 44 + ADD COLUMN scopes TEXT[], 45 + ADD COLUMN dpop_private_key_multibase TEXT; 46 + 47 + -- Make original dpop_private_jwk nullable (we now use dpop_private_key_multibase) 48 + ALTER TABLE oauth_sessions ALTER COLUMN dpop_private_jwk DROP NOT NULL; 49 + 50 + -- Populate session_id for existing sessions (use DID as default for single-session per account) 51 + -- In production, you may want to generate unique session IDs 52 + UPDATE oauth_sessions 53 + SET 54 + session_id = 'default', 55 + host_url = pds_url, 56 + auth_server_token_endpoint = auth_server_iss || '/oauth/token', 57 + scopes = ARRAY['atproto']::TEXT[] 58 + WHERE session_id IS NULL; 59 + 60 + -- Make session_id NOT NULL after populating existing data 61 + ALTER TABLE oauth_sessions 62 + ALTER COLUMN session_id SET NOT NULL; 63 + 64 + -- Drop old unique constraint on did only 65 + ALTER TABLE oauth_sessions 66 + DROP CONSTRAINT IF EXISTS oauth_sessions_did_key; 67 + 68 + -- Create new composite unique constraint for (did, session_id) 69 + -- This allows multiple sessions per account 70 + -- Note: UNIQUE constraint automatically creates an index, so no separate index needed 71 + ALTER TABLE oauth_sessions 72 + ADD CONSTRAINT oauth_sessions_did_session_id_key UNIQUE (did, session_id); 73 + 74 + -- Add comment explaining the schema change 75 + COMMENT ON COLUMN oauth_sessions.session_id IS 'Session identifier to support multiple concurrent sessions per account'; 76 + COMMENT ON CONSTRAINT oauth_sessions_did_session_id_key ON oauth_sessions IS 'Composite key allowing multiple sessions per DID'; 77 + 78 + -- +goose Down 79 + -- Rollback: Remove added columns and restore original unique constraint 80 + 81 + -- oauth_sessions rollback 82 + -- Drop composite unique constraint (this also drops the associated index) 83 + ALTER TABLE oauth_sessions 84 + DROP CONSTRAINT IF EXISTS oauth_sessions_did_session_id_key; 85 + 86 + -- Delete all but the most recent session per DID before restoring unique constraint 87 + -- This ensures the UNIQUE (did) constraint can be added without conflicts 88 + DELETE FROM oauth_sessions a 89 + USING oauth_sessions b 90 + WHERE a.did = b.did 91 + AND a.created_at < b.created_at; 92 + 93 + -- Restore old unique constraint 94 + ALTER TABLE oauth_sessions 95 + ADD CONSTRAINT oauth_sessions_did_key UNIQUE (did); 96 + 97 + -- Restore NOT NULL constraint on dpop_private_jwk 98 + ALTER TABLE oauth_sessions 99 + ALTER COLUMN dpop_private_jwk SET NOT NULL; 100 + 101 + ALTER TABLE oauth_sessions 102 + DROP COLUMN IF EXISTS dpop_private_key_multibase, 103 + DROP COLUMN IF EXISTS scopes, 104 + DROP COLUMN IF EXISTS auth_server_revocation_endpoint, 105 + DROP COLUMN IF EXISTS auth_server_token_endpoint, 106 + DROP COLUMN IF EXISTS host_url, 107 + DROP COLUMN IF EXISTS session_id; 108 + 109 + -- oauth_requests rollback 110 + DROP INDEX IF EXISTS idx_oauth_requests_request_uri; 111 + 112 + -- Restore NOT NULL constraints 113 + ALTER TABLE oauth_requests 114 + ALTER COLUMN dpop_private_jwk SET NOT NULL, 115 + ALTER COLUMN did SET NOT NULL, 116 + ALTER COLUMN handle SET NOT NULL, 117 + ALTER COLUMN pds_url SET NOT NULL; 118 + 119 + ALTER TABLE oauth_requests 120 + DROP COLUMN IF EXISTS dpop_private_key_multibase, 121 + DROP COLUMN IF EXISTS scopes, 122 + DROP COLUMN IF EXISTS auth_server_revocation_endpoint, 123 + DROP COLUMN IF EXISTS auth_server_token_endpoint, 124 + DROP COLUMN IF EXISTS request_uri;
+23
internal/db/migrations/020_add_mobile_oauth_csrf.sql
··· 1 + -- +goose Up 2 + -- Add columns for mobile OAuth CSRF protection with server-side state 3 + -- This ties the CSRF token to the OAuth state, allowing validation against 4 + -- a value that comes back through the OAuth response (the state parameter) 5 + -- rather than only validating cookies against each other. 6 + 7 + ALTER TABLE oauth_requests 8 + ADD COLUMN mobile_csrf_token TEXT, 9 + ADD COLUMN mobile_redirect_uri TEXT; 10 + 11 + -- Index for quick lookup of mobile data when callback is received 12 + CREATE INDEX idx_oauth_requests_mobile_csrf ON oauth_requests(state) 13 + WHERE mobile_csrf_token IS NOT NULL; 14 + 15 + COMMENT ON COLUMN oauth_requests.mobile_csrf_token IS 'CSRF token for mobile OAuth flows, validated against cookie on callback'; 16 + COMMENT ON COLUMN oauth_requests.mobile_redirect_uri IS 'Mobile redirect URI (Universal Link) for this OAuth flow'; 17 + 18 + -- +goose Down 19 + DROP INDEX IF EXISTS idx_oauth_requests_mobile_csrf; 20 + 21 + ALTER TABLE oauth_requests 22 + DROP COLUMN IF EXISTS mobile_redirect_uri, 23 + DROP COLUMN IF EXISTS mobile_csrf_token;