Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions configs/train/letter.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
{
"experiment_name": "letter_data",
"best_metric": "validation/ndcg@20",
"train_epochs_num": 100,
"dataset": {
"type": "letter_full",
"path_to_data_dir": "../data",
"name": "Beauty_letter",
"max_sequence_length": 50,
"samplers": {
"type": "last_item_prediction",
"negative_sampler_type": "random"
},
"beauty_inter_json": "../../LETTER/data/Beauty/Beauty.inter.json"
},
"dataloader": {
"train": {
"type": "torch",
"batch_size": 256,
"batch_processor": {
"type": "letter",
"beauty_index_json": "../../LETTER/data/Beauty/Beauty.index.json",
"semantic_length": 4
},
"drop_last": true,
"shuffle": true
},
"validation": {
"type": "torch",
"batch_size": 256,
"batch_processor": {
"type": "letter",
"beauty_index_json": "../../LETTER/data/Beauty/Beauty.index.json",
"semantic_length": 4
},
"drop_last": false,
"shuffle": false
}
},
"model": {
"type": "tiger",
"rqvae_train_config_path": "../configs/train/rqvae_train_config.json",
"rqvae_checkpoint_path": "../checkpoints/rqvae_beauty_final_state.pth",
"embs_extractor_path": "../data/Beauty/rqvae/data_full.pt",
"sequence_prefix": "item",
"predictions_prefix": "logits",
"positive_prefix": "labels",
"labels_prefix": "labels",
"embedding_dim": 64,
"num_heads": 2,
"num_encoder_layers": 2,
"num_decoder_layers": 2,
"dim_feedforward": 256,
"dropout": 0.3,
"activation": "gelu",
"layer_norm_eps": 1e-9,
"initializer_range": 0.02
},
"optimizer": {
"type": "basic",
"optimizer": {
"type": "adam",
"lr": 0.001
},
"clip_grad_threshold": 5.0
},
"loss": {
"type": "composite",
"losses": [
{
"type": "ce",
"predictions_prefix": "logits",
"labels_prefix": "semantic.labels",
"weight": 1.0,
"output_prefix": "semantic_loss"
},
{
"type": "ce",
"predictions_prefix": "dedup.logits",
"labels_prefix": "dedup.labels",
"weight": 1.0,
"output_prefix": "dedup_loss"
}
],
"output_prefix": "loss"
},
"callback": {
"type": "composite",
"callbacks": [
{
"type": "metric",
"on_step": 1,
"loss_prefix": "loss"
},
{
"type": "validation",
"on_step": 1024,
"pred_prefix": "logits",
"labels_prefix": "labels",
"metrics": {
"ndcg@5": {
"type": "ndcg",
"k": 5
},
"ndcg@10": {
"type": "ndcg",
"k": 10
},
"ndcg@20": {
"type": "ndcg",
"k": 20
},
"recall@5": {
"type": "recall",
"k": 5
},
"recall@10": {
"type": "recall",
"k": 10
},
"recall@20": {
"type": "recall",
"k": 20
},
"coverage@5": {
"type": "coverage",
"k": 5
},
"coverage@10": {
"type": "coverage",
"k": 10
},
"coverage@20": {
"type": "coverage",
"k": 20
}
}
},
{
"type": "eval",
"on_step": 2048,
"pred_prefix": "logits",
"labels_prefix": "labels",
"metrics": {
"ndcg@5": {
"type": "ndcg",
"k": 5
},
"ndcg@10": {
"type": "ndcg",
"k": 10
},
"ndcg@20": {
"type": "ndcg",
"k": 20
},
"recall@5": {
"type": "recall",
"k": 5
},
"recall@10": {
"type": "recall",
"k": 10
},
"recall@20": {
"type": "recall",
"k": 20
},
"coverage@5": {
"type": "coverage",
"k": 5
},
"coverage@10": {
"type": "coverage",
"k": 10
},
"coverage@20": {
"type": "coverage",
"k": 20
}
}
}
]
}
}

60 changes: 60 additions & 0 deletions modeling/dataloader/batch_processors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json
import re

import torch
from utils import MetaParent

Expand Down Expand Up @@ -43,3 +46,60 @@ def __call__(self, batch):
processed_batch[part] = torch.tensor(values, dtype=torch.long)

return processed_batch


class LetterBatchProcessor(BaseBatchProcessor, config_name='letter'):
def __init__(self, mapping, semantic_length):
self._mapping: dict[int, list[int]] = mapping
self._prefixes = ['item', 'labels', 'positive', 'negative']
self._semantic_length = semantic_length

@classmethod
def create_from_config(cls, config, **kwargs):
mapping_path = config["beauty_index_json"]
with open(mapping_path, "r") as f:
mapping = json.load(f)

semantic_length = config["semantic_length"]

parsed = {}

for key, semantic_ids in mapping.items():
numbers = [int(re.search(r'\d+', item).group()) for item in semantic_ids]
assert len(numbers) == semantic_length
parsed[int(key)] = numbers

return cls(mapping=parsed, semantic_length=semantic_length)

def __call__(self, batch):
processed_batch = {}

