ML-based recommendation feed for Bluesky posts

Move scripts to subdir add script for crawling likes

+260
crawl_follows.py scripts/crawl_follows.py
+1
data/likes/.gitignore
··· 1 + ./*
get_posts.py scripts/get_posts.py
+259
scripts/get_likes.py
··· 1 + import argparse 2 + import asyncio 3 + from datetime import datetime 4 + import gzip 5 + import json 6 + import logging 7 + import os 8 + import sys 9 + from typing import Tuple, List, Dict 10 + 11 + from atproto import AsyncClient 12 + from atproto import exceptions as at_exceptions 13 + from atproto_client.models.app.bsky.feed.like import Record as LikeRecord 14 + import pandas as pd 15 + 16 + from crawl_follows import RateLimit 17 + 18 + logger = logging.getLogger(__name__) 19 + logger.setLevel(logging.INFO) 20 + 21 + # Create formatter 22 + formatter = logging.Formatter( 23 + "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 24 + ) 25 + 26 + # Console handler 27 + console_handler = logging.StreamHandler(sys.stdout) 28 + console_handler.setFormatter(formatter) 29 + logger.addHandler(console_handler) 30 + 31 + 32 + BATCH_SIZE = 1 33 + FOLLOWER_THRESHOLD = 150 34 + REQUIRED_ENV = ("BSKY_USER", "BSKY_APP_PW") 35 + 36 + 37 + async def get_all_likes( 38 + client: AsyncClient, 39 + rate_limit: RateLimit, 40 + account_did: str, 41 + start_dt: datetime, 42 + end_dt: datetime, 43 + ) -> Tuple[List[LikeRecord], str]: 44 + likes: List[LikeRecord] = [] 45 + await rate_limit.acquire() 46 + try: 47 + data = await client.com.atproto.repo.list_records( 48 + { 49 + "collection": "app.bsky.feed.like", 50 + "repo": account_did, 51 + "limit": 100, 52 + } 53 + ) 54 + # If user can't be accessed just return an empty list to skip next time 55 + except at_exceptions.BadRequestError as e: 56 + if e.response.status_code == 400: 57 + return [], account_did 58 + else: 59 + logger.info(f"Error status code: {e.response.status_code}") 60 + raise e 61 + 62 + for like in data.records: 63 + dt = datetime.strptime(like.value.created_at, "%Y-%m-%dT%H:%M:%S.%fZ") 64 + if start_dt <= dt and dt < end_dt: 65 + likes.append(like) 66 + 67 + hit_start_window = False 68 + while data.cursor and not hit_start_window: 69 + await rate_limit.acquire() 70 + data = await client.com.atproto.repo.list_records( 71 + { 72 + "collection": "app.bsky.feed.like", 73 + "repo": account_did, 74 + "cursor": data.cursor, 75 + "limit": 100, 76 + } 77 + ) 78 + 79 + for like in data.records: 80 + dt = datetime.strptime(like.value.created_at, "%Y-%m-%dT%H:%M:%S.%fZ") 81 + if start_dt <= dt and dt < end_dt: 82 + likes.append(like) 83 + 84 + if dt < start_dt: 85 + hit_start_window = True 86 + 87 + return likes, account_did 88 + 89 + 90 + async def retrieve_likes( 91 + user: str, 92 + app_pw: str, 93 + graph_file: str, 94 + checkpoint_dir: str, 95 + start_dt: datetime, 96 + end_dt: datetime, 97 + ): 98 + 99 + # If checkpoint dir doesn't exist, try to create it 100 + if not os.path.isdir(checkpoint_dir): 101 + logger.info("Checkpoint dir doesn't exist, creating...") 102 + try: 103 + os.mkdir(checkpoint_dir) 104 + except Exception as e: 105 + logger.error(f"Failed to created checkpoint dir, {checkpoint_dir}\n{e}") 106 + sys.exit(1) 107 + 108 + # Checkpoint folders contain one file per user 109 + completed_accounts = set() 110 + try: 111 + files = os.listdir(checkpoint_dir) 112 + for file in files: 113 + # Grab entire file name except for .gz extension 114 + completed_accounts.add(file[:-3]) 115 + except Exception as e: 116 + logger.error( 117 + f"Failed to recover from checkpoint dir, {checkpoint_dir}\n{e}", 118 + exc_info=1, 119 + ) 120 + sys.exit(1) 121 + 122 + # Load follow graph parquet file 123 + to_explore = dict() 124 + try: 125 + logger.info("Parsing follower graph file...") 126 + follow_df = pd.read_parquet(graph_file) 127 + # Limit to only accounts following between 100 and 1000 followers 128 + follow_df = follow_df.loc[follow_df["follows"].str.len().between(100, 1000)] 129 + except Exception as e: 130 + logger.error(f"Failed to open follow graph file, {graph_file}\n{e}") 131 + sys.exit(1) 132 + 133 + for _, row in follow_df.iterrows(): 134 + for acct in row["follows"]: 135 + if acct not in completed_accounts: 136 + if acct not in to_explore: 137 + to_explore[acct] = 0 138 + to_explore[acct] += 1 139 + 140 + accts = [(acct, follows) for acct, follows in to_explore.items()] 141 + accts.sort(key=lambda x: -1 * x[1]) 142 + 143 + logger.info(f"Num of accounts to retrieve posts from: {len(accts)}") 144 + 145 + client = AsyncClient() 146 + await client.login(user, app_pw) 147 + 148 + # Get all posts for accounts 149 + batch_count = 0 150 + fail_count = 0 151 + rate_limiter = RateLimit(BATCH_SIZE) 152 + for i in range(0, len(accts), BATCH_SIZE): 153 + batch = [acct for acct, _ in accts[i : i + BATCH_SIZE]] 154 + for result in asyncio.as_completed( 155 + [ 156 + get_all_likes(client, rate_limiter, did, start_dt, end_dt) 157 + for did in batch 158 + ] 159 + ): 160 + try: 161 + likes, did = await result 162 + # Save likes 163 + with gzip.open( 164 + os.path.join(checkpoint_dir, did + ".gz"), "wt" 165 + ) as out_file: 166 + for like in likes: 167 + out_file.write(like.model_dump_json() + "\n") 168 + except at_exceptions.BadRequestError as e: 169 + # Bad request is probably a profile that's private or deleted 170 + logger.info(f"Bad Request: {e.response.content.error}") 171 + continue 172 + except Exception as e: 173 + logger.error(f"Failed to get likes: {e}", exc_info=1) 174 + fail_count += 1 175 + if fail_count >= 100: 176 + logger.error("Hitting error threshold, exiting...") 177 + sys.exit(1) 178 + continue 179 + 180 + batch_count += 1 181 + if batch_count % 10 == 0: 182 + logger.info(f"Completed batch: {batch_count}") 183 + sys.exit(0) 184 + 185 + 186 + def main(): 187 + for key in REQUIRED_ENV: 188 + if key not in os.environ: 189 + raise ValueError(f"Must set '{key}' env var") 190 + 191 + user_name = os.environ["BSKY_USER"] 192 + app_pw = os.environ["BSKY_APP_PW"] 193 + 194 + parser = argparse.ArgumentParser( 195 + prog="GetLikes", 196 + description="Get all likes for accounts in provided follow graph", 197 + ) 198 + parser.add_argument( 199 + "--graph-file", 200 + dest="graph_file", 201 + required=True, 202 + help="File with follow graph", 203 + ) 204 + parser.add_argument( 205 + "--save-dir", 206 + dest="save_dir", 207 + required=True, 208 + help="Where to store crawl data", 209 + ) 210 + parser.add_argument( 211 + "--start", 212 + dest="start", 213 + required=True, 214 + help="Date to start saving posts from (YYYY-MM-DD)", 215 + ) 216 + parser.add_argument( 217 + "--end", 218 + dest="end", 219 + required=True, 220 + help="Date to stop (exclusive) saving posts from (YYYY-MM-DD)", 221 + ) 222 + args = parser.parse_args() 223 + 224 + if args.save_dir is None and args.ckpt is None: 225 + logger.error("Must provide save dir or checkpoint dir") 226 + sys.exit(1) 227 + 228 + try: 229 + start = datetime.strptime(args.start, "%Y-%m-%d") 230 + except: 231 + logger.error("Invalid start date") 232 + sys.exit(1) 233 + 234 + try: 235 + end = datetime.strptime(args.end, "%Y-%m-%d") 236 + except: 237 + logger.error("Invalid end date") 238 + sys.exit(1) 239 + 240 + if end <= start: 241 + logger.error( 242 + "Start date has to be before date, what're you trying to do man..." 243 + ) 244 + sys.exit(1) 245 + 246 + asyncio.run( 247 + retrieve_likes( 248 + user_name, 249 + app_pw, 250 + graph_file=args.graph_file, 251 + checkpoint_dir=args.save_dir, 252 + start_dt=start, 253 + end_dt=end, 254 + ) 255 + ) 256 + 257 + 258 + if __name__ == "__main__": 259 + main()