Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
92818ef
make tests deterministic
patrickvonplaten Jan 5, 2023
745b167
run slow tests
patrickvonplaten Jan 5, 2023
ec685c2
Merge branch 'main' into make_tests_deterministic
patrickvonplaten Jan 23, 2023
ae219a8
prepare for testing
patrickvonplaten Jan 23, 2023
7256866
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Jan 24, 2023
d5d80b8
finish
patrickvonplaten Jan 24, 2023
2458069
refactor
patrickvonplaten Jan 24, 2023
d81dd8f
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Jan 24, 2023
db9dc5d
add print statements
patrickvonplaten Jan 24, 2023
9279207
finish more
patrickvonplaten Jan 24, 2023
e4c7013
correct some test failures
patrickvonplaten Jan 24, 2023
7ed509a
more fixes
patrickvonplaten Jan 24, 2023
b62aa7e
set up to correct tests
patrickvonplaten Jan 24, 2023
87269e0
more corrections
patrickvonplaten Jan 24, 2023
b54dd9e
up
patrickvonplaten Jan 24, 2023
37cab0c
fix more
patrickvonplaten Jan 24, 2023
70786f5
more prints
patrickvonplaten Jan 24, 2023
32db721
add
patrickvonplaten Jan 24, 2023
5d7b5d8
up
patrickvonplaten Jan 24, 2023
a70197b
up
patrickvonplaten Jan 24, 2023
3ada1ae
up
patrickvonplaten Jan 24, 2023
93ea032
uP
patrickvonplaten Jan 24, 2023
ef45927
uP
patrickvonplaten Jan 24, 2023
2d26ab5
more fixes
patrickvonplaten Jan 25, 2023
6ca6e11
uP
patrickvonplaten Jan 25, 2023
4fffb40
up
patrickvonplaten Jan 25, 2023
4cb48c8
up
patrickvonplaten Jan 25, 2023
6971909
up
patrickvonplaten Jan 25, 2023
d133442
up
patrickvonplaten Jan 25, 2023
bfd3a3f
fix more
patrickvonplaten Jan 25, 2023
44a5f60
up
patrickvonplaten Jan 25, 2023
ecf9dd5
up
patrickvonplaten Jan 25, 2023
8722a34
clean tests
patrickvonplaten Jan 25, 2023
ab440a6
up
patrickvonplaten Jan 25, 2023
3def149
up
patrickvonplaten Jan 25, 2023
5f312d8
up
patrickvonplaten Jan 25, 2023
7c212ea
more fixes
patrickvonplaten Jan 25, 2023
c285a67
Apply suggestions from code review
patrickvonplaten Jan 25, 2023
a2802e5
make
patrickvonplaten Jan 25, 2023
35ee141
correct
patrickvonplaten Jan 25, 2023
87e172f
finish
patrickvonplaten Jan 25, 2023
b77cca9
Merge branch 'make_tests_deterministic' of https://github.com/hugging…
patrickvonplaten Jan 25, 2023
716d496
finish
patrickvonplaten Jan 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
title: Text-Guided Depth-to-Image
- local: using-diffusers/reusing_seeds
title: Reusing seeds for deterministic generation
- local: using-diffusers/reproducibility
title: Reproducibility
- local: using-diffusers/custom_pipeline_examples
title: Community Pipelines
- local: using-diffusers/contribute_pipeline
Expand Down
159 changes: 159 additions & 0 deletions docs/source/en/using-diffusers/reproducibility.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Reproducibility

