this repo has no description
at main 428 lines 14 kB view raw
1import asyncio 2import json 3import logging 4from datetime import datetime 5from threading import Lock 6from time import time 7from typing import Any, Callable, List, Optional 8 9import click 10from aiokafka import AIOKafkaConsumer, ConsumerRecord 11from clickhouse_connect import create_client 12from dateutil.parser import isoparse 13 14from config import CONFIG 15from metrics import prom_metrics 16from models import Follow, FollowRecord, TapEvent, Unfollow 17 18logging.basicConfig( 19 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 20) 21 22logger = logging.getLogger(__name__) 23 24 25class FollowIndexer: 26 def __init__( 27 self, 28 clickhouse_host: str, 29 clickhouse_port: int, 30 clickhouse_user: str, 31 clickhouse_pass: str, 32 batch_size: int, 33 ): 34 self.client = create_client( 35 host=clickhouse_host, 36 port=clickhouse_port, 37 username=clickhouse_user, 38 password=clickhouse_pass, 39 ) 40 41 self.batch_size = batch_size 42 43 self._follow_batch: List[Follow] = [] 44 self._unfollow_batch: List[Unfollow] = [] 45 46 self._follow_lock = Lock() 47 self._unfollow_lock = Lock() 48 49 self._last_follow_flush = datetime.now() 50 self._last_unfollow_flush = datetime.now() 51 52 logger.info(f"Connected to Clickhouse: {clickhouse_host}:{clickhouse_port}") 53 54 def init_schema(self): 55 queries = [ 56 """ 57 CREATE TABLE IF NOT EXISTS follows ( 58 did String, 59 subject String, 60 uri String, 61 created_at DateTime 62 ) ENGINE = MergeTree() 63 PARTITION BY toYYYYMM(created_at) 64 ORDER BY (did, subject) 65 """, 66 """ 67 CREATE TABLE IF NOT EXISTS follows_reverse( 68 did String, 69 subject String, 70 uri String, 71 created_at DateTime 72 ) ENGINE = MergeTree() 73 PARTITION BY toYYYYMM(created_at) 74 ORDER BY (subject, did) 75 """, 76 """ 77 CREATE TABLE IF NOT EXISTS unfollows ( 78 uri String PRIMARY KEY, 79 created_at DateTime 80 ) ENGINE = MergeTree() 81 """, 82 ] 83 84 for query in queries: 85 try: 86 self.client.command(query) 87 except Exception as e: 88 logger.error(f"Failed to execute schema query: {e}") 89 raise 90 91 logger.info("Follow schemas initialized successfully") 92 93 def insert_follow(self, follow: Follow): 94 to_insert: Optional[List[Follow]] = None 95 96 with self._follow_lock: 97 self._follow_batch.append(follow) 98 99 if len(self._follow_batch) >= self.batch_size: 100 to_insert = self._follow_batch.copy() 101 self._follow_batch = [] 102 103 if not to_insert: 104 return 105 106 self._flush_follows(to_insert) 107 108 def _flush_follows(self, follows: List[Follow]): 109 status = "error" 110 start_time = time() 111 try: 112 follows_data = [[f.did, f.subject, f.uri, f.created_at] for f in follows] 113 114 self.client.insert( 115 "follows", 116 follows_data, 117 column_names=["did", "subject", "uri", "created_at"], 118 ) 119 120 self.client.insert( 121 "follows_reverse", 122 follows_data, 123 column_names=["did", "subject", "uri", "created_at"], 124 ) 125 126 status = "ok" 127 except Exception as e: 128 # TODO: handle errors gracefully 129 logger.error(f"Error inserting batch: {e}") 130 finally: 131 prom_metrics.insert_duration.labels(kind="follow", status=status).observe( 132 time() - start_time 133 ) 134 prom_metrics.inserted.labels(kind="follow", status=status).inc(len(follows)) 135 136 def insert_unfollow(self, unfollow: Unfollow): 137 to_insert: Optional[List[Unfollow]] = None 138 139 with self._unfollow_lock: 140 self._unfollow_batch.append(unfollow) 141 142 if len(self._unfollow_batch) >= self.batch_size: 143 to_insert = self._unfollow_batch.copy() 144 self._unfollow_batch = [] 145 146 if not to_insert: 147 return 148 149 self._flush_unfollows(to_insert) 150 151 def _flush_unfollows(self, unfollows: List[Unfollow]): 152 status = "error" 153 start_time = time() 154 try: 155 unfollows_data = [[f.uri, f.created_at] for f in unfollows] 156 157 self.client.insert( 158 "unfollows", 159 unfollows_data, 160 column_names=["uri", "created_at"], 161 ) 162 163 status = "ok" 164 except Exception as e: 165 # TODO: handle errors gracefully 166 logger.error(f"Error inserting batch: {e}") 167 finally: 168 prom_metrics.insert_duration.labels(kind="unfollow", status=status).observe( 169 time() - start_time 170 ) 171 prom_metrics.inserted.labels(kind="unfollow", status=status).inc( 172 len(unfollows) 173 ) 174 175 def flush_all(self): 176 with self._follow_lock: 177 if self._follow_batch: 178 batch_to_flush = self._follow_batch.copy() 179 self._follow_batch = [] 180 self._flush_follows(batch_to_flush) 181 182 with self._unfollow_lock: 183 if self._unfollow_batch: 184 batch_to_flush = self._unfollow_batch.copy() 185 self._unfollow_batch = [] 186 self._flush_unfollows(batch_to_flush) 187 188 def stream_follows(self, cb: Callable[[str, str], None], batch_size: int = 100_000): 189 query = """ 190 SELECT f.did, f.subject 191 FROM follows f 192 LEFT ANTI JOIN unfollows u ON f.uri = u.uri 193 """ 194 195 try: 196 with self.client.query_row_block_stream( 197 query, settings={"max_block_size": batch_size} 198 ) as stream: 199 total_handled = 0 200 for block in stream: 201 for row in block: 202 cb(row[0], row[1]) 203 total_handled += 1 204 205 if total_handled % 1_000_000 == 0: 206 logger.info(f"Handled {total_handled:,} follows so far") 207 logger.info(f"Finished streaming {total_handled:,} follows") 208 except Exception as e: 209 logger.error(f"Error streaming follows: {e}") 210 211 212class Consumer: 213 def __init__( 214 self, 215 indexer: FollowIndexer, 216 bootstrap_servers: List[str], 217 input_topic: str, 218 group_id: str, 219 max_concurrent_tasks: int = 100, 220 ): 221 self.indexer = indexer 222 self.bootstrap_servers = bootstrap_servers 223 self.input_topic = input_topic 224 self.group_id = group_id 225 self.max_concurrent_tasks = max_concurrent_tasks 226 self.consumer: Optional[AIOKafkaConsumer] = None 227 self._flush_task: Optional[asyncio.Task[Any]] = None 228 self._semaphore: Optional[asyncio.Semaphore] = None 229 self._shutdown_event: Optional[asyncio.Event] = None 230 231 async def stop(self): 232 if self._shutdown_event: 233 self._shutdown_event.set() 234 235 if self._flush_task: 236 self._flush_task.cancel() 237 try: 238 await self._flush_task 239 except asyncio.CancelledError: 240 pass 241 242 self.indexer.flush_all() 243 244 if self.consumer: 245 await self.consumer.stop() 246 logger.info("Stopped Kafka consumer") 247 248 async def _periodic_flush(self): 249 try: 250 while True: 251 await asyncio.sleep(5) 252 self.indexer.flush_all() 253 except asyncio.CancelledError: 254 logger.info("Periodic flush task cancelled") 255 raise 256 257 async def _handle_event(self, message: ConsumerRecord[Any, Any]): 258 status = "error" 259 kind = "unk" 260 261 try: 262 evt = TapEvent.model_validate(message.value) 263 264 if not evt.record or evt.record.collection != "app.bsky.graph.follow": 265 kind = "skipped" 266 status = "ok" 267 return 268 269 op = evt.record 270 uri = f"at://{op.did}/{op.collection}/{op.rkey}" 271 272 if op.action == "update": 273 kind = "update" 274 elif op.action == "create": 275 kind = "create" 276 rec = FollowRecord.model_validate(op.record) 277 created_at = isoparse(rec.created_at) 278 follow = Follow( 279 uri=uri, did=op.did, subject=rec.subject, created_at=created_at 280 ) 281 self.indexer.insert_follow(follow) 282 else: 283 kind = "delete" 284 285 unfollow = Unfollow(uri=uri, created_at=datetime.now()) 286 287 self.indexer.insert_unfollow(unfollow) 288 289 status = "ok" 290 except Exception as e: 291 logger.error(f"Failed to handle event: {e}") 292 finally: 293 prom_metrics.events_handled.labels(kind=kind, status=status).inc() 294 295 async def _handle_event_with_semaphore(self, message: ConsumerRecord[Any, Any]): 296 assert self._semaphore is not None 297 async with self._semaphore: 298 await self._handle_event(message) 299 300 async def run(self): 301 self._semaphore = asyncio.Semaphore(self.max_concurrent_tasks) 302 self._shutdown_event = asyncio.Event() 303 304 self.consumer = AIOKafkaConsumer( 305 self.input_topic, 306 bootstrap_servers=",".join(self.bootstrap_servers), 307 group_id=self.group_id, 308 auto_offset_reset="earliest", 309 enable_auto_commit=True, 310 auto_commit_interval_ms=5000, 311 session_timeout_ms=30000, 312 max_poll_interval_ms=300000, 313 value_deserializer=lambda m: json.loads(m.decode("utf-8")), 314 ) 315 await self.consumer.start() 316 logger.info( 317 f"Started Kafka consumer for topic: {self.bootstrap_servers}, {self.input_topic}" 318 ) 319 320 self._flush_task = asyncio.create_task(self._periodic_flush()) 321 322 pending_tasks: set[asyncio.Task[Any]] = set() 323 324 try: 325 async for message in self.consumer: 326 prom_metrics.events_received.inc() 327 328 task = asyncio.create_task(self._handle_event_with_semaphore(message)) 329 pending_tasks.add(task) 330 task.add_done_callback(pending_tasks.discard) 331 332 if len(pending_tasks) >= self.max_concurrent_tasks * 2: 333 done, pending_tasks_set = await asyncio.wait( 334 pending_tasks, timeout=0, return_when=asyncio.FIRST_COMPLETED 335 ) 336 pending_tasks = pending_tasks_set 337 for t in done: 338 if t.exception(): 339 logger.error(f"Task failed with exception: {t.exception()}") 340 341 except Exception as e: 342 logger.error(f"Error consuming messages: {e}") 343 raise 344 finally: 345 if pending_tasks: 346 logger.info( 347 f"Waiting for {len(pending_tasks)} pending tasks to complete..." 348 ) 349 await asyncio.gather(*pending_tasks, return_exceptions=True) 350 self.indexer.flush_all() 351 352 353@click.command() 354@click.option("--ch-host") 355@click.option("--ch-port", type=int) 356@click.option("--ch-user") 357@click.option("--ch-pass") 358@click.option("--batch-size", type=int) 359@click.option( 360 "--bootstrap-servers", help="Comma-separated list of Kafka bootstrap servers" 361) 362@click.option("--input-topic") 363@click.option("--group-id") 364@click.option("--metrics-host") 365@click.option("--metrics-port", type=int) 366def main( 367 ch_host: Optional[str], 368 ch_port: Optional[int], 369 ch_user: Optional[str], 370 ch_pass: Optional[str], 371 batch_size: Optional[int], 372 bootstrap_servers: Optional[str], 373 input_topic: Optional[str], 374 group_id: Optional[str], 375 metrics_host: Optional[str], 376 metrics_port: Optional[int], 377): 378 prom_metrics.start_http( 379 addr=metrics_host or CONFIG.metrics_host, 380 port=metrics_port or CONFIG.metrics_port, 381 ) 382 383 indexer = FollowIndexer( 384 clickhouse_host=ch_host or CONFIG.clickhouse_host, 385 clickhouse_port=ch_port or CONFIG.clickhouse_port, 386 clickhouse_user=ch_user or CONFIG.clickhouse_user, 387 clickhouse_pass=ch_pass or CONFIG.clickhouse_pass, 388 batch_size=batch_size or CONFIG.batch_size, 389 ) 390 indexer.init_schema() 391 392 kafka_servers = ( 393 bootstrap_servers.split(",") 394 if bootstrap_servers 395 else CONFIG.kafka_bootstrap_servers 396 ) 397 398 consumer = Consumer( 399 indexer=indexer, 400 bootstrap_servers=kafka_servers, 401 input_topic=input_topic or CONFIG.kafka_input_topic, 402 group_id=group_id or CONFIG.kafka_group_id, 403 ) 404 405 async def run_with_shutdown(): 406 loop = asyncio.get_event_loop() 407 408 import signal 409 410 def handle_signal(): 411 logger.info("Received shutdown signal...") 412 asyncio.create_task(consumer.stop()) 413 414 for sig in (signal.SIGTERM, signal.SIGINT): 415 loop.add_signal_handler(sig, handle_signal) 416 417 try: 418 await consumer.run() 419 except asyncio.CancelledError: 420 pass 421 finally: 422 await consumer.stop() 423 424 asyncio.run(run_with_shutdown()) 425 426 427if __name__ == "__main__": 428 main()