ML-based recommendation feed for Bluesky posts
at main 56 lines 2.1 kB view raw
1import torch 2from torch import nn 3import torch.nn.functional as F 4import lightning as pyl 5 6 7class UserEmbedding(nn.Module): 8 def __init__(self, num_embeds: int, size: int): 9 super().__init__() 10 self.embed = nn.Embedding( 11 num_embeddings=num_embeds, embedding_dim=size, dtype=torch.bfloat16 12 ) 13 14 def forward(self, batch): 15 return self.embed(batch) 16 17 18class FollowEmbedModule(pyl.LightningModule): 19 def __init__(self, num_embeds: int, embed_dim: int, learning_rate: float): 20 super().__init__() 21 self.save_hyperparameters() 22 self.source_embed = UserEmbedding(num_embeds, embed_dim) 23 self.target_embed = UserEmbedding(num_embeds, embed_dim) 24 self.learning_rate = learning_rate 25 26 def training_step(self, batch, _): 27 x, y = batch 28 rotated_x = x.view(x.size(0), -1) 29 source = self.source_embed(rotated_x[:, 0]) 30 target: torch.Tensor = self.target_embed(rotated_x[:, 1]) 31 dot = torch.bmm(source.unsqueeze(1), target.unsqueeze(2)).squeeze() 32 loss = -1 * torch.mean(torch.log(F.sigmoid(dot * y))) 33 self.log("train_loss", loss, prog_bar=False) 34 return loss 35 36 def validation_step(self, batch, _): 37 x, y = batch 38 rotated_x = x.view(x.size(0), -1) 39 source = self.source_embed(rotated_x[:, 0]) 40 target: torch.Tensor = self.target_embed(rotated_x[:, 1]) 41 dot = torch.bmm(source.unsqueeze(1), target.unsqueeze(2)).squeeze() 42 loss = -1 * torch.mean(torch.log(F.sigmoid(dot * y))) 43 self.log("val_loss", loss, prog_bar=False) 44 return loss 45 46 def on_validation_epoch_end(self): 47 avg_loss = self.trainer.callback_metrics.get("val_loss") 48 if avg_loss is not None: 49 print(f"\nValidation Loss: {avg_loss:.4f}\n") 50 avg_loss = self.trainer.callback_metrics.get("train_loss") 51 if avg_loss is not None: 52 print(f"\nTraining Loss: {avg_loss:.4f}\n") 53 54 def configure_optimizers(self): 55 optimizers = torch.optim.Adagrad(self.parameters(), lr=self.learning_rate) 56 return optimizers