Before reading about reproducibility for Diffusers, it is strongly recommended to take a look at
[PyTorch's statement about reproducibility](https://pytorch.org/docs/stable/notes/randomness.html).

PyTorch states that
> *completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms.*
While one can never expect the same results across platforms, one can expect results to be reproducible
across releases, platforms, etc... within a certain tolerance. However, this tolerance strongly varies
depending on the diffusion pipeline and checkpoint.

In the following, we show how to best control sources of randomness for diffusion models.

## Inference

During inference, diffusion pipelines heavily rely on random sampling operations, such as the creating the
gaussian noise tensors to be denoised and adding noise to the scheduling step.

Let's have a look at an example. We run the [DDIM pipeline](./api/pipelines/ddim.mdx)
for just two inference steps and return a numpy tensor to look into the numerical values of the output.

```python
from diffusers import DDIMPipeline
import numpy as np

model_id = "google/ddpm-cifar10-32"

# load model and scheduler
ddim = DDIMPipeline.from_pretrained(model_id)

# run pipeline for just two steps and return numpy tensor
image = ddim(num_inference_steps=2, output_type="np").images
print(np.abs(image).sum())
```

Running the above prints a value of 1464.2076, but running it again prints a different
value of 1495.1768. What is going on here? Every time the pipeline is run, gaussian noise
is created and step-wise denoised. To create the gaussian noise with [`torch.randn`](https://pytorch.org/docs/stable/generated/torch.randn.html), a different random seed is taken every time, thus leading to a different result.
This is a desired property of diffusion pipelines, as it means that the pipeline can create a different random image every time it is run. In many cases, one would like to generate the exact same image of a certain
run, for which case an instance of a [PyTorch generator](https://pytorch.org/docs/stable/generated/torch.randn.html) has to be passed:

```python
import torch
from diffusers import DDIMPipeline
import numpy as np

model_id = "google/ddpm-cifar10-32"

# load model and scheduler
ddim = DDIMPipeline.from_pretrained(model_id)

# create a generator for reproducibility
generator = torch.Generator(device="cpu").manual_seed(0)

# run pipeline for just two steps and return numpy tensor
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
print(np.abs(image).sum())
```

Running the above always prints a value of 1491.1711 - also upon running it again because we
define the generator object to be passed to all random functions of the pipeline.

If you run this code snippet on your specific hardware and version, you should get a similar, if not the same, result.

<Tip>

It might be a bit unintuitive at first to pass `generator` objects to the pipelines instead of
just integer values representing the seed, but this is the recommended design when dealing with
probabilistic models in PyTorch as generators are *random states* that are advanced and can thus be
passed to multiple pipelines in a sequence.

</Tip>

Great! Now, we know how to write reproducible pipelines, but it gets a bit trickier since the above example only runs on the CPU. How do we also achieve reproducibility on GPU?
In short, one should not expect full reproducibility across different hardware when running pipelines on GPU
as matrix multiplications are less deterministic on GPU than on CPU and diffusion pipelines tend to require
a lot of matrix multiplications. Let's see what we can do to keep the randomness within limits across
different GPU hardware.

To achieve maximum speed performance, it is recommended to create the generator directly on GPU when running
the pipeline on GPU:

```python
import torch
from diffusers import DDIMPipeline
import numpy as np

model_id = "google/ddpm-cifar10-32"

# load model and scheduler
ddim = DDIMPipeline.from_pretrained(model_id)
ddim.to("cuda")

# create a generator for reproducibility
generator = torch.Generator(device="cuda").manual_seed(0)

# run pipeline for just two steps and return numpy tensor
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
print(np.abs(image).sum())
```

Running the above now prints a value of 1389.8634 - even though we're using the exact same seed!
This is unfortunate as it means we cannot reproduce the results we achieved on GPU, also on CPU.
Nevertheless, it should be expected since the GPU uses a different random number generator than the CPU.

To circumvent this problem, we created a [`randn_tensor`](#diffusers.utils.randn_tensor) function, which can create random noise
on the CPU and then move the tensor to GPU if necessary. The function is used everywhere inside the pipelines allowing the user to **always** pass a CPU generator even if the pipeline is run on GPU:

```python
import torch
from diffusers import DDIMPipeline
import numpy as np

model_id = "google/ddpm-cifar10-32"

# load model and scheduler
ddim = DDIMPipeline.from_pretrained(model_id)
ddim.to("cuda")

# create a generator for reproducibility
generator = torch.manual_seed(0)

# run pipeline for just two steps and return numpy tensor
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
print(np.abs(image).sum())
```

Running the above now prints a value of 1491.1713, much closer to the value of 1491.1711 when
the pipeline is fully run on the CPU.

<Tip>

As a consequence, we recommend always passing a CPU generator if Reproducibility is important.
The loss of performance is often neglectable, but one can be sure to generate much more similar
values than if the pipeline would have been run on CPU.

</Tip>

Finally, we noticed that more complex pipelines, such as [`UnCLIPPipeline`] are often extremely
susceptible to precision error propagation and thus one cannot expect even similar results across
different GPU hardware or PyTorch versions. In such cases, one has to make sure to run
exactly the same hardware and PyTorch version for full Reproducibility.

## Randomness utilities

### randn_tensor
[[autodoc]] diffusers.utils.randn_tensor
20 changes: 1 addition & 19 deletions src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

from ...schedulers import DDIMScheduler
from ...utils import deprecate, randn_tensor
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


Expand Down Expand Up @@ -78,24 +78,6 @@ def __call__(
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""

if (
generator is not None
and isinstance(generator, torch.Generator)
and generator.device.type != self.device.type
and self.device.type != "mps"
):
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be ignored. "
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
)
deprecate(
"generator.device == 'cpu'",
"0.13.0",
message,
)
generator = None

# Sample gaussian noise to begin loop
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
load_numpy,
nightly,
parse_flag_from_env,
print_tensor_test,
require_torch_gpu,
slow,
torch_all_close,
Expand Down
25 changes: 22 additions & 3 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from distutils.util import strtobool
from io import BytesIO, StringIO
from pathlib import Path
from typing import Union
from typing import Optional, Union

import numpy as np

Expand Down Expand Up @@ -45,6 +45,21 @@ def torch_all_close(a, b, *args, **kwargs):
return True


def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_name="expected_slice"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW @pcuenca @patil-suraj @yiyixuxu @williamberman

I added some functionality that helps to automatically correct slice-comparison tests.
All you need to do here is to add:

print_tensor_test(image_slice)

and it'll write the correct tensor to a file: filename with the name expected_tensor_name.

The utils file here: https://github.com/huggingface/diffusers/pull/1924/files#diff-bd9ba404f47129f2218fc94c5e528cc81bfd10cef1eff1d6d543f6ffa0143367 can then take this file and will automatically overwrite the test with the correct results. This should save quite some time going forward :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's very cool, thanks a lot !

test_name = os.environ.get("PYTEST_CURRENT_TEST")
if not torch.is_tensor(tensor):
tensor = torch.from_numpy(tensor)

tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
# format is usually:
# expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161])
output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array")
test_file, test_class, test_fn = test_name.split("::")
test_fn = test_fn.split()[0]
with open(filename, "a") as f:
print(";".join([test_file, test_class, test_fn, output_str]), file=f)


def get_tests_dir(append_path=None):
"""
Args:
Expand Down Expand Up @@ -150,9 +165,13 @@ def require_onnxruntime(test_case):
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)


def load_numpy(arry: Union[str, np.ndarray]) -> np.ndarray:
def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
if isinstance(arry, str):
if arry.startswith("http://") or arry.startswith("https://"):
# local_path = "/home/patrick_huggingface_co/"
if local_path is not None:
# local_path can be passed to correct images of tests
return os.path.join(local_path, "/".join([arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]]))
elif arry.startswith("http://") or arry.startswith("https://"):
response = requests.get(arry)
response.raise_for_status()
arry = np.load(BytesIO(response.content))
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False)

def get_generator(self, seed=0):
if torch_device == "mps":
return torch.Generator().manual_seed(seed)
return torch.manual_seed(seed)
return torch.Generator(device=torch_device).manual_seed(seed)

@parameterized.expand(
Expand Down
51 changes: 9 additions & 42 deletions tests/pipelines/altdiffusion/test_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def test_alt_diffusion_pndm(self):
expected_slice = np.array(
[0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
)

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2


Expand All @@ -207,20 +208,16 @@ def test_alt_diffusion(self):
alt_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = alt_pipe(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
)
generator = torch.manual_seed(0)
output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np")

image = output.images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.8720703, 0.87109375, 0.87402344, 0.87109375, 0.8779297, 0.8925781, 0.8823242, 0.8808594, 0.8613281]
)
expected_slice = np.array([0.1010, 0.0800, 0.0794, 0.0885, 0.0843, 0.0762, 0.0769, 0.0729, 0.0586])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_alt_diffusion_fast_ddim(self):
Expand All @@ -231,44 +228,14 @@ def test_alt_diffusion_fast_ddim(self):
alt_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)

with torch.autocast("cuda"):
output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
image = output.images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.9267578, 0.9301758, 0.9013672, 0.9345703, 0.92578125, 0.94433594, 0.9423828, 0.9423828, 0.9160156]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_alt_diffusion_text2img_pipeline_fp16(self):
torch.cuda.reset_peak_memory_stats()
model_id = "BAAI/AltDiffusion"
pipe = AltDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

prompt = "a photograph of an astronaut riding a horse"
expected_slice = np.array([0.4019, 0.4052, 0.3810, 0.4119, 0.3916, 0.3982, 0.4651, 0.4195, 0.5323])

generator = torch.Generator(device=torch_device).manual_seed(0)
output_chunked = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images

generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images

# Make sure results are close enough
diff = np.abs(image_chunked.flatten() - image.flatten())
# They ARE different since ops are not run always at the same precision
# however, they should be extremely close.
assert diff.mean() < 2e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
10 changes: 6 additions & 4 deletions tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def test_stable_diffusion_img2img_default_case(self):
expected_slice = np.array(
[0.41293705, 0.38656747, 0.40876025, 0.4782187, 0.4656803, 0.41394007, 0.4142093, 0.47150758, 0.4570448]
)

assert np.abs(image_slice.flatten() - expected_slice).max() < 1.5e-3
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1.5e-3

Expand Down Expand Up @@ -196,7 +197,7 @@ def test_stable_diffusion_img2img_fp16(self):
alt_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
image = alt_pipe(
[prompt],
generator=generator,
Expand Down Expand Up @@ -227,7 +228,7 @@ def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):

prompt = "A fantasy landscape, trending on artstation"

generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
Expand All @@ -241,7 +242,8 @@ def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
image_slice = image[255:258, 383:386, -1]

assert image.shape == (504, 760, 3)
expected_slice = np.array([0.3252, 0.3340, 0.3418, 0.3263, 0.3346, 0.3300, 0.3163, 0.3470, 0.3427])
expected_slice = np.array([0.9358, 0.9397, 0.9599, 0.9901, 1.0000, 1.0000, 0.9882, 1.0000, 1.0000])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3


Expand Down Expand Up @@ -275,7 +277,7 @@ def test_stable_diffusion_img2img_pipeline_default(self):

prompt = "A fantasy landscape, trending on artstation"

generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
Expand Down
Loading