Skip to content

New batching#308

Open
michaelteboul wants to merge 6 commits intomainfrom
new_batching
Open

New batching#308
michaelteboul wants to merge 6 commits intomainfrom
new_batching

Conversation

@michaelteboul
Copy link
Copy Markdown
Contributor

Description of the goal of the PR

Description:

Changes this PR introduces (fill it before implementation)

  • : New method to construct the batch --> improve performance on gpu
  • : Add a 4th iter_batch method for Item2vec/Prod2vec model

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @michaelteboul, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a substantial improvement to the data handling and training pipeline by integrating TensorFlow's tf.data.Dataset for more efficient batch processing, particularly beneficial for GPU utilization. It also expands the data generation capabilities to support new model types like Item2vec/Prod2vec by providing a dedicated method for creating pair-wised data. These changes collectively aim to enhance the performance and versatility of the basket models.

Highlights

  • Refactored Batching Mechanism: The data loading and batching process within the fit method has been significantly overhauled to leverage TensorFlow's tf.data.Dataset API. This change introduces more efficient data pipelining, including unbatching, shuffling, batching, and prefetching, which is expected to improve GPU performance.
  • New Data Generation Method for Item2vec/Prod2vec: A new get_pair_wised_data_from_trip_index method has been added to basket_dataset.py. This method generates pair-wised data from trips, suitable for models like Item2vec or Prod2vec, by creating item pairs from the same basket.
  • Updated Model Instantiation: The instantiate method in base_basket_model.py now accepts n_users as an argument, allowing models to be initialized with user-specific information.
  • Streamlined Training and Evaluation Loops: The training and validation loops in fit have been adjusted to integrate with the new tf.data.Dataset pipeline. The loss calculation and callback logging were updated to reflect average loss per batch, and the evaluate method now directly uses self.evaluate for metrics when available.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • choice_learn/basket_models/base_basket_model.py
    • Added n_users parameter to the instantiate method call.
    • Replaced manual batch iteration with a tf.data.Dataset pipeline for training, incorporating from_generator, unbatch, shuffle, batch, and prefetch.
    • Modified training loop to calculate and log average batch loss for callbacks.
    • Adjusted validation loop to conditionally use self.evaluate for metrics or iterate batches for loss calculation.
    • Changed average_on_batch to average_on_trip in the evaluate method for the basketwise-nll metric.
  • choice_learn/basket_models/data/basket_dataset.py
    • Updated get_one_vs_all_augmented_data_from_trip_index to initialize 'Future purchases' with zeros instead of an empty array.
    • Refactored iter_batch to yield data per trip based on data_method, removing the internal batching and buffering logic.
    • Introduced a new method get_pair_wised_data_from_trip_index to generate item-pair data for models like Prod2vec.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new data batching mechanism using tf.data.Dataset, which is a great step towards improving performance and simplifying the data pipeline. It also adds a new data preparation method for prod2vec models. The overall changes are positive, but I've found a critical issue that would cause a runtime error, and a few areas for improvement regarding code consistency and logging behavior during training. Please see my detailed comments.

from_logits=False,
epsilon=epsilon_eval,
average_on_batch=True,
average_on_trip=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The NegativeLogLikeliHood metric does not accept an average_on_trip argument. The correct argument name is average_on_batch. This change will cause a TypeError at runtime.

Suggested change
average_on_trip=True,
average_on_batch=True,

