aturi indexer with listRecords and countRecords endpoints
1use axum::{
2 extract::{Query, State},
3 routing::get,
4 Json, Router,
5};
6use fjall::{Database, Keyspace, KeyspaceCreateOptions};
7use serde::{Deserialize, Serialize};
8use std::{
9 pin,
10 sync::{
11 atomic::{AtomicU64, Ordering},
12 Arc,
13 },
14 time::Duration,
15};
16use tapped::{Event, RecordAction, RecordEvent, TapClient};
17use tokio::signal::unix::SignalKind;
18use tokio_util::sync::CancellationToken;
19use tracing::{error, info, warn};
20
21#[global_allocator]
22static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
23
24#[derive(Clone)]
25struct AppState {
26 db: Database,
27 counts: Keyspace,
28}
29
30#[tokio::main]
31async fn main() -> anyhow::Result<()> {
32 tracing_subscriber::fmt::init();
33
34 let db = Database::builder("aturlist_fjall").open()?;
35 // open keyspaces
36 let counts = db.keyspace("counts", || KeyspaceCreateOptions::default())?;
37
38 let state = AppState {
39 db: db.clone(),
40 counts: counts.clone(),
41 };
42
43 let ops_count = Arc::new(AtomicU64::new(0));
44
45 // start tap consumers
46 let num_consumers = std::env::var("TAP_CONCURRENCY")
47 .ok()
48 .and_then(|s| s.parse().ok())
49 .unwrap_or(20);
50 let closed = CancellationToken::new();
51
52 for i in 0..num_consumers {
53 let db_clone = state.db.clone();
54 let counts_clone = state.counts.clone();
55 let ops_count_clone = ops_count.clone();
56 let closed = closed.child_token();
57 tokio::spawn(async move {
58 info!("starting consumer #{}", i);
59 run_tap_consumer(db_clone, counts_clone, ops_count_clone, closed).await;
60 });
61 }
62
63 // start stats reporter
64 let ops_count_stats = ops_count.clone();
65 tokio::spawn(async move {
66 let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
67 let mut last_count = 0;
68 // The first tick completes immediately
69 interval.tick().await;
70
71 loop {
72 interval.tick().await;
73 let current_count = ops_count_stats.load(Ordering::Relaxed);
74 let delta = current_count - last_count;
75 let ops_sec = delta as f64 / 60.0;
76 info!(
77 "stats: total_ops={} delta_ops={} ops_sec={:.2}",
78 current_count, delta, ops_sec
79 );
80 last_count = current_count;
81 }
82 });
83
84 let app = Router::new()
85 .route("/xrpc/systems.gaze.aturlist.listRecords", get(list_records))
86 .route(
87 "/xrpc/systems.gaze.aturlist.countRecords",
88 get(count_records),
89 )
90 .with_state(state);
91
92 let listener = tokio::net::TcpListener::bind("0.0.0.0:7155").await?;
93 info!("listening on {}", listener.local_addr()?);
94
95 let mut _sigterm = tokio::signal::unix::signal(SignalKind::terminate())?;
96 let mut _sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
97 let sigterm = pin::pin!(_sigterm.recv());
98 let sigint = pin::pin!(_sigint.recv());
99 let terminating = futures::future::select(sigterm, sigint);
100
101 tokio::select! {
102 res = axum::serve(listener, app) => res?,
103 _ = terminating => {
104 info!("shutting down!");
105 closed.cancel();
106 }
107 }
108
109 info!("waiting 10 seconds for cleanup...");
110 tokio::time::sleep(Duration::from_secs(10)).await;
111
112 info!("byebye! (_ _*)Zzz");
113
114 Ok(())
115}
116
117async fn run_tap_consumer(
118 db: Database,
119 counts: Keyspace,
120 ops_count: Arc<AtomicU64>,
121 closed: CancellationToken,
122) {
123 let tap_url = "http://localhost:2480";
124
125 'outer: loop {
126 info!("connecting to tap at {}", tap_url);
127 match TapClient::new(tap_url) {
128 Ok(client) => {
129 if let Err(e) = client.health().await {
130 warn!("tap health check failed: {}", e);
131 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
132 continue;
133 }
134
135 match client.channel().await {
136 Ok((mut receiver, mut ack_sender)) => {
137 info!("connected to tap firehose");
138 loop {
139 tokio::select! {
140 ev = receiver.recv() => {
141 match ev {
142 Ok((event, ack_id)) => {
143 let mut handled = true;
144
145 if let Event::Record(rec) = event {
146 match db.keyspace(&rec.collection, KeyspaceCreateOptions::default) {
147 Ok(ks) => {
148 let counts = counts.clone();
149 let ops_count = ops_count.clone();
150 handled =
151 tokio::task::spawn_blocking(move || {
152 if let Err(e) = handle_record(&counts, &ks, rec) {
153 error!("error handling record: {}", e);
154 false
155 } else {
156 ops_count.fetch_add(1, Ordering::Relaxed);
157 true
158 }
159 })
160 .await
161 .expect("couldnt join task");
162 }
163 Err(err) => {
164 error!(
165 "failed to open keyspace for {}: {}",
166 rec.collection, err
167 );
168 }
169 }
170 }
171
172 if handled {
173 if let Err(e) = ack_sender.ack(ack_id).await {
174 warn!("failed to ack event: {}", e);
175 break;
176 }
177 }
178 }
179 Err(err) => {
180 warn!("tap channel closed: {err}");
181 break;
182 }
183 }
184 }
185 _ = closed.cancelled() => break 'outer,
186 }
187 }
188 }
189 Err(e) => {
190 warn!("failed to subscribe to channel: {}", e);
191 }
192 }
193 }
194 Err(e) => {
195 warn!("failed to create tap client: {}", e);
196 }
197 }
198 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
199 }
200}
201
202fn strip_did_prefix(did: &str) -> &str {
203 did.strip_prefix("did:").unwrap_or(did)
204}
205
206fn handle_record(counts: &Keyspace, records: &Keyspace, rec: RecordEvent) -> anyhow::Result<()> {
207 // index everything, no filter.
208 // key: strip_did(did)|rkey
209 let key = make_key(strip_did_prefix(&rec.did), &rec.rkey);
210
211 // logic to maintain counts:
212 // create: insert and increment.
213 // update: insert (overwrite). no count change.
214 // delete: remove and decrement.
215
216 match rec.action {
217 RecordAction::Create => {
218 // info!("creating {} {} {}", rec.did, rec.collection, rec.rkey);
219 records.insert(&key, &[])?;
220 increment_count(counts, strip_did_prefix(&rec.did), &rec.collection)?;
221 }
222 RecordAction::Update => {
223 // info!("updating {} {} {}", rec.did, rec.collection, rec.rkey);
224 // records.insert(&key, &[])?;
225 }
226 RecordAction::Delete => {
227 // info!("deleting {} {} {}", rec.did, rec.collection, rec.rkey);
228 records.remove(&key)?;
229 decrement_count(counts, strip_did_prefix(&rec.did), &rec.collection)?;
230 }
231 _ => {}
232 }
233
234 Ok(())
235}
236
237fn make_key(repo_stripped: &str, rkey: &str) -> Vec<u8> {
238 format!("{}|{}", repo_stripped, rkey).into_bytes()
239}
240
241fn make_count_key(repo_stripped: &str, collection: &str) -> Vec<u8> {
242 format!("{}|{}", repo_stripped, collection).into_bytes()
243}
244
245fn increment_count(counts: &Keyspace, repo_stripped: &str, collection: &str) -> anyhow::Result<()> {
246 let key = make_count_key(repo_stripped, collection);
247 let mut current = 0u64;
248 if let Some(val) = counts.get(&key)? {
249 if val.len() == 8 {
250 current = u64::from_le_bytes(val[..].try_into().unwrap());
251 }
252 }
253 current += 1;
254 counts.insert(&key, current.to_le_bytes())?;
255 Ok(())
256}
257
258fn decrement_count(counts: &Keyspace, repo_stripped: &str, collection: &str) -> anyhow::Result<()> {
259 let key = make_count_key(repo_stripped, collection);
260 let mut current = 0u64;
261 if let Some(val) = counts.get(&key)? {
262 if val.len() == 8 {
263 current = u64::from_le_bytes(val[..].try_into().unwrap());
264 }
265 }
266 if current > 0 {
267 current -= 1;
268 counts.insert(&key, current.to_le_bytes())?;
269 }
270 Ok(())
271}
272
273// handlers
274
275#[derive(Deserialize)]
276struct ListRecordsParams {
277 repo: String,
278 collection: String,
279 cursor: Option<String>,
280 reverse: Option<bool>,
281 limit: Option<usize>,
282}
283
284#[derive(Serialize)]
285struct ListRecordsResponse {
286 aturis: Vec<String>,
287 // count field is usually empty in listRecords but we can leave it 0
288 count: usize,
289 #[serde(skip_serializing_if = "Option::is_none")]
290 cursor: Option<String>,
291}
292
293async fn list_records(
294 State(state): State<AppState>,
295 Query(params): Query<ListRecordsParams>,
296) -> Json<ListRecordsResponse> {
297 let records = match state
298 .db
299 .keyspace(¶ms.collection, || KeyspaceCreateOptions::default())
300 {
301 Ok(p) => p,
302 Err(_) => {
303 return Json(ListRecordsResponse {
304 aturis: Vec::new(),
305 count: 0,
306 cursor: None,
307 });
308 }
309 };
310
311 let repo_stripped = strip_did_prefix(¶ms.repo);
312 let prefix_str = format!("{}|", repo_stripped);
313 let prefix = prefix_str.as_bytes();
314
315 // default to descending (newest first) -> reverse=false means descending.
316 // reverse=true means ascending.
317 let ascending = params.reverse.unwrap_or(false);
318 let limit = params.limit.unwrap_or(50).min(500);
319
320 let mut aturis = Vec::new();
321 let mut last_rkey = None;
322
323 let start_bound = if ascending {
324 if let Some(c) = ¶ms.cursor {
325 let mut k = make_key(repo_stripped, c);
326 k.push(0); // start after cursor
327 k
328 } else {
329 prefix.to_vec()
330 }
331 } else {
332 // descending
333 prefix.to_vec()
334 };
335
336 let end_bound = if ascending {
337 let mut p = prefix.to_vec();
338 p.push(0xFF);
339 p
340 } else {
341 // descending
342 if let Some(c) = ¶ms.cursor {
343 make_key(repo_stripped, c)
344 } else {
345 let mut p = prefix.to_vec();
346 p.push(0xFF);
347 p
348 }
349 };
350
351 let range = records.range(start_bound..end_bound);
352
353 let mut process_key = |k: &[u8]| {
354 let k_str = String::from_utf8_lossy(k);
355 let parts: Vec<&str> = k_str.split('|').collect();
356 // key format: repo_stripped|rkey
357 if parts.len() == 2 {
358 let rkey = parts[1];
359 aturis.push(format!(
360 "at://{}/{}/{}",
361 params.repo, params.collection, rkey
362 ));
363 last_rkey = Some(rkey.to_string());
364 }
365 };
366
367 if ascending {
368 for item in range.take(limit) {
369 if let Ok(k) = item.key() {
370 process_key(&k);
371 }
372 }
373 } else {
374 for item in range.rev().take(limit) {
375 if let Ok(k) = item.key() {
376 process_key(&k);
377 }
378 }
379 }
380
381 let count = aturis.len();
382
383 Json(ListRecordsResponse {
384 aturis,
385 count,
386 cursor: last_rkey,
387 })
388}
389
390#[derive(Deserialize)]
391struct CountRecordsParams {
392 repo: String,
393 collection: String,
394}
395
396#[derive(Serialize)]
397struct CountRecordsResponse {
398 repo: String,
399 collection: String,
400 count: u64,
401}
402
403async fn count_records(
404 State(state): State<AppState>,
405 Query(params): Query<CountRecordsParams>,
406) -> Json<CountRecordsResponse> {
407 let repo_stripped = strip_did_prefix(¶ms.repo);
408 let key = make_count_key(repo_stripped, ¶ms.collection);
409 let mut count = 0u64;
410
411 if let Ok(Some(val)) = state.counts.get(&key) {
412 if val.len() == 8 {
413 count = u64::from_le_bytes(val[..].try_into().unwrap());
414 }
415 }
416
417 Json(CountRecordsResponse {
418 repo: params.repo,
419 collection: params.collection,
420 count,
421 })
422}