import torch from torch import nn import torch.nn.functional as F import lightning as pyl class UserEmbedding(nn.Module): def __init__(self, num_embeds: int, size: int): super().__init__() self.embed = nn.Embedding( num_embeddings=num_embeds, embedding_dim=size, dtype=torch.bfloat16 ) def forward(self, batch): return self.embed(batch) class FollowEmbedModule(pyl.LightningModule): def __init__(self, num_embeds: int, embed_dim: int, learning_rate: float): super().__init__() self.save_hyperparameters() self.source_embed = UserEmbedding(num_embeds, embed_dim) self.target_embed = UserEmbedding(num_embeds, embed_dim) self.learning_rate = learning_rate def training_step(self, batch, _): x, y = batch rotated_x = x.view(x.size(0), -1) source = self.source_embed(rotated_x[:, 0]) target: torch.Tensor = self.target_embed(rotated_x[:, 1]) dot = torch.bmm(source.unsqueeze(1), target.unsqueeze(2)).squeeze() loss = -1 * torch.mean(torch.log(F.sigmoid(dot * y))) self.log("train_loss", loss, prog_bar=False) return loss def validation_step(self, batch, _): x, y = batch rotated_x = x.view(x.size(0), -1) source = self.source_embed(rotated_x[:, 0]) target: torch.Tensor = self.target_embed(rotated_x[:, 1]) dot = torch.bmm(source.unsqueeze(1), target.unsqueeze(2)).squeeze() loss = -1 * torch.mean(torch.log(F.sigmoid(dot * y))) self.log("val_loss", loss, prog_bar=False) return loss def on_validation_epoch_end(self): avg_loss = self.trainer.callback_metrics.get("val_loss") if avg_loss is not None: print(f"\nValidation Loss: {avg_loss:.4f}\n") avg_loss = self.trainer.callback_metrics.get("train_loss") if avg_loss is not None: print(f"\nTraining Loss: {avg_loss:.4f}\n") def configure_optimizers(self): optimizers = torch.optim.Adagrad(self.parameters(), lr=self.learning_rate) return optimizers