Skip to content
7 changes: 6 additions & 1 deletion src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,12 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg
hidden_states = self.conv2(hidden_states)

if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor.contiguous())
# Only use contiguous() during training to avoid DDP gradient stride mismatch warning.
# In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU.
# Issue: https://github.com/huggingface/diffusers/issues/12975
if self.training:
input_tensor = input_tensor.contiguous()
input_tensor = self.conv_shortcut(input_tensor)

output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

Expand Down
3 changes: 3 additions & 0 deletions tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def test_inference_batch_single_identical(self):
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=5e-1)

def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)

@is_flaky()
def test_model_cpu_offload_forward_pass(self):
super().test_inference_batch_single_identical(expected_max_diff=8e-4)
Expand Down
3 changes: 3 additions & 0 deletions tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def test_float16_inference(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)

def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)


@slow
@require_torch_accelerator
Expand Down
Loading