for key in batch[0].keys():
if key.endswith('.ids'):
prefix = key.split('.')[0]
assert '{}.length'.format(prefix) in batch[0]

processed_batch[f'{prefix}.ids'] = []
processed_batch[f'{prefix}.length'] = []

for sample in batch:
processed_batch[f'{prefix}.ids'].extend(sample[f'{prefix}.ids'])
processed_batch[f'{prefix}.length'].append(sample[f'{prefix}.length'])

for prefix in self._prefixes:
if f"{prefix}.ids" in processed_batch:
ids = processed_batch[f"{prefix}.ids"]
lengths = processed_batch[f"{prefix}.length"]

flattened_ids = []

for _id in ids:
flattened_ids.extend(self._mapping[_id])

processed_batch[f"semantic_{prefix}.ids"] = flattened_ids
processed_batch[f"semantic_{prefix}.length"] = [length * self._semantic_length for length in lengths]

for part, values in processed_batch.items():
processed_batch[part] = torch.tensor(values, dtype=torch.long)

return processed_batch
75 changes: 73 additions & 2 deletions modeling/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def create_from_config(cls, config, **kwargs):
def flatten_item_sequence(cls, item_ids):
min_history_length = 3 # TODOPK make this configurable
histories = []
for i in range(min_history_length-1, len(item_ids)):
histories.append(item_ids[:i+1])
for i in range(min_history_length, len(item_ids)):
histories.append(item_ids[:i])
return histories

@classmethod
Expand Down Expand Up @@ -791,12 +791,31 @@ def create_from_config(cls, config, **kwargs):
max_item_id = max(max_item_id, max(item_ids))

assert len(item_ids) >= 5

# item_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# prefix_length: 5, 6, 7, 8, 9, 10
for prefix_length in range(5, len(item_ids) + 1):
# prefix = [1, 2, 3, 4, 5]
# prefix = [1, 2, 3, 4, 5, 6]
# prefix = [1, 2, 3, 4, 5, 6, 7]
# prefix = [1, 2, 3, 4, 5, 6, 7, 8]
# prefix = [1, 2, 3, 4, 5, 6, 7, 8, 9]
# prefix = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]



prefix = item_ids[
:prefix_length
] # TODOPK no sliding window, only incrmenting sequence from last 50 items

# prefix[:-2] = [1, 2, 3]
# prefix[:-2] = [1, 2, 3, 4]
# prefix[:-2] = [1, 2, 3, 4, 5]
# prefix[:-2] = [1, 2, 3, 4, 5, 6]
# prefix[:-2] = [1, 2, 3, 4, 5, 6, 7]
# prefix[:-2] = [1, 2, 3, 4, 5, 6, 7, 8]

train_dataset.append(
{
"user.ids": [user_id],
Expand All @@ -809,6 +828,7 @@ def create_from_config(cls, config, **kwargs):
set(prefix[:-2][-max_sequence_length:])
)

# item_ids[:-1] = [1, 2, 3, 4, 5, 6, 7, 8, 9]
validation_dataset.append(
{
"user.ids": [user_id],
Expand All @@ -820,6 +840,8 @@ def create_from_config(cls, config, **kwargs):
assert len(item_ids[:-1][-max_sequence_length:]) == len(
set(item_ids[:-1][-max_sequence_length:])
)

# item_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
test_dataset.append(
{
"user.ids": [user_id],
Expand Down Expand Up @@ -874,6 +896,55 @@ def create_from_config(cls, config, **kwargs):
)


class LetterFullDataset(ScientificFullDataset, config_name="letter_full"):
def __init__(
self,
train_sampler,
validation_sampler,
test_sampler,
num_users,
num_items,
max_sequence_length,
):
self._train_sampler = train_sampler
self._validation_sampler = validation_sampler
self._test_sampler = test_sampler
self._num_users = num_users
self._num_items = num_items
self._max_sequence_length = max_sequence_length

@classmethod
def create_from_config(cls, config, **kwargs):
user_interactions_path = os.path.join(config["beauty_inter_json"])
with open(user_interactions_path, "r") as f:
user_interactions = json.load(f)

dir_path = os.path.join(config["path_to_data_dir"], config["name"])

os.makedirs(dir_path, exist_ok=True)
dataset_path = os.path.join(dir_path, "all_data.txt")

logger.info(f"Saving data to {dataset_path}")

# Map from LETTER format to Our format
with open(dataset_path, "w") as f:
for user_id, item_ids in user_interactions.items():
items_repr = map(str, item_ids)
f.write(f"{user_id} {' '.join(items_repr)}\n")

dataset = ScientificFullDataset.create_from_config(config, **kwargs)

return cls(
train_sampler=dataset._train_sampler,
validation_sampler=dataset._validation_sampler,
test_sampler=dataset._test_sampler,
num_users=dataset._num_users,
num_items=dataset._num_items,
max_sequence_length=dataset._max_sequence_length,
)



class RqVaeDataset(BaseDataset, config_name='rqvae'):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ dependencies = [
"tensorboard>=2",
"torch>=2.7",
"transformers>=4.51",
"tqdm>=4",
"jupyter>=1",
]

[tool.uv.sources]
Expand Down