Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/diffusers/utils/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,17 @@ def __post_init__(self):
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")

for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])

if other_fields_are_none and isinstance(first_field, dict):
for key, value in first_field.items():
self[key] = value
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v

def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
Expand Down
60 changes: 60 additions & 0 deletions tests/test_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest
from dataclasses import dataclass
from typing import List, Union

import numpy as np

import PIL.Image
from diffusers.utils.outputs import BaseOutput


@dataclass
class CustomOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]


class ConfigTester(unittest.TestCase):
def test_outputs_single_attribute(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4))

# check every way of getting the attribute
assert isinstance(outputs.images, np.ndarray)
assert outputs.images.shape == (1, 3, 4, 4)
assert isinstance(outputs["images"], np.ndarray)
Copy link
Member Author

@anton-l anton-l Sep 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformers version (initial commit from this PR) was failing on this line
cc @patrickvonplaten

assert outputs["images"].shape == (1, 3, 4, 4)
assert isinstance(outputs[0], np.ndarray)
assert outputs[0].shape == (1, 3, 4, 4)

# test with a non-tensor attribute
outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])

# check every way of getting the attribute
assert isinstance(outputs.images, list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool that this is all tested now!

assert isinstance(outputs.images[0], PIL.Image.Image)
assert isinstance(outputs["images"], list)
assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image)

def test_outputs_dict_init(self):
# test output reinitialization with a `dict` for compatibility with `accelerate`
outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)})

# check every way of getting the attribute
assert isinstance(outputs.images, np.ndarray)
Copy link
Member Author

@anton-l anton-l Sep 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pre-PR diffusers version was failing on this line

assert outputs.images.shape == (1, 3, 4, 4)
assert isinstance(outputs["images"], np.ndarray)
assert outputs["images"].shape == (1, 3, 4, 4)
assert isinstance(outputs[0], np.ndarray)
assert outputs[0].shape == (1, 3, 4, 4)

# test with a non-tensor attribute
outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]})

# check every way of getting the attribute
assert isinstance(outputs.images, list)
assert isinstance(outputs.images[0], PIL.Image.Image)
assert isinstance(outputs["images"], list)
assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image)