A community based topic aggregation platform built on atproto

Merge branch 'feat/aggregator-api-keys'

+5261 -123
+7 -4
.env.dev
··· 38 38 PDS_ADMIN_PASSWORD=admin 39 39 40 40 # Handle domains (users will get handles like alice.local.coves.dev) 41 - # Communities will use .community.coves.social (singular per atProto conventions) 42 - PDS_SERVICE_HANDLE_DOMAINS=.local.coves.dev,.community.coves.social 41 + # Communities will use c-{name}.coves.social (3-level format with c- prefix) 42 + PDS_SERVICE_HANDLE_DOMAINS=.local.coves.dev,.coves.social 43 43 44 44 # PLC Rotation Key (k256 private key in hex format - for local dev only) 45 45 # This is a randomly generated key for testing - DO NOT use in production ··· 133 133 PDS_INSTANCE_HANDLE=testuser123.local.coves.dev 134 134 PDS_INSTANCE_PASSWORD=test-password-123 135 135 136 - # Kagi News Aggregator DID (for trusted thumbnail URLs) 137 - KAGI_AGGREGATOR_DID=did:plc:yyf34padpfjknejyutxtionr 136 + # Trusted Aggregator DIDs (bypasses community authorization check) 137 + # Comma-separated list of DIDs 138 + # - did:plc:yyf34padpfjknejyutxtionr = kagi-news.coves.social (production) 139 + # - did:plc:igjbg5cex7poojsniebvmafb = test-aggregator.local.coves.dev (dev) 140 + TRUSTED_AGGREGATOR_DIDS=did:plc:yyf34padpfjknejyutxtionr,did:plc:igjbg5cex7poojsniebvmafb 138 141 139 142 # ============================================================================= 140 143 # Development Settings
+1 -1
.env.dev.example
··· 46 46 PDS_DID_PLC_URL=http://plc-directory:3000 47 47 PDS_JWT_SECRET=local-dev-jwt-secret-change-in-production 48 48 PDS_ADMIN_PASSWORD=admin 49 - PDS_SERVICE_HANDLE_DOMAINS=.local.coves.dev,.community.coves.social 49 + PDS_SERVICE_HANDLE_DOMAINS=.local.coves.dev,.coves.social 50 50 PDS_PLC_ROTATION_KEY=<generate-a-random-hex-key> 51 51 52 52 # =============================================================================
+2 -3
aggregators/kagi-news/.env.example
··· 1 - # Aggregator Identity (pre-created account credentials) 2 - AGGREGATOR_HANDLE=kagi-news.local.coves.dev 3 - AGGREGATOR_PASSWORD=your-secure-password-here 1 + # Coves API Key (get from https://coves.social after OAuth login) 2 + COVES_API_KEY=ckapi_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx 4 3 5 4 # Optional: Override Coves API URL (defaults to config.yaml) 6 5 # COVES_API_URL=http://localhost:3001
+2
aggregators/kagi-news/config.example.yaml
··· 2 2 3 3 # Coves API endpoint 4 4 coves_api_url: "https://coves.social" 5 + # API key is loaded from COVES_API_KEY environment variable 6 + # Get your API key from https://coves.social after OAuth login 5 7 6 8 # Feed-to-community mappings 7 9 # Handle format: c-{name}.{instance} (e.g., c-worldnews.coves.social)
-1
aggregators/kagi-news/requirements.txt
··· 2 2 feedparser==6.0.11 3 3 beautifulsoup4==4.12.3 4 4 requests==2.31.0 5 - atproto==0.0.55 6 5 pyyaml==6.0.1 7 6 8 7 # Testing
+133 -55
aggregators/kagi-news/src/coves_client.py
··· 1 1 """ 2 2 Coves API Client for posting to communities. 3 3 4 - Handles authentication and posting via XRPC. 4 + Handles API key authentication and posting via XRPC. 5 5 """ 6 6 import logging 7 7 import requests 8 8 from typing import Dict, List, Optional 9 - from atproto import Client 10 9 11 10 logger = logging.getLogger(__name__) 12 11 13 12 13 + class CovesAPIError(Exception): 14 + """Base exception for Coves API errors.""" 15 + 16 + def __init__(self, message: str, status_code: int = None, response_body: str = None): 17 + super().__init__(message) 18 + self.status_code = status_code 19 + self.response_body = response_body 20 + 21 + 22 + class CovesAuthenticationError(CovesAPIError): 23 + """Raised when authentication fails (401 Unauthorized).""" 24 + pass 25 + 26 + 27 + class CovesNotFoundError(CovesAPIError): 28 + """Raised when a resource is not found (404 Not Found).""" 29 + pass 30 + 31 + 32 + class CovesRateLimitError(CovesAPIError): 33 + """Raised when rate limit is exceeded (429 Too Many Requests).""" 34 + pass 35 + 36 + 37 + class CovesForbiddenError(CovesAPIError): 38 + """Raised when access is forbidden (403 Forbidden).""" 39 + pass 40 + 41 + 14 42 class CovesClient: 15 43 """ 16 44 Client for posting to Coves communities via XRPC. 17 45 18 46 Handles: 19 - - Authentication with aggregator credentials 47 + - API key authentication 20 48 - Creating posts in communities (social.coves.community.post.create) 21 49 - External embed formatting 22 50 """ 23 51 24 - def __init__(self, api_url: str, handle: str, password: str, pds_url: Optional[str] = None): 25 - """ 26 - Initialize Coves client. 27 - 28 - Args: 29 - api_url: Coves AppView URL for posting (e.g., "http://localhost:8081") 30 - handle: Aggregator handle (e.g., "kagi-news.coves.social") 31 - password: Aggregator password/app password 32 - pds_url: Optional PDS URL for authentication (defaults to api_url) 33 - """ 34 - self.api_url = api_url 35 - self.pds_url = pds_url or api_url # Auth through PDS, post through AppView 36 - self.handle = handle 37 - self.password = password 38 - self.client = Client(base_url=self.pds_url) # Use PDS for auth 39 - self._authenticated = False 52 + # API key format constants (must match Go constants in apikey_service.go) 53 + API_KEY_PREFIX = "ckapi_" 54 + API_KEY_TOTAL_LENGTH = 70 # 6 (prefix) + 64 (32 bytes hex-encoded) 40 55 41 - def authenticate(self): 56 + def __init__(self, api_url: str, api_key: str): 42 57 """ 43 - Authenticate with Coves API. 58 + Initialize Coves client with API key authentication. 44 59 45 - Uses com.atproto.server.createSession directly to avoid 46 - Bluesky-specific endpoints that don't exist on Coves PDS. 60 + Args: 61 + api_url: Coves API URL for posting (e.g., "https://coves.social") 62 + api_key: Coves API key (e.g., "ckapi_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") 47 63 48 64 Raises: 49 - Exception: If authentication fails 65 + ValueError: If api_key format is invalid 50 66 """ 51 - try: 52 - logger.info(f"Authenticating as {self.handle}") 53 - 54 - # Use createSession directly (avoid app.bsky.actor.getProfile) 55 - session = self.client.com.atproto.server.create_session( 56 - {"identifier": self.handle, "password": self.password} 67 + # Validate API key format for early failure with clear error 68 + if not api_key: 69 + raise ValueError("API key cannot be empty") 70 + if not api_key.startswith(self.API_KEY_PREFIX): 71 + raise ValueError(f"API key must start with '{self.API_KEY_PREFIX}'") 72 + if len(api_key) != self.API_KEY_TOTAL_LENGTH: 73 + raise ValueError( 74 + f"API key must be {self.API_KEY_TOTAL_LENGTH} characters " 75 + f"(got {len(api_key)})" 57 76 ) 58 77 59 - # Manually set session (skip profile fetch) 60 - self.client._session = session 61 - self._authenticated = True 62 - self.did = session.did 78 + self.api_url = api_url.rstrip('/') 79 + self.api_key = api_key 80 + self.session = requests.Session() 81 + self.session.headers['Authorization'] = f'Bearer {api_key}' 82 + self.session.headers['Content-Type'] = 'application/json' 63 83 64 - logger.info(f"Authentication successful (DID: {self.did})") 65 - except Exception as e: 66 - logger.error(f"Authentication failed: {e}") 67 - raise 84 + def authenticate(self): 85 + """ 86 + No-op for API key authentication. 87 + 88 + API key is set in the session headers during initialization. 89 + This method is kept for backward compatibility with existing code 90 + that calls authenticate() before making requests. 91 + """ 92 + logger.info("Using API key authentication (no session creation needed)") 68 93 69 94 def create_post( 70 95 self, ··· 90 115 AT Proto URI of created post (e.g., "at://did:plc:.../social.coves.post/...") 91 116 92 117 Raises: 93 - Exception: If post creation fails 118 + requests.HTTPError: If post creation fails 94 119 """ 95 - if not self._authenticated: 96 - self.authenticate() 97 - 98 120 try: 99 121 # Prepare post data for social.coves.community.post.create endpoint 100 122 post_data = { ··· 119 141 # This provides validation, authorization, and business logic 120 142 logger.info(f"Creating post in community: {community_handle}") 121 143 122 - # Make direct HTTP request to XRPC endpoint 144 + # Make HTTP request to XRPC endpoint using session with API key 123 145 url = f"{self.api_url}/xrpc/social.coves.community.post.create" 124 - headers = { 125 - "Authorization": f"Bearer {self.client._session.access_jwt}", 126 - "Content-Type": "application/json" 127 - } 128 - 129 - response = requests.post(url, json=post_data, headers=headers, timeout=30) 146 + response = self.session.post(url, json=post_data, timeout=30) 130 147 131 - # Log detailed error if request fails 148 + # Handle specific error cases 132 149 if not response.ok: 133 150 error_body = response.text 134 151 logger.error(f"Post creation failed ({response.status_code}): {error_body}") 135 - response.raise_for_status() 152 + self._raise_for_status(response) 153 + 154 + try: 155 + result = response.json() 156 + post_uri = result["uri"] 157 + except (ValueError, KeyError) as e: 158 + # ValueError for invalid JSON, KeyError for missing 'uri' field 159 + logger.error(f"Failed to parse post creation response: {e}") 160 + raise CovesAPIError( 161 + f"Invalid response from server: {e}", 162 + status_code=response.status_code, 163 + response_body=response.text 164 + ) 136 165 137 - result = response.json() 138 - post_uri = result["uri"] 139 166 logger.info(f"Post created: {post_uri}") 140 167 return post_uri 141 168 142 - except Exception as e: 143 - logger.error(f"Failed to create post: {e}") 169 + except requests.RequestException as e: 170 + # Network errors, timeouts, etc. 171 + logger.error(f"Network error creating post: {e}") 172 + raise 173 + except CovesAPIError: 174 + # Re-raise our custom exceptions as-is 144 175 raise 145 176 146 177 def create_external_embed( ··· 175 206 "$type": "social.coves.embed.external", 176 207 "external": external 177 208 } 209 + 210 + def _raise_for_status(self, response: requests.Response) -> None: 211 + """ 212 + Raise specific exceptions based on HTTP status code. 213 + 214 + Args: 215 + response: The HTTP response object 216 + 217 + Raises: 218 + CovesAuthenticationError: For 401 Unauthorized 219 + CovesNotFoundError: For 404 Not Found 220 + CovesRateLimitError: For 429 Too Many Requests 221 + CovesAPIError: For other 4xx/5xx errors 222 + """ 223 + status_code = response.status_code 224 + error_body = response.text 225 + 226 + if status_code == 401: 227 + raise CovesAuthenticationError( 228 + f"Authentication failed: {error_body}", 229 + status_code=status_code, 230 + response_body=error_body 231 + ) 232 + elif status_code == 403: 233 + raise CovesForbiddenError( 234 + f"Access forbidden: {error_body}", 235 + status_code=status_code, 236 + response_body=error_body 237 + ) 238 + elif status_code == 404: 239 + raise CovesNotFoundError( 240 + f"Resource not found: {error_body}", 241 + status_code=status_code, 242 + response_body=error_body 243 + ) 244 + elif status_code == 429: 245 + raise CovesRateLimitError( 246 + f"Rate limit exceeded: {error_body}", 247 + status_code=status_code, 248 + response_body=error_body 249 + ) 250 + else: 251 + raise CovesAPIError( 252 + f"API request failed ({status_code}): {error_body}", 253 + status_code=status_code, 254 + response_body=error_body 255 + ) 178 256 179 257 def _get_timestamp(self) -> str: 180 258 """
+5 -9
aggregators/kagi-news/src/main.py
··· 71 71 if coves_client: 72 72 self.coves_client = coves_client 73 73 else: 74 - # Get credentials from environment 75 - aggregator_handle = os.getenv('AGGREGATOR_HANDLE') 76 - aggregator_password = os.getenv('AGGREGATOR_PASSWORD') 77 - pds_url = os.getenv('PDS_URL') # Optional: separate PDS for auth 74 + # Get API key from environment 75 + api_key = os.getenv('COVES_API_KEY') 78 76 79 - if not aggregator_handle or not aggregator_password: 77 + if not api_key: 80 78 raise ValueError( 81 - "Missing AGGREGATOR_HANDLE or AGGREGATOR_PASSWORD environment variables" 79 + "COVES_API_KEY environment variable required" 82 80 ) 83 81 84 82 self.coves_client = CovesClient( 85 83 api_url=self.config.coves_api_url, 86 - handle=aggregator_handle, 87 - password=aggregator_password, 88 - pds_url=pds_url # Auth through PDS if specified 84 + api_key=api_key 89 85 ) 90 86 91 87 def run(self):
+127 -3
aggregators/kagi-news/tests/test_coves_client.py
··· 4 4 Tests the client's local functionality without requiring live infrastructure. 5 5 """ 6 6 import pytest 7 - from src.coves_client import CovesClient 7 + from unittest.mock import Mock 8 + from src.coves_client import ( 9 + CovesClient, 10 + CovesAPIError, 11 + CovesAuthenticationError, 12 + CovesForbiddenError, 13 + CovesNotFoundError, 14 + CovesRateLimitError, 15 + ) 16 + 17 + 18 + # Valid test API key (70 chars total: 6 prefix + 64 hex chars) 19 + VALID_TEST_API_KEY = "ckapi_" + "a" * 64 20 + 21 + 22 + class TestAPIKeyValidation: 23 + """Tests for API key format validation in constructor.""" 24 + 25 + def test_rejects_empty_api_key(self): 26 + """Empty API key should raise ValueError.""" 27 + with pytest.raises(ValueError, match="cannot be empty"): 28 + CovesClient(api_url="http://localhost", api_key="") 29 + 30 + def test_rejects_wrong_prefix(self): 31 + """API key with wrong prefix should raise ValueError.""" 32 + wrong_prefix_key = "wrong_" + "a" * 64 33 + with pytest.raises(ValueError, match="must start with 'ckapi_'"): 34 + CovesClient(api_url="http://localhost", api_key=wrong_prefix_key) 35 + 36 + def test_rejects_short_api_key(self): 37 + """API key that is too short should raise ValueError.""" 38 + short_key = "ckapi_tooshort" 39 + with pytest.raises(ValueError, match="must be 70 characters"): 40 + CovesClient(api_url="http://localhost", api_key=short_key) 41 + 42 + def test_rejects_long_api_key(self): 43 + """API key that is too long should raise ValueError.""" 44 + long_key = "ckapi_" + "a" * 100 45 + with pytest.raises(ValueError, match="must be 70 characters"): 46 + CovesClient(api_url="http://localhost", api_key=long_key) 47 + 48 + def test_accepts_valid_api_key(self): 49 + """Valid API key format should be accepted.""" 50 + client = CovesClient(api_url="http://localhost", api_key=VALID_TEST_API_KEY) 51 + assert client.api_key == VALID_TEST_API_KEY 52 + 53 + 54 + class TestRaiseForStatus: 55 + """Tests for _raise_for_status method.""" 56 + 57 + @pytest.fixture 58 + def client(self): 59 + """Create a CovesClient instance for testing.""" 60 + return CovesClient(api_url="http://localhost", api_key=VALID_TEST_API_KEY) 61 + 62 + def test_raises_authentication_error_for_401(self, client): 63 + """401 response should raise CovesAuthenticationError.""" 64 + mock_response = Mock() 65 + mock_response.status_code = 401 66 + mock_response.text = "Invalid API key" 67 + 68 + with pytest.raises(CovesAuthenticationError) as exc_info: 69 + client._raise_for_status(mock_response) 70 + 71 + assert exc_info.value.status_code == 401 72 + assert "Authentication failed" in str(exc_info.value) 73 + 74 + def test_raises_forbidden_error_for_403(self, client): 75 + """403 response should raise CovesForbiddenError.""" 76 + mock_response = Mock() 77 + mock_response.status_code = 403 78 + mock_response.text = "Not authorized for this community" 79 + 80 + with pytest.raises(CovesForbiddenError) as exc_info: 81 + client._raise_for_status(mock_response) 82 + 83 + assert exc_info.value.status_code == 403 84 + assert "Access forbidden" in str(exc_info.value) 85 + 86 + def test_raises_not_found_error_for_404(self, client): 87 + """404 response should raise CovesNotFoundError.""" 88 + mock_response = Mock() 89 + mock_response.status_code = 404 90 + mock_response.text = "Community not found" 91 + 92 + with pytest.raises(CovesNotFoundError) as exc_info: 93 + client._raise_for_status(mock_response) 94 + 95 + assert exc_info.value.status_code == 404 96 + assert "Resource not found" in str(exc_info.value) 97 + 98 + def test_raises_rate_limit_error_for_429(self, client): 99 + """429 response should raise CovesRateLimitError.""" 100 + mock_response = Mock() 101 + mock_response.status_code = 429 102 + mock_response.text = "Rate limit exceeded" 103 + 104 + with pytest.raises(CovesRateLimitError) as exc_info: 105 + client._raise_for_status(mock_response) 106 + 107 + assert exc_info.value.status_code == 429 108 + assert "Rate limit exceeded" in str(exc_info.value) 109 + 110 + def test_raises_generic_api_error_for_500(self, client): 111 + """500 response should raise generic CovesAPIError.""" 112 + mock_response = Mock() 113 + mock_response.status_code = 500 114 + mock_response.text = "Internal server error" 115 + 116 + with pytest.raises(CovesAPIError) as exc_info: 117 + client._raise_for_status(mock_response) 118 + 119 + assert exc_info.value.status_code == 500 120 + assert not isinstance(exc_info.value, CovesAuthenticationError) 121 + assert not isinstance(exc_info.value, CovesNotFoundError) 122 + 123 + def test_exception_includes_response_body(self, client): 124 + """Exception should include the response body.""" 125 + mock_response = Mock() 126 + mock_response.status_code = 400 127 + mock_response.text = '{"error": "Bad request details"}' 128 + 129 + with pytest.raises(CovesAPIError) as exc_info: 130 + client._raise_for_status(mock_response) 131 + 132 + assert exc_info.value.response_body == '{"error": "Bad request details"}' 8 133 9 134 10 135 class TestCreateExternalEmbed: ··· 15 140 """Create a CovesClient instance for testing.""" 16 141 return CovesClient( 17 142 api_url="http://localhost:8081", 18 - handle="test.handle", 19 - password="test_password" 143 + api_key=VALID_TEST_API_KEY 20 144 ) 21 145 22 146 def test_creates_embed_without_sources(self, client):
+16 -3
cmd/server/main.go
··· 390 390 aggregatorService := aggregators.NewAggregatorService(aggregatorRepo, communityService) 391 391 log.Println("✅ Aggregator service initialized") 392 392 393 + // Initialize API key service for aggregator authentication 394 + apiKeyService := aggregators.NewAPIKeyService(aggregatorRepo, oauthClient.ClientApp) 395 + log.Println("✅ API key service initialized") 396 + 393 397 // Get instance DID for service auth validator audience 394 398 serviceDID := instanceDID // Use instance DID as the service audience 395 399 ··· 402 406 } 403 407 log.Printf("✅ Service auth validator initialized (audience: %s)", serviceDID) 404 408 405 - // Create DualAuthMiddleware that supports both OAuth and service JWT 409 + // Create DualAuthMiddleware that supports OAuth, service JWT, and API keys 406 410 // OAuth tokens are for user authentication (sealed session tokens) 407 411 // Service JWTs are for aggregator authentication (PDS-signed tokens) 412 + // API keys are for aggregator bot authentication (stateless, cryptographic) 413 + apiKeyValidator := middleware.NewAPIKeyValidatorAdapter(apiKeyService) 408 414 dualAuth := middleware.NewDualAuthMiddleware( 409 415 oauthClient, // SessionUnsealer for OAuth 410 416 oauthStore, // ClientAuthStore for OAuth sessions 411 417 serviceValidator, // ServiceAuthValidator for JWT validation 412 418 aggregatorRepo, // AggregatorChecker - uses repo directly since it implements the interface 413 - ) 414 - log.Println("✅ Dual auth middleware initialized (OAuth + service JWT)") 419 + ).WithAPIKeyValidator(apiKeyValidator) 420 + log.Println("✅ Dual auth middleware initialized (OAuth + service JWT + API keys)") 415 421 416 422 // Initialize unfurl cache repository 417 423 unfurlRepo := unfurl.NewRepository(db) ··· 622 628 623 629 routes.RegisterAggregatorRoutes(r, aggregatorService, communityService, userService, identityResolver) 624 630 log.Println("Aggregator XRPC endpoints registered (query endpoints public, registration endpoint public)") 631 + 632 + routes.RegisterAggregatorAPIKeyRoutes(r, authMiddleware, apiKeyService, aggregatorService) 633 + log.Println("✅ Aggregator API key endpoints registered") 634 + log.Println(" - POST /xrpc/social.coves.aggregator.createApiKey (requires OAuth)") 635 + log.Println(" - GET /xrpc/social.coves.aggregator.getApiKey (requires OAuth)") 636 + log.Println(" - POST /xrpc/social.coves.aggregator.revokeApiKey (requires OAuth)") 637 + log.Println(" - GET /xrpc/social.coves.aggregator.getMetrics (public)") 625 638 626 639 // Comment query API - supports optional authentication for viewer state 627 640 // Stricter rate limiting for expensive nested comment queries
+1159
internal/api/handlers/aggregator/apikey_handlers_test.go
··· 1 + package aggregator 2 + 3 + import ( 4 + "Coves/internal/api/middleware" 5 + "Coves/internal/core/aggregators" 6 + "context" 7 + "encoding/json" 8 + "errors" 9 + "net/http" 10 + "net/http/httptest" 11 + "testing" 12 + "time" 13 + 14 + oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth" 15 + "github.com/bluesky-social/indigo/atproto/syntax" 16 + ) 17 + 18 + // mockAggregatorService implements aggregators.Service for testing 19 + type mockAggregatorService struct { 20 + isAggregatorFunc func(ctx context.Context, did string) (bool, error) 21 + } 22 + 23 + func (m *mockAggregatorService) IsAggregator(ctx context.Context, did string) (bool, error) { 24 + if m.isAggregatorFunc != nil { 25 + return m.isAggregatorFunc(ctx, did) 26 + } 27 + return true, nil 28 + } 29 + 30 + // Stub implementations for Service interface methods we don't test 31 + func (m *mockAggregatorService) GetAggregator(ctx context.Context, did string) (*aggregators.Aggregator, error) { 32 + return nil, nil 33 + } 34 + 35 + func (m *mockAggregatorService) GetAggregators(ctx context.Context, dids []string) ([]*aggregators.Aggregator, error) { 36 + return nil, nil 37 + } 38 + 39 + func (m *mockAggregatorService) ListAggregators(ctx context.Context, limit, offset int) ([]*aggregators.Aggregator, error) { 40 + return nil, nil 41 + } 42 + 43 + func (m *mockAggregatorService) GetAuthorizationsForAggregator(ctx context.Context, req aggregators.GetAuthorizationsRequest) ([]*aggregators.Authorization, error) { 44 + return nil, nil 45 + } 46 + 47 + func (m *mockAggregatorService) ListAggregatorsForCommunity(ctx context.Context, req aggregators.ListForCommunityRequest) ([]*aggregators.Authorization, error) { 48 + return nil, nil 49 + } 50 + 51 + func (m *mockAggregatorService) EnableAggregator(ctx context.Context, req aggregators.EnableAggregatorRequest) (*aggregators.Authorization, error) { 52 + return nil, nil 53 + } 54 + 55 + func (m *mockAggregatorService) DisableAggregator(ctx context.Context, req aggregators.DisableAggregatorRequest) (*aggregators.Authorization, error) { 56 + return nil, nil 57 + } 58 + 59 + func (m *mockAggregatorService) UpdateAggregatorConfig(ctx context.Context, req aggregators.UpdateConfigRequest) (*aggregators.Authorization, error) { 60 + return nil, nil 61 + } 62 + 63 + func (m *mockAggregatorService) ValidateAggregatorPost(ctx context.Context, aggregatorDID, communityDID string) error { 64 + return nil 65 + } 66 + 67 + func (m *mockAggregatorService) RecordAggregatorPost(ctx context.Context, aggregatorDID, communityDID, postURI, postCID string) error { 68 + return nil 69 + } 70 + 71 + // XRPCError represents an XRPC error response for testing 72 + type XRPCError struct { 73 + Error string `json:"error"` 74 + Message string `json:"message"` 75 + } 76 + 77 + // Helper to create authenticated request context with OAuth session 78 + func createAuthenticatedContext(t *testing.T, didStr string) context.Context { 79 + t.Helper() 80 + did, err := syntax.ParseDID(didStr) 81 + if err != nil { 82 + t.Fatalf("Failed to parse DID: %v", err) 83 + } 84 + session := &oauthlib.ClientSessionData{ 85 + AccountDID: did, 86 + AccessToken: "test_access_token", 87 + SessionID: "test_session", 88 + } 89 + ctx := context.WithValue(context.Background(), middleware.OAuthSessionKey, session) 90 + ctx = context.WithValue(ctx, middleware.UserDIDKey, didStr) 91 + return ctx 92 + } 93 + 94 + // Helper to create context with just UserDID (no OAuth session) 95 + func createUserDIDContext(didStr string) context.Context { 96 + return context.WithValue(context.Background(), middleware.UserDIDKey, didStr) 97 + } 98 + 99 + // ============================================================================= 100 + // CreateAPIKey Handler Tests 101 + // ============================================================================= 102 + 103 + func TestCreateAPIKeyHandler_Success(t *testing.T) { 104 + // Create mock services 105 + mockAggSvc := &mockAggregatorService{ 106 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 107 + return true, nil // Is an aggregator 108 + }, 109 + } 110 + 111 + mockAPIKeySvc := &mockAPIKeyService{ 112 + generateKeyFunc: func(ctx context.Context, aggregatorDID string, oauthSession *oauthlib.ClientSessionData) (string, string, error) { 113 + return "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", "ckapi_012345", nil 114 + }, 115 + } 116 + 117 + handler := NewCreateAPIKeyHandler(mockAPIKeySvc, mockAggSvc) 118 + 119 + // Create request with full auth context (including OAuth session) 120 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.createApiKey", nil) 121 + req.Header.Set("Content-Type", "application/json") 122 + ctx := createAuthenticatedContext(t, "did:plc:aggregator123") 123 + req = req.WithContext(ctx) 124 + 125 + // Execute handler 126 + w := httptest.NewRecorder() 127 + handler.HandleCreateAPIKey(w, req) 128 + 129 + // Check status code 130 + if w.Code != http.StatusOK { 131 + t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String()) 132 + } 133 + 134 + // Check response format 135 + var response CreateAPIKeyResponse 136 + if err := json.NewDecoder(w.Body).Decode(&response); err != nil { 137 + t.Fatalf("Failed to decode response: %v", err) 138 + } 139 + 140 + if response.Key != "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" { 141 + t.Errorf("Expected key to match generated key, got %s", response.Key) 142 + } 143 + if response.KeyPrefix != "ckapi_012345" { 144 + t.Errorf("Expected keyPrefix to match, got %s", response.KeyPrefix) 145 + } 146 + if response.DID != "did:plc:aggregator123" { 147 + t.Errorf("Expected DID to match authenticated user, got %s", response.DID) 148 + } 149 + if response.CreatedAt == "" { 150 + t.Error("Expected createdAt to be set") 151 + } 152 + } 153 + 154 + func TestCreateAPIKeyHandler_RequiresAuth(t *testing.T) { 155 + mockAggSvc := &mockAggregatorService{} 156 + handler := NewCreateAPIKeyHandler(nil, mockAggSvc) 157 + 158 + // Create HTTP request without auth context 159 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.createApiKey", nil) 160 + req.Header.Set("Content-Type", "application/json") 161 + // No OAuth session in context 162 + 163 + // Execute handler 164 + w := httptest.NewRecorder() 165 + handler.HandleCreateAPIKey(w, req) 166 + 167 + // Check status code 168 + if w.Code != http.StatusUnauthorized { 169 + t.Errorf("Expected status 401, got %d. Body: %s", w.Code, w.Body.String()) 170 + } 171 + 172 + // Check error response 173 + var errResp XRPCError 174 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 175 + t.Fatalf("Failed to decode error response: %v", err) 176 + } 177 + if errResp.Error != "AuthenticationRequired" { 178 + t.Errorf("Expected error AuthenticationRequired, got %s", errResp.Error) 179 + } 180 + } 181 + 182 + func TestCreateAPIKeyHandler_MethodNotAllowed(t *testing.T) { 183 + mockAggSvc := &mockAggregatorService{} 184 + handler := NewCreateAPIKeyHandler(nil, mockAggSvc) 185 + 186 + // Create GET request (should only accept POST) 187 + req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.aggregator.createApiKey", nil) 188 + 189 + // Execute handler 190 + w := httptest.NewRecorder() 191 + handler.HandleCreateAPIKey(w, req) 192 + 193 + // Check status code 194 + if w.Code != http.StatusMethodNotAllowed { 195 + t.Errorf("Expected status 405, got %d", w.Code) 196 + } 197 + } 198 + 199 + func TestCreateAPIKeyHandler_NotAggregator(t *testing.T) { 200 + mockAggSvc := &mockAggregatorService{ 201 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 202 + return false, nil // Not an aggregator 203 + }, 204 + } 205 + handler := NewCreateAPIKeyHandler(nil, mockAggSvc) 206 + 207 + // Create request with auth 208 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.createApiKey", nil) 209 + req.Header.Set("Content-Type", "application/json") 210 + ctx := createAuthenticatedContext(t, "did:plc:user123") 211 + req = req.WithContext(ctx) 212 + 213 + // Execute handler 214 + w := httptest.NewRecorder() 215 + handler.HandleCreateAPIKey(w, req) 216 + 217 + // Check status code 218 + if w.Code != http.StatusForbidden { 219 + t.Errorf("Expected status 403, got %d. Body: %s", w.Code, w.Body.String()) 220 + } 221 + 222 + // Check error response 223 + var errResp XRPCError 224 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 225 + t.Fatalf("Failed to decode error response: %v", err) 226 + } 227 + if errResp.Error != "AggregatorRequired" { 228 + t.Errorf("Expected error AggregatorRequired, got %s", errResp.Error) 229 + } 230 + } 231 + 232 + func TestCreateAPIKeyHandler_AggregatorCheckError(t *testing.T) { 233 + mockAggSvc := &mockAggregatorService{ 234 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 235 + return false, errors.New("database error") 236 + }, 237 + } 238 + handler := NewCreateAPIKeyHandler(nil, mockAggSvc) 239 + 240 + // Create request with auth 241 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.createApiKey", nil) 242 + req.Header.Set("Content-Type", "application/json") 243 + ctx := createAuthenticatedContext(t, "did:plc:user123") 244 + req = req.WithContext(ctx) 245 + 246 + // Execute handler 247 + w := httptest.NewRecorder() 248 + handler.HandleCreateAPIKey(w, req) 249 + 250 + // Check status code 251 + if w.Code != http.StatusInternalServerError { 252 + t.Errorf("Expected status 500, got %d. Body: %s", w.Code, w.Body.String()) 253 + } 254 + 255 + // Check error response 256 + var errResp XRPCError 257 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 258 + t.Fatalf("Failed to decode error response: %v", err) 259 + } 260 + if errResp.Error != "InternalServerError" { 261 + t.Errorf("Expected error InternalServerError, got %s", errResp.Error) 262 + } 263 + } 264 + 265 + func TestCreateAPIKeyHandler_MissingOAuthSession(t *testing.T) { 266 + mockAggSvc := &mockAggregatorService{ 267 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 268 + return true, nil // Is an aggregator 269 + }, 270 + } 271 + handler := NewCreateAPIKeyHandler(nil, mockAggSvc) 272 + 273 + // Create request with UserDID but no OAuth session 274 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.createApiKey", nil) 275 + req.Header.Set("Content-Type", "application/json") 276 + ctx := createUserDIDContext("did:plc:aggregator123") 277 + req = req.WithContext(ctx) 278 + 279 + // Execute handler 280 + w := httptest.NewRecorder() 281 + handler.HandleCreateAPIKey(w, req) 282 + 283 + // Check status code - should fail because OAuth session is required 284 + if w.Code != http.StatusUnauthorized { 285 + t.Errorf("Expected status 401, got %d. Body: %s", w.Code, w.Body.String()) 286 + } 287 + 288 + // Check error response 289 + var errResp XRPCError 290 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 291 + t.Fatalf("Failed to decode error response: %v", err) 292 + } 293 + if errResp.Error != "OAuthSessionRequired" { 294 + t.Errorf("Expected error OAuthSessionRequired, got %s", errResp.Error) 295 + } 296 + } 297 + 298 + // ============================================================================= 299 + // GetAPIKey Handler Tests 300 + // ============================================================================= 301 + 302 + func TestGetAPIKeyHandler_Success(t *testing.T) { 303 + createdAt := time.Now().Add(-24 * time.Hour) 304 + lastUsed := time.Now().Add(-1 * time.Hour) 305 + 306 + mockAggSvc := &mockAggregatorService{ 307 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 308 + return true, nil // Is an aggregator 309 + }, 310 + } 311 + 312 + mockAPIKeySvc := &mockAPIKeyService{ 313 + getAPIKeyInfoFunc: func(ctx context.Context, aggregatorDID string) (*aggregators.APIKeyInfo, error) { 314 + return &aggregators.APIKeyInfo{ 315 + HasKey: true, 316 + KeyPrefix: "ckapi_test12", 317 + CreatedAt: &createdAt, 318 + LastUsedAt: &lastUsed, 319 + IsRevoked: false, 320 + }, nil 321 + }, 322 + } 323 + 324 + handler := NewGetAPIKeyHandler(mockAPIKeySvc, mockAggSvc) 325 + 326 + // Create request with auth 327 + req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.aggregator.getApiKey", nil) 328 + ctx := createUserDIDContext("did:plc:aggregator123") 329 + req = req.WithContext(ctx) 330 + 331 + // Execute handler 332 + w := httptest.NewRecorder() 333 + handler.HandleGetAPIKey(w, req) 334 + 335 + // Check status code 336 + if w.Code != http.StatusOK { 337 + t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String()) 338 + } 339 + 340 + // Check response format 341 + var response GetAPIKeyResponse 342 + if err := json.NewDecoder(w.Body).Decode(&response); err != nil { 343 + t.Fatalf("Failed to decode response: %v", err) 344 + } 345 + 346 + if !response.HasKey { 347 + t.Error("Expected hasKey to be true") 348 + } 349 + if response.KeyInfo == nil { 350 + t.Fatal("Expected keyInfo to be present") 351 + } 352 + if response.KeyInfo.Prefix != "ckapi_test12" { 353 + t.Errorf("Expected prefix 'ckapi_test12', got %s", response.KeyInfo.Prefix) 354 + } 355 + if response.KeyInfo.IsRevoked { 356 + t.Error("Expected isRevoked to be false") 357 + } 358 + if response.KeyInfo.CreatedAt == "" { 359 + t.Error("Expected createdAt to be set") 360 + } 361 + if response.KeyInfo.LastUsedAt == nil { 362 + t.Error("Expected lastUsedAt to be set") 363 + } 364 + } 365 + 366 + func TestGetAPIKeyHandler_Success_NoKey(t *testing.T) { 367 + mockAggSvc := &mockAggregatorService{ 368 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 369 + return true, nil // Is an aggregator 370 + }, 371 + } 372 + 373 + mockAPIKeySvc := &mockAPIKeyService{ 374 + getAPIKeyInfoFunc: func(ctx context.Context, aggregatorDID string) (*aggregators.APIKeyInfo, error) { 375 + return &aggregators.APIKeyInfo{ 376 + HasKey: false, 377 + }, nil 378 + }, 379 + } 380 + 381 + handler := NewGetAPIKeyHandler(mockAPIKeySvc, mockAggSvc) 382 + 383 + // Create request with auth 384 + req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.aggregator.getApiKey", nil) 385 + ctx := createUserDIDContext("did:plc:aggregator123") 386 + req = req.WithContext(ctx) 387 + 388 + // Execute handler 389 + w := httptest.NewRecorder() 390 + handler.HandleGetAPIKey(w, req) 391 + 392 + // Check status code 393 + if w.Code != http.StatusOK { 394 + t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String()) 395 + } 396 + 397 + // Check response format 398 + var response GetAPIKeyResponse 399 + if err := json.NewDecoder(w.Body).Decode(&response); err != nil { 400 + t.Fatalf("Failed to decode response: %v", err) 401 + } 402 + 403 + if response.HasKey { 404 + t.Error("Expected hasKey to be false") 405 + } 406 + if response.KeyInfo != nil { 407 + t.Error("Expected keyInfo to be nil when hasKey is false") 408 + } 409 + } 410 + 411 + func TestGetAPIKeyHandler_RequiresAuth(t *testing.T) { 412 + mockAggSvc := &mockAggregatorService{} 413 + handler := NewGetAPIKeyHandler(nil, mockAggSvc) 414 + 415 + // Create HTTP request without auth context 416 + req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.aggregator.getApiKey", nil) 417 + // No auth context 418 + 419 + // Execute handler 420 + w := httptest.NewRecorder() 421 + handler.HandleGetAPIKey(w, req) 422 + 423 + // Check status code 424 + if w.Code != http.StatusUnauthorized { 425 + t.Errorf("Expected status 401, got %d. Body: %s", w.Code, w.Body.String()) 426 + } 427 + 428 + // Check error response 429 + var errResp XRPCError 430 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 431 + t.Fatalf("Failed to decode error response: %v", err) 432 + } 433 + if errResp.Error != "AuthenticationRequired" { 434 + t.Errorf("Expected error AuthenticationRequired, got %s", errResp.Error) 435 + } 436 + } 437 + 438 + func TestGetAPIKeyHandler_MethodNotAllowed(t *testing.T) { 439 + mockAggSvc := &mockAggregatorService{} 440 + handler := NewGetAPIKeyHandler(nil, mockAggSvc) 441 + 442 + // Create POST request (should only accept GET) 443 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.getApiKey", nil) 444 + 445 + // Execute handler 446 + w := httptest.NewRecorder() 447 + handler.HandleGetAPIKey(w, req) 448 + 449 + // Check status code 450 + if w.Code != http.StatusMethodNotAllowed { 451 + t.Errorf("Expected status 405, got %d", w.Code) 452 + } 453 + } 454 + 455 + func TestGetAPIKeyHandler_NotAggregator(t *testing.T) { 456 + mockAggSvc := &mockAggregatorService{ 457 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 458 + return false, nil // Not an aggregator 459 + }, 460 + } 461 + handler := NewGetAPIKeyHandler(nil, mockAggSvc) 462 + 463 + // Create request with auth 464 + req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.aggregator.getApiKey", nil) 465 + ctx := createUserDIDContext("did:plc:user123") 466 + req = req.WithContext(ctx) 467 + 468 + // Execute handler 469 + w := httptest.NewRecorder() 470 + handler.HandleGetAPIKey(w, req) 471 + 472 + // Check status code 473 + if w.Code != http.StatusForbidden { 474 + t.Errorf("Expected status 403, got %d. Body: %s", w.Code, w.Body.String()) 475 + } 476 + 477 + // Check error response 478 + var errResp XRPCError 479 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 480 + t.Fatalf("Failed to decode error response: %v", err) 481 + } 482 + if errResp.Error != "AggregatorRequired" { 483 + t.Errorf("Expected error AggregatorRequired, got %s", errResp.Error) 484 + } 485 + } 486 + 487 + func TestGetAPIKeyHandler_AggregatorCheckError(t *testing.T) { 488 + mockAggSvc := &mockAggregatorService{ 489 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 490 + return false, errors.New("database error") 491 + }, 492 + } 493 + handler := NewGetAPIKeyHandler(nil, mockAggSvc) 494 + 495 + // Create request with auth 496 + req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.aggregator.getApiKey", nil) 497 + ctx := createUserDIDContext("did:plc:user123") 498 + req = req.WithContext(ctx) 499 + 500 + // Execute handler 501 + w := httptest.NewRecorder() 502 + handler.HandleGetAPIKey(w, req) 503 + 504 + // Check status code 505 + if w.Code != http.StatusInternalServerError { 506 + t.Errorf("Expected status 500, got %d. Body: %s", w.Code, w.Body.String()) 507 + } 508 + 509 + // Check error response 510 + var errResp XRPCError 511 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 512 + t.Fatalf("Failed to decode error response: %v", err) 513 + } 514 + if errResp.Error != "InternalServerError" { 515 + t.Errorf("Expected error InternalServerError, got %s", errResp.Error) 516 + } 517 + } 518 + 519 + // ============================================================================= 520 + // RevokeAPIKey Handler Tests 521 + // ============================================================================= 522 + 523 + func TestRevokeAPIKeyHandler_Success(t *testing.T) { 524 + mockAggSvc := &mockAggregatorService{ 525 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 526 + return true, nil // Is an aggregator 527 + }, 528 + } 529 + 530 + revokeKeyCalled := false 531 + mockAPIKeySvc := &mockAPIKeyService{ 532 + getAPIKeyInfoFunc: func(ctx context.Context, aggregatorDID string) (*aggregators.APIKeyInfo, error) { 533 + return &aggregators.APIKeyInfo{ 534 + HasKey: true, 535 + KeyPrefix: "ckapi_test12", 536 + IsRevoked: false, 537 + }, nil 538 + }, 539 + revokeKeyFunc: func(ctx context.Context, aggregatorDID string) error { 540 + revokeKeyCalled = true 541 + return nil 542 + }, 543 + } 544 + 545 + handler := NewRevokeAPIKeyHandler(mockAPIKeySvc, mockAggSvc) 546 + 547 + // Create request with auth 548 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.revokeApiKey", nil) 549 + req.Header.Set("Content-Type", "application/json") 550 + ctx := createUserDIDContext("did:plc:aggregator123") 551 + req = req.WithContext(ctx) 552 + 553 + // Execute handler 554 + w := httptest.NewRecorder() 555 + handler.HandleRevokeAPIKey(w, req) 556 + 557 + // Check status code 558 + if w.Code != http.StatusOK { 559 + t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String()) 560 + } 561 + 562 + // Check that RevokeKey was called 563 + if !revokeKeyCalled { 564 + t.Error("Expected RevokeKey to be called") 565 + } 566 + 567 + // Check response format 568 + var response RevokeAPIKeyResponse 569 + if err := json.NewDecoder(w.Body).Decode(&response); err != nil { 570 + t.Fatalf("Failed to decode response: %v", err) 571 + } 572 + 573 + if response.RevokedAt == "" { 574 + t.Error("Expected revokedAt to be set") 575 + } 576 + 577 + // Verify timestamp format 578 + _, err := time.Parse("2006-01-02T15:04:05.000Z", response.RevokedAt) 579 + if err != nil { 580 + t.Errorf("Expected revokedAt to be valid ISO8601 timestamp: %v", err) 581 + } 582 + } 583 + 584 + func TestRevokeAPIKeyHandler_NoKeyToRevoke(t *testing.T) { 585 + mockAggSvc := &mockAggregatorService{ 586 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 587 + return true, nil // Is an aggregator 588 + }, 589 + } 590 + 591 + mockAPIKeySvc := &mockAPIKeyService{ 592 + getAPIKeyInfoFunc: func(ctx context.Context, aggregatorDID string) (*aggregators.APIKeyInfo, error) { 593 + return &aggregators.APIKeyInfo{ 594 + HasKey: false, // No key exists 595 + }, nil 596 + }, 597 + } 598 + 599 + handler := NewRevokeAPIKeyHandler(mockAPIKeySvc, mockAggSvc) 600 + 601 + // Create request with auth 602 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.revokeApiKey", nil) 603 + req.Header.Set("Content-Type", "application/json") 604 + ctx := createUserDIDContext("did:plc:aggregator123") 605 + req = req.WithContext(ctx) 606 + 607 + // Execute handler 608 + w := httptest.NewRecorder() 609 + handler.HandleRevokeAPIKey(w, req) 610 + 611 + // Check status code 612 + if w.Code != http.StatusBadRequest { 613 + t.Errorf("Expected status 400, got %d. Body: %s", w.Code, w.Body.String()) 614 + } 615 + 616 + // Check error response 617 + var errResp XRPCError 618 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 619 + t.Fatalf("Failed to decode error response: %v", err) 620 + } 621 + if errResp.Error != "ApiKeyNotFound" { 622 + t.Errorf("Expected error ApiKeyNotFound, got %s", errResp.Error) 623 + } 624 + } 625 + 626 + func TestRevokeAPIKeyHandler_AlreadyRevoked(t *testing.T) { 627 + mockAggSvc := &mockAggregatorService{ 628 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 629 + return true, nil // Is an aggregator 630 + }, 631 + } 632 + 633 + mockAPIKeySvc := &mockAPIKeyService{ 634 + getAPIKeyInfoFunc: func(ctx context.Context, aggregatorDID string) (*aggregators.APIKeyInfo, error) { 635 + return &aggregators.APIKeyInfo{ 636 + HasKey: true, 637 + KeyPrefix: "ckapi_test12", 638 + IsRevoked: true, // Already revoked 639 + }, nil 640 + }, 641 + } 642 + 643 + handler := NewRevokeAPIKeyHandler(mockAPIKeySvc, mockAggSvc) 644 + 645 + // Create request with auth 646 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.revokeApiKey", nil) 647 + req.Header.Set("Content-Type", "application/json") 648 + ctx := createUserDIDContext("did:plc:aggregator123") 649 + req = req.WithContext(ctx) 650 + 651 + // Execute handler 652 + w := httptest.NewRecorder() 653 + handler.HandleRevokeAPIKey(w, req) 654 + 655 + // Check status code 656 + if w.Code != http.StatusBadRequest { 657 + t.Errorf("Expected status 400, got %d. Body: %s", w.Code, w.Body.String()) 658 + } 659 + 660 + // Check error response 661 + var errResp XRPCError 662 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 663 + t.Fatalf("Failed to decode error response: %v", err) 664 + } 665 + if errResp.Error != "ApiKeyAlreadyRevoked" { 666 + t.Errorf("Expected error ApiKeyAlreadyRevoked, got %s", errResp.Error) 667 + } 668 + } 669 + 670 + func TestRevokeAPIKeyHandler_RequiresAuth(t *testing.T) { 671 + mockAggSvc := &mockAggregatorService{} 672 + handler := NewRevokeAPIKeyHandler(nil, mockAggSvc) 673 + 674 + // Create HTTP request without auth context 675 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.revokeApiKey", nil) 676 + req.Header.Set("Content-Type", "application/json") 677 + // No auth context 678 + 679 + // Execute handler 680 + w := httptest.NewRecorder() 681 + handler.HandleRevokeAPIKey(w, req) 682 + 683 + // Check status code 684 + if w.Code != http.StatusUnauthorized { 685 + t.Errorf("Expected status 401, got %d. Body: %s", w.Code, w.Body.String()) 686 + } 687 + 688 + // Check error response 689 + var errResp XRPCError 690 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 691 + t.Fatalf("Failed to decode error response: %v", err) 692 + } 693 + if errResp.Error != "AuthenticationRequired" { 694 + t.Errorf("Expected error AuthenticationRequired, got %s", errResp.Error) 695 + } 696 + } 697 + 698 + func TestRevokeAPIKeyHandler_MethodNotAllowed(t *testing.T) { 699 + mockAggSvc := &mockAggregatorService{} 700 + handler := NewRevokeAPIKeyHandler(nil, mockAggSvc) 701 + 702 + // Create GET request (should only accept POST) 703 + req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.aggregator.revokeApiKey", nil) 704 + 705 + // Execute handler 706 + w := httptest.NewRecorder() 707 + handler.HandleRevokeAPIKey(w, req) 708 + 709 + // Check status code 710 + if w.Code != http.StatusMethodNotAllowed { 711 + t.Errorf("Expected status 405, got %d", w.Code) 712 + } 713 + } 714 + 715 + func TestRevokeAPIKeyHandler_NotAggregator(t *testing.T) { 716 + mockAggSvc := &mockAggregatorService{ 717 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 718 + return false, nil // Not an aggregator 719 + }, 720 + } 721 + handler := NewRevokeAPIKeyHandler(nil, mockAggSvc) 722 + 723 + // Create request with auth 724 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.revokeApiKey", nil) 725 + req.Header.Set("Content-Type", "application/json") 726 + ctx := createUserDIDContext("did:plc:user123") 727 + req = req.WithContext(ctx) 728 + 729 + // Execute handler 730 + w := httptest.NewRecorder() 731 + handler.HandleRevokeAPIKey(w, req) 732 + 733 + // Check status code 734 + if w.Code != http.StatusForbidden { 735 + t.Errorf("Expected status 403, got %d. Body: %s", w.Code, w.Body.String()) 736 + } 737 + 738 + // Check error response 739 + var errResp XRPCError 740 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 741 + t.Fatalf("Failed to decode error response: %v", err) 742 + } 743 + if errResp.Error != "AggregatorRequired" { 744 + t.Errorf("Expected error AggregatorRequired, got %s", errResp.Error) 745 + } 746 + } 747 + 748 + func TestRevokeAPIKeyHandler_AggregatorCheckError(t *testing.T) { 749 + mockAggSvc := &mockAggregatorService{ 750 + isAggregatorFunc: func(ctx context.Context, did string) (bool, error) { 751 + return false, errors.New("database error") 752 + }, 753 + } 754 + handler := NewRevokeAPIKeyHandler(nil, mockAggSvc) 755 + 756 + // Create request with auth 757 + req := httptest.NewRequest(http.MethodPost, "/xrpc/social.coves.aggregator.revokeApiKey", nil) 758 + req.Header.Set("Content-Type", "application/json") 759 + ctx := createUserDIDContext("did:plc:user123") 760 + req = req.WithContext(ctx) 761 + 762 + // Execute handler 763 + w := httptest.NewRecorder() 764 + handler.HandleRevokeAPIKey(w, req) 765 + 766 + // Check status code 767 + if w.Code != http.StatusInternalServerError { 768 + t.Errorf("Expected status 500, got %d. Body: %s", w.Code, w.Body.String()) 769 + } 770 + 771 + // Check error response 772 + var errResp XRPCError 773 + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { 774 + t.Fatalf("Failed to decode error response: %v", err) 775 + } 776 + if errResp.Error != "InternalServerError" { 777 + t.Errorf("Expected error InternalServerError, got %s", errResp.Error) 778 + } 779 + } 780 + 781 + // ============================================================================= 782 + // Response Format Tests 783 + // ============================================================================= 784 + 785 + func TestRevokeAPIKeyResponse_ContainsRequiredFields(t *testing.T) { 786 + // Verify RevokeAPIKeyResponse has the required fields per lexicon 787 + response := RevokeAPIKeyResponse{ 788 + RevokedAt: time.Now().UTC().Format("2006-01-02T15:04:05.000Z"), 789 + } 790 + 791 + data, err := json.Marshal(response) 792 + if err != nil { 793 + t.Fatalf("Failed to marshal response: %v", err) 794 + } 795 + 796 + var decoded map[string]interface{} 797 + if err := json.Unmarshal(data, &decoded); err != nil { 798 + t.Fatalf("Failed to unmarshal response: %v", err) 799 + } 800 + 801 + // Check required fields per lexicon (success field removed per AT Protocol best practices) 802 + if _, ok := decoded["revokedAt"]; !ok { 803 + t.Error("Response missing required 'revokedAt' field") 804 + } 805 + } 806 + 807 + func TestCreateAPIKeyResponse_ContainsRequiredFields(t *testing.T) { 808 + response := CreateAPIKeyResponse{ 809 + Key: "ckapi_test1234567890123456789012345678", 810 + KeyPrefix: "ckapi_test12", 811 + DID: "did:plc:aggregator123", 812 + CreatedAt: time.Now().UTC().Format("2006-01-02T15:04:05.000Z"), 813 + } 814 + 815 + data, err := json.Marshal(response) 816 + if err != nil { 817 + t.Fatalf("Failed to marshal response: %v", err) 818 + } 819 + 820 + var decoded map[string]interface{} 821 + if err := json.Unmarshal(data, &decoded); err != nil { 822 + t.Fatalf("Failed to unmarshal response: %v", err) 823 + } 824 + 825 + // Check required fields 826 + requiredFields := []string{"key", "keyPrefix", "did", "createdAt"} 827 + for _, field := range requiredFields { 828 + if _, ok := decoded[field]; !ok { 829 + t.Errorf("Response missing required '%s' field", field) 830 + } 831 + } 832 + } 833 + 834 + func TestGetAPIKeyResponse_ContainsRequiredFields(t *testing.T) { 835 + response := GetAPIKeyResponse{ 836 + HasKey: true, 837 + KeyInfo: &APIKeyView{ 838 + Prefix: "ckapi_test12", 839 + CreatedAt: time.Now().UTC().Format("2006-01-02T15:04:05.000Z"), 840 + IsRevoked: false, 841 + }, 842 + } 843 + 844 + data, err := json.Marshal(response) 845 + if err != nil { 846 + t.Fatalf("Failed to marshal response: %v", err) 847 + } 848 + 849 + var decoded map[string]interface{} 850 + if err := json.Unmarshal(data, &decoded); err != nil { 851 + t.Fatalf("Failed to unmarshal response: %v", err) 852 + } 853 + 854 + // Check required fields (now uses nested keyInfo structure) 855 + if _, ok := decoded["hasKey"]; !ok { 856 + t.Error("Response missing required 'hasKey' field") 857 + } 858 + if keyInfo, ok := decoded["keyInfo"].(map[string]interface{}); ok { 859 + if _, ok := keyInfo["isRevoked"]; !ok { 860 + t.Error("keyInfo missing required 'isRevoked' field") 861 + } 862 + } else { 863 + t.Error("Response missing 'keyInfo' field when hasKey is true") 864 + } 865 + } 866 + 867 + func TestGetAPIKeyResponse_OmitsEmptyOptionalFields(t *testing.T) { 868 + response := GetAPIKeyResponse{ 869 + HasKey: false, 870 + // KeyInfo is nil when hasKey is false 871 + } 872 + 873 + data, err := json.Marshal(response) 874 + if err != nil { 875 + t.Fatalf("Failed to marshal response: %v", err) 876 + } 877 + 878 + var decoded map[string]interface{} 879 + if err := json.Unmarshal(data, &decoded); err != nil { 880 + t.Fatalf("Failed to unmarshal response: %v", err) 881 + } 882 + 883 + // KeyInfo should be omitted when hasKey is false (per omitempty tag) 884 + if _, ok := decoded["keyInfo"]; ok { 885 + t.Error("Response should omit nil 'keyInfo' field when hasKey is false") 886 + } 887 + } 888 + 889 + // ============================================================================= 890 + // Handler Success Path Tests with Mocks 891 + // ============================================================================= 892 + 893 + // mockAPIKeyService implements aggregators.APIKeyServiceInterface for testing 894 + type mockAPIKeyService struct { 895 + generateKeyFunc func(ctx context.Context, aggregatorDID string, oauthSession *oauthlib.ClientSessionData) (plainKey string, keyPrefix string, err error) 896 + getAPIKeyInfoFunc func(ctx context.Context, aggregatorDID string) (*aggregators.APIKeyInfo, error) 897 + revokeKeyFunc func(ctx context.Context, aggregatorDID string) error 898 + failedLastUsedUpdates int64 899 + failedNonceUpdates int64 900 + } 901 + 902 + func (m *mockAPIKeyService) GenerateKey(ctx context.Context, aggregatorDID string, oauthSession *oauthlib.ClientSessionData) (string, string, error) { 903 + if m.generateKeyFunc != nil { 904 + return m.generateKeyFunc(ctx, aggregatorDID, oauthSession) 905 + } 906 + return "", "", errors.New("not implemented") 907 + } 908 + 909 + func (m *mockAPIKeyService) GetAPIKeyInfo(ctx context.Context, aggregatorDID string) (*aggregators.APIKeyInfo, error) { 910 + if m.getAPIKeyInfoFunc != nil { 911 + return m.getAPIKeyInfoFunc(ctx, aggregatorDID) 912 + } 913 + return nil, errors.New("not implemented") 914 + } 915 + 916 + func (m *mockAPIKeyService) RevokeKey(ctx context.Context, aggregatorDID string) error { 917 + if m.revokeKeyFunc != nil { 918 + return m.revokeKeyFunc(ctx, aggregatorDID) 919 + } 920 + return errors.New("not implemented") 921 + } 922 + 923 + func (m *mockAPIKeyService) GetFailedLastUsedUpdates() int64 { 924 + return m.failedLastUsedUpdates 925 + } 926 + 927 + func (m *mockAPIKeyService) GetFailedNonceUpdates() int64 { 928 + return m.failedNonceUpdates 929 + } 930 + 931 + // Verify mockAPIKeyService implements the interface at compile time 932 + var _ aggregators.APIKeyServiceInterface = (*mockAPIKeyService)(nil) 933 + 934 + func TestCreateAPIKeyHandler_Success_RequiresIntegration(t *testing.T) { 935 + // The CreateAPIKeyHandler.HandleCreateAPIKey method calls: 936 + // 1. middleware.GetUserDID(r) - to get authenticated user 937 + // 2. h.aggregatorService.IsAggregator(ctx, userDID) - to verify aggregator status 938 + // 3. middleware.GetOAuthSession(r) - to get OAuth session 939 + // 4. h.apiKeyService.GenerateKey(ctx, userDID, oauthSession) - to create the key 940 + // 941 + // Since apiKeyService is a concrete *aggregators.APIKeyService (not an interface), 942 + // we cannot mock it directly. Full success path testing requires: 943 + // - A real aggregators.Repository mock 944 + // - A real OAuth store mock 945 + // - Setting up the full APIKeyService with those mocks 946 + // 947 + // This test documents the pattern for integration-style testing with mocks: 948 + 949 + // Create mock repository that tracks calls 950 + createdAt := time.Now() 951 + generateKeyCalled := false 952 + 953 + // Create a custom test that verifies the handler response format when everything works 954 + t.Run("response_format_verification", func(t *testing.T) { 955 + // Verify the expected response format matches what GenerateKey would return 956 + response := CreateAPIKeyResponse{ 957 + Key: "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", 958 + KeyPrefix: "ckapi_012345", 959 + DID: "did:plc:aggregator123", 960 + CreatedAt: createdAt.Format("2006-01-02T15:04:05.000Z"), 961 + } 962 + 963 + data, err := json.Marshal(response) 964 + if err != nil { 965 + t.Fatalf("Failed to marshal response: %v", err) 966 + } 967 + 968 + var decoded map[string]interface{} 969 + if err := json.Unmarshal(data, &decoded); err != nil { 970 + t.Fatalf("Failed to unmarshal response: %v", err) 971 + } 972 + 973 + // Verify key format 974 + key, ok := decoded["key"].(string) 975 + if !ok || len(key) != 70 { 976 + t.Errorf("Expected key to be 70 chars, got %d", len(key)) 977 + } 978 + if !ok || key[:6] != "ckapi_" { 979 + t.Errorf("Expected key to start with 'ckapi_', got %s", key[:6]) 980 + } 981 + 982 + // Verify keyPrefix is first 12 chars of key 983 + keyPrefix, ok := decoded["keyPrefix"].(string) 984 + if !ok || keyPrefix != key[:12] { 985 + t.Errorf("Expected keyPrefix to be first 12 chars of key") 986 + } 987 + }) 988 + 989 + // This assertion exists just to use the variable and satisfy the linter 990 + _ = generateKeyCalled 991 + } 992 + 993 + func TestGetAPIKeyHandler_Success_RequiresIntegration(t *testing.T) { 994 + // Similar to CreateAPIKeyHandler, GetAPIKeyHandler uses concrete *aggregators.APIKeyService. 995 + // This test documents the integration test pattern and verifies response format. 996 + 997 + t.Run("response_format_with_active_key", func(t *testing.T) { 998 + createdAt := time.Now().Add(-24 * time.Hour) 999 + lastUsed := time.Now().Add(-1 * time.Hour) 1000 + lastUsedStr := lastUsed.Format("2006-01-02T15:04:05.000Z") 1001 + 1002 + response := GetAPIKeyResponse{ 1003 + HasKey: true, 1004 + KeyInfo: &APIKeyView{ 1005 + Prefix: "ckapi_test12", 1006 + CreatedAt: createdAt.Format("2006-01-02T15:04:05.000Z"), 1007 + LastUsedAt: &lastUsedStr, 1008 + IsRevoked: false, 1009 + }, 1010 + } 1011 + 1012 + data, err := json.Marshal(response) 1013 + if err != nil { 1014 + t.Fatalf("Failed to marshal response: %v", err) 1015 + } 1016 + 1017 + var decoded map[string]interface{} 1018 + if err := json.Unmarshal(data, &decoded); err != nil { 1019 + t.Fatalf("Failed to unmarshal response: %v", err) 1020 + } 1021 + 1022 + // Verify all expected fields are present 1023 + if !decoded["hasKey"].(bool) { 1024 + t.Error("Expected hasKey to be true") 1025 + } 1026 + keyInfo := decoded["keyInfo"].(map[string]interface{}) 1027 + if keyInfo["prefix"] != "ckapi_test12" { 1028 + t.Errorf("Expected prefix 'ckapi_test12', got %v", keyInfo["prefix"]) 1029 + } 1030 + if keyInfo["isRevoked"].(bool) { 1031 + t.Error("Expected isRevoked to be false") 1032 + } 1033 + }) 1034 + 1035 + t.Run("response_format_with_no_key", func(t *testing.T) { 1036 + response := GetAPIKeyResponse{ 1037 + HasKey: false, 1038 + // KeyInfo is nil when hasKey is false 1039 + } 1040 + 1041 + data, err := json.Marshal(response) 1042 + if err != nil { 1043 + t.Fatalf("Failed to marshal response: %v", err) 1044 + } 1045 + 1046 + var decoded map[string]interface{} 1047 + if err := json.Unmarshal(data, &decoded); err != nil { 1048 + t.Fatalf("Failed to unmarshal response: %v", err) 1049 + } 1050 + 1051 + if decoded["hasKey"].(bool) { 1052 + t.Error("Expected hasKey to be false") 1053 + } 1054 + if _, ok := decoded["keyInfo"]; ok { 1055 + t.Error("Expected keyInfo to be omitted when hasKey is false") 1056 + } 1057 + }) 1058 + } 1059 + 1060 + // ============================================================================= 1061 + // Service Error Handling Tests 1062 + // ============================================================================= 1063 + // These tests document the expected error handling behavior when the APIKeyService 1064 + // returns errors. Since handlers use concrete *aggregators.APIKeyService (not an 1065 + // interface), full testing of these paths requires integration tests with mocked 1066 + // repository layer. 1067 + 1068 + func TestRevokeAPIKeyHandler_ServiceError_Documentation(t *testing.T) { 1069 + // Documents expected behavior when RevokeKey returns an error: 1070 + // - Handler should return 500 InternalServerError 1071 + // - Error response should include "RevocationFailed" error code 1072 + // 1073 + // This behavior is tested at the service level and integration level. 1074 + t.Run("expected_error_response", func(t *testing.T) { 1075 + errorResp := struct { 1076 + Error string `json:"error"` 1077 + Message string `json:"message"` 1078 + }{ 1079 + Error: "RevocationFailed", 1080 + Message: "Failed to revoke API key", 1081 + } 1082 + 1083 + data, err := json.Marshal(errorResp) 1084 + if err != nil { 1085 + t.Fatalf("Failed to marshal error response: %v", err) 1086 + } 1087 + 1088 + var decoded map[string]interface{} 1089 + if err := json.Unmarshal(data, &decoded); err != nil { 1090 + t.Fatalf("Failed to unmarshal response: %v", err) 1091 + } 1092 + 1093 + if decoded["error"] != "RevocationFailed" { 1094 + t.Errorf("Expected error 'RevocationFailed', got %v", decoded["error"]) 1095 + } 1096 + }) 1097 + } 1098 + 1099 + func TestCreateAPIKeyHandler_KeyGenerationError_Documentation(t *testing.T) { 1100 + // Documents expected behavior when GenerateKey returns an error: 1101 + // - Handler should return 500 InternalServerError 1102 + // - Error response should include "KeyGenerationFailed" error code 1103 + // 1104 + // This behavior is tested at the service level and integration level. 1105 + t.Run("expected_error_response", func(t *testing.T) { 1106 + errorResp := struct { 1107 + Error string `json:"error"` 1108 + Message string `json:"message"` 1109 + }{ 1110 + Error: "KeyGenerationFailed", 1111 + Message: "Failed to generate API key", 1112 + } 1113 + 1114 + data, err := json.Marshal(errorResp) 1115 + if err != nil { 1116 + t.Fatalf("Failed to marshal error response: %v", err) 1117 + } 1118 + 1119 + var decoded map[string]interface{} 1120 + if err := json.Unmarshal(data, &decoded); err != nil { 1121 + t.Fatalf("Failed to unmarshal response: %v", err) 1122 + } 1123 + 1124 + if decoded["error"] != "KeyGenerationFailed" { 1125 + t.Errorf("Expected error 'KeyGenerationFailed', got %v", decoded["error"]) 1126 + } 1127 + }) 1128 + } 1129 + 1130 + func TestGetAPIKeyHandler_ServiceError_Documentation(t *testing.T) { 1131 + // Documents expected behavior when GetAPIKeyInfo returns an error: 1132 + // - Handler should return 500 InternalServerError 1133 + // - Error response should include "InternalServerError" error code 1134 + // 1135 + // This behavior is tested at the service level and integration level. 1136 + t.Run("expected_error_response", func(t *testing.T) { 1137 + errorResp := struct { 1138 + Error string `json:"error"` 1139 + Message string `json:"message"` 1140 + }{ 1141 + Error: "InternalServerError", 1142 + Message: "Failed to get API key info", 1143 + } 1144 + 1145 + data, err := json.Marshal(errorResp) 1146 + if err != nil { 1147 + t.Fatalf("Failed to marshal error response: %v", err) 1148 + } 1149 + 1150 + var decoded map[string]interface{} 1151 + if err := json.Unmarshal(data, &decoded); err != nil { 1152 + t.Fatalf("Failed to unmarshal response: %v", err) 1153 + } 1154 + 1155 + if decoded["error"] != "InternalServerError" { 1156 + t.Errorf("Expected error 'InternalServerError', got %v", decoded["error"]) 1157 + } 1158 + }) 1159 + }
+102
internal/api/handlers/aggregator/create_api_key.go
··· 1 + package aggregator 2 + 3 + import ( 4 + "errors" 5 + "log" 6 + "net/http" 7 + 8 + "Coves/internal/api/middleware" 9 + "Coves/internal/core/aggregators" 10 + ) 11 + 12 + // CreateAPIKeyHandler handles API key creation for aggregators 13 + type CreateAPIKeyHandler struct { 14 + apiKeyService aggregators.APIKeyServiceInterface 15 + aggregatorService aggregators.Service 16 + } 17 + 18 + // NewCreateAPIKeyHandler creates a new handler for API key creation 19 + func NewCreateAPIKeyHandler(apiKeyService aggregators.APIKeyServiceInterface, aggregatorService aggregators.Service) *CreateAPIKeyHandler { 20 + return &CreateAPIKeyHandler{ 21 + apiKeyService: apiKeyService, 22 + aggregatorService: aggregatorService, 23 + } 24 + } 25 + 26 + // CreateAPIKeyResponse represents the response when creating an API key 27 + type CreateAPIKeyResponse struct { 28 + Key string `json:"key"` // The plain-text key (shown ONCE) 29 + KeyPrefix string `json:"keyPrefix"` // First 12 chars for identification 30 + DID string `json:"did"` // Aggregator DID 31 + CreatedAt string `json:"createdAt"` // ISO8601 timestamp 32 + } 33 + 34 + // HandleCreateAPIKey handles POST /xrpc/social.coves.aggregator.createApiKey 35 + // This endpoint requires OAuth authentication and is only available to registered aggregators. 36 + // The API key is returned ONCE and cannot be retrieved again. 37 + // 38 + // Key Replacement: If an aggregator already has an API key, calling this endpoint will 39 + // generate a new key and replace the existing one. The old key will be immediately 40 + // invalidated and all future requests using the old key will fail authentication. 41 + func (h *CreateAPIKeyHandler) HandleCreateAPIKey(w http.ResponseWriter, r *http.Request) { 42 + if r.Method != http.MethodPost { 43 + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 44 + return 45 + } 46 + 47 + // Get authenticated DID from context (set by RequireAuth middleware) 48 + userDID := middleware.GetUserDID(r) 49 + if userDID == "" { 50 + writeError(w, http.StatusUnauthorized, "AuthenticationRequired", "Must be authenticated to create API key") 51 + return 52 + } 53 + 54 + // Verify the caller is a registered aggregator 55 + isAggregator, err := h.aggregatorService.IsAggregator(r.Context(), userDID) 56 + if err != nil { 57 + log.Printf("ERROR: Failed to check aggregator status: %v", err) 58 + writeError(w, http.StatusInternalServerError, "InternalServerError", "Failed to verify aggregator status") 59 + return 60 + } 61 + if !isAggregator { 62 + writeError(w, http.StatusForbidden, "AggregatorRequired", "Only registered aggregators can create API keys") 63 + return 64 + } 65 + 66 + // Get the OAuth session from context 67 + oauthSession := middleware.GetOAuthSession(r) 68 + if oauthSession == nil { 69 + writeError(w, http.StatusUnauthorized, "OAuthSessionRequired", "OAuth session required to create API key") 70 + return 71 + } 72 + 73 + // Generate the API key 74 + plainKey, keyPrefix, err := h.apiKeyService.GenerateKey(r.Context(), userDID, oauthSession) 75 + if err != nil { 76 + log.Printf("ERROR: Failed to generate API key for %s: %v", userDID, err) 77 + 78 + // Differentiate error types for appropriate HTTP status codes 79 + switch { 80 + case aggregators.IsNotFound(err): 81 + // Aggregator not found in database - should not happen if IsAggregator check passed 82 + writeError(w, http.StatusForbidden, "AggregatorRequired", "User is not a registered aggregator") 83 + case errors.Is(err, aggregators.ErrOAuthSessionMismatch): 84 + // OAuth session DID doesn't match the requested aggregator DID 85 + writeError(w, http.StatusBadRequest, "SessionMismatch", "OAuth session does not match the requested aggregator") 86 + default: 87 + // All other errors are internal server errors 88 + writeError(w, http.StatusInternalServerError, "KeyGenerationFailed", "Failed to generate API key") 89 + } 90 + return 91 + } 92 + 93 + // Return the key (shown ONCE only) 94 + response := CreateAPIKeyResponse{ 95 + Key: plainKey, 96 + KeyPrefix: keyPrefix, 97 + DID: userDID, 98 + CreatedAt: formatTimestamp(), 99 + } 100 + 101 + writeJSONResponse(w, http.StatusOK, response) 102 + }
+28 -6
internal/api/handlers/aggregator/errors.go
··· 3 3 import ( 4 4 "Coves/internal/core/aggregators" 5 5 "Coves/internal/core/communities" 6 + "bytes" 6 7 "encoding/json" 7 8 "log" 8 9 "net/http" ··· 14 15 Message string `json:"message"` 15 16 } 16 17 17 - // writeError writes a JSON error response 18 - func writeError(w http.ResponseWriter, statusCode int, errorType, message string) { 18 + // writeJSONResponse buffers the JSON encoding before sending headers. 19 + // This ensures that encoding failures don't result in partial responses 20 + // with already-sent headers. Returns true if the response was written 21 + // successfully, false otherwise. 22 + func writeJSONResponse(w http.ResponseWriter, statusCode int, data interface{}) bool { 23 + // Buffer the JSON first to detect encoding errors before sending headers 24 + var buf bytes.Buffer 25 + if err := json.NewEncoder(&buf).Encode(data); err != nil { 26 + log.Printf("ERROR: Failed to encode JSON response: %v", err) 27 + // Send a proper error response since we haven't sent headers yet 28 + w.Header().Set("Content-Type", "application/json") 29 + w.WriteHeader(http.StatusInternalServerError) 30 + _, _ = w.Write([]byte(`{"error":"InternalServerError","message":"Failed to encode response"}`)) 31 + return false 32 + } 33 + 19 34 w.Header().Set("Content-Type", "application/json") 20 35 w.WriteHeader(statusCode) 21 - if err := json.NewEncoder(w).Encode(ErrorResponse{ 36 + if _, err := w.Write(buf.Bytes()); err != nil { 37 + log.Printf("ERROR: Failed to write response body: %v", err) 38 + return false 39 + } 40 + return true 41 + } 42 + 43 + // writeError writes a JSON error response with proper buffering 44 + func writeError(w http.ResponseWriter, statusCode int, errorType, message string) { 45 + writeJSONResponse(w, statusCode, ErrorResponse{ 22 46 Error: errorType, 23 47 Message: message, 24 - }); err != nil { 25 - log.Printf("ERROR: Failed to encode error response: %v", err) 26 - } 48 + }) 27 49 } 28 50 29 51 // handleServiceError maps service errors to HTTP responses
+109
internal/api/handlers/aggregator/get_api_key.go
··· 1 + package aggregator 2 + 3 + import ( 4 + "log" 5 + "net/http" 6 + 7 + "Coves/internal/api/middleware" 8 + "Coves/internal/core/aggregators" 9 + ) 10 + 11 + // GetAPIKeyHandler handles API key info retrieval for aggregators 12 + type GetAPIKeyHandler struct { 13 + apiKeyService aggregators.APIKeyServiceInterface 14 + aggregatorService aggregators.Service 15 + } 16 + 17 + // NewGetAPIKeyHandler creates a new handler for API key info retrieval 18 + func NewGetAPIKeyHandler(apiKeyService aggregators.APIKeyServiceInterface, aggregatorService aggregators.Service) *GetAPIKeyHandler { 19 + return &GetAPIKeyHandler{ 20 + apiKeyService: apiKeyService, 21 + aggregatorService: aggregatorService, 22 + } 23 + } 24 + 25 + // APIKeyView represents the nested key metadata (matches social.coves.aggregator.defs#apiKeyView) 26 + type APIKeyView struct { 27 + Prefix string `json:"prefix"` // First 12 chars for identification 28 + CreatedAt string `json:"createdAt"` // ISO8601 timestamp when key was created 29 + LastUsedAt *string `json:"lastUsedAt,omitempty"` // ISO8601 timestamp when key was last used 30 + IsRevoked bool `json:"isRevoked"` // Whether the key has been revoked 31 + RevokedAt *string `json:"revokedAt,omitempty"` // ISO8601 timestamp when key was revoked 32 + } 33 + 34 + // GetAPIKeyResponse represents the response when getting API key info 35 + type GetAPIKeyResponse struct { 36 + HasKey bool `json:"hasKey"` // Whether the aggregator has an API key 37 + KeyInfo *APIKeyView `json:"keyInfo,omitempty"` // Key metadata (only present if hasKey is true) 38 + } 39 + 40 + // HandleGetAPIKey handles GET /xrpc/social.coves.aggregator.getApiKey 41 + // This endpoint requires OAuth authentication and returns info about the aggregator's API key. 42 + // NOTE: The actual key value is NEVER returned - only metadata about the key. 43 + func (h *GetAPIKeyHandler) HandleGetAPIKey(w http.ResponseWriter, r *http.Request) { 44 + if r.Method != http.MethodGet { 45 + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 46 + return 47 + } 48 + 49 + // Get authenticated DID from context (set by RequireAuth middleware) 50 + userDID := middleware.GetUserDID(r) 51 + if userDID == "" { 52 + writeError(w, http.StatusUnauthorized, "AuthenticationRequired", "Must be authenticated to get API key info") 53 + return 54 + } 55 + 56 + // Verify the caller is a registered aggregator 57 + isAggregator, err := h.aggregatorService.IsAggregator(r.Context(), userDID) 58 + if err != nil { 59 + log.Printf("ERROR: Failed to check aggregator status: %v", err) 60 + writeError(w, http.StatusInternalServerError, "InternalServerError", "Failed to verify aggregator status") 61 + return 62 + } 63 + if !isAggregator { 64 + writeError(w, http.StatusForbidden, "AggregatorRequired", "Only registered aggregators can get API key info") 65 + return 66 + } 67 + 68 + // Get API key info 69 + keyInfo, err := h.apiKeyService.GetAPIKeyInfo(r.Context(), userDID) 70 + if err != nil { 71 + if aggregators.IsNotFound(err) { 72 + writeError(w, http.StatusNotFound, "AggregatorNotFound", "Aggregator not found") 73 + return 74 + } 75 + log.Printf("ERROR: Failed to get API key info for %s: %v", userDID, err) 76 + writeError(w, http.StatusInternalServerError, "InternalServerError", "Failed to get API key info") 77 + return 78 + } 79 + 80 + // Build response 81 + response := GetAPIKeyResponse{ 82 + HasKey: keyInfo.HasKey, 83 + } 84 + 85 + if keyInfo.HasKey { 86 + view := &APIKeyView{ 87 + Prefix: keyInfo.KeyPrefix, 88 + IsRevoked: keyInfo.IsRevoked, 89 + } 90 + 91 + if keyInfo.CreatedAt != nil { 92 + view.CreatedAt = keyInfo.CreatedAt.Format("2006-01-02T15:04:05.000Z") 93 + } 94 + 95 + if keyInfo.LastUsedAt != nil { 96 + ts := keyInfo.LastUsedAt.Format("2006-01-02T15:04:05.000Z") 97 + view.LastUsedAt = &ts 98 + } 99 + 100 + if keyInfo.RevokedAt != nil { 101 + ts := keyInfo.RevokedAt.Format("2006-01-02T15:04:05.000Z") 102 + view.RevokedAt = &ts 103 + } 104 + 105 + response.KeyInfo = view 106 + } 107 + 108 + writeJSONResponse(w, http.StatusOK, response) 109 + }
+42
internal/api/handlers/aggregator/metrics.go
··· 1 + package aggregator 2 + 3 + import ( 4 + "net/http" 5 + 6 + "Coves/internal/core/aggregators" 7 + ) 8 + 9 + // MetricsHandler provides API key service metrics for monitoring 10 + type MetricsHandler struct { 11 + apiKeyService aggregators.APIKeyServiceInterface 12 + } 13 + 14 + // NewMetricsHandler creates a new metrics handler 15 + func NewMetricsHandler(apiKeyService aggregators.APIKeyServiceInterface) *MetricsHandler { 16 + return &MetricsHandler{ 17 + apiKeyService: apiKeyService, 18 + } 19 + } 20 + 21 + // MetricsResponse contains API key service operational metrics 22 + type MetricsResponse struct { 23 + FailedLastUsedUpdates int64 `json:"failedLastUsedUpdates"` 24 + FailedNonceUpdates int64 `json:"failedNonceUpdates"` 25 + } 26 + 27 + // HandleMetrics handles GET /xrpc/social.coves.aggregator.getMetrics 28 + // Returns operational metrics for the API key service. 29 + // This endpoint is intended for internal monitoring and health checks. 30 + func (h *MetricsHandler) HandleMetrics(w http.ResponseWriter, r *http.Request) { 31 + if r.Method != http.MethodGet { 32 + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 33 + return 34 + } 35 + 36 + response := MetricsResponse{ 37 + FailedLastUsedUpdates: h.apiKeyService.GetFailedLastUsedUpdates(), 38 + FailedNonceUpdates: h.apiKeyService.GetFailedNonceUpdates(), 39 + } 40 + 41 + writeJSONResponse(w, http.StatusOK, response) 42 + }
+99
internal/api/handlers/aggregator/revoke_api_key.go
··· 1 + package aggregator 2 + 3 + import ( 4 + "log" 5 + "net/http" 6 + "time" 7 + 8 + "Coves/internal/api/middleware" 9 + "Coves/internal/core/aggregators" 10 + ) 11 + 12 + // RevokeAPIKeyHandler handles API key revocation for aggregators 13 + type RevokeAPIKeyHandler struct { 14 + apiKeyService aggregators.APIKeyServiceInterface 15 + aggregatorService aggregators.Service 16 + } 17 + 18 + // NewRevokeAPIKeyHandler creates a new handler for API key revocation 19 + func NewRevokeAPIKeyHandler(apiKeyService aggregators.APIKeyServiceInterface, aggregatorService aggregators.Service) *RevokeAPIKeyHandler { 20 + return &RevokeAPIKeyHandler{ 21 + apiKeyService: apiKeyService, 22 + aggregatorService: aggregatorService, 23 + } 24 + } 25 + 26 + // RevokeAPIKeyResponse represents the response when revoking an API key 27 + type RevokeAPIKeyResponse struct { 28 + RevokedAt string `json:"revokedAt"` // ISO8601 timestamp when key was revoked 29 + } 30 + 31 + // HandleRevokeAPIKey handles POST /xrpc/social.coves.aggregator.revokeApiKey 32 + // This endpoint requires OAuth authentication and revokes the aggregator's current API key. 33 + // After revocation, the aggregator must complete OAuth flow again to get a new key. 34 + func (h *RevokeAPIKeyHandler) HandleRevokeAPIKey(w http.ResponseWriter, r *http.Request) { 35 + if r.Method != http.MethodPost { 36 + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 37 + return 38 + } 39 + 40 + // Get authenticated DID from context (set by RequireAuth middleware) 41 + userDID := middleware.GetUserDID(r) 42 + if userDID == "" { 43 + writeError(w, http.StatusUnauthorized, "AuthenticationRequired", "Must be authenticated to revoke API key") 44 + return 45 + } 46 + 47 + // Verify the caller is a registered aggregator 48 + isAggregator, err := h.aggregatorService.IsAggregator(r.Context(), userDID) 49 + if err != nil { 50 + log.Printf("ERROR: Failed to check aggregator status: %v", err) 51 + writeError(w, http.StatusInternalServerError, "InternalServerError", "Failed to verify aggregator status") 52 + return 53 + } 54 + if !isAggregator { 55 + writeError(w, http.StatusForbidden, "AggregatorRequired", "Only registered aggregators can revoke API keys") 56 + return 57 + } 58 + 59 + // Check if the aggregator has an API key to revoke 60 + keyInfo, err := h.apiKeyService.GetAPIKeyInfo(r.Context(), userDID) 61 + if err != nil { 62 + if aggregators.IsNotFound(err) { 63 + writeError(w, http.StatusNotFound, "AggregatorNotFound", "Aggregator not found") 64 + return 65 + } 66 + log.Printf("ERROR: Failed to get API key info for %s: %v", userDID, err) 67 + writeError(w, http.StatusInternalServerError, "InternalServerError", "Failed to get API key info") 68 + return 69 + } 70 + 71 + if !keyInfo.HasKey { 72 + writeError(w, http.StatusBadRequest, "ApiKeyNotFound", "No API key exists to revoke") 73 + return 74 + } 75 + 76 + if keyInfo.IsRevoked { 77 + writeError(w, http.StatusBadRequest, "ApiKeyAlreadyRevoked", "API key has already been revoked") 78 + return 79 + } 80 + 81 + // Revoke the API key 82 + if err := h.apiKeyService.RevokeKey(r.Context(), userDID); err != nil { 83 + log.Printf("ERROR: Failed to revoke API key for %s: %v", userDID, err) 84 + writeError(w, http.StatusInternalServerError, "RevocationFailed", "Failed to revoke API key") 85 + return 86 + } 87 + 88 + // Return success 89 + response := RevokeAPIKeyResponse{ 90 + RevokedAt: time.Now().UTC().Format("2006-01-02T15:04:05.000Z"), 91 + } 92 + 93 + writeJSONResponse(w, http.StatusOK, response) 94 + } 95 + 96 + // formatTimestamp returns current time in ISO8601 format 97 + func formatTimestamp() string { 98 + return time.Now().UTC().Format("2006-01-02T15:04:05.000Z") 99 + }
+45
internal/api/middleware/apikey_adapter.go
··· 1 + package middleware 2 + 3 + import ( 4 + "Coves/internal/core/aggregators" 5 + "context" 6 + ) 7 + 8 + // APIKeyValidatorAdapter adapts the aggregators.APIKeyService to the middleware.APIKeyValidator interface 9 + type APIKeyValidatorAdapter struct { 10 + service *aggregators.APIKeyService 11 + } 12 + 13 + // NewAPIKeyValidatorAdapter creates a new adapter for API key validation 14 + func NewAPIKeyValidatorAdapter(service *aggregators.APIKeyService) *APIKeyValidatorAdapter { 15 + return &APIKeyValidatorAdapter{ 16 + service: service, 17 + } 18 + } 19 + 20 + // ValidateKey validates an API key and returns the aggregator DID if valid 21 + func (a *APIKeyValidatorAdapter) ValidateKey(ctx context.Context, plainKey string) (string, error) { 22 + creds, err := a.service.ValidateKey(ctx, plainKey) 23 + if err != nil { 24 + return "", err 25 + } 26 + return creds.DID, nil 27 + } 28 + 29 + // RefreshTokensIfNeeded refreshes OAuth tokens for the aggregator if they are expired 30 + func (a *APIKeyValidatorAdapter) RefreshTokensIfNeeded(ctx context.Context, aggregatorDID string) error { 31 + creds, err := a.service.GetAggregatorCredentials(ctx, aggregatorDID) 32 + if err != nil { 33 + return err 34 + } 35 + 36 + if creds.APIKeyRevokedAt != nil { 37 + return aggregators.ErrAPIKeyRevoked 38 + } 39 + 40 + if creds.APIKeyHash == "" { 41 + return aggregators.ErrAPIKeyInvalid 42 + } 43 + 44 + return a.service.RefreshTokensIfNeeded(ctx, creds) 45 + }
+574
internal/api/middleware/apikey_adapter_test.go
··· 1 + package middleware 2 + 3 + import ( 4 + "Coves/internal/core/aggregators" 5 + "context" 6 + "errors" 7 + "testing" 8 + "time" 9 + 10 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 11 + "github.com/bluesky-social/indigo/atproto/syntax" 12 + ) 13 + 14 + // minimalMockOAuthStore implements oauth.SessionStore for testing. 15 + // This is a minimal implementation that just returns errors, used for tests 16 + // that don't actually need OAuth functionality. 17 + type minimalMockOAuthStore struct{} 18 + 19 + func (m *minimalMockOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { 20 + return nil, errors.New("session not found") 21 + } 22 + 23 + func (m *minimalMockOAuthStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { 24 + return nil 25 + } 26 + 27 + func (m *minimalMockOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 28 + return nil 29 + } 30 + 31 + func (m *minimalMockOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { 32 + return nil, errors.New("not found") 33 + } 34 + 35 + func (m *minimalMockOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 36 + return nil 37 + } 38 + 39 + func (m *minimalMockOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 40 + return nil 41 + } 42 + 43 + // newTestAPIKeyService creates an APIKeyService with mock dependencies for testing. 44 + // This helper ensures tests don't panic from nil checks added in constructor validation. 45 + func newTestAPIKeyService(repo aggregators.Repository) *aggregators.APIKeyService { 46 + mockStore := &minimalMockOAuthStore{} 47 + mockApp := &oauth.ClientApp{Store: mockStore} 48 + return aggregators.NewAPIKeyService(repo, mockApp) 49 + } 50 + 51 + // mockAPIKeyServiceRepository implements aggregators.Repository for testing 52 + type mockAPIKeyServiceRepository struct { 53 + getAggregatorFunc func(ctx context.Context, did string) (*aggregators.Aggregator, error) 54 + getByAPIKeyHashFunc func(ctx context.Context, keyHash string) (*aggregators.Aggregator, error) 55 + getCredentialsByAPIKeyHashFunc func(ctx context.Context, keyHash string) (*aggregators.AggregatorCredentials, error) 56 + getAggregatorCredentialsFunc func(ctx context.Context, did string) (*aggregators.AggregatorCredentials, error) 57 + setAPIKeyFunc func(ctx context.Context, did, keyPrefix, keyHash string, oauthCreds *aggregators.OAuthCredentials) error 58 + updateOAuthTokensFunc func(ctx context.Context, did, accessToken, refreshToken string, expiresAt time.Time) error 59 + updateOAuthNoncesFunc func(ctx context.Context, did, authServerNonce, pdsNonce string) error 60 + updateAPIKeyLastUsedFunc func(ctx context.Context, did string) error 61 + revokeAPIKeyFunc func(ctx context.Context, did string) error 62 + } 63 + 64 + func (m *mockAPIKeyServiceRepository) GetAggregator(ctx context.Context, did string) (*aggregators.Aggregator, error) { 65 + if m.getAggregatorFunc != nil { 66 + return m.getAggregatorFunc(ctx, did) 67 + } 68 + return &aggregators.Aggregator{DID: did, DisplayName: "Test Aggregator"}, nil 69 + } 70 + 71 + func (m *mockAPIKeyServiceRepository) GetByAPIKeyHash(ctx context.Context, keyHash string) (*aggregators.Aggregator, error) { 72 + if m.getByAPIKeyHashFunc != nil { 73 + return m.getByAPIKeyHashFunc(ctx, keyHash) 74 + } 75 + return nil, aggregators.ErrAggregatorNotFound 76 + } 77 + 78 + func (m *mockAPIKeyServiceRepository) SetAPIKey(ctx context.Context, did, keyPrefix, keyHash string, oauthCreds *aggregators.OAuthCredentials) error { 79 + if m.setAPIKeyFunc != nil { 80 + return m.setAPIKeyFunc(ctx, did, keyPrefix, keyHash, oauthCreds) 81 + } 82 + return nil 83 + } 84 + 85 + func (m *mockAPIKeyServiceRepository) UpdateOAuthTokens(ctx context.Context, did, accessToken, refreshToken string, expiresAt time.Time) error { 86 + if m.updateOAuthTokensFunc != nil { 87 + return m.updateOAuthTokensFunc(ctx, did, accessToken, refreshToken, expiresAt) 88 + } 89 + return nil 90 + } 91 + 92 + func (m *mockAPIKeyServiceRepository) UpdateOAuthNonces(ctx context.Context, did, authServerNonce, pdsNonce string) error { 93 + if m.updateOAuthNoncesFunc != nil { 94 + return m.updateOAuthNoncesFunc(ctx, did, authServerNonce, pdsNonce) 95 + } 96 + return nil 97 + } 98 + 99 + func (m *mockAPIKeyServiceRepository) UpdateAPIKeyLastUsed(ctx context.Context, did string) error { 100 + if m.updateAPIKeyLastUsedFunc != nil { 101 + return m.updateAPIKeyLastUsedFunc(ctx, did) 102 + } 103 + return nil 104 + } 105 + 106 + func (m *mockAPIKeyServiceRepository) RevokeAPIKey(ctx context.Context, did string) error { 107 + if m.revokeAPIKeyFunc != nil { 108 + return m.revokeAPIKeyFunc(ctx, did) 109 + } 110 + return nil 111 + } 112 + 113 + // Stub implementations for Repository interface methods not used in APIKeyService tests 114 + func (m *mockAPIKeyServiceRepository) CreateAggregator(ctx context.Context, aggregator *aggregators.Aggregator) error { 115 + return nil 116 + } 117 + 118 + func (m *mockAPIKeyServiceRepository) GetAggregatorsByDIDs(ctx context.Context, dids []string) ([]*aggregators.Aggregator, error) { 119 + return nil, nil 120 + } 121 + 122 + func (m *mockAPIKeyServiceRepository) UpdateAggregator(ctx context.Context, aggregator *aggregators.Aggregator) error { 123 + return nil 124 + } 125 + 126 + func (m *mockAPIKeyServiceRepository) DeleteAggregator(ctx context.Context, did string) error { 127 + return nil 128 + } 129 + 130 + func (m *mockAPIKeyServiceRepository) ListAggregators(ctx context.Context, limit, offset int) ([]*aggregators.Aggregator, error) { 131 + return nil, nil 132 + } 133 + 134 + func (m *mockAPIKeyServiceRepository) IsAggregator(ctx context.Context, did string) (bool, error) { 135 + return false, nil 136 + } 137 + 138 + func (m *mockAPIKeyServiceRepository) CreateAuthorization(ctx context.Context, auth *aggregators.Authorization) error { 139 + return nil 140 + } 141 + 142 + func (m *mockAPIKeyServiceRepository) GetAuthorization(ctx context.Context, aggregatorDID, communityDID string) (*aggregators.Authorization, error) { 143 + return nil, nil 144 + } 145 + 146 + func (m *mockAPIKeyServiceRepository) GetAuthorizationByURI(ctx context.Context, recordURI string) (*aggregators.Authorization, error) { 147 + return nil, nil 148 + } 149 + 150 + func (m *mockAPIKeyServiceRepository) UpdateAuthorization(ctx context.Context, auth *aggregators.Authorization) error { 151 + return nil 152 + } 153 + 154 + func (m *mockAPIKeyServiceRepository) DeleteAuthorization(ctx context.Context, aggregatorDID, communityDID string) error { 155 + return nil 156 + } 157 + 158 + func (m *mockAPIKeyServiceRepository) DeleteAuthorizationByURI(ctx context.Context, recordURI string) error { 159 + return nil 160 + } 161 + 162 + func (m *mockAPIKeyServiceRepository) ListAuthorizationsForAggregator(ctx context.Context, aggregatorDID string, enabledOnly bool, limit, offset int) ([]*aggregators.Authorization, error) { 163 + return nil, nil 164 + } 165 + 166 + func (m *mockAPIKeyServiceRepository) ListAuthorizationsForCommunity(ctx context.Context, communityDID string, enabledOnly bool, limit, offset int) ([]*aggregators.Authorization, error) { 167 + return nil, nil 168 + } 169 + 170 + func (m *mockAPIKeyServiceRepository) IsAuthorized(ctx context.Context, aggregatorDID, communityDID string) (bool, error) { 171 + return false, nil 172 + } 173 + 174 + func (m *mockAPIKeyServiceRepository) RecordAggregatorPost(ctx context.Context, aggregatorDID, communityDID, postURI, postCID string) error { 175 + return nil 176 + } 177 + 178 + func (m *mockAPIKeyServiceRepository) CountRecentPosts(ctx context.Context, aggregatorDID, communityDID string, since time.Time) (int, error) { 179 + return 0, nil 180 + } 181 + 182 + func (m *mockAPIKeyServiceRepository) GetRecentPosts(ctx context.Context, aggregatorDID, communityDID string, since time.Time) ([]*aggregators.AggregatorPost, error) { 183 + return nil, nil 184 + } 185 + 186 + func (m *mockAPIKeyServiceRepository) GetAggregatorCredentials(ctx context.Context, did string) (*aggregators.AggregatorCredentials, error) { 187 + if m.getAggregatorCredentialsFunc != nil { 188 + return m.getAggregatorCredentialsFunc(ctx, did) 189 + } 190 + return &aggregators.AggregatorCredentials{DID: did}, nil 191 + } 192 + 193 + func (m *mockAPIKeyServiceRepository) GetCredentialsByAPIKeyHash(ctx context.Context, keyHash string) (*aggregators.AggregatorCredentials, error) { 194 + if m.getCredentialsByAPIKeyHashFunc != nil { 195 + return m.getCredentialsByAPIKeyHashFunc(ctx, keyHash) 196 + } 197 + return nil, aggregators.ErrAggregatorNotFound 198 + } 199 + 200 + // ============================================================================= 201 + // ValidateKey Delegation Tests 202 + // ============================================================================= 203 + 204 + func TestAPIKeyValidatorAdapter_ValidateKey_DelegatesToService(t *testing.T) { 205 + expectedDID := "did:plc:aggregator123" 206 + 207 + repo := &mockAPIKeyServiceRepository{ 208 + getCredentialsByAPIKeyHashFunc: func(ctx context.Context, keyHash string) (*aggregators.AggregatorCredentials, error) { 209 + return &aggregators.AggregatorCredentials{ 210 + DID: expectedDID, 211 + APIKeyHash: keyHash, 212 + APIKeyPrefix: "ckapi_0123", 213 + }, nil 214 + }, 215 + updateAPIKeyLastUsedFunc: func(ctx context.Context, did string) error { 216 + return nil 217 + }, 218 + } 219 + 220 + service := newTestAPIKeyService(repo) 221 + adapter := NewAPIKeyValidatorAdapter(service) 222 + 223 + validKey := "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" 224 + did, err := adapter.ValidateKey(context.Background(), validKey) 225 + if err != nil { 226 + t.Fatalf("ValidateKey() unexpected error: %v", err) 227 + } 228 + 229 + if did != expectedDID { 230 + t.Errorf("ValidateKey() = %s, want %s", did, expectedDID) 231 + } 232 + } 233 + 234 + func TestAPIKeyValidatorAdapter_ValidateKey_InvalidKey(t *testing.T) { 235 + repo := &mockAPIKeyServiceRepository{} 236 + service := newTestAPIKeyService(repo) 237 + adapter := NewAPIKeyValidatorAdapter(service) 238 + 239 + // Test various invalid key formats 240 + tests := []struct { 241 + name string 242 + key string 243 + }{ 244 + {"empty key", ""}, 245 + {"too short", "ckapi_short"}, 246 + {"wrong prefix", "wrong_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}, 247 + } 248 + 249 + for _, tt := range tests { 250 + t.Run(tt.name, func(t *testing.T) { 251 + _, err := adapter.ValidateKey(context.Background(), tt.key) 252 + if err == nil { 253 + t.Error("ValidateKey() expected error, got nil") 254 + } 255 + if !errors.Is(err, aggregators.ErrAPIKeyInvalid) { 256 + t.Errorf("ValidateKey() error = %v, want %v", err, aggregators.ErrAPIKeyInvalid) 257 + } 258 + }) 259 + } 260 + } 261 + 262 + func TestAPIKeyValidatorAdapter_ValidateKey_NotFound(t *testing.T) { 263 + repo := &mockAPIKeyServiceRepository{ 264 + getCredentialsByAPIKeyHashFunc: func(ctx context.Context, keyHash string) (*aggregators.AggregatorCredentials, error) { 265 + return nil, aggregators.ErrAggregatorNotFound 266 + }, 267 + } 268 + 269 + service := newTestAPIKeyService(repo) 270 + adapter := NewAPIKeyValidatorAdapter(service) 271 + 272 + validKey := "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" 273 + _, err := adapter.ValidateKey(context.Background(), validKey) 274 + if err == nil { 275 + t.Error("ValidateKey() expected error, got nil") 276 + } 277 + // Should return ErrAPIKeyInvalid when key not found 278 + if !errors.Is(err, aggregators.ErrAPIKeyInvalid) { 279 + t.Errorf("ValidateKey() error = %v, want %v", err, aggregators.ErrAPIKeyInvalid) 280 + } 281 + } 282 + 283 + func TestAPIKeyValidatorAdapter_ValidateKey_Revoked(t *testing.T) { 284 + repo := &mockAPIKeyServiceRepository{ 285 + getCredentialsByAPIKeyHashFunc: func(ctx context.Context, keyHash string) (*aggregators.AggregatorCredentials, error) { 286 + return nil, aggregators.ErrAPIKeyRevoked 287 + }, 288 + } 289 + 290 + service := newTestAPIKeyService(repo) 291 + adapter := NewAPIKeyValidatorAdapter(service) 292 + 293 + validKey := "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" 294 + _, err := adapter.ValidateKey(context.Background(), validKey) 295 + if err == nil { 296 + t.Error("ValidateKey() expected error, got nil") 297 + } 298 + if !errors.Is(err, aggregators.ErrAPIKeyRevoked) { 299 + t.Errorf("ValidateKey() error = %v, want %v", err, aggregators.ErrAPIKeyRevoked) 300 + } 301 + } 302 + 303 + func TestAPIKeyValidatorAdapter_ValidateKey_RepositoryError(t *testing.T) { 304 + expectedError := errors.New("database connection failed") 305 + 306 + repo := &mockAPIKeyServiceRepository{ 307 + getCredentialsByAPIKeyHashFunc: func(ctx context.Context, keyHash string) (*aggregators.AggregatorCredentials, error) { 308 + return nil, expectedError 309 + }, 310 + } 311 + 312 + service := newTestAPIKeyService(repo) 313 + adapter := NewAPIKeyValidatorAdapter(service) 314 + 315 + validKey := "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" 316 + _, err := adapter.ValidateKey(context.Background(), validKey) 317 + if err == nil { 318 + t.Error("ValidateKey() expected error, got nil") 319 + } 320 + } 321 + 322 + // ============================================================================= 323 + // RefreshTokensIfNeeded Delegation Tests 324 + // ============================================================================= 325 + 326 + func TestAPIKeyValidatorAdapter_RefreshTokensIfNeeded_DelegatesToService(t *testing.T) { 327 + // Tokens expire in 1 hour - well beyond the 5 minute buffer, so no refresh needed 328 + expiresAt := time.Now().Add(1 * time.Hour) 329 + aggregatorDID := "did:plc:aggregator123" 330 + 331 + repo := &mockAPIKeyServiceRepository{ 332 + getAggregatorCredentialsFunc: func(ctx context.Context, did string) (*aggregators.AggregatorCredentials, error) { 333 + return &aggregators.AggregatorCredentials{ 334 + DID: did, 335 + APIKeyHash: "somehash", 336 + OAuthTokenExpiresAt: &expiresAt, 337 + }, nil 338 + }, 339 + } 340 + 341 + service := newTestAPIKeyService(repo) 342 + adapter := NewAPIKeyValidatorAdapter(service) 343 + 344 + err := adapter.RefreshTokensIfNeeded(context.Background(), aggregatorDID) 345 + if err != nil { 346 + t.Fatalf("RefreshTokensIfNeeded() unexpected error: %v", err) 347 + } 348 + } 349 + 350 + func TestAPIKeyValidatorAdapter_RefreshTokensIfNeeded_AggregatorNotFound(t *testing.T) { 351 + repo := &mockAPIKeyServiceRepository{ 352 + getAggregatorCredentialsFunc: func(ctx context.Context, did string) (*aggregators.AggregatorCredentials, error) { 353 + return nil, aggregators.ErrAggregatorNotFound 354 + }, 355 + } 356 + 357 + service := newTestAPIKeyService(repo) 358 + adapter := NewAPIKeyValidatorAdapter(service) 359 + 360 + err := adapter.RefreshTokensIfNeeded(context.Background(), "did:plc:nonexistent") 361 + if err == nil { 362 + t.Error("RefreshTokensIfNeeded() expected error, got nil") 363 + } 364 + if !errors.Is(err, aggregators.ErrAggregatorNotFound) { 365 + t.Errorf("RefreshTokensIfNeeded() error = %v, want %v", err, aggregators.ErrAggregatorNotFound) 366 + } 367 + } 368 + 369 + func TestAPIKeyValidatorAdapter_RefreshTokensIfNeeded_NoAPIKey(t *testing.T) { 370 + aggregatorDID := "did:plc:aggregator123" 371 + 372 + repo := &mockAPIKeyServiceRepository{ 373 + getAggregatorCredentialsFunc: func(ctx context.Context, did string) (*aggregators.AggregatorCredentials, error) { 374 + return &aggregators.AggregatorCredentials{ 375 + DID: did, 376 + APIKeyHash: "", // No API key 377 + }, nil 378 + }, 379 + } 380 + 381 + service := newTestAPIKeyService(repo) 382 + adapter := NewAPIKeyValidatorAdapter(service) 383 + 384 + // Should return ErrAPIKeyInvalid when no API key exists 385 + err := adapter.RefreshTokensIfNeeded(context.Background(), aggregatorDID) 386 + if !errors.Is(err, aggregators.ErrAPIKeyInvalid) { 387 + t.Errorf("RefreshTokensIfNeeded() error = %v, want %v", err, aggregators.ErrAPIKeyInvalid) 388 + } 389 + } 390 + 391 + func TestAPIKeyValidatorAdapter_RefreshTokensIfNeeded_RevokedAPIKey(t *testing.T) { 392 + aggregatorDID := "did:plc:aggregator123" 393 + revokedAt := time.Now().Add(-1 * time.Hour) 394 + 395 + repo := &mockAPIKeyServiceRepository{ 396 + getAggregatorCredentialsFunc: func(ctx context.Context, did string) (*aggregators.AggregatorCredentials, error) { 397 + return &aggregators.AggregatorCredentials{ 398 + DID: did, 399 + APIKeyHash: "somehash", 400 + APIKeyRevokedAt: &revokedAt, // Key is revoked 401 + }, nil 402 + }, 403 + } 404 + 405 + service := newTestAPIKeyService(repo) 406 + adapter := NewAPIKeyValidatorAdapter(service) 407 + 408 + // Should return ErrAPIKeyRevoked when API key is revoked 409 + err := adapter.RefreshTokensIfNeeded(context.Background(), aggregatorDID) 410 + if !errors.Is(err, aggregators.ErrAPIKeyRevoked) { 411 + t.Errorf("RefreshTokensIfNeeded() error = %v, want %v", err, aggregators.ErrAPIKeyRevoked) 412 + } 413 + } 414 + 415 + func TestAPIKeyValidatorAdapter_RefreshTokensIfNeeded_RepositoryError(t *testing.T) { 416 + expectedError := errors.New("database connection failed") 417 + 418 + repo := &mockAPIKeyServiceRepository{ 419 + getAggregatorCredentialsFunc: func(ctx context.Context, did string) (*aggregators.AggregatorCredentials, error) { 420 + return nil, expectedError 421 + }, 422 + } 423 + 424 + service := newTestAPIKeyService(repo) 425 + adapter := NewAPIKeyValidatorAdapter(service) 426 + 427 + err := adapter.RefreshTokensIfNeeded(context.Background(), "did:plc:aggregator123") 428 + if err == nil { 429 + t.Error("RefreshTokensIfNeeded() expected error, got nil") 430 + } 431 + } 432 + 433 + // ============================================================================= 434 + // GetAPIKeyInfo Delegation Tests (via service) 435 + // ============================================================================= 436 + 437 + func TestAPIKeyValidatorAdapter_GetAggregator_DelegatesToService(t *testing.T) { 438 + expectedDID := "did:plc:aggregator123" 439 + expectedDisplayName := "Test Aggregator" 440 + 441 + repo := &mockAPIKeyServiceRepository{ 442 + getAggregatorFunc: func(ctx context.Context, did string) (*aggregators.Aggregator, error) { 443 + return &aggregators.Aggregator{ 444 + DID: expectedDID, 445 + DisplayName: expectedDisplayName, 446 + }, nil 447 + }, 448 + } 449 + 450 + service := newTestAPIKeyService(repo) 451 + 452 + // Test that GetAggregator is properly delegated 453 + aggregator, err := service.GetAggregator(context.Background(), expectedDID) 454 + if err != nil { 455 + t.Fatalf("GetAggregator() unexpected error: %v", err) 456 + } 457 + 458 + if aggregator.DID != expectedDID { 459 + t.Errorf("GetAggregator() DID = %s, want %s", aggregator.DID, expectedDID) 460 + } 461 + if aggregator.DisplayName != expectedDisplayName { 462 + t.Errorf("GetAggregator() DisplayName = %s, want %s", aggregator.DisplayName, expectedDisplayName) 463 + } 464 + } 465 + 466 + func TestAPIKeyValidatorAdapter_GetAggregator_NotFound(t *testing.T) { 467 + repo := &mockAPIKeyServiceRepository{ 468 + getAggregatorFunc: func(ctx context.Context, did string) (*aggregators.Aggregator, error) { 469 + return nil, aggregators.ErrAggregatorNotFound 470 + }, 471 + } 472 + 473 + service := newTestAPIKeyService(repo) 474 + 475 + _, err := service.GetAggregator(context.Background(), "did:plc:nonexistent") 476 + if !errors.Is(err, aggregators.ErrAggregatorNotFound) { 477 + t.Errorf("GetAggregator() error = %v, want %v", err, aggregators.ErrAggregatorNotFound) 478 + } 479 + } 480 + 481 + func TestAPIKeyValidatorAdapter_GetAggregator_RepositoryError(t *testing.T) { 482 + expectedError := errors.New("database error") 483 + 484 + repo := &mockAPIKeyServiceRepository{ 485 + getAggregatorFunc: func(ctx context.Context, did string) (*aggregators.Aggregator, error) { 486 + return nil, expectedError 487 + }, 488 + } 489 + 490 + service := newTestAPIKeyService(repo) 491 + 492 + _, err := service.GetAggregator(context.Background(), "did:plc:aggregator123") 493 + if err == nil { 494 + t.Error("GetAggregator() expected error, got nil") 495 + } 496 + } 497 + 498 + // ============================================================================= 499 + // Constructor and nil handling tests 500 + // ============================================================================= 501 + 502 + func TestNewAPIKeyValidatorAdapter(t *testing.T) { 503 + repo := &mockAPIKeyServiceRepository{} 504 + service := newTestAPIKeyService(repo) 505 + 506 + adapter := NewAPIKeyValidatorAdapter(service) 507 + if adapter == nil { 508 + t.Fatal("NewAPIKeyValidatorAdapter() returned nil") 509 + } 510 + } 511 + 512 + // ============================================================================= 513 + // Integration-style test: Full validation flow 514 + // ============================================================================= 515 + 516 + func TestAPIKeyValidatorAdapter_FullValidationFlow(t *testing.T) { 517 + // This test verifies the complete flow: 518 + // 1. Validate API key 519 + // 2. Check if tokens need refresh 520 + // 3. Return aggregator DID 521 + 522 + aggregatorDID := "did:plc:aggregator123" 523 + expiresAt := time.Now().Add(1 * time.Hour) 524 + validationCount := 0 525 + 526 + repo := &mockAPIKeyServiceRepository{ 527 + getCredentialsByAPIKeyHashFunc: func(ctx context.Context, keyHash string) (*aggregators.AggregatorCredentials, error) { 528 + validationCount++ 529 + return &aggregators.AggregatorCredentials{ 530 + DID: aggregatorDID, 531 + APIKeyHash: keyHash, 532 + APIKeyPrefix: "ckapi_0123", 533 + OAuthTokenExpiresAt: &expiresAt, 534 + }, nil 535 + }, 536 + getAggregatorCredentialsFunc: func(ctx context.Context, did string) (*aggregators.AggregatorCredentials, error) { 537 + return &aggregators.AggregatorCredentials{ 538 + DID: did, 539 + APIKeyHash: "somehash", 540 + OAuthTokenExpiresAt: &expiresAt, 541 + }, nil 542 + }, 543 + updateAPIKeyLastUsedFunc: func(ctx context.Context, did string) error { 544 + return nil 545 + }, 546 + } 547 + 548 + service := newTestAPIKeyService(repo) 549 + adapter := NewAPIKeyValidatorAdapter(service) 550 + 551 + // Step 1: Validate the key 552 + validKey := "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" 553 + did, err := adapter.ValidateKey(context.Background(), validKey) 554 + if err != nil { 555 + t.Fatalf("ValidateKey() unexpected error: %v", err) 556 + } 557 + if did != aggregatorDID { 558 + t.Errorf("ValidateKey() DID = %s, want %s", did, aggregatorDID) 559 + } 560 + 561 + // Step 2: Check/refresh tokens (should succeed without refresh since tokens are valid) 562 + err = adapter.RefreshTokensIfNeeded(context.Background(), did) 563 + if err != nil { 564 + t.Errorf("RefreshTokensIfNeeded() unexpected error: %v", err) 565 + } 566 + 567 + // Verify validation was called 568 + if validationCount != 1 { 569 + t.Errorf("Expected 1 validation call, got %d", validationCount) 570 + } 571 + } 572 + 573 + // Ensure we don't have unused import 574 + var _ = oauth.ClientApp{}
+77 -7
internal/api/middleware/auth.go
··· 33 33 const ( 34 34 AuthMethodOAuth = "oauth" 35 35 AuthMethodServiceJWT = "service_jwt" 36 + AuthMethodAPIKey = "api_key" 36 37 ) 38 + 39 + // API key prefix constant 40 + const APIKeyPrefix = "ckapi_" 37 41 38 42 // SessionUnsealer is an interface for unsealing session tokens 39 43 // This allows for mocking in tests ··· 49 53 // ServiceAuthValidator is an interface for validating service JWTs 50 54 type ServiceAuthValidator interface { 51 55 Validate(ctx context.Context, tokenString string, lexMethod *syntax.NSID) (syntax.DID, error) 56 + } 57 + 58 + // APIKeyValidator is an interface for validating API keys (used by aggregators) 59 + type APIKeyValidator interface { 60 + // ValidateKey validates an API key and returns the aggregator DID if valid 61 + ValidateKey(ctx context.Context, plainKey string) (aggregatorDID string, err error) 62 + // RefreshTokensIfNeeded refreshes OAuth tokens for the aggregator if they are expired 63 + RefreshTokensIfNeeded(ctx context.Context, aggregatorDID string) error 52 64 } 53 65 54 66 // OAuthAuthMiddleware enforces OAuth authentication using sealed session tokens. ··· 329 341 } 330 342 } 331 343 332 - // DualAuthMiddleware enforces authentication using either OAuth sealed tokens (for users) 333 - // or PDS service JWTs (for aggregators only). 344 + // DualAuthMiddleware enforces authentication using either OAuth sealed tokens (for users), 345 + // PDS service JWTs (for aggregators), or API keys (for aggregators). 334 346 type DualAuthMiddleware struct { 335 347 unsealer SessionUnsealer 336 348 store oauthlib.ClientAuthStore 337 349 serviceValidator ServiceAuthValidator 338 350 aggregatorChecker AggregatorChecker 351 + apiKeyValidator APIKeyValidator // Optional: if nil, API key auth is disabled 339 352 } 340 353 341 354 // NewDualAuthMiddleware creates a new dual auth middleware that supports both OAuth and service JWT authentication. ··· 353 366 } 354 367 } 355 368 356 - // RequireAuth middleware ensures the user is authenticated via either OAuth or service JWT. 369 + // WithAPIKeyValidator adds API key validation support to the middleware. 370 + // Returns the middleware for method chaining. 371 + func (m *DualAuthMiddleware) WithAPIKeyValidator(validator APIKeyValidator) *DualAuthMiddleware { 372 + m.apiKeyValidator = validator 373 + return m 374 + } 375 + 376 + // RequireAuth middleware ensures the user is authenticated via either OAuth, service JWT, or API key. 357 377 // Supports: 378 + // - API keys via Authorization: Bearer ckapi_... (aggregators only, checked first) 358 379 // - OAuth sealed session tokens via Authorization: Bearer <sealed_token> or Cookie: coves_session=<sealed_token> 359 380 // - Service JWTs via Authorization: Bearer <jwt> 360 381 // 361 - // SECURITY: Service JWT authentication is RESTRICTED to registered aggregators only. 362 - // Non-aggregator DIDs will be rejected even with valid JWT signatures. 363 - // This enforcement happens in handleServiceAuth() via aggregatorChecker.IsAggregator(). 382 + // SECURITY: Service JWT and API key authentication are RESTRICTED to registered aggregators only. 383 + // Non-aggregator DIDs will be rejected even with valid JWT signatures or API keys. 384 + // This enforcement happens in handleServiceAuth() via aggregatorChecker.IsAggregator() and 385 + // in handleAPIKeyAuth() via apiKeyValidator.ValidateKey(). 364 386 // 365 387 // If not authenticated, returns 401. 366 388 // If authenticated, injects user DID and auth method into context. ··· 398 420 log.Printf("[AUTH_TRACE] ip=%s method=%s path=%s token_source=%s", 399 421 r.RemoteAddr, r.Method, r.URL.Path, tokenSource) 400 422 423 + // Check for API key first (before JWT/OAuth routing) 424 + // API keys start with "ckapi_" prefix 425 + if strings.HasPrefix(token, APIKeyPrefix) { 426 + m.handleAPIKeyAuth(w, r, next, token) 427 + return 428 + } 429 + 401 430 // Detect token type and route to appropriate handler 402 431 if isJWTFormat(token) { 403 432 m.handleServiceAuth(w, r, next, token) ··· 411 440 func (m *DualAuthMiddleware) handleServiceAuth(w http.ResponseWriter, r *http.Request, next http.Handler, token string) { 412 441 // Validate the service JWT 413 442 // Note: lexMethod is nil, which allows any lexicon method (endpoint-agnostic validation). 414 - // The ServiceAuthValidator skips the lexicon method check when nil (see indigo/atproto/auth/jwt.go:86-88). 443 + // The ServiceAuthValidator skips the lexicon method check when lexMethod is nil. 415 444 // This is intentional - we want aggregators to authenticate globally, not per-endpoint. 416 445 did, err := m.serviceValidator.Validate(r.Context(), token, nil) 417 446 if err != nil { ··· 447 476 ctx := context.WithValue(r.Context(), UserDIDKey, didStr) 448 477 ctx = context.WithValue(ctx, IsAggregatorAuthKey, true) 449 478 ctx = context.WithValue(ctx, AuthMethodKey, AuthMethodServiceJWT) 479 + 480 + // Call next handler 481 + next.ServeHTTP(w, r.WithContext(ctx)) 482 + } 483 + 484 + // handleAPIKeyAuth handles authentication using Coves API keys (aggregators only) 485 + func (m *DualAuthMiddleware) handleAPIKeyAuth(w http.ResponseWriter, r *http.Request, next http.Handler, token string) { 486 + // Check if API key validation is enabled 487 + if m.apiKeyValidator == nil { 488 + log.Printf("[AUTH_FAILURE] type=api_key_disabled ip=%s method=%s path=%s", 489 + r.RemoteAddr, r.Method, r.URL.Path) 490 + writeAuthError(w, "API key authentication is not enabled") 491 + return 492 + } 493 + 494 + // Validate the API key 495 + aggregatorDID, err := m.apiKeyValidator.ValidateKey(r.Context(), token) 496 + if err != nil { 497 + log.Printf("[AUTH_FAILURE] type=api_key_invalid ip=%s method=%s path=%s error=%v", 498 + r.RemoteAddr, r.Method, r.URL.Path, err) 499 + writeAuthError(w, "Invalid or revoked API key") 500 + return 501 + } 502 + 503 + // Refresh OAuth tokens if needed (for PDS operations) 504 + if err := m.apiKeyValidator.RefreshTokensIfNeeded(r.Context(), aggregatorDID); err != nil { 505 + log.Printf("[AUTH_FAILURE] type=token_refresh_failed ip=%s method=%s path=%s did=%s error=%v", 506 + r.RemoteAddr, r.Method, r.URL.Path, aggregatorDID, err) 507 + // Token refresh failure means the aggregator cannot perform authenticated PDS operations 508 + // This is a critical failure - reject the request so the aggregator knows to re-authenticate 509 + writeAuthError(w, "API key authentication failed: unable to refresh OAuth tokens. Please re-authenticate.") 510 + return 511 + } 512 + 513 + log.Printf("[AUTH_SUCCESS] type=api_key ip=%s method=%s path=%s did=%s", 514 + r.RemoteAddr, r.Method, r.URL.Path, aggregatorDID) 515 + 516 + // Inject DID and auth method into context 517 + ctx := context.WithValue(r.Context(), UserDIDKey, aggregatorDID) 518 + ctx = context.WithValue(ctx, IsAggregatorAuthKey, true) 519 + ctx = context.WithValue(ctx, AuthMethodKey, AuthMethodAPIKey) 450 520 451 521 // Call next handler 452 522 next.ServeHTTP(w, r.WithContext(ctx))
+206
internal/api/middleware/auth_test.go
··· 1691 1691 }) 1692 1692 } 1693 1693 } 1694 + 1695 + // Mock APIKeyValidator for testing 1696 + type mockAPIKeyValidator struct { 1697 + aggregators map[string]string // key -> DID 1698 + shouldFail bool 1699 + refreshCalled bool 1700 + } 1701 + 1702 + func (m *mockAPIKeyValidator) ValidateKey(ctx context.Context, plainKey string) (string, error) { 1703 + if m.shouldFail { 1704 + return "", fmt.Errorf("invalid API key") 1705 + } 1706 + // Extract DID from key for testing (real implementation would hash and look up) 1707 + // Test format: ckapi_<did_suffix>_rest 1708 + if len(plainKey) < 12 { 1709 + return "", fmt.Errorf("invalid key format") 1710 + } 1711 + // For testing, assume valid keys return a known aggregator DID 1712 + if aggregatorDID, ok := m.aggregators[plainKey]; ok { 1713 + return aggregatorDID, nil 1714 + } 1715 + return "", fmt.Errorf("unknown API key") 1716 + } 1717 + 1718 + func (m *mockAPIKeyValidator) RefreshTokensIfNeeded(ctx context.Context, aggregatorDID string) error { 1719 + m.refreshCalled = true 1720 + return nil 1721 + } 1722 + 1723 + // TestDualAuthMiddleware_APIKey_Valid tests API key authentication 1724 + func TestDualAuthMiddleware_APIKey_Valid(t *testing.T) { 1725 + client := newMockOAuthClient() 1726 + store := newMockOAuthStore() 1727 + validator := &mockServiceAuthValidator{} 1728 + aggregatorChecker := &mockAggregatorChecker{ 1729 + aggregators: make(map[string]bool), 1730 + } 1731 + 1732 + apiKeyValidator := &mockAPIKeyValidator{ 1733 + aggregators: map[string]string{ 1734 + "ckapi_test1234567890123456789012345678": "did:plc:aggregator123", 1735 + }, 1736 + } 1737 + 1738 + middleware := NewDualAuthMiddleware(client, store, validator, aggregatorChecker). 1739 + WithAPIKeyValidator(apiKeyValidator) 1740 + 1741 + handlerCalled := false 1742 + handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1743 + handlerCalled = true 1744 + 1745 + // Verify DID was extracted 1746 + extractedDID := GetUserDID(r) 1747 + if extractedDID != "did:plc:aggregator123" { 1748 + t.Errorf("expected DID 'did:plc:aggregator123', got %s", extractedDID) 1749 + } 1750 + 1751 + // Verify it's marked as aggregator auth 1752 + if !IsAggregatorAuth(r) { 1753 + t.Error("expected IsAggregatorAuth to be true") 1754 + } 1755 + 1756 + // Verify auth method 1757 + authMethod := GetAuthMethod(r) 1758 + if authMethod != AuthMethodAPIKey { 1759 + t.Errorf("expected auth method %s, got %s", AuthMethodAPIKey, authMethod) 1760 + } 1761 + 1762 + w.WriteHeader(http.StatusOK) 1763 + })) 1764 + 1765 + req := httptest.NewRequest("GET", "/test", nil) 1766 + req.Header.Set("Authorization", "Bearer ckapi_test1234567890123456789012345678") 1767 + w := httptest.NewRecorder() 1768 + 1769 + handler.ServeHTTP(w, req) 1770 + 1771 + if !handlerCalled { 1772 + t.Error("handler was not called") 1773 + } 1774 + 1775 + if w.Code != http.StatusOK { 1776 + t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String()) 1777 + } 1778 + 1779 + // Verify token refresh was attempted 1780 + if !apiKeyValidator.refreshCalled { 1781 + t.Error("expected token refresh to be called") 1782 + } 1783 + } 1784 + 1785 + // TestDualAuthMiddleware_APIKey_Invalid tests API key authentication with invalid key 1786 + func TestDualAuthMiddleware_APIKey_Invalid(t *testing.T) { 1787 + client := newMockOAuthClient() 1788 + store := newMockOAuthStore() 1789 + validator := &mockServiceAuthValidator{} 1790 + aggregatorChecker := &mockAggregatorChecker{ 1791 + aggregators: make(map[string]bool), 1792 + } 1793 + 1794 + apiKeyValidator := &mockAPIKeyValidator{ 1795 + shouldFail: true, 1796 + } 1797 + 1798 + middleware := NewDualAuthMiddleware(client, store, validator, aggregatorChecker). 1799 + WithAPIKeyValidator(apiKeyValidator) 1800 + 1801 + handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1802 + t.Error("handler should not be called for invalid API key") 1803 + })) 1804 + 1805 + req := httptest.NewRequest("GET", "/test", nil) 1806 + req.Header.Set("Authorization", "Bearer ckapi_invalid_key_12345678901234567") 1807 + w := httptest.NewRecorder() 1808 + 1809 + handler.ServeHTTP(w, req) 1810 + 1811 + if w.Code != http.StatusUnauthorized { 1812 + t.Errorf("expected status 401, got %d", w.Code) 1813 + } 1814 + 1815 + var response map[string]string 1816 + _ = json.Unmarshal(w.Body.Bytes(), &response) 1817 + if response["message"] != "Invalid or revoked API key" { 1818 + t.Errorf("unexpected error message: %s", response["message"]) 1819 + } 1820 + } 1821 + 1822 + // TestDualAuthMiddleware_APIKey_Disabled tests API key auth when validator is not configured 1823 + func TestDualAuthMiddleware_APIKey_Disabled(t *testing.T) { 1824 + client := newMockOAuthClient() 1825 + store := newMockOAuthStore() 1826 + validator := &mockServiceAuthValidator{} 1827 + aggregatorChecker := &mockAggregatorChecker{ 1828 + aggregators: make(map[string]bool), 1829 + } 1830 + 1831 + // No API key validator configured 1832 + middleware := NewDualAuthMiddleware(client, store, validator, aggregatorChecker) 1833 + 1834 + handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1835 + t.Error("handler should not be called when API key auth is disabled") 1836 + })) 1837 + 1838 + req := httptest.NewRequest("GET", "/test", nil) 1839 + req.Header.Set("Authorization", "Bearer ckapi_test1234567890123456789012345678") 1840 + w := httptest.NewRecorder() 1841 + 1842 + handler.ServeHTTP(w, req) 1843 + 1844 + if w.Code != http.StatusUnauthorized { 1845 + t.Errorf("expected status 401, got %d", w.Code) 1846 + } 1847 + 1848 + var response map[string]string 1849 + _ = json.Unmarshal(w.Body.Bytes(), &response) 1850 + if response["message"] != "API key authentication is not enabled" { 1851 + t.Errorf("unexpected error message: %s", response["message"]) 1852 + } 1853 + } 1854 + 1855 + // TestDualAuthMiddleware_APIKey_PrecedenceOverOAuth tests that API keys are detected before OAuth 1856 + func TestDualAuthMiddleware_APIKey_PrecedenceOverOAuth(t *testing.T) { 1857 + client := newMockOAuthClient() 1858 + store := newMockOAuthStore() 1859 + validator := &mockServiceAuthValidator{} 1860 + aggregatorChecker := &mockAggregatorChecker{ 1861 + aggregators: make(map[string]bool), 1862 + } 1863 + 1864 + apiKeyValidator := &mockAPIKeyValidator{ 1865 + aggregators: map[string]string{ 1866 + "ckapi_test1234567890123456789012345678": "did:plc:apikey_aggregator", 1867 + }, 1868 + } 1869 + 1870 + middleware := NewDualAuthMiddleware(client, store, validator, aggregatorChecker). 1871 + WithAPIKeyValidator(apiKeyValidator) 1872 + 1873 + handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1874 + // Verify API key auth was used 1875 + authMethod := GetAuthMethod(r) 1876 + if authMethod != AuthMethodAPIKey { 1877 + t.Errorf("expected API key auth method, got %s", authMethod) 1878 + } 1879 + 1880 + // Verify DID from API key (not OAuth) 1881 + did := GetUserDID(r) 1882 + if did != "did:plc:apikey_aggregator" { 1883 + t.Errorf("expected API key aggregator DID, got %s", did) 1884 + } 1885 + 1886 + w.WriteHeader(http.StatusOK) 1887 + })) 1888 + 1889 + // Use API key format token (starts with ckapi_) 1890 + req := httptest.NewRequest("GET", "/test", nil) 1891 + req.Header.Set("Authorization", "Bearer ckapi_test1234567890123456789012345678") 1892 + w := httptest.NewRecorder() 1893 + 1894 + handler.ServeHTTP(w, req) 1895 + 1896 + if w.Code != http.StatusOK { 1897 + t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String()) 1898 + } 1899 + }
+37
internal/api/routes/aggregator.go
··· 57 57 // POST /xrpc/social.coves.aggregator.disable (requires auth + moderator) 58 58 // POST /xrpc/social.coves.aggregator.updateConfig (requires auth + moderator) 59 59 } 60 + 61 + // RegisterAggregatorAPIKeyRoutes registers API key management endpoints for aggregators. 62 + // These endpoints require OAuth authentication and are only available to registered aggregators. 63 + // Call this function AFTER setting up the auth middleware. 64 + func RegisterAggregatorAPIKeyRoutes( 65 + r chi.Router, 66 + authMiddleware middleware.AuthMiddleware, 67 + apiKeyService aggregators.APIKeyServiceInterface, 68 + aggregatorService aggregators.Service, 69 + ) { 70 + // Create API key handlers 71 + createAPIKeyHandler := aggregator.NewCreateAPIKeyHandler(apiKeyService, aggregatorService) 72 + getAPIKeyHandler := aggregator.NewGetAPIKeyHandler(apiKeyService, aggregatorService) 73 + revokeAPIKeyHandler := aggregator.NewRevokeAPIKeyHandler(apiKeyService, aggregatorService) 74 + metricsHandler := aggregator.NewMetricsHandler(apiKeyService) 75 + 76 + // API key management endpoints (require OAuth authentication) 77 + // POST /xrpc/social.coves.aggregator.createApiKey 78 + // Creates a new API key for the authenticated aggregator 79 + r.With(authMiddleware.RequireAuth).Post("/xrpc/social.coves.aggregator.createApiKey", 80 + createAPIKeyHandler.HandleCreateAPIKey) 81 + 82 + // GET /xrpc/social.coves.aggregator.getApiKey 83 + // Gets info about the authenticated aggregator's API key (not the key itself) 84 + r.With(authMiddleware.RequireAuth).Get("/xrpc/social.coves.aggregator.getApiKey", 85 + getAPIKeyHandler.HandleGetAPIKey) 86 + 87 + // POST /xrpc/social.coves.aggregator.revokeApiKey 88 + // Revokes the authenticated aggregator's API key 89 + r.With(authMiddleware.RequireAuth).Post("/xrpc/social.coves.aggregator.revokeApiKey", 90 + revokeAPIKeyHandler.HandleRevokeAPIKey) 91 + 92 + // GET /xrpc/social.coves.aggregator.getMetrics 93 + // Returns operational metrics for the API key service (internal monitoring endpoint) 94 + // No authentication required - metrics are non-sensitive operational data 95 + r.Get("/xrpc/social.coves.aggregator.getMetrics", metricsHandler.HandleMetrics) 96 + }
+63
internal/atproto/lexicon/social/coves/aggregator/createApiKey.json
··· 1 + { 2 + "lexicon": 1, 3 + "id": "social.coves.aggregator.createApiKey", 4 + "defs": { 5 + "main": { 6 + "type": "procedure", 7 + "description": "Create an API key for the authenticated aggregator. Requires OAuth authentication. The API key is returned ONCE and cannot be retrieved again. Store it securely.", 8 + "input": { 9 + "encoding": "application/json", 10 + "schema": { 11 + "type": "object", 12 + "description": "No input required. The key is generated server-side for the authenticated aggregator.", 13 + "properties": {} 14 + } 15 + }, 16 + "output": { 17 + "encoding": "application/json", 18 + "schema": { 19 + "type": "object", 20 + "required": ["key", "keyPrefix", "did", "createdAt"], 21 + "properties": { 22 + "key": { 23 + "type": "string", 24 + "description": "The plain-text API key. This is shown ONCE and cannot be retrieved again. Format: ckapi_<64-hex-chars> (32 bytes hex-encoded)" 25 + }, 26 + "keyPrefix": { 27 + "type": "string", 28 + "description": "First 12 characters of the key (e.g., 'ckapi_ab12cd') for identification in logs and UI" 29 + }, 30 + "did": { 31 + "type": "string", 32 + "format": "did", 33 + "description": "DID of the aggregator that owns this key" 34 + }, 35 + "createdAt": { 36 + "type": "string", 37 + "format": "datetime", 38 + "description": "ISO8601 timestamp when the key was created" 39 + } 40 + } 41 + } 42 + }, 43 + "errors": [ 44 + { 45 + "name": "AuthenticationRequired", 46 + "description": "OAuth authentication is required to create an API key" 47 + }, 48 + { 49 + "name": "OAuthSessionRequired", 50 + "description": "OAuth session is required (not service JWT) to create an API key" 51 + }, 52 + { 53 + "name": "AggregatorRequired", 54 + "description": "Only registered aggregators can create API keys" 55 + }, 56 + { 57 + "name": "KeyGenerationFailed", 58 + "description": "Failed to generate the API key" 59 + } 60 + ] 61 + } 62 + } 63 + }
+30
internal/atproto/lexicon/social/coves/aggregator/defs.json
··· 204 204 "format": "at-uri" 205 205 } 206 206 } 207 + }, 208 + "apiKeyView": { 209 + "type": "object", 210 + "description": "View of an API key's metadata. The actual key value is never returned after initial creation.", 211 + "required": ["prefix", "createdAt", "isRevoked"], 212 + "properties": { 213 + "prefix": { 214 + "type": "string", 215 + "description": "First 12 characters of the key (e.g., 'ckapi_ab12cd') for identification in logs and UI" 216 + }, 217 + "createdAt": { 218 + "type": "string", 219 + "format": "datetime", 220 + "description": "When the key was created" 221 + }, 222 + "lastUsedAt": { 223 + "type": "string", 224 + "format": "datetime", 225 + "description": "When the key was last used for authentication" 226 + }, 227 + "isRevoked": { 228 + "type": "boolean", 229 + "description": "Whether the key has been revoked" 230 + }, 231 + "revokedAt": { 232 + "type": "string", 233 + "format": "datetime", 234 + "description": "When the key was revoked" 235 + } 236 + } 207 237 } 208 238 } 209 239 }
+47
internal/atproto/lexicon/social/coves/aggregator/getApiKey.json
··· 1 + { 2 + "lexicon": 1, 3 + "id": "social.coves.aggregator.getApiKey", 4 + "defs": { 5 + "main": { 6 + "type": "query", 7 + "description": "Get information about the authenticated aggregator's API key. Note: The actual key value is NEVER returned - only metadata about the key.", 8 + "parameters": { 9 + "type": "params", 10 + "description": "No parameters required. Returns key info for the authenticated aggregator.", 11 + "properties": {} 12 + }, 13 + "output": { 14 + "encoding": "application/json", 15 + "schema": { 16 + "type": "object", 17 + "required": ["hasKey"], 18 + "properties": { 19 + "hasKey": { 20 + "type": "boolean", 21 + "description": "Whether the aggregator has an API key (active or revoked)" 22 + }, 23 + "keyInfo": { 24 + "type": "ref", 25 + "ref": "social.coves.aggregator.defs#apiKeyView", 26 + "description": "API key metadata. Only present if hasKey is true." 27 + } 28 + } 29 + } 30 + }, 31 + "errors": [ 32 + { 33 + "name": "AuthenticationRequired", 34 + "description": "Authentication is required to get API key info" 35 + }, 36 + { 37 + "name": "AggregatorRequired", 38 + "description": "Only registered aggregators can get API key info" 39 + }, 40 + { 41 + "name": "AggregatorNotFound", 42 + "description": "Aggregator not found" 43 + } 44 + ] 45 + } 46 + } 47 + }
+58
internal/atproto/lexicon/social/coves/aggregator/revokeApiKey.json
··· 1 + { 2 + "lexicon": 1, 3 + "id": "social.coves.aggregator.revokeApiKey", 4 + "defs": { 5 + "main": { 6 + "type": "procedure", 7 + "description": "Revoke the authenticated aggregator's API key. After revocation, the aggregator must complete OAuth flow again to create a new API key. This action cannot be undone.", 8 + "input": { 9 + "encoding": "application/json", 10 + "schema": { 11 + "type": "object", 12 + "description": "No input required. Revokes the key for the authenticated aggregator.", 13 + "properties": {} 14 + } 15 + }, 16 + "output": { 17 + "encoding": "application/json", 18 + "schema": { 19 + "type": "object", 20 + "required": ["revokedAt"], 21 + "properties": { 22 + "revokedAt": { 23 + "type": "string", 24 + "format": "datetime", 25 + "description": "ISO8601 timestamp when the key was revoked" 26 + } 27 + } 28 + } 29 + }, 30 + "errors": [ 31 + { 32 + "name": "AuthenticationRequired", 33 + "description": "Authentication is required to revoke an API key" 34 + }, 35 + { 36 + "name": "AggregatorRequired", 37 + "description": "Only registered aggregators can revoke API keys" 38 + }, 39 + { 40 + "name": "AggregatorNotFound", 41 + "description": "Aggregator not found" 42 + }, 43 + { 44 + "name": "ApiKeyNotFound", 45 + "description": "No API key exists to revoke" 46 + }, 47 + { 48 + "name": "ApiKeyAlreadyRevoked", 49 + "description": "API key has already been revoked" 50 + }, 51 + { 52 + "name": "RevocationFailed", 53 + "description": "Failed to revoke the API key" 54 + } 55 + ] 56 + } 57 + } 58 + }
+102 -13
internal/core/aggregators/aggregator.go
··· 6 6 // Aggregators are autonomous services that can post content to communities after authorization 7 7 // Following Bluesky's pattern: app.bsky.feed.generator and app.bsky.labeler.service 8 8 type Aggregator struct { 9 - CreatedAt time.Time `json:"createdAt" db:"created_at"` 10 - IndexedAt time.Time `json:"indexedAt" db:"indexed_at"` 11 - AvatarURL string `json:"avatarUrl,omitempty" db:"avatar_url"` 12 - DID string `json:"did" db:"did"` 13 - MaintainerDID string `json:"maintainerDid,omitempty" db:"maintainer_did"` 14 - SourceURL string `json:"sourceUrl,omitempty" db:"source_url"` 15 - Description string `json:"description,omitempty" db:"description"` 16 - DisplayName string `json:"displayName" db:"display_name"` 17 - RecordURI string `json:"recordUri,omitempty" db:"record_uri"` 18 - RecordCID string `json:"recordCid,omitempty" db:"record_cid"` 19 - ConfigSchema []byte `json:"configSchema,omitempty" db:"config_schema"` 20 - CommunitiesUsing int `json:"communitiesUsing" db:"communities_using"` 21 - PostsCreated int `json:"postsCreated" db:"posts_created"` 9 + // Core timestamps 10 + CreatedAt time.Time `json:"createdAt" db:"created_at"` 11 + IndexedAt time.Time `json:"indexedAt" db:"indexed_at"` 12 + 13 + // Identity and display 14 + DID string `json:"did" db:"did"` 15 + DisplayName string `json:"displayName" db:"display_name"` 16 + Description string `json:"description,omitempty" db:"description"` 17 + AvatarURL string `json:"avatarUrl,omitempty" db:"avatar_url"` 18 + 19 + // Metadata 20 + MaintainerDID string `json:"maintainerDid,omitempty" db:"maintainer_did"` 21 + SourceURL string `json:"sourceUrl,omitempty" db:"source_url"` 22 + RecordURI string `json:"recordUri,omitempty" db:"record_uri"` 23 + RecordCID string `json:"recordCid,omitempty" db:"record_cid"` 24 + ConfigSchema []byte `json:"configSchema,omitempty" db:"config_schema"` 25 + 26 + // Stats 27 + CommunitiesUsing int `json:"communitiesUsing" db:"communities_using"` 28 + PostsCreated int `json:"postsCreated" db:"posts_created"` 29 + } 30 + 31 + // OAuthCredentials holds OAuth session data for aggregator authentication 32 + // Used when setting up or refreshing API key authentication 33 + type OAuthCredentials struct { 34 + AccessToken string 35 + RefreshToken string 36 + TokenExpiresAt time.Time 37 + PDSURL string 38 + AuthServerIss string 39 + AuthServerTokenEndpoint string 40 + DPoPPrivateKeyMultibase string 41 + DPoPAuthServerNonce string 42 + DPoPPDSNonce string 43 + } 44 + 45 + // Validate checks that all required OAuthCredentials fields are present and valid. 46 + // Returns an error describing the first validation failure, or nil if valid. 47 + func (c *OAuthCredentials) Validate() error { 48 + if c.AccessToken == "" { 49 + return NewValidationError("accessToken", "access token is required") 50 + } 51 + if c.RefreshToken == "" { 52 + return NewValidationError("refreshToken", "refresh token is required") 53 + } 54 + if c.TokenExpiresAt.IsZero() { 55 + return NewValidationError("tokenExpiresAt", "token expiry time is required") 56 + } 57 + if c.PDSURL == "" { 58 + return NewValidationError("pdsUrl", "PDS URL is required") 59 + } 60 + if c.AuthServerIss == "" { 61 + return NewValidationError("authServerIss", "auth server issuer is required") 62 + } 63 + if c.AuthServerTokenEndpoint == "" { 64 + return NewValidationError("authServerTokenEndpoint", "auth server token endpoint is required") 65 + } 66 + if c.DPoPPrivateKeyMultibase == "" { 67 + return NewValidationError("dpopPrivateKey", "DPoP private key is required") 68 + } 69 + return nil 70 + } 71 + 72 + // AggregatorCredentials holds sensitive authentication data for aggregators. 73 + // This is the preferred type for authentication operations - separates concerns 74 + // from the public Aggregator type and prevents credential leakage. 75 + type AggregatorCredentials struct { 76 + DID string `db:"did"` 77 + 78 + // API Key Authentication 79 + APIKeyPrefix string `db:"api_key_prefix"` 80 + APIKeyHash string `db:"api_key_hash"` 81 + APIKeyCreatedAt *time.Time `db:"api_key_created_at"` 82 + APIKeyRevokedAt *time.Time `db:"api_key_revoked_at"` 83 + APIKeyLastUsed *time.Time `db:"api_key_last_used_at"` 84 + 85 + // OAuth Session Credentials 86 + OAuthAccessToken string `db:"oauth_access_token"` 87 + OAuthRefreshToken string `db:"oauth_refresh_token"` 88 + OAuthTokenExpiresAt *time.Time `db:"oauth_token_expires_at"` 89 + OAuthPDSURL string `db:"oauth_pds_url"` 90 + OAuthAuthServerIss string `db:"oauth_auth_server_iss"` 91 + OAuthAuthServerTokenEndpoint string `db:"oauth_auth_server_token_endpoint"` 92 + OAuthDPoPPrivateKeyMultibase string `db:"oauth_dpop_private_key_multibase"` 93 + OAuthDPoPAuthServerNonce string `db:"oauth_dpop_authserver_nonce"` 94 + OAuthDPoPPDSNonce string `db:"oauth_dpop_pds_nonce"` 95 + } 96 + 97 + // HasActiveAPIKey returns true if the credentials have an active (non-revoked) API key. 98 + // An active key has a non-empty hash and has not been revoked. 99 + func (c *AggregatorCredentials) HasActiveAPIKey() bool { 100 + return c.APIKeyHash != "" && c.APIKeyRevokedAt == nil 101 + } 102 + 103 + // IsOAuthTokenExpired returns true if the OAuth access token has expired or will expire soon. 104 + // Uses a 5-minute buffer before actual expiry to allow proactive token refresh, 105 + // accounting for clock skew and network latency during refresh operations. 106 + func (c *AggregatorCredentials) IsOAuthTokenExpired() bool { 107 + if c.OAuthTokenExpiresAt == nil { 108 + return true 109 + } 110 + return time.Now().Add(5 * time.Minute).After(*c.OAuthTokenExpiresAt) 22 111 } 23 112 24 113 // Authorization represents a community's authorization for an aggregator
+373
internal/core/aggregators/apikey_service.go
··· 1 + package aggregators 2 + 3 + import ( 4 + "context" 5 + "crypto/rand" 6 + "crypto/sha256" 7 + "encoding/hex" 8 + "errors" 9 + "fmt" 10 + "log/slog" 11 + "sync/atomic" 12 + "time" 13 + 14 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 15 + "github.com/bluesky-social/indigo/atproto/syntax" 16 + ) 17 + 18 + const ( 19 + // APIKeyPrefix is the prefix for all Coves API keys 20 + APIKeyPrefix = "ckapi_" 21 + // APIKeyRandomBytes is the number of random bytes in the key (32 bytes = 256 bits) 22 + APIKeyRandomBytes = 32 23 + // APIKeyTotalLength is the total length of the API key including prefix 24 + // 6 (prefix "ckapi_") + 64 (32 bytes hex-encoded) = 70 25 + APIKeyTotalLength = 70 26 + // TokenRefreshBuffer is how long before expiry we should refresh tokens 27 + TokenRefreshBuffer = 5 * time.Minute 28 + // DefaultSessionID is used for API key sessions since aggregators have a single session 29 + DefaultSessionID = "apikey" 30 + ) 31 + 32 + // APIKeyService handles API key generation, validation, and OAuth token management 33 + // for aggregator authentication. 34 + type APIKeyService struct { 35 + repo Repository 36 + oauthApp *oauth.ClientApp // For resuming sessions and refreshing tokens 37 + 38 + // failedLastUsedUpdates tracks the number of failed API key last_used timestamp updates. 39 + // This counter provides visibility into persistent DB issues that would otherwise be hidden 40 + // since the update is done asynchronously. Use GetFailedLastUsedUpdates() to read. 41 + failedLastUsedUpdates atomic.Int64 42 + 43 + // failedNonceUpdates tracks the number of failed OAuth nonce updates. 44 + // Nonce failures may indicate DB issues and could lead to DPoP replay protection issues. 45 + // Use GetFailedNonceUpdates() to read. 46 + failedNonceUpdates atomic.Int64 47 + } 48 + 49 + // NewAPIKeyService creates a new API key service. 50 + // Panics if repo or oauthApp are nil, as these are required dependencies. 51 + func NewAPIKeyService(repo Repository, oauthApp *oauth.ClientApp) *APIKeyService { 52 + if repo == nil { 53 + panic("aggregators.NewAPIKeyService: repo cannot be nil") 54 + } 55 + if oauthApp == nil { 56 + panic("aggregators.NewAPIKeyService: oauthApp cannot be nil") 57 + } 58 + return &APIKeyService{ 59 + repo: repo, 60 + oauthApp: oauthApp, 61 + } 62 + } 63 + 64 + // GenerateKey creates a new API key for an aggregator. 65 + // The aggregator must have completed OAuth authentication first. 66 + // Returns the plain-text key (only shown once) and the key prefix for reference. 67 + func (s *APIKeyService) GenerateKey(ctx context.Context, aggregatorDID string, oauthSession *oauth.ClientSessionData) (plainKey string, keyPrefix string, err error) { 68 + // Validate aggregator exists 69 + aggregator, err := s.repo.GetAggregator(ctx, aggregatorDID) 70 + if err != nil { 71 + return "", "", fmt.Errorf("failed to get aggregator: %w", err) 72 + } 73 + 74 + // Validate OAuth session matches the aggregator 75 + if oauthSession.AccountDID.String() != aggregatorDID { 76 + return "", "", ErrOAuthSessionMismatch 77 + } 78 + 79 + // Generate random key 80 + randomBytes := make([]byte, APIKeyRandomBytes) 81 + if _, err := rand.Read(randomBytes); err != nil { 82 + return "", "", fmt.Errorf("failed to generate random key: %w", err) 83 + } 84 + randomHex := hex.EncodeToString(randomBytes) 85 + plainKey = APIKeyPrefix + randomHex 86 + 87 + // Create key prefix (first 12 chars including prefix for identification) 88 + keyPrefix = plainKey[:12] 89 + 90 + // Hash the key for storage (SHA-256) 91 + keyHash := hashAPIKey(plainKey) 92 + 93 + // Extract OAuth credentials from session 94 + // Note: ClientSessionData doesn't store token expiry from the OAuth response. 95 + // We use a 1-hour default which matches typical OAuth access token lifetimes. 96 + // Token refresh happens proactively before expiry via RefreshTokensIfNeeded. 97 + tokenExpiry := time.Now().Add(1 * time.Hour) 98 + oauthCreds := &OAuthCredentials{ 99 + AccessToken: oauthSession.AccessToken, 100 + RefreshToken: oauthSession.RefreshToken, 101 + TokenExpiresAt: tokenExpiry, 102 + PDSURL: oauthSession.HostURL, 103 + AuthServerIss: oauthSession.AuthServerURL, 104 + AuthServerTokenEndpoint: oauthSession.AuthServerTokenEndpoint, 105 + DPoPPrivateKeyMultibase: oauthSession.DPoPPrivateKeyMultibase, 106 + DPoPAuthServerNonce: oauthSession.DPoPAuthServerNonce, 107 + DPoPPDSNonce: oauthSession.DPoPHostNonce, 108 + } 109 + 110 + // Validate OAuth credentials before proceeding 111 + if err := oauthCreds.Validate(); err != nil { 112 + return "", "", fmt.Errorf("invalid OAuth credentials: %w", err) 113 + } 114 + 115 + // Store the OAuth session in the store FIRST (before API key) 116 + // This prevents a race condition where the API key exists but can't refresh tokens. 117 + // Order: OAuth session → API key (if session fails, no dangling API key) 118 + apiKeySession := *oauthSession // Copy session data 119 + apiKeySession.SessionID = DefaultSessionID 120 + if err := s.oauthApp.Store.SaveSession(ctx, apiKeySession); err != nil { 121 + slog.Error("failed to store OAuth session for API key - aborting key creation", 122 + "did", aggregatorDID, 123 + "error", err, 124 + ) 125 + return "", "", fmt.Errorf("failed to store OAuth session for token refresh: %w", err) 126 + } 127 + 128 + // Now store key hash and OAuth credentials in aggregators table 129 + // If this fails, we have an orphaned OAuth session, but that's less problematic 130 + // than having an API key that can't refresh tokens. 131 + if err := s.repo.SetAPIKey(ctx, aggregatorDID, keyPrefix, keyHash, oauthCreds); err != nil { 132 + // Best effort cleanup of the OAuth session we just stored 133 + if deleteErr := s.oauthApp.Store.DeleteSession(ctx, oauthSession.AccountDID, DefaultSessionID); deleteErr != nil { 134 + slog.Warn("failed to cleanup OAuth session after API key storage failure", 135 + "did", aggregatorDID, 136 + "error", deleteErr, 137 + ) 138 + } 139 + return "", "", fmt.Errorf("failed to store API key: %w", err) 140 + } 141 + 142 + slog.Info("API key generated for aggregator", 143 + "did", aggregatorDID, 144 + "display_name", aggregator.DisplayName, 145 + "key_prefix", keyPrefix, 146 + ) 147 + 148 + return plainKey, keyPrefix, nil 149 + } 150 + 151 + // ValidateKey validates an API key and returns the associated aggregator credentials. 152 + // Returns ErrAPIKeyInvalid if the key is not found or revoked. 153 + func (s *APIKeyService) ValidateKey(ctx context.Context, plainKey string) (*AggregatorCredentials, error) { 154 + // Validate key format - log invalid attempts for security monitoring 155 + if len(plainKey) != APIKeyTotalLength || plainKey[:6] != APIKeyPrefix { 156 + // Log for security monitoring (potential brute-force detection) 157 + // Don't log the full key, just metadata about the attempt 158 + slog.Warn("[SECURITY] invalid API key format attempt", 159 + "key_length", len(plainKey), 160 + "has_valid_prefix", len(plainKey) >= 6 && plainKey[:6] == APIKeyPrefix, 161 + ) 162 + return nil, ErrAPIKeyInvalid 163 + } 164 + 165 + // Hash the provided key 166 + keyHash := hashAPIKey(plainKey) 167 + 168 + // Look up aggregator credentials by hash 169 + creds, err := s.repo.GetCredentialsByAPIKeyHash(ctx, keyHash) 170 + if err != nil { 171 + if IsNotFound(err) { 172 + return nil, ErrAPIKeyInvalid 173 + } 174 + // Check for revoked API key (returned by repo when api_key_revoked_at is set) 175 + if errors.Is(err, ErrAPIKeyRevoked) { 176 + slog.Warn("revoked API key used", 177 + "key_hash_prefix", keyHash[:8], 178 + ) 179 + return nil, ErrAPIKeyRevoked 180 + } 181 + return nil, fmt.Errorf("failed to lookup API key: %w", err) 182 + } 183 + 184 + // Update last used timestamp (async, don't block on error) 185 + // Use a bounded timeout to prevent goroutine accumulation if DB is slow/down 186 + // Extract trace info from context before spawning goroutine for log correlation 187 + aggregatorDID := creds.DID // capture for goroutine 188 + go func() { 189 + updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 190 + defer cancel() 191 + 192 + if updateErr := s.repo.UpdateAPIKeyLastUsed(updateCtx, aggregatorDID); updateErr != nil { 193 + // Increment failure counter for monitoring visibility 194 + failCount := s.failedLastUsedUpdates.Add(1) 195 + slog.Error("failed to update API key last used", 196 + "did", aggregatorDID, 197 + "error", updateErr, 198 + "total_failures", failCount, 199 + ) 200 + } 201 + }() 202 + 203 + return creds, nil 204 + } 205 + 206 + // RefreshTokensIfNeeded checks if the OAuth tokens are expired or expiring soon, 207 + // and refreshes them if necessary. 208 + func (s *APIKeyService) RefreshTokensIfNeeded(ctx context.Context, creds *AggregatorCredentials) error { 209 + // Check if tokens need refresh 210 + if creds.OAuthTokenExpiresAt != nil { 211 + if time.Until(*creds.OAuthTokenExpiresAt) > TokenRefreshBuffer { 212 + // Tokens still valid 213 + return nil 214 + } 215 + } 216 + 217 + // Need to refresh tokens 218 + slog.Info("refreshing OAuth tokens for aggregator", 219 + "did", creds.DID, 220 + "expires_at", creds.OAuthTokenExpiresAt, 221 + ) 222 + 223 + // Parse DID 224 + did, err := syntax.ParseDID(creds.DID) 225 + if err != nil { 226 + return fmt.Errorf("failed to parse aggregator DID: %w", err) 227 + } 228 + 229 + // Resume the OAuth session from the store 230 + // The session was stored when the aggregator created their API key 231 + session, err := s.oauthApp.ResumeSession(ctx, did, DefaultSessionID) 232 + if err != nil { 233 + slog.Error("failed to resume OAuth session for token refresh", 234 + "did", creds.DID, 235 + "error", err, 236 + ) 237 + return fmt.Errorf("failed to resume session: %w", err) 238 + } 239 + 240 + // Refresh tokens using indigo's OAuth library 241 + newAccessToken, err := session.RefreshTokens(ctx) 242 + if err != nil { 243 + slog.Error("failed to refresh OAuth tokens", 244 + "did", creds.DID, 245 + "error", err, 246 + ) 247 + return fmt.Errorf("failed to refresh tokens: %w", err) 248 + } 249 + 250 + // Note: ClientSessionData doesn't store token expiry from the OAuth response. 251 + // We use a 1-hour default which matches typical OAuth access token lifetimes. 252 + newExpiry := time.Now().Add(1 * time.Hour) 253 + 254 + // Update tokens in database 255 + if err := s.repo.UpdateOAuthTokens(ctx, creds.DID, newAccessToken, session.Data.RefreshToken, newExpiry); err != nil { 256 + return fmt.Errorf("failed to update tokens: %w", err) 257 + } 258 + 259 + // Update nonces in our database as a secondary copy for visibility/backup. 260 + // The authoritative nonces are in indigo's OAuth store (via SaveSession above). 261 + // Session resumption uses s.oauthApp.ResumeSession which reads from indigo's store, 262 + // so this failure is non-critical - hence warning level, not error. 263 + if err := s.repo.UpdateOAuthNonces(ctx, creds.DID, session.Data.DPoPAuthServerNonce, session.Data.DPoPHostNonce); err != nil { 264 + failCount := s.failedNonceUpdates.Add(1) 265 + slog.Warn("failed to update OAuth nonces in aggregators table", 266 + "did", creds.DID, 267 + "error", err, 268 + "total_failures", failCount, 269 + ) 270 + } 271 + 272 + // Update credentials in memory 273 + creds.OAuthAccessToken = newAccessToken 274 + creds.OAuthRefreshToken = session.Data.RefreshToken 275 + creds.OAuthTokenExpiresAt = &newExpiry 276 + creds.OAuthDPoPAuthServerNonce = session.Data.DPoPAuthServerNonce 277 + creds.OAuthDPoPPDSNonce = session.Data.DPoPHostNonce 278 + 279 + slog.Info("OAuth tokens refreshed for aggregator", 280 + "did", creds.DID, 281 + "new_expires_at", newExpiry, 282 + ) 283 + 284 + return nil 285 + } 286 + 287 + // GetAccessToken returns a valid access token for the aggregator, 288 + // refreshing if necessary. 289 + func (s *APIKeyService) GetAccessToken(ctx context.Context, creds *AggregatorCredentials) (string, error) { 290 + // Ensure tokens are fresh 291 + if err := s.RefreshTokensIfNeeded(ctx, creds); err != nil { 292 + return "", fmt.Errorf("failed to ensure fresh tokens: %w", err) 293 + } 294 + 295 + return creds.OAuthAccessToken, nil 296 + } 297 + 298 + // RevokeKey revokes an API key for an aggregator. 299 + // After revocation, the aggregator must complete OAuth flow again to get a new key. 300 + func (s *APIKeyService) RevokeKey(ctx context.Context, aggregatorDID string) error { 301 + if err := s.repo.RevokeAPIKey(ctx, aggregatorDID); err != nil { 302 + return fmt.Errorf("failed to revoke API key: %w", err) 303 + } 304 + 305 + slog.Info("API key revoked for aggregator", 306 + "did", aggregatorDID, 307 + ) 308 + 309 + return nil 310 + } 311 + 312 + // GetAggregator retrieves the public aggregator information by DID. 313 + // For credential/authentication data, use GetAggregatorCredentials instead. 314 + func (s *APIKeyService) GetAggregator(ctx context.Context, aggregatorDID string) (*Aggregator, error) { 315 + return s.repo.GetAggregator(ctx, aggregatorDID) 316 + } 317 + 318 + // GetAggregatorCredentials retrieves credentials for an aggregator by DID. 319 + func (s *APIKeyService) GetAggregatorCredentials(ctx context.Context, aggregatorDID string) (*AggregatorCredentials, error) { 320 + return s.repo.GetAggregatorCredentials(ctx, aggregatorDID) 321 + } 322 + 323 + // GetAPIKeyInfo returns information about an aggregator's API key (without the actual key). 324 + func (s *APIKeyService) GetAPIKeyInfo(ctx context.Context, aggregatorDID string) (*APIKeyInfo, error) { 325 + creds, err := s.repo.GetAggregatorCredentials(ctx, aggregatorDID) 326 + if err != nil { 327 + return nil, err 328 + } 329 + 330 + if creds.APIKeyHash == "" { 331 + return &APIKeyInfo{ 332 + HasKey: false, 333 + }, nil 334 + } 335 + 336 + return &APIKeyInfo{ 337 + HasKey: true, 338 + KeyPrefix: creds.APIKeyPrefix, 339 + CreatedAt: creds.APIKeyCreatedAt, 340 + LastUsedAt: creds.APIKeyLastUsed, 341 + IsRevoked: creds.APIKeyRevokedAt != nil, 342 + RevokedAt: creds.APIKeyRevokedAt, 343 + }, nil 344 + } 345 + 346 + // APIKeyInfo contains non-sensitive information about an API key 347 + type APIKeyInfo struct { 348 + HasKey bool 349 + KeyPrefix string 350 + CreatedAt *time.Time 351 + LastUsedAt *time.Time 352 + IsRevoked bool 353 + RevokedAt *time.Time 354 + } 355 + 356 + // hashAPIKey creates a SHA-256 hash of the API key for storage 357 + func hashAPIKey(plainKey string) string { 358 + hash := sha256.Sum256([]byte(plainKey)) 359 + return hex.EncodeToString(hash[:]) 360 + } 361 + 362 + // GetFailedLastUsedUpdates returns the count of failed API key last_used timestamp updates. 363 + // This is useful for monitoring and alerting on persistent database issues. 364 + func (s *APIKeyService) GetFailedLastUsedUpdates() int64 { 365 + return s.failedLastUsedUpdates.Load() 366 + } 367 + 368 + // GetFailedNonceUpdates returns the count of failed OAuth nonce updates. 369 + // This is useful for monitoring and alerting on persistent database issues 370 + // that could affect DPoP replay protection. 371 + func (s *APIKeyService) GetFailedNonceUpdates() int64 { 372 + return s.failedNonceUpdates.Load() 373 + }
+1143
internal/core/aggregators/apikey_service_test.go
··· 1 + package aggregators 2 + 3 + import ( 4 + "context" 5 + "crypto/sha256" 6 + "encoding/hex" 7 + "errors" 8 + "testing" 9 + "time" 10 + 11 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 12 + "github.com/bluesky-social/indigo/atproto/syntax" 13 + ) 14 + 15 + // ptrTime returns a pointer to a time.Time (current time) 16 + func ptrTime() *time.Time { 17 + t := time.Now() 18 + return &t 19 + } 20 + 21 + // ptrTimeOffset returns a pointer to a time.Time offset from now 22 + func ptrTimeOffset(d time.Duration) *time.Time { 23 + t := time.Now().Add(d) 24 + return &t 25 + } 26 + 27 + // newTestAPIKeyService creates an APIKeyService with mock dependencies for testing. 28 + // This helper ensures tests don't panic from nil checks added in constructor validation. 29 + func newTestAPIKeyService(repo Repository) *APIKeyService { 30 + mockStore := &mockOAuthStore{} 31 + mockApp := &oauth.ClientApp{Store: mockStore} 32 + return NewAPIKeyService(repo, mockApp) 33 + } 34 + 35 + // mockRepository implements Repository interface for testing 36 + type mockRepository struct { 37 + getAggregatorFunc func(ctx context.Context, did string) (*Aggregator, error) 38 + getByAPIKeyHashFunc func(ctx context.Context, keyHash string) (*Aggregator, error) 39 + getCredentialsByAPIKeyHashFunc func(ctx context.Context, keyHash string) (*AggregatorCredentials, error) 40 + getAggregatorCredentialsFunc func(ctx context.Context, did string) (*AggregatorCredentials, error) 41 + setAPIKeyFunc func(ctx context.Context, did, keyPrefix, keyHash string, oauthCreds *OAuthCredentials) error 42 + updateOAuthTokensFunc func(ctx context.Context, did, accessToken, refreshToken string, expiresAt time.Time) error 43 + updateOAuthNoncesFunc func(ctx context.Context, did, authServerNonce, pdsNonce string) error 44 + updateAPIKeyLastUsedFunc func(ctx context.Context, did string) error 45 + revokeAPIKeyFunc func(ctx context.Context, did string) error 46 + } 47 + 48 + func (m *mockRepository) GetAggregator(ctx context.Context, did string) (*Aggregator, error) { 49 + if m.getAggregatorFunc != nil { 50 + return m.getAggregatorFunc(ctx, did) 51 + } 52 + return &Aggregator{DID: did, DisplayName: "Test Aggregator"}, nil 53 + } 54 + 55 + func (m *mockRepository) GetByAPIKeyHash(ctx context.Context, keyHash string) (*Aggregator, error) { 56 + if m.getByAPIKeyHashFunc != nil { 57 + return m.getByAPIKeyHashFunc(ctx, keyHash) 58 + } 59 + return nil, ErrAggregatorNotFound 60 + } 61 + 62 + func (m *mockRepository) SetAPIKey(ctx context.Context, did, keyPrefix, keyHash string, oauthCreds *OAuthCredentials) error { 63 + if m.setAPIKeyFunc != nil { 64 + return m.setAPIKeyFunc(ctx, did, keyPrefix, keyHash, oauthCreds) 65 + } 66 + return nil 67 + } 68 + 69 + func (m *mockRepository) UpdateOAuthTokens(ctx context.Context, did, accessToken, refreshToken string, expiresAt time.Time) error { 70 + if m.updateOAuthTokensFunc != nil { 71 + return m.updateOAuthTokensFunc(ctx, did, accessToken, refreshToken, expiresAt) 72 + } 73 + return nil 74 + } 75 + 76 + func (m *mockRepository) UpdateOAuthNonces(ctx context.Context, did, authServerNonce, pdsNonce string) error { 77 + if m.updateOAuthNoncesFunc != nil { 78 + return m.updateOAuthNoncesFunc(ctx, did, authServerNonce, pdsNonce) 79 + } 80 + return nil 81 + } 82 + 83 + func (m *mockRepository) UpdateAPIKeyLastUsed(ctx context.Context, did string) error { 84 + if m.updateAPIKeyLastUsedFunc != nil { 85 + return m.updateAPIKeyLastUsedFunc(ctx, did) 86 + } 87 + return nil 88 + } 89 + 90 + func (m *mockRepository) RevokeAPIKey(ctx context.Context, did string) error { 91 + if m.revokeAPIKeyFunc != nil { 92 + return m.revokeAPIKeyFunc(ctx, did) 93 + } 94 + return nil 95 + } 96 + 97 + // Stub implementations for Repository interface methods not used in APIKeyService tests 98 + func (m *mockRepository) CreateAggregator(ctx context.Context, aggregator *Aggregator) error { 99 + return nil 100 + } 101 + 102 + func (m *mockRepository) GetAggregatorsByDIDs(ctx context.Context, dids []string) ([]*Aggregator, error) { 103 + return nil, nil 104 + } 105 + 106 + func (m *mockRepository) UpdateAggregator(ctx context.Context, aggregator *Aggregator) error { 107 + return nil 108 + } 109 + 110 + func (m *mockRepository) DeleteAggregator(ctx context.Context, did string) error { 111 + return nil 112 + } 113 + 114 + func (m *mockRepository) ListAggregators(ctx context.Context, limit, offset int) ([]*Aggregator, error) { 115 + return nil, nil 116 + } 117 + 118 + func (m *mockRepository) IsAggregator(ctx context.Context, did string) (bool, error) { 119 + return false, nil 120 + } 121 + 122 + func (m *mockRepository) CreateAuthorization(ctx context.Context, auth *Authorization) error { 123 + return nil 124 + } 125 + 126 + func (m *mockRepository) GetAuthorization(ctx context.Context, aggregatorDID, communityDID string) (*Authorization, error) { 127 + return nil, nil 128 + } 129 + 130 + func (m *mockRepository) GetAuthorizationByURI(ctx context.Context, recordURI string) (*Authorization, error) { 131 + return nil, nil 132 + } 133 + 134 + func (m *mockRepository) UpdateAuthorization(ctx context.Context, auth *Authorization) error { 135 + return nil 136 + } 137 + 138 + func (m *mockRepository) DeleteAuthorization(ctx context.Context, aggregatorDID, communityDID string) error { 139 + return nil 140 + } 141 + 142 + func (m *mockRepository) DeleteAuthorizationByURI(ctx context.Context, recordURI string) error { 143 + return nil 144 + } 145 + 146 + func (m *mockRepository) ListAuthorizationsForAggregator(ctx context.Context, aggregatorDID string, enabledOnly bool, limit, offset int) ([]*Authorization, error) { 147 + return nil, nil 148 + } 149 + 150 + func (m *mockRepository) ListAuthorizationsForCommunity(ctx context.Context, communityDID string, enabledOnly bool, limit, offset int) ([]*Authorization, error) { 151 + return nil, nil 152 + } 153 + 154 + func (m *mockRepository) IsAuthorized(ctx context.Context, aggregatorDID, communityDID string) (bool, error) { 155 + return false, nil 156 + } 157 + 158 + func (m *mockRepository) RecordAggregatorPost(ctx context.Context, aggregatorDID, communityDID, postURI, postCID string) error { 159 + return nil 160 + } 161 + 162 + func (m *mockRepository) CountRecentPosts(ctx context.Context, aggregatorDID, communityDID string, since time.Time) (int, error) { 163 + return 0, nil 164 + } 165 + 166 + func (m *mockRepository) GetRecentPosts(ctx context.Context, aggregatorDID, communityDID string, since time.Time) ([]*AggregatorPost, error) { 167 + return nil, nil 168 + } 169 + 170 + func (m *mockRepository) GetAggregatorCredentials(ctx context.Context, did string) (*AggregatorCredentials, error) { 171 + if m.getAggregatorCredentialsFunc != nil { 172 + return m.getAggregatorCredentialsFunc(ctx, did) 173 + } 174 + return &AggregatorCredentials{DID: did}, nil 175 + } 176 + 177 + func (m *mockRepository) GetCredentialsByAPIKeyHash(ctx context.Context, keyHash string) (*AggregatorCredentials, error) { 178 + if m.getCredentialsByAPIKeyHashFunc != nil { 179 + return m.getCredentialsByAPIKeyHashFunc(ctx, keyHash) 180 + } 181 + return nil, ErrAggregatorNotFound 182 + } 183 + 184 + func TestHashAPIKey(t *testing.T) { 185 + plainKey := "ckapi_abcdef1234567890abcdef1234567890" 186 + 187 + // Hash the key 188 + hash := hashAPIKey(plainKey) 189 + 190 + // Verify it's a valid hex string 191 + if len(hash) != 64 { 192 + t.Errorf("Expected 64 character hash, got %d", len(hash)) 193 + } 194 + 195 + // Verify it's consistent 196 + hash2 := hashAPIKey(plainKey) 197 + if hash != hash2 { 198 + t.Error("Hash function should be deterministic") 199 + } 200 + 201 + // Verify different keys produce different hashes 202 + differentKey := "ckapi_different1234567890abcdef12" 203 + differentHash := hashAPIKey(differentKey) 204 + if hash == differentHash { 205 + t.Error("Different keys should produce different hashes") 206 + } 207 + 208 + // Verify manually 209 + expectedHash := sha256.Sum256([]byte(plainKey)) 210 + expectedHex := hex.EncodeToString(expectedHash[:]) 211 + if hash != expectedHex { 212 + t.Errorf("Expected %s, got %s", expectedHex, hash) 213 + } 214 + } 215 + 216 + func TestAPIKeyConstants(t *testing.T) { 217 + // Verify the key prefix length assumption 218 + if len(APIKeyPrefix) != 6 { 219 + t.Errorf("Expected APIKeyPrefix to be 6 chars, got %d", len(APIKeyPrefix)) 220 + } 221 + 222 + // Verify total length calculation 223 + // Random bytes are hex-encoded, so they double in length (32 bytes -> 64 chars) 224 + expectedLength := len(APIKeyPrefix) + (APIKeyRandomBytes * 2) 225 + if APIKeyTotalLength != expectedLength { 226 + t.Errorf("APIKeyTotalLength should be %d (prefix + hex-encoded random), got %d", expectedLength, APIKeyTotalLength) 227 + } 228 + 229 + // Verify expected values explicitly 230 + if APIKeyTotalLength != 70 { 231 + t.Errorf("APIKeyTotalLength should be 70 (6 prefix + 64 hex chars), got %d", APIKeyTotalLength) 232 + } 233 + } 234 + 235 + func TestValidateKey_FormatValidation(t *testing.T) { 236 + // We can't test the full ValidateKey without mocking, but we can verify 237 + // the format validation logic by checking the constants 238 + // 32 random bytes hex-encoded = 64 characters 239 + testKey := "ckapi_" + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" 240 + if len(testKey) != APIKeyTotalLength { 241 + t.Errorf("Test key length mismatch: expected %d, got %d", APIKeyTotalLength, len(testKey)) 242 + } 243 + 244 + // Test key should start with prefix 245 + if testKey[:6] != APIKeyPrefix { 246 + t.Errorf("Test key should start with %s", APIKeyPrefix) 247 + } 248 + 249 + // Verify key length is 70 characters 250 + if len(testKey) != 70 { 251 + t.Errorf("Test key should be 70 characters, got %d", len(testKey)) 252 + } 253 + } 254 + 255 + // ============================================================================= 256 + // AggregatorCredentials Tests 257 + // ============================================================================= 258 + 259 + func TestAggregatorCredentials_HasActiveAPIKey(t *testing.T) { 260 + tests := []struct { 261 + name string 262 + creds AggregatorCredentials 263 + wantActive bool 264 + }{ 265 + { 266 + name: "no key hash", 267 + creds: AggregatorCredentials{}, 268 + wantActive: false, 269 + }, 270 + { 271 + name: "has key hash, not revoked", 272 + creds: AggregatorCredentials{APIKeyHash: "somehash"}, 273 + wantActive: true, 274 + }, 275 + { 276 + name: "has key hash, revoked", 277 + creds: AggregatorCredentials{ 278 + APIKeyHash: "somehash", 279 + APIKeyRevokedAt: ptrTime(), 280 + }, 281 + wantActive: false, 282 + }, 283 + } 284 + 285 + for _, tt := range tests { 286 + t.Run(tt.name, func(t *testing.T) { 287 + got := tt.creds.HasActiveAPIKey() 288 + if got != tt.wantActive { 289 + t.Errorf("HasActiveAPIKey() = %v, want %v", got, tt.wantActive) 290 + } 291 + }) 292 + } 293 + } 294 + 295 + func TestAggregatorCredentials_IsOAuthTokenExpired(t *testing.T) { 296 + tests := []struct { 297 + name string 298 + creds AggregatorCredentials 299 + wantExpired bool 300 + }{ 301 + { 302 + name: "nil expiry", 303 + creds: AggregatorCredentials{}, 304 + wantExpired: true, 305 + }, 306 + { 307 + name: "expired in the past", 308 + creds: AggregatorCredentials{ 309 + OAuthTokenExpiresAt: ptrTimeOffset(-1 * time.Hour), 310 + }, 311 + wantExpired: true, 312 + }, 313 + { 314 + name: "within 5 minute buffer (4 minutes remaining)", 315 + creds: AggregatorCredentials{ 316 + OAuthTokenExpiresAt: ptrTimeOffset(4 * time.Minute), 317 + }, 318 + wantExpired: true, // Should be expired because within buffer 319 + }, 320 + { 321 + name: "exactly at 5 minute buffer", 322 + creds: AggregatorCredentials{ 323 + OAuthTokenExpiresAt: ptrTimeOffset(5 * time.Minute), 324 + }, 325 + wantExpired: true, // Edge case - at exactly buffer time 326 + }, 327 + { 328 + name: "beyond 5 minute buffer (6 minutes remaining)", 329 + creds: AggregatorCredentials{ 330 + OAuthTokenExpiresAt: ptrTimeOffset(6 * time.Minute), 331 + }, 332 + wantExpired: false, // Should not be expired 333 + }, 334 + { 335 + name: "well beyond buffer (1 hour remaining)", 336 + creds: AggregatorCredentials{ 337 + OAuthTokenExpiresAt: ptrTimeOffset(1 * time.Hour), 338 + }, 339 + wantExpired: false, 340 + }, 341 + } 342 + 343 + for _, tt := range tests { 344 + t.Run(tt.name, func(t *testing.T) { 345 + got := tt.creds.IsOAuthTokenExpired() 346 + if got != tt.wantExpired { 347 + t.Errorf("IsOAuthTokenExpired() = %v, want %v", got, tt.wantExpired) 348 + } 349 + }) 350 + } 351 + } 352 + 353 + // ============================================================================= 354 + // ValidateKey Tests 355 + // ============================================================================= 356 + 357 + func TestAPIKeyService_ValidateKey_InvalidFormat(t *testing.T) { 358 + repo := &mockRepository{} 359 + service := newTestAPIKeyService(repo) 360 + 361 + tests := []struct { 362 + name string 363 + key string 364 + wantErr error 365 + }{ 366 + { 367 + name: "empty key", 368 + key: "", 369 + wantErr: ErrAPIKeyInvalid, 370 + }, 371 + { 372 + name: "too short", 373 + key: "ckapi_short", 374 + wantErr: ErrAPIKeyInvalid, 375 + }, 376 + { 377 + name: "wrong prefix", 378 + key: "wrong_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", 379 + wantErr: ErrAPIKeyInvalid, 380 + }, 381 + { 382 + name: "correct length but wrong prefix", 383 + key: "badpfx0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd", 384 + wantErr: ErrAPIKeyInvalid, 385 + }, 386 + } 387 + 388 + for _, tt := range tests { 389 + t.Run(tt.name, func(t *testing.T) { 390 + _, err := service.ValidateKey(context.Background(), tt.key) 391 + if !errors.Is(err, tt.wantErr) { 392 + t.Errorf("ValidateKey() error = %v, want %v", err, tt.wantErr) 393 + } 394 + }) 395 + } 396 + } 397 + 398 + func TestAPIKeyService_ValidateKey_NotFound(t *testing.T) { 399 + repo := &mockRepository{ 400 + getCredentialsByAPIKeyHashFunc: func(ctx context.Context, keyHash string) (*AggregatorCredentials, error) { 401 + return nil, ErrAggregatorNotFound 402 + }, 403 + } 404 + service := newTestAPIKeyService(repo) 405 + 406 + // Valid format but key not in database 407 + validKey := "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" 408 + _, err := service.ValidateKey(context.Background(), validKey) 409 + if !errors.Is(err, ErrAPIKeyInvalid) { 410 + t.Errorf("ValidateKey() error = %v, want %v", err, ErrAPIKeyInvalid) 411 + } 412 + } 413 + 414 + func TestAPIKeyService_ValidateKey_Revoked(t *testing.T) { 415 + // The current implementation expects the repository to return ErrAPIKeyRevoked 416 + // when the API key has been revoked. This is done at the repository layer. 417 + repo := &mockRepository{ 418 + getCredentialsByAPIKeyHashFunc: func(ctx context.Context, keyHash string) (*AggregatorCredentials, error) { 419 + // Repository returns error for revoked keys 420 + return nil, ErrAPIKeyRevoked 421 + }, 422 + } 423 + service := newTestAPIKeyService(repo) 424 + 425 + validKey := "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" 426 + _, err := service.ValidateKey(context.Background(), validKey) 427 + if !errors.Is(err, ErrAPIKeyRevoked) { 428 + t.Errorf("ValidateKey() error = %v, want %v", err, ErrAPIKeyRevoked) 429 + } 430 + } 431 + 432 + func TestAPIKeyService_ValidateKey_Success(t *testing.T) { 433 + expectedDID := "did:plc:aggregator123" 434 + lastUsedChan := make(chan struct{}) 435 + 436 + repo := &mockRepository{ 437 + getCredentialsByAPIKeyHashFunc: func(ctx context.Context, keyHash string) (*AggregatorCredentials, error) { 438 + return &AggregatorCredentials{ 439 + DID: expectedDID, 440 + APIKeyHash: keyHash, 441 + APIKeyPrefix: "ckapi_0123", 442 + }, nil 443 + }, 444 + updateAPIKeyLastUsedFunc: func(ctx context.Context, did string) error { 445 + close(lastUsedChan) 446 + return nil 447 + }, 448 + } 449 + service := newTestAPIKeyService(repo) 450 + 451 + validKey := "ckapi_0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" 452 + creds, err := service.ValidateKey(context.Background(), validKey) 453 + if err != nil { 454 + t.Fatalf("ValidateKey() unexpected error: %v", err) 455 + } 456 + 457 + if creds.DID != expectedDID { 458 + t.Errorf("ValidateKey() DID = %s, want %s", creds.DID, expectedDID) 459 + } 460 + 461 + // Wait for async update with timeout using channel-based synchronization 462 + select { 463 + case <-lastUsedChan: 464 + // Success - UpdateAPIKeyLastUsed was called 465 + case <-time.After(1 * time.Second): 466 + t.Error("Expected UpdateAPIKeyLastUsed to be called (timeout)") 467 + } 468 + } 469 + 470 + // ============================================================================= 471 + // GenerateKey Tests 472 + // ============================================================================= 473 + 474 + func TestAPIKeyService_GenerateKey_AggregatorNotFound(t *testing.T) { 475 + repo := &mockRepository{ 476 + getAggregatorFunc: func(ctx context.Context, did string) (*Aggregator, error) { 477 + return nil, ErrAggregatorNotFound 478 + }, 479 + } 480 + service := newTestAPIKeyService(repo) 481 + 482 + did, _ := syntax.ParseDID("did:plc:test123") 483 + session := &oauth.ClientSessionData{ 484 + AccountDID: did, 485 + AccessToken: "test_token", 486 + } 487 + 488 + _, _, err := service.GenerateKey(context.Background(), "did:plc:test123", session) 489 + if err == nil { 490 + t.Error("GenerateKey() expected error, got nil") 491 + } 492 + } 493 + 494 + func TestAPIKeyService_GenerateKey_DIDMismatch(t *testing.T) { 495 + repo := &mockRepository{ 496 + getAggregatorFunc: func(ctx context.Context, did string) (*Aggregator, error) { 497 + return &Aggregator{DID: did}, nil 498 + }, 499 + } 500 + service := newTestAPIKeyService(repo) 501 + 502 + // Session DID doesn't match requested aggregator DID 503 + sessionDID, _ := syntax.ParseDID("did:plc:different") 504 + session := &oauth.ClientSessionData{ 505 + AccountDID: sessionDID, 506 + AccessToken: "test_token", 507 + } 508 + 509 + _, _, err := service.GenerateKey(context.Background(), "did:plc:aggregator123", session) 510 + if err == nil { 511 + t.Error("GenerateKey() expected DID mismatch error, got nil") 512 + } 513 + if !errors.Is(err, nil) && err.Error() == "" { 514 + // Just check there's an error for DID mismatch 515 + } 516 + } 517 + 518 + func TestAPIKeyService_GenerateKey_SetAPIKeyError(t *testing.T) { 519 + expectedError := errors.New("database error") 520 + repo := &mockRepository{ 521 + getAggregatorFunc: func(ctx context.Context, did string) (*Aggregator, error) { 522 + return &Aggregator{DID: did, DisplayName: "Test"}, nil 523 + }, 524 + setAPIKeyFunc: func(ctx context.Context, did, keyPrefix, keyHash string, oauthCreds *OAuthCredentials) error { 525 + return expectedError 526 + }, 527 + } 528 + 529 + // Create a minimal mock OAuth store 530 + mockStore := &mockOAuthStore{} 531 + mockApp := &oauth.ClientApp{Store: mockStore} 532 + 533 + service := NewAPIKeyService(repo, mockApp) 534 + 535 + did, _ := syntax.ParseDID("did:plc:aggregator123") 536 + session := &oauth.ClientSessionData{ 537 + AccountDID: did, 538 + AccessToken: "test_token", 539 + } 540 + 541 + _, _, err := service.GenerateKey(context.Background(), "did:plc:aggregator123", session) 542 + if err == nil { 543 + t.Error("GenerateKey() expected error, got nil") 544 + } 545 + } 546 + 547 + func TestAPIKeyService_GenerateKey_Success(t *testing.T) { 548 + aggregatorDID := "did:plc:aggregator123" 549 + var storedKeyPrefix, storedKeyHash string 550 + var storedOAuthCreds *OAuthCredentials 551 + var savedSession *oauth.ClientSessionData 552 + 553 + repo := &mockRepository{ 554 + getAggregatorFunc: func(ctx context.Context, did string) (*Aggregator, error) { 555 + if did != aggregatorDID { 556 + return nil, ErrAggregatorNotFound 557 + } 558 + return &Aggregator{ 559 + DID: did, 560 + DisplayName: "Test Aggregator", 561 + }, nil 562 + }, 563 + setAPIKeyFunc: func(ctx context.Context, did, keyPrefix, keyHash string, oauthCreds *OAuthCredentials) error { 564 + storedKeyPrefix = keyPrefix 565 + storedKeyHash = keyHash 566 + storedOAuthCreds = oauthCreds 567 + return nil 568 + }, 569 + } 570 + 571 + // Create mock OAuth store that tracks saved sessions 572 + mockStore := &mockOAuthStore{ 573 + saveSessionFunc: func(ctx context.Context, session oauth.ClientSessionData) error { 574 + savedSession = &session 575 + return nil 576 + }, 577 + } 578 + mockApp := &oauth.ClientApp{Store: mockStore} 579 + 580 + service := NewAPIKeyService(repo, mockApp) 581 + 582 + // Create OAuth session 583 + did, _ := syntax.ParseDID(aggregatorDID) 584 + session := &oauth.ClientSessionData{ 585 + AccountDID: did, 586 + SessionID: "original_session", 587 + AccessToken: "test_access_token", 588 + RefreshToken: "test_refresh_token", 589 + HostURL: "https://pds.example.com", 590 + AuthServerURL: "https://auth.example.com", 591 + AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 592 + DPoPPrivateKeyMultibase: "z1234567890", 593 + DPoPAuthServerNonce: "auth_nonce_123", 594 + DPoPHostNonce: "host_nonce_456", 595 + } 596 + 597 + plainKey, keyPrefix, err := service.GenerateKey(context.Background(), aggregatorDID, session) 598 + if err != nil { 599 + t.Fatalf("GenerateKey() unexpected error: %v", err) 600 + } 601 + 602 + // Verify key format 603 + if len(plainKey) != APIKeyTotalLength { 604 + t.Errorf("GenerateKey() plainKey length = %d, want %d", len(plainKey), APIKeyTotalLength) 605 + } 606 + if plainKey[:6] != APIKeyPrefix { 607 + t.Errorf("GenerateKey() plainKey prefix = %s, want %s", plainKey[:6], APIKeyPrefix) 608 + } 609 + 610 + // Verify key prefix is first 12 chars 611 + if keyPrefix != plainKey[:12] { 612 + t.Errorf("GenerateKey() keyPrefix = %s, want %s", keyPrefix, plainKey[:12]) 613 + } 614 + 615 + // Verify hash was stored (SHA-256 produces 64 hex chars) 616 + if len(storedKeyHash) != 64 { 617 + t.Errorf("GenerateKey() stored hash length = %d, want 64", len(storedKeyHash)) 618 + } 619 + 620 + // Verify hash matches the key 621 + expectedHash := hashAPIKey(plainKey) 622 + if storedKeyHash != expectedHash { 623 + t.Errorf("GenerateKey() stored hash doesn't match key hash") 624 + } 625 + 626 + // Verify stored key prefix matches returned prefix 627 + if storedKeyPrefix != keyPrefix { 628 + t.Errorf("GenerateKey() stored keyPrefix = %s, want %s", storedKeyPrefix, keyPrefix) 629 + } 630 + 631 + // Verify OAuth credentials were saved 632 + if storedOAuthCreds == nil { 633 + t.Fatal("GenerateKey() OAuth credentials not stored") 634 + } 635 + if storedOAuthCreds.AccessToken != session.AccessToken { 636 + t.Errorf("GenerateKey() stored AccessToken = %s, want %s", storedOAuthCreds.AccessToken, session.AccessToken) 637 + } 638 + if storedOAuthCreds.RefreshToken != session.RefreshToken { 639 + t.Errorf("GenerateKey() stored RefreshToken = %s, want %s", storedOAuthCreds.RefreshToken, session.RefreshToken) 640 + } 641 + if storedOAuthCreds.PDSURL != session.HostURL { 642 + t.Errorf("GenerateKey() stored PDSURL = %s, want %s", storedOAuthCreds.PDSURL, session.HostURL) 643 + } 644 + if storedOAuthCreds.AuthServerIss != session.AuthServerURL { 645 + t.Errorf("GenerateKey() stored AuthServerIss = %s, want %s", storedOAuthCreds.AuthServerIss, session.AuthServerURL) 646 + } 647 + if storedOAuthCreds.DPoPPrivateKeyMultibase != session.DPoPPrivateKeyMultibase { 648 + t.Errorf("GenerateKey() stored DPoPPrivateKeyMultibase mismatch") 649 + } 650 + if storedOAuthCreds.DPoPAuthServerNonce != session.DPoPAuthServerNonce { 651 + t.Errorf("GenerateKey() stored DPoPAuthServerNonce = %s, want %s", storedOAuthCreds.DPoPAuthServerNonce, session.DPoPAuthServerNonce) 652 + } 653 + if storedOAuthCreds.DPoPPDSNonce != session.DPoPHostNonce { 654 + t.Errorf("GenerateKey() stored DPoPPDSNonce = %s, want %s", storedOAuthCreds.DPoPPDSNonce, session.DPoPHostNonce) 655 + } 656 + 657 + // Verify session was saved to OAuth store 658 + if savedSession == nil { 659 + t.Fatal("GenerateKey() session not saved to OAuth store") 660 + } 661 + if savedSession.SessionID != DefaultSessionID { 662 + t.Errorf("GenerateKey() saved session ID = %s, want %s", savedSession.SessionID, DefaultSessionID) 663 + } 664 + if savedSession.AccessToken != session.AccessToken { 665 + t.Errorf("GenerateKey() saved session AccessToken mismatch") 666 + } 667 + } 668 + 669 + func TestAPIKeyService_GenerateKey_OAuthStoreSaveError(t *testing.T) { 670 + // Test that OAuth session save failure aborts key creation early 671 + // With the new ordering (OAuth session first, then API key), if OAuth save fails, 672 + // we abort immediately without creating an API key. 673 + aggregatorDID := "did:plc:aggregator123" 674 + setAPIKeyCalled := false 675 + 676 + repo := &mockRepository{ 677 + getAggregatorFunc: func(ctx context.Context, did string) (*Aggregator, error) { 678 + return &Aggregator{DID: did, DisplayName: "Test"}, nil 679 + }, 680 + setAPIKeyFunc: func(ctx context.Context, did, keyPrefix, keyHash string, oauthCreds *OAuthCredentials) error { 681 + setAPIKeyCalled = true 682 + return nil 683 + }, 684 + } 685 + 686 + // Create mock OAuth store that fails on save 687 + mockStore := &mockOAuthStore{ 688 + saveSessionFunc: func(ctx context.Context, session oauth.ClientSessionData) error { 689 + return errors.New("failed to save session") 690 + }, 691 + } 692 + mockApp := &oauth.ClientApp{Store: mockStore} 693 + 694 + service := NewAPIKeyService(repo, mockApp) 695 + 696 + did, _ := syntax.ParseDID(aggregatorDID) 697 + session := &oauth.ClientSessionData{ 698 + AccountDID: did, 699 + AccessToken: "test_token", 700 + } 701 + 702 + _, _, err := service.GenerateKey(context.Background(), aggregatorDID, session) 703 + if err == nil { 704 + t.Error("GenerateKey() expected error when OAuth store save fails, got nil") 705 + } 706 + 707 + // Verify SetAPIKey was NOT called - we should abort before storing the key 708 + // This prevents the race condition where an API key exists but can't refresh tokens 709 + if setAPIKeyCalled { 710 + t.Error("GenerateKey() should NOT call SetAPIKey when OAuth session save fails") 711 + } 712 + } 713 + 714 + // mockOAuthStore implements oauth.ClientAuthStore for testing 715 + type mockOAuthStore struct { 716 + getSessionFunc func(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) 717 + saveSessionFunc func(ctx context.Context, session oauth.ClientSessionData) error 718 + deleteSessionFunc func(ctx context.Context, did syntax.DID, sessionID string) error 719 + getAuthRequestInfoFunc func(ctx context.Context, state string) (*oauth.AuthRequestData, error) 720 + saveAuthRequestInfoFunc func(ctx context.Context, info oauth.AuthRequestData) error 721 + deleteAuthRequestInfoFunc func(ctx context.Context, state string) error 722 + } 723 + 724 + func (m *mockOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { 725 + if m.getSessionFunc != nil { 726 + return m.getSessionFunc(ctx, did, sessionID) 727 + } 728 + return nil, errors.New("session not found") 729 + } 730 + 731 + func (m *mockOAuthStore) SaveSession(ctx context.Context, session oauth.ClientSessionData) error { 732 + if m.saveSessionFunc != nil { 733 + return m.saveSessionFunc(ctx, session) 734 + } 735 + return nil 736 + } 737 + 738 + func (m *mockOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 739 + if m.deleteSessionFunc != nil { 740 + return m.deleteSessionFunc(ctx, did, sessionID) 741 + } 742 + return nil 743 + } 744 + 745 + func (m *mockOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { 746 + if m.getAuthRequestInfoFunc != nil { 747 + return m.getAuthRequestInfoFunc(ctx, state) 748 + } 749 + return nil, errors.New("not found") 750 + } 751 + 752 + func (m *mockOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 753 + if m.saveAuthRequestInfoFunc != nil { 754 + return m.saveAuthRequestInfoFunc(ctx, info) 755 + } 756 + return nil 757 + } 758 + 759 + func (m *mockOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 760 + if m.deleteAuthRequestInfoFunc != nil { 761 + return m.deleteAuthRequestInfoFunc(ctx, state) 762 + } 763 + return nil 764 + } 765 + 766 + // ============================================================================= 767 + // RevokeKey Tests 768 + // ============================================================================= 769 + 770 + func TestAPIKeyService_RevokeKey_Success(t *testing.T) { 771 + revokeCalled := false 772 + revokedDID := "" 773 + 774 + repo := &mockRepository{ 775 + revokeAPIKeyFunc: func(ctx context.Context, did string) error { 776 + revokeCalled = true 777 + revokedDID = did 778 + return nil 779 + }, 780 + } 781 + service := newTestAPIKeyService(repo) 782 + 783 + err := service.RevokeKey(context.Background(), "did:plc:aggregator123") 784 + if err != nil { 785 + t.Fatalf("RevokeKey() unexpected error: %v", err) 786 + } 787 + 788 + if !revokeCalled { 789 + t.Error("Expected RevokeAPIKey to be called on repository") 790 + } 791 + if revokedDID != "did:plc:aggregator123" { 792 + t.Errorf("RevokeKey() called with DID = %s, want did:plc:aggregator123", revokedDID) 793 + } 794 + } 795 + 796 + func TestAPIKeyService_RevokeKey_Error(t *testing.T) { 797 + expectedError := errors.New("database error") 798 + repo := &mockRepository{ 799 + revokeAPIKeyFunc: func(ctx context.Context, did string) error { 800 + return expectedError 801 + }, 802 + } 803 + service := newTestAPIKeyService(repo) 804 + 805 + err := service.RevokeKey(context.Background(), "did:plc:aggregator123") 806 + if err == nil { 807 + t.Error("RevokeKey() expected error, got nil") 808 + } 809 + } 810 + 811 + // ============================================================================= 812 + // GetAPIKeyInfo Tests 813 + // ============================================================================= 814 + 815 + func TestAPIKeyService_GetAPIKeyInfo_NoKey(t *testing.T) { 816 + repo := &mockRepository{ 817 + getAggregatorCredentialsFunc: func(ctx context.Context, did string) (*AggregatorCredentials, error) { 818 + return &AggregatorCredentials{ 819 + DID: did, 820 + APIKeyHash: "", // No key 821 + }, nil 822 + }, 823 + } 824 + service := newTestAPIKeyService(repo) 825 + 826 + info, err := service.GetAPIKeyInfo(context.Background(), "did:plc:aggregator123") 827 + if err != nil { 828 + t.Fatalf("GetAPIKeyInfo() unexpected error: %v", err) 829 + } 830 + 831 + if info.HasKey { 832 + t.Error("GetAPIKeyInfo() HasKey = true, want false") 833 + } 834 + } 835 + 836 + func TestAPIKeyService_GetAPIKeyInfo_HasActiveKey(t *testing.T) { 837 + createdAt := time.Now().Add(-24 * time.Hour) 838 + lastUsed := time.Now().Add(-1 * time.Hour) 839 + 840 + repo := &mockRepository{ 841 + getAggregatorCredentialsFunc: func(ctx context.Context, did string) (*AggregatorCredentials, error) { 842 + return &AggregatorCredentials{ 843 + DID: did, 844 + APIKeyHash: "somehash", 845 + APIKeyPrefix: "ckapi_test12", 846 + APIKeyCreatedAt: &createdAt, 847 + APIKeyLastUsed: &lastUsed, 848 + }, nil 849 + }, 850 + } 851 + service := newTestAPIKeyService(repo) 852 + 853 + info, err := service.GetAPIKeyInfo(context.Background(), "did:plc:aggregator123") 854 + if err != nil { 855 + t.Fatalf("GetAPIKeyInfo() unexpected error: %v", err) 856 + } 857 + 858 + if !info.HasKey { 859 + t.Error("GetAPIKeyInfo() HasKey = false, want true") 860 + } 861 + if info.KeyPrefix != "ckapi_test12" { 862 + t.Errorf("GetAPIKeyInfo() KeyPrefix = %s, want ckapi_test12", info.KeyPrefix) 863 + } 864 + if info.IsRevoked { 865 + t.Error("GetAPIKeyInfo() IsRevoked = true, want false") 866 + } 867 + if info.CreatedAt == nil || !info.CreatedAt.Equal(createdAt) { 868 + t.Error("GetAPIKeyInfo() CreatedAt mismatch") 869 + } 870 + if info.LastUsedAt == nil || !info.LastUsedAt.Equal(lastUsed) { 871 + t.Error("GetAPIKeyInfo() LastUsedAt mismatch") 872 + } 873 + } 874 + 875 + func TestAPIKeyService_GetAPIKeyInfo_RevokedKey(t *testing.T) { 876 + revokedAt := time.Now().Add(-1 * time.Hour) 877 + 878 + repo := &mockRepository{ 879 + getAggregatorCredentialsFunc: func(ctx context.Context, did string) (*AggregatorCredentials, error) { 880 + return &AggregatorCredentials{ 881 + DID: did, 882 + APIKeyHash: "somehash", 883 + APIKeyPrefix: "ckapi_test12", 884 + APIKeyRevokedAt: &revokedAt, 885 + }, nil 886 + }, 887 + } 888 + service := newTestAPIKeyService(repo) 889 + 890 + info, err := service.GetAPIKeyInfo(context.Background(), "did:plc:aggregator123") 891 + if err != nil { 892 + t.Fatalf("GetAPIKeyInfo() unexpected error: %v", err) 893 + } 894 + 895 + if !info.HasKey { 896 + t.Error("GetAPIKeyInfo() HasKey = false, want true (revoked keys still exist)") 897 + } 898 + if !info.IsRevoked { 899 + t.Error("GetAPIKeyInfo() IsRevoked = false, want true") 900 + } 901 + if info.RevokedAt == nil || !info.RevokedAt.Equal(revokedAt) { 902 + t.Error("GetAPIKeyInfo() RevokedAt mismatch") 903 + } 904 + } 905 + 906 + func TestAPIKeyService_GetAPIKeyInfo_NotFound(t *testing.T) { 907 + repo := &mockRepository{ 908 + getAggregatorCredentialsFunc: func(ctx context.Context, did string) (*AggregatorCredentials, error) { 909 + return nil, ErrAggregatorNotFound 910 + }, 911 + } 912 + service := newTestAPIKeyService(repo) 913 + 914 + _, err := service.GetAPIKeyInfo(context.Background(), "did:plc:nonexistent") 915 + if !errors.Is(err, ErrAggregatorNotFound) { 916 + t.Errorf("GetAPIKeyInfo() error = %v, want ErrAggregatorNotFound", err) 917 + } 918 + } 919 + 920 + // ============================================================================= 921 + // RefreshTokensIfNeeded Tests 922 + // ============================================================================= 923 + 924 + func TestAPIKeyService_RefreshTokensIfNeeded_TokensStillValid(t *testing.T) { 925 + // Tokens expire in 1 hour - well beyond the 5 minute buffer 926 + expiresAt := time.Now().Add(1 * time.Hour) 927 + 928 + creds := &AggregatorCredentials{ 929 + DID: "did:plc:aggregator123", 930 + OAuthTokenExpiresAt: &expiresAt, 931 + } 932 + 933 + repo := &mockRepository{} 934 + service := newTestAPIKeyService(repo) 935 + 936 + err := service.RefreshTokensIfNeeded(context.Background(), creds) 937 + if err != nil { 938 + t.Fatalf("RefreshTokensIfNeeded() unexpected error: %v", err) 939 + } 940 + 941 + // No refresh should have happened - we can't easily verify this without 942 + // more complex mocking, but the absence of error is the key indicator 943 + } 944 + 945 + func TestAPIKeyService_RefreshTokensIfNeeded_WithinBuffer(t *testing.T) { 946 + // Token expires in 4 minutes - within the 5 minute buffer, so needs refresh 947 + // This test verifies that when tokens are within the buffer, the service 948 + // attempts to refresh them. 949 + // 950 + // Note: Full integration testing of token refresh requires a real OAuth app. 951 + // This test is intentionally skipped as it would require extensive mocking 952 + // of the indigo OAuth library internals. 953 + t.Skip("RefreshTokensIfNeeded requires fully configured OAuth app - covered by integration tests") 954 + } 955 + 956 + func TestAPIKeyService_RefreshTokensIfNeeded_ExpiredNilTokens(t *testing.T) { 957 + // When OAuthTokenExpiresAt is nil, tokens need refresh 958 + // This should also attempt to refresh (and fail with nil OAuth app) 959 + t.Skip("RefreshTokensIfNeeded requires fully configured OAuth app - covered by integration tests") 960 + } 961 + 962 + // ============================================================================= 963 + // GetAccessToken Tests 964 + // ============================================================================= 965 + 966 + func TestAPIKeyService_GetAccessToken_ValidAggregatorTokensNotExpired(t *testing.T) { 967 + // Tokens expire in 1 hour - well beyond the 5 minute buffer 968 + expiresAt := time.Now().Add(1 * time.Hour) 969 + expectedToken := "valid_access_token_123" 970 + 971 + creds := &AggregatorCredentials{ 972 + DID: "did:plc:aggregator123", 973 + OAuthAccessToken: expectedToken, 974 + OAuthTokenExpiresAt: &expiresAt, 975 + } 976 + 977 + repo := &mockRepository{} 978 + service := newTestAPIKeyService(repo) 979 + 980 + token, err := service.GetAccessToken(context.Background(), creds) 981 + if err != nil { 982 + t.Fatalf("GetAccessToken() unexpected error: %v", err) 983 + } 984 + 985 + if token != expectedToken { 986 + t.Errorf("GetAccessToken() = %s, want %s", token, expectedToken) 987 + } 988 + } 989 + 990 + func TestAPIKeyService_GetAccessToken_ExpiredTokens(t *testing.T) { 991 + // Tokens expired 1 hour ago - requires refresh 992 + // Since refresh requires a real OAuth app, this test verifies the error path 993 + expiresAt := time.Now().Add(-1 * time.Hour) 994 + 995 + creds := &AggregatorCredentials{ 996 + DID: "did:plc:aggregator123", 997 + OAuthAccessToken: "expired_token", 998 + OAuthRefreshToken: "refresh_token", 999 + OAuthTokenExpiresAt: &expiresAt, 1000 + } 1001 + 1002 + repo := &mockRepository{} 1003 + // Service has nil OAuth app, so refresh will fail 1004 + service := newTestAPIKeyService(repo) 1005 + 1006 + _, err := service.GetAccessToken(context.Background(), creds) 1007 + if err == nil { 1008 + t.Error("GetAccessToken() expected error when tokens are expired and no OAuth app configured, got nil") 1009 + } 1010 + } 1011 + 1012 + func TestAPIKeyService_GetAccessToken_NilExpiry(t *testing.T) { 1013 + // Nil expiry means tokens need refresh 1014 + creds := &AggregatorCredentials{ 1015 + DID: "did:plc:aggregator123", 1016 + OAuthAccessToken: "some_token", 1017 + OAuthTokenExpiresAt: nil, // nil means needs refresh 1018 + } 1019 + 1020 + repo := &mockRepository{} 1021 + service := newTestAPIKeyService(repo) 1022 + 1023 + _, err := service.GetAccessToken(context.Background(), creds) 1024 + if err == nil { 1025 + t.Error("GetAccessToken() expected error when expiry is nil and no OAuth app configured, got nil") 1026 + } 1027 + } 1028 + 1029 + func TestAPIKeyService_GetAccessToken_WithinExpiryBuffer(t *testing.T) { 1030 + // Tokens expire in 4 minutes - within the 5 minute buffer, so needs refresh 1031 + expiresAt := time.Now().Add(4 * time.Minute) 1032 + 1033 + creds := &AggregatorCredentials{ 1034 + DID: "did:plc:aggregator123", 1035 + OAuthAccessToken: "soon_to_expire_token", 1036 + OAuthRefreshToken: "refresh_token", 1037 + OAuthTokenExpiresAt: &expiresAt, 1038 + } 1039 + 1040 + repo := &mockRepository{} 1041 + service := newTestAPIKeyService(repo) 1042 + 1043 + // Should attempt refresh and fail since no OAuth app is configured 1044 + _, err := service.GetAccessToken(context.Background(), creds) 1045 + if err == nil { 1046 + t.Error("GetAccessToken() expected error when tokens are within buffer and no OAuth app configured, got nil") 1047 + } 1048 + } 1049 + 1050 + func TestAPIKeyService_GetAccessToken_RevokedKey(t *testing.T) { 1051 + // Test behavior when aggregator has a revoked key 1052 + // The API key check happens in ValidateKey, but GetAccessToken should still work 1053 + // if called directly with a valid aggregator (before revocation is detected) 1054 + expiresAt := time.Now().Add(1 * time.Hour) 1055 + revokedAt := time.Now().Add(-30 * time.Minute) 1056 + expectedToken := "valid_access_token" 1057 + 1058 + creds := &AggregatorCredentials{ 1059 + DID: "did:plc:aggregator123", 1060 + APIKeyRevokedAt: &revokedAt, // Key is revoked 1061 + OAuthAccessToken: expectedToken, 1062 + OAuthTokenExpiresAt: &expiresAt, 1063 + } 1064 + 1065 + repo := &mockRepository{} 1066 + service := newTestAPIKeyService(repo) 1067 + 1068 + // GetAccessToken doesn't check revocation - that's done at ValidateKey level 1069 + // It just returns the token if valid 1070 + token, err := service.GetAccessToken(context.Background(), creds) 1071 + if err != nil { 1072 + t.Fatalf("GetAccessToken() unexpected error: %v", err) 1073 + } 1074 + 1075 + if token != expectedToken { 1076 + t.Errorf("GetAccessToken() = %s, want %s", token, expectedToken) 1077 + } 1078 + } 1079 + 1080 + func TestAPIKeyService_FailureCounters_InitiallyZero(t *testing.T) { 1081 + repo := &mockRepository{} 1082 + service := newTestAPIKeyService(repo) 1083 + 1084 + if got := service.GetFailedLastUsedUpdates(); got != 0 { 1085 + t.Errorf("GetFailedLastUsedUpdates() = %d, want 0", got) 1086 + } 1087 + 1088 + if got := service.GetFailedNonceUpdates(); got != 0 { 1089 + t.Errorf("GetFailedNonceUpdates() = %d, want 0", got) 1090 + } 1091 + } 1092 + 1093 + func TestAPIKeyService_FailedLastUsedUpdates_IncrementsOnError(t *testing.T) { 1094 + // Create a valid API key 1095 + plainKey := APIKeyPrefix + "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789" 1096 + keyHash := hashAPIKey(plainKey) 1097 + 1098 + updateCalled := make(chan struct{}, 1) 1099 + repo := &mockRepository{ 1100 + getCredentialsByAPIKeyHashFunc: func(ctx context.Context, hash string) (*AggregatorCredentials, error) { 1101 + if hash == keyHash { 1102 + return &AggregatorCredentials{ 1103 + DID: "did:plc:aggregator123", 1104 + APIKeyHash: keyHash, 1105 + }, nil 1106 + } 1107 + return nil, ErrAPIKeyInvalid 1108 + }, 1109 + updateAPIKeyLastUsedFunc: func(ctx context.Context, did string) error { 1110 + defer func() { updateCalled <- struct{}{} }() 1111 + return errors.New("database connection failed") 1112 + }, 1113 + } 1114 + 1115 + service := newTestAPIKeyService(repo) 1116 + 1117 + // Initial count should be 0 1118 + if got := service.GetFailedLastUsedUpdates(); got != 0 { 1119 + t.Errorf("GetFailedLastUsedUpdates() initial = %d, want 0", got) 1120 + } 1121 + 1122 + // Validate the key (triggers async last_used update) 1123 + _, err := service.ValidateKey(context.Background(), plainKey) 1124 + if err != nil { 1125 + t.Fatalf("ValidateKey() unexpected error: %v", err) 1126 + } 1127 + 1128 + // Wait for async update to complete 1129 + select { 1130 + case <-updateCalled: 1131 + // Update was called 1132 + case <-time.After(2 * time.Second): 1133 + t.Fatal("timeout waiting for async UpdateAPIKeyLastUsed call") 1134 + } 1135 + 1136 + // Give a moment for the counter to be incremented 1137 + time.Sleep(10 * time.Millisecond) 1138 + 1139 + // Counter should now be 1 1140 + if got := service.GetFailedLastUsedUpdates(); got != 1 { 1141 + t.Errorf("GetFailedLastUsedUpdates() after failure = %d, want 1", got) 1142 + } 1143 + }
+22 -1
internal/core/aggregators/errors.go
··· 16 16 ErrConfigSchemaValidation = errors.New("configuration does not match aggregator's schema") 17 17 ErrNotModerator = errors.New("user is not a moderator of this community") 18 18 ErrNotImplemented = errors.New("feature not yet implemented") // For Phase 2 write-forward operations 19 + 20 + // API Key authentication errors 21 + ErrAPIKeyRevoked = errors.New("API key has been revoked") 22 + ErrAPIKeyInvalid = errors.New("invalid API key") 23 + ErrAPIKeyNotFound = errors.New("API key not found for this aggregator") 24 + ErrOAuthTokenExpired = errors.New("OAuth token has expired and needs refresh") 25 + ErrOAuthRefreshFailed = errors.New("failed to refresh OAuth token") 26 + ErrOAuthSessionMismatch = errors.New("OAuth session DID does not match aggregator DID") 19 27 ) 20 28 21 29 // ValidationError represents a validation error with field details ··· 38 46 39 47 // Error classification helpers for handlers to map to HTTP status codes 40 48 func IsNotFound(err error) bool { 41 - return errors.Is(err, ErrAggregatorNotFound) || errors.Is(err, ErrAuthorizationNotFound) 49 + return errors.Is(err, ErrAggregatorNotFound) || 50 + errors.Is(err, ErrAuthorizationNotFound) || 51 + errors.Is(err, ErrAPIKeyNotFound) 42 52 } 43 53 44 54 func IsValidationError(err error) bool { ··· 61 71 func IsNotImplemented(err error) bool { 62 72 return errors.Is(err, ErrNotImplemented) 63 73 } 74 + 75 + func IsAPIKeyError(err error) bool { 76 + return errors.Is(err, ErrAPIKeyRevoked) || 77 + errors.Is(err, ErrAPIKeyInvalid) || 78 + errors.Is(err, ErrAPIKeyNotFound) 79 + } 80 + 81 + func IsOAuthError(err error) bool { 82 + return errors.Is(err, ErrOAuthTokenExpired) || 83 + errors.Is(err, ErrOAuthRefreshFailed) 84 + }
+43
internal/core/aggregators/interfaces.go
··· 3 3 import ( 4 4 "context" 5 5 "time" 6 + 7 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 6 8 ) 7 9 8 10 // Repository defines the interface for aggregator data persistence ··· 34 36 RecordAggregatorPost(ctx context.Context, aggregatorDID, communityDID, postURI, postCID string) error 35 37 CountRecentPosts(ctx context.Context, aggregatorDID, communityDID string, since time.Time) (int, error) 36 38 GetRecentPosts(ctx context.Context, aggregatorDID, communityDID string, since time.Time) ([]*AggregatorPost, error) 39 + 40 + // API Key Authentication 41 + // GetByAPIKeyHash looks up an aggregator by their API key hash for authentication 42 + GetByAPIKeyHash(ctx context.Context, keyHash string) (*Aggregator, error) 43 + // GetAggregatorCredentials retrieves only the credential fields for an aggregator. 44 + // Used by APIKeyService for authentication operations where full aggregator is not needed. 45 + GetAggregatorCredentials(ctx context.Context, did string) (*AggregatorCredentials, error) 46 + // GetCredentialsByAPIKeyHash looks up aggregator credentials by their API key hash. 47 + // Returns ErrAPIKeyRevoked if the key has been revoked. 48 + // Returns ErrAPIKeyInvalid if no aggregator found with that hash. 49 + GetCredentialsByAPIKeyHash(ctx context.Context, keyHash string) (*AggregatorCredentials, error) 50 + // SetAPIKey stores API key credentials and OAuth session for an aggregator 51 + SetAPIKey(ctx context.Context, did, keyPrefix, keyHash string, oauthCreds *OAuthCredentials) error 52 + // UpdateOAuthTokens updates OAuth tokens after a refresh operation 53 + UpdateOAuthTokens(ctx context.Context, did, accessToken, refreshToken string, expiresAt time.Time) error 54 + // UpdateOAuthNonces updates DPoP nonces after token operations 55 + UpdateOAuthNonces(ctx context.Context, did, authServerNonce, pdsNonce string) error 56 + // UpdateAPIKeyLastUsed updates the last_used_at timestamp for audit purposes 57 + UpdateAPIKeyLastUsed(ctx context.Context, did string) error 58 + // RevokeAPIKey marks an API key as revoked (sets api_key_revoked_at) 59 + RevokeAPIKey(ctx context.Context, did string) error 37 60 } 38 61 39 62 // Service defines the interface for aggregator business logic ··· 60 83 // Post tracking (called after successful post creation) 61 84 RecordAggregatorPost(ctx context.Context, aggregatorDID, communityDID, postURI, postCID string) error 62 85 } 86 + 87 + // APIKeyServiceInterface defines the interface for API key operations used by handlers. 88 + // This interface enables easier testing by allowing mock implementations. 89 + type APIKeyServiceInterface interface { 90 + // GenerateKey creates a new API key for an aggregator. 91 + // Returns the plain-text key (only shown once) and the key prefix for reference. 92 + GenerateKey(ctx context.Context, aggregatorDID string, oauthSession *oauth.ClientSessionData) (plainKey string, keyPrefix string, err error) 93 + 94 + // GetAPIKeyInfo returns information about an aggregator's API key (without the actual key). 95 + GetAPIKeyInfo(ctx context.Context, aggregatorDID string) (*APIKeyInfo, error) 96 + 97 + // RevokeKey revokes an API key for an aggregator. 98 + RevokeKey(ctx context.Context, aggregatorDID string) error 99 + 100 + // GetFailedLastUsedUpdates returns the count of failed last_used timestamp updates. 101 + GetFailedLastUsedUpdates() int64 102 + 103 + // GetFailedNonceUpdates returns the count of failed OAuth nonce updates. 104 + GetFailedNonceUpdates() int64 105 + }
+25 -11
internal/core/posts/service.go
··· 10 10 "log" 11 11 "net/http" 12 12 "os" 13 + "strings" 13 14 "time" 14 15 15 16 "Coves/internal/api/middleware" ··· 83 84 return nil, fmt.Errorf("authenticated DID does not match author DID") 84 85 } 85 86 86 - // 3. Determine actor type: Kagi aggregator, other aggregator, or regular user 87 - kagiAggregatorDID := os.Getenv("KAGI_AGGREGATOR_DID") 88 - isTrustedKagi := kagiAggregatorDID != "" && req.AuthorDID == kagiAggregatorDID 87 + // 3. Determine actor type: trusted aggregator, other aggregator, or regular user 88 + // Check against comma-separated list of trusted aggregator DIDs 89 + trustedDIDs := os.Getenv("TRUSTED_AGGREGATOR_DIDS") 90 + if trustedDIDs == "" { 91 + // Fallback to legacy single DID env var 92 + trustedDIDs = os.Getenv("KAGI_AGGREGATOR_DID") 93 + } 94 + isTrustedAggregator := false 95 + if trustedDIDs != "" { 96 + for _, did := range strings.Split(trustedDIDs, ",") { 97 + if strings.TrimSpace(did) == req.AuthorDID { 98 + isTrustedAggregator = true 99 + break 100 + } 101 + } 102 + } 89 103 90 - // Check if this is a non-Kagi aggregator (requires database lookup) 104 + // Check if this is a non-trusted aggregator (requires database lookup) 91 105 var isOtherAggregator bool 92 106 var err error 93 - if !isTrustedKagi && s.aggregatorService != nil { 107 + if !isTrustedAggregator && s.aggregatorService != nil { 94 108 isOtherAggregator, err = s.aggregatorService.IsAggregator(ctx, req.AuthorDID) 95 109 if err != nil { 96 110 log.Printf("[POST-CREATE] Warning: failed to check if DID is aggregator: %v", err) ··· 138 152 } 139 153 140 154 // 7. Apply validation based on actor type (aggregator vs user) 141 - if isTrustedKagi { 155 + if isTrustedAggregator { 142 156 // TRUSTED AGGREGATOR VALIDATION FLOW 143 - // Kagi aggregator is authorized via KAGI_AGGREGATOR_DID env var (temporary) 157 + // Trusted aggregators are authorized via TRUSTED_AGGREGATOR_DIDS env var (temporary) 144 158 // TODO: Replace with proper XRPC aggregator authorization endpoint 145 - log.Printf("[POST-CREATE] Trusted Kagi aggregator detected: %s posting to community: %s", req.AuthorDID, communityDID) 159 + log.Printf("[POST-CREATE] Trusted aggregator detected: %s posting to community: %s", req.AuthorDID, communityDID) 146 160 // Aggregators skip membership checks and visibility restrictions 147 161 // They are authorized services, not community members 148 162 } else if isOtherAggregator { ··· 219 233 220 234 // TRUSTED AGGREGATOR: Allow Kagi aggregator to provide thumbnail URLs directly 221 235 // This bypasses unfurl for more accurate RSS-sourced thumbnails 222 - if req.ThumbnailURL != nil && *req.ThumbnailURL != "" && isTrustedKagi { 236 + if req.ThumbnailURL != nil && *req.ThumbnailURL != "" && isTrustedAggregator { 223 237 log.Printf("[AGGREGATOR-THUMB] Trusted aggregator provided thumbnail: %s", *req.ThumbnailURL) 224 238 225 239 if s.blobService != nil { ··· 239 253 240 254 // Unfurl enhancement (optional, only if URL is supported) 241 255 // Skip unfurl for trusted aggregators - they provide their own metadata 242 - if !isTrustedKagi { 256 + if !isTrustedAggregator { 243 257 if uri, ok := external["uri"].(string); ok && uri != "" { 244 258 // Check if we support unfurling this URL 245 259 if s.unfurlService != nil && s.unfurlService.IsSupported(uri) { ··· 313 327 314 328 // 13. Return response (AppView will index via Jetstream consumer) 315 329 log.Printf("[POST-CREATE] Author: %s (trustedKagi=%v, otherAggregator=%v), Community: %s, URI: %s", 316 - req.AuthorDID, isTrustedKagi, isOtherAggregator, communityDID, uri) 330 + req.AuthorDID, isTrustedAggregator, isOtherAggregator, communityDID, uri) 317 331 318 332 return &CreatePostResponse{ 319 333 URI: uri,
+77
internal/db/migrations/024_add_aggregator_api_keys.sql
··· 1 + -- +goose Up 2 + -- Add API key authentication and OAuth credential storage for aggregators 3 + -- This enables aggregators to authenticate using API keys backed by OAuth sessions 4 + 5 + -- ============================================================================ 6 + -- Add API key columns to aggregators table 7 + -- ============================================================================ 8 + ALTER TABLE aggregators 9 + -- API key identification (prefix for log correlation, hash for auth) 10 + ADD COLUMN api_key_prefix VARCHAR(12), 11 + ADD COLUMN api_key_hash VARCHAR(64) UNIQUE, 12 + 13 + -- OAuth credentials (encrypted at application layer before storage) 14 + -- SECURITY: These columns contain sensitive OAuth tokens 15 + ADD COLUMN oauth_access_token TEXT, 16 + ADD COLUMN oauth_refresh_token TEXT, 17 + ADD COLUMN oauth_token_expires_at TIMESTAMPTZ, 18 + 19 + -- OAuth session metadata for token refresh 20 + ADD COLUMN oauth_pds_url TEXT, 21 + ADD COLUMN oauth_auth_server_iss TEXT, 22 + ADD COLUMN oauth_auth_server_token_endpoint TEXT, 23 + 24 + -- DPoP keys and nonces for token refresh (multibase encoded) 25 + -- SECURITY: Contains private key material 26 + ADD COLUMN oauth_dpop_private_key_multibase TEXT, 27 + ADD COLUMN oauth_dpop_authserver_nonce TEXT, 28 + ADD COLUMN oauth_dpop_pds_nonce TEXT, 29 + 30 + -- API key lifecycle timestamps 31 + ADD COLUMN api_key_created_at TIMESTAMPTZ, 32 + ADD COLUMN api_key_revoked_at TIMESTAMPTZ, 33 + ADD COLUMN api_key_last_used_at TIMESTAMPTZ; 34 + 35 + -- Index for API key lookup during authentication 36 + -- Partial index excludes NULL values since not all aggregators have API keys 37 + CREATE INDEX idx_aggregators_api_key_hash 38 + ON aggregators(api_key_hash) 39 + WHERE api_key_hash IS NOT NULL; 40 + 41 + -- ============================================================================ 42 + -- Security comments on sensitive columns 43 + -- ============================================================================ 44 + COMMENT ON COLUMN aggregators.api_key_prefix IS 'First 12 characters of API key for identification in logs (not secret)'; 45 + COMMENT ON COLUMN aggregators.api_key_hash IS 'SHA-256 hash of full API key for authentication lookup'; 46 + COMMENT ON COLUMN aggregators.oauth_access_token IS 'SENSITIVE: Encrypted OAuth access token for PDS operations'; 47 + COMMENT ON COLUMN aggregators.oauth_refresh_token IS 'SENSITIVE: Encrypted OAuth refresh token for session renewal'; 48 + COMMENT ON COLUMN aggregators.oauth_token_expires_at IS 'When the OAuth access token expires (triggers refresh)'; 49 + COMMENT ON COLUMN aggregators.oauth_pds_url IS 'PDS URL for this aggregators OAuth session'; 50 + COMMENT ON COLUMN aggregators.oauth_auth_server_iss IS 'OAuth authorization server issuer URL'; 51 + COMMENT ON COLUMN aggregators.oauth_auth_server_token_endpoint IS 'OAuth token refresh endpoint URL'; 52 + COMMENT ON COLUMN aggregators.oauth_dpop_private_key_multibase IS 'SENSITIVE: DPoP private key in multibase format for token refresh'; 53 + COMMENT ON COLUMN aggregators.oauth_dpop_authserver_nonce IS 'Latest DPoP nonce from authorization server'; 54 + COMMENT ON COLUMN aggregators.oauth_dpop_pds_nonce IS 'Latest DPoP nonce from PDS'; 55 + COMMENT ON COLUMN aggregators.api_key_created_at IS 'When the API key was generated'; 56 + COMMENT ON COLUMN aggregators.api_key_revoked_at IS 'When the API key was revoked (NULL = active)'; 57 + COMMENT ON COLUMN aggregators.api_key_last_used_at IS 'Last successful authentication using this API key'; 58 + 59 + -- +goose Down 60 + -- Remove API key columns from aggregators table 61 + DROP INDEX IF EXISTS idx_aggregators_api_key_hash; 62 + 63 + ALTER TABLE aggregators 64 + DROP COLUMN IF EXISTS api_key_prefix, 65 + DROP COLUMN IF EXISTS api_key_hash, 66 + DROP COLUMN IF EXISTS oauth_access_token, 67 + DROP COLUMN IF EXISTS oauth_refresh_token, 68 + DROP COLUMN IF EXISTS oauth_token_expires_at, 69 + DROP COLUMN IF EXISTS oauth_pds_url, 70 + DROP COLUMN IF EXISTS oauth_auth_server_iss, 71 + DROP COLUMN IF EXISTS oauth_auth_server_token_endpoint, 72 + DROP COLUMN IF EXISTS oauth_dpop_private_key_multibase, 73 + DROP COLUMN IF EXISTS oauth_dpop_authserver_nonce, 74 + DROP COLUMN IF EXISTS oauth_dpop_pds_nonce, 75 + DROP COLUMN IF EXISTS api_key_created_at, 76 + DROP COLUMN IF EXISTS api_key_revoked_at, 77 + DROP COLUMN IF EXISTS api_key_last_used_at;
+92
internal/db/migrations/025_encrypt_aggregator_oauth_tokens.sql
··· 1 + -- +goose Up 2 + -- Encrypt aggregator OAuth tokens at rest using pgp_sym_encrypt 3 + -- This addresses the security issue where OAuth tokens were stored in plaintext 4 + -- despite migration 024 claiming "encrypted at application layer before storage" 5 + 6 + -- +goose StatementBegin 7 + 8 + -- Step 1: Add new encrypted columns for OAuth tokens and DPoP private key 9 + ALTER TABLE aggregators 10 + ADD COLUMN oauth_access_token_encrypted BYTEA, 11 + ADD COLUMN oauth_refresh_token_encrypted BYTEA, 12 + ADD COLUMN oauth_dpop_private_key_encrypted BYTEA; 13 + 14 + -- Step 2: Migrate existing plaintext data to encrypted columns 15 + -- Uses the same encryption key table as community credentials (migration 006) 16 + UPDATE aggregators 17 + SET 18 + oauth_access_token_encrypted = CASE 19 + WHEN oauth_access_token IS NOT NULL AND oauth_access_token != '' 20 + THEN pgp_sym_encrypt(oauth_access_token, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 21 + ELSE NULL 22 + END, 23 + oauth_refresh_token_encrypted = CASE 24 + WHEN oauth_refresh_token IS NOT NULL AND oauth_refresh_token != '' 25 + THEN pgp_sym_encrypt(oauth_refresh_token, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 26 + ELSE NULL 27 + END, 28 + oauth_dpop_private_key_encrypted = CASE 29 + WHEN oauth_dpop_private_key_multibase IS NOT NULL AND oauth_dpop_private_key_multibase != '' 30 + THEN pgp_sym_encrypt(oauth_dpop_private_key_multibase, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 31 + ELSE NULL 32 + END 33 + WHERE oauth_access_token IS NOT NULL 34 + OR oauth_refresh_token IS NOT NULL 35 + OR oauth_dpop_private_key_multibase IS NOT NULL; 36 + 37 + -- Step 3: Drop the old plaintext columns 38 + ALTER TABLE aggregators 39 + DROP COLUMN oauth_access_token, 40 + DROP COLUMN oauth_refresh_token, 41 + DROP COLUMN oauth_dpop_private_key_multibase; 42 + 43 + -- Step 4: Add security comments 44 + COMMENT ON COLUMN aggregators.oauth_access_token_encrypted IS 'SENSITIVE: Encrypted OAuth access token (pgp_sym_encrypt) for PDS operations'; 45 + COMMENT ON COLUMN aggregators.oauth_refresh_token_encrypted IS 'SENSITIVE: Encrypted OAuth refresh token (pgp_sym_encrypt) for session renewal'; 46 + COMMENT ON COLUMN aggregators.oauth_dpop_private_key_encrypted IS 'SENSITIVE: Encrypted DPoP private key (pgp_sym_encrypt) for token refresh'; 47 + 48 + -- +goose StatementEnd 49 + 50 + -- +goose Down 51 + -- +goose StatementBegin 52 + 53 + -- Restore plaintext columns 54 + ALTER TABLE aggregators 55 + ADD COLUMN oauth_access_token TEXT, 56 + ADD COLUMN oauth_refresh_token TEXT, 57 + ADD COLUMN oauth_dpop_private_key_multibase TEXT; 58 + 59 + -- Decrypt data back to plaintext (for rollback) 60 + UPDATE aggregators 61 + SET 62 + oauth_access_token = CASE 63 + WHEN oauth_access_token_encrypted IS NOT NULL 64 + THEN pgp_sym_decrypt(oauth_access_token_encrypted, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 65 + ELSE NULL 66 + END, 67 + oauth_refresh_token = CASE 68 + WHEN oauth_refresh_token_encrypted IS NOT NULL 69 + THEN pgp_sym_decrypt(oauth_refresh_token_encrypted, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 70 + ELSE NULL 71 + END, 72 + oauth_dpop_private_key_multibase = CASE 73 + WHEN oauth_dpop_private_key_encrypted IS NOT NULL 74 + THEN pgp_sym_decrypt(oauth_dpop_private_key_encrypted, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 75 + ELSE NULL 76 + END 77 + WHERE oauth_access_token_encrypted IS NOT NULL 78 + OR oauth_refresh_token_encrypted IS NOT NULL 79 + OR oauth_dpop_private_key_encrypted IS NOT NULL; 80 + 81 + -- Drop encrypted columns 82 + ALTER TABLE aggregators 83 + DROP COLUMN oauth_access_token_encrypted, 84 + DROP COLUMN oauth_refresh_token_encrypted, 85 + DROP COLUMN oauth_dpop_private_key_encrypted; 86 + 87 + -- Restore comments 88 + COMMENT ON COLUMN aggregators.oauth_access_token IS 'SENSITIVE: OAuth access token for PDS operations'; 89 + COMMENT ON COLUMN aggregators.oauth_refresh_token IS 'SENSITIVE: OAuth refresh token for session renewal'; 90 + COMMENT ON COLUMN aggregators.oauth_dpop_private_key_multibase IS 'SENSITIVE: DPoP private key in multibase format for token refresh'; 91 + 92 + -- +goose StatementEnd
+415 -6
internal/db/postgres/aggregator_repo.go
··· 69 69 } 70 70 71 71 // GetAggregator retrieves an aggregator by DID 72 + // Returns only public/display fields - use GetAggregatorCredentials for authentication data 72 73 func (r *postgresAggregatorRepo) GetAggregator(ctx context.Context, did string) (*aggregators.Aggregator, error) { 73 74 query := ` 74 75 SELECT ··· 79 80 WHERE did = $1` 80 81 81 82 agg := &aggregators.Aggregator{} 82 - var description, avatarCID, maintainerDID, homepageURL, recordURI, recordCID sql.NullString 83 + var description, avatarURL, maintainerDID, sourceURL, recordURI, recordCID sql.NullString 83 84 var configSchema []byte 84 85 85 86 err := r.db.QueryRowContext(ctx, query, did).Scan( 86 87 &agg.DID, 87 88 &agg.DisplayName, 88 89 &description, 89 - &avatarCID, 90 + &avatarURL, 90 91 &configSchema, 91 92 &maintainerDID, 92 - &homepageURL, 93 + &sourceURL, 93 94 &agg.CommunitiesUsing, 94 95 &agg.PostsCreated, 95 96 &agg.CreatedAt, ··· 105 106 return nil, fmt.Errorf("failed to get aggregator: %w", err) 106 107 } 107 108 108 - // Map nullable fields 109 + // Map nullable string fields 109 110 agg.Description = description.String 110 - agg.AvatarURL = avatarCID.String 111 + agg.AvatarURL = avatarURL.String 111 112 agg.MaintainerDID = maintainerDID.String 112 - agg.SourceURL = homepageURL.String 113 + agg.SourceURL = sourceURL.String 113 114 agg.RecordURI = recordURI.String 114 115 agg.RecordCID = recordCID.String 116 + 115 117 if configSchema != nil { 116 118 agg.ConfigSchema = configSchema 117 119 } ··· 753 755 } 754 756 755 757 return posts, nil 758 + } 759 + 760 + // ===== API Key Authentication Operations ===== 761 + 762 + // GetByAPIKeyHash looks up an aggregator by their API key hash for authentication 763 + // Returns ErrAggregatorNotFound if no aggregator exists with that key hash 764 + // Returns ErrAPIKeyRevoked if the API key has been revoked 765 + // Note: Returns only public Aggregator fields - use GetCredentialsByAPIKeyHash for credentials 766 + func (r *postgresAggregatorRepo) GetByAPIKeyHash(ctx context.Context, keyHash string) (*aggregators.Aggregator, error) { 767 + query := ` 768 + SELECT 769 + did, display_name, description, avatar_url, config_schema, 770 + maintainer_did, source_url, communities_using, posts_created, 771 + created_at, indexed_at, record_uri, record_cid, 772 + api_key_revoked_at 773 + FROM aggregators 774 + WHERE api_key_hash = $1` 775 + 776 + agg := &aggregators.Aggregator{} 777 + var description, avatarURL, maintainerDID, sourceURL, recordURI, recordCID sql.NullString 778 + var configSchema []byte 779 + var apiKeyRevokedAt sql.NullTime 780 + 781 + err := r.db.QueryRowContext(ctx, query, keyHash).Scan( 782 + &agg.DID, 783 + &agg.DisplayName, 784 + &description, 785 + &avatarURL, 786 + &configSchema, 787 + &maintainerDID, 788 + &sourceURL, 789 + &agg.CommunitiesUsing, 790 + &agg.PostsCreated, 791 + &agg.CreatedAt, 792 + &agg.IndexedAt, 793 + &recordURI, 794 + &recordCID, 795 + &apiKeyRevokedAt, 796 + ) 797 + 798 + if err == sql.ErrNoRows { 799 + return nil, aggregators.ErrAggregatorNotFound 800 + } 801 + if err != nil { 802 + return nil, fmt.Errorf("failed to get aggregator by API key hash: %w", err) 803 + } 804 + 805 + // Check if API key is revoked before returning 806 + if apiKeyRevokedAt.Valid { 807 + return nil, aggregators.ErrAPIKeyRevoked 808 + } 809 + 810 + // Map nullable string fields 811 + agg.Description = description.String 812 + agg.AvatarURL = avatarURL.String 813 + agg.MaintainerDID = maintainerDID.String 814 + agg.SourceURL = sourceURL.String 815 + agg.RecordURI = recordURI.String 816 + agg.RecordCID = recordCID.String 817 + 818 + if configSchema != nil { 819 + agg.ConfigSchema = configSchema 820 + } 821 + 822 + return agg, nil 823 + } 824 + 825 + // SetAPIKey stores API key credentials and OAuth session for an aggregator 826 + // This is called after successful OAuth flow to generate the API key 827 + // SECURITY: OAuth tokens and DPoP private key are encrypted at rest using pgp_sym_encrypt 828 + func (r *postgresAggregatorRepo) SetAPIKey(ctx context.Context, did, keyPrefix, keyHash string, oauthCreds *aggregators.OAuthCredentials) error { 829 + query := ` 830 + UPDATE aggregators SET 831 + api_key_prefix = $2, 832 + api_key_hash = $3, 833 + api_key_created_at = NOW(), 834 + api_key_revoked_at = NULL, 835 + oauth_access_token_encrypted = CASE WHEN $4 != '' THEN pgp_sym_encrypt($4, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) ELSE NULL END, 836 + oauth_refresh_token_encrypted = CASE WHEN $5 != '' THEN pgp_sym_encrypt($5, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) ELSE NULL END, 837 + oauth_token_expires_at = $6, 838 + oauth_pds_url = $7, 839 + oauth_auth_server_iss = $8, 840 + oauth_auth_server_token_endpoint = $9, 841 + oauth_dpop_private_key_encrypted = CASE WHEN $10 != '' THEN pgp_sym_encrypt($10, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) ELSE NULL END, 842 + oauth_dpop_authserver_nonce = $11, 843 + oauth_dpop_pds_nonce = $12 844 + WHERE did = $1` 845 + 846 + result, err := r.db.ExecContext(ctx, query, 847 + did, 848 + keyPrefix, 849 + keyHash, 850 + oauthCreds.AccessToken, 851 + oauthCreds.RefreshToken, 852 + oauthCreds.TokenExpiresAt, 853 + oauthCreds.PDSURL, 854 + oauthCreds.AuthServerIss, 855 + oauthCreds.AuthServerTokenEndpoint, 856 + oauthCreds.DPoPPrivateKeyMultibase, 857 + oauthCreds.DPoPAuthServerNonce, 858 + oauthCreds.DPoPPDSNonce, 859 + ) 860 + if err != nil { 861 + return fmt.Errorf("failed to set API key: %w", err) 862 + } 863 + 864 + rows, err := result.RowsAffected() 865 + if err != nil { 866 + return fmt.Errorf("failed to get rows affected: %w", err) 867 + } 868 + if rows == 0 { 869 + return aggregators.ErrAggregatorNotFound 870 + } 871 + 872 + return nil 873 + } 874 + 875 + // UpdateOAuthTokens updates OAuth tokens after a refresh operation 876 + // Called after successfully refreshing an expired access token 877 + // SECURITY: OAuth tokens are encrypted at rest using pgp_sym_encrypt 878 + func (r *postgresAggregatorRepo) UpdateOAuthTokens(ctx context.Context, did, accessToken, refreshToken string, expiresAt time.Time) error { 879 + query := ` 880 + UPDATE aggregators SET 881 + oauth_access_token_encrypted = pgp_sym_encrypt($2, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)), 882 + oauth_refresh_token_encrypted = pgp_sym_encrypt($3, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)), 883 + oauth_token_expires_at = $4 884 + WHERE did = $1` 885 + 886 + result, err := r.db.ExecContext(ctx, query, did, accessToken, refreshToken, expiresAt) 887 + if err != nil { 888 + return fmt.Errorf("failed to update OAuth tokens: %w", err) 889 + } 890 + 891 + rows, err := result.RowsAffected() 892 + if err != nil { 893 + return fmt.Errorf("failed to get rows affected: %w", err) 894 + } 895 + if rows == 0 { 896 + return aggregators.ErrAggregatorNotFound 897 + } 898 + 899 + return nil 900 + } 901 + 902 + // UpdateOAuthNonces updates DPoP nonces after token operations 903 + // Nonces are updated after each request to the auth server or PDS 904 + func (r *postgresAggregatorRepo) UpdateOAuthNonces(ctx context.Context, did, authServerNonce, pdsNonce string) error { 905 + query := ` 906 + UPDATE aggregators SET 907 + oauth_dpop_authserver_nonce = COALESCE(NULLIF($2, ''), oauth_dpop_authserver_nonce), 908 + oauth_dpop_pds_nonce = COALESCE(NULLIF($3, ''), oauth_dpop_pds_nonce) 909 + WHERE did = $1` 910 + 911 + result, err := r.db.ExecContext(ctx, query, did, authServerNonce, pdsNonce) 912 + if err != nil { 913 + return fmt.Errorf("failed to update OAuth nonces: %w", err) 914 + } 915 + 916 + rows, err := result.RowsAffected() 917 + if err != nil { 918 + return fmt.Errorf("failed to get rows affected: %w", err) 919 + } 920 + if rows == 0 { 921 + return aggregators.ErrAggregatorNotFound 922 + } 923 + 924 + return nil 925 + } 926 + 927 + // UpdateAPIKeyLastUsed updates the last_used_at timestamp for audit purposes 928 + // Called on each successful authentication to track API key usage 929 + func (r *postgresAggregatorRepo) UpdateAPIKeyLastUsed(ctx context.Context, did string) error { 930 + query := ` 931 + UPDATE aggregators SET 932 + api_key_last_used_at = NOW() 933 + WHERE did = $1` 934 + 935 + result, err := r.db.ExecContext(ctx, query, did) 936 + if err != nil { 937 + return fmt.Errorf("failed to update API key last used: %w", err) 938 + } 939 + 940 + rows, err := result.RowsAffected() 941 + if err != nil { 942 + return fmt.Errorf("failed to get rows affected: %w", err) 943 + } 944 + if rows == 0 { 945 + return aggregators.ErrAggregatorNotFound 946 + } 947 + 948 + return nil 949 + } 950 + 951 + // RevokeAPIKey marks an API key as revoked (sets api_key_revoked_at) 952 + // After revocation, the aggregator must complete OAuth flow again to get a new key 953 + func (r *postgresAggregatorRepo) RevokeAPIKey(ctx context.Context, did string) error { 954 + query := ` 955 + UPDATE aggregators SET 956 + api_key_revoked_at = NOW() 957 + WHERE did = $1 AND api_key_hash IS NOT NULL` 958 + 959 + result, err := r.db.ExecContext(ctx, query, did) 960 + if err != nil { 961 + return fmt.Errorf("failed to revoke API key: %w", err) 962 + } 963 + 964 + rows, err := result.RowsAffected() 965 + if err != nil { 966 + return fmt.Errorf("failed to get rows affected: %w", err) 967 + } 968 + if rows == 0 { 969 + return aggregators.ErrAggregatorNotFound 970 + } 971 + 972 + return nil 973 + } 974 + 975 + // GetAggregatorCredentials retrieves only credential data for an aggregator 976 + // Used by APIKeyService for authentication operations where full aggregator is not needed 977 + func (r *postgresAggregatorRepo) GetAggregatorCredentials(ctx context.Context, did string) (*aggregators.AggregatorCredentials, error) { 978 + query := ` 979 + SELECT 980 + did, 981 + api_key_prefix, api_key_hash, api_key_created_at, api_key_revoked_at, api_key_last_used_at, 982 + CASE 983 + WHEN oauth_access_token_encrypted IS NOT NULL 984 + THEN pgp_sym_decrypt(oauth_access_token_encrypted, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 985 + ELSE NULL 986 + END as oauth_access_token, 987 + CASE 988 + WHEN oauth_refresh_token_encrypted IS NOT NULL 989 + THEN pgp_sym_decrypt(oauth_refresh_token_encrypted, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 990 + ELSE NULL 991 + END as oauth_refresh_token, 992 + oauth_token_expires_at, 993 + oauth_pds_url, oauth_auth_server_iss, oauth_auth_server_token_endpoint, 994 + CASE 995 + WHEN oauth_dpop_private_key_encrypted IS NOT NULL 996 + THEN pgp_sym_decrypt(oauth_dpop_private_key_encrypted, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 997 + ELSE NULL 998 + END as oauth_dpop_private_key_multibase, 999 + oauth_dpop_authserver_nonce, oauth_dpop_pds_nonce 1000 + FROM aggregators 1001 + WHERE did = $1` 1002 + 1003 + creds := &aggregators.AggregatorCredentials{} 1004 + var apiKeyPrefix, apiKeyHash sql.NullString 1005 + var oauthAccessToken, oauthRefreshToken sql.NullString 1006 + var oauthPDSURL, oauthAuthServerIss, oauthAuthServerTokenEndpoint sql.NullString 1007 + var oauthDPoPPrivateKey, oauthDPoPAuthServerNonce, oauthDPoPPDSNonce sql.NullString 1008 + var apiKeyCreatedAt, apiKeyRevokedAt, apiKeyLastUsed, oauthTokenExpiresAt sql.NullTime 1009 + 1010 + err := r.db.QueryRowContext(ctx, query, did).Scan( 1011 + &creds.DID, 1012 + &apiKeyPrefix, 1013 + &apiKeyHash, 1014 + &apiKeyCreatedAt, 1015 + &apiKeyRevokedAt, 1016 + &apiKeyLastUsed, 1017 + &oauthAccessToken, 1018 + &oauthRefreshToken, 1019 + &oauthTokenExpiresAt, 1020 + &oauthPDSURL, 1021 + &oauthAuthServerIss, 1022 + &oauthAuthServerTokenEndpoint, 1023 + &oauthDPoPPrivateKey, 1024 + &oauthDPoPAuthServerNonce, 1025 + &oauthDPoPPDSNonce, 1026 + ) 1027 + 1028 + if err == sql.ErrNoRows { 1029 + return nil, aggregators.ErrAggregatorNotFound 1030 + } 1031 + if err != nil { 1032 + return nil, fmt.Errorf("failed to get aggregator credentials: %w", err) 1033 + } 1034 + 1035 + // Map nullable string fields 1036 + creds.APIKeyPrefix = apiKeyPrefix.String 1037 + creds.APIKeyHash = apiKeyHash.String 1038 + creds.OAuthAccessToken = oauthAccessToken.String 1039 + creds.OAuthRefreshToken = oauthRefreshToken.String 1040 + creds.OAuthPDSURL = oauthPDSURL.String 1041 + creds.OAuthAuthServerIss = oauthAuthServerIss.String 1042 + creds.OAuthAuthServerTokenEndpoint = oauthAuthServerTokenEndpoint.String 1043 + creds.OAuthDPoPPrivateKeyMultibase = oauthDPoPPrivateKey.String 1044 + creds.OAuthDPoPAuthServerNonce = oauthDPoPAuthServerNonce.String 1045 + creds.OAuthDPoPPDSNonce = oauthDPoPPDSNonce.String 1046 + 1047 + // Map nullable time fields 1048 + if apiKeyCreatedAt.Valid { 1049 + t := apiKeyCreatedAt.Time 1050 + creds.APIKeyCreatedAt = &t 1051 + } 1052 + if apiKeyRevokedAt.Valid { 1053 + t := apiKeyRevokedAt.Time 1054 + creds.APIKeyRevokedAt = &t 1055 + } 1056 + if apiKeyLastUsed.Valid { 1057 + t := apiKeyLastUsed.Time 1058 + creds.APIKeyLastUsed = &t 1059 + } 1060 + if oauthTokenExpiresAt.Valid { 1061 + t := oauthTokenExpiresAt.Time 1062 + creds.OAuthTokenExpiresAt = &t 1063 + } 1064 + 1065 + return creds, nil 1066 + } 1067 + 1068 + // GetCredentialsByAPIKeyHash looks up credentials by API key hash for authentication 1069 + // Returns ErrAPIKeyRevoked if the API key has been revoked 1070 + // Returns ErrAPIKeyInvalid if no aggregator found with that hash 1071 + func (r *postgresAggregatorRepo) GetCredentialsByAPIKeyHash(ctx context.Context, keyHash string) (*aggregators.AggregatorCredentials, error) { 1072 + query := ` 1073 + SELECT 1074 + did, 1075 + api_key_prefix, api_key_hash, api_key_created_at, api_key_revoked_at, api_key_last_used_at, 1076 + CASE 1077 + WHEN oauth_access_token_encrypted IS NOT NULL 1078 + THEN pgp_sym_decrypt(oauth_access_token_encrypted, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 1079 + ELSE NULL 1080 + END as oauth_access_token, 1081 + CASE 1082 + WHEN oauth_refresh_token_encrypted IS NOT NULL 1083 + THEN pgp_sym_decrypt(oauth_refresh_token_encrypted, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 1084 + ELSE NULL 1085 + END as oauth_refresh_token, 1086 + oauth_token_expires_at, 1087 + oauth_pds_url, oauth_auth_server_iss, oauth_auth_server_token_endpoint, 1088 + CASE 1089 + WHEN oauth_dpop_private_key_encrypted IS NOT NULL 1090 + THEN pgp_sym_decrypt(oauth_dpop_private_key_encrypted, (SELECT encode(key_data, 'hex') FROM encryption_keys WHERE id = 1)) 1091 + ELSE NULL 1092 + END as oauth_dpop_private_key_multibase, 1093 + oauth_dpop_authserver_nonce, oauth_dpop_pds_nonce 1094 + FROM aggregators 1095 + WHERE api_key_hash = $1` 1096 + 1097 + creds := &aggregators.AggregatorCredentials{} 1098 + var apiKeyPrefix, apiKeyHash sql.NullString 1099 + var oauthAccessToken, oauthRefreshToken sql.NullString 1100 + var oauthPDSURL, oauthAuthServerIss, oauthAuthServerTokenEndpoint sql.NullString 1101 + var oauthDPoPPrivateKey, oauthDPoPAuthServerNonce, oauthDPoPPDSNonce sql.NullString 1102 + var apiKeyCreatedAt, apiKeyRevokedAt, apiKeyLastUsed, oauthTokenExpiresAt sql.NullTime 1103 + 1104 + err := r.db.QueryRowContext(ctx, query, keyHash).Scan( 1105 + &creds.DID, 1106 + &apiKeyPrefix, 1107 + &apiKeyHash, 1108 + &apiKeyCreatedAt, 1109 + &apiKeyRevokedAt, 1110 + &apiKeyLastUsed, 1111 + &oauthAccessToken, 1112 + &oauthRefreshToken, 1113 + &oauthTokenExpiresAt, 1114 + &oauthPDSURL, 1115 + &oauthAuthServerIss, 1116 + &oauthAuthServerTokenEndpoint, 1117 + &oauthDPoPPrivateKey, 1118 + &oauthDPoPAuthServerNonce, 1119 + &oauthDPoPPDSNonce, 1120 + ) 1121 + 1122 + if err == sql.ErrNoRows { 1123 + return nil, aggregators.ErrAPIKeyInvalid 1124 + } 1125 + if err != nil { 1126 + return nil, fmt.Errorf("failed to get credentials by API key hash: %w", err) 1127 + } 1128 + 1129 + // Map nullable string fields 1130 + creds.APIKeyPrefix = apiKeyPrefix.String 1131 + creds.APIKeyHash = apiKeyHash.String 1132 + creds.OAuthAccessToken = oauthAccessToken.String 1133 + creds.OAuthRefreshToken = oauthRefreshToken.String 1134 + creds.OAuthPDSURL = oauthPDSURL.String 1135 + creds.OAuthAuthServerIss = oauthAuthServerIss.String 1136 + creds.OAuthAuthServerTokenEndpoint = oauthAuthServerTokenEndpoint.String 1137 + creds.OAuthDPoPPrivateKeyMultibase = oauthDPoPPrivateKey.String 1138 + creds.OAuthDPoPAuthServerNonce = oauthDPoPAuthServerNonce.String 1139 + creds.OAuthDPoPPDSNonce = oauthDPoPPDSNonce.String 1140 + 1141 + // Map nullable time fields 1142 + if apiKeyCreatedAt.Valid { 1143 + t := apiKeyCreatedAt.Time 1144 + creds.APIKeyCreatedAt = &t 1145 + } 1146 + if apiKeyRevokedAt.Valid { 1147 + t := apiKeyRevokedAt.Time 1148 + creds.APIKeyRevokedAt = &t 1149 + } 1150 + if apiKeyLastUsed.Valid { 1151 + t := apiKeyLastUsed.Time 1152 + creds.APIKeyLastUsed = &t 1153 + } 1154 + if oauthTokenExpiresAt.Valid { 1155 + t := oauthTokenExpiresAt.Time 1156 + creds.OAuthTokenExpiresAt = &t 1157 + } 1158 + 1159 + // Check if API key is revoked 1160 + if creds.APIKeyRevokedAt != nil { 1161 + return nil, aggregators.ErrAPIKeyRevoked 1162 + } 1163 + 1164 + return creds, nil 756 1165 } 757 1166 758 1167 // ===== Helper Functions =====