ML-based recommendation feed for Bluesky posts
1import argparse
2import asyncio
3import gzip
4import logging
5import os
6import sys
7from typing import Tuple, List
8
9from atproto import AsyncClient
10from atproto import exceptions as at_exceptions
11from atproto_client.models.app.bsky.graph.follow import Record as FollowRecord
12
13from utils import load_checkpoint, get_accounts, RateLimit, BSKY_API_LIMIT
14
15logger = logging.getLogger(__name__)
16logger.setLevel(logging.INFO)
17
18# Create formatter
19formatter = logging.Formatter(
20 "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
21)
22
23# Console handler
24console_handler = logging.StreamHandler(sys.stdout)
25console_handler.setFormatter(formatter)
26logger.addHandler(console_handler)
27
28
29BATCH_SIZE = 100
30
31
32async def get_all_follows(
33 client: AsyncClient,
34 rate_limit: RateLimit,
35 account_did: str,
36) -> Tuple[List[FollowRecord], str]:
37 follows: List[FollowRecord] = []
38 await rate_limit.acquire()
39 try:
40 data = await client.com.atproto.repo.list_records(
41 {
42 "collection": "app.bsky.graph.follow",
43 "repo": account_did,
44 "limit": 100,
45 }
46 )
47 # If user can't be accessed just return an empty list to skip next time
48 except at_exceptions.BadRequestError as e:
49 if e.response.status_code == 400:
50 return [], account_did
51 else:
52 logger.info(f"Error status code: {e.response.status_code}")
53 raise e
54 except at_exceptions.RequestException as e:
55 if e.response.status_code == 500:
56 return await get_all_follows(client, rate_limit, account_did)
57 raise e
58 # If we timeout just try again
59 except at_exceptions.InvokeTimeoutError:
60 return await get_all_follows(client, rate_limit, account_did)
61
62 for follow in data.records:
63 follows.append(follow)
64
65 # Limit to 1000 follows per account
66 while data.cursor and len(follows) < 1000:
67 await rate_limit.acquire()
68 try:
69 data = await client.com.atproto.repo.list_records(
70 {
71 "collection": "app.bsky.graph.follow",
72 "repo": account_did,
73 "cursor": data.cursor,
74 "limit": 100,
75 }
76 )
77 for follow in data.records:
78 follows.append(follow)
79 # If we timeout just try again
80 except at_exceptions.InvokeTimeoutError:
81 continue
82 except at_exceptions.RequestException as e:
83 if e.response.status_code == 500:
84 continue
85 else:
86 raise e
87
88 return follows, account_did
89
90
91async def retrieve_follows(
92 graph_file: str,
93 checkpoint_dir: str,
94):
95 completed_accounts = load_checkpoint(checkpoint_dir)
96 accts = get_accounts(graph_file, completed_accounts)
97 logger.info(f"Num of accounts to retrieve follows from: {len(accts)}")
98
99 client = AsyncClient()
100
101 # Get all follows for accounts
102 batch_count = 0
103 fail_count = 0
104 rate_limiter = RateLimit(BSKY_API_LIMIT)
105 for i in range(0, len(accts), BATCH_SIZE):
106 batch = [acct for acct, _ in accts[i : i + BATCH_SIZE]]
107 for result in asyncio.as_completed(
108 [get_all_follows(client, rate_limiter, did) for did in batch]
109 ):
110 try:
111 follows, did = await result
112 # Save follows
113 with gzip.open(
114 os.path.join(checkpoint_dir, did + ".gz"), "wt"
115 ) as out_file:
116 for follow in follows:
117 out_file.write(follow.model_dump_json() + "\n")
118 except at_exceptions.BadRequestError as e:
119 # Bad request is probably a profile that's private or deleted
120 logger.info(f"Bad Request: {e.response.content.error}")
121 continue
122 except Exception as e:
123 logger.error(f"Failed to get follows: {e}", exc_info=1)
124 fail_count += 1
125 if fail_count >= 100:
126 logger.error("Hitting error threshold, exiting...")
127 sys.exit(1)
128 continue
129
130 batch_count += 1
131 if batch_count % 10 == 0:
132 logger.info(f"Completed batch: {batch_count}")
133
134
135def main():
136 parser = argparse.ArgumentParser(
137 prog="GetFollows",
138 description="Get all follows for accounts in provided follow graph",
139 )
140 parser.add_argument(
141 "--graph-file",
142 dest="graph_file",
143 required=True,
144 help="File with follow graph",
145 )
146 parser.add_argument(
147 "--save-dir",
148 dest="save_dir",
149 required=True,
150 help="Where to store crawl data",
151 )
152 args = parser.parse_args()
153
154 if args.save_dir is None and args.ckpt is None:
155 logger.error("Must provide save dir or checkpoint dir")
156 sys.exit(1)
157
158 asyncio.run(
159 retrieve_follows(
160 graph_file=args.graph_file,
161 checkpoint_dir=args.save_dir,
162 )
163 )
164
165
166if __name__ == "__main__":
167 main()