Skip to content
Open

Plum #33

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions scripts/plum/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch

import irec.callbacks as cb
from irec.runners import TrainingRunner, TrainingRunnerContext

class InitCodebooks(cb.TrainingCallback):
def __init__(self, dataloader):
super().__init__()
self._dataloader = dataloader

@torch.no_grad()
def before_run(self, runner: TrainingRunner):
for i in range(len(runner.model.codebooks)):
X = next(iter(self._dataloader))['embedding']
idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])]
remainder = runner.model.encoder(X[idx])

for j in range(i):
codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j])
codebook_vectors = runner.model.codebooks[j][codebook_indices]
remainder = remainder - codebook_vectors

runner.model.codebooks[i].data = remainder.detach()


class FixDeadCentroids(cb.TrainingCallback):
def __init__(self, dataloader):
super().__init__()
self._dataloader = dataloader

def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext):
for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)):
context.metrics[f'num_dead/{i}'] = num_fixed

@torch.no_grad()
def fix_dead_codebooks(self, runner: TrainingRunner):
num_fixed = []
for codebook_idx, codebook in enumerate(runner.model.codebooks):
centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device)
random_batch = next(iter(self._dataloader))['embedding']

for batch in self._dataloader:
remainder = runner.model.encoder(batch['embedding'])
for l in range(codebook_idx):
ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
remainder = remainder - runner.model.codebooks[l][ind]

indices = runner.model.get_codebook_indices(remainder, codebook)
centroid_counts.scatter_add_(0, indices, torch.ones_like(indices))

dead_mask = (centroid_counts == 0)
num_dead = int(dead_mask.sum().item())
num_fixed.append(num_dead)
if num_dead == 0:
continue

remainder = runner.model.encoder(random_batch)
for l in range(codebook_idx):
ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
remainder = remainder - runner.model.codebooks[l][ind]
remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead]
codebook[dead_mask] = remainder.detach()

return num_fixed
96 changes: 96 additions & 0 deletions scripts/plum/cooc_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import json
import pickle
from collections import defaultdict, Counter

import numpy as np
from loguru import logger


import pickle
from collections import defaultdict, Counter

class CoocMappingDataset:
def __init__(
self,
train_sampler,
validation_sampler,
test_sampler,
num_items,
max_sequence_length,
cooccur_counter_mapping=None
):
self._train_sampler = train_sampler
self._validation_sampler = validation_sampler
self._test_sampler = test_sampler
self._num_items = num_items
self._max_sequence_length = max_sequence_length
self._cooccur_counter_mapping = cooccur_counter_mapping

@classmethod
def create(cls, inter_json_path, max_sequence_length, sampler_type, window_size):
max_item_id = 0
train_dataset, validation_dataset, test_dataset = [], [], []

with open(inter_json_path, 'r') as f:
user_interactions = json.load(f)

for user_id_str, item_ids in user_interactions.items():
user_id = int(user_id_str)
if item_ids:
max_item_id = max(max_item_id, max(item_ids))
assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items'
train_dataset.append({
'user.ids': [user_id],
'item.ids': item_ids[:-2],
})
validation_dataset.append({
'user.ids': [user_id],
'item.ids': item_ids[:-1],
})
test_dataset.append({
'user.ids': [user_id],
'item.ids': item_ids,
})

cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size)
logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}')

train_sampler = train_dataset
validation_sampler = validation_dataset
test_sampler = test_dataset

return cls(
train_sampler=train_sampler,
validation_sampler=validation_sampler,
test_sampler=test_sampler,
num_items=max_item_id + 1,
max_sequence_length=max_sequence_length,
cooccur_counter_mapping=cooccur_counter_mapping
)

@staticmethod
def build_cooccur_counter_mapping(train_dataset, window_size): #TODO передавать время и по нему строить окно
cooccur_counts = defaultdict(Counter)
for session in train_dataset:
items = session['item.ids']
for i in range(len(items)):
item_i = items[i]
for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)):
if i != j:
cooccur_counts[item_i][items[j]] += 1
return cooccur_counts

def get_datasets(self):
return self._train_sampler, self._validation_sampler, self._test_sampler

@property
def num_items(self):
return self._num_items

@property
def max_sequence_length(self):
return self._max_sequence_length

@property
def cooccur_counter_mapping(self):
return self._cooccur_counter_mapping
38 changes: 38 additions & 0 deletions scripts/plum/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import pickle

from irec.data.base import BaseDataset
from irec.data.transforms import Transform


