this repo has no description
1import asyncio
2import json
3import logging
4from datetime import datetime
5from threading import Lock
6from time import time
7from typing import Any, Callable, List, Optional
8
9import click
10from aiokafka import AIOKafkaConsumer, ConsumerRecord
11from clickhouse_connect import create_client
12from dateutil.parser import isoparse
13
14from config import CONFIG
15from metrics import prom_metrics
16from models import Follow, FollowRecord, TapEvent, Unfollow
17
18logging.basicConfig(
19 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
20)
21
22logger = logging.getLogger(__name__)
23
24
25class FollowIndexer:
26 def __init__(
27 self,
28 clickhouse_host: str,
29 clickhouse_port: int,
30 clickhouse_user: str,
31 clickhouse_pass: str,
32 batch_size: int,
33 ):
34 self.client = create_client(
35 host=clickhouse_host,
36 port=clickhouse_port,
37 username=clickhouse_user,
38 password=clickhouse_pass,
39 )
40
41 self.batch_size = batch_size
42
43 self._follow_batch: List[Follow] = []
44 self._unfollow_batch: List[Unfollow] = []
45
46 self._follow_lock = Lock()
47 self._unfollow_lock = Lock()
48
49 self._last_follow_flush = datetime.now()
50 self._last_unfollow_flush = datetime.now()
51
52 logger.info(f"Connected to Clickhouse: {clickhouse_host}:{clickhouse_port}")
53
54 def init_schema(self):
55 queries = [
56 """
57 CREATE TABLE IF NOT EXISTS follows (
58 did String,
59 subject String,
60 uri String,
61 created_at DateTime
62 ) ENGINE = MergeTree()
63 PARTITION BY toYYYYMM(created_at)
64 ORDER BY (did, subject)
65 """,
66 """
67 CREATE TABLE IF NOT EXISTS follows_reverse(
68 did String,
69 subject String,
70 uri String,
71 created_at DateTime
72 ) ENGINE = MergeTree()
73 PARTITION BY toYYYYMM(created_at)
74 ORDER BY (subject, did)
75 """,
76 """
77 CREATE TABLE IF NOT EXISTS unfollows (
78 uri String PRIMARY KEY,
79 created_at DateTime
80 ) ENGINE = MergeTree()
81 """,
82 ]
83
84 for query in queries:
85 try:
86 self.client.command(query)
87 except Exception as e:
88 logger.error(f"Failed to execute schema query: {e}")
89 raise
90
91 logger.info("Follow schemas initialized successfully")
92
93 def insert_follow(self, follow: Follow):
94 to_insert: Optional[List[Follow]] = None
95
96 with self._follow_lock:
97 self._follow_batch.append(follow)
98
99 if len(self._follow_batch) >= self.batch_size:
100 to_insert = self._follow_batch.copy()
101 self._follow_batch = []
102
103 if not to_insert:
104 return
105
106 self._flush_follows(to_insert)
107
108 def _flush_follows(self, follows: List[Follow]):
109 status = "error"
110 start_time = time()
111 try:
112 follows_data = [[f.did, f.subject, f.uri, f.created_at] for f in follows]
113
114 self.client.insert(
115 "follows",
116 follows_data,
117 column_names=["did", "subject", "uri", "created_at"],
118 )
119
120 self.client.insert(
121 "follows_reverse",
122 follows_data,
123 column_names=["did", "subject", "uri", "created_at"],
124 )
125
126 status = "ok"
127 except Exception as e:
128 # TODO: handle errors gracefully
129 logger.error(f"Error inserting batch: {e}")
130 finally:
131 prom_metrics.insert_duration.labels(kind="follow", status=status).observe(
132 time() - start_time
133 )
134 prom_metrics.inserted.labels(kind="follow", status=status).inc(len(follows))
135
136 def insert_unfollow(self, unfollow: Unfollow):
137 to_insert: Optional[List[Unfollow]] = None
138
139 with self._unfollow_lock:
140 self._unfollow_batch.append(unfollow)
141
142 if len(self._unfollow_batch) >= self.batch_size:
143 to_insert = self._unfollow_batch.copy()
144 self._unfollow_batch = []
145
146 if not to_insert:
147 return
148
149 self._flush_unfollows(to_insert)
150
151 def _flush_unfollows(self, unfollows: List[Unfollow]):
152 status = "error"
153 start_time = time()
154 try:
155 unfollows_data = [[f.uri, f.created_at] for f in unfollows]
156
157 self.client.insert(
158 "unfollows",
159 unfollows_data,
160 column_names=["uri", "created_at"],
161 )
162
163 status = "ok"
164 except Exception as e:
165 # TODO: handle errors gracefully
166 logger.error(f"Error inserting batch: {e}")
167 finally:
168 prom_metrics.insert_duration.labels(kind="unfollow", status=status).observe(
169 time() - start_time
170 )
171 prom_metrics.inserted.labels(kind="unfollow", status=status).inc(
172 len(unfollows)
173 )
174
175 def flush_all(self):
176 with self._follow_lock:
177 if self._follow_batch:
178 batch_to_flush = self._follow_batch.copy()
179 self._follow_batch = []
180 self._flush_follows(batch_to_flush)
181
182 with self._unfollow_lock:
183 if self._unfollow_batch:
184 batch_to_flush = self._unfollow_batch.copy()
185 self._unfollow_batch = []
186 self._flush_unfollows(batch_to_flush)
187
188 def stream_follows(self, cb: Callable[[str, str], None], batch_size: int = 100_000):
189 query = """
190 SELECT f.did, f.subject
191 FROM follows f
192 LEFT ANTI JOIN unfollows u ON f.uri = u.uri
193 """
194
195 try:
196 with self.client.query_row_block_stream(
197 query, settings={"max_block_size": batch_size}
198 ) as stream:
199 total_handled = 0
200 for block in stream:
201 for row in block:
202 cb(row[0], row[1])
203 total_handled += 1
204
205 if total_handled % 1_000_000 == 0:
206 logger.info(f"Handled {total_handled:,} follows so far")
207 logger.info(f"Finished streaming {total_handled:,} follows")
208 except Exception as e:
209 logger.error(f"Error streaming follows: {e}")
210
211
212class Consumer:
213 def __init__(
214 self,
215 indexer: FollowIndexer,
216 bootstrap_servers: List[str],
217 input_topic: str,
218 group_id: str,
219 max_concurrent_tasks: int = 100,
220 ):
221 self.indexer = indexer
222 self.bootstrap_servers = bootstrap_servers
223 self.input_topic = input_topic
224 self.group_id = group_id
225 self.max_concurrent_tasks = max_concurrent_tasks
226 self.consumer: Optional[AIOKafkaConsumer] = None
227 self._flush_task: Optional[asyncio.Task[Any]] = None
228 self._semaphore: Optional[asyncio.Semaphore] = None
229 self._shutdown_event: Optional[asyncio.Event] = None
230
231 async def stop(self):
232 if self._shutdown_event:
233 self._shutdown_event.set()
234
235 if self._flush_task:
236 self._flush_task.cancel()
237 try:
238 await self._flush_task
239 except asyncio.CancelledError:
240 pass
241
242 self.indexer.flush_all()
243
244 if self.consumer:
245 await self.consumer.stop()
246 logger.info("Stopped Kafka consumer")
247
248 async def _periodic_flush(self):
249 try:
250 while True:
251 await asyncio.sleep(5)
252 self.indexer.flush_all()
253 except asyncio.CancelledError:
254 logger.info("Periodic flush task cancelled")
255 raise
256
257 async def _handle_event(self, message: ConsumerRecord[Any, Any]):
258 status = "error"
259 kind = "unk"
260
261 try:
262 evt = TapEvent.model_validate(message.value)
263
264 if not evt.record or evt.record.collection != "app.bsky.graph.follow":
265 kind = "skipped"
266 status = "ok"
267 return
268
269 op = evt.record
270 uri = f"at://{op.did}/{op.collection}/{op.rkey}"
271
272 if op.action == "update":
273 kind = "update"
274 elif op.action == "create":
275 kind = "create"
276 rec = FollowRecord.model_validate(op.record)
277 created_at = isoparse(rec.created_at)
278 follow = Follow(
279 uri=uri, did=op.did, subject=rec.subject, created_at=created_at
280 )
281 self.indexer.insert_follow(follow)
282 else:
283 kind = "delete"
284
285 unfollow = Unfollow(uri=uri, created_at=datetime.now())
286
287 self.indexer.insert_unfollow(unfollow)
288
289 status = "ok"
290 except Exception as e:
291 logger.error(f"Failed to handle event: {e}")
292 finally:
293 prom_metrics.events_handled.labels(kind=kind, status=status).inc()
294
295 async def _handle_event_with_semaphore(self, message: ConsumerRecord[Any, Any]):
296 assert self._semaphore is not None
297 async with self._semaphore:
298 await self._handle_event(message)
299
300 async def run(self):
301 self._semaphore = asyncio.Semaphore(self.max_concurrent_tasks)
302 self._shutdown_event = asyncio.Event()
303
304 self.consumer = AIOKafkaConsumer(
305 self.input_topic,
306 bootstrap_servers=",".join(self.bootstrap_servers),
307 group_id=self.group_id,
308 auto_offset_reset="earliest",
309 enable_auto_commit=True,
310 auto_commit_interval_ms=5000,
311 session_timeout_ms=30000,
312 max_poll_interval_ms=300000,
313 value_deserializer=lambda m: json.loads(m.decode("utf-8")),
314 )
315 await self.consumer.start()
316 logger.info(
317 f"Started Kafka consumer for topic: {self.bootstrap_servers}, {self.input_topic}"
318 )
319
320 self._flush_task = asyncio.create_task(self._periodic_flush())
321
322 pending_tasks: set[asyncio.Task[Any]] = set()
323
324 try:
325 async for message in self.consumer:
326 prom_metrics.events_received.inc()
327
328 task = asyncio.create_task(self._handle_event_with_semaphore(message))
329 pending_tasks.add(task)
330 task.add_done_callback(pending_tasks.discard)
331
332 if len(pending_tasks) >= self.max_concurrent_tasks * 2:
333 done, pending_tasks_set = await asyncio.wait(
334 pending_tasks, timeout=0, return_when=asyncio.FIRST_COMPLETED
335 )
336 pending_tasks = pending_tasks_set
337 for t in done:
338 if t.exception():
339 logger.error(f"Task failed with exception: {t.exception()}")
340
341 except Exception as e:
342 logger.error(f"Error consuming messages: {e}")
343 raise
344 finally:
345 if pending_tasks:
346 logger.info(
347 f"Waiting for {len(pending_tasks)} pending tasks to complete..."
348 )
349 await asyncio.gather(*pending_tasks, return_exceptions=True)
350 self.indexer.flush_all()
351
352
353@click.command()
354@click.option("--ch-host")
355@click.option("--ch-port", type=int)
356@click.option("--ch-user")
357@click.option("--ch-pass")
358@click.option("--batch-size", type=int)
359@click.option(
360 "--bootstrap-servers", help="Comma-separated list of Kafka bootstrap servers"
361)
362@click.option("--input-topic")
363@click.option("--group-id")
364@click.option("--metrics-host")
365@click.option("--metrics-port", type=int)
366def main(
367 ch_host: Optional[str],
368 ch_port: Optional[int],
369 ch_user: Optional[str],
370 ch_pass: Optional[str],
371 batch_size: Optional[int],
372 bootstrap_servers: Optional[str],
373 input_topic: Optional[str],
374 group_id: Optional[str],
375 metrics_host: Optional[str],
376 metrics_port: Optional[int],
377):
378 prom_metrics.start_http(
379 addr=metrics_host or CONFIG.metrics_host,
380 port=metrics_port or CONFIG.metrics_port,
381 )
382
383 indexer = FollowIndexer(
384 clickhouse_host=ch_host or CONFIG.clickhouse_host,
385 clickhouse_port=ch_port or CONFIG.clickhouse_port,
386 clickhouse_user=ch_user or CONFIG.clickhouse_user,
387 clickhouse_pass=ch_pass or CONFIG.clickhouse_pass,
388 batch_size=batch_size or CONFIG.batch_size,
389 )
390 indexer.init_schema()
391
392 kafka_servers = (
393 bootstrap_servers.split(",")
394 if bootstrap_servers
395 else CONFIG.kafka_bootstrap_servers
396 )
397
398 consumer = Consumer(
399 indexer=indexer,
400 bootstrap_servers=kafka_servers,
401 input_topic=input_topic or CONFIG.kafka_input_topic,
402 group_id=group_id or CONFIG.kafka_group_id,
403 )
404
405 async def run_with_shutdown():
406 loop = asyncio.get_event_loop()
407
408 import signal
409
410 def handle_signal():
411 logger.info("Received shutdown signal...")
412 asyncio.create_task(consumer.stop())
413
414 for sig in (signal.SIGTERM, signal.SIGINT):
415 loop.add_signal_handler(sig, handle_signal)
416
417 try:
418 await consumer.run()
419 except asyncio.CancelledError:
420 pass
421 finally:
422 await consumer.stop()
423
424 asyncio.run(run_with_shutdown())
425
426
427if __name__ == "__main__":
428 main()