diff --git a/scripts/plum/callbacks.py b/scripts/plum/callbacks.py new file mode 100644 index 0000000..43ec460 --- /dev/null +++ b/scripts/plum/callbacks.py @@ -0,0 +1,64 @@ +import torch + +import irec.callbacks as cb +from irec.runners import TrainingRunner, TrainingRunnerContext + +class InitCodebooks(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + @torch.no_grad() + def before_run(self, runner: TrainingRunner): + for i in range(len(runner.model.codebooks)): + X = next(iter(self._dataloader))['embedding'] + idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])] + remainder = runner.model.encoder(X[idx]) + + for j in range(i): + codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j]) + codebook_vectors = runner.model.codebooks[j][codebook_indices] + remainder = remainder - codebook_vectors + + runner.model.codebooks[i].data = remainder.detach() + + +class FixDeadCentroids(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)): + context.metrics[f'num_dead/{i}'] = num_fixed + + @torch.no_grad() + def fix_dead_codebooks(self, runner: TrainingRunner): + num_fixed = [] + for codebook_idx, codebook in enumerate(runner.model.codebooks): + centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device) + random_batch = next(iter(self._dataloader))['embedding'] + + for batch in self._dataloader: + remainder = runner.model.encoder(batch['embedding']) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + + indices = runner.model.get_codebook_indices(remainder, codebook) + centroid_counts.scatter_add_(0, indices, torch.ones_like(indices)) + + dead_mask = (centroid_counts == 0) + num_dead = int(dead_mask.sum().item()) + num_fixed.append(num_dead) + if num_dead == 0: + continue + + remainder = runner.model.encoder(random_batch) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead] + codebook[dead_mask] = remainder.detach() + + return num_fixed diff --git a/scripts/plum/cooc_data.py b/scripts/plum/cooc_data.py new file mode 100644 index 0000000..b11e6f0 --- /dev/null +++ b/scripts/plum/cooc_data.py @@ -0,0 +1,96 @@ +import json +import pickle +from collections import defaultdict, Counter + +import numpy as np +from loguru import logger + + +import pickle +from collections import defaultdict, Counter + +class CoocMappingDataset: + def __init__( + self, + train_sampler, + validation_sampler, + test_sampler, + num_items, + max_sequence_length, + cooccur_counter_mapping=None + ): + self._train_sampler = train_sampler + self._validation_sampler = validation_sampler + self._test_sampler = test_sampler + self._num_items = num_items + self._max_sequence_length = max_sequence_length + self._cooccur_counter_mapping = cooccur_counter_mapping + + @classmethod + def create(cls, inter_json_path, max_sequence_length, sampler_type, window_size): + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + with open(inter_json_path, 'r') as f: + user_interactions = json.load(f) + + for user_id_str, item_ids in user_interactions.items(): + user_id = int(user_id_str) + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items' + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-2], + }) + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-1], + }) + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids, + }) + + cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size) + logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}') + + train_sampler = train_dataset + validation_sampler = validation_dataset + test_sampler = test_dataset + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, + max_sequence_length=max_sequence_length, + cooccur_counter_mapping=cooccur_counter_mapping + ) + + @staticmethod + def build_cooccur_counter_mapping(train_dataset, window_size): #TODO передавать время и по нему строить окно + cooccur_counts = defaultdict(Counter) + for session in train_dataset: + items = session['item.ids'] + for i in range(len(items)): + item_i = items[i] + for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)): + if i != j: + cooccur_counts[item_i][items[j]] += 1 + return cooccur_counts + + def get_datasets(self): + return self._train_sampler, self._validation_sampler, self._test_sampler + + @property + def num_items(self): + return self._num_items + + @property + def max_sequence_length(self): + return self._max_sequence_length + + @property + def cooccur_counter_mapping(self): + return self._cooccur_counter_mapping diff --git a/scripts/plum/data.py b/scripts/plum/data.py new file mode 100644 index 0000000..0ffef82 --- /dev/null +++ b/scripts/plum/data.py @@ -0,0 +1,38 @@ +import numpy as np +import pickle + +from irec.data.base import BaseDataset +from irec.data.transforms import Transform + + +class EmbeddingDataset(BaseDataset): + def __init__(self, data_path): + self.data_path = data_path + with open(data_path, 'rb') as f: + self.data = pickle.load(f) + + self.item_ids = np.array(self.data['item_id'], dtype=np.int64) + self.embeddings = np.array(self.data['embedding'], dtype=np.float32) + + def __getitem__(self, idx): + index = self.item_ids[idx] + tensor_emb = self.embeddings[idx] + return { + 'item_id': index, + 'embedding': tensor_emb, + 'embedding_dim': len(tensor_emb) + } + + def __len__(self): + return len(self.embeddings) + + +class ProcessEmbeddings(Transform): + def __init__(self, embedding_dim, keys): + self.embedding_dim = embedding_dim + self.keys = keys + + def __call__(self, batch): + for key in self.keys: + batch[key] = batch[key].reshape(-1, self.embedding_dim) + return batch \ No newline at end of file diff --git a/scripts/plum/infer_default.py b/scripts/plum/infer_default.py new file mode 100644 index 0000000..af8df34 --- /dev/null +++ b/scripts/plum/infer_default.py @@ -0,0 +1,152 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddings +from cooc_data import CoocMappingDataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 + +BETA = 0.25 +MODEL_PATH = '/home/jovyan/IRec/checkpoints/test_plum_rqvae_beauty_ws_2_best_0.0054.pth' + +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}' + +IREC_PATH = '/home/jovyan/IRec/' + + +def main(): + fix_random_seed(SEED_VALUE) + + data = CoocMappingDataset.create( + inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'), + max_sequence_length=20, + sampler_type='sasrec', + window_size=WINDOW_SIZE + ) + + dataset = EmbeddingDataset( + data_path='/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + add_cooc_transform = AddWeightedCooccurrenceEmbeddings( + data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/models.py b/scripts/plum/models.py new file mode 100644 index 0000000..d475712 --- /dev/null +++ b/scripts/plum/models.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class PlumRQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + beta=0.25, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=0.0, + ): + super().__init__() + self.register_buffer('beta', torch.tensor(beta)) + self.temperature = temperature + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.quant_loss_weight = quant_loss_weight + + self.contrastive_loss_weight = contrastive_loss_weight + + self.encoder = self.make_encoding_tower(input_dim, embedding_dim) + self.decoder = self.make_encoding_tower(embedding_dim, input_dim) + + self.codebooks = torch.nn.ParameterList() + for _ in range(num_codebooks): + cb = torch.FloatTensor(codebook_size, embedding_dim) + #nn.init.normal_(cb) + self.codebooks.append(cb) + + @staticmethod + def make_encoding_tower(d1, d2, bias=False): + return torch.nn.Sequential( + nn.Linear(d1, d1), + nn.ReLU(), + nn.Linear(d1, d2), + nn.ReLU(), + nn.Linear(d2, d2, bias=bias) + ) + + @staticmethod + def get_codebook_indices(remainder, codebook): + dist = torch.cdist(remainder, codebook) + return dist.argmin(dim=-1) + + def _quantize_representation(self, latent_vector): + latent_restored = 0 + remainder = latent_vector + + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + return latent_restored + + def contrastive_loss(self, p_i, p_i_star): + N_b = p_i.size(0) + + p_i = F.normalize(p_i, p=2, dim=-1) #TODO посмотреть без нормалайза + p_i_star = F.normalize(p_i_star, p=2, dim=-1) + + similarities = torch.matmul(p_i, p_i_star.T) / self.temperature + + labels = torch.arange(N_b, dtype=torch.long, device=p_i.device) + + loss = F.cross_entropy(similarities, labels) + + return loss + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + item_ids = inputs['item_id'] + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + clusters.append(codebook_indices) + + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding']) + + if 'cooccurrence_embedding' in inputs: + cooccurrence_latent = self.encoder(inputs['cooccurrence_embedding'].to(latent_restored.device)) + cooccurrence_restored = self._quantize_representation(cooccurrence_latent) + con_loss = self.contrastive_loss(latent_restored, cooccurrence_restored) + else: + con_loss = torch.as_tensor(0.0, device=latent_vector.device) + + loss = ( + recon_loss + + self.quant_loss_weight * rqvae_loss + + self.contrastive_loss_weight * con_loss + ).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + 'con_loss': con_loss.item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } \ No newline at end of file diff --git a/scripts/plum/train_plum.py b/scripts/plum/train_plum.py new file mode 100644 index 0000000..5a00bc3 --- /dev/null +++ b/scripts/plum/train_plum.py @@ -0,0 +1,169 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddings +from cooc_data import CoocMappingDataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 500 +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}' +IREC_PATH = '../../../../../' + + +def main(): + fix_random_seed(SEED_VALUE) + + import pickle + + data = CoocMappingDataset.create( + inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'), + max_sequence_length=20, + sampler_type='sasrec', + window_size=WINDOW_SIZE + ) + + dataset = EmbeddingDataset( + data_path='/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + add_cooc_transform = AddWeightedCooccurrenceEmbeddings( + data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/transforms.py b/scripts/plum/transforms.py new file mode 100644 index 0000000..0af1dda --- /dev/null +++ b/scripts/plum/transforms.py @@ -0,0 +1,43 @@ +import numpy as np +import pickle +import torch + +from irec.data.base import BaseDataset +from irec.data.transforms import Transform + +from cooc_data import CoocMappingDataset + + +class AddWeightedCooccurrenceEmbeddings: + def __init__(self, cooccur_counts, item_id_to_embedding, all_item_ids): + self.cooccur_counts = cooccur_counts + self.item_id_to_embedding = item_id_to_embedding + self.all_item_ids = all_item_ids + self.call_count = 0 + + def __call__(self, batch): + self.call_count += 1 + item_ids = batch['item_id'] + cooccurrence_embeddings = [] + + for idx, item_id in enumerate(item_ids): + item_id_val = int(item_id.item()) if torch.is_tensor(item_id) else int(item_id) + + counter = self.cooccur_counts.get(item_id_val) + if counter and len(counter) > 0: + cooc_ids, freqs = zip(*counter.items()) + freqs_array = np.array(freqs, dtype=np.float32) + probs = freqs_array / freqs_array.sum() + cooc_id = np.random.choice(cooc_ids, p=probs) + + else: + cooc_id = np.random.choice(self.all_item_ids) + if self.call_count % 10 == 0 and idx < 5: + print(f" idx={idx}: item_id={item_id_val} fallback random") + + cooc_emb = self.item_id_to_embedding.get(cooc_id, batch['embedding'][0]) + cooccurrence_embeddings.append(cooc_emb) + + batch['cooccurrence_embedding'] = torch.stack(cooccurrence_embeddings) + return batch +