ML-based recommendation feed for Bluesky posts
1from argparse import ArgumentParser
2from datetime import datetime
3import os
4from typing import Optional
5
6import lightning as pyl
7from lightning.pytorch.loggers import WandbLogger
8from lightning.pytorch.callbacks import ModelCheckpoint
9import torch
10from torchdata.stateful_dataloader import StatefulDataLoader
11
12from scripts.training.data import FollowDataset
13from scripts.training.models import FollowEmbedModule
14from scripts.utils import get_logger
15
16logger = get_logger(__name__)
17
18PROJECT = "goodposts-followers"
19SAVE_DIR = "./data/training"
20NAME = "following-embedding"
21
22
23def main(run_id: Optional[str], model_checkpoint: Optional[str], dataset_path: str):
24 if run_id is None:
25 run_id = NAME + str(int(datetime.now().timestamp()))
26
27 logger.info("Loading training dataset")
28 train_dataset = FollowDataset(
29 dataset_path=dataset_path,
30 split="training",
31 negative_sample_chance=0.5,
32 )
33
34 logger.info("Loading validation dataset")
35 val_dataset = FollowDataset(
36 dataset_path=dataset_path,
37 split="validation",
38 negative_sample_chance=0.0,
39 )
40
41 train_dataloader = StatefulDataLoader(
42 train_dataset,
43 pin_memory=True,
44 collate_fn=train_dataset.collate_rows,
45 batch_size=1024,
46 num_workers=7,
47 shuffle=True,
48 )
49
50 val_dataloader = StatefulDataLoader(
51 val_dataset,
52 pin_memory=True,
53 collate_fn=val_dataset.collate_rows,
54 batch_size=1024,
55 num_workers=7,
56 )
57
58 logger.info("Instantiating embedding model")
59 wandb_logger = WandbLogger(
60 project=PROJECT,
61 save_dir=SAVE_DIR,
62 name=NAME,
63 id=run_id,
64 resume="allow",
65 )
66
67 checkpoint_callback = ModelCheckpoint(
68 dirpath=os.path.join(SAVE_DIR, "checkpoints", run_id),
69 filename="{epoch}-{step}",
70 save_last=True,
71 save_top_k=1,
72 monitor="val_loss",
73 )
74
75 if model_checkpoint:
76 embedding = FollowEmbedModule.load_from_checkpoint(
77 checkpoint_path=model_checkpoint
78 )
79 else:
80 embedding = FollowEmbedModule(
81 num_embeds=train_dataset.num_users(),
82 embed_dim=256,
83 learning_rate=1e-1,
84 )
85
86 torch.set_float32_matmul_precision("medium")
87 trainer = pyl.Trainer(
88 logger=wandb_logger,
89 callbacks=[checkpoint_callback],
90 accelerator="gpu",
91 devices=1,
92 precision="bf16",
93 gradient_clip_val=0.5,
94 max_epochs=10,
95 val_check_interval=0.5,
96 )
97 trainer.fit(
98 model=embedding,
99 train_dataloaders=train_dataloader,
100 val_dataloaders=val_dataloader,
101 ckpt_path=model_checkpoint,
102 )
103
104
105if __name__ == "__main__":
106 parser = ArgumentParser()
107 parser.add_argument(
108 "-r",
109 "--run-id",
110 dest="run_id",
111 help="ID for training run",
112 )
113 parser.add_argument(
114 "-m",
115 "--model-checkpoint",
116 dest="model_checkpoint",
117 help="Path to checkpoint with run data",
118 )
119 parser.add_argument(
120 "-d",
121 "--dataset-path",
122 dest="dataset_path",
123 help="Path to dataset directory",
124 )
125
126 args = parser.parse_args()
127 main(
128 run_id=args.run_id,
129 model_checkpoint=args.model_checkpoint,
130 dataset_path=args.dataset_path,
131 )