ML-based recommendation feed for Bluesky posts
at main 131 lines 3.4 kB view raw
1from argparse import ArgumentParser 2from datetime import datetime 3import os 4from typing import Optional 5 6import lightning as pyl 7from lightning.pytorch.loggers import WandbLogger 8from lightning.pytorch.callbacks import ModelCheckpoint 9import torch 10from torchdata.stateful_dataloader import StatefulDataLoader 11 12from scripts.training.data import FollowDataset 13from scripts.training.models import FollowEmbedModule 14from scripts.utils import get_logger 15 16logger = get_logger(__name__) 17 18PROJECT = "goodposts-followers" 19SAVE_DIR = "./data/training" 20NAME = "following-embedding" 21 22 23def main(run_id: Optional[str], model_checkpoint: Optional[str], dataset_path: str): 24 if run_id is None: 25 run_id = NAME + str(int(datetime.now().timestamp())) 26 27 logger.info("Loading training dataset") 28 train_dataset = FollowDataset( 29 dataset_path=dataset_path, 30 split="training", 31 negative_sample_chance=0.5, 32 ) 33 34 logger.info("Loading validation dataset") 35 val_dataset = FollowDataset( 36 dataset_path=dataset_path, 37 split="validation", 38 negative_sample_chance=0.0, 39 ) 40 41 train_dataloader = StatefulDataLoader( 42 train_dataset, 43 pin_memory=True, 44 collate_fn=train_dataset.collate_rows, 45 batch_size=1024, 46 num_workers=7, 47 shuffle=True, 48 ) 49 50 val_dataloader = StatefulDataLoader( 51 val_dataset, 52 pin_memory=True, 53 collate_fn=val_dataset.collate_rows, 54 batch_size=1024, 55 num_workers=7, 56 ) 57 58 logger.info("Instantiating embedding model") 59 wandb_logger = WandbLogger( 60 project=PROJECT, 61 save_dir=SAVE_DIR, 62 name=NAME, 63 id=run_id, 64 resume="allow", 65 ) 66 67 checkpoint_callback = ModelCheckpoint( 68 dirpath=os.path.join(SAVE_DIR, "checkpoints", run_id), 69 filename="{epoch}-{step}", 70 save_last=True, 71 save_top_k=1, 72 monitor="val_loss", 73 ) 74 75 if model_checkpoint: 76 embedding = FollowEmbedModule.load_from_checkpoint( 77 checkpoint_path=model_checkpoint 78 ) 79 else: 80 embedding = FollowEmbedModule( 81 num_embeds=train_dataset.num_users(), 82 embed_dim=256, 83 learning_rate=1e-1, 84 ) 85 86 torch.set_float32_matmul_precision("medium") 87 trainer = pyl.Trainer( 88 logger=wandb_logger, 89 callbacks=[checkpoint_callback], 90 accelerator="gpu", 91 devices=1, 92 precision="bf16", 93 gradient_clip_val=0.5, 94 max_epochs=10, 95 val_check_interval=0.5, 96 ) 97 trainer.fit( 98 model=embedding, 99 train_dataloaders=train_dataloader, 100 val_dataloaders=val_dataloader, 101 ckpt_path=model_checkpoint, 102 ) 103 104 105if __name__ == "__main__": 106 parser = ArgumentParser() 107 parser.add_argument( 108 "-r", 109 "--run-id", 110 dest="run_id", 111 help="ID for training run", 112 ) 113 parser.add_argument( 114 "-m", 115 "--model-checkpoint", 116 dest="model_checkpoint", 117 help="Path to checkpoint with run data", 118 ) 119 parser.add_argument( 120 "-d", 121 "--dataset-path", 122 dest="dataset_path", 123 help="Path to dataset directory", 124 ) 125 126 args = parser.parse_args() 127 main( 128 run_id=args.run_id, 129 model_checkpoint=args.model_checkpoint, 130 dataset_path=args.dataset_path, 131 )