···22use axum::{Router, routing::post};
33use constcat::concat;
4455-use crate::AppState;
55+use crate::serve::AppState;
6677pub mod apply_writes;
88// pub mod create_record;
+1-1
src/apis/mod.rs
···77use axum::{Json, Router, routing::get};
88use serde_json::json;
991010-use crate::{AppState, Result};
1010+use crate::serve::{AppState, Result};
11111212/// Health check endpoint. Returns name and version of the service.
1313pub(crate) async fn health() -> Result<Json<serde_json::Value>> {
+4-1
src/auth.rs
···88use diesel::prelude::*;
99use sha2::{Digest as _, Sha256};
10101111-use crate::{AppState, Error, error::ErrorMessage};
1111+use crate::{
1212+ error::{Error, ErrorMessage},
1313+ serve::AppState,
1414+};
12151316/// Request extractor for authenticated users.
1417/// If specified in an API endpoint, this guarantees the API can only be called
···148148149149impl ApiError {
150150 /// Get the appropriate HTTP status code for this error
151151- fn status_code(&self) -> StatusCode {
151151+ const fn status_code(&self) -> StatusCode {
152152 match self {
153153 Self::RuntimeError => StatusCode::INTERNAL_SERVER_ERROR,
154154 Self::InvalidLogin
···190190 Self::BadRequest(error, _) => error,
191191 Self::AuthRequiredError(_) => "AuthRequiredError",
192192 }
193193- .to_string()
193193+ .to_owned()
194194 }
195195196196 /// Get the user-facing error message
···218218 Self::BadRequest(_, msg) => msg,
219219 Self::AuthRequiredError(msg) => msg,
220220 }
221221- .to_string()
221221+ .to_owned()
222222 }
223223}
224224225225impl From<Error> for ApiError {
226226 fn from(_value: Error) -> Self {
227227- ApiError::RuntimeError
227227+ Self::RuntimeError
228228 }
229229}
230230231231impl From<handle::errors::Error> for ApiError {
232232 fn from(value: handle::errors::Error) -> Self {
233233 match value.kind {
234234- ErrorKind::InvalidHandle => ApiError::InvalidHandle,
235235- ErrorKind::HandleNotAvailable => ApiError::HandleNotAvailable,
236236- ErrorKind::UnsupportedDomain => ApiError::UnsupportedDomain,
237237- ErrorKind::InternalError => ApiError::RuntimeError,
234234+ ErrorKind::InvalidHandle => Self::InvalidHandle,
235235+ ErrorKind::HandleNotAvailable => Self::HandleNotAvailable,
236236+ ErrorKind::UnsupportedDomain => Self::UnsupportedDomain,
237237+ ErrorKind::InternalError => Self::RuntimeError,
238238 }
239239 }
240240}
···245245 let error_type = self.error_type();
246246 let message = self.message();
247247248248- // Log the error for debugging
249249- error!("API Error: {}: {}", error_type, message);
248248+ if cfg!(debug_assertions) {
249249+ error!("API Error: {}: {}", error_type, message);
250250+ }
250251251252 // Create the error message and serialize to JSON
252253 let error_message = ErrorMessage::new(error_type, message);
253254 let body = serde_json::to_string(&error_message).unwrap_or_else(|_| {
254254- r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_string()
255255+ r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_owned()
255256 });
256257257258 // Build the response
-426
src/firehose.rs
···11-//! The firehose module.
22-use std::{collections::VecDeque, time::Duration};
33-44-use anyhow::{Result, bail};
55-use atrium_api::{
66- com::atproto::sync::{self},
77- types::string::{Datetime, Did, Tid},
88-};
99-use atrium_repo::Cid;
1010-use axum::extract::ws::{Message, WebSocket};
1111-use metrics::{counter, gauge};
1212-use rand::Rng as _;
1313-use serde::{Serialize, ser::SerializeMap as _};
1414-use tracing::{debug, error, info, warn};
1515-1616-use crate::{
1717- Client,
1818- config::AppConfig,
1919- metrics::{FIREHOSE_HISTORY, FIREHOSE_LISTENERS, FIREHOSE_MESSAGES, FIREHOSE_SEQUENCE},
2020-};
2121-2222-enum FirehoseMessage {
2323- Broadcast(sync::subscribe_repos::Message),
2424- Connect(Box<(WebSocket, Option<i64>)>),
2525-}
2626-2727-enum FrameHeader {
2828- Error,
2929- Message(String),
3030-}
3131-3232-impl Serialize for FrameHeader {
3333- #[expect(clippy::question_mark_used, reason = "returns a Result")]
3434- fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
3535- where
3636- S: serde::Serializer,
3737- {
3838- let mut map = serializer.serialize_map(None)?;
3939-4040- match *self {
4141- Self::Message(ref s) => {
4242- map.serialize_key("op")?;
4343- map.serialize_value(&1_i32)?;
4444- map.serialize_key("t")?;
4545- map.serialize_value(s.as_str())?;
4646- }
4747- Self::Error => {
4848- map.serialize_key("op")?;
4949- map.serialize_value(&-1_i32)?;
5050- }
5151- }
5252-5353- map.end()
5454- }
5555-}
5656-5757-/// A repository operation.
5858-pub(crate) enum RepoOp {
5959- /// Create a new record.
6060- Create {
6161- /// The CID of the record.
6262- cid: Cid,
6363- /// The path of the record.
6464- path: String,
6565- },
6666- /// Delete an existing record.
6767- Delete {
6868- /// The path of the record.
6969- path: String,
7070- /// The previous CID of the record.
7171- prev: Cid,
7272- },
7373- /// Update an existing record.
7474- Update {
7575- /// The CID of the record.
7676- cid: Cid,
7777- /// The path of the record.
7878- path: String,
7979- /// The previous CID of the record.
8080- prev: Cid,
8181- },
8282-}
8383-8484-impl From<RepoOp> for sync::subscribe_repos::RepoOp {
8585- fn from(val: RepoOp) -> Self {
8686- let (action, cid, prev, path) = match val {
8787- RepoOp::Create { cid, path } => ("create", Some(cid), None, path),
8888- RepoOp::Update { cid, path, prev } => ("update", Some(cid), Some(prev), path),
8989- RepoOp::Delete { path, prev } => ("delete", None, Some(prev), path),
9090- };
9191-9292- sync::subscribe_repos::RepoOpData {
9393- action: action.to_owned(),
9494- cid: cid.map(atrium_api::types::CidLink),
9595- prev: prev.map(atrium_api::types::CidLink),
9696- path,
9797- }
9898- .into()
9999- }
100100-}
101101-102102-/// A commit to the repository.
103103-pub(crate) struct Commit {
104104- /// Blobs that were created in this commit.
105105- pub blobs: Vec<Cid>,
106106- /// The car file containing the commit blocks.
107107- pub car: Vec<u8>,
108108- /// The CID of the commit.
109109- pub cid: Cid,
110110- /// The DID of the repository changed.
111111- pub did: Did,
112112- /// The operations performed in this commit.
113113- pub ops: Vec<RepoOp>,
114114- /// The previous commit's CID (if applicable).
115115- pub pcid: Option<Cid>,
116116- /// The revision of the commit.
117117- pub rev: String,
118118-}
119119-120120-impl From<Commit> for sync::subscribe_repos::Commit {
121121- fn from(val: Commit) -> Self {
122122- sync::subscribe_repos::CommitData {
123123- blobs: val
124124- .blobs
125125- .into_iter()
126126- .map(atrium_api::types::CidLink)
127127- .collect::<Vec<_>>(),
128128- blocks: val.car,
129129- commit: atrium_api::types::CidLink(val.cid),
130130- ops: val.ops.into_iter().map(Into::into).collect::<Vec<_>>(),
131131- prev_data: val.pcid.map(atrium_api::types::CidLink),
132132- rebase: false,
133133- repo: val.did,
134134- rev: Tid::new(val.rev).expect("should be valid revision"),
135135- seq: 0,
136136- since: None,
137137- time: Datetime::now(),
138138- too_big: false,
139139- }
140140- .into()
141141- }
142142-}
143143-144144-/// A firehose producer. This is used to transmit messages to the firehose for broadcast.
145145-#[derive(Clone, Debug)]
146146-pub(crate) struct FirehoseProducer {
147147- /// The channel to send messages to the firehose.
148148- tx: tokio::sync::mpsc::Sender<FirehoseMessage>,
149149-}
150150-151151-impl FirehoseProducer {
152152- /// Broadcast an `#account` event.
153153- pub(crate) async fn account(&self, account: impl Into<sync::subscribe_repos::Account>) {
154154- drop(
155155- self.tx
156156- .send(FirehoseMessage::Broadcast(
157157- sync::subscribe_repos::Message::Account(Box::new(account.into())),
158158- ))
159159- .await,
160160- );
161161- }
162162- /// Handle client connection.
163163- pub(crate) async fn client_connection(&self, ws: WebSocket, cursor: Option<i64>) {
164164- drop(
165165- self.tx
166166- .send(FirehoseMessage::Connect(Box::new((ws, cursor))))
167167- .await,
168168- );
169169- }
170170- /// Broadcast a `#commit` event.
171171- pub(crate) async fn commit(&self, commit: impl Into<sync::subscribe_repos::Commit>) {
172172- drop(
173173- self.tx
174174- .send(FirehoseMessage::Broadcast(
175175- sync::subscribe_repos::Message::Commit(Box::new(commit.into())),
176176- ))
177177- .await,
178178- );
179179- }
180180- /// Broadcast an `#identity` event.
181181- pub(crate) async fn identity(&self, identity: impl Into<sync::subscribe_repos::Identity>) {
182182- drop(
183183- self.tx
184184- .send(FirehoseMessage::Broadcast(
185185- sync::subscribe_repos::Message::Identity(Box::new(identity.into())),
186186- ))
187187- .await,
188188- );
189189- }
190190-}
191191-192192-#[expect(
193193- clippy::as_conversions,
194194- clippy::cast_possible_truncation,
195195- clippy::cast_sign_loss,
196196- clippy::cast_precision_loss,
197197- clippy::arithmetic_side_effects
198198-)]
199199-/// Convert a `usize` to a `f64`.
200200-const fn convert_usize_f64(x: usize) -> Result<f64, &'static str> {
201201- let result = x as f64;
202202- if result as usize - x > 0 {
203203- return Err("cannot convert");
204204- }
205205- Ok(result)
206206-}
207207-208208-/// Serialize a message.
209209-fn serialize_message(seq: u64, mut msg: sync::subscribe_repos::Message) -> (&'static str, Vec<u8>) {
210210- let mut dummy_seq = 0_i64;
211211- #[expect(clippy::pattern_type_mismatch)]
212212- let (ty, nseq) = match &mut msg {
213213- sync::subscribe_repos::Message::Account(m) => ("#account", &mut m.seq),
214214- sync::subscribe_repos::Message::Commit(m) => ("#commit", &mut m.seq),
215215- sync::subscribe_repos::Message::Identity(m) => ("#identity", &mut m.seq),
216216- sync::subscribe_repos::Message::Sync(m) => ("#sync", &mut m.seq),
217217- sync::subscribe_repos::Message::Info(_m) => ("#info", &mut dummy_seq),
218218- };
219219- // Set the sequence number.
220220- *nseq = i64::try_from(seq).expect("should find seq");
221221-222222- let hdr = FrameHeader::Message(ty.to_owned());
223223-224224- let mut frame = Vec::new();
225225- serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header");
226226- serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message");
227227-228228- (ty, frame)
229229-}
230230-231231-/// Broadcast a message out to all clients.
232232-async fn broadcast_message(clients: &mut Vec<WebSocket>, msg: Message) -> Result<()> {
233233- counter!(FIREHOSE_MESSAGES).increment(1);
234234-235235- for i in (0..clients.len()).rev() {
236236- let client = clients.get_mut(i).expect("should find client");
237237- if let Err(e) = client.send(msg.clone()).await {
238238- debug!("Firehose client disconnected: {e}");
239239- drop(clients.remove(i));
240240- }
241241- }
242242-243243- gauge!(FIREHOSE_LISTENERS)
244244- .set(convert_usize_f64(clients.len()).expect("should find clients length"));
245245- Ok(())
246246-}
247247-248248-/// Handle a new connection from a websocket client created by subscribeRepos.
249249-async fn handle_connect(
250250- mut ws: WebSocket,
251251- seq: u64,
252252- history: &VecDeque<(u64, &str, sync::subscribe_repos::Message)>,
253253- cursor: Option<i64>,
254254-) -> Result<WebSocket> {
255255- if let Some(cursor) = cursor {
256256- let mut frame = Vec::new();
257257- let cursor = u64::try_from(cursor);
258258- if cursor.is_err() {
259259- tracing::warn!("cursor is not a valid u64");
260260- return Ok(ws);
261261- }
262262- let cursor = cursor.expect("should be valid u64");
263263- // Cursor specified; attempt to backfill the consumer.
264264- if cursor > seq {
265265- let hdr = FrameHeader::Error;
266266- let msg = sync::subscribe_repos::Error::FutureCursor(Some(format!(
267267- "cursor {cursor} is greater than the current sequence number {seq}"
268268- )));
269269- serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header");
270270- serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message");
271271- // Drop the connection.
272272- drop(ws.send(Message::binary(frame)).await);
273273- bail!(
274274- "connection dropped: cursor {cursor} is greater than the current sequence number {seq}"
275275- );
276276- }
277277-278278- for &(historical_seq, ty, ref msg) in history {
279279- if cursor > historical_seq {
280280- continue;
281281- }
282282- let hdr = FrameHeader::Message(ty.to_owned());
283283- serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header");
284284- serde_ipld_dagcbor::to_writer(&mut frame, msg).expect("should serialize message");
285285- if let Err(e) = ws.send(Message::binary(frame.clone())).await {
286286- debug!("Firehose client disconnected during backfill: {e}");
287287- break;
288288- }
289289- // Clear out the frame to begin a new one.
290290- frame.clear();
291291- }
292292- }
293293-294294- Ok(ws)
295295-}
296296-297297-/// Reconnect to upstream relays.
298298-pub(crate) async fn reconnect_relays(client: &Client, config: &AppConfig) {
299299- // Avoid connecting to upstream relays in test mode.
300300- if config.test {
301301- return;
302302- }
303303-304304- info!("attempting to reconnect to upstream relays");
305305- for relay in &config.firehose.relays {
306306- let Some(host) = relay.host_str() else {
307307- warn!("relay {} has no host specified", relay);
308308- continue;
309309- };
310310-311311- let r = client
312312- .post(format!("https://{host}/xrpc/com.atproto.sync.requestCrawl"))
313313- .json(&serde_json::json!({
314314- "hostname": format!("https://{}", config.host_name)
315315- }))
316316- .send()
317317- .await;
318318-319319- let r = match r {
320320- Ok(r) => r,
321321- Err(e) => {
322322- error!("failed to hit upstream relay {host}: {e}");
323323- continue;
324324- }
325325- };
326326-327327- let s = r.status();
328328- if let Err(e) = r.error_for_status_ref() {
329329- error!("failed to hit upstream relay {host}: {e}");
330330- }
331331-332332- let b = r.json::<serde_json::Value>().await;
333333- if let Ok(b) = b {
334334- info!("relay {host}: {} {}", s, b);
335335- } else {
336336- info!("relay {host}: {}", s);
337337- }
338338- }
339339-}
340340-341341-/// The main entrypoint for the firehose.
342342-///
343343-/// This will broadcast all updates in this PDS out to anyone who is listening.
344344-///
345345-/// Reference: <https://atproto.com/specs/sync>
346346-pub(crate) fn spawn(
347347- client: Client,
348348- config: AppConfig,
349349-) -> (tokio::task::JoinHandle<()>, FirehoseProducer) {
350350- let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
351351- let handle = tokio::spawn(async move {
352352- fn time_since_inception() -> u64 {
353353- chrono::Utc::now()
354354- .timestamp_micros()
355355- .checked_sub(1_743_442_000_000_000)
356356- .expect("should not wrap")
357357- .unsigned_abs()
358358- }
359359- let mut clients: Vec<WebSocket> = Vec::new();
360360- let mut history = VecDeque::with_capacity(1000);
361361- let mut seq = time_since_inception();
362362-363363- loop {
364364- if let Ok(msg) = tokio::time::timeout(Duration::from_secs(30), rx.recv()).await {
365365- match msg {
366366- Some(FirehoseMessage::Broadcast(msg)) => {
367367- let (ty, by) = serialize_message(seq, msg.clone());
368368-369369- history.push_back((seq, ty, msg));
370370- gauge!(FIREHOSE_HISTORY).set(
371371- convert_usize_f64(history.len()).expect("should find history length"),
372372- );
373373-374374- info!(
375375- "Broadcasting message {} {} to {} clients",
376376- seq,
377377- ty,
378378- clients.len()
379379- );
380380-381381- counter!(FIREHOSE_SEQUENCE).absolute(seq);
382382- let now = time_since_inception();
383383- if now > seq {
384384- seq = now;
385385- } else {
386386- seq = seq.checked_add(1).expect("should not wrap");
387387- }
388388-389389- drop(broadcast_message(&mut clients, Message::binary(by)).await);
390390- }
391391- Some(FirehoseMessage::Connect(ws_cursor)) => {
392392- let (ws, cursor) = *ws_cursor;
393393- match handle_connect(ws, seq, &history, cursor).await {
394394- Ok(r) => {
395395- gauge!(FIREHOSE_LISTENERS).increment(1_i32);
396396- clients.push(r);
397397- }
398398- Err(e) => {
399399- error!("failed to connect new client: {e}");
400400- }
401401- }
402402- }
403403- // All producers have been destroyed.
404404- None => break,
405405- }
406406- } else {
407407- if clients.is_empty() {
408408- reconnect_relays(&client, &config).await;
409409- }
410410-411411- let contents = rand::thread_rng()
412412- .sample_iter(rand::distributions::Alphanumeric)
413413- .take(15)
414414- .map(char::from)
415415- .collect::<String>();
416416-417417- // Send a websocket ping message.
418418- // Reference: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers#pings_and_pongs_the_heartbeat_of_websockets
419419- let message = Message::Ping(axum::body::Bytes::from_owner(contents));
420420- drop(broadcast_message(&mut clients, message).await);
421421- }
422422- }
423423- });
424424-425425- (handle, FirehoseProducer { tx })
426426-}
+3-438
src/lib.rs
···88mod db;
99mod did;
1010pub mod error;
1111-mod firehose;
1211mod metrics;
1313-mod mmap;
1412mod models;
1513mod oauth;
1616-mod plc;
1714mod schema;
1515+mod serve;
1816mod service_proxy;
1919-#[cfg(test)]
2020-mod tests;
21172222-use account_manager::{AccountManager, SharedAccountManager};
2323-use anyhow::{Context as _, anyhow};
2424-use atrium_api::types::string::Did;
2525-use atrium_crypto::keypair::{Export as _, Secp256k1Keypair};
2626-use auth::AuthenticatedUser;
2727-use axum::{
2828- Router,
2929- body::Body,
3030- extract::{FromRef, Request, State},
3131- http::{self, HeaderMap, Response, StatusCode, Uri},
3232- response::IntoResponse,
3333- routing::get,
3434-};
3535-use azure_core::credentials::TokenCredential;
3636-use clap::Parser;
3737-use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter};
3838-use config::AppConfig;
3939-use db::establish_pool;
4040-use deadpool_diesel::sqlite::Pool;
4141-use diesel::prelude::*;
4242-use diesel_migrations::{EmbeddedMigrations, embed_migrations};
4343-pub use error::Error;
4444-use figment::{Figment, providers::Format as _};
4545-use firehose::FirehoseProducer;
4646-use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager};
4747-use rand::Rng as _;
4848-use rsky_pds::{crawlers::Crawlers, sequencer::Sequencer};
4949-use serde::{Deserialize, Serialize};
5050-use service_proxy::service_proxy;
5151-use std::{
5252- net::{IpAddr, Ipv4Addr, SocketAddr},
5353- path::PathBuf,
5454- str::FromStr as _,
5555- sync::Arc,
5656-};
5757-use tokio::{net::TcpListener, sync::RwLock};
5858-use tower_http::{cors::CorsLayer, trace::TraceLayer};
5959-use tracing::{info, warn};
6060-use uuid::Uuid;
6161-6262-/// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`.
6363-pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
6464-6565-/// Embedded migrations
6666-pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
6767-6868-/// The application-wide result type.
6969-pub type Result<T> = std::result::Result<T, Error>;
7070-/// The reqwest client type with middleware.
7171-pub type Client = reqwest_middleware::ClientWithMiddleware;
7272-7373-/// The Shared Sequencer which requests crawls from upstream relays and emits events to the firehose.
7474-pub struct SharedSequencer {
7575- /// The sequencer instance.
7676- pub sequencer: RwLock<Sequencer>,
7777-}
7878-7979-#[expect(
8080- clippy::arbitrary_source_item_ordering,
8181- reason = "serialized data might be structured"
8282-)]
8383-#[derive(Serialize, Deserialize, Debug, Clone)]
8484-/// The key data structure.
8585-struct KeyData {
8686- /// Primary signing key for all repo operations.
8787- skey: Vec<u8>,
8888- /// Primary signing (rotation) key for all PLC operations.
8989- rkey: Vec<u8>,
9090-}
9191-9292-// FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies,
9393-// and the implementations of this algorithm are much more limited as compared to P256.
9494-//
9595-// Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/
9696-#[derive(Clone)]
9797-/// The signing key for PLC/DID operations.
9898-pub struct SigningKey(Arc<Secp256k1Keypair>);
9999-#[derive(Clone)]
100100-/// The rotation key for PLC operations.
101101-pub struct RotationKey(Arc<Secp256k1Keypair>);
102102-103103-impl std::ops::Deref for SigningKey {
104104- type Target = Secp256k1Keypair;
105105-106106- fn deref(&self) -> &Self::Target {
107107- &self.0
108108- }
109109-}
110110-111111-impl SigningKey {
112112- /// Import from a private key.
113113- pub fn import(key: &[u8]) -> Result<Self> {
114114- let key = Secp256k1Keypair::import(key).context("failed to import signing key")?;
115115- Ok(Self(Arc::new(key)))
116116- }
117117-}
118118-119119-impl std::ops::Deref for RotationKey {
120120- type Target = Secp256k1Keypair;
121121-122122- fn deref(&self) -> &Self::Target {
123123- &self.0
124124- }
125125-}
126126-127127-#[derive(Parser, Debug, Clone)]
128128-/// Command line arguments.
129129-pub struct Args {
130130- /// Path to the configuration file
131131- #[arg(short, long, default_value = "default.toml")]
132132- pub config: PathBuf,
133133- /// The verbosity level.
134134- #[command(flatten)]
135135- pub verbosity: Verbosity<InfoLevel>,
136136-}
137137-138138-/// The actor pools for the database connections.
139139-pub struct ActorPools {
140140- /// The database connection pool for the actor's repository.
141141- pub repo: Pool,
142142- /// The database connection pool for the actor's blobs.
143143- pub blob: Pool,
144144-}
145145-146146-impl Clone for ActorPools {
147147- fn clone(&self) -> Self {
148148- Self {
149149- repo: self.repo.clone(),
150150- blob: self.blob.clone(),
151151- }
152152- }
153153-}
154154-155155-#[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")]
156156-#[derive(Clone, FromRef)]
157157-pub struct AppState {
158158- /// The application configuration.
159159- pub config: AppConfig,
160160- /// The main database connection pool. Used for common PDS data, like invite codes.
161161- pub db: Pool,
162162- /// Actor-specific database connection pools. Hashed by DID.
163163- pub db_actors: std::collections::HashMap<String, ActorPools>,
164164-165165- /// The HTTP client with middleware.
166166- pub client: Client,
167167- /// The simple HTTP client.
168168- pub simple_client: reqwest::Client,
169169- /// The firehose producer.
170170- pub sequencer: Arc<SharedSequencer>,
171171- /// The account manager.
172172- pub account_manager: Arc<SharedAccountManager>,
173173-174174- /// The signing key.
175175- pub signing_key: SigningKey,
176176- /// The rotation key.
177177- pub rotation_key: RotationKey,
178178-}
1818+pub use serve::run;
1791918020/// The index (/) route.
181181-async fn index() -> impl IntoResponse {
2121+async fn index() -> impl axum::response::IntoResponse {
18222 r"
18323 __ __
18424 /\ \__ /\ \__
···19939 Protocol: https://atproto.com
20040 "
20141}
202202-203203-/// The main application entry point.
204204-#[expect(
205205- clippy::cognitive_complexity,
206206- clippy::too_many_lines,
207207- unused_qualifications,
208208- reason = "main function has high complexity"
209209-)]
210210-pub async fn run() -> anyhow::Result<()> {
211211- let args = Args::parse();
212212-213213- // Set up trace logging to console and account for the user-provided verbosity flag.
214214- if args.verbosity.log_level_filter() != LevelFilter::Off {
215215- let lvl = match args.verbosity.log_level_filter() {
216216- LevelFilter::Error => tracing::Level::ERROR,
217217- LevelFilter::Warn => tracing::Level::WARN,
218218- LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO,
219219- LevelFilter::Debug => tracing::Level::DEBUG,
220220- LevelFilter::Trace => tracing::Level::TRACE,
221221- };
222222- tracing_subscriber::fmt().with_max_level(lvl).init();
223223- }
224224-225225- if !args.config.exists() {
226226- // Throw up a warning if the config file does not exist.
227227- //
228228- // This is not fatal because users can specify all configuration settings via
229229- // the environment, but the most likely scenario here is that a user accidentally
230230- // omitted the config file for some reason (e.g. forgot to mount it into Docker).
231231- warn!(
232232- "configuration file {} does not exist",
233233- args.config.display()
234234- );
235235- }
236236-237237- // Read and parse the user-provided configuration.
238238- let config: AppConfig = Figment::new()
239239- .admerge(figment::providers::Toml::file(args.config))
240240- .admerge(figment::providers::Env::prefixed("BLUEPDS_"))
241241- .extract()
242242- .context("failed to load configuration")?;
243243-244244- if config.test {
245245- warn!("BluePDS starting up in TEST mode.");
246246- warn!("This means the application will not federate with the rest of the network.");
247247- warn!(
248248- "If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`"
249249- );
250250- }
251251-252252- // Initialize metrics reporting.
253253- metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?;
254254-255255- // Create a reqwest client that will be used for all outbound requests.
256256- let simple_client = reqwest::Client::builder()
257257- .user_agent(APP_USER_AGENT)
258258- .build()
259259- .context("failed to build requester client")?;
260260- let client = reqwest_middleware::ClientBuilder::new(simple_client.clone())
261261- .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache {
262262- mode: CacheMode::Default,
263263- manager: MokaManager::default(),
264264- options: HttpCacheOptions::default(),
265265- }))
266266- .build();
267267-268268- tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?)
269269- .await
270270- .context("failed to create key directory")?;
271271-272272- // Check if crypto keys exist. If not, create new ones.
273273- let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) {
274274- let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f))
275275- .context("failed to deserialize crypto keys")?;
276276-277277- let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?;
278278- let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?;
279279-280280- (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
281281- } else {
282282- info!("signing keys not found, generating new ones");
283283-284284- let skey = Secp256k1Keypair::create(&mut rand::thread_rng());
285285- let rkey = Secp256k1Keypair::create(&mut rand::thread_rng());
286286-287287- let keys = KeyData {
288288- skey: skey.export(),
289289- rkey: rkey.export(),
290290- };
291291-292292- let mut f = std::fs::File::create(&config.key).context("failed to create key file")?;
293293- serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?;
294294-295295- (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
296296- };
297297-298298- tokio::fs::create_dir_all(&config.repo.path).await?;
299299- tokio::fs::create_dir_all(&config.plc.path).await?;
300300- tokio::fs::create_dir_all(&config.blob.path).await?;
301301-302302- // Create a database connection manager and pool for the main database.
303303- let pool =
304304- establish_pool(&config.db).context("failed to establish database connection pool")?;
305305- // Create a dictionary of database connection pools for each actor.
306306- let mut actor_pools = std::collections::HashMap::new();
307307- // let mut actor_blob_pools = std::collections::HashMap::new();
308308- // We'll determine actors by looking in the data/repo dir for .db files.
309309- let mut actor_dbs = tokio::fs::read_dir(&config.repo.path)
310310- .await
311311- .context("failed to read repo directory")?;
312312- while let Some(entry) = actor_dbs
313313- .next_entry()
314314- .await
315315- .context("failed to read repo dir")?
316316- {
317317- let path = entry.path();
318318- if path.extension().and_then(|s| s.to_str()) == Some("db") {
319319- let did_path = path
320320- .file_stem()
321321- .and_then(|s| s.to_str())
322322- .context("failed to get actor DID")?;
323323- let did = Did::from_str(&format!("did:plc:{}", did_path))
324324- .expect("should be able to parse actor DID");
325325-326326- // Create a new database connection manager and pool for the actor.
327327- // The path for the SQLite connection needs to look like "sqlite://data/repo/<actor>.db"
328328- let path_repo = format!("sqlite://{}", did_path);
329329- let actor_repo_pool =
330330- establish_pool(&path_repo).context("failed to create database connection pool")?;
331331- // Create a new database connection manager and pool for the actor blobs.
332332- // The path for the SQLite connection needs to look like "sqlite://data/blob/<actor>.db"
333333- let path_blob = path_repo.replace("repo", "blob");
334334- let actor_blob_pool =
335335- establish_pool(&path_blob).context("failed to create database connection pool")?;
336336- drop(actor_pools.insert(
337337- did.to_string(),
338338- ActorPools {
339339- repo: actor_repo_pool,
340340- blob: actor_blob_pool,
341341- },
342342- ));
343343- }
344344- }
345345- // Apply pending migrations
346346- // let conn = pool.get().await?;
347347- // conn.run_pending_migrations(MIGRATIONS)
348348- // .expect("should be able to run migrations");
349349-350350- let hostname = config.host_name.clone();
351351- let crawlers: Vec<String> = config
352352- .firehose
353353- .relays
354354- .iter()
355355- .map(|s| s.to_string())
356356- .collect();
357357- let sequencer = Arc::new(SharedSequencer {
358358- sequencer: RwLock::new(Sequencer::new(
359359- Crawlers::new(hostname, crawlers.clone()),
360360- None,
361361- )),
362362- });
363363- let account_manager = SharedAccountManager {
364364- account_manager: RwLock::new(AccountManager::new(pool.clone())),
365365- };
366366-367367- let addr = config
368368- .listen_address
369369- .unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000));
370370-371371- let app = Router::new()
372372- .route("/", get(index))
373373- .merge(oauth::routes())
374374- .nest(
375375- "/xrpc",
376376- apis::routes()
377377- .merge(actor_endpoints::routes())
378378- .fallback(service_proxy),
379379- )
380380- // .layer(RateLimitLayer::new(30, Duration::from_secs(30)))
381381- .layer(CorsLayer::permissive())
382382- .layer(TraceLayer::new_for_http())
383383- .with_state(AppState {
384384- config: config.clone(),
385385- db: pool.clone(),
386386- db_actors: actor_pools.clone(),
387387- client: client.clone(),
388388- simple_client,
389389- sequencer: sequencer.clone(),
390390- account_manager: Arc::new(account_manager),
391391- signing_key: skey,
392392- rotation_key: rkey,
393393- });
394394-395395- info!("listening on {addr}");
396396- info!("connect to: http://127.0.0.1:{}", addr.port());
397397-398398- // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created).
399399- // If so, create an invite code and share it via the console.
400400- let conn = pool.get().await.context("failed to get db connection")?;
401401-402402- #[derive(QueryableByName)]
403403- struct TotalCount {
404404- #[diesel(sql_type = diesel::sql_types::Integer)]
405405- total_count: i32,
406406- }
407407-408408- let result = conn.interact(move |conn| {
409409- diesel::sql_query(
410410- "SELECT (SELECT COUNT(*) FROM account) + (SELECT COUNT(*) FROM invite_code) AS total_count",
411411- )
412412- .get_result::<TotalCount>(conn)
413413- })
414414- .await
415415- .expect("should be able to query database")?;
416416-417417- let c = result.total_count;
418418-419419- #[expect(clippy::print_stdout)]
420420- if c == 0 {
421421- let uuid = Uuid::new_v4().to_string();
422422-423423- use crate::models::pds as models;
424424- use crate::schema::pds::invite_code::dsl as InviteCode;
425425- let uuid_clone = uuid.clone();
426426- drop(
427427- conn.interact(move |conn| {
428428- diesel::insert_into(InviteCode::invite_code)
429429- .values(models::InviteCode {
430430- code: uuid_clone,
431431- available_uses: 1,
432432- disabled: 0,
433433- for_account: "None".to_owned(),
434434- created_by: "None".to_owned(),
435435- created_at: "None".to_owned(),
436436- })
437437- .execute(conn)
438438- .context("failed to create new invite code")
439439- })
440440- .await
441441- .expect("should be able to create invite code"),
442442- );
443443-444444- // N.B: This is a sensitive message, so we're bypassing `tracing` here and
445445- // logging it directly to console.
446446- println!("=====================================");
447447- println!(" FIRST STARTUP ");
448448- println!("=====================================");
449449- println!("Use this code to create an account:");
450450- println!("{uuid}");
451451- println!("=====================================");
452452- }
453453-454454- let listener = TcpListener::bind(&addr)
455455- .await
456456- .context("failed to bind address")?;
457457-458458- // Serve the app, and request crawling from upstream relays.
459459- let serve = tokio::spawn(async move {
460460- axum::serve(listener, app.into_make_service())
461461- .await
462462- .context("failed to serve app")
463463- });
464464-465465- // Now that the app is live, request a crawl from upstream relays.
466466- let mut background_sequencer = sequencer.sequencer.write().await.clone();
467467- drop(tokio::spawn(
468468- async move { background_sequencer.start().await },
469469- ));
470470-471471- serve
472472- .await
473473- .map_err(Into::into)
474474- .and_then(|r| r)
475475- .context("failed to serve app")
476476-}
+1-3
src/main.rs
···11//! BluePDS binary entry point.
2233use anyhow::Context as _;
44-use clap::Parser;
5465#[tokio::main(flavor = "multi_thread")]
76async fn main() -> anyhow::Result<()> {
88- // Parse command line arguments and call into the library's run function
97 bluepds::run().await.context("failed to run application")
1010-}88+}
-274
src/mmap.rs
···11-#![allow(clippy::arbitrary_source_item_ordering)]
22-use std::io::{ErrorKind, Read as _, Seek as _, Write as _};
33-44-#[cfg(unix)]
55-use std::os::fd::AsRawFd as _;
66-#[cfg(windows)]
77-use std::os::windows::io::AsRawHandle;
88-99-use memmap2::{MmapMut, MmapOptions};
1010-1111-pub(crate) struct MappedFile {
1212- /// The underlying file handle.
1313- file: std::fs::File,
1414- /// The length of the file.
1515- len: u64,
1616- /// The mapped memory region.
1717- map: MmapMut,
1818- /// Our current offset into the file.
1919- off: u64,
2020-}
2121-2222-impl MappedFile {
2323- pub(crate) fn new(mut f: std::fs::File) -> std::io::Result<Self> {
2424- let len = f.seek(std::io::SeekFrom::End(0))?;
2525-2626- #[cfg(windows)]
2727- let raw = f.as_raw_handle();
2828- #[cfg(unix)]
2929- let raw = f.as_raw_fd();
3030-3131- #[expect(unsafe_code)]
3232- Ok(Self {
3333- // SAFETY:
3434- // All file-backed memory map constructors are marked \
3535- // unsafe because of the potential for Undefined Behavior (UB) \
3636- // using the map if the underlying file is subsequently modified, in or out of process.
3737- map: unsafe { MmapOptions::new().map_mut(raw)? },
3838- file: f,
3939- len,
4040- off: 0,
4141- })
4242- }
4343-4444- /// Resize the memory-mapped file. This will reallocate the memory mapping.
4545- #[expect(unsafe_code)]
4646- fn resize(&mut self, len: u64) -> std::io::Result<()> {
4747- // Resize the file.
4848- self.file.set_len(len)?;
4949-5050- #[cfg(windows)]
5151- let raw = self.file.as_raw_handle();
5252- #[cfg(unix)]
5353- let raw = self.file.as_raw_fd();
5454-5555- // SAFETY:
5656- // All file-backed memory map constructors are marked \
5757- // unsafe because of the potential for Undefined Behavior (UB) \
5858- // using the map if the underlying file is subsequently modified, in or out of process.
5959- self.map = unsafe { MmapOptions::new().map_mut(raw)? };
6060- self.len = len;
6161-6262- Ok(())
6363- }
6464-}
6565-6666-impl std::io::Read for MappedFile {
6767- fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
6868- if self.off == self.len {
6969- // If we're at EOF, return an EOF error code. `Ok(0)` tends to trip up some implementations.
7070- return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "eof"));
7171- }
7272-7373- // Calculate the number of bytes we're going to read.
7474- let remaining_bytes = self.len.saturating_sub(self.off);
7575- let buf_len = u64::try_from(buf.len()).unwrap_or(u64::MAX);
7676- let len = usize::try_from(std::cmp::min(remaining_bytes, buf_len)).unwrap_or(usize::MAX);
7777-7878- let off = usize::try_from(self.off).map_err(|e| {
7979- std::io::Error::new(
8080- ErrorKind::InvalidInput,
8181- format!("offset too large for this platform: {e}"),
8282- )
8383- })?;
8484-8585- if let (Some(dest), Some(src)) = (
8686- buf.get_mut(..len),
8787- self.map.get(off..off.saturating_add(len)),
8888- ) {
8989- dest.copy_from_slice(src);
9090- self.off = self.off.saturating_add(u64::try_from(len).unwrap_or(0));
9191- Ok(len)
9292- } else {
9393- Err(std::io::Error::new(
9494- ErrorKind::InvalidInput,
9595- "invalid buffer range",
9696- ))
9797- }
9898- }
9999-}
100100-101101-impl std::io::Write for MappedFile {
102102- fn flush(&mut self) -> std::io::Result<()> {
103103- // This is done by the system.
104104- Ok(())
105105- }
106106- fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
107107- // Determine if we need to resize the file.
108108- let buf_len = u64::try_from(buf.len()).map_err(|e| {
109109- std::io::Error::new(
110110- ErrorKind::InvalidInput,
111111- format!("buffer length too large for this platform: {e}"),
112112- )
113113- })?;
114114-115115- if self.off.saturating_add(buf_len) >= self.len {
116116- self.resize(self.off.saturating_add(buf_len))?;
117117- }
118118-119119- let off = usize::try_from(self.off).map_err(|e| {
120120- std::io::Error::new(
121121- ErrorKind::InvalidInput,
122122- format!("offset too large for this platform: {e}"),
123123- )
124124- })?;
125125- let len = buf.len();
126126-127127- if let Some(dest) = self.map.get_mut(off..off.saturating_add(len)) {
128128- dest.copy_from_slice(buf);
129129- self.off = self.off.saturating_add(buf_len);
130130- Ok(len)
131131- } else {
132132- Err(std::io::Error::new(
133133- ErrorKind::InvalidInput,
134134- "invalid buffer range",
135135- ))
136136- }
137137- }
138138-}
139139-140140-impl std::io::Seek for MappedFile {
141141- fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
142142- let off = match pos {
143143- std::io::SeekFrom::Start(i) => i,
144144- std::io::SeekFrom::End(i) => {
145145- if i <= 0 {
146146- // If i is negative or zero, we're seeking backwards from the end
147147- // or exactly at the end
148148- self.len.saturating_sub(i.unsigned_abs())
149149- } else {
150150- // If i is positive, we're seeking beyond the end, which is allowed
151151- // but requires extending the file
152152- self.len.saturating_add(i.unsigned_abs())
153153- }
154154- }
155155- std::io::SeekFrom::Current(i) => {
156156- if i >= 0 {
157157- self.off.saturating_add(i.unsigned_abs())
158158- } else {
159159- self.off.saturating_sub(i.unsigned_abs())
160160- }
161161- }
162162- };
163163-164164- // If the offset is beyond EOF, extend the file to the new size.
165165- if off > self.len {
166166- self.resize(off)?;
167167- }
168168-169169- self.off = off;
170170- Ok(off)
171171- }
172172-}
173173-174174-impl tokio::io::AsyncRead for MappedFile {
175175- fn poll_read(
176176- mut self: std::pin::Pin<&mut Self>,
177177- _cx: &mut std::task::Context<'_>,
178178- buf: &mut tokio::io::ReadBuf<'_>,
179179- ) -> std::task::Poll<std::io::Result<()>> {
180180- let wbuf = buf.initialize_unfilled();
181181- let len = wbuf.len();
182182-183183- std::task::Poll::Ready(match self.read(wbuf) {
184184- Ok(_) => {
185185- buf.advance(len);
186186- Ok(())
187187- }
188188- Err(e) => Err(e),
189189- })
190190- }
191191-}
192192-193193-impl tokio::io::AsyncWrite for MappedFile {
194194- fn poll_flush(
195195- self: std::pin::Pin<&mut Self>,
196196- _cx: &mut std::task::Context<'_>,
197197- ) -> std::task::Poll<Result<(), std::io::Error>> {
198198- std::task::Poll::Ready(Ok(()))
199199- }
200200-201201- fn poll_shutdown(
202202- self: std::pin::Pin<&mut Self>,
203203- _cx: &mut std::task::Context<'_>,
204204- ) -> std::task::Poll<Result<(), std::io::Error>> {
205205- std::task::Poll::Ready(Ok(()))
206206- }
207207-208208- fn poll_write(
209209- mut self: std::pin::Pin<&mut Self>,
210210- _cx: &mut std::task::Context<'_>,
211211- buf: &[u8],
212212- ) -> std::task::Poll<Result<usize, std::io::Error>> {
213213- std::task::Poll::Ready(self.write(buf))
214214- }
215215-}
216216-217217-impl tokio::io::AsyncSeek for MappedFile {
218218- fn poll_complete(
219219- self: std::pin::Pin<&mut Self>,
220220- _cx: &mut std::task::Context<'_>,
221221- ) -> std::task::Poll<std::io::Result<u64>> {
222222- std::task::Poll::Ready(Ok(self.off))
223223- }
224224-225225- fn start_seek(
226226- mut self: std::pin::Pin<&mut Self>,
227227- position: std::io::SeekFrom,
228228- ) -> std::io::Result<()> {
229229- self.seek(position).map(|_p| ())
230230- }
231231-}
232232-233233-#[cfg(test)]
234234-mod test {
235235- use rand::Rng as _;
236236- use std::io::Write as _;
237237-238238- use super::*;
239239-240240- #[test]
241241- fn basic_rw() {
242242- let tmp = std::env::temp_dir().join(
243243- rand::thread_rng()
244244- .sample_iter(rand::distributions::Alphanumeric)
245245- .take(10)
246246- .map(char::from)
247247- .collect::<String>(),
248248- );
249249-250250- let mut m = MappedFile::new(
251251- std::fs::File::options()
252252- .create(true)
253253- .truncate(true)
254254- .read(true)
255255- .write(true)
256256- .open(&tmp)
257257- .expect("Failed to open temporary file"),
258258- )
259259- .expect("Failed to create MappedFile");
260260-261261- m.write_all(b"abcd123").expect("Failed to write data");
262262- let _: u64 = m
263263- .seek(std::io::SeekFrom::Start(0))
264264- .expect("Failed to seek to start");
265265-266266- let mut buf = [0_u8; 7];
267267- m.read_exact(&mut buf).expect("Failed to read data");
268268-269269- assert_eq!(&buf, b"abcd123");
270270-271271- drop(m);
272272- std::fs::remove_file(tmp).expect("Failed to remove temporary file");
273273- }
274274-}