--- title: Training Example: Matrix Factorization keywords: fastai sidebar: home_sidebar nb_path: "examples/matrix_factorization.ipynb" ---
!pip install pytorch-lightning recsys_slates_dataset -q
import sys
sys.path.append("../")
import os
os.getcwd()
import torch
import pytorch_lightning as pl
# Define parameters for this run in a dictionary
param = {
'dim' : 9,
'batch_size' : int(1e5),
'effective_batch_size' : int(2e6),
'sample_candidate_items' : 4, # If true, the dataloader adds an additional datapoint to each batch, "allitem", which is randomly sampled items to be used as negative feedback
'num_epochs': 100,
'overfit_batches' : False,
'name' : 'MatrixFactorization-CategoricalLoss'
}
from recsys_slates_dataset import lightning_helper
dm = lightning_helper.SlateDataModule(num_workers=0, **param)
dm.setup()
We implement a simple Matrix Factorization model using categorical losses (instead of the traditional Gaussian loss). Given a slate $S$ shown to the user $u$, the likelihood of clicking a specific item $c$ is:
$$ \frac{e^{z_u *v_c}}{\sum_{i \in S} e^{z_u *v_c}} $$where
$z_u$ is a parameter vector for user $u$,
$v_i$ is a parameter vector for item $i$,
and $x*y$ is the inner product between $x$ and $y$.
from typing import *
import torch.nn as nn
import torch.distributions as dist
class SimilarityDot(pl.LightningModule):
def __init__(self):
super().__init__()
def forward(self, Z, V):
return (Z * V).sum(-1)
def dict_chunker(dict_of_seqs, size):
"Iterates over the first dimension of a dict of sequences"
length = len(dict_of_seqs[list(dict_of_seqs.keys())[0]]) # length of first idex
return ( {key : seq[pos:pos + size] for key, seq in dict_of_seqs.items()} for pos in range(0, length, size))
class MatrixFactorization(pl.LightningModule):
def __init__(
self,
num_users,
num_items,
dim=2,
lr_start=1e-3,
optim="adam",
*args, **kwargs):
super().__init__()
self.save_hyperparameters()
self.score_func = SimilarityDot()
# Initialize parameters
torch.manual_seed(1)
self.itemvec = nn.Embedding(self.hparams.num_items, self.hparams.dim)
nn.init.uniform_(self.itemvec.weight, a=-0.05, b=0.05)
self.uservec = nn.Embedding(self.hparams.num_users, self.hparams.dim)
nn.init.uniform_(self.uservec.weight, a=-0.05, b=0.05)
def loglik(self, batch):
# Get user and item parameters:
# Dimensions of tensors: [user/batch, interaction/step, item/slate, dim]
zetas = self.uservec(batch['userId']).unsqueeze(1).unsqueeze(1)
# Concatenate positive and negative items (first element is the positive one)
items = torch.cat((batch['click'].unsqueeze(-1), batch['allitem']), dim=-1)
# find the parameters vector corresponding to each item in the batch:
itemvecs_batch = self.itemvec(items)
# Compute the similarity (dot product) between the users and items for all items in all slates:
scores = self.score_func(zetas, itemvecs_batch)
# Set effectively zero probability for special Ids (0 is pad and 2 is UNK).
# These scores are log, so -100 is effectively 0: exp(-100)=4e-44
#scores[(batch['slate'] == 2) | (batch['slate'] == 0)] = -100
# Flatten all Tensors to [user, slatelength] (This simplifies the computation of the loss)
# We flatten by using a masking tensor that also selects the relevant data.
# Mask out data that are in a different phase AND datapoints that did not result in any clicks:
mask = (batch['phase_mask']*(batch['click']>=3)).bool()
scores_flat = scores[mask]
# Compute the "allitem" log likelihood of the observations:
# We use a categorical loss where all our positive signals are in the first dimension:
click_idx_flat = torch.zeros((scores_flat.size(0)), device=self.device)
loglik = dist.Categorical(logits=scores_flat).log_prob(click_idx_flat).sum()
return loglik
# TRAINING FUNCTIONS
def step(self, batch, batch_idx, phase):
stats = {}
stats['loglik'] = self.loglik(batch)
# Since we are doing stochastic gradient decsent,
# multiply with the data factor to get estimate of the loss for the whole dataset:
data_factor = (self.hparams.num_users / batch['click'].size(0))
stats['loss'] = -(stats['loglik']*data_factor)
# Report loss and loglik:
with torch.no_grad():
for key, val in stats.items():
self.log(f"{phase}/{key}", val, on_step=False, on_epoch=True, sync_dist= (phase!="train"))
return stats['loss']
@torch.no_grad()
def validation_epoch_end(self, outputs):
# Report mean absolute values of parameters:
for key, par in self.named_parameters():
self.log(f"param/{key}-L1", par.data.abs().mean(), on_step=False, sync_dist=True)
def training_step(self, batch, batch_idx):
return self.step(batch, batch_idx, phase="train")
def validation_step(self, batch, batch_idx):
return self.step(batch, batch_idx, phase="valid")
def configure_optimizers(self):
pars = self.parameters()
optimizer = torch.optim.AdamW(pars, lr=self.hparams.lr_start)
return optimizer
# PREDICT FUNCTIONS BELOW HERE
@torch.jit.export
def forward_items(
self,
batch : Dict[str, torch.Tensor],
targets: Optional[torch.Tensor]=None,
t_rec: int=-1):
"""
Given a batch of data, estimate scores for all items in target.
If target is None, use all items.
NB: This function is very memory intensive. Need small batch sizes.
"""
if targets is None:
targets = torch.arange(self.hparams.num_items,device=self.device)
target_vecs = self.itemvec(targets).unsqueeze(-2)
zetas = self.uservec(batch['userId']).unsqueeze(1).unsqueeze(1)
scores = self.score_func(zetas,target_vecs).squeeze(-1)
return scores
@torch.no_grad()
def recommend_batch(self, batch: Dict[str, torch.Tensor], num_rec=1, chunksize=3, t_rec=-1, **kwargs):
topk = torch.zeros((len(batch['click']), num_rec), device=self.device)
i = 0
for batch_chunk in dict_chunker(batch, chunksize):
pred = self.forward_items(batch=batch, t_rec=t_rec)
vals, topk_chunk = pred[:,3:].topk(num_rec, dim=1)
topk_chunk = 3+topk_chunk
topk[i:(i + len(pred))] = topk_chunk
i += len(pred)
return topk
model = MatrixFactorization(num_items = dm.num_items, num_users = dm.num_users, **param)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="valid/loglik",
mode="max"
)
cb = [
checkpoint_callback,
# pl.callbacks.LearningRateMonitor(),
lightning_helper.CallbackPrintRecommendedCategory(dm)
]
trainer = pl.Trainer(
overfit_batches=param.get('overfit_batches', False), # for fast dry-runs
callbacks=cb,
logger = pl.loggers.TensorBoardLogger(f"logs", name=param['name']),
max_epochs=param['num_epochs'],
gpus= -1 if torch.cuda.is_available() else 0,
accumulate_grad_batches= int(param['effective_batch_size']/param['batch_size']),
weights_summary='full',
)
#%% TRAIN
trainer.fit(model, dm)