at protocol indexer with flexible filtering, xrpc queries, and a cursor-backed event stream, built on fjall
at-protocol atproto indexer rust fjall
at 8990a2ff5651c71cb6fa29aa3c44fa4212b6a4f6 325 lines 11 kB view raw
1use crate::api::AppState; 2use crate::db::types::DbRkey; 3use crate::db::{self, Db, keys}; 4use axum::extract::FromRequest; 5use axum::response::IntoResponse; 6use axum::{Json, Router, extract::State, http::StatusCode}; 7use futures::TryFutureExt; 8use jacquard_api::com_atproto::repo::{ 9 get_record::{GetRecordError, GetRecordOutput, GetRecordRequest}, 10 list_records::{ListRecordsOutput, ListRecordsRequest, Record as RepoRecord}, 11}; 12use jacquard_common::CowStr; 13use jacquard_common::cowstr::ToCowStr; 14use jacquard_common::types::ident::AtIdentifier; 15use jacquard_common::xrpc::{XrpcEndpoint, XrpcMethod}; 16use jacquard_common::{IntoStatic, xrpc::XrpcRequest}; 17use jacquard_common::{ 18 types::{ 19 string::{AtUri, Cid}, 20 value::Data, 21 }, 22 xrpc::{GenericXrpcError, XrpcError}, 23}; 24use miette::IntoDiagnostic; 25use serde::{Deserialize, Serialize}; 26use smol_str::ToSmolStr; 27use std::{fmt::Display, sync::Arc}; 28use tokio::task::spawn_blocking; 29 30pub fn router() -> Router<Arc<AppState>> { 31 Router::new() 32 .route( 33 GetRecordRequest::PATH, 34 axum::routing::get(handle_get_record), 35 ) 36 .route( 37 ListRecordsRequest::PATH, 38 axum::routing::get(handle_list_records), 39 ) 40 .route(CountRecords::PATH, axum::routing::get(handle_count_records)) 41} 42 43#[derive(Debug)] 44pub struct XrpcErrorResponse<E: IntoStatic + std::error::Error = GenericXrpcError> { 45 pub status: StatusCode, 46 pub error: XrpcError<E>, 47} 48 49impl<E: Serialize + IntoStatic + std::error::Error> IntoResponse for XrpcErrorResponse<E> { 50 fn into_response(self) -> axum::response::Response { 51 (self.status, Json(self.error)).into_response() 52 } 53} 54 55pub type XrpcResult<T, E = GenericXrpcError> = Result<T, XrpcErrorResponse<E>>; 56 57pub struct ExtractXrpc<E: XrpcEndpoint>(pub E::Request<'static>); 58 59impl<S, E> FromRequest<S> for ExtractXrpc<E> 60where 61 S: Send + Sync, 62 E: XrpcEndpoint, 63 E::Request<'static>: Send, 64 for<'de> E::Request<'de>: Deserialize<'de> + IntoStatic<Output = E::Request<'static>>, 65{ 66 type Rejection = XrpcErrorResponse<GenericXrpcError>; 67 68 async fn from_request( 69 req: axum::extract::Request, 70 _state: &S, 71 ) -> Result<Self, Self::Rejection> { 72 let nsid = E::Request::<'static>::NSID; 73 match E::METHOD { 74 XrpcMethod::Query => { 75 let query = req.uri().query().unwrap_or(""); 76 let res: E::Request<'_> = 77 serde_urlencoded::from_str(query).map_err(|e| bad_request(nsid, e))?; 78 Ok(ExtractXrpc(res.into_static())) 79 } 80 XrpcMethod::Procedure(_) => { 81 let body = axum::body::to_bytes(req.into_body(), usize::MAX) 82 .await 83 .map_err(|e| internal_error(nsid, e))?; 84 let res: E::Request<'_> = 85 serde_json::from_slice(&body).map_err(|e| bad_request(nsid, e))?; 86 Ok(ExtractXrpc(res.into_static())) 87 } 88 } 89 } 90} 91 92fn internal_error<E: std::error::Error + IntoStatic>( 93 nsid: &'static str, 94 message: impl Display, 95) -> XrpcErrorResponse<E> { 96 XrpcErrorResponse { 97 status: StatusCode::INTERNAL_SERVER_ERROR, 98 error: XrpcError::Generic(GenericXrpcError { 99 error: "InternalError".into(), 100 message: Some(message.to_smolstr()), 101 nsid, 102 method: "GET", 103 http_status: StatusCode::INTERNAL_SERVER_ERROR, 104 }), 105 } 106} 107 108fn bad_request<E: std::error::Error + IntoStatic>( 109 nsid: &'static str, 110 message: impl Display, 111) -> XrpcErrorResponse<E> { 112 XrpcErrorResponse { 113 status: StatusCode::BAD_REQUEST, 114 error: XrpcError::Generic(GenericXrpcError { 115 error: "InvalidRequest".into(), 116 message: Some(message.to_smolstr()), 117 nsid, 118 method: "GET", 119 http_status: StatusCode::BAD_REQUEST, 120 }), 121 } 122} 123 124pub async fn handle_get_record( 125 State(state): State<Arc<AppState>>, 126 ExtractXrpc(req): ExtractXrpc<GetRecordRequest>, 127) -> Result<Json<GetRecordOutput<'static>>, XrpcErrorResponse<GetRecordError<'static>>> { 128 let db = &state.db; 129 let did = state 130 .resolver 131 .resolve_did(&req.repo) 132 .await 133 .map_err(|e| bad_request(GetRecordRequest::PATH, e))?; 134 135 let db_key = keys::record_key( 136 &did, 137 req.collection.as_str(), 138 &DbRkey::new(req.rkey.0.as_str()), 139 ); 140 141 let cid_bytes = Db::get(db.records.clone(), db_key) 142 .await 143 .map_err(|e| internal_error(GetRecordRequest::PATH, e))?; 144 145 if let Some(cid_bytes) = cid_bytes { 146 // lookup block using binary cid 147 let block_bytes = Db::get(db.blocks.clone(), cid_bytes.clone()) 148 .await 149 .map_err(|e| internal_error(GetRecordRequest::PATH, e))? 150 .ok_or_else(|| internal_error(GetRecordRequest::PATH, "not found"))?; 151 152 let value: Data = serde_ipld_dagcbor::from_slice(&block_bytes) 153 .map_err(|e| internal_error(GetRecordRequest::PATH, e))?; 154 155 let cid = Cid::new(&cid_bytes) 156 .map_err(|e| internal_error(GetRecordRequest::PATH, e))? 157 .into_static(); 158 159 Ok(Json(GetRecordOutput { 160 uri: AtUri::from_parts_owned( 161 did.as_str(), 162 req.collection.as_str(), 163 req.rkey.0.as_str(), 164 ) 165 .unwrap(), 166 cid: Some(Cid::Str(cid.to_cowstr()).into_static()), 167 value: value.into_static(), 168 extra_data: Default::default(), 169 })) 170 } else { 171 Err(XrpcErrorResponse { 172 status: StatusCode::NOT_FOUND, 173 error: XrpcError::Xrpc(GetRecordError::RecordNotFound(None)), 174 }) 175 } 176} 177 178pub async fn handle_list_records( 179 State(state): State<Arc<AppState>>, 180 ExtractXrpc(req): ExtractXrpc<ListRecordsRequest>, 181) -> Result<Json<ListRecordsOutput<'static>>, XrpcErrorResponse<GenericXrpcError>> { 182 let db = &state.db; 183 let did = state 184 .resolver 185 .resolve_did(&req.repo) 186 .await 187 .map_err(|e| bad_request(ListRecordsRequest::PATH, e))?; 188 189 let ks = db.records.clone(); 190 191 let prefix = keys::record_prefix_collection(&did, req.collection.as_str()); 192 193 let limit = req.limit.unwrap_or(50).min(100) as usize; 194 let reverse = req.reverse.unwrap_or(false); 195 let blocks_ks = db.blocks.clone(); 196 197 let (results, cursor) = tokio::task::spawn_blocking(move || { 198 let mut results = Vec::new(); 199 let mut cursor = None; 200 201 let iter: Box<dyn Iterator<Item = _>> = if !reverse { 202 let mut end_prefix = prefix.clone(); 203 if let Some(last) = end_prefix.last_mut() { 204 *last += 1; 205 } 206 207 let end_key = if let Some(cursor) = &req.cursor { 208 let mut k = prefix.clone(); 209 k.extend_from_slice(cursor.as_bytes()); 210 k 211 } else { 212 end_prefix 213 }; 214 215 Box::new(ks.range(prefix.as_slice()..end_key.as_slice()).rev()) 216 } else { 217 let start_key = if let Some(cursor) = &req.cursor { 218 let mut k = prefix.clone(); 219 k.extend_from_slice(cursor.as_bytes()); 220 k.push(0); 221 k 222 } else { 223 prefix.clone() 224 }; 225 226 Box::new(ks.range(start_key.as_slice()..)) 227 }; 228 229 for item in iter { 230 let (key, cid_bytes) = item.into_inner().into_diagnostic()?; 231 232 if !key.starts_with(prefix.as_slice()) { 233 break; 234 } 235 236 let rkey = keys::parse_rkey(&key[prefix.len()..])?; 237 if results.len() >= limit { 238 cursor = Some(rkey); 239 break; 240 } 241 242 // look up using binary cid bytes from the record 243 if let Ok(Some(block_bytes)) = blocks_ks.get(&cid_bytes) { 244 let val: Data = serde_ipld_dagcbor::from_slice(&block_bytes).unwrap_or(Data::Null); 245 let cid = 246 Cid::Str(Cid::new(&cid_bytes).into_diagnostic()?.to_cowstr()).into_static(); 247 results.push(RepoRecord { 248 uri: AtUri::from_parts_owned( 249 did.as_str(), 250 req.collection.as_str(), 251 rkey.to_smolstr(), 252 ) 253 .into_diagnostic()?, 254 cid, 255 value: val.into_static(), 256 extra_data: Default::default(), 257 }); 258 } 259 } 260 Result::<_, miette::Report>::Ok((results, cursor)) 261 }) 262 .await 263 .map_err(|e| internal_error(ListRecordsRequest::PATH, e))? 264 .map_err(|e| internal_error(ListRecordsRequest::PATH, e))?; 265 266 Ok(Json(ListRecordsOutput { 267 records: results, 268 cursor: cursor.map(|c| CowStr::Owned(c.to_smolstr())), 269 extra_data: Default::default(), 270 })) 271} 272 273#[derive(Serialize, Deserialize, jacquard_derive::IntoStatic)] 274pub struct CountRecordsOutput { 275 pub count: u64, 276} 277 278pub struct CountRecordsResponse; 279impl jacquard_common::xrpc::XrpcResp for CountRecordsResponse { 280 const NSID: &'static str = "systems.gaze.hydrant.countRecords"; 281 const ENCODING: &'static str = "application/json"; 282 type Output<'de> = CountRecordsOutput; 283 type Err<'de> = GenericXrpcError; 284} 285 286#[derive(Serialize, Deserialize, jacquard_derive::IntoStatic)] 287pub struct CountRecordsRequestData<'i> { 288 #[serde(borrow)] 289 pub identifier: AtIdentifier<'i>, 290 pub collection: String, 291} 292 293impl<'a> jacquard_common::xrpc::XrpcRequest for CountRecordsRequestData<'a> { 294 const NSID: &'static str = "systems.gaze.hydrant.countRecords"; 295 const METHOD: jacquard_common::xrpc::XrpcMethod = jacquard_common::xrpc::XrpcMethod::Query; 296 type Response = CountRecordsResponse; 297} 298 299pub struct CountRecords; 300impl jacquard_common::xrpc::XrpcEndpoint for CountRecords { 301 const PATH: &'static str = "/xrpc/systems.gaze.hydrant.countRecords"; 302 const METHOD: jacquard_common::xrpc::XrpcMethod = jacquard_common::xrpc::XrpcMethod::Query; 303 type Request<'de> = CountRecordsRequestData<'de>; 304 type Response = CountRecordsResponse; 305} 306 307pub async fn handle_count_records( 308 State(state): State<Arc<AppState>>, 309 ExtractXrpc(req): ExtractXrpc<CountRecords>, 310) -> XrpcResult<Json<CountRecordsOutput>> { 311 let did = state 312 .resolver 313 .resolve_did(&req.identifier) 314 .await 315 .map_err(|e| bad_request(CountRecords::PATH, e))?; 316 317 let count = spawn_blocking(move || { 318 db::get_record_count(&state.db, &did, &req.collection) 319 .map_err(|e| internal_error(CountRecords::PATH, e)) 320 }) 321 .map_err(|e| internal_error(CountRecords::PATH, e)) 322 .await??; 323 324 Ok(Json(CountRecordsOutput { count })) 325}