A community based topic aggregation platform built on atproto
at main 434 lines 16 kB view raw
1package aggregators 2 3import ( 4 "context" 5 "crypto/rand" 6 "crypto/sha256" 7 "encoding/hex" 8 "errors" 9 "fmt" 10 "log/slog" 11 "sync/atomic" 12 "time" 13 14 "github.com/bluesky-social/indigo/atproto/auth/oauth" 15 "github.com/bluesky-social/indigo/atproto/syntax" 16) 17 18const ( 19 // APIKeyPrefix is the prefix for all Coves API keys 20 APIKeyPrefix = "ckapi_" 21 // APIKeyRandomBytes is the number of random bytes in the key (32 bytes = 256 bits) 22 APIKeyRandomBytes = 32 23 // APIKeyTotalLength is the total length of the API key including prefix 24 // 6 (prefix "ckapi_") + 64 (32 bytes hex-encoded) = 70 25 APIKeyTotalLength = 70 26 // TokenRefreshBuffer is how long before expiry we should refresh tokens 27 TokenRefreshBuffer = 5 * time.Minute 28 // DefaultSessionID is used for API key sessions since aggregators have a single session 29 DefaultSessionID = "apikey" 30) 31 32// APIKeyService handles API key generation, validation, and OAuth token management 33// for aggregator authentication. 34type APIKeyService struct { 35 repo Repository 36 oauthApp *oauth.ClientApp // For resuming sessions and refreshing tokens 37 38 // failedLastUsedUpdates tracks the number of failed API key last_used timestamp updates. 39 // This counter provides visibility into persistent DB issues that would otherwise be hidden 40 // since the update is done asynchronously. Use GetFailedLastUsedUpdates() to read. 41 failedLastUsedUpdates atomic.Int64 42 43 // failedNonceUpdates tracks the number of failed OAuth nonce updates. 44 // Nonce failures may indicate DB issues and could lead to DPoP replay protection issues. 45 // Use GetFailedNonceUpdates() to read. 46 failedNonceUpdates atomic.Int64 47} 48 49// NewAPIKeyService creates a new API key service. 50// Panics if repo or oauthApp are nil, as these are required dependencies. 51func NewAPIKeyService(repo Repository, oauthApp *oauth.ClientApp) *APIKeyService { 52 if repo == nil { 53 panic("aggregators.NewAPIKeyService: repo cannot be nil") 54 } 55 if oauthApp == nil { 56 panic("aggregators.NewAPIKeyService: oauthApp cannot be nil") 57 } 58 return &APIKeyService{ 59 repo: repo, 60 oauthApp: oauthApp, 61 } 62} 63 64// GenerateKey creates a new API key for an aggregator. 65// The aggregator must have completed OAuth authentication first. 66// Returns the plain-text key (only shown once) and the key prefix for reference. 67func (s *APIKeyService) GenerateKey(ctx context.Context, aggregatorDID string, oauthSession *oauth.ClientSessionData) (plainKey string, keyPrefix string, err error) { 68 // Validate aggregator exists 69 aggregator, err := s.repo.GetAggregator(ctx, aggregatorDID) 70 if err != nil { 71 return "", "", fmt.Errorf("failed to get aggregator: %w", err) 72 } 73 74 // Validate OAuth session matches the aggregator 75 if oauthSession.AccountDID.String() != aggregatorDID { 76 return "", "", ErrOAuthSessionMismatch 77 } 78 79 // Generate random key 80 randomBytes := make([]byte, APIKeyRandomBytes) 81 if _, err := rand.Read(randomBytes); err != nil { 82 return "", "", fmt.Errorf("failed to generate random key: %w", err) 83 } 84 randomHex := hex.EncodeToString(randomBytes) 85 plainKey = APIKeyPrefix + randomHex 86 87 // Create key prefix (first 12 chars including prefix for identification) 88 keyPrefix = plainKey[:12] 89 90 // Hash the key for storage (SHA-256) 91 keyHash := hashAPIKey(plainKey) 92 93 // Extract OAuth credentials from session 94 // Note: ClientSessionData doesn't store token expiry from the OAuth response. 95 // We use a 1-hour default which matches typical OAuth access token lifetimes. 96 // Token refresh happens proactively before expiry via RefreshTokensIfNeeded. 97 tokenExpiry := time.Now().Add(1 * time.Hour) 98 oauthCreds := &OAuthCredentials{ 99 AccessToken: oauthSession.AccessToken, 100 RefreshToken: oauthSession.RefreshToken, 101 TokenExpiresAt: tokenExpiry, 102 PDSURL: oauthSession.HostURL, 103 AuthServerIss: oauthSession.AuthServerURL, 104 AuthServerTokenEndpoint: oauthSession.AuthServerTokenEndpoint, 105 DPoPPrivateKeyMultibase: oauthSession.DPoPPrivateKeyMultibase, 106 DPoPAuthServerNonce: oauthSession.DPoPAuthServerNonce, 107 DPoPPDSNonce: oauthSession.DPoPHostNonce, 108 } 109 110 // Validate OAuth credentials before proceeding 111 if err := oauthCreds.Validate(); err != nil { 112 return "", "", fmt.Errorf("invalid OAuth credentials: %w", err) 113 } 114 115 // Store the OAuth session in the store FIRST (before API key) 116 // This prevents a race condition where the API key exists but can't refresh tokens. 117 // Order: OAuth session → API key (if session fails, no dangling API key) 118 apiKeySession := *oauthSession // Copy session data 119 apiKeySession.SessionID = DefaultSessionID 120 if err := s.oauthApp.Store.SaveSession(ctx, apiKeySession); err != nil { 121 slog.Error("failed to store OAuth session for API key - aborting key creation", 122 "did", aggregatorDID, 123 "error", err, 124 ) 125 return "", "", fmt.Errorf("failed to store OAuth session for token refresh: %w", err) 126 } 127 128 // Now store key hash and OAuth credentials in aggregators table 129 // If this fails, we have an orphaned OAuth session, but that's less problematic 130 // than having an API key that can't refresh tokens. 131 if err := s.repo.SetAPIKey(ctx, aggregatorDID, keyPrefix, keyHash, oauthCreds); err != nil { 132 // Best effort cleanup of the OAuth session we just stored 133 if deleteErr := s.oauthApp.Store.DeleteSession(ctx, oauthSession.AccountDID, DefaultSessionID); deleteErr != nil { 134 slog.Warn("failed to cleanup OAuth session after API key storage failure", 135 "did", aggregatorDID, 136 "error", deleteErr, 137 ) 138 } 139 return "", "", fmt.Errorf("failed to store API key: %w", err) 140 } 141 142 slog.Info("API key generated for aggregator", 143 "did", aggregatorDID, 144 "display_name", aggregator.DisplayName, 145 "key_prefix", keyPrefix, 146 ) 147 148 return plainKey, keyPrefix, nil 149} 150 151// ValidateKey validates an API key and returns the associated aggregator credentials. 152// Returns ErrAPIKeyInvalid if the key is not found or revoked. 153func (s *APIKeyService) ValidateKey(ctx context.Context, plainKey string) (*AggregatorCredentials, error) { 154 // Validate key format - log invalid attempts for security monitoring 155 if len(plainKey) != APIKeyTotalLength || plainKey[:6] != APIKeyPrefix { 156 // Log for security monitoring (potential brute-force detection) 157 // Don't log the full key, just metadata about the attempt 158 slog.Warn("[SECURITY] invalid API key format attempt", 159 "key_length", len(plainKey), 160 "has_valid_prefix", len(plainKey) >= 6 && plainKey[:6] == APIKeyPrefix, 161 ) 162 return nil, ErrAPIKeyInvalid 163 } 164 165 // Hash the provided key 166 keyHash := hashAPIKey(plainKey) 167 168 // Look up aggregator credentials by hash 169 creds, err := s.repo.GetCredentialsByAPIKeyHash(ctx, keyHash) 170 if err != nil { 171 if IsNotFound(err) { 172 return nil, ErrAPIKeyInvalid 173 } 174 // Check for revoked API key (returned by repo when api_key_revoked_at is set) 175 if errors.Is(err, ErrAPIKeyRevoked) { 176 slog.Warn("revoked API key used", 177 "key_hash_prefix", keyHash[:8], 178 ) 179 return nil, ErrAPIKeyRevoked 180 } 181 return nil, fmt.Errorf("failed to lookup API key: %w", err) 182 } 183 184 // Update last used timestamp (async, don't block on error) 185 // Use a bounded timeout to prevent goroutine accumulation if DB is slow/down 186 // Extract trace info from context before spawning goroutine for log correlation 187 aggregatorDID := creds.DID // capture for goroutine 188 go func() { 189 updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 190 defer cancel() 191 192 if updateErr := s.repo.UpdateAPIKeyLastUsed(updateCtx, aggregatorDID); updateErr != nil { 193 // Increment failure counter for monitoring visibility 194 failCount := s.failedLastUsedUpdates.Add(1) 195 slog.Error("failed to update API key last used", 196 "did", aggregatorDID, 197 "error", updateErr, 198 "total_failures", failCount, 199 ) 200 } 201 }() 202 203 return creds, nil 204} 205 206// RefreshTokensIfNeeded checks if the OAuth tokens are expired or expiring soon, 207// and refreshes them if necessary. 208func (s *APIKeyService) RefreshTokensIfNeeded(ctx context.Context, creds *AggregatorCredentials) error { 209 // Check if tokens need refresh 210 if creds.OAuthTokenExpiresAt != nil { 211 if time.Until(*creds.OAuthTokenExpiresAt) > TokenRefreshBuffer { 212 // Tokens still valid 213 return nil 214 } 215 } 216 217 // Need to refresh tokens 218 slog.Info("refreshing OAuth tokens for aggregator", 219 "did", creds.DID, 220 "expires_at", creds.OAuthTokenExpiresAt, 221 ) 222 223 // Parse DID 224 did, err := syntax.ParseDID(creds.DID) 225 if err != nil { 226 return fmt.Errorf("failed to parse aggregator DID: %w", err) 227 } 228 229 // Resume the OAuth session from the store 230 // The session was stored when the aggregator created their API key 231 session, err := s.oauthApp.ResumeSession(ctx, did, DefaultSessionID) 232 if err != nil { 233 slog.Error("failed to resume OAuth session for token refresh", 234 "did", creds.DID, 235 "error", err, 236 ) 237 return fmt.Errorf("failed to resume session: %w", err) 238 } 239 240 // Refresh tokens using indigo's OAuth library 241 newAccessToken, err := session.RefreshTokens(ctx) 242 if err != nil { 243 slog.Error("failed to refresh OAuth tokens", 244 "did", creds.DID, 245 "error", err, 246 ) 247 return fmt.Errorf("failed to refresh tokens: %w", err) 248 } 249 250 // Note: ClientSessionData doesn't store token expiry from the OAuth response. 251 // We use a 1-hour default which matches typical OAuth access token lifetimes. 252 newExpiry := time.Now().Add(1 * time.Hour) 253 254 // Update tokens in database 255 if err := s.repo.UpdateOAuthTokens(ctx, creds.DID, newAccessToken, session.Data.RefreshToken, newExpiry); err != nil { 256 return fmt.Errorf("failed to update tokens: %w", err) 257 } 258 259 // Update nonces in our database as a secondary copy for visibility/backup. 260 // The authoritative nonces are in indigo's OAuth store (via SaveSession above). 261 // Session resumption uses s.oauthApp.ResumeSession which reads from indigo's store, 262 // so this failure is non-critical - hence warning level, not error. 263 if err := s.repo.UpdateOAuthNonces(ctx, creds.DID, session.Data.DPoPAuthServerNonce, session.Data.DPoPHostNonce); err != nil { 264 failCount := s.failedNonceUpdates.Add(1) 265 slog.Warn("failed to update OAuth nonces in aggregators table", 266 "did", creds.DID, 267 "error", err, 268 "total_failures", failCount, 269 ) 270 } 271 272 // Update credentials in memory 273 creds.OAuthAccessToken = newAccessToken 274 creds.OAuthRefreshToken = session.Data.RefreshToken 275 creds.OAuthTokenExpiresAt = &newExpiry 276 creds.OAuthDPoPAuthServerNonce = session.Data.DPoPAuthServerNonce 277 creds.OAuthDPoPPDSNonce = session.Data.DPoPHostNonce 278 279 slog.Info("OAuth tokens refreshed for aggregator", 280 "did", creds.DID, 281 "new_expires_at", newExpiry, 282 ) 283 284 return nil 285} 286 287// GetAccessToken returns a valid access token for the aggregator, 288// refreshing if necessary. 289func (s *APIKeyService) GetAccessToken(ctx context.Context, creds *AggregatorCredentials) (string, error) { 290 // Ensure tokens are fresh 291 if err := s.RefreshTokensIfNeeded(ctx, creds); err != nil { 292 return "", fmt.Errorf("failed to ensure fresh tokens: %w", err) 293 } 294 295 return creds.OAuthAccessToken, nil 296} 297 298// RevokeKey revokes an API key for an aggregator. 299// After revocation, the aggregator must complete OAuth flow again to get a new key. 300func (s *APIKeyService) RevokeKey(ctx context.Context, aggregatorDID string) error { 301 if err := s.repo.RevokeAPIKey(ctx, aggregatorDID); err != nil { 302 return fmt.Errorf("failed to revoke API key: %w", err) 303 } 304 305 slog.Info("API key revoked for aggregator", 306 "did", aggregatorDID, 307 ) 308 309 return nil 310} 311 312// GetAggregator retrieves the public aggregator information by DID. 313// For credential/authentication data, use GetAggregatorCredentials instead. 314func (s *APIKeyService) GetAggregator(ctx context.Context, aggregatorDID string) (*Aggregator, error) { 315 return s.repo.GetAggregator(ctx, aggregatorDID) 316} 317 318// GetAggregatorCredentials retrieves credentials for an aggregator by DID. 319func (s *APIKeyService) GetAggregatorCredentials(ctx context.Context, aggregatorDID string) (*AggregatorCredentials, error) { 320 return s.repo.GetAggregatorCredentials(ctx, aggregatorDID) 321} 322 323// GetAPIKeyInfo returns information about an aggregator's API key (without the actual key). 324func (s *APIKeyService) GetAPIKeyInfo(ctx context.Context, aggregatorDID string) (*APIKeyInfo, error) { 325 creds, err := s.repo.GetAggregatorCredentials(ctx, aggregatorDID) 326 if err != nil { 327 return nil, err 328 } 329 330 if creds.APIKeyHash == "" { 331 return &APIKeyInfo{ 332 HasKey: false, 333 }, nil 334 } 335 336 return &APIKeyInfo{ 337 HasKey: true, 338 KeyPrefix: creds.APIKeyPrefix, 339 CreatedAt: creds.APIKeyCreatedAt, 340 LastUsedAt: creds.APIKeyLastUsed, 341 IsRevoked: creds.APIKeyRevokedAt != nil, 342 RevokedAt: creds.APIKeyRevokedAt, 343 }, nil 344} 345 346// APIKeyInfo contains non-sensitive information about an API key 347type APIKeyInfo struct { 348 HasKey bool 349 KeyPrefix string 350 CreatedAt *time.Time 351 LastUsedAt *time.Time 352 IsRevoked bool 353 RevokedAt *time.Time 354} 355 356// hashAPIKey creates a SHA-256 hash of the API key for storage 357func hashAPIKey(plainKey string) string { 358 hash := sha256.Sum256([]byte(plainKey)) 359 return hex.EncodeToString(hash[:]) 360} 361 362// GetFailedLastUsedUpdates returns the count of failed API key last_used timestamp updates. 363// This is useful for monitoring and alerting on persistent database issues. 364func (s *APIKeyService) GetFailedLastUsedUpdates() int64 { 365 return s.failedLastUsedUpdates.Load() 366} 367 368// GetFailedNonceUpdates returns the count of failed OAuth nonce updates. 369// This is useful for monitoring and alerting on persistent database issues 370// that could affect DPoP replay protection. 371func (s *APIKeyService) GetFailedNonceUpdates() int64 { 372 return s.failedNonceUpdates.Load() 373} 374 375// perAggregatorRefreshTimeout is the maximum time allowed for refreshing 376// a single aggregator's tokens. This prevents a slow OAuth server from 377// blocking the entire refresh job. 378const perAggregatorRefreshTimeout = 30 * time.Second 379 380// RefreshExpiringTokens proactively refreshes tokens for all aggregators 381// whose tokens will expire within the given buffer period. 382// Returns count of successful refreshes and any errors encountered. 383// Each aggregator refresh has a 30-second timeout to prevent slow OAuth servers 384// from blocking the entire job. 385func (s *APIKeyService) RefreshExpiringTokens(ctx context.Context, expiryBuffer time.Duration) (refreshed int, errors []error) { 386 // Get all aggregators with tokens expiring within the buffer period 387 aggregators, err := s.repo.ListAggregatorsNeedingTokenRefresh(ctx, expiryBuffer) 388 if err != nil { 389 slog.Error("[TOKEN-REFRESH] Failed to list aggregators needing token refresh", 390 "error", err, 391 "expiry_buffer", expiryBuffer, 392 ) 393 return 0, []error{fmt.Errorf("failed to list aggregators needing refresh: %w", err)} 394 } 395 396 if len(aggregators) == 0 { 397 return 0, nil 398 } 399 400 slog.Info("[TOKEN-REFRESH] Starting proactive token refresh", 401 "aggregator_count", len(aggregators), 402 "expiry_buffer", expiryBuffer, 403 ) 404 405 // Refresh tokens for each aggregator with per-aggregator timeout 406 for _, creds := range aggregators { 407 slog.Info("[TOKEN-REFRESH] Attempting token refresh for aggregator", 408 "did", creds.DID, 409 "token_expires_at", creds.OAuthTokenExpiresAt, 410 ) 411 412 // Create per-aggregator timeout context to prevent slow OAuth servers 413 // from blocking the entire refresh cycle 414 refreshCtx, cancel := context.WithTimeout(ctx, perAggregatorRefreshTimeout) 415 err := s.RefreshTokensIfNeeded(refreshCtx, creds) 416 cancel() 417 418 if err != nil { 419 slog.Error("[TOKEN-REFRESH] Failed to refresh tokens for aggregator", 420 "did", creds.DID, 421 "error", err, 422 ) 423 errors = append(errors, fmt.Errorf("aggregator %s: %w", creds.DID, err)) 424 } else { 425 slog.Info("[TOKEN-REFRESH] Successfully refreshed tokens for aggregator", 426 "did", creds.DID, 427 "new_expires_at", creds.OAuthTokenExpiresAt, 428 ) 429 refreshed++ 430 } 431 } 432 433 return refreshed, errors 434}