diff --git a/configs/train/rqvae_train_config.json b/configs/train/rqvae_train_config.json index d5d25308..ba2cfb05 100644 --- a/configs/train/rqvae_train_config.json +++ b/configs/train/rqvae_train_config.json @@ -1,6 +1,6 @@ { "experiment_name": "rqvae_beauty", - "train_steps_num": 6000, + "train_steps_num": 1024, "dataset": { "type": "rqvae", "path_to_data_dir": "../data", @@ -12,7 +12,7 @@ "dataloader": { "train": { "type": "torch", - "batch_size": 128, + "batch_size": 256, "batch_processor": { "type": "embed" }, @@ -36,7 +36,7 @@ "n_iter": 100, "codebook_sizes": [256, 256, 256, 256], "should_init_codebooks": true, - "should_reinit_unused_clusters": false, + "should_reinit_unused_clusters": true, "initializer_range": 0.02 }, "optimizer": { @@ -49,7 +49,7 @@ "scheduler": { "type": "step", "step_size": 100, - "gamma": 0.98 + "gamma": 0.96 } }, "loss": { diff --git a/configs/train/tiger_train_config.json b/configs/train/tiger_train_config.json new file mode 100644 index 00000000..23e7ef6b --- /dev/null +++ b/configs/train/tiger_train_config.json @@ -0,0 +1,72 @@ +{ + "experiment_name": "tiger_beauty", + "train_steps_num": 5000, + "dataset": { + "type": "rqvae", + "path_to_data_dir": "../data", + "name": "Beauty", + "samplers": { + "type": "identity" + } + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "embed" + }, + "drop_last": false, + "shuffle": true + }, + "validation": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "embed" + }, + "drop_last": false, + "shuffle": false + } + }, + "model": { + "emb_dim": 512, + "n_tokens": 256, + "n_codebooks": 4, + "nhead": 8, + "num_encoder_layers": 6, + "num_decoder_layers": 6, + "dim_feedforward": 2048, + "dropout": 0.1 + }, + "rqvae_checkpoint_path": "../checkpoints/rqvae_beauty_final_state.pth", + "rqvae_train_config_path": "../configs/train/rqvae_train_config.json", + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 1e-4 + }, + "clip_grad_threshold": 5.0, + "scheduler": { + "type": "step", + "step_size": 100, + "gamma": 0.98 + } + }, + "loss": { + "type": "rqvae_loss", + "beta": 0.25, + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + } + ] + } +} diff --git a/modeling/main.ipynb b/modeling/main.ipynb new file mode 100644 index 00000000..7f6194b7 --- /dev/null +++ b/modeling/main.ipynb @@ -0,0 +1,268 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from rqvae.rqvae_data import get_data\n", + "\n", + "df = get_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "embs = torch.stack(df[\"embeddings\"].tolist())" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from utils import DEVICE\n", + "from models.base import BaseModel\n", + "\n", + "config = json.load(open(\"../configs/train/tiger_train_config.json\"))\n", + "\n", + "rqvae_train_config = json.load(open(config['rqvae_train_config_path']))\n", + "rq_vae_config = rqvae_train_config['model']\n", + "rq_vae_config['should_init_codebooks'] = False\n", + "\n", + "rqvae_model = BaseModel.create_from_config(rq_vae_config).to(DEVICE)\n", + "\n", + "rqvae_model.load_state_dict(torch.load(config['rqvae_checkpoint_path'], weights_only=True))\n", + "rqvae_model.eval()\n", + "\n", + "ids = df.asin_numeric.tolist()\n", + "\n", + "embs_dict = {\"ids\": torch.tensor(ids).to(DEVICE), \"embeddings\": embs.to(DEVICE)}\n", + "\n", + "semantic_ids = list(rqvae_model.forward(embs_dict))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from rqvae.collisions import dedup\n", + "\n", + "items_with_tuples = list(zip(df[\"asin\"], df[\"title\"].fillna(\"unknown\"), semantic_ids))\n", + "items_with_tuples = dedup(items_with_tuples)\n", + "\n", + "assert len(df) == len(set(item[-1] for item in items_with_tuples))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "item='Revlon Beyond Natural Smoothing Primer, Clear, 0.85 Ounce' clust_tuple=(0, 33, 46, 141, 0)\n", + "item='Color Club Magic Attraction 843 Nail Polish' clust_tuple=(0, 228, 117, 102, 0)\n", + "item='Color Club Wild at Heart 871 Nail Polish' clust_tuple=(0, 228, 83, 106, 0)\n", + "1\n", + "item='Natural Beauty White / Brightening Essence Full Face Mask 10 Pcs' clust_tuple=(1, 210, 157, 13, 0)\n", + "item='OGX Conditioner, Nourishing Coconut Milk, 13oz' clust_tuple=(1, 50, 213, 95, 0)\n", + "item='BurnOut Eco-sensitive Zinc Oxide Sunscreen SPF 35 (3 oz)' clust_tuple=(1, 119, 46, 204, 0)\n", + "item='Kerastase Reflection Chroma Riche Luminous Softening Essence, 4.2 Ounce' clust_tuple=(1, 119, 3, 229, 0)\n", + "item='KINeSYS Performance Sunscreen, SPF 30, Spray, Mango Scent, 4-Ounce Bottles' clust_tuple=(1, 242, 54, 95, 0)\n", + "2\n", + "item='Pink Sugar by Aquolina for Women - 3.4 Ounce EDT Spray' clust_tuple=(2, 173, 246, 41, 0)\n", + "item='Guess By Parlux Fragrances For Men. Eau De Toilette Spray 2.5 Oz.' clust_tuple=(2, 104, 233, 139, 0)\n", + "item='Black Xs By Paco Rabanne For Men, Eau De Toilette Spray, 3.4-Ounce Bottle' clust_tuple=(2, 177, 233, 95, 0)\n", + "item='Incanto Shine By Salvatore Ferragamo For Women, Eau De Toilette Spray, 3.4-Ounce Bottle' clust_tuple=(2, 128, 185, 76, 0)\n", + "item='Versace Man Eau Fraiche By Gianni Versace For Men Edt Spray 3.4 Oz' clust_tuple=(2, 148, 81, 107, 0)\n", + "3\n", + "item='HDE® Facial Pore Cleanser Cleaner Blackhead Acne Remover' clust_tuple=(3, 106, 111, 172, 0)\n", + "item=\"Best Anti Aging Cream Reduces Wrinkels in Women and Men - Clinical Strength Bio-Peptide Wrinkle Cream Reduces Deep Wrinkles, Smooths Fine Lines and "Crow's Feet" - Tighten, Rejuvinate, and Rebuild Youthful, Healthy Skin - Boost Collagen and Ultra-Moisturize with Peptides Made For Your Skin - Great For Face, Under Eyes and Decolletage. You Love It Or We Buy It Back No Hassle Money Back Guarantee. [Reduced Price For Summer! Take 65% Off Automatically at Checkout!]★ 2oz Jar (60ml)\" clust_tuple=(3, 90, 239, 106, 0)\n", + "item='Raw African Black soap Imported From Ghana 4oz' clust_tuple=(3, 145, 185, 168, 0)\n", + "item='Skin Obsession 20% TCA Home Chemical Peel for face and body removes lines, sun damage and signs of aging' clust_tuple=(3, 207, 151, 96, 0)\n", + "item='Skin Obsession 25% TCA Chemical Peel for Home Use 1 fl oz (30 ml)' clust_tuple=(3, 207, 47, 44, 0)\n", + "4\n", + "item='Philosophy When Hope is Not Enough Firming and Lifting Serum for Unisex, 1 Ounce' clust_tuple=(4, 247, 239, 194, 0)\n", + "item='DKNY BE DELICIOUS by Donna Karan Womens EAU DE PARFUM SPRAY 1 OZ' clust_tuple=(4, 182, 111, 5, 0)\n", + "item='Victorinox By Swiss Army For Men 125 Years Eau-de-toilette Spray, 3.4-Ounce' clust_tuple=(4, 128, 239, 5, 0)\n", + "item='Givenchy Play for Men by Givenchy 3.3 oz 100 ml EDT Spray' clust_tuple=(4, 50, 19, 168, 0)\n", + "item='Jilbere Hot Air Brush' clust_tuple=(4, 15, 47, 194, 0)\n" + ] + } + ], + "source": [ + "from rqvae.rqvae_data import search_similar_items\n", + "\n", + "\n", + "for i in range(5):\n", + " sim = search_similar_items(items_with_tuples, (i,), 5)\n", + " if len(sim) == 0:\n", + " continue\n", + " print(i)\n", + " for asin, item, clust_tuple in sim:\n", + " # if 'shampoo' in item.lower():\n", + " print(f\"{item=} {clust_tuple=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjoAAAGgCAYAAACjXc14AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAALGFJREFUeJzt3X9wVfWd//FXzC9CJjkSILncNWJ0shEalrXBDcEf0AIBS0gd3YINe8WWAi4CjcDyY11XdKYJoIJTsyq6VKyiOF2NawuNxBWjFAIYSCsIaNcIQRKC9nITfpjE8Pn+4XK+vSQEgonJ/fh8zJwZ7jnv87mfTz6cOS8+3HMTZowxAgAAsNBl3d0BAACArkLQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADW6nDQeeeddzRx4kR5vV6FhYXptddec481Nzdr0aJFGjJkiGJjY+X1enXnnXfqyJEjQW00NjZqzpw56tevn2JjY5Wbm6vDhw8H1fj9fvl8PjmOI8dx5PP5dPz48aCaQ4cOaeLEiYqNjVW/fv00d+5cNTU1dXRIAADAUhEdPeHkyZMaOnSofvKTn+j2228POnbq1Cnt2rVL999/v4YOHSq/36/8/Hzl5ubqvffec+vy8/P129/+VuvXr1ffvn01f/585eTkqKKiQuHh4ZKkvLw8HT58WCUlJZKkGTNmyOfz6be//a0kqaWlRRMmTFD//v21ZcsWff7555o6daqMMXr88ccvaixnzpzRkSNHFBcXp7CwsI7+KAAAQDcwxqihoUFer1eXXXaBNRvzNUgyxcXF7dbs2LHDSDIHDx40xhhz/PhxExkZadavX+/WfPrpp+ayyy4zJSUlxhhjPvjgAyPJlJeXuzXbtm0zksz+/fuNMcZs3LjRXHbZZebTTz91a1566SUTHR1tAoHARfW/urraSGJjY2NjY2MLwa26uvqC9/oOr+h0VCAQUFhYmC6//HJJUkVFhZqbm5Wdne3WeL1epaena+vWrRo3bpy2bdsmx3GUmZnp1gwfPlyO42jr1q1KS0vTtm3blJ6eLq/X69aMGzdOjY2Nqqio0Pe+971WfWlsbFRjY6P72vzfL26vrq5WfHx8Zw8dAAB0gfr6eiUnJysuLu6CtV0adL744gstXrxYeXl5bpCora1VVFSU+vTpE1SblJSk2tpatyYxMbFVe4mJiUE1SUlJQcf79OmjqKgot+ZchYWFevDBB1vtj4+PJ+gAABBiLuZjJ1321FVzc7PuuOMOnTlzRk888cQF640xQR1uq/OXUvPXlixZokAg4G7V1dUXMxQAABCiuiToNDc3a9KkSaqqqlJpaWnQaonH41FTU5P8fn/QOXV1de4Kjcfj0dGjR1u1e+zYsaCac1du/H6/mpubW630nBUdHe2u3rCKAwCA/To96JwNOR999JHefPNN9e3bN+h4RkaGIiMjVVpa6u6rqanRnj17NGLECElSVlaWAoGAduzY4dZs375dgUAgqGbPnj2qqalxazZt2qTo6GhlZGR09rAAAEAI6vBndE6cOKE///nP7uuqqipVVlYqISFBXq9X//iP/6hdu3bpd7/7nVpaWtxVl4SEBEVFRclxHE2bNk3z589X3759lZCQoAULFmjIkCEaM2aMJGnQoEEaP368pk+frtWrV0v66vHynJwcpaWlSZKys7M1ePBg+Xw+Pfzww/rLX/6iBQsWaPr06azUAACAr1zUc9h/ZfPmzW0+4jV16lRTVVV13kfANm/e7LZx+vRpM3v2bJOQkGBiYmJMTk6OOXToUND7fP7552bKlCkmLi7OxMXFmSlTphi/3x9Uc/DgQTNhwgQTExNjEhISzOzZs80XX3xx0WMJBAJG0kU/jg4AALpfR+7fYcb83zPW30L19fVyHEeBQIBVIAAAQkRH7t/8risAAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGsRdAAAgLUIOgAAwFod/hUQuHhXLd7Q3V3osE+WTejuLgAA0GlY0QEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADW6nDQeeeddzRx4kR5vV6FhYXptddeCzpujNHSpUvl9XoVExOjUaNGae/evUE1jY2NmjNnjvr166fY2Fjl5ubq8OHDQTV+v18+n0+O48hxHPl8Ph0/fjyo5tChQ5o4caJiY2PVr18/zZ07V01NTR0dEgAAsFSHg87Jkyc1dOhQFRUVtXl8xYoVWrlypYqKirRz5055PB6NHTtWDQ0Nbk1+fr6Ki4u1fv16bdmyRSdOnFBOTo5aWlrcmry8PFVWVqqkpEQlJSWqrKyUz+dzj7e0tGjChAk6efKktmzZovXr1+uVV17R/PnzOzokAABgqTBjjLnkk8PCVFxcrFtvvVXSV6s5Xq9X+fn5WrRokaSvVm+SkpK0fPlyzZw5U4FAQP3799fzzz+vyZMnS5KOHDmi5ORkbdy4UePGjdO+ffs0ePBglZeXKzMzU5JUXl6urKws7d+/X2lpafr973+vnJwcVVdXy+v1SpLWr1+vu+66S3V1dYqPj79g/+vr6+U4jgKBwEXVd9RVizd0eptd7ZNlE7q7CwAAtKsj9+9O/YxOVVWVamtrlZ2d7e6Ljo7WyJEjtXXrVklSRUWFmpubg2q8Xq/S09Pdmm3btslxHDfkSNLw4cPlOE5QTXp6uhtyJGncuHFqbGxURUVFm/1rbGxUfX190AYAAOzVqUGntrZWkpSUlBS0PykpyT1WW1urqKgo9enTp92axMTEVu0nJiYG1Zz7Pn369FFUVJRbc67CwkL3Mz+O4yg5OfkSRgkAAEJFlzx1FRYWFvTaGNNq37nOrWmr/lJq/tqSJUsUCATcrbq6ut0+AQCA0NapQcfj8UhSqxWVuro6d/XF4/GoqalJfr+/3ZqjR4+2av/YsWNBNee+j9/vV3Nzc6uVnrOio6MVHx8ftAEAAHt1atBJSUmRx+NRaWmpu6+pqUllZWUaMWKEJCkjI0ORkZFBNTU1NdqzZ49bk5WVpUAgoB07drg127dvVyAQCKrZs2ePampq3JpNmzYpOjpaGRkZnTksAAAQoiI6esKJEyf05z//2X1dVVWlyspKJSQk6Morr1R+fr4KCgqUmpqq1NRUFRQUqHfv3srLy5MkOY6jadOmaf78+erbt68SEhK0YMECDRkyRGPGjJEkDRo0SOPHj9f06dO1evVqSdKMGTOUk5OjtLQ0SVJ2drYGDx4sn8+nhx9+WH/5y1+0YMECTZ8+nZUaAAAg6RKCznvvvafvfe977ut58+ZJkqZOnaq1a9dq4cKFOn36tGbNmiW/36/MzExt2rRJcXFx7jmrVq1SRESEJk2apNOnT2v06NFau3atwsPD3Zp169Zp7ty57tNZubm5Qd/dEx4erg0bNmjWrFm64YYbFBMTo7y8PD3yyCMd/ykAAAArfa3v0Ql1fI9Oa3yPDgCgp+u279EBAADoSQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWp0edL788kv927/9m1JSUhQTE6Orr75aDz30kM6cOePWGGO0dOlSeb1excTEaNSoUdq7d29QO42NjZozZ4769eun2NhY5ebm6vDhw0E1fr9fPp9PjuPIcRz5fD4dP368s4cEAABCVKcHneXLl+upp55SUVGR9u3bpxUrVujhhx/W448/7tasWLFCK1euVFFRkXbu3CmPx6OxY8eqoaHBrcnPz1dxcbHWr1+vLVu26MSJE8rJyVFLS4tbk5eXp8rKSpWUlKikpESVlZXy+XydPSQAABCiwowxpjMbzMnJUVJSktasWePuu/3229W7d289//zzMsbI6/UqPz9fixYtkvTV6k1SUpKWL1+umTNnKhAIqH///nr++ec1efJkSdKRI0eUnJysjRs3aty4cdq3b58GDx6s8vJyZWZmSpLKy8uVlZWl/fv3Ky0t7YJ9ra+vl+M4CgQCio+P78wfgyTpqsUbOr3NrvbJsgnd3QUAANrVkft3p6/o3Hjjjfqf//kfffjhh5KkP/7xj9qyZYt+8IMfSJKqqqpUW1ur7Oxs95zo6GiNHDlSW7dulSRVVFSoubk5qMbr9So9Pd2t2bZtmxzHcUOOJA0fPlyO47g1AADg2y2isxtctGiRAoGArr32WoWHh6ulpUW/+MUv9OMf/1iSVFtbK0lKSkoKOi8pKUkHDx50a6KiotSnT59WNWfPr62tVWJiYqv3T0xMdGvO1djYqMbGRvd1fX39JY4SAACEgk5f0Xn55Zf1wgsv6MUXX9SuXbv03HPP6ZFHHtFzzz0XVBcWFhb02hjTat+5zq1pq769dgoLC90PLjuOo+Tk5IsdFgAACEGdHnT+5V/+RYsXL9Ydd9yhIUOGyOfz6d5771VhYaEkyePxSFKrVZe6ujp3lcfj8aipqUl+v7/dmqNHj7Z6/2PHjrVaLTpryZIlCgQC7lZdXf31BgsAAHq0Tg86p06d0mWXBTcbHh7uPl6ekpIij8ej0tJS93hTU5PKyso0YsQISVJGRoYiIyODampqarRnzx63JisrS4FAQDt27HBrtm/frkAg4NacKzo6WvHx8UEbAACwV6d/RmfixIn6xS9+oSuvvFLf+c53tHv3bq1cuVI//elPJX313035+fkqKChQamqqUlNTVVBQoN69eysvL0+S5DiOpk2bpvnz56tv375KSEjQggULNGTIEI0ZM0aSNGjQII0fP17Tp0/X6tWrJUkzZsxQTk7ORT1xBQAA7NfpQefxxx/X/fffr1mzZqmurk5er1czZ87Uv//7v7s1Cxcu1OnTpzVr1iz5/X5lZmZq06ZNiouLc2tWrVqliIgITZo0SadPn9bo0aO1du1ahYeHuzXr1q3T3Llz3aezcnNzVVRU1NlDAgAAIarTv0cnlPA9Oq3xPToAgJ6uW79HBwAAoKcg6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGt1SdD59NNP9U//9E/q27evevfurb//+79XRUWFe9wYo6VLl8rr9SomJkajRo3S3r17g9pobGzUnDlz1K9fP8XGxio3N1eHDx8OqvH7/fL5fHIcR47jyOfz6fjx410xJAAAEII6Pej4/X7dcMMNioyM1O9//3t98MEHevTRR3X55Ze7NStWrNDKlStVVFSknTt3yuPxaOzYsWpoaHBr8vPzVVxcrPXr12vLli06ceKEcnJy1NLS4tbk5eWpsrJSJSUlKikpUWVlpXw+X2cPCQAAhKgwY4zpzAYXL16sP/zhD3r33XfbPG6MkdfrVX5+vhYtWiTpq9WbpKQkLV++XDNnzlQgEFD//v31/PPPa/LkyZKkI0eOKDk5WRs3btS4ceO0b98+DR48WOXl5crMzJQklZeXKysrS/v371daWtoF+1pfXy/HcRQIBBQfH99JP4H/76rFGzq9za72ybIJ3d0FAADa1ZH7d6ev6Lz++usaNmyYfvSjHykxMVHXXXednnnmGfd4VVWVamtrlZ2d7e6Ljo7WyJEjtXXrVklSRUWFmpubg2q8Xq/S09Pdmm3btslxHDfkSNLw4cPlOI5bAwAAvt06Peh8/PHHevLJJ5Wamqo33nhDd999t+bOnatf//rXkqTa2lpJUlJSUtB5SUlJ7rHa2lpFRUWpT58+7dYkJia2ev/ExES35lyNjY2qr68P2gAAgL0iOrvBM2fOaNiwYSooKJAkXXfdddq7d6+efPJJ3XnnnW5dWFhY0HnGmFb7znVuTVv17bVTWFioBx988KLHAgAAQlunr+gMGDBAgwcPDto3aNAgHTp0SJLk8XgkqdWqS11dnbvK4/F41NTUJL/f327N0aNHW73/sWPHWq0WnbVkyRIFAgF3q66uvoQRAgCAUNHpQeeGG27QgQMHgvZ9+OGHGjhwoCQpJSVFHo9HpaWl7vGmpiaVlZVpxIgRkqSMjAxFRkYG1dTU1GjPnj1uTVZWlgKBgHbs2OHWbN++XYFAwK05V3R0tOLj44M2AABgr07/r6t7771XI0aMUEFBgSZNmqQdO3bo6aef1tNPPy3pq/9uys/PV0FBgVJTU5WamqqCggL17t1beXl5kiTHcTRt2jTNnz9fffv2VUJCghYsWKAhQ4ZozJgxkr5aJRo/frymT5+u1atXS5JmzJihnJyci3riCgAA2K/Tg87111+v4uJiLVmyRA899JBSUlL02GOPacqUKW7NwoULdfr0ac2aNUt+v1+ZmZnatGmT4uLi3JpVq1YpIiJCkyZN0unTpzV69GitXbtW4eHhbs26des0d+5c9+ms3NxcFRUVdfaQAABAiOr079EJJXyPTmt8jw4AoKfr1u/RAQAA6CkIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYK0uDzqFhYUKCwtTfn6+u88Yo6VLl8rr9SomJkajRo3S3r17g85rbGzUnDlz1K9fP8XGxio3N1eHDx8OqvH7/fL5fHIcR47jyOfz6fjx4109JAAAECK6NOjs3LlTTz/9tP7u7/4uaP+KFSu0cuVKFRUVaefOnfJ4PBo7dqwaGhrcmvz8fBUXF2v9+vXasmWLTpw4oZycHLW0tLg1eXl5qqysVElJiUpKSlRZWSmfz9eVQwIAACGky4LOiRMnNGXKFD3zzDPq06ePu98Yo8cee0z33XefbrvtNqWnp+u5557TqVOn9OKLL0qSAoGA1qxZo0cffVRjxozRddddpxdeeEHvv/++3nzzTUnSvn37VFJSov/8z/9UVlaWsrKy9Mwzz+h3v/udDhw40FXDAgAAIaTLgs4999yjCRMmaMyYMUH7q6qqVFtbq+zsbHdfdHS0Ro4cqa1bt0qSKioq1NzcHFTj9XqVnp7u1mzbtk2O4ygzM9OtGT58uBzHcWvO1djYqPr6+qANAADYK6IrGl2/fr127dqlnTt3tjpWW1srSUpKSgran5SUpIMHD7o1UVFRQStBZ2vOnl9bW6vExMRW7ScmJro15yosLNSDDz7Y8QEBAICQ1OkrOtXV1fr5z3+uF154Qb169TpvXVhYWNBrY0yrfec6t6at+vbaWbJkiQKBgLtVV1e3+34AACC0dXrQqaioUF1dnTIyMhQREaGIiAiVlZXpl7/8pSIiItyVnHNXXerq6txjHo9HTU1N8vv97dYcPXq01fsfO3as1WrRWdHR0YqPjw/aAACAvTo96IwePVrvv/++Kisr3W3YsGGaMmWKKisrdfXVV8vj8ai0tNQ9p6mpSWVlZRoxYoQkKSMjQ5GRkUE1NTU12rNnj1uTlZWlQCCgHTt2uDXbt29XIBBwawAAwLdbp39GJy4uTunp6UH7YmNj1bdvX3d/fn6+CgoKlJqaqtTUVBUUFKh3797Ky8uTJDmOo2nTpmn+/Pnq27evEhIStGDBAg0ZMsT9cPOgQYM0fvx4TZ8+XatXr5YkzZgxQzk5OUpLS+vsYQEAgBDUJR9GvpCFCxfq9OnTmjVrlvx+vzIzM7Vp0ybFxcW5NatWrVJERIQmTZqk06dPa/To0Vq7dq3Cw8PdmnXr1mnu3Lnu01m5ubkqKir6xscDAAB6pjBjjOnuTnSX+vp6OY6jQCDQJZ/XuWrxhk5vs6t9smxCd3cBAIB2deT+ze+6AgAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAa3V60CksLNT111+vuLg4JSYm6tZbb9WBAweCaowxWrp0qbxer2JiYjRq1Cjt3bs3qKaxsVFz5sxRv379FBsbq9zcXB0+fDioxu/3y+fzyXEcOY4jn8+n48ePd/aQAABAiOr0oFNWVqZ77rlH5eXlKi0t1Zdffqns7GydPHnSrVmxYoVWrlypoqIi7dy5Ux6PR2PHjlVDQ4Nbk5+fr+LiYq1fv15btmzRiRMnlJOTo5aWFrcmLy9PlZWVKikpUUlJiSorK+Xz+Tp7SAAAIESFGWNMV77BsWPHlJiYqLKyMt18880yxsjr9So/P1+LFi2S9NXqTVJSkpYvX66ZM2cqEAiof//+ev755zV58mRJ0pEjR5ScnKyNGzdq3Lhx2rdvnwYPHqzy8nJlZmZKksrLy5WVlaX9+/crLS3tgn2rr6+X4zgKBAKKj4/v9LFftXhDp7fZ1T5ZNqG7uwAAQLs6cv/u8s/oBAIBSVJCQoIkqaqqSrW1tcrOznZroqOjNXLkSG3dulWSVFFRoebm5qAar9er9PR0t2bbtm1yHMcNOZI0fPhwOY7j1pyrsbFR9fX1QRsAALBXlwYdY4zmzZunG2+8Uenp6ZKk2tpaSVJSUlJQbVJSknustrZWUVFR6tOnT7s1iYmJrd4zMTHRrTlXYWGh+3kex3GUnJz89QYIAAB6tC4NOrNnz9af/vQnvfTSS62OhYWFBb02xrTad65za9qqb6+dJUuWKBAIuFt1dfXFDAMAAISoLgs6c+bM0euvv67NmzfriiuucPd7PB5JarXqUldX567yeDweNTU1ye/3t1tz9OjRVu977NixVqtFZ0VHRys+Pj5oAwAA9ur0oGOM0ezZs/Xqq6/qrbfeUkpKStDxlJQUeTwelZaWuvuamppUVlamESNGSJIyMjIUGRkZVFNTU6M9e/a4NVlZWQoEAtqxY4dbs337dgUCAbcGAAB8u0V0doP33HOPXnzxRf33f/+34uLi3JUbx3EUExOjsLAw5efnq6CgQKmpqUpNTVVBQYF69+6tvLw8t3batGmaP3+++vbtq4SEBC1YsEBDhgzRmDFjJEmDBg3S+PHjNX36dK1evVqSNGPGDOXk5FzUE1cAAMB+nR50nnzySUnSqFGjgvY/++yzuuuuuyRJCxcu1OnTpzVr1iz5/X5lZmZq06ZNiouLc+tXrVqliIgITZo0SadPn9bo0aO1du1ahYeHuzXr1q3T3Llz3aezcnNzVVRU1NlDAgAAIarLv0enJ+N7dFrje3QAAD1dj/oeHQAAgO5C0AEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYi6AAAAGsRdAAAgLUIOgAAwFoEHQAAYC2CDgAAsBZBBwAAWIugAwAArEXQAQAA1iLoAAAAaxF0AACAtQg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALAWQQcAAFiLoAMAAKxF0AEAANYK+aDzxBNPKCUlRb169VJGRobefffd7u4SAADoIUI66Lz88svKz8/Xfffdp927d+umm27SLbfcokOHDnV31wAAQA8Q0kFn5cqVmjZtmn72s59p0KBBeuyxx5ScnKwnn3yyu7sGAAB6gIju7sClampqUkVFhRYvXhy0Pzs7W1u3bm3znMbGRjU2NrqvA4GAJKm+vr5L+nim8VSXtNuVrrz3N93dhQ7b8+C47u4CAOAbdPa+bYy5YG3IBp3PPvtMLS0tSkpKCtqflJSk2traNs8pLCzUgw8+2Gp/cnJyl/QR3wznse7uAQCgOzQ0NMhxnHZrQjbonBUWFhb02hjTat9ZS5Ys0bx589zXx48f18CBA3Xo0KEL/qBCUX19vZKTk1VdXa34+Pju7k6XsH2Mto9Psn+MjC/02T7GUByfMUYNDQ3yer0XrA3ZoNOvXz+Fh4e3Wr2pq6trtcpzVnR0tKKjo1vtdxwnZCb3UsTHx1s9Psn+Mdo+Psn+MTK+0Gf7GENtfBe7QBGyH0aOiopSRkaGSktLg/aXlpZqxIgR3dQrAADQk4Tsio4kzZs3Tz6fT8OGDVNWVpaefvppHTp0SHfffXd3dw0AAPQAIR10Jk+erM8//1wPPfSQampqlJ6ero0bN2rgwIEXdX50dLQeeOCBNv87ywa2j0+yf4y2j0+yf4yML/TZPkbbxxdmLubZLAAAgBAUsp/RAQAAuBCCDgAAsBZBBwAAWIugAwAArGV90HniiSeUkpKiXr16KSMjQ++++2679WVlZcrIyFCvXr109dVX66mnnvqGetoxhYWFuv766xUXF6fExETdeuutOnDgQLvnvP322woLC2u17d+//xvqdccsXbq0VV89Hk+754TK/EnSVVdd1eZ83HPPPW3Wh8L8vfPOO5o4caK8Xq/CwsL02muvBR03xmjp0qXyer2KiYnRqFGjtHfv3gu2+8orr2jw4MGKjo7W4MGDVVxc3EUjaF9742tubtaiRYs0ZMgQxcbGyuv16s4779SRI0fabXPt2rVtzusXX3zRxaNp7ULzd9ddd7Xq5/Dhwy/Ybk+ZP+nCY2xrLsLCwvTwww+ft82eNIcXc28I9euwo6wOOi+//LLy8/N13333affu3brpppt0yy236NChQ23WV1VV6Qc/+IFuuukm7d69W//6r/+quXPn6pVXXvmGe35hZWVluueee1ReXq7S0lJ9+eWXys7O1smTJy947oEDB1RTU+Nuqamp30CPL813vvOdoL6+//77560NpfmTpJ07dwaN7eyXX/7oRz9q97yePH8nT57U0KFDVVRU1ObxFStWaOXKlSoqKtLOnTvl8Xg0duxYNTQ0nLfNbdu2afLkyfL5fPrjH/8on8+nSZMmafv27V01jPNqb3ynTp3Srl27dP/992vXrl169dVX9eGHHyo3N/eC7cbHxwfNaU1NjXr16tUVQ2jXheZPksaPHx/Uz40bN7bbZk+aP+nCYzx3Hn71q18pLCxMt99+e7vt9pQ5vJh7Q6hfhx1mLPYP//AP5u677w7ad+2115rFixe3Wb9w4UJz7bXXBu2bOXOmGT58eJf1sbPU1dUZSaasrOy8NZs3bzaSjN/v/+Y69jU88MADZujQoRddH8rzZ4wxP//5z80111xjzpw50+bxUJs/Saa4uNh9febMGePxeMyyZcvcfV988YVxHMc89dRT521n0qRJZvz48UH7xo0bZ+64445O73NHnDu+tuzYscNIMgcPHjxvzbPPPmscx+ncznWCtsY3depU88Mf/rBD7fTU+TPm4ubwhz/8ofn+97/fbk1PnUNjWt8bbLsOL4a1KzpNTU2qqKhQdnZ20P7s7Gxt3bq1zXO2bdvWqn7cuHF677331Nzc3GV97QyBQECSlJCQcMHa6667TgMGDNDo0aO1efPmru7a1/LRRx/J6/UqJSVFd9xxhz7++OPz1oby/DU1NemFF17QT3/60/P+UtqzQmn+/lpVVZVqa2uD5ig6OlojR4487zUpnX9e2zunpwgEAgoLC9Pll1/ebt2JEyc0cOBAXXHFFcrJydHu3bu/mQ5egrfffluJiYn627/9W02fPl11dXXt1ofy/B09elQbNmzQtGnTLljbU+fw3HvDt/E6tDbofPbZZ2ppaWn1Cz6TkpJa/SLQs2pra9us//LLL/XZZ591WV+/LmOM5s2bpxtvvFHp6ennrRswYICefvppvfLKK3r11VeVlpam0aNH65133vkGe3vxMjMz9etf/1pvvPGGnnnmGdXW1mrEiBH6/PPP26wP1fmTpNdee03Hjx/XXXfddd6aUJu/c5297jpyTZ49r6Pn9ARffPGFFi9erLy8vHZ/UeK1116rtWvX6vXXX9dLL72kXr166YYbbtBHH330Dfb24txyyy1at26d3nrrLT366KPauXOnvv/976uxsfG854Tq/EnSc889p7i4ON12223t1vXUOWzr3vBtuw6lEP8VEBfj3H8dG2Pa/RdzW/Vt7e9JZs+erT/96U/asmVLu3VpaWlKS0tzX2dlZam6ulqPPPKIbr755q7uZofdcsst7p+HDBmirKwsXXPNNXruuec0b968Ns8JxfmTpDVr1uiWW26R1+s9b02ozd/5dPSavNRzulNzc7PuuOMOnTlzRk888US7tcOHDw/6QO8NN9yg7373u3r88cf1y1/+squ72iGTJ092/5yenq5hw4Zp4MCB2rBhQ7thINTm76xf/epXmjJlygU/a9NT57C9e8O34To8y9oVnX79+ik8PLxV2qyrq2uVSs/yeDxt1kdERKhv375d1tevY86cOXr99de1efNmXXHFFR0+f/jw4d3+r46LFRsbqyFDhpy3v6E4f5J08OBBvfnmm/rZz37W4XNDaf7OPjHXkWvy7HkdPac7NTc3a9KkSaqqqlJpaWm7qzltueyyy3T99deHxLwOGDBAAwcObLevoTZ/Z7377rs6cODAJV2XPWEOz3dv+LZch3/N2qATFRWljIwM90mWs0pLSzVixIg2z8nKympVv2nTJg0bNkyRkZFd1tdLYYzR7Nmz9eqrr+qtt95SSkrKJbWze/duDRgwoJN71zUaGxu1b9++8/Y3lObvrz377LNKTEzUhAkTOnxuKM1fSkqKPB5P0Bw1NTWprKzsvNekdP55be+c7nI25Hz00Ud68803LylgG2NUWVkZEvP6+eefq7q6ut2+htL8/bU1a9YoIyNDQ4cO7fC53TmHF7o3fBuuw1a64xPQ35T169ebyMhIs2bNGvPBBx+Y/Px8Exsbaz755BNjjDGLFy82Pp/Prf/4449N7969zb333ms++OADs2bNGhMZGWn+67/+q7uGcF7//M//bBzHMW+//bapqalxt1OnTrk1545v1apVpri42Hz44Ydmz549ZvHixUaSeeWVV7pjCBc0f/588/bbb5uPP/7YlJeXm5ycHBMXF2fF/J3V0tJirrzySrNo0aJWx0Jx/hoaGszu3bvN7t27jSSzcuVKs3v3bvepo2XLlhnHccyrr75q3n//ffPjH//YDBgwwNTX17tt+Hy+oCcj//CHP5jw8HCzbNkys2/fPrNs2TITERFhysvLe9T4mpubTW5urrniiitMZWVl0HXZ2Nh43vEtXbrUlJSUmP/93/81u3fvNj/5yU9MRESE2b59e48aX0NDg5k/f77ZunWrqaqqMps3bzZZWVnmb/7mb0Jm/oy58N9RY4wJBAKmd+/e5sknn2yzjZ48hxdzbwj167CjrA46xhjzH//xH2bgwIEmKirKfPe73w16/Hrq1Klm5MiRQfVvv/22ue6660xUVJS56qqrzvsXvbtJanN79tln3Zpzx7d8+XJzzTXXmF69epk+ffqYG2+80WzYsOGb7/xFmjx5shkwYICJjIw0Xq/X3HbbbWbv3r3u8VCev7PeeOMNI8kcOHCg1bFQnL+zj8Cfu02dOtUY89WjrQ888IDxeDwmOjra3Hzzzeb9998PamPkyJFu/Vm/+c1vTFpamomMjDTXXnttt4W79sZXVVV13uty8+bNbhvnji8/P99ceeWVJioqyvTv399kZ2ebrVu3fvODM+2P79SpUyY7O9v079/fREZGmiuvvNJMnTrVHDp0KKiNnjx/xlz476gxxqxevdrExMSY48ePt9lGT57Di7k3hPp12FFhxvzfpzUBAAAsY+1ndAAAAAg6AADAWgQdAABgLYIOAACwFkEHAABYi6ADAACsRdABAADWIugAAABrEXQAAIC1CDoAAMBaBB0AAGAtgg4AALDW/wMhPwLu9rHL3wAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from collections import Counter\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "plt.hist(Counter(item[-1][:-1] for item in items_with_tuples).values())\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# # raw full ids\n", + "# full_duplicates = Counter(item[-1][:-1] for item in items_with_tuples).items()\n", + "# duplicated = [(semantic_id, amount) for (semantic_id, amount) in full_duplicates if amount > 1]\n", + "# duplicated" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Counter({1: 11195, 2: 316, 3: 44, 4: 12, 5: 3, 6: 3, 9: 2, 7: 2, 8: 1, 21: 1})" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# collison counters - (how many item have same full semantic id): amount of such sets\n", + "vals = Counter(item[-1][:-1] for item in items_with_tuples).values()\n", + "Counter(vals)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Counter({0: 11579,\n", + " 1: 384,\n", + " 2: 68,\n", + " 3: 24,\n", + " 4: 12,\n", + " 5: 9,\n", + " 6: 6,\n", + " 7: 4,\n", + " 8: 3,\n", + " 10: 1,\n", + " 11: 1,\n", + " 17: 1,\n", + " 12: 1,\n", + " 15: 1,\n", + " 19: 1,\n", + " 13: 1,\n", + " 14: 1,\n", + " 9: 1,\n", + " 18: 1,\n", + " 16: 1,\n", + " 20: 1})" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# dedup idxes\n", + "Counter(item[-1][4] for item in items_with_tuples)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# from sklearn import preprocessing\n", + "\n", + "# labels = df['asin']\n", + "\n", + "# le = preprocessing.LabelEncoder()\n", + "# targets = le.fit_transform(labels)\n", + "\n", + "# df['asin_numeric'] = targets\n", + "\n", + "# torch.save(df, './all_data.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "gsrec", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/modeling/models/rqvae.py b/modeling/models/rqvae.py index b105d9e3..de4bbf58 100644 --- a/modeling/models/rqvae.py +++ b/modeling/models/rqvae.py @@ -1,18 +1,16 @@ from models.base import TorchModel import torch -import torch.nn as nn import torch -from tqdm import tqdm import faiss class RqVaeModel(TorchModel, config_name='rqvae'): def __init__( self, - all_data, + train_sampler, input_dim: int, hidden_dim: int, n_iter: int, @@ -43,9 +41,11 @@ def __init__( self._init_weights(initializer_range) - embeddings = torch.stack([entry['item.embed'] for entry in all_data._dataset]) - if self.should_init_codebooks: + if train_sampler is None: + raise AttributeError("Train sampler is None") + + embeddings = torch.stack([entry['item.embed'] for entry in train_sampler._dataset]) self.init_codebooks(embeddings) print('Codebooks initialized with Faiss Kmeans') self.should_init_codebooks = False @@ -53,7 +53,7 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - all_data=kwargs['train_sampler'], + train_sampler=kwargs.get('train_sampler'), input_dim=config['input_dim'], hidden_dim=config['hidden_dim'], n_iter=config['n_iter'], @@ -141,9 +141,12 @@ def train_pass(self, embeddings): def eval_pass(self, embeddings): ind_lists = [] - for cb in self.codebooks: - dist = torch.cdist(self.encoder(embeddings), cb) - ind_lists.append(dist.argmin(dim=-1).cpu().numpy()) + remainder = self.encoder(embeddings) + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + codebook_vectors = codebook[codebook_indices] + ind_lists.append(codebook_indices.cpu().numpy()) + remainder = remainder - codebook_vectors return zip(*ind_lists) def forward(self, inputs): diff --git a/modeling/models/tiger.py b/modeling/models/tiger.py new file mode 100644 index 00000000..98a852db --- /dev/null +++ b/modeling/models/tiger.py @@ -0,0 +1,96 @@ +import json +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modeling.utils import DEVICE +from models.base import BaseModel, TorchModel + +# TODO finish tiger model +class TigerModel(TorchModel, config_name='tiger'): + def __init__( + self, + rqvae_encoder, + emb_dim, + n_tokens, + n_codebooks, + nhead, + num_encoder_layers, + num_decoder_layers, + dim_feedforward, + dropout + ): + super().__init__() + + self.rqvae_encoder = rqvae_encoder + self.emb_dim = emb_dim + self.n_tokens = n_tokens + + self.position_embeddings = nn.Embedding(n_codebooks, emb_dim) + self.item_embeddings = nn.Embedding(n_tokens, emb_dim) + + self.transformer = nn.Transformer( + d_model=emb_dim, + nhead=nhead, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + dropout=dropout + ) + + self.proj = nn.Linear(emb_dim, n_tokens) + + @classmethod + def create_from_config(cls, config, **kwargs): + rqvae_train_config = json.load(open(config['rqvae_train_config_path'])) + + rqvae_model = BaseModel.create_from_config(rqvae_train_config['model']).to(DEVICE) + rqvae_model.load_state_dict(torch.load(config['rqvae_checkpoint_path'], weights_only=True)) + rqvae_model.eval() + + return cls( + rqvae_encoder=rqvae_model, + emb_dim=config['emb_dim'], + n_tokens=config['n_tokens'], + n_codebooks=config['n_codebooks'], + nhead=config['nhead'], + num_encoder_layers=config['num_encoder_layers'], + num_decoder_layers=config['num_decoder_layers'], + dim_feedforward=config['dim_feedforward'], + dropout=config['dropout'] + ) + + def forward(self, user_item_history): + # Get item embeddings from RQVAE encoder + item_sequence = self.rqvae_encoder(user_item_history) + + # Convert item sequence to embeddings (embedding size is emb_dim) + item_embs = self.item_embeddings(item_sequence) + + # Add positional embeddings (positions are in the range [0, 3] for each tuple in the sequence) + positions = torch.arange(0, item_embs.size(1), device=item_embs.device).unsqueeze(0) + position_embs = self.position_embeddings(positions) + + # Add position embeddings to item embeddings + embeddings = item_embs + position_embs + + # Transformer expects the input to be in (seq_len, batch, embedding_dim) format + embeddings = embeddings.permute(1, 0, 2) # Convert to (seq_len, batch, emb_dim) + + # Create the target sequence for the transformer decoder + # You can shift the sequence for training as needed (e.g., teacher forcing) + target = embeddings.clone() # Use input embeddings as target for now + + # Pass through the transformer (using embeddings as both input and target) + transformer_output = self.transformer(embeddings, target) + + # Project the output back to token space (256 possible values for each codebook) + logits = self.proj(transformer_output) + + # Apply softmax to get probabilities (for cross-entropy loss) + return logits + + def compute_loss(self, logits, target): + # Compute cross-entropy loss + loss = F.cross_entropy(logits.view(-1, self.n_tokens), target.view(-1)) + return loss diff --git a/modeling/rqvae/__init__.py b/modeling/rqvae/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/collisions.py b/modeling/rqvae/collisions.py similarity index 100% rename from src/collisions.py rename to modeling/rqvae/collisions.py diff --git a/src/rqvae_data.py b/modeling/rqvae/rqvae_data.py similarity index 88% rename from src/rqvae_data.py rename to modeling/rqvae/rqvae_data.py index 19fb23d1..170c4e18 100644 --- a/src/rqvae_data.py +++ b/modeling/rqvae/rqvae_data.py @@ -75,18 +75,10 @@ def get_data(cached=True): with torch.no_grad(): df["embeddings"] = df["combined_text"].progress_apply(encode_text) else: - df = torch.load("../data/df_with_embs.pt", weights_only=False) + df = torch.load("../data/Beauty/all_data.pt", weights_only=False) return df -def get_cb_tuples(rqvae, embeddings): - ind_lists = [] - for cb in rqvae.codebooks: - dist = torch.cdist(rqvae.encoder(embeddings), cb) - ind_lists.append(dist.argmin(dim=-1).cpu().numpy()) - - return zip(*ind_lists) - def search_similar_items(items_with_tuples, clust2search, max_cnt=5): random.shuffle(items_with_tuples) diff --git a/review.md b/review.md index 572accc1..5977c185 100644 --- a/review.md +++ b/review.md @@ -2,7 +2,12 @@ ## Todos +- posterior collapse (как будто все сваливается в один индекс в кодбуке) (fixed eval code) +- обязательно использование reinit unused clusters! +- в Amazon датасете пофиг на rating? получается учитываются только implicit действия? +- TODO какой базовый класс использовать для e2e модели? (LastPred?) - TODO backward on mean loss? in `RqVae` +- TODO имя для модели (tiger) ## Links @@ -14,7 +19,7 @@ ## Todo -### Train +### Train full encoder-decoder - На чем обучать? То есть на каких данных запускать backward pass? - train model diff --git a/src/main.ipynb b/src/main.ipynb deleted file mode 100644 index f6d7b18b..00000000 --- a/src/main.ipynb +++ /dev/null @@ -1,171 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "from rqvae_data import get_data\n", - "\n", - "df = get_data()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "embs = torch.stack(df[\"embeddings\"].tolist())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "embs.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from rqvae import RQVAE\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "\n", - "rqvae = RQVAE(\n", - " input_dim=embs.shape[1],\n", - " hidden_dim=128,\n", - " beta=0.25,\n", - " codebook_sizes=[256] * 4,\n", - " should_init_codebooks=True,\n", - " should_reinit_unused_clusters=False,\n", - ").to(device)\n", - "\n", - "\n", - "embs_dict = {\"embedding\": embs.to(device)}\n", - "\n", - "rqvae.forward(embs_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "from collisions import dedup\n", - "from rqvae_data import get_cb_tuples\n", - "\n", - "\n", - "cb_tuples = list(get_cb_tuples(rqvae, embs_dict[\"embedding\"]))\n", - "items_with_tuples = list(zip(df[\"asin\"], df[\"title\"].fillna(\"unknown\"), cb_tuples))\n", - "items_with_tuples = dedup(items_with_tuples)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from rqvae_data import search_similar_items\n", - "\n", - "\n", - "for i in range(230, 240):\n", - " sim = search_similar_items(items_with_tuples, (i,), 10)\n", - " if len(sim) == 0:\n", - " continue\n", - " print(i)\n", - " for asin, item, clust_tuple in sim:\n", - " if 'nail' in item.lower():\n", - " print(f\"{item=} {clust_tuple=}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from collections import Counter\n", - "import matplotlib.pyplot as plt\n", - "\n", - "\n", - "plt.hist(Counter(item[-1][:-1] for item in items_with_tuples).values(), bins=100)\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "len(set(item[-1] for item in items_with_tuples))" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn import preprocessing\n", - "\n", - "labels = df['asin']\n", - "\n", - "le = preprocessing.LabelEncoder()\n", - "targets = le.fit_transform(labels)\n", - "\n", - "df['asin_numeric'] = targets" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "torch.save(df, './all_data.pt')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "gsrec", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/src/main.py b/src/main.py deleted file mode 100644 index 21d47550..00000000 --- a/src/main.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch -import typing -import random -import os - -from rqvae import RQVAE - -device = torch.device("cuda") - - -def get_cb_tuples(embeddings): - ind_lists = [] - for cb in rqvae.codebooks: - dist = torch.cdist(rqvae.encoder(embeddings), cb) - ind_lists.append(dist.argmin(dim=-1).cpu().numpy()) - - return zip(*ind_lists) - - -def search_similar_items(items_with_tuples): - random.shuffle(items_with_tuples) - clust2search = (585,) - cnt = 0 - for item, clust_tuple in items_with_tuples: - if clust_tuple[: len(clust2search)] == clust2search: - print(item, clust_tuple) - cnt += 1 - if cnt >= 5: - break - - -# TODO: add T5 sentence construction from huggingface - - -embs = {"embedding": []} - -rqvae = RQVAE( - input_dim=200, - hidden_dim=128, - beta=0.25, - codebook_sizes=[256] * 4, - should_init_codebooks=False, - should_reinit_unused_clusters=False, -).to(device) - -rqvae.forward(embs)