ML-based recommendation feed for Bluesky posts
1from functools import reduce
2import json
3import glob
4from multiprocessing import Pool
5import os
6import random
7
8import numpy as np
9import torch
10from torch.utils.data import Dataset
11
12
13from scripts.utils import get_logger
14
15logger = get_logger(__name__)
16
17
18def get_frequencies(array: np.ndarray, user_count: int) -> np.ndarray:
19 frequencies = np.zeros((user_count, 2))
20 for source_id, target_id in array:
21 frequencies[source_id][0] += 1
22 frequencies[target_id][1] += 1
23 return frequencies
24
25
26class FollowDataset(Dataset):
27 def __init__(
28 self,
29 dataset_path: str,
30 split: str,
31 negative_sample_chance: float,
32 ):
33 with open(os.path.join(dataset_path, "metadata.json"), "r") as in_file:
34 metadata = json.load(in_file)
35
36 with open(os.path.join(dataset_path, metadata["did_id_map"]), "r") as in_file:
37 self.did_id_map: dict[str, int] = json.load(in_file)
38
39 if split not in metadata["splits"]:
40 raise ValueError(f"Could not find split in metadata file: {split}")
41 split_files = metadata["splits"][split]
42
43 self.numpy_files: list[tuple[str, str, int, int]] = []
44 for file in split_files:
45 file_idx = file["filename"].split("_")[2].split(".")[0]
46 self.numpy_files.append((file["filename"], file["dtype"], file_idx, file["shape"][0])) # type: ignore
47
48 self.numpy_files.sort(key=lambda x: x[1])
49 self.dataframes: list[np.ndarray] = []
50 for filename, dtype, _, row_count in self.numpy_files:
51 self.dataframes.append(
52 np.memmap(
53 os.path.join(dataset_path, filename),
54 dtype=dtype,
55 mode="r",
56 shape=(row_count, 2),
57 )
58 )
59
60 logger.info("Calculating node frequency...")
61 with Pool(7) as p:
62 self.cumulative_freq = reduce(
63 np.add,
64 p.starmap(
65 get_frequencies,
66 [
67 (dataframe, len(self.did_id_map))
68 for dataframe in self.dataframes
69 ],
70 ),
71 )
72 self.cumulative_freq[0] = self.cumulative_freq[0].cumsum()
73 self.cumulative_freq[1] = self.cumulative_freq[1].cumsum()
74
75 self.negative_sample_chance = negative_sample_chance
76
77 def __len__(self) -> int:
78 return sum((row_count for _, _, _, row_count in self.numpy_files))
79
80 def num_users(self) -> int:
81 return len(self.did_id_map)
82
83 def _idx_to_row(self, idx: int) -> tuple[int, int]:
84 if idx < 0 or idx >= len(self):
85 raise IndexError(f"Invalid index: {idx}")
86
87 # Find which file contains index
88 row_index_total = 0
89 effective_idx = idx
90 i = 0
91 for i, (_, _, _, row_count) in enumerate(self.numpy_files):
92 row_index_total += row_count
93 if idx < row_index_total:
94 break
95 effective_idx -= row_count
96
97 row = self.dataframes[i][effective_idx]
98 return (row[0].item(), row[1].item())
99
100 def __getitem__(self, idx: int) -> tuple[tuple[int, int], int]:
101 """
102 Grab follow connection and corrupt it at defined frequency to another id
103 weighted by that id's prevalence in the dataset
104 """
105 sample = self._idx_to_row(idx)
106 return (sample, 1)
107
108 def collate_rows(
109 self,
110 batch: list[tuple[tuple[int, int], int]],
111 ) -> tuple[torch.Tensor, torch.Tensor]:
112 # Corrupt some rows into negative edges
113 corrupted_sources = []
114 corrupted_targets = []
115 for i in range(len(batch)):
116 if random.random() < self.negative_sample_chance:
117 if random.random() < 0.5:
118 corrupted_sources.append(i)
119 else:
120 corrupted_targets.append(i)
121
122 new_sources = self.cumulative_freq[:, 0].searchsorted(
123 np.random.sample(len(corrupted_sources)) * self.cumulative_freq.shape[0]
124 )
125 new_targets = self.cumulative_freq[:, 1].searchsorted(
126 np.random.sample(len(corrupted_targets)) * self.cumulative_freq.shape[0]
127 )
128 new_sources[new_sources >= self.cumulative_freq.shape[0]] = (
129 self.cumulative_freq.shape[0] - 1
130 )
131 new_targets[new_targets >= self.cumulative_freq.shape[0]] = (
132 self.cumulative_freq.shape[0] - 1
133 )
134
135 for i, idx in enumerate(corrupted_sources):
136 batch[idx] = ((new_sources[i], batch[idx][0][1]), -1)
137
138 for i, idx in enumerate(corrupted_targets):
139 batch[idx] = ((batch[idx][0][0], new_targets[i]), -1)
140
141 follows = torch.concat(tuple(torch.IntTensor([follow]) for follow, _ in batch))
142 labels = torch.concat(tuple(torch.IntTensor([label]) for _, label in batch))
143 return (follows, labels)