ML-based recommendation feed for Bluesky posts
1import argparse
2import asyncio
3import decimal
4import logging
5import os
6import sys
7import time
8from typing import Tuple, List, Dict
9
10from atproto import AsyncClient
11from atproto import exceptions as at_exceptions
12import pandas as pd
13
14from utils import RateLimit, BSKY_API_LIMIT
15
16logger = logging.getLogger(__name__)
17logger.setLevel(logging.INFO)
18
19# Create formatter
20formatter = logging.Formatter(
21 "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
22)
23
24# Console handler
25console_handler = logging.StreamHandler(sys.stdout)
26console_handler.setFormatter(formatter)
27logger.addHandler(console_handler)
28
29
30BATCH_SIZE = 10
31CHECKPOINT_THRESHOLD = 100
32REQUIRED_ENV = ("BSKY_USER", "BSKY_APP_PW")
33
34
35async def get_all_follows(
36 client: AsyncClient, rate_limit: RateLimit, did: str
37) -> Tuple[str, List[str]]:
38 follows = []
39 await rate_limit.acquire()
40 data = await client.get_follows(actor=did, limit=100)
41 cursor = data.cursor
42 for follow in data.follows:
43 follows.append(follow.did)
44
45 while cursor:
46 await rate_limit.acquire()
47 data = await client.get_follows(actor=did, cursor=cursor, limit=100)
48 cursor = data.cursor
49 for follow in data.follows:
50 follows.append(follow.did)
51
52 return (did, follows)
53
54
55def save_checkpoint(output_dir: str, job_start: int, f_map: Dict[str, List[str]]):
56 # Create dataframe for follow map
57 follow_dict = {"follower": [], "follows": []}
58 for follower, follows in f_map.items():
59 follow_dict["follower"].append(follower)
60 follow_dict["follows"].append(follows)
61 follow_df = pd.DataFrame(data=follow_dict)
62
63 # Convert to scientific notation for more pleasent unix sort
64 size_string = f"{decimal.Decimal(len(f_map)):.2e}".replace("+", "").replace(
65 ".", "_"
66 )
67 save_path = os.path.join(
68 output_dir, f"{job_start}_checkpoint_{size_string}.parquet"
69 )
70
71 # Save follow map as parquet file
72 logger.info(f"Saving follow map to: {save_path}")
73 try:
74 follow_df.to_parquet(save_path)
75 except Exception as e:
76 logger.error(f"Failed to save follow map: {e}")
77 sys.exit(1)
78
79
80async def explore(username: str, pw: str, start_did: str, num_hops: int, save_dir: str):
81 client = AsyncClient()
82 await client.login(username, pw)
83
84 logger.info(f"Starting did: {start_did}")
85 follow_map = dict()
86 distance_map = {start_did: 0}
87 to_explore = [start_did]
88
89 logger.info(
90 f"Starting crawl with:\nStart DID: {start_did}\nNum hops: {num_hops}\nSaving Output to: {save_dir}"
91 )
92
93 job_start = int(time.time())
94
95 # Try to only send 10 requests a second
96 batch_count = 1
97 fail_count = 0
98 rate_limiter = RateLimit(BSKY_API_LIMIT)
99 while len(to_explore):
100 batch = to_explore[:BATCH_SIZE]
101 to_explore = to_explore[BATCH_SIZE:]
102 logger.info(
103 f"Starting batch with size: {len(batch)} remaining to_explore: {len(to_explore)}"
104 )
105 for result in asyncio.as_completed(
106 [get_all_follows(client, rate_limiter, did) for did in batch]
107 ):
108 try:
109 follower, follows = await result
110 follow_map[follower] = follows
111 logger.info(f"{follower} follows {len(follows)} (public) accounts")
112 except at_exceptions.BadRequestError as e:
113 # Bad request is probably a profile that's private or deleted
114 logger.info(f"Bad Request: {e.response.content.error}")
115 continue
116 except Exception as e:
117 logger.error(f"Failed to get followers: {e}", exc_info=1)
118 fail_count += 1
119 if fail_count >= 3:
120 sys.exit(1)
121 continue
122
123 # Save a checkpoint every 10,000 accounts
124 if batch_count % CHECKPOINT_THRESHOLD == 0:
125 save_checkpoint(save_dir, job_start, follow_map)
126
127 # If too far from start, dont add follows to exploration queue
128 if distance_map[follower] != num_hops:
129 for follow_did in follows:
130 if follow_did not in follow_map:
131 to_explore.append(follow_did)
132 distance_map[follow_did] = distance_map[follower] + 1
133
134 logger.info(
135 f"Finished batch {batch_count} | to_explore size: {len(to_explore)}"
136 )
137 batch_count += 1
138
139 save_checkpoint(save_dir, job_start, follow_map)
140 logger.info("Crawl complete")
141
142
143def main():
144 for key in REQUIRED_ENV:
145 if key not in os.environ:
146 raise ValueError(f"Must set '{key}' env var")
147
148 user_name = os.environ["BSKY_USER"]
149 app_pw = os.environ["BSKY_APP_PW"]
150
151 parser = argparse.ArgumentParser(
152 prog="CrawlFollows",
153 description="Crawl social graph using follows starting from provided DID",
154 )
155 parser.add_argument(
156 "--start-did",
157 dest="start_did",
158 required=True,
159 help="DID for account to start crawl at",
160 )
161 parser.add_argument(
162 "--num-hops",
163 dest="num_hops",
164 type=int,
165 default=2,
166 help="How many network hops to explore out from the start",
167 )
168 parser.add_argument(
169 "--save-dir",
170 dest="save_dir",
171 default="data/crawl/",
172 help="Where to store crawl data",
173 )
174 args = parser.parse_args()
175
176 asyncio.run(
177 explore(
178 user_name,
179 app_pw,
180 start_did=args.start_did,
181 num_hops=args.num_hops,
182 save_dir=args.save_dir,
183 )
184 )
185
186
187if __name__ == "__main__":
188 main()