from argparse import ArgumentParser from datetime import datetime import os from typing import Optional import lightning as pyl from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.callbacks import ModelCheckpoint import torch from torchdata.stateful_dataloader import StatefulDataLoader from scripts.training.data import FollowDataset from scripts.training.models import FollowEmbedModule from scripts.utils import get_logger logger = get_logger(__name__) PROJECT = "goodposts-followers" SAVE_DIR = "./data/training" NAME = "following-embedding" def main(run_id: Optional[str], model_checkpoint: Optional[str], dataset_path: str): if run_id is None: run_id = NAME + str(int(datetime.now().timestamp())) logger.info("Loading training dataset") train_dataset = FollowDataset( dataset_path=dataset_path, split="training", negative_sample_chance=0.5, ) logger.info("Loading validation dataset") val_dataset = FollowDataset( dataset_path=dataset_path, split="validation", negative_sample_chance=0.0, ) train_dataloader = StatefulDataLoader( train_dataset, pin_memory=True, collate_fn=train_dataset.collate_rows, batch_size=1024, num_workers=7, shuffle=True, ) val_dataloader = StatefulDataLoader( val_dataset, pin_memory=True, collate_fn=val_dataset.collate_rows, batch_size=1024, num_workers=7, ) logger.info("Instantiating embedding model") wandb_logger = WandbLogger( project=PROJECT, save_dir=SAVE_DIR, name=NAME, id=run_id, resume="allow", ) checkpoint_callback = ModelCheckpoint( dirpath=os.path.join(SAVE_DIR, "checkpoints", run_id), filename="{epoch}-{step}", save_last=True, save_top_k=1, monitor="val_loss", ) if model_checkpoint: embedding = FollowEmbedModule.load_from_checkpoint( checkpoint_path=model_checkpoint ) else: embedding = FollowEmbedModule( num_embeds=train_dataset.num_users(), embed_dim=256, learning_rate=1e-1, ) torch.set_float32_matmul_precision("medium") trainer = pyl.Trainer( logger=wandb_logger, callbacks=[checkpoint_callback], accelerator="gpu", devices=1, precision="bf16", gradient_clip_val=0.5, max_epochs=10, val_check_interval=0.5, ) trainer.fit( model=embedding, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=model_checkpoint, ) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( "-r", "--run-id", dest="run_id", help="ID for training run", ) parser.add_argument( "-m", "--model-checkpoint", dest="model_checkpoint", help="Path to checkpoint with run data", ) parser.add_argument( "-d", "--dataset-path", dest="dataset_path", help="Path to dataset directory", ) args = parser.parse_args() main( run_id=args.run_id, model_checkpoint=args.model_checkpoint, dataset_path=args.dataset_path, )