diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 64b38e2646..7ab3a47eba 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -282,7 +282,7 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch(batchdata, engine.state.device) + d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) batch_size = self.data_loader.batch_size g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) g_output = self.g_inferer(g_input, self.g_network)