import argparse import gzip import json import logging import os import random import sys from typing import Optional import numpy as np from scripts.utils import get_logger logger = get_logger(__name__) def main( follow_dir: str, output_dir: str, val_split: float, file_size: int, max_accounts: Optional[int], exclude_unfetched: bool = False, ): if val_split >= 1.0 or val_split <= 0.0: raise ValueError("Validation split must be between 0 and 1 exclusive") if not os.path.exists(output_dir): os.makedirs(output_dir) did_id_map: dict[str, int] = dict() files = os.listdir(follow_dir) if max_accounts: files = files[:max_accounts] if exclude_unfetched: for file in files: did_id_map[file[-3]] = len(did_id_map) logger.info("Reading follows...") train_file_idx = 1 val_file_idx = 1 train_follows: list[tuple[int, int]] = [] val_follows: list[tuple[int, int]] = [] train_files = [] val_files = [] for file in files: source_did = file[:-3] # Remove .gz extension if max_accounts and len(did_id_map) >= max_accounts: continue if source_did not in did_id_map: did_id_map[source_did] = len(did_id_map) for line in gzip.open(follow_dir + file, "rt"): target_did = json.loads(line)["value"]["subject"] if exclude_unfetched and target_did not in did_id_map: continue if max_accounts and len(did_id_map) >= max_accounts: continue if target_did not in did_id_map: did_id_map[target_did] = len(did_id_map) if random.random() < val_split: val_follows.append((did_id_map[source_did], did_id_map[target_did])) if len(val_follows) >= file_size: filename = f"val_follows_{val_file_idx}.data" logger.info(f"Saving validation file: {filename}") val_df = np.array(val_follows, dtype="int64") val_file_df = np.memmap( os.path.join(output_dir, filename), dtype="int64", mode="w+", shape=val_df.shape, ) val_file_df[:] = val_df[:] val_file_df.flush() val_follows = [] val_files.append( { "filename": filename, "shape": val_df.shape, "dtype": "int64", } ) val_file_idx += 1 del val_df del val_file_df else: train_follows.append((did_id_map[source_did], did_id_map[target_did])) if len(train_follows) >= file_size: filename = f"train_follows_{train_file_idx}.data" logger.info(f"Saving training file: {filename}") train_df = np.array(train_follows, dtype="int64") train_file_df = np.memmap( os.path.join(output_dir, filename), dtype="int64", mode="w+", shape=train_df.shape, ) train_file_df[:] = train_df[:] train_file_df.flush() train_follows = [] train_files.append( { "filename": filename, "shape": train_df.shape, "dtype": "int64", } ) train_file_idx += 1 del train_df del train_file_df logger.info("Saving remnant files...") if len(val_follows): filename = f"val_follows_{val_file_idx}.data" val_df = np.array(val_follows, dtype="int64") val_file_df = np.memmap( os.path.join(output_dir, filename), dtype="int64", mode="w+", shape=val_df.shape, ) val_file_df[:] = val_df[:] val_file_df.flush() val_files.append( { "filename": filename, "shape": val_df.shape, "dtype": "int64", } ) if len(train_follows): filename = f"train_follows_{train_file_idx}.data" train_df = np.array(train_follows, dtype="int64") train_file_df = np.memmap( os.path.join(output_dir, filename), dtype="int64", mode="w+", shape=train_df.shape, ) train_file_df[:] = train_df[:] train_file_df.flush() train_files.append( { "filename": filename, "shape": train_df.shape, "dtype": "int64", } ) logger.info("Finished reading follows") # Write did_id_map out logger.info("Writing id map...") with open(os.path.join(output_dir, "did_id_map.json"), "w") as out_file: out_file.write(json.dumps(did_id_map)) # Write out file metadata logger.info("Writing dataset metadata") with open(os.path.join(output_dir, "metadata.json"), "w") as out_file: out_file.write( json.dumps( { "did_id_map": "did_id_map.json", "splits": {"training": train_files, "validation": val_files}, } ) ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--follow-dir", dest="follow_dir", required=True, help="Path to folder with files", ) parser.add_argument( "--output-dir", dest="output_dir", required=True, help="Path to folder for output", ) parser.add_argument( "--val-split", dest="val_split", required=True, type=float, help="Percent of data for validation", ) parser.add_argument( "--file-size", dest="file_size", required=True, type=int, help="Max rows per output file", ) parser.add_argument( "--exclude-unfetched", dest="exclude_unfetched", required=False, type=bool, help="Whether to include accounts whose follows haven't been retrieved", ) parser.add_argument( "--max-accounts", dest="max_accounts", required=False, type=int, help="Maximum number of accounts to include", ) args = parser.parse_args() main( follow_dir=args.follow_dir, output_dir=args.output_dir, val_split=args.val_split, file_size=args.file_size, exclude_unfetched=args.exclude_unfetched, max_accounts=args.max_accounts, )