from functools import reduce import json import glob from multiprocessing import Pool import os import random import numpy as np import torch from torch.utils.data import Dataset from scripts.utils import get_logger logger = get_logger(__name__) def get_frequencies(array: np.ndarray, user_count: int) -> np.ndarray: frequencies = np.zeros((user_count, 2)) for source_id, target_id in array: frequencies[source_id][0] += 1 frequencies[target_id][1] += 1 return frequencies class FollowDataset(Dataset): def __init__( self, dataset_path: str, split: str, negative_sample_chance: float, ): with open(os.path.join(dataset_path, "metadata.json"), "r") as in_file: metadata = json.load(in_file) with open(os.path.join(dataset_path, metadata["did_id_map"]), "r") as in_file: self.did_id_map: dict[str, int] = json.load(in_file) if split not in metadata["splits"]: raise ValueError(f"Could not find split in metadata file: {split}") split_files = metadata["splits"][split] self.numpy_files: list[tuple[str, str, int, int]] = [] for file in split_files: file_idx = file["filename"].split("_")[2].split(".")[0] self.numpy_files.append((file["filename"], file["dtype"], file_idx, file["shape"][0])) # type: ignore self.numpy_files.sort(key=lambda x: x[1]) self.dataframes: list[np.ndarray] = [] for filename, dtype, _, row_count in self.numpy_files: self.dataframes.append( np.memmap( os.path.join(dataset_path, filename), dtype=dtype, mode="r", shape=(row_count, 2), ) ) logger.info("Calculating node frequency...") with Pool(7) as p: self.cumulative_freq = reduce( np.add, p.starmap( get_frequencies, [ (dataframe, len(self.did_id_map)) for dataframe in self.dataframes ], ), ) self.cumulative_freq[0] = self.cumulative_freq[0].cumsum() self.cumulative_freq[1] = self.cumulative_freq[1].cumsum() self.negative_sample_chance = negative_sample_chance def __len__(self) -> int: return sum((row_count for _, _, _, row_count in self.numpy_files)) def num_users(self) -> int: return len(self.did_id_map) def _idx_to_row(self, idx: int) -> tuple[int, int]: if idx < 0 or idx >= len(self): raise IndexError(f"Invalid index: {idx}") # Find which file contains index row_index_total = 0 effective_idx = idx i = 0 for i, (_, _, _, row_count) in enumerate(self.numpy_files): row_index_total += row_count if idx < row_index_total: break effective_idx -= row_count row = self.dataframes[i][effective_idx] return (row[0].item(), row[1].item()) def __getitem__(self, idx: int) -> tuple[tuple[int, int], int]: """ Grab follow connection and corrupt it at defined frequency to another id weighted by that id's prevalence in the dataset """ sample = self._idx_to_row(idx) return (sample, 1) def collate_rows( self, batch: list[tuple[tuple[int, int], int]], ) -> tuple[torch.Tensor, torch.Tensor]: # Corrupt some rows into negative edges corrupted_sources = [] corrupted_targets = [] for i in range(len(batch)): if random.random() < self.negative_sample_chance: if random.random() < 0.5: corrupted_sources.append(i) else: corrupted_targets.append(i) new_sources = self.cumulative_freq[:, 0].searchsorted( np.random.sample(len(corrupted_sources)) * self.cumulative_freq.shape[0] ) new_targets = self.cumulative_freq[:, 1].searchsorted( np.random.sample(len(corrupted_targets)) * self.cumulative_freq.shape[0] ) new_sources[new_sources >= self.cumulative_freq.shape[0]] = ( self.cumulative_freq.shape[0] - 1 ) new_targets[new_targets >= self.cumulative_freq.shape[0]] = ( self.cumulative_freq.shape[0] - 1 ) for i, idx in enumerate(corrupted_sources): batch[idx] = ((new_sources[i], batch[idx][0][1]), -1) for i, idx in enumerate(corrupted_targets): batch[idx] = ((batch[idx][0][0], new_targets[i]), -1) follows = torch.concat(tuple(torch.IntTensor([follow]) for follow, _ in batch)) labels = torch.concat(tuple(torch.IntTensor([label]) for _, label in batch)) return (follows, labels)