From b1ff16aab7704dc954d1c91b6e9e2eb0a9fab8bb Mon Sep 17 00:00:00 2001 From: Charles Bensimon Date: Fri, 29 Sep 2023 12:18:57 +0200 Subject: [PATCH 1/5] Make BaseOutput dataclasses picklable --- src/diffusers/utils/outputs.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 37b11561d1e1..632175df5508 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -16,7 +16,7 @@ """ from collections import OrderedDict -from dataclasses import fields +from dataclasses import fields, is_dataclass from typing import Any, Tuple import numpy as np @@ -101,6 +101,13 @@ def __setitem__(self, key, value): # Don't call self.__setattr__ to avoid recursion errors super().__setattr__(key, value) + def __reduce__(self): + if not is_dataclass(self): + return super().__reduce__() + callable, _args, state, istate, dstate = super().__reduce__() + args = tuple(getattr(self, field.name) for field in fields(self)) + return callable, args, state, istate, dstate + def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not `None`. From 69c4df6c40329b2c4d7cb917a4924a103f72bad1 Mon Sep 17 00:00:00 2001 From: cbensimon Date: Fri, 29 Sep 2023 10:25:37 +0000 Subject: [PATCH 2/5] make style --- src/diffusers/utils/outputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 632175df5508..41ae1c4447de 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -107,7 +107,7 @@ def __reduce__(self): callable, _args, state, istate, dstate = super().__reduce__() args = tuple(getattr(self, field.name) for field in fields(self)) return callable, args, state, istate, dstate - + def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not `None`. From 0a99c3e65ae7524557225432aadb3807cfc7d9d3 Mon Sep 17 00:00:00 2001 From: cbensimon Date: Fri, 29 Sep 2023 10:40:29 +0000 Subject: [PATCH 3/5] Test --- tests/others/test_outputs.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/others/test_outputs.py b/tests/others/test_outputs.py index 50cbd1d54ee4..492e71f0ba31 100644 --- a/tests/others/test_outputs.py +++ b/tests/others/test_outputs.py @@ -1,3 +1,4 @@ +import pickle as pkl import unittest from dataclasses import dataclass from typing import List, Union @@ -58,3 +59,13 @@ def test_outputs_dict_init(self): assert isinstance(outputs["images"][0], PIL.Image.Image) assert isinstance(outputs[0], list) assert isinstance(outputs[0][0], PIL.Image.Image) + + def test_outputs_serialization(self): + outputs_orig = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))]) + serialized = pkl.dumps(outputs_orig) + outputs_copy = pkl.loads(serialized) + + # Check original and copy are equal + assert dir(outputs_orig) == dir(outputs_copy) + assert dict(outputs_orig) == dict(outputs_copy) + assert vars(outputs_orig) == vars(outputs_copy) From 6b81b4a137f34cba147d0e694b6bd104aafa7dac Mon Sep 17 00:00:00 2001 From: cbensimon Date: Fri, 29 Sep 2023 10:49:19 +0000 Subject: [PATCH 4/5] Empty commit From 85a8f6f33ea83023dfe1616aee4feeceffba60fc Mon Sep 17 00:00:00 2001 From: cbensimon Date: Fri, 29 Sep 2023 11:01:08 +0000 Subject: [PATCH 5/5] Simpler and safer --- src/diffusers/utils/outputs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 41ae1c4447de..802c699eb9cc 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -104,9 +104,9 @@ def __setitem__(self, key, value): def __reduce__(self): if not is_dataclass(self): return super().__reduce__() - callable, _args, state, istate, dstate = super().__reduce__() + callable, _args, *remaining = super().__reduce__() args = tuple(getattr(self, field.name) for field in fields(self)) - return callable, args, state, istate, dstate + return callable, args, *remaining def to_tuple(self) -> Tuple[Any]: """