···5};
6use axum::{extract::FromRequestParts, http::StatusCode};
7use base64::Engine as _;
08use sha2::{Digest as _, Sha256};
910-use crate::{AppState, Error, error::ErrorMessage};
1112/// Request extractor for authenticated users.
13/// If specified in an API endpoint, this guarantees the API can only be called
···129130 // Extract subject (DID)
131 if let Some(did) = claims.get("sub").and_then(serde_json::Value::as_str) {
132- let _status = sqlx::query_scalar!(r#"SELECT status FROM accounts WHERE did = ?"#, did)
133- .fetch_one(&state.db)
000000000134 .await
135 .with_context(|| format!("failed to query account {did}"))
136 .context("should fetch account status")?;
···326327 let timestamp = chrono::Utc::now().timestamp();
328000329 // Check if JTI has been used before
330- let jti_used =
331- sqlx::query_scalar!(r#"SELECT COUNT(*) FROM oauth_used_jtis WHERE jti = ?"#, jti)
332- .fetch_one(&state.db)
333- .await
334- .context("failed to check JTI")?;
000000335336 if jti_used > 0 {
337 return Err(Error::with_status(
···347 .and_then(serde_json::Value::as_i64)
348 .unwrap_or_else(|| timestamp.checked_add(60).unwrap_or(timestamp));
349350- _ = sqlx::query!(
351- r#"
352- INSERT INTO oauth_used_jtis (jti, issuer, created_at, expires_at)
353- VALUES (?, ?, ?, ?)
354- "#,
355- jti,
356- calculated_thumbprint, // Use thumbprint as issuer identifier
357- timestamp,
358- exp
359- )
360- .execute(&state.db)
361- .await
362- .context("failed to store JTI")?;
0000363364 // Extract subject (DID) from access token
365- if let Some(did) = claims.get("sub").and_then(|v| v.as_str()) {
366- let _status = sqlx::query_scalar!(r#"SELECT status FROM accounts WHERE did = ?"#, did)
367- .fetch_one(&state.db)
000000000368 .await
369 .with_context(|| format!("failed to query account {did}"))
370 .context("should fetch account status")?;
···5};
6use axum::{extract::FromRequestParts, http::StatusCode};
7use base64::Engine as _;
8+use diesel::prelude::*;
9use sha2::{Digest as _, Sha256};
1011+use crate::{AppState, Error, db::DbConn, error::ErrorMessage};
1213/// Request extractor for authenticated users.
14/// If specified in an API endpoint, this guarantees the API can only be called
···130131 // Extract subject (DID)
132 if let Some(did) = claims.get("sub").and_then(serde_json::Value::as_str) {
133+ // Convert SQLx query to Diesel query
134+ use crate::schema::accounts::dsl as AccountSchema;
135+136+ let _status = state
137+ .db
138+ .run(move |conn| {
139+ AccountSchema::accounts
140+ .filter(AccountSchema::did.eq(did.to_string()))
141+ .select(AccountSchema::status)
142+ .first::<String>(conn)
143+ })
144 .await
145 .with_context(|| format!("failed to query account {did}"))
146 .context("should fetch account status")?;
···336337 let timestamp = chrono::Utc::now().timestamp();
338339+ // Convert SQLx JTI check to Diesel
340+ use crate::schema::oauth_used_jtis::dsl as JtiSchema;
341+342 // Check if JTI has been used before
343+ let jti_string = jti.to_string();
344+ let jti_used = state
345+ .db
346+ .run(move |conn| {
347+ JtiSchema::oauth_used_jtis
348+ .filter(JtiSchema::jti.eq(jti_string))
349+ .count()
350+ .get_result::<i64>(conn)
351+ })
352+ .await
353+ .context("failed to check JTI")?;
354355 if jti_used > 0 {
356 return Err(Error::with_status(
···366 .and_then(serde_json::Value::as_i64)
367 .unwrap_or_else(|| timestamp.checked_add(60).unwrap_or(timestamp));
368369+ // Convert SQLx INSERT to Diesel
370+ let jti_str = jti.to_string();
371+ let thumbprint_str = calculated_thumbprint.to_string();
372+ state
373+ .db
374+ .run(move |conn| {
375+ diesel::insert_into(JtiSchema::oauth_used_jtis)
376+ .values((
377+ JtiSchema::jti.eq(jti_str),
378+ JtiSchema::issuer.eq(thumbprint_str),
379+ JtiSchema::created_at.eq(timestamp),
380+ JtiSchema::expires_at.eq(exp),
381+ ))
382+ .execute(conn)
383+ })
384+ .await
385+ .context("failed to store JTI")?;
386387 // Extract subject (DID) from access token
388+ if let Some(did) = claims.get("sub").and_then(|v| v.as_str) {
389+ // Convert SQLx query to Diesel
390+ use crate::schema::accounts::dsl as AccountSchema;
391+392+ let _status = state
393+ .db
394+ .run(move |conn| {
395+ AccountSchema::accounts
396+ .filter(AccountSchema::did.eq(did.to_string()))
397+ .select(AccountSchema::status)
398+ .first::<String>(conn)
399+ })
400 .await
401 .with_context(|| format!("failed to query account {did}"))
402 .context("should fetch account status")?;
···1use anyhow::Result;
2+use deadpool_diesel::sqlite::{Manager, Pool, Runtime};
000000000000034#[tracing::instrument(skip_all)]
5+/// Establish a connection to the database
6+/// Takes a database URL as an argument (like "sqlite://data/sqlite.db")
7+pub(crate) fn establish_pool(database_url: &str) -> Result<Pool> {
8+ tracing::debug!("Establishing database connection");
9+ let manager = Manager::new(database_url, Runtime::Tokio1);
10+ let pool = Pool::builder(manager)
11+ .max_size(8)
12+ .build()
13+ .expect("should be able to create connection pool");
14+ tracing::debug!("Database connection established");
15+ Ok(pool)
16}
···37use clap::Parser;
38use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter};
39use config::AppConfig;
0040use diesel::prelude::*;
41-use diesel::r2d2::{self, ConnectionManager};
42-use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
43#[expect(clippy::pub_use, clippy::useless_attribute)]
44pub use error::Error;
45use figment::{Figment, providers::Format as _};
···68pub type Result<T> = std::result::Result<T, Error>;
69/// The reqwest client type with middleware.
70pub type Client = reqwest_middleware::ClientWithMiddleware;
71-/// The database connection pool.
72-pub type Db = r2d2::Pool<ConnectionManager<SqliteConnection>>;
73/// The Azure credential type.
74pub type Cred = Arc<dyn TokenCredential>;
75···132 verbosity: Verbosity<InfoLevel>,
133}
1340000000000000135#[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")]
136#[derive(Clone, FromRef)]
137struct AppState {
···139 config: AppConfig,
140 /// The Azure credential.
141 cred: Cred,
142- /// The database connection pool.
143- db: Db,
00144145 /// The HTTP client with middleware.
146 client: Client,
···291#[expect(
292 clippy::cognitive_complexity,
293 clippy::too_many_lines,
0294 reason = "main function has high complexity"
295)]
296async fn run() -> anyhow::Result<()> {
···388 let cred = azure_identity::DefaultAzureCredential::new()
389 .context("failed to create Azure credential")?;
390391- // Create a database connection manager and pool
392- let manager = ConnectionManager::<SqliteConnection>::new(&config.db);
393- let db = r2d2::Pool::builder()
394- .build(manager)
395- .context("failed to create database connection pool")?;
000000000000000003960000000000000000000397 // Apply pending migrations
398- let conn = &mut db
399- .get()
400- .context("failed to get database connection for migrations")?;
401- conn.run_pending_migrations(MIGRATIONS)
402- .expect("should be able to run migrations");
403404 let (_fh, fhp) = firehose::spawn(client.clone(), config.clone());
405···422 .with_state(AppState {
423 cred,
424 config: config.clone(),
425- db: db.clone(),
0426 client: client.clone(),
427 simple_client,
428 firehose: fhp,
···435436 // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created).
437 // If so, create an invite code and share it via the console.
438- let conn = &mut db.get().context("failed to get database connection")?;
439440 #[derive(QueryableByName)]
441 struct TotalCount {
···443 total_count: i32,
444 }
445446- let result = diesel::sql_query(
447- "SELECT (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) AS total_count",
448- )
449- .get_result::<TotalCount>(conn)
450- .context("failed to query database")?;
00000000451452 let c = result.total_count;
453···455 if c == 0 {
456 let uuid = Uuid::new_v4().to_string();
457458- diesel::sql_query(
00459 "INSERT INTO invites (id, did, count, created_at) VALUES (?, NULL, 1, datetime('now'))",
460 )
461- .bind::<diesel::sql_types::Text, _>(uuid.clone())
462 .execute(conn)
463- .context("failed to create new invite code")?;
00464465 // N.B: This is a sensitive message, so we're bypassing `tracing` here and
466 // logging it directly to console.
···37use clap::Parser;
38use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter};
39use config::AppConfig;
40+use db::establish_pool;
41+use deadpool_diesel::sqlite::Pool;
42use diesel::prelude::*;
43+use diesel_migrations::{EmbeddedMigrations, embed_migrations};
044#[expect(clippy::pub_use, clippy::useless_attribute)]
45pub use error::Error;
46use figment::{Figment, providers::Format as _};
···69pub type Result<T> = std::result::Result<T, Error>;
70/// The reqwest client type with middleware.
71pub type Client = reqwest_middleware::ClientWithMiddleware;
0072/// The Azure credential type.
73pub type Cred = Arc<dyn TokenCredential>;
74···131 verbosity: Verbosity<InfoLevel>,
132}
133134+struct ActorPools {
135+ repo: Pool,
136+ blob: Pool,
137+}
138+impl Clone for ActorPools {
139+ fn clone(&self) -> Self {
140+ Self {
141+ repo: self.repo.clone(),
142+ blob: self.blob.clone(),
143+ }
144+ }
145+}
146+147#[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")]
148#[derive(Clone, FromRef)]
149struct AppState {
···151 config: AppConfig,
152 /// The Azure credential.
153 cred: Cred,
154+ /// The main database connection pool. Used for common PDS data, like invite codes.
155+ db: Pool,
156+ /// Actor-specific database connection pools. Hashed by DID.
157+ db_actors: std::collections::HashMap<String, ActorPools>,
158159 /// The HTTP client with middleware.
160 client: Client,
···305#[expect(
306 clippy::cognitive_complexity,
307 clippy::too_many_lines,
308+ unused_qualifications,
309 reason = "main function has high complexity"
310)]
311async fn run() -> anyhow::Result<()> {
···403 let cred = azure_identity::DefaultAzureCredential::new()
404 .context("failed to create Azure credential")?;
405406+ // Create a database connection manager and pool for the main database.
407+ let pool =
408+ establish_pool(&config.db).context("failed to establish database connection pool")?;
409+ // Create a dictionary of database connection pools for each actor.
410+ let mut actor_pools = std::collections::HashMap::new();
411+ // let mut actor_blob_pools = std::collections::HashMap::new();
412+ // We'll determine actors by looking in the data/repo dir for .db files.
413+ let mut actor_dbs = tokio::fs::read_dir(&config.repo.path)
414+ .await
415+ .context("failed to read repo directory")?;
416+ while let Some(entry) = actor_dbs
417+ .next_entry()
418+ .await
419+ .context("failed to read repo dir")?
420+ {
421+ let path = entry.path();
422+ if path.extension().and_then(|s| s.to_str()) == Some("db") {
423+ let did = path
424+ .file_stem()
425+ .and_then(|s| s.to_str())
426+ .context("failed to get actor DID")?;
427+ let did = Did::from_str(did).expect("should be able to parse actor DID");
428429+ // Create a new database connection manager and pool for the actor.
430+ // The path for the SQLite connection needs to look like "sqlite://data/repo/<actor>.db"
431+ let path_repo = format!("sqlite://{}", path.display());
432+ let actor_repo_pool =
433+ establish_pool(&path_repo).context("failed to create database connection pool")?;
434+ // Create a new database connection manager and pool for the actor blobs.
435+ // The path for the SQLite connection needs to look like "sqlite://data/blob/<actor>.db"
436+ let path_blob = path_repo.replace("repo", "blob");
437+ let actor_blob_pool =
438+ establish_pool(&path_blob).context("failed to create database connection pool")?;
439+ actor_pools.insert(
440+ did.to_string(),
441+ ActorPools {
442+ repo: actor_repo_pool,
443+ blob: actor_blob_pool,
444+ },
445+ );
446+ }
447+ }
448 // Apply pending migrations
449+ // let conn = pool.get().await?;
450+ // conn.run_pending_migrations(MIGRATIONS)
451+ // .expect("should be able to run migrations");
00452453 let (_fh, fhp) = firehose::spawn(client.clone(), config.clone());
454···471 .with_state(AppState {
472 cred,
473 config: config.clone(),
474+ db: pool.clone(),
475+ db_actors: actor_pools.clone(),
476 client: client.clone(),
477 simple_client,
478 firehose: fhp,
···485486 // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created).
487 // If so, create an invite code and share it via the console.
488+ let conn = pool.get().await.context("failed to get db connection")?;
489490 #[derive(QueryableByName)]
491 struct TotalCount {
···493 total_count: i32,
494 }
495496+ // let result = diesel::sql_query(
497+ // "SELECT (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) AS total_count",
498+ // )
499+ // .get_result::<TotalCount>(conn)
500+ // .context("failed to query database")?;
501+ let result = conn.interact(move |conn| {
502+ diesel::sql_query(
503+ "SELECT (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) AS total_count",
504+ )
505+ .get_result::<TotalCount>(conn)
506+ })
507+ .await
508+ .expect("should be able to query database")?;
509510 let c = result.total_count;
511···513 if c == 0 {
514 let uuid = Uuid::new_v4().to_string();
515516+ let uuid_clone = uuid.clone();
517+ conn.interact(move |conn| {
518+ diesel::sql_query(
519 "INSERT INTO invites (id, did, count, created_at) VALUES (?, NULL, 1, datetime('now'))",
520 )
521+ .bind::<diesel::sql_types::Text, _>(uuid_clone)
522 .execute(conn)
523+ .context("failed to create new invite code")
524+ .expect("should be able to create invite code")
525+ });
526527 // N.B: This is a sensitive message, so we're bypassing `tracing` here and
528 // logging it directly to console.
+1-1
src/tests.rs
···222 let opts = SqliteConnectOptions::from_str(&config.db)
223 .context("failed to parse database options")?
224 .create_if_missing(true);
225- let db = SqlitePool::connect_with(opts).await?;
226227 sqlx::migrate!()
228 .run(&db)
···222 let opts = SqliteConnectOptions::from_str(&config.db)
223 .context("failed to parse database options")?
224 .create_if_missing(true);
225+ let db = SqliteDbConn::connect_with(opts).await?;
226227 sqlx::migrate!()
228 .run(&db)