Skip to content

Determinism implementation in nnunet#2871

Open
Luugaaa wants to merge 3 commits intoMIC-DKFZ:masterfrom
Luugaaa:determinism_implementation
Open

Determinism implementation in nnunet#2871
Luugaaa wants to merge 3 commits intoMIC-DKFZ:masterfrom
Luugaaa:determinism_implementation

Conversation

@Luugaaa
Copy link
Copy Markdown

@Luugaaa Luugaaa commented Jul 18, 2025

Hi MIC-DKFZ team,

This PR introduces updates to enable fully deterministic training in nnU-Net, which is crucial for reproducibility in research. The changes include adding a deterministic flag, implementing a seed_everything function, and ensuring the data augmentation pipeline is correctly seeded. I've been working on a similar fix for batchgenerators and believe these changes will work together to make the entire training process reproducible.

Problem

Achieving deterministic behavior in a multi-process environment can be tricky. Even with seeding, sources of randomness can persist, especially in the data augmentation pipeline. The original nnU-Net trainer used a non-deterministic data loader by default and didn't have a straightforward way to enforce reproducibility across all components, including PyTorch, NumPy, and the data loading workers. This could lead to slight variations in training results, even with the same initial seed.

Solution

To address this, I've implemented the following changes:

  1. deterministic Flag: A deterministic boolean flag has been added to the nnUNetTrainer's __init__ method. When set to True, it activates all the changes needed for reproducible training.
  2. seed_everything Function: A new helper function, seed_everything, is called when the deterministic flag is active. This function sets the seeds for random, numpy, and torch, and also configures cudnn for deterministic behavior to eliminate sources of randomness from the GPU.
  3. Seeded Data Loader: The get_dataloaders method now checks the deterministic flag. If True, it uses MultiThreadedAugmenter and passes a unique, generated seed to each worker process. This ensures that the data augmentation pipeline is fully deterministic and produces the same results in the same order for every run. If False, it continues to use the default NonDetMultiThreadedAugmenter. Same in get_training_transforms to enable or not the benchmark.

How It's Tested

To validate these changes, I've developed a determinism test pipeline that sets up a dummy dataset and runs the entire preprocessing and training pipeline for two epochs with the 2d and 3d configuration. The test works as follows:

  1. A dummy dataset is generated, and nnU-Net's preprocessing is run.
  2. The nnUNetTrainer is instantiated with deterministic=True, and a short training session is run. The final checkpoint is saved.
  3. This process is repeated a second time, with a new trainer instance but the same deterministic settings.
  4. The two resulting checkpoints are then compared byte-for-byte to ensure that the model weights and training/validation losses are identical.

With these changes, the trainer now passes this test, confirming that the training process is fully reproducible when the deterministic flag is enabled. This should be a significant help for researchers who need to ensure their results are perfectly reproducible. The changes are self-contained and don't affect the default behavior of the trainer.

Notes :

  • Unfortunately, I can only test on cpu, there is a slight chance cuda could have a different behavior, even if it's seeded.
  • I've run those tests with my deterministic version batchgeneratorsv2, please see this PR

I hope these changes are helpful. Thanks for maintaining this great project, and I look forward to your feedback!

Post Scriptum

Here is a partial output of the determinism test pipeline :

...
2025-07-18 16:41:32.447456: Unable to plot network architecture:
2025-07-18 16:41:32.447667: module 'torch.onnx' has no attribute '_optimize_trace'
2025-07-18 16:41:32.467535: 
2025-07-18 16:41:32.467720: Epoch 0
2025-07-18 16:41:32.468179: Current learning rate: 0.01
2025-07-18 16:41:36.461223: train_loss -0.3063
2025-07-18 16:41:36.461562: val_loss -0.5951
2025-07-18 16:41:36.461810: Pseudo dice [np.float32(0.9153)]
2025-07-18 16:41:36.461934: Epoch time: 3.99 s
2025-07-18 16:41:36.462014: Yayy! New best EMA pseudo Dice: 0.9153000116348267
2025-07-18 16:41:36.814825: 
2025-07-18 16:41:36.815006: Epoch 1
2025-07-18 16:41:36.815104: Current learning rate: 0.00536
2025-07-18 16:41:40.652067: train_loss -0.4862
2025-07-18 16:41:40.652299: val_loss -0.7087
2025-07-18 16:41:40.652573: Pseudo dice [np.float32(0.9449)]
2025-07-18 16:41:40.652686: Epoch time: 3.84 s
2025-07-18 16:41:40.652763: Yayy! New best EMA pseudo Dice: 0.9182000160217285
2025-07-18 16:41:41.638668: Training done.

