Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions configs/train/rqvae_train_config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"experiment_name": "rqvae_beauty",
"train_steps_num": 6000,
"train_steps_num": 1024,
"dataset": {
"type": "rqvae",
"path_to_data_dir": "../data",
Expand All @@ -12,7 +12,7 @@
"dataloader": {
"train": {
"type": "torch",
"batch_size": 128,
"batch_size": 256,
"batch_processor": {
"type": "embed"
},
Expand All @@ -36,7 +36,7 @@
"n_iter": 100,
"codebook_sizes": [256, 256, 256, 256],
"should_init_codebooks": true,
"should_reinit_unused_clusters": false,
"should_reinit_unused_clusters": true,
"initializer_range": 0.02
},
"optimizer": {
Expand All @@ -49,7 +49,7 @@
"scheduler": {
"type": "step",
"step_size": 100,
"gamma": 0.98
"gamma": 0.96
}
},
"loss": {
Expand Down
72 changes: 72 additions & 0 deletions configs/train/tiger_train_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
{
"experiment_name": "tiger_beauty",
"train_steps_num": 5000,
"dataset": {
"type": "rqvae",
"path_to_data_dir": "../data",
"name": "Beauty",
"samplers": {
"type": "identity"
}
},
"dataloader": {
"train": {
"type": "torch",
"batch_size": 128,
"batch_processor": {
"type": "embed"
},
"drop_last": false,
"shuffle": true
},
"validation": {
"type": "torch",
"batch_size": 256,
"batch_processor": {
"type": "embed"
},
"drop_last": false,
"shuffle": false
}
},
"model": {
"emb_dim": 512,
"n_tokens": 256,
"n_codebooks": 4,
"nhead": 8,
"num_encoder_layers": 6,
"num_decoder_layers": 6,
"dim_feedforward": 2048,
"dropout": 0.1
},
"rqvae_checkpoint_path": "../checkpoints/rqvae_beauty_final_state.pth",
"rqvae_train_config_path": "../configs/train/rqvae_train_config.json",
"optimizer": {
"type": "basic",
"optimizer": {
"type": "adam",
"lr": 1e-4
},
"clip_grad_threshold": 5.0,
"scheduler": {
"type": "step",
"step_size": 100,
"gamma": 0.98
}
},
"loss": {
"type": "rqvae_loss",
"beta": 0.25,
"output_prefix": "loss"
},
"callback": {
"type": "composite",
"callbacks": [
{
"type": "metric",
"on_step": 1,
"loss_prefix": "loss"
}
]
}
}
268 changes: 268 additions & 0 deletions modeling/main.ipynb

Large diffs are not rendered by default.

21 changes: 12 additions & 9 deletions modeling/models/rqvae.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from models.base import TorchModel

import torch
import torch.nn as nn

import torch

from tqdm import tqdm
import faiss

class RqVaeModel(TorchModel, config_name='rqvae'):

def __init__(
self,
all_data,
train_sampler,
input_dim: int,
hidden_dim: int,
n_iter: int,
Expand Down Expand Up @@ -43,17 +41,19 @@ def __init__(

self._init_weights(initializer_range)

embeddings = torch.stack([entry['item.embed'] for entry in all_data._dataset])

if self.should_init_codebooks:
if train_sampler is None:
raise AttributeError("Train sampler is None")

embeddings = torch.stack([entry['item.embed'] for entry in train_sampler._dataset])
self.init_codebooks(embeddings)
print('Codebooks initialized with Faiss Kmeans')
self.should_init_codebooks = False

@classmethod
def create_from_config(cls, config, **kwargs):
return cls(
all_data=kwargs['train_sampler'],
train_sampler=kwargs.get('train_sampler'),
input_dim=config['input_dim'],
hidden_dim=config['hidden_dim'],
n_iter=config['n_iter'],
Expand Down Expand Up @@ -141,9 +141,12 @@ def train_pass(self, embeddings):

def eval_pass(self, embeddings):
ind_lists = []
for cb in self.codebooks:
dist = torch.cdist(self.encoder(embeddings), cb)
ind_lists.append(dist.argmin(dim=-1).cpu().numpy())
remainder = self.encoder(embeddings)
for codebook in self.codebooks:
codebook_indices = self.get_codebook_indices(remainder, codebook)
codebook_vectors = codebook[codebook_indices]
ind_lists.append(codebook_indices.cpu().numpy())
remainder = remainder - codebook_vectors
return zip(*ind_lists)

def forward(self, inputs):
Expand Down
96 changes: 96 additions & 0 deletions modeling/models/tiger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import json
import torch
import torch.nn as nn
import torch.nn.functional as F

from modeling.utils import DEVICE
from models.base import BaseModel, TorchModel

# TODO finish tiger model
class TigerModel(TorchModel, config_name='tiger'):
def __init__(
self,
rqvae_encoder,
emb_dim,
n_tokens,
n_codebooks,
nhead,
num_encoder_layers,
num_decoder_layers,
dim_feedforward,
dropout
):
super().__init__()

self.rqvae_encoder = rqvae_encoder
self.emb_dim = emb_dim
self.n_tokens = n_tokens

self.position_embeddings = nn.Embedding(n_codebooks, emb_dim)
self.item_embeddings = nn.Embedding(n_tokens, emb_dim)

self.transformer = nn.Transformer(
d_model=emb_dim,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout
)

self.proj = nn.Linear(emb_dim, n_tokens)

@classmethod
def create_from_config(cls, config, **kwargs):
rqvae_train_config = json.load(open(config['rqvae_train_config_path']))

rqvae_model = BaseModel.create_from_config(rqvae_train_config['model']).to(DEVICE)
rqvae_model.load_state_dict(torch.load(config['rqvae_checkpoint_path'], weights_only=True))
rqvae_model.eval()

return cls(
rqvae_encoder=rqvae_model,
emb_dim=config['emb_dim'],
n_tokens=config['n_tokens'],
n_codebooks=config['n_codebooks'],
nhead=config['nhead'],
num_encoder_layers=config['num_encoder_layers'],
num_decoder_layers=config['num_decoder_layers'],
dim_feedforward=config['dim_feedforward'],
dropout=config['dropout']
)

def forward(self, user_item_history):
# Get item embeddings from RQVAE encoder
item_sequence = self.rqvae_encoder(user_item_history)

# Convert item sequence to embeddings (embedding size is emb_dim)
item_embs = self.item_embeddings(item_sequence)

# Add positional embeddings (positions are in the range [0, 3] for each tuple in the sequence)
positions = torch.arange(0, item_embs.size(1), device=item_embs.device).unsqueeze(0)
position_embs = self.position_embeddings(positions)

# Add position embeddings to item embeddings
embeddings = item_embs + position_embs

# Transformer expects the input to be in (seq_len, batch, embedding_dim) format
embeddings = embeddings.permute(1, 0, 2) # Convert to (seq_len, batch, emb_dim)

# Create the target sequence for the transformer decoder
# You can shift the sequence for training as needed (e.g., teacher forcing)
target = embeddings.clone() # Use input embeddings as target for now

# Pass through the transformer (using embeddings as both input and target)
transformer_output = self.transformer(embeddings, target)

# Project the output back to token space (256 possible values for each codebook)
logits = self.proj(transformer_output)

# Apply softmax to get probabilities (for cross-entropy loss)
return logits

def compute_loss(self, logits, target):
# Compute cross-entropy loss
loss = F.cross_entropy(logits.view(-1, self.n_tokens), target.view(-1))
return loss
Empty file added modeling/rqvae/__init__.py
Empty file.
File renamed without changes.
10 changes: 1 addition & 9 deletions src/rqvae_data.py → modeling/rqvae/rqvae_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,10 @@ def get_data(cached=True):
with torch.no_grad():
df["embeddings"] = df["combined_text"].progress_apply(encode_text)
else:
df = torch.load("../data/df_with_embs.pt", weights_only=False)
df = torch.load("../data/Beauty/all_data.pt", weights_only=False)

return df

def get_cb_tuples(rqvae, embeddings):
ind_lists = []
for cb in rqvae.codebooks:
dist = torch.cdist(rqvae.encoder(embeddings), cb)
ind_lists.append(dist.argmin(dim=-1).cpu().numpy())

return zip(*ind_lists)


def search_similar_items(items_with_tuples, clust2search, max_cnt=5):
random.shuffle(items_with_tuples)
Expand Down
7 changes: 6 additions & 1 deletion review.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

## Todos

- posterior collapse (как будто все сваливается в один индекс в кодбуке) (fixed eval code)
- обязательно использование reinit unused clusters!
- в Amazon датасете пофиг на rating? получается учитываются только implicit действия?
- TODO какой базовый класс использовать для e2e модели? (LastPred?)
- TODO backward on mean loss? in `RqVae`
- TODO имя для модели (tiger)

## Links

Expand All @@ -14,7 +19,7 @@

## Todo

### Train
### Train full encoder-decoder

- На чем обучать? То есть на каких данных запускать backward pass?
- train model
Expand Down
Loading