diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index b02f62d02d03..45d483ce7b1d 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -59,10 +59,17 @@ def __post_init__(self): if not len(class_fields): raise ValueError(f"{self.__class__.__name__} has no fields.") - for field in class_fields: - v = getattr(self, field.name) - if v is not None: - self[field.name] = v + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and isinstance(first_field, dict): + for key, value in first_field.items(): + self[key] = value + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") diff --git a/tests/test_outputs.py b/tests/test_outputs.py new file mode 100644 index 000000000000..3c3054c885a1 --- /dev/null +++ b/tests/test_outputs.py @@ -0,0 +1,60 @@ +import unittest +from dataclasses import dataclass +from typing import List, Union + +import numpy as np + +import PIL.Image +from diffusers.utils.outputs import BaseOutput + + +@dataclass +class CustomOutput(BaseOutput): + images: Union[List[PIL.Image.Image], np.ndarray] + + +class ConfigTester(unittest.TestCase): + def test_outputs_single_attribute(self): + outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4)) + + # check every way of getting the attribute + assert isinstance(outputs.images, np.ndarray) + assert outputs.images.shape == (1, 3, 4, 4) + assert isinstance(outputs["images"], np.ndarray) + assert outputs["images"].shape == (1, 3, 4, 4) + assert isinstance(outputs[0], np.ndarray) + assert outputs[0].shape == (1, 3, 4, 4) + + # test with a non-tensor attribute + outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))]) + + # check every way of getting the attribute + assert isinstance(outputs.images, list) + assert isinstance(outputs.images[0], PIL.Image.Image) + assert isinstance(outputs["images"], list) + assert isinstance(outputs["images"][0], PIL.Image.Image) + assert isinstance(outputs[0], list) + assert isinstance(outputs[0][0], PIL.Image.Image) + + def test_outputs_dict_init(self): + # test output reinitialization with a `dict` for compatibility with `accelerate` + outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)}) + + # check every way of getting the attribute + assert isinstance(outputs.images, np.ndarray) + assert outputs.images.shape == (1, 3, 4, 4) + assert isinstance(outputs["images"], np.ndarray) + assert outputs["images"].shape == (1, 3, 4, 4) + assert isinstance(outputs[0], np.ndarray) + assert outputs[0].shape == (1, 3, 4, 4) + + # test with a non-tensor attribute + outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]}) + + # check every way of getting the attribute + assert isinstance(outputs.images, list) + assert isinstance(outputs.images[0], PIL.Image.Image) + assert isinstance(outputs["images"], list) + assert isinstance(outputs["images"][0], PIL.Image.Image) + assert isinstance(outputs[0], list) + assert isinstance(outputs[0][0], PIL.Image.Image)