ML-based recommendation feed for Bluesky posts
1import argparse
2import gzip
3import json
4import logging
5import os
6import random
7import sys
8from typing import Optional
9
10import numpy as np
11
12from scripts.utils import get_logger
13
14logger = get_logger(__name__)
15
16
17def main(
18 follow_dir: str,
19 output_dir: str,
20 val_split: float,
21 file_size: int,
22 max_accounts: Optional[int],
23 exclude_unfetched: bool = False,
24):
25 if val_split >= 1.0 or val_split <= 0.0:
26 raise ValueError("Validation split must be between 0 and 1 exclusive")
27
28 if not os.path.exists(output_dir):
29 os.makedirs(output_dir)
30
31 did_id_map: dict[str, int] = dict()
32
33 files = os.listdir(follow_dir)
34 if max_accounts:
35 files = files[:max_accounts]
36
37 if exclude_unfetched:
38 for file in files:
39 did_id_map[file[-3]] = len(did_id_map)
40
41 logger.info("Reading follows...")
42 train_file_idx = 1
43 val_file_idx = 1
44 train_follows: list[tuple[int, int]] = []
45 val_follows: list[tuple[int, int]] = []
46 train_files = []
47 val_files = []
48 for file in files:
49 source_did = file[:-3] # Remove .gz extension
50 if max_accounts and len(did_id_map) >= max_accounts:
51 continue
52
53 if source_did not in did_id_map:
54 did_id_map[source_did] = len(did_id_map)
55
56 for line in gzip.open(follow_dir + file, "rt"):
57 target_did = json.loads(line)["value"]["subject"]
58 if exclude_unfetched and target_did not in did_id_map:
59 continue
60
61 if max_accounts and len(did_id_map) >= max_accounts:
62 continue
63
64 if target_did not in did_id_map:
65 did_id_map[target_did] = len(did_id_map)
66
67 if random.random() < val_split:
68 val_follows.append((did_id_map[source_did], did_id_map[target_did]))
69 if len(val_follows) >= file_size:
70 filename = f"val_follows_{val_file_idx}.data"
71 logger.info(f"Saving validation file: {filename}")
72 val_df = np.array(val_follows, dtype="int64")
73 val_file_df = np.memmap(
74 os.path.join(output_dir, filename),
75 dtype="int64",
76 mode="w+",
77 shape=val_df.shape,
78 )
79 val_file_df[:] = val_df[:]
80 val_file_df.flush()
81 val_follows = []
82 val_files.append(
83 {
84 "filename": filename,
85 "shape": val_df.shape,
86 "dtype": "int64",
87 }
88 )
89 val_file_idx += 1
90 del val_df
91 del val_file_df
92 else:
93 train_follows.append((did_id_map[source_did], did_id_map[target_did]))
94 if len(train_follows) >= file_size:
95 filename = f"train_follows_{train_file_idx}.data"
96 logger.info(f"Saving training file: {filename}")
97 train_df = np.array(train_follows, dtype="int64")
98 train_file_df = np.memmap(
99 os.path.join(output_dir, filename),
100 dtype="int64",
101 mode="w+",
102 shape=train_df.shape,
103 )
104 train_file_df[:] = train_df[:]
105 train_file_df.flush()
106 train_follows = []
107 train_files.append(
108 {
109 "filename": filename,
110 "shape": train_df.shape,
111 "dtype": "int64",
112 }
113 )
114 train_file_idx += 1
115 del train_df
116 del train_file_df
117
118 logger.info("Saving remnant files...")
119 if len(val_follows):
120 filename = f"val_follows_{val_file_idx}.data"
121 val_df = np.array(val_follows, dtype="int64")
122 val_file_df = np.memmap(
123 os.path.join(output_dir, filename),
124 dtype="int64",
125 mode="w+",
126 shape=val_df.shape,
127 )
128 val_file_df[:] = val_df[:]
129 val_file_df.flush()
130 val_files.append(
131 {
132 "filename": filename,
133 "shape": val_df.shape,
134 "dtype": "int64",
135 }
136 )
137
138 if len(train_follows):
139 filename = f"train_follows_{train_file_idx}.data"
140 train_df = np.array(train_follows, dtype="int64")
141 train_file_df = np.memmap(
142 os.path.join(output_dir, filename),
143 dtype="int64",
144 mode="w+",
145 shape=train_df.shape,
146 )
147 train_file_df[:] = train_df[:]
148 train_file_df.flush()
149 train_files.append(
150 {
151 "filename": filename,
152 "shape": train_df.shape,
153 "dtype": "int64",
154 }
155 )
156
157 logger.info("Finished reading follows")
158
159 # Write did_id_map out
160 logger.info("Writing id map...")
161 with open(os.path.join(output_dir, "did_id_map.json"), "w") as out_file:
162 out_file.write(json.dumps(did_id_map))
163
164 # Write out file metadata
165 logger.info("Writing dataset metadata")
166 with open(os.path.join(output_dir, "metadata.json"), "w") as out_file:
167 out_file.write(
168 json.dumps(
169 {
170 "did_id_map": "did_id_map.json",
171 "splits": {"training": train_files, "validation": val_files},
172 }
173 )
174 )
175
176
177if __name__ == "__main__":
178 parser = argparse.ArgumentParser()
179 parser.add_argument(
180 "--follow-dir",
181 dest="follow_dir",
182 required=True,
183 help="Path to folder with files",
184 )
185
186 parser.add_argument(
187 "--output-dir",
188 dest="output_dir",
189 required=True,
190 help="Path to folder for output",
191 )
192
193 parser.add_argument(
194 "--val-split",
195 dest="val_split",
196 required=True,
197 type=float,
198 help="Percent of data for validation",
199 )
200
201 parser.add_argument(
202 "--file-size",
203 dest="file_size",
204 required=True,
205 type=int,
206 help="Max rows per output file",
207 )
208 parser.add_argument(
209 "--exclude-unfetched",
210 dest="exclude_unfetched",
211 required=False,
212 type=bool,
213 help="Whether to include accounts whose follows haven't been retrieved",
214 )
215 parser.add_argument(
216 "--max-accounts",
217 dest="max_accounts",
218 required=False,
219 type=int,
220 help="Maximum number of accounts to include",
221 )
222 args = parser.parse_args()
223
224 main(
225 follow_dir=args.follow_dir,
226 output_dir=args.output_dir,
227 val_split=args.val_split,
228 file_size=args.file_size,
229 exclude_unfetched=args.exclude_unfetched,
230 max_accounts=args.max_accounts,
231 )