ML-based recommendation feed for Bluesky posts
at main 188 lines 5.7 kB view raw
1import argparse 2import asyncio 3import decimal 4import logging 5import os 6import sys 7import time 8from typing import Tuple, List, Dict 9 10from atproto import AsyncClient 11from atproto import exceptions as at_exceptions 12import pandas as pd 13 14from utils import RateLimit, BSKY_API_LIMIT 15 16logger = logging.getLogger(__name__) 17logger.setLevel(logging.INFO) 18 19# Create formatter 20formatter = logging.Formatter( 21 "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 22) 23 24# Console handler 25console_handler = logging.StreamHandler(sys.stdout) 26console_handler.setFormatter(formatter) 27logger.addHandler(console_handler) 28 29 30BATCH_SIZE = 10 31CHECKPOINT_THRESHOLD = 100 32REQUIRED_ENV = ("BSKY_USER", "BSKY_APP_PW") 33 34 35async def get_all_follows( 36 client: AsyncClient, rate_limit: RateLimit, did: str 37) -> Tuple[str, List[str]]: 38 follows = [] 39 await rate_limit.acquire() 40 data = await client.get_follows(actor=did, limit=100) 41 cursor = data.cursor 42 for follow in data.follows: 43 follows.append(follow.did) 44 45 while cursor: 46 await rate_limit.acquire() 47 data = await client.get_follows(actor=did, cursor=cursor, limit=100) 48 cursor = data.cursor 49 for follow in data.follows: 50 follows.append(follow.did) 51 52 return (did, follows) 53 54 55def save_checkpoint(output_dir: str, job_start: int, f_map: Dict[str, List[str]]): 56 # Create dataframe for follow map 57 follow_dict = {"follower": [], "follows": []} 58 for follower, follows in f_map.items(): 59 follow_dict["follower"].append(follower) 60 follow_dict["follows"].append(follows) 61 follow_df = pd.DataFrame(data=follow_dict) 62 63 # Convert to scientific notation for more pleasent unix sort 64 size_string = f"{decimal.Decimal(len(f_map)):.2e}".replace("+", "").replace( 65 ".", "_" 66 ) 67 save_path = os.path.join( 68 output_dir, f"{job_start}_checkpoint_{size_string}.parquet" 69 ) 70 71 # Save follow map as parquet file 72 logger.info(f"Saving follow map to: {save_path}") 73 try: 74 follow_df.to_parquet(save_path) 75 except Exception as e: 76 logger.error(f"Failed to save follow map: {e}") 77 sys.exit(1) 78 79 80async def explore(username: str, pw: str, start_did: str, num_hops: int, save_dir: str): 81 client = AsyncClient() 82 await client.login(username, pw) 83 84 logger.info(f"Starting did: {start_did}") 85 follow_map = dict() 86 distance_map = {start_did: 0} 87 to_explore = [start_did] 88 89 logger.info( 90 f"Starting crawl with:\nStart DID: {start_did}\nNum hops: {num_hops}\nSaving Output to: {save_dir}" 91 ) 92 93 job_start = int(time.time()) 94 95 # Try to only send 10 requests a second 96 batch_count = 1 97 fail_count = 0 98 rate_limiter = RateLimit(BSKY_API_LIMIT) 99 while len(to_explore): 100 batch = to_explore[:BATCH_SIZE] 101 to_explore = to_explore[BATCH_SIZE:] 102 logger.info( 103 f"Starting batch with size: {len(batch)} remaining to_explore: {len(to_explore)}" 104 ) 105 for result in asyncio.as_completed( 106 [get_all_follows(client, rate_limiter, did) for did in batch] 107 ): 108 try: 109 follower, follows = await result 110 follow_map[follower] = follows 111 logger.info(f"{follower} follows {len(follows)} (public) accounts") 112 except at_exceptions.BadRequestError as e: 113 # Bad request is probably a profile that's private or deleted 114 logger.info(f"Bad Request: {e.response.content.error}") 115 continue 116 except Exception as e: 117 logger.error(f"Failed to get followers: {e}", exc_info=1) 118 fail_count += 1 119 if fail_count >= 3: 120 sys.exit(1) 121 continue 122 123 # Save a checkpoint every 10,000 accounts 124 if batch_count % CHECKPOINT_THRESHOLD == 0: 125 save_checkpoint(save_dir, job_start, follow_map) 126 127 # If too far from start, dont add follows to exploration queue 128 if distance_map[follower] != num_hops: 129 for follow_did in follows: 130 if follow_did not in follow_map: 131 to_explore.append(follow_did) 132 distance_map[follow_did] = distance_map[follower] + 1 133 134 logger.info( 135 f"Finished batch {batch_count} | to_explore size: {len(to_explore)}" 136 ) 137 batch_count += 1 138 139 save_checkpoint(save_dir, job_start, follow_map) 140 logger.info("Crawl complete") 141 142 143def main(): 144 for key in REQUIRED_ENV: 145 if key not in os.environ: 146 raise ValueError(f"Must set '{key}' env var") 147 148 user_name = os.environ["BSKY_USER"] 149 app_pw = os.environ["BSKY_APP_PW"] 150 151 parser = argparse.ArgumentParser( 152 prog="CrawlFollows", 153 description="Crawl social graph using follows starting from provided DID", 154 ) 155 parser.add_argument( 156 "--start-did", 157 dest="start_did", 158 required=True, 159 help="DID for account to start crawl at", 160 ) 161 parser.add_argument( 162 "--num-hops", 163 dest="num_hops", 164 type=int, 165 default=2, 166 help="How many network hops to explore out from the start", 167 ) 168 parser.add_argument( 169 "--save-dir", 170 dest="save_dir", 171 default="data/crawl/", 172 help="Where to store crawl data", 173 ) 174 args = parser.parse_args() 175 176 asyncio.run( 177 explore( 178 user_name, 179 app_pw, 180 start_did=args.start_did, 181 num_hops=args.num_hops, 182 save_dir=args.save_dir, 183 ) 184 ) 185 186 187if __name__ == "__main__": 188 main()