ML-based recommendation feed for Bluesky posts
1import torch
2from torch import nn
3import torch.nn.functional as F
4import lightning as pyl
5
6
7class UserEmbedding(nn.Module):
8 def __init__(self, num_embeds: int, size: int):
9 super().__init__()
10 self.embed = nn.Embedding(
11 num_embeddings=num_embeds, embedding_dim=size, dtype=torch.bfloat16
12 )
13
14 def forward(self, batch):
15 return self.embed(batch)
16
17
18class FollowEmbedModule(pyl.LightningModule):
19 def __init__(self, num_embeds: int, embed_dim: int, learning_rate: float):
20 super().__init__()
21 self.save_hyperparameters()
22 self.source_embed = UserEmbedding(num_embeds, embed_dim)
23 self.target_embed = UserEmbedding(num_embeds, embed_dim)
24 self.learning_rate = learning_rate
25
26 def training_step(self, batch, _):
27 x, y = batch
28 rotated_x = x.view(x.size(0), -1)
29 source = self.source_embed(rotated_x[:, 0])
30 target: torch.Tensor = self.target_embed(rotated_x[:, 1])
31 dot = torch.bmm(source.unsqueeze(1), target.unsqueeze(2)).squeeze()
32 loss = -1 * torch.mean(torch.log(F.sigmoid(dot * y)))
33 self.log("train_loss", loss, prog_bar=False)
34 return loss
35
36 def validation_step(self, batch, _):
37 x, y = batch
38 rotated_x = x.view(x.size(0), -1)
39 source = self.source_embed(rotated_x[:, 0])
40 target: torch.Tensor = self.target_embed(rotated_x[:, 1])
41 dot = torch.bmm(source.unsqueeze(1), target.unsqueeze(2)).squeeze()
42 loss = -1 * torch.mean(torch.log(F.sigmoid(dot * y)))
43 self.log("val_loss", loss, prog_bar=False)
44 return loss
45
46 def on_validation_epoch_end(self):
47 avg_loss = self.trainer.callback_metrics.get("val_loss")
48 if avg_loss is not None:
49 print(f"\nValidation Loss: {avg_loss:.4f}\n")
50 avg_loss = self.trainer.callback_metrics.get("train_loss")
51 if avg_loss is not None:
52 print(f"\nTraining Loss: {avg_loss:.4f}\n")
53
54 def configure_optimizers(self):
55 optimizers = torch.optim.Adagrad(self.parameters(), lr=self.learning_rate)
56 return optimizers