ML-based recommendation feed for Bluesky posts
at main 167 lines 5.1 kB view raw
1import argparse 2import asyncio 3import gzip 4import logging 5import os 6import sys 7from typing import Tuple, List 8 9from atproto import AsyncClient 10from atproto import exceptions as at_exceptions 11from atproto_client.models.app.bsky.graph.follow import Record as FollowRecord 12 13from utils import load_checkpoint, get_accounts, RateLimit, BSKY_API_LIMIT 14 15logger = logging.getLogger(__name__) 16logger.setLevel(logging.INFO) 17 18# Create formatter 19formatter = logging.Formatter( 20 "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 21) 22 23# Console handler 24console_handler = logging.StreamHandler(sys.stdout) 25console_handler.setFormatter(formatter) 26logger.addHandler(console_handler) 27 28 29BATCH_SIZE = 100 30 31 32async def get_all_follows( 33 client: AsyncClient, 34 rate_limit: RateLimit, 35 account_did: str, 36) -> Tuple[List[FollowRecord], str]: 37 follows: List[FollowRecord] = [] 38 await rate_limit.acquire() 39 try: 40 data = await client.com.atproto.repo.list_records( 41 { 42 "collection": "app.bsky.graph.follow", 43 "repo": account_did, 44 "limit": 100, 45 } 46 ) 47 # If user can't be accessed just return an empty list to skip next time 48 except at_exceptions.BadRequestError as e: 49 if e.response.status_code == 400: 50 return [], account_did 51 else: 52 logger.info(f"Error status code: {e.response.status_code}") 53 raise e 54 except at_exceptions.RequestException as e: 55 if e.response.status_code == 500: 56 return await get_all_follows(client, rate_limit, account_did) 57 raise e 58 # If we timeout just try again 59 except at_exceptions.InvokeTimeoutError: 60 return await get_all_follows(client, rate_limit, account_did) 61 62 for follow in data.records: 63 follows.append(follow) 64 65 # Limit to 1000 follows per account 66 while data.cursor and len(follows) < 1000: 67 await rate_limit.acquire() 68 try: 69 data = await client.com.atproto.repo.list_records( 70 { 71 "collection": "app.bsky.graph.follow", 72 "repo": account_did, 73 "cursor": data.cursor, 74 "limit": 100, 75 } 76 ) 77 for follow in data.records: 78 follows.append(follow) 79 # If we timeout just try again 80 except at_exceptions.InvokeTimeoutError: 81 continue 82 except at_exceptions.RequestException as e: 83 if e.response.status_code == 500: 84 continue 85 else: 86 raise e 87 88 return follows, account_did 89 90 91async def retrieve_follows( 92 graph_file: str, 93 checkpoint_dir: str, 94): 95 completed_accounts = load_checkpoint(checkpoint_dir) 96 accts = get_accounts(graph_file, completed_accounts) 97 logger.info(f"Num of accounts to retrieve follows from: {len(accts)}") 98 99 client = AsyncClient() 100 101 # Get all follows for accounts 102 batch_count = 0 103 fail_count = 0 104 rate_limiter = RateLimit(BSKY_API_LIMIT) 105 for i in range(0, len(accts), BATCH_SIZE): 106 batch = [acct for acct, _ in accts[i : i + BATCH_SIZE]] 107 for result in asyncio.as_completed( 108 [get_all_follows(client, rate_limiter, did) for did in batch] 109 ): 110 try: 111 follows, did = await result 112 # Save follows 113 with gzip.open( 114 os.path.join(checkpoint_dir, did + ".gz"), "wt" 115 ) as out_file: 116 for follow in follows: 117 out_file.write(follow.model_dump_json() + "\n") 118 except at_exceptions.BadRequestError as e: 119 # Bad request is probably a profile that's private or deleted 120 logger.info(f"Bad Request: {e.response.content.error}") 121 continue 122 except Exception as e: 123 logger.error(f"Failed to get follows: {e}", exc_info=1) 124 fail_count += 1 125 if fail_count >= 100: 126 logger.error("Hitting error threshold, exiting...") 127 sys.exit(1) 128 continue 129 130 batch_count += 1 131 if batch_count % 10 == 0: 132 logger.info(f"Completed batch: {batch_count}") 133 134 135def main(): 136 parser = argparse.ArgumentParser( 137 prog="GetFollows", 138 description="Get all follows for accounts in provided follow graph", 139 ) 140 parser.add_argument( 141 "--graph-file", 142 dest="graph_file", 143 required=True, 144 help="File with follow graph", 145 ) 146 parser.add_argument( 147 "--save-dir", 148 dest="save_dir", 149 required=True, 150 help="Where to store crawl data", 151 ) 152 args = parser.parse_args() 153 154 if args.save_dir is None and args.ckpt is None: 155 logger.error("Must provide save dir or checkpoint dir") 156 sys.exit(1) 157 158 asyncio.run( 159 retrieve_follows( 160 graph_file=args.graph_file, 161 checkpoint_dir=args.save_dir, 162 ) 163 ) 164 165 166if __name__ == "__main__": 167 main()