From 678863a66408e5111d6898803fa9a250abd3550e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 17:45:23 +0200 Subject: [PATCH 1/2] lowe tolerance --- tests/test_models_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 80055c1a10f8..5bc8ec7bc64f 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -232,7 +232,7 @@ def test_gradient_checkpointing(self): # compare the output and parameters gradients self.assertTrue((output_checkpointed == output_not_checkpointed).all()) for name in grad_checkpointed: - self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5)) + self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=8e-3)) # TODO(Patrick) - Re-add this test after having cleaned up LDM From 6b48c809468e1962b495912ba45260dcc7a6cf95 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 29 Sep 2022 11:47:02 +0200 Subject: [PATCH 2/2] put model in eval mode --- tests/test_models_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 5bc8ec7bc64f..94a186d1c06a 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -199,7 +199,7 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + model = self.model_class(**init_dict).eval() model.to(torch_device) out = model(**inputs_dict).sample @@ -232,7 +232,7 @@ def test_gradient_checkpointing(self): # compare the output and parameters gradients self.assertTrue((output_checkpointed == output_not_checkpointed).all()) for name in grad_checkpointed: - self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=8e-3)) + self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5)) # TODO(Patrick) - Re-add this test after having cleaned up LDM