class EmbeddingDataset(BaseDataset):
def __init__(self, data_path):
self.data_path = data_path
with open(data_path, 'rb') as f:
self.data = pickle.load(f)

self.item_ids = np.array(self.data['item_id'], dtype=np.int64)
self.embeddings = np.array(self.data['embedding'], dtype=np.float32)

def __getitem__(self, idx):
index = self.item_ids[idx]
tensor_emb = self.embeddings[idx]
return {
'item_id': index,
'embedding': tensor_emb,
'embedding_dim': len(tensor_emb)
}

def __len__(self):
return len(self.embeddings)


class ProcessEmbeddings(Transform):
def __init__(self, embedding_dim, keys):
self.embedding_dim = embedding_dim
self.keys = keys

def __call__(self, batch):
for key in self.keys:
batch[key] = batch[key].reshape(-1, self.embedding_dim)
return batch
152 changes: 152 additions & 0 deletions scripts/plum/infer_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from loguru import logger
import os

import torch

import irec.callbacks as cb
from irec.data.dataloader import DataLoader
from irec.data.transforms import Collate, ToTorch, ToDevice
from irec.runners import InferenceRunner

from irec.utils import fix_random_seed

from data import EmbeddingDataset, ProcessEmbeddings
from models import PlumRQVAE
from transforms import AddWeightedCooccurrenceEmbeddings
from cooc_data import CoocMappingDataset

SEED_VALUE = 42
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

BATCH_SIZE = 1024

INPUT_DIM = 4096
HIDDEN_DIM = 32
CODEBOOK_SIZE = 256
NUM_CODEBOOKS = 3

BETA = 0.25
MODEL_PATH = '/home/jovyan/IRec/checkpoints/test_plum_rqvae_beauty_ws_2_best_0.0054.pth'

WINDOW_SIZE = 2

EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}'

IREC_PATH = '/home/jovyan/IRec/'


def main():
fix_random_seed(SEED_VALUE)

data = CoocMappingDataset.create(
inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'),
max_sequence_length=20,
sampler_type='sasrec',
window_size=WINDOW_SIZE
)

dataset = EmbeddingDataset(
data_path='/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl'
)

item_id_to_embedding = {}
all_item_ids = []
for idx in range(len(dataset)):
sample = dataset[idx]
item_id = int(sample['item_id'])
item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
all_item_ids.append(item_id)

add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)

dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=False,
drop_last=False,
).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)

model = PlumRQVAE(
input_dim=INPUT_DIM,
num_codebooks=NUM_CODEBOOKS,
codebook_size=CODEBOOK_SIZE,
embedding_dim=HIDDEN_DIM,
beta=BETA,
quant_loss_weight=1.0,
contrastive_loss_weight=1.0,
temperature=1.0
).to(DEVICE)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

logger.debug(f'Overall parameters: {total_params:,}')
logger.debug(f'Trainable parameters: {trainable_params:,}')

callbacks = [
cb.LoadModel(MODEL_PATH),

cb.BatchMetrics(metrics=lambda model_outputs, _: {
'loss': model_outputs['loss'],
'recon_loss': model_outputs['recon_loss'],
'rqvae_loss': model_outputs['rqvae_loss'],
'con_loss': model_outputs['con_loss']
}, name='valid'),

cb.MetricAccumulator(
accumulators={
'valid/loss': cb.MeanAccumulator(),
'valid/recon_loss': cb.MeanAccumulator(),
'valid/rqvae_loss': cb.MeanAccumulator(),
'valid/con_loss': cb.MeanAccumulator(),
},
),

cb.Logger().every_num_steps(len(dataloader)),

cb.InferenceSaver(
metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
save_path=f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json',
format='json'
)
]

logger.debug('Everything is ready for training process!')

runner = InferenceRunner(
model=model,
dataset=dataloader,
callbacks=callbacks,
)
runner.run()

import json
from collections import defaultdict
import numpy as np

with open(f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', 'r') as f:
mappings = json.load(f)

inter = {}
sem_2_ids = defaultdict(list)
for mapping in mappings:
item_id = mapping['item_id']
clusters = mapping['clusters']
inter[int(item_id)] = clusters
sem_2_ids[tuple(clusters)].append(int(item_id))

for semantics, items in sem_2_ids.items():
assert len(items) <= CODEBOOK_SIZE, str(len(items))
collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
for item_id, collision_solver in zip(items, collision_solvers):
inter[item_id].append(collision_solver)
for i in range(len(inter[item_id])):
inter[item_id][i] += CODEBOOK_SIZE * i

with open(os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
json.dump(inter, f, indent=2)


if __name__ == '__main__':
main()
Loading