Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ def __call__(

if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
image = torch.randn(image_shape, generator=generator, dtype=self.unet.dtype)
image = image.to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=self.device)
image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)

# set step values
self.scheduler.set_timesteps(num_inference_steps)
Expand Down
255 changes: 92 additions & 163 deletions tests/pipelines/altdiffusion/test_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import gc
import random
import unittest

import numpy as np
Expand All @@ -25,9 +24,9 @@
RobertaSeriesConfig,
RobertaSeriesModelWithTransformation,
)
from diffusers.utils import floats_tensor, slow, torch_device
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import XLMRobertaTokenizer
from transformers import CLIPTextConfig, CLIPTextModel, XLMRobertaTokenizer

from ...test_pipelines_common import PipelineTesterMixin

Expand All @@ -36,25 +35,11 @@


class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
pipeline_class = AltDiffusionPipeline

@property
def dummy_image(self):
batch_size = 1
num_channels = 3
sizes = (32, 32)

image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image

@property
def dummy_cond_unet(self):
def get_dummy_components(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
Expand All @@ -64,202 +49,146 @@ def dummy_cond_unet(self):
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
return model

@property
def dummy_cond_unet_inpaint(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=9,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
return model

@property
def dummy_vae(self):
torch.manual_seed(0)
model = AutoencoderKL(
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
return model

@property
def dummy_text_encoder(self):
# TODO: address the non-deterministic text encoder (fails for save-load tests)
# torch.manual_seed(0)
# text_encoder_config = RobertaSeriesConfig(
# hidden_size=32,
# project_dim=32,
# intermediate_size=37,
# layer_norm_eps=1e-05,
# num_attention_heads=4,
# num_hidden_layers=5,
# vocab_size=5002,
# )
# text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)

torch.manual_seed(0)
config = RobertaSeriesConfig(
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
Comment on lines +69 to +85
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj the AltDiffusion pipeline produces non-matching outputs if I run it with the same inputs twice. Replacing RobertaSeriesModelWithTransformation with CLIPTextModel helped, so it's probably that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm weird, could you maybe open a seperate issue for this? I can look into it :-) We should test with RobertaSeriesConfig here IMO :-)

hidden_size=32,
project_dim=32,
projection_dim=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=5002,
)
return RobertaSeriesModelWithTransformation(config)
text_encoder = CLIPTextModel(text_encoder_config)

@property
def dummy_extractor(self):
def extract(*args, **kwargs):
class Out:
def __init__(self):
self.pixel_values = torch.ones([0])

def to(self, device):
self.pixel_values.to(device)
return self

return Out()
tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
tokenizer.model_max_length = 77

return extract
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components

def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "numpy",
}
return inputs

def test_alt_diffusion_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)

vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
tokenizer.model_max_length = 77

# make sure here that pndm scheduler skips prk
alt_pipe = AltDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
components = self.get_dummy_components()
torch.manual_seed(0)
text_encoder_config = RobertaSeriesConfig(
hidden_size=32,
project_dim=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
vocab_size=5002,
)
# TODO: remove after fixing the non-deterministic text encoder
text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
components["text_encoder"] = text_encoder

alt_pipe = AltDiffusionPipeline(**components)
alt_pipe = alt_pipe.to(device)
alt_pipe.set_progress_bar_config(disable=None)

prompt = "A photo of an astronaut"

generator = torch.Generator(device=device).manual_seed(0)
output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = "A photo of an astronaut"
output = alt_pipe(**inputs)
image = output.images

generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = alt_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]

image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)
expected_slice = np.array(
[0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093]
)

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

def test_alt_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
tokenizer.model_max_length = 77

# make sure here that pndm scheduler skips prk
alt_pipe = AltDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
torch.manual_seed(0)
text_encoder_config = RobertaSeriesConfig(
hidden_size=32,
project_dim=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
vocab_size=5002,
)
# TODO: remove after fixing the non-deterministic text encoder
text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
components["text_encoder"] = text_encoder
alt_pipe = AltDiffusionPipeline(**components)
alt_pipe = alt_pipe.to(device)
alt_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")

inputs = self.get_dummy_inputs(device)
output = alt_pipe(**inputs)
image = output.images

generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = alt_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]

image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)
expected_slice = np.array(
[0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_alt_diffusion_fp16(self):
"""Test that stable diffusion works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
tokenizer.model_max_length = 77

# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()

# make sure here that pndm scheduler skips prk
alt_pipe = AltDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
alt_pipe = alt_pipe.to(torch_device)
alt_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images

assert image.shape == (1, 64, 64, 3)


@slow
Expand Down
4 changes: 1 addition & 3 deletions tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import XLMRobertaTokenizer

from ...test_pipelines_common import PipelineTesterMixin


torch.backends.cuda.matmul.allow_tf32 = False


class AltDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
Expand Down
Empty file.
Empty file.
Loading