ML-based recommendation feed for Bluesky posts
at main 143 lines 5.0 kB view raw
1from functools import reduce 2import json 3import glob 4from multiprocessing import Pool 5import os 6import random 7 8import numpy as np 9import torch 10from torch.utils.data import Dataset 11 12 13from scripts.utils import get_logger 14 15logger = get_logger(__name__) 16 17 18def get_frequencies(array: np.ndarray, user_count: int) -> np.ndarray: 19 frequencies = np.zeros((user_count, 2)) 20 for source_id, target_id in array: 21 frequencies[source_id][0] += 1 22 frequencies[target_id][1] += 1 23 return frequencies 24 25 26class FollowDataset(Dataset): 27 def __init__( 28 self, 29 dataset_path: str, 30 split: str, 31 negative_sample_chance: float, 32 ): 33 with open(os.path.join(dataset_path, "metadata.json"), "r") as in_file: 34 metadata = json.load(in_file) 35 36 with open(os.path.join(dataset_path, metadata["did_id_map"]), "r") as in_file: 37 self.did_id_map: dict[str, int] = json.load(in_file) 38 39 if split not in metadata["splits"]: 40 raise ValueError(f"Could not find split in metadata file: {split}") 41 split_files = metadata["splits"][split] 42 43 self.numpy_files: list[tuple[str, str, int, int]] = [] 44 for file in split_files: 45 file_idx = file["filename"].split("_")[2].split(".")[0] 46 self.numpy_files.append((file["filename"], file["dtype"], file_idx, file["shape"][0])) # type: ignore 47 48 self.numpy_files.sort(key=lambda x: x[1]) 49 self.dataframes: list[np.ndarray] = [] 50 for filename, dtype, _, row_count in self.numpy_files: 51 self.dataframes.append( 52 np.memmap( 53 os.path.join(dataset_path, filename), 54 dtype=dtype, 55 mode="r", 56 shape=(row_count, 2), 57 ) 58 ) 59 60 logger.info("Calculating node frequency...") 61 with Pool(7) as p: 62 self.cumulative_freq = reduce( 63 np.add, 64 p.starmap( 65 get_frequencies, 66 [ 67 (dataframe, len(self.did_id_map)) 68 for dataframe in self.dataframes 69 ], 70 ), 71 ) 72 self.cumulative_freq[0] = self.cumulative_freq[0].cumsum() 73 self.cumulative_freq[1] = self.cumulative_freq[1].cumsum() 74 75 self.negative_sample_chance = negative_sample_chance 76 77 def __len__(self) -> int: 78 return sum((row_count for _, _, _, row_count in self.numpy_files)) 79 80 def num_users(self) -> int: 81 return len(self.did_id_map) 82 83 def _idx_to_row(self, idx: int) -> tuple[int, int]: 84 if idx < 0 or idx >= len(self): 85 raise IndexError(f"Invalid index: {idx}") 86 87 # Find which file contains index 88 row_index_total = 0 89 effective_idx = idx 90 i = 0 91 for i, (_, _, _, row_count) in enumerate(self.numpy_files): 92 row_index_total += row_count 93 if idx < row_index_total: 94 break 95 effective_idx -= row_count 96 97 row = self.dataframes[i][effective_idx] 98 return (row[0].item(), row[1].item()) 99 100 def __getitem__(self, idx: int) -> tuple[tuple[int, int], int]: 101 """ 102 Grab follow connection and corrupt it at defined frequency to another id 103 weighted by that id's prevalence in the dataset 104 """ 105 sample = self._idx_to_row(idx) 106 return (sample, 1) 107 108 def collate_rows( 109 self, 110 batch: list[tuple[tuple[int, int], int]], 111 ) -> tuple[torch.Tensor, torch.Tensor]: 112 # Corrupt some rows into negative edges 113 corrupted_sources = [] 114 corrupted_targets = [] 115 for i in range(len(batch)): 116 if random.random() < self.negative_sample_chance: 117 if random.random() < 0.5: 118 corrupted_sources.append(i) 119 else: 120 corrupted_targets.append(i) 121 122 new_sources = self.cumulative_freq[:, 0].searchsorted( 123 np.random.sample(len(corrupted_sources)) * self.cumulative_freq.shape[0] 124 ) 125 new_targets = self.cumulative_freq[:, 1].searchsorted( 126 np.random.sample(len(corrupted_targets)) * self.cumulative_freq.shape[0] 127 ) 128 new_sources[new_sources >= self.cumulative_freq.shape[0]] = ( 129 self.cumulative_freq.shape[0] - 1 130 ) 131 new_targets[new_targets >= self.cumulative_freq.shape[0]] = ( 132 self.cumulative_freq.shape[0] - 1 133 ) 134 135 for i, idx in enumerate(corrupted_sources): 136 batch[idx] = ((new_sources[i], batch[idx][0][1]), -1) 137 138 for i, idx in enumerate(corrupted_targets): 139 batch[idx] = ((batch[idx][0][0], new_targets[i]), -1) 140 141 follows = torch.concat(tuple(torch.IntTensor([follow]) for follow, _ in batch)) 142 labels = torch.concat(tuple(torch.IntTensor([label]) for _, label in batch)) 143 return (follows, labels)