ML-based recommendation feed for Bluesky posts
at main 231 lines 7.1 kB view raw
1import argparse 2import gzip 3import json 4import logging 5import os 6import random 7import sys 8from typing import Optional 9 10import numpy as np 11 12from scripts.utils import get_logger 13 14logger = get_logger(__name__) 15 16 17def main( 18 follow_dir: str, 19 output_dir: str, 20 val_split: float, 21 file_size: int, 22 max_accounts: Optional[int], 23 exclude_unfetched: bool = False, 24): 25 if val_split >= 1.0 or val_split <= 0.0: 26 raise ValueError("Validation split must be between 0 and 1 exclusive") 27 28 if not os.path.exists(output_dir): 29 os.makedirs(output_dir) 30 31 did_id_map: dict[str, int] = dict() 32 33 files = os.listdir(follow_dir) 34 if max_accounts: 35 files = files[:max_accounts] 36 37 if exclude_unfetched: 38 for file in files: 39 did_id_map[file[-3]] = len(did_id_map) 40 41 logger.info("Reading follows...") 42 train_file_idx = 1 43 val_file_idx = 1 44 train_follows: list[tuple[int, int]] = [] 45 val_follows: list[tuple[int, int]] = [] 46 train_files = [] 47 val_files = [] 48 for file in files: 49 source_did = file[:-3] # Remove .gz extension 50 if max_accounts and len(did_id_map) >= max_accounts: 51 continue 52 53 if source_did not in did_id_map: 54 did_id_map[source_did] = len(did_id_map) 55 56 for line in gzip.open(follow_dir + file, "rt"): 57 target_did = json.loads(line)["value"]["subject"] 58 if exclude_unfetched and target_did not in did_id_map: 59 continue 60 61 if max_accounts and len(did_id_map) >= max_accounts: 62 continue 63 64 if target_did not in did_id_map: 65 did_id_map[target_did] = len(did_id_map) 66 67 if random.random() < val_split: 68 val_follows.append((did_id_map[source_did], did_id_map[target_did])) 69 if len(val_follows) >= file_size: 70 filename = f"val_follows_{val_file_idx}.data" 71 logger.info(f"Saving validation file: {filename}") 72 val_df = np.array(val_follows, dtype="int64") 73 val_file_df = np.memmap( 74 os.path.join(output_dir, filename), 75 dtype="int64", 76 mode="w+", 77 shape=val_df.shape, 78 ) 79 val_file_df[:] = val_df[:] 80 val_file_df.flush() 81 val_follows = [] 82 val_files.append( 83 { 84 "filename": filename, 85 "shape": val_df.shape, 86 "dtype": "int64", 87 } 88 ) 89 val_file_idx += 1 90 del val_df 91 del val_file_df 92 else: 93 train_follows.append((did_id_map[source_did], did_id_map[target_did])) 94 if len(train_follows) >= file_size: 95 filename = f"train_follows_{train_file_idx}.data" 96 logger.info(f"Saving training file: {filename}") 97 train_df = np.array(train_follows, dtype="int64") 98 train_file_df = np.memmap( 99 os.path.join(output_dir, filename), 100 dtype="int64", 101 mode="w+", 102 shape=train_df.shape, 103 ) 104 train_file_df[:] = train_df[:] 105 train_file_df.flush() 106 train_follows = [] 107 train_files.append( 108 { 109 "filename": filename, 110 "shape": train_df.shape, 111 "dtype": "int64", 112 } 113 ) 114 train_file_idx += 1 115 del train_df 116 del train_file_df 117 118 logger.info("Saving remnant files...") 119 if len(val_follows): 120 filename = f"val_follows_{val_file_idx}.data" 121 val_df = np.array(val_follows, dtype="int64") 122 val_file_df = np.memmap( 123 os.path.join(output_dir, filename), 124 dtype="int64", 125 mode="w+", 126 shape=val_df.shape, 127 ) 128 val_file_df[:] = val_df[:] 129 val_file_df.flush() 130 val_files.append( 131 { 132 "filename": filename, 133 "shape": val_df.shape, 134 "dtype": "int64", 135 } 136 ) 137 138 if len(train_follows): 139 filename = f"train_follows_{train_file_idx}.data" 140 train_df = np.array(train_follows, dtype="int64") 141 train_file_df = np.memmap( 142 os.path.join(output_dir, filename), 143 dtype="int64", 144 mode="w+", 145 shape=train_df.shape, 146 ) 147 train_file_df[:] = train_df[:] 148 train_file_df.flush() 149 train_files.append( 150 { 151 "filename": filename, 152 "shape": train_df.shape, 153 "dtype": "int64", 154 } 155 ) 156 157 logger.info("Finished reading follows") 158 159 # Write did_id_map out 160 logger.info("Writing id map...") 161 with open(os.path.join(output_dir, "did_id_map.json"), "w") as out_file: 162 out_file.write(json.dumps(did_id_map)) 163 164 # Write out file metadata 165 logger.info("Writing dataset metadata") 166 with open(os.path.join(output_dir, "metadata.json"), "w") as out_file: 167 out_file.write( 168 json.dumps( 169 { 170 "did_id_map": "did_id_map.json", 171 "splits": {"training": train_files, "validation": val_files}, 172 } 173 ) 174 ) 175 176 177if __name__ == "__main__": 178 parser = argparse.ArgumentParser() 179 parser.add_argument( 180 "--follow-dir", 181 dest="follow_dir", 182 required=True, 183 help="Path to folder with files", 184 ) 185 186 parser.add_argument( 187 "--output-dir", 188 dest="output_dir", 189 required=True, 190 help="Path to folder for output", 191 ) 192 193 parser.add_argument( 194 "--val-split", 195 dest="val_split", 196 required=True, 197 type=float, 198 help="Percent of data for validation", 199 ) 200 201 parser.add_argument( 202 "--file-size", 203 dest="file_size", 204 required=True, 205 type=int, 206 help="Max rows per output file", 207 ) 208 parser.add_argument( 209 "--exclude-unfetched", 210 dest="exclude_unfetched", 211 required=False, 212 type=bool, 213 help="Whether to include accounts whose follows haven't been retrieved", 214 ) 215 parser.add_argument( 216 "--max-accounts", 217 dest="max_accounts", 218 required=False, 219 type=int, 220 help="Maximum number of accounts to include", 221 ) 222 args = parser.parse_args() 223 224 main( 225 follow_dir=args.follow_dir, 226 output_dir=args.output_dir, 227 val_split=args.val_split, 228 file_size=args.file_size, 229 exclude_unfetched=args.exclude_unfetched, 230 max_accounts=args.max_accounts, 231 )