From 03740de3221d6ee87171cb314e61fc874016e0c2 Mon Sep 17 00:00:00 2001 From: William Berman Date: Sun, 12 Feb 2023 18:51:07 -0800 Subject: [PATCH 1/4] unet check length input --- src/diffusers/models/unet_2d.py | 13 +++++++++- src/diffusers/models/unet_2d_condition.py | 21 ++++++++++++++++ .../versatile_diffusion/modeling_text_unet.py | 25 +++++++++++++++++++ .../test_stable_diffusion.py | 2 +- .../test_stable_diffusion_depth.py | 2 +- .../test_stable_diffusion_inpaint.py | 2 +- .../test_stable_diffusion_v_pred.py | 2 +- 7 files changed, 62 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 35f5dc34574c..6b67990dac06 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -94,7 +94,7 @@ def __init__( mid_block_scale_factor: float = 1, downsample_padding: int = 1, act_fn: str = "silu", - attention_head_dim: int = 8, + attention_head_dim: Optional[int] = 8, norm_num_groups: int = 32, norm_eps: float = 1e-5, resnet_time_scale_shift: str = "default", @@ -107,6 +107,17 @@ def __init__( self.sample_size = sample_size time_embed_dim = block_out_channels[0] * 4 + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + # input self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ba2c09b297b9..daffef765077 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -150,6 +150,27 @@ def __init__( self.sample_size = sample_size + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 806875a9a507..2c5b717ac861 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -236,6 +236,31 @@ def __init__( self.sample_size = sample_size + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + "Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:" + f" {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + "Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:" + f" {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + "Must provide the same number of `only_cross_attention` as `down_block_types`." + f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:" + f" {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = LinearMultiDim( diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 6db22626c9e0..9c9a0a018629 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -56,7 +56,7 @@ def get_dummy_components(self): up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, # SD2-specific config below - attention_head_dim=(2, 4, 8, 8), + attention_head_dim=(2, 4), use_linear_projection=True, ) scheduler = DDIMScheduler( diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 49d0e4dbdca2..a36b7925b624 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -65,7 +65,7 @@ def get_dummy_components(self): down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, - attention_head_dim=(2, 4, 8, 8), + attention_head_dim=(2, 4), use_linear_projection=True, ) scheduler = PNDMScheduler(skip_prk_steps=True) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 0ec47a510b50..58bdd465e422 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -47,7 +47,7 @@ def get_dummy_components(self): up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, # SD2-specific config below - attention_head_dim=(2, 4, 8, 8), + attention_head_dim=(2, 4), use_linear_projection=True, ) scheduler = PNDMScheduler(skip_prk_steps=True) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index a06f13632a3a..39cc546f6774 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -56,7 +56,7 @@ def dummy_cond_unet(self): up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, # SD2-specific config below - attention_head_dim=(2, 4, 8, 8), + attention_head_dim=(2, 4), use_linear_projection=True, ) return model From ec6eb488fd7b5c5f9a802f2ca317f3e5ff17e660 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 13 Feb 2023 12:31:40 +0200 Subject: [PATCH 2/4] prep test file for changes --- .../test_stable_diffusion_depth.py | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index a36b7925b624..666d094891f9 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -41,7 +41,7 @@ ) from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils.import_utils import is_accelerate_available -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import require_torch_gpu, print_tensor_test from ...test_pipelines_common import PipelineTesterMixin @@ -284,11 +284,12 @@ def test_stable_diffusion_depth2img_default_case(self): image = pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 32, 32, 3) - if torch_device == "mps": - expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546]) - else: - expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711]) +# assert image.shape == (1, 32, 32, 3) +# if torch_device == "mps": +# expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546]) +# else: + print_tensor_test(image_slice) + expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -306,10 +307,11 @@ def test_stable_diffusion_depth2img_negative_prompt(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - if torch_device == "mps": - expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335]) - else: - expected_slice = np.array([0.6332, 0.5167, 0.3911, 0.4446, 0.5971, 0.4619, 0.3821, 0.5323, 0.4621]) +# if torch_device == "mps": +# expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335]) +# else: + print_tensor_test(image_slice) + expected_slice = np.array([0.6332, 0.5167, 0.3911, 0.4446, 0.5971, 0.4619, 0.3821, 0.5323, 0.4621]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -328,10 +330,11 @@ def test_stable_diffusion_depth2img_multiple_init_images(self): assert image.shape == (2, 32, 32, 3) - if torch_device == "mps": - expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551]) - else: - expected_slice = np.array([0.6248, 0.5206, 0.6007, 0.6749, 0.5022, 0.6442, 0.5352, 0.4140, 0.4681]) +# if torch_device == "mps": +# expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551]) +# else: + print_tensor_test(image_slice) + expected_slice = np.array([0.6248, 0.5206, 0.6007, 0.6749, 0.5022, 0.6442, 0.5352, 0.4140, 0.4681]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -383,10 +386,11 @@ def test_stable_diffusion_depth2img_pil(self): image = pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - if torch_device == "mps": - expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439]) - else: - expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711]) +# if torch_device == "mps": +# expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439]) +# else: + expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711]) + print_tensor_test(image_slice) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 From 64d1727c264408167c0e3eb68999da5de96d5e56 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 13 Feb 2023 12:37:20 +0200 Subject: [PATCH 3/4] correct all tests --- test_corrections.txt | 4 ++++ .../stable_diffusion_2/test_stable_diffusion_depth.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 test_corrections.txt diff --git a/test_corrections.txt b/test_corrections.txt new file mode 100644 index 000000000000..227ed1b8bed2 --- /dev/null +++ b/test_corrections.txt @@ -0,0 +1,4 @@ +tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py;StableDiffusionDepth2ImgPipelineFastTests;test_stable_diffusion_depth2img_default_case;expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) +tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py;StableDiffusionDepth2ImgPipelineFastTests;test_stable_diffusion_depth2img_multiple_init_images;expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674]) +tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py;StableDiffusionDepth2ImgPipelineFastTests;test_stable_diffusion_depth2img_negative_prompt;expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626]) +tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py;StableDiffusionDepth2ImgPipelineFastTests;test_stable_diffusion_depth2img_pil;expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 666d094891f9..61d314442210 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -289,7 +289,7 @@ def test_stable_diffusion_depth2img_default_case(self): # expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546]) # else: print_tensor_test(image_slice) - expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711]) + expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -311,7 +311,7 @@ def test_stable_diffusion_depth2img_negative_prompt(self): # expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335]) # else: print_tensor_test(image_slice) - expected_slice = np.array([0.6332, 0.5167, 0.3911, 0.4446, 0.5971, 0.4619, 0.3821, 0.5323, 0.4621]) + expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -334,7 +334,7 @@ def test_stable_diffusion_depth2img_multiple_init_images(self): # expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551]) # else: print_tensor_test(image_slice) - expected_slice = np.array([0.6248, 0.5206, 0.6007, 0.6749, 0.5022, 0.6442, 0.5352, 0.4140, 0.4681]) + expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -389,7 +389,7 @@ def test_stable_diffusion_depth2img_pil(self): # if torch_device == "mps": # expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439]) # else: - expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711]) + expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) print_tensor_test(image_slice) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 From dbf78f515b51112abc79f88e1aecc5afa435ec3c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 13 Feb 2023 12:41:19 +0200 Subject: [PATCH 4/4] clean up --- test_corrections.txt | 4 -- .../test_stable_diffusion_depth.py | 40 +++++++++---------- 2 files changed, 18 insertions(+), 26 deletions(-) delete mode 100644 test_corrections.txt diff --git a/test_corrections.txt b/test_corrections.txt deleted file mode 100644 index 227ed1b8bed2..000000000000 --- a/test_corrections.txt +++ /dev/null @@ -1,4 +0,0 @@ -tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py;StableDiffusionDepth2ImgPipelineFastTests;test_stable_diffusion_depth2img_default_case;expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) -tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py;StableDiffusionDepth2ImgPipelineFastTests;test_stable_diffusion_depth2img_multiple_init_images;expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674]) -tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py;StableDiffusionDepth2ImgPipelineFastTests;test_stable_diffusion_depth2img_negative_prompt;expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626]) -tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py;StableDiffusionDepth2ImgPipelineFastTests;test_stable_diffusion_depth2img_pil;expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 61d314442210..14fcbef2d164 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -41,7 +41,7 @@ ) from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils.import_utils import is_accelerate_available -from diffusers.utils.testing_utils import require_torch_gpu, print_tensor_test +from diffusers.utils.testing_utils import require_torch_gpu from ...test_pipelines_common import PipelineTesterMixin @@ -284,12 +284,11 @@ def test_stable_diffusion_depth2img_default_case(self): image = pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] -# assert image.shape == (1, 32, 32, 3) -# if torch_device == "mps": -# expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546]) -# else: - print_tensor_test(image_slice) - expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) + assert image.shape == (1, 32, 32, 3) + if torch_device == "mps": + expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546]) + else: + expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -307,11 +306,10 @@ def test_stable_diffusion_depth2img_negative_prompt(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) -# if torch_device == "mps": -# expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335]) -# else: - print_tensor_test(image_slice) - expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626]) + if torch_device == "mps": + expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335]) + else: + expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -330,11 +328,10 @@ def test_stable_diffusion_depth2img_multiple_init_images(self): assert image.shape == (2, 32, 32, 3) -# if torch_device == "mps": -# expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551]) -# else: - print_tensor_test(image_slice) - expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674]) + if torch_device == "mps": + expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551]) + else: + expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -386,11 +383,10 @@ def test_stable_diffusion_depth2img_pil(self): image = pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] -# if torch_device == "mps": -# expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439]) -# else: - expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) - print_tensor_test(image_slice) + if torch_device == "mps": + expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439]) + else: + expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3