...
2025-07-18 16:41:45.915659: Unable to plot network architecture:
2025-07-18 16:41:45.915884: module 'torch.onnx' has no attribute '_optimize_trace'
2025-07-18 16:41:45.934981: 
2025-07-18 16:41:45.935152: Epoch 0
2025-07-18 16:41:45.935414: Current learning rate: 0.01
2025-07-18 16:41:50.681323: train_loss -0.3063
2025-07-18 16:41:50.681636: val_loss -0.5951
2025-07-18 16:41:50.681895: Pseudo dice [np.float32(0.9153)]
2025-07-18 16:41:50.682026: Epoch time: 4.75 s
2025-07-18 16:41:50.682148: Yayy! New best EMA pseudo Dice: 0.9153000116348267
2025-07-18 16:41:51.014654: 
2025-07-18 16:41:51.014840: Epoch 1
2025-07-18 16:41:51.014943: Current learning rate: 0.00536
2025-07-18 16:41:55.595253: train_loss -0.4862
2025-07-18 16:41:55.595432: val_loss -0.7087
2025-07-18 16:41:55.595511: Pseudo dice [np.float32(0.9449)]
2025-07-18 16:41:55.595585: Epoch time: 4.58 s
2025-07-18 16:41:55.595646: Yayy! New best EMA pseudo Dice: 0.9182000160217285
2025-07-18 16:41:56.397161: Training done.
Finished training run: run2_2d

--- Comparing checkpoints ---
Model weights are identical: True
Training losses are identical: True
Validation losses are identical: TruePASSED: The training process for '2d' is deterministic.
========== FINISHED TEST FOR CONFIGURATION: 2D ==========

...
2025-07-18 16:42:01.014618: Unable to plot network architecture:
2025-07-18 16:42:01.014909: module 'torch.onnx' has no attribute '_optimize_trace'
2025-07-18 16:42:01.044617: 
2025-07-18 16:42:01.044771: Epoch 0
2025-07-18 16:42:01.045026: Current learning rate: 0.01
2025-07-18 16:42:46.660177: train_loss -0.1917
2025-07-18 16:42:46.660746: val_loss -0.494
2025-07-18 16:42:46.660841: Pseudo dice [np.float32(0.8838)]
2025-07-18 16:42:46.661161: Epoch time: 45.62 s
2025-07-18 16:42:46.661241: Yayy! New best EMA pseudo Dice: 0.8838000297546387
2025-07-18 16:42:47.174816: 
2025-07-18 16:42:47.174935: Epoch 1
2025-07-18 16:42:47.175024: Current learning rate: 0.00536
2025-07-18 16:43:32.276000: train_loss -0.4466
2025-07-18 16:43:32.277203: val_loss -0.6225
2025-07-18 16:43:32.277357: Pseudo dice [np.float32(0.9225)]
2025-07-18 16:43:32.277469: Epoch time: 45.1 s
2025-07-18 16:43:32.277545: Yayy! New best EMA pseudo Dice: 0.8877000212669373
2025-07-18 16:43:33.464729: Training done.

...
2025-07-18 16:43:38.122881: Unable to plot network architecture:
2025-07-18 16:43:38.123304: module 'torch.onnx' has no attribute '_optimize_trace'
2025-07-18 16:43:38.144184: 
2025-07-18 16:43:38.144306: Epoch 0
2025-07-18 16:43:38.144516: Current learning rate: 0.01
2025-07-18 16:44:25.789200: train_loss -0.1917
2025-07-18 16:44:25.789970: val_loss -0.494
2025-07-18 16:44:25.790100: Pseudo dice [np.float32(0.8838)]
2025-07-18 16:44:25.790242: Epoch time: 47.65 s
2025-07-18 16:44:25.790316: Yayy! New best EMA pseudo Dice: 0.8838000297546387
2025-07-18 16:44:26.444514: 
2025-07-18 16:44:26.444639: Epoch 1
2025-07-18 16:44:26.444731: Current learning rate: 0.00536
2025-07-18 16:45:12.907103: train_loss -0.4466
2025-07-18 16:45:12.909298: val_loss -0.6225
2025-07-18 16:45:12.909446: Pseudo dice [np.float32(0.9225)]
2025-07-18 16:45:12.909584: Epoch time: 46.46 s
2025-07-18 16:45:12.909658: Yayy! New best EMA pseudo Dice: 0.8877000212669373
2025-07-18 16:45:14.147640: Training done.
Finished training run: run2_3d_fullres

--- Comparing checkpoints ---
Model weights are identical: True
Training losses are identical: True
Validation losses are identical: TruePASSED: The training process for '3d_fullres' is deterministic.
========== FINISHED TEST FOR CONFIGURATION: 3D_FULLRES ==========


=========================
  DETERMINISM TEST REPORT
=========================
Configuration '2d': ✅ PASSED
Configuration '3d_fullres': ✅ PASSED
=========================

@Luugaaa
Copy link
Copy Markdown
Author

Luugaaa commented Jul 28, 2025

Note : I was finally able to test the determinism fix on CUDA. Although the reproducibility is greatly improved, the training is not deterministic. I've pinpointed the source of the non determinism to something happening in the backward path from batch_id = 1. I don't plan on investigating further.

@FabianIsensee FabianIsensee self-assigned this Oct 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants