at protocol indexer with flexible filtering, xrpc queries, and a cursor-backed event stream, built on fjall
at-protocol
atproto
indexer
rust
fjall
1use crate::api::AppState;
2use crate::db::types::DbRkey;
3use crate::db::{self, Db, keys};
4use axum::extract::FromRequest;
5use axum::response::IntoResponse;
6use axum::{Json, Router, extract::State, http::StatusCode};
7use futures::TryFutureExt;
8use jacquard_api::com_atproto::repo::{
9 get_record::{GetRecordError, GetRecordOutput, GetRecordRequest},
10 list_records::{ListRecordsOutput, ListRecordsRequest, Record as RepoRecord},
11};
12use jacquard_common::CowStr;
13use jacquard_common::cowstr::ToCowStr;
14use jacquard_common::types::ident::AtIdentifier;
15use jacquard_common::xrpc::{XrpcEndpoint, XrpcMethod};
16use jacquard_common::{IntoStatic, xrpc::XrpcRequest};
17use jacquard_common::{
18 types::{
19 string::{AtUri, Cid},
20 value::Data,
21 },
22 xrpc::{GenericXrpcError, XrpcError},
23};
24use miette::IntoDiagnostic;
25use serde::{Deserialize, Serialize};
26use smol_str::ToSmolStr;
27use std::{fmt::Display, sync::Arc};
28use tokio::task::spawn_blocking;
29
30pub fn router() -> Router<Arc<AppState>> {
31 Router::new()
32 .route(
33 GetRecordRequest::PATH,
34 axum::routing::get(handle_get_record),
35 )
36 .route(
37 ListRecordsRequest::PATH,
38 axum::routing::get(handle_list_records),
39 )
40 .route(CountRecords::PATH, axum::routing::get(handle_count_records))
41}
42
43#[derive(Debug)]
44pub struct XrpcErrorResponse<E: IntoStatic + std::error::Error = GenericXrpcError> {
45 pub status: StatusCode,
46 pub error: XrpcError<E>,
47}
48
49impl<E: Serialize + IntoStatic + std::error::Error> IntoResponse for XrpcErrorResponse<E> {
50 fn into_response(self) -> axum::response::Response {
51 (self.status, Json(self.error)).into_response()
52 }
53}
54
55pub type XrpcResult<T, E = GenericXrpcError> = Result<T, XrpcErrorResponse<E>>;
56
57pub struct ExtractXrpc<E: XrpcEndpoint>(pub E::Request<'static>);
58
59impl<S, E> FromRequest<S> for ExtractXrpc<E>
60where
61 S: Send + Sync,
62 E: XrpcEndpoint,
63 E::Request<'static>: Send,
64 for<'de> E::Request<'de>: Deserialize<'de> + IntoStatic<Output = E::Request<'static>>,
65{
66 type Rejection = XrpcErrorResponse<GenericXrpcError>;
67
68 async fn from_request(
69 req: axum::extract::Request,
70 _state: &S,
71 ) -> Result<Self, Self::Rejection> {
72 let nsid = E::Request::<'static>::NSID;
73 match E::METHOD {
74 XrpcMethod::Query => {
75 let query = req.uri().query().unwrap_or("");
76 let res: E::Request<'_> =
77 serde_urlencoded::from_str(query).map_err(|e| bad_request(nsid, e))?;
78 Ok(ExtractXrpc(res.into_static()))
79 }
80 XrpcMethod::Procedure(_) => {
81 let body = axum::body::to_bytes(req.into_body(), usize::MAX)
82 .await
83 .map_err(|e| internal_error(nsid, e))?;
84 let res: E::Request<'_> =
85 serde_json::from_slice(&body).map_err(|e| bad_request(nsid, e))?;
86 Ok(ExtractXrpc(res.into_static()))
87 }
88 }
89 }
90}
91
92fn internal_error<E: std::error::Error + IntoStatic>(
93 nsid: &'static str,
94 message: impl Display,
95) -> XrpcErrorResponse<E> {
96 XrpcErrorResponse {
97 status: StatusCode::INTERNAL_SERVER_ERROR,
98 error: XrpcError::Generic(GenericXrpcError {
99 error: "InternalError".into(),
100 message: Some(message.to_smolstr()),
101 nsid,
102 method: "GET",
103 http_status: StatusCode::INTERNAL_SERVER_ERROR,
104 }),
105 }
106}
107
108fn bad_request<E: std::error::Error + IntoStatic>(
109 nsid: &'static str,
110 message: impl Display,
111) -> XrpcErrorResponse<E> {
112 XrpcErrorResponse {
113 status: StatusCode::BAD_REQUEST,
114 error: XrpcError::Generic(GenericXrpcError {
115 error: "InvalidRequest".into(),
116 message: Some(message.to_smolstr()),
117 nsid,
118 method: "GET",
119 http_status: StatusCode::BAD_REQUEST,
120 }),
121 }
122}
123
124pub async fn handle_get_record(
125 State(state): State<Arc<AppState>>,
126 ExtractXrpc(req): ExtractXrpc<GetRecordRequest>,
127) -> Result<Json<GetRecordOutput<'static>>, XrpcErrorResponse<GetRecordError<'static>>> {
128 let db = &state.db;
129 let did = state
130 .resolver
131 .resolve_did(&req.repo)
132 .await
133 .map_err(|e| bad_request(GetRecordRequest::PATH, e))?;
134
135 let db_key = keys::record_key(
136 &did,
137 req.collection.as_str(),
138 &DbRkey::new(req.rkey.0.as_str()),
139 );
140
141 let cid_bytes = Db::get(db.records.clone(), db_key)
142 .await
143 .map_err(|e| internal_error(GetRecordRequest::PATH, e))?;
144
145 if let Some(cid_bytes) = cid_bytes {
146 // lookup block using binary cid
147 let block_bytes = Db::get(db.blocks.clone(), cid_bytes.clone())
148 .await
149 .map_err(|e| internal_error(GetRecordRequest::PATH, e))?
150 .ok_or_else(|| internal_error(GetRecordRequest::PATH, "not found"))?;
151
152 let value: Data = serde_ipld_dagcbor::from_slice(&block_bytes)
153 .map_err(|e| internal_error(GetRecordRequest::PATH, e))?;
154
155 let cid = Cid::new(&cid_bytes)
156 .map_err(|e| internal_error(GetRecordRequest::PATH, e))?
157 .into_static();
158
159 Ok(Json(GetRecordOutput {
160 uri: AtUri::from_parts_owned(
161 did.as_str(),
162 req.collection.as_str(),
163 req.rkey.0.as_str(),
164 )
165 .unwrap(),
166 cid: Some(Cid::Str(cid.to_cowstr()).into_static()),
167 value: value.into_static(),
168 extra_data: Default::default(),
169 }))
170 } else {
171 Err(XrpcErrorResponse {
172 status: StatusCode::NOT_FOUND,
173 error: XrpcError::Xrpc(GetRecordError::RecordNotFound(None)),
174 })
175 }
176}
177
178pub async fn handle_list_records(
179 State(state): State<Arc<AppState>>,
180 ExtractXrpc(req): ExtractXrpc<ListRecordsRequest>,
181) -> Result<Json<ListRecordsOutput<'static>>, XrpcErrorResponse<GenericXrpcError>> {
182 let db = &state.db;
183 let did = state
184 .resolver
185 .resolve_did(&req.repo)
186 .await
187 .map_err(|e| bad_request(ListRecordsRequest::PATH, e))?;
188
189 let ks = db.records.clone();
190
191 let prefix = keys::record_prefix_collection(&did, req.collection.as_str());
192
193 let limit = req.limit.unwrap_or(50).min(100) as usize;
194 let reverse = req.reverse.unwrap_or(false);
195 let blocks_ks = db.blocks.clone();
196
197 let (results, cursor) = tokio::task::spawn_blocking(move || {
198 let mut results = Vec::new();
199 let mut cursor = None;
200
201 let iter: Box<dyn Iterator<Item = _>> = if !reverse {
202 let mut end_prefix = prefix.clone();
203 if let Some(last) = end_prefix.last_mut() {
204 *last += 1;
205 }
206
207 let end_key = if let Some(cursor) = &req.cursor {
208 let mut k = prefix.clone();
209 k.extend_from_slice(cursor.as_bytes());
210 k
211 } else {
212 end_prefix
213 };
214
215 Box::new(ks.range(prefix.as_slice()..end_key.as_slice()).rev())
216 } else {
217 let start_key = if let Some(cursor) = &req.cursor {
218 let mut k = prefix.clone();
219 k.extend_from_slice(cursor.as_bytes());
220 k.push(0);
221 k
222 } else {
223 prefix.clone()
224 };
225
226 Box::new(ks.range(start_key.as_slice()..))
227 };
228
229 for item in iter {
230 let (key, cid_bytes) = item.into_inner().into_diagnostic()?;
231
232 if !key.starts_with(prefix.as_slice()) {
233 break;
234 }
235
236 let rkey = keys::parse_rkey(&key[prefix.len()..])?;
237 if results.len() >= limit {
238 cursor = Some(rkey);
239 break;
240 }
241
242 // look up using binary cid bytes from the record
243 if let Ok(Some(block_bytes)) = blocks_ks.get(&cid_bytes) {
244 let val: Data = serde_ipld_dagcbor::from_slice(&block_bytes).unwrap_or(Data::Null);
245 let cid =
246 Cid::Str(Cid::new(&cid_bytes).into_diagnostic()?.to_cowstr()).into_static();
247 results.push(RepoRecord {
248 uri: AtUri::from_parts_owned(
249 did.as_str(),
250 req.collection.as_str(),
251 rkey.to_smolstr(),
252 )
253 .into_diagnostic()?,
254 cid,
255 value: val.into_static(),
256 extra_data: Default::default(),
257 });
258 }
259 }
260 Result::<_, miette::Report>::Ok((results, cursor))
261 })
262 .await
263 .map_err(|e| internal_error(ListRecordsRequest::PATH, e))?
264 .map_err(|e| internal_error(ListRecordsRequest::PATH, e))?;
265
266 Ok(Json(ListRecordsOutput {
267 records: results,
268 cursor: cursor.map(|c| CowStr::Owned(c.to_smolstr())),
269 extra_data: Default::default(),
270 }))
271}
272
273#[derive(Serialize, Deserialize, jacquard_derive::IntoStatic)]
274pub struct CountRecordsOutput {
275 pub count: u64,
276}
277
278pub struct CountRecordsResponse;
279impl jacquard_common::xrpc::XrpcResp for CountRecordsResponse {
280 const NSID: &'static str = "systems.gaze.hydrant.countRecords";
281 const ENCODING: &'static str = "application/json";
282 type Output<'de> = CountRecordsOutput;
283 type Err<'de> = GenericXrpcError;
284}
285
286#[derive(Serialize, Deserialize, jacquard_derive::IntoStatic)]
287pub struct CountRecordsRequestData<'i> {
288 #[serde(borrow)]
289 pub identifier: AtIdentifier<'i>,
290 pub collection: String,
291}
292
293impl<'a> jacquard_common::xrpc::XrpcRequest for CountRecordsRequestData<'a> {
294 const NSID: &'static str = "systems.gaze.hydrant.countRecords";
295 const METHOD: jacquard_common::xrpc::XrpcMethod = jacquard_common::xrpc::XrpcMethod::Query;
296 type Response = CountRecordsResponse;
297}
298
299pub struct CountRecords;
300impl jacquard_common::xrpc::XrpcEndpoint for CountRecords {
301 const PATH: &'static str = "/xrpc/systems.gaze.hydrant.countRecords";
302 const METHOD: jacquard_common::xrpc::XrpcMethod = jacquard_common::xrpc::XrpcMethod::Query;
303 type Request<'de> = CountRecordsRequestData<'de>;
304 type Response = CountRecordsResponse;
305}
306
307pub async fn handle_count_records(
308 State(state): State<Arc<AppState>>,
309 ExtractXrpc(req): ExtractXrpc<CountRecords>,
310) -> XrpcResult<Json<CountRecordsOutput>> {
311 let did = state
312 .resolver
313 .resolve_did(&req.identifier)
314 .await
315 .map_err(|e| bad_request(CountRecords::PATH, e))?;
316
317 let count = spawn_blocking(move || {
318 db::get_record_count(&state.db, &did, &req.collection)
319 .map_err(|e| internal_error(CountRecords::PATH, e))
320 })
321 .map_err(|e| internal_error(CountRecords::PATH, e))
322 .await??;
323
324 Ok(Json(CountRecordsOutput { count }))
325}