Comment on lines +721 to +726
inner_range = tqdm.tqdm(
inner_range,
total=int(trip_dataset.n_samples / np.max([batch_size, 1])),
position=1,
leave=False,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new data loading pipeline is a great improvement! However, it seems to have introduced a couple of minor regressions in user feedback during training compared to the previous implementation:

  1. The training progress bar is now displayed unconditionally, whereas before it was only shown if verbose > 0.
  2. The progress bar's description is no longer updated with the current training loss during the epoch, which was helpful for monitoring.

Consider reintroducing the verbose check and updating the description to provide better feedback during training. You could do something like this:

# ... after inner_range tf.data.Dataset is defined
if verbose > 0:
    inner_range = tqdm.tqdm(
        inner_range,
        total=int(trip_dataset.n_samples / np.max([batch_size, 1])),
        position=1,
        leave=False,
    )

# ... inside the training loop, after current_avg_loss is calculated
if verbose > 0 and isinstance(inner_range, tqdm.tqdm):
    inner_range.set_description(f"Train Loss: {current_avg_loss:.4f}")

for i in range(length_trip):
context_indices = [j for j in range(length_trip) if j != i]
if len(context_indices) > max_pairs_per_item:
context_indices = random.sample(context_indices, max_pairs_per_item) # nosec
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the rest of the codebase which uses numpy.random for random operations (e.g., np.random.permutation), it would be better to use np.random.choice here instead of random.sample. This also helps with reproducibility, as you can control the seed via np.random.seed() for the entire process.

Suggested change
context_indices = random.sample(context_indices, max_pairs_per_item) # nosec
context_indices = np.random.choice(context_indices, size=max_pairs_per_item, replace=False)

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Feb 19, 2026

Coverage

Coverage Report for Python 3.9
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2482988%115–116, 127, 145, 189, 259, 381, 489, 589–591, 680, 785, 793, 803, 851, 854–864, 925–928, 968–969
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1935572%74–77, 295–297, 407, 540–576, 604–644, 667, 674–679, 749–760, 808
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   badminton.py81693%62, 194–199, 247
   bakery.py38392%47, 51, 61
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py4002793%42–43, 154, 714, 1147–1205
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353590%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2362360%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py854349%74, 126–130, 147–166, 176, 190–199, 211–232, 242
TOTAL566090084% 

Tests Skipped Failures Errors Time
222 0 💤 2 ❌ 0 🔥 6m 37s ⏱️

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Feb 19, 2026

Coverage

Coverage Report for Python 3.10
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2482988%115–116, 127, 145, 189, 259, 381, 489, 589–591, 680, 785, 793, 803, 851, 854–864, 925–928, 968–969
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1935572%74–77, 295–297, 407, 540–576, 604–644, 667, 674–679, 749–760, 808
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   badminton.py81693%62, 194–199, 247
   bakery.py38392%47, 51, 61
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353689%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 908, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py854349%74, 126–130, 147–166, 176, 190–199, 211–232, 242
TOTAL566288184% 

Tests Skipped Failures Errors Time
222 0 💤 0 ❌ 0 🔥 5m 46s ⏱️

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Feb 19, 2026

Coverage

Coverage Report for Python 3.11
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2482988%115–116, 127, 145, 189, 259, 381, 489, 589–591, 680, 785, 793, 803, 851, 854–864, 925–928, 968–969
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1935572%74–77, 295–297, 407, 540–576, 604–644, 667, 674–679, 749–760, 808
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   badminton.py81693%62, 194–199, 247
   bakery.py38392%47, 51, 61
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353590%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py854349%74, 126–130, 147–166, 176, 190–199, 211–232, 242
TOTAL566288084% 

Tests Skipped Failures Errors Time
222 0 💤 0 ❌ 0 🔥 7m 31s ⏱️

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Feb 19, 2026

Coverage

Coverage Report for Python 3.12
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2482988%115–116, 127, 145, 189, 259, 381, 489, 589–591, 680, 785, 793, 803, 851, 854–864, 925–928, 968–969
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1935572%74–77, 295–297, 407, 540–576, 604–644, 667, 674–679, 749–760, 808
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   badminton.py81693%62, 194–199, 247
   bakery.py38392%47, 53, 61
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353590%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py854349%74, 126–130, 147–166, 176, 190–199, 211–232, 242
TOTAL566288084% 

Tests Skipped Failures Errors Time
222 0 💤 0 ❌ 0 🔥 7m 54s ⏱️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants