From d57695bd6c6a5ff5c5976df8dfb24a6cc4ffabf6 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 19 Sep 2022 16:00:58 +0200 Subject: [PATCH 1/4] Fix BaseOutput initialization from dict --- src/diffusers/utils/outputs.py | 38 ++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index b02f62d02d03..708271bf9d86 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -59,10 +59,40 @@ def __post_init__(self): if not len(class_fields): raise ValueError(f"{self.__class__.__name__} has no fields.") - for field in class_fields: - v = getattr(self, field.name) - if v is not None: - self[field.name] = v + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and not is_tensor(first_field): + if isinstance(first_field, dict): + iterator = first_field.items() + first_field_iterator = True + else: + try: + iterator = iter(first_field) + first_field_iterator = True + except TypeError: + first_field_iterator = False + + # if we provided an iterator as first field and the iterator is a (key, value) iterator + # set the associated fields + if first_field_iterator: + for element in iterator: + if ( + not isinstance(element, (list, tuple)) + or not len(element) == 2 + or not isinstance(element[0], str) + ): + break + setattr(self, element[0], element[1]) + if element[1] is not None: + self[element[0]] = element[1] + elif first_field is not None: + self[class_fields[0].name] = first_field + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") From 9f04ca8db0973feb6c385cd530e192dee252c656 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 19 Sep 2022 16:01:26 +0200 Subject: [PATCH 2/4] style --- src/diffusers/utils/outputs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 708271bf9d86..d8e695db59b0 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -78,9 +78,9 @@ def __post_init__(self): if first_field_iterator: for element in iterator: if ( - not isinstance(element, (list, tuple)) - or not len(element) == 2 - or not isinstance(element[0], str) + not isinstance(element, (list, tuple)) + or not len(element) == 2 + or not isinstance(element[0], str) ): break setattr(self, element[0], element[1]) From 690ac0516fa446e69a8c25fcde3a682f7cb9683f Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 19 Sep 2022 18:56:33 +0200 Subject: [PATCH 3/4] Simplify post-init, add tests --- src/diffusers/utils/outputs.py | 29 ++-------------- tests/test_outputs.py | 60 ++++++++++++++++++++++++++++++++++ tests/test_pipelines.py | 9 +++++ 3 files changed, 72 insertions(+), 26 deletions(-) create mode 100644 tests/test_outputs.py diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index d8e695db59b0..45d483ce7b1d 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -62,32 +62,9 @@ def __post_init__(self): first_field = getattr(self, class_fields[0].name) other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) - if other_fields_are_none and not is_tensor(first_field): - if isinstance(first_field, dict): - iterator = first_field.items() - first_field_iterator = True - else: - try: - iterator = iter(first_field) - first_field_iterator = True - except TypeError: - first_field_iterator = False - - # if we provided an iterator as first field and the iterator is a (key, value) iterator - # set the associated fields - if first_field_iterator: - for element in iterator: - if ( - not isinstance(element, (list, tuple)) - or not len(element) == 2 - or not isinstance(element[0], str) - ): - break - setattr(self, element[0], element[1]) - if element[1] is not None: - self[element[0]] = element[1] - elif first_field is not None: - self[class_fields[0].name] = first_field + if other_fields_are_none and isinstance(first_field, dict): + for key, value in first_field.items(): + self[key] = value else: for field in class_fields: v = getattr(self, field.name) diff --git a/tests/test_outputs.py b/tests/test_outputs.py new file mode 100644 index 000000000000..3c3054c885a1 --- /dev/null +++ b/tests/test_outputs.py @@ -0,0 +1,60 @@ +import unittest +from dataclasses import dataclass +from typing import List, Union + +import numpy as np + +import PIL.Image +from diffusers.utils.outputs import BaseOutput + + +@dataclass +class CustomOutput(BaseOutput): + images: Union[List[PIL.Image.Image], np.ndarray] + + +class ConfigTester(unittest.TestCase): + def test_outputs_single_attribute(self): + outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4)) + + # check every way of getting the attribute + assert isinstance(outputs.images, np.ndarray) + assert outputs.images.shape == (1, 3, 4, 4) + assert isinstance(outputs["images"], np.ndarray) + assert outputs["images"].shape == (1, 3, 4, 4) + assert isinstance(outputs[0], np.ndarray) + assert outputs[0].shape == (1, 3, 4, 4) + + # test with a non-tensor attribute + outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))]) + + # check every way of getting the attribute + assert isinstance(outputs.images, list) + assert isinstance(outputs.images[0], PIL.Image.Image) + assert isinstance(outputs["images"], list) + assert isinstance(outputs["images"][0], PIL.Image.Image) + assert isinstance(outputs[0], list) + assert isinstance(outputs[0][0], PIL.Image.Image) + + def test_outputs_dict_init(self): + # test output reinitialization with a `dict` for compatibility with `accelerate` + outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)}) + + # check every way of getting the attribute + assert isinstance(outputs.images, np.ndarray) + assert outputs.images.shape == (1, 3, 4, 4) + assert isinstance(outputs["images"], np.ndarray) + assert outputs["images"].shape == (1, 3, 4, 4) + assert isinstance(outputs[0], np.ndarray) + assert outputs[0].shape == (1, 3, 4, 4) + + # test with a non-tensor attribute + outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]}) + + # check every way of getting the attribute + assert isinstance(outputs.images, list) + assert isinstance(outputs.images[0], PIL.Image.Image) + assert isinstance(outputs["images"], list) + assert isinstance(outputs["images"][0], PIL.Image.Image) + assert isinstance(outputs[0], list) + assert isinstance(outputs[0][0], PIL.Image.Image) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 102a55a93e4b..931b01137b08 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -837,6 +837,15 @@ def test_output_format(self): assert isinstance(images, list) assert isinstance(images[0], PIL.Image.Image) + @slow + def test_output_dict(self): + from diffusers import DDIMPipeline + + pipeline = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32") + outputs = pipeline(num_inference_steps=2) + + print(outputs["images"]) + @slow def test_ddpm_cifar10(self): model_id = "google/ddpm-cifar10-32" From 8f37ca5634b4f1333ddbaa28523d90255cd63fb5 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 19 Sep 2022 18:56:53 +0200 Subject: [PATCH 4/4] remove debug --- tests/test_pipelines.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 931b01137b08..102a55a93e4b 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -837,15 +837,6 @@ def test_output_format(self): assert isinstance(images, list) assert isinstance(images[0], PIL.Image.Image) - @slow - def test_output_dict(self): - from diffusers import DDIMPipeline - - pipeline = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32") - outputs = pipeline(num_inference_steps=2) - - print(outputs["images"]) - @slow def test_ddpm_cifar10(self): model_id = "google/ddpm-cifar10-32"