Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 0 additions & 56 deletions src/diffusers/models/attention2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,62 +32,6 @@ def forward(self, x):
return self.to_out(out)


# unet.py
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels

self.norm = normalization(in_channels, swish=None, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)

def forward(self, x):
print("x", x.abs().sum())
h_ = x
h_ = self.norm(h_)

print("hid_states shape", h_.shape)
print("hid_states", h_.abs().sum())
print("hid_states - 3 - 3", h_.view(h_.shape[0], h_.shape[1], -1)[:, :3, -3:])

q = self.q(h_)
k = self.k(h_)
v = self.v(h_)

print(self.q)
print("q_shape", q.shape)
print("q", q.abs().sum())
# print("k_shape", k.shape)
# print("k", k.abs().sum())
# print("v_shape", v.shape)
# print("v", v.abs().sum())

# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw

w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)

print("weight", w_.abs().sum())

# attend to values
v = v.reshape(b, c, h * w)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)

h_ = self.proj_out(h_)

return x + h_


# unet_glide.py & unet_ldm.py
class AttentionBlock(nn.Module):
"""
Expand Down
43 changes: 1 addition & 42 deletions src/diffusers/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
from .attention2d import AttnBlock, AttentionBlock
from .attention2d import AttentionBlock


def nonlinearity(x):
Expand Down Expand Up @@ -86,44 +86,6 @@ def forward(self, x, temb):
return x + h


#class AttnBlock(nn.Module):
# def __init__(self, in_channels):
# super().__init__()
# self.in_channels = in_channels
#
# self.norm = Normalize(in_channels)
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x):
# h_ = x
# h_ = self.norm(h_)
# q = self.q(h_)
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# b, c, h, w = q.shape
# q = q.reshape(b, c, h * w)
# q = q.permute(0, 2, 1) # b,hw,c
# k = k.reshape(b, c, h * w) # b,c,hw
# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# v = v.reshape(b, c, h * w)
# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
# h_ = h_.reshape(b, c, h, w)
#
# h_ = self.proj_out(h_)
#
# return x + h_


class UNetModel(ModelMixin, ConfigMixin):
def __init__(
self,
Expand Down Expand Up @@ -186,7 +148,6 @@ def __init__(
)
block_in = block_out
if curr_res in attn_resolutions:
# attn.append(AttnBlock(block_in))
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
down = nn.Module()
down.block = block
Expand All @@ -202,7 +163,6 @@ def __init__(
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
# self.mid.attn_1 = AttnBlock(block_in)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
Expand All @@ -228,7 +188,6 @@ def __init__(
)
block_in = block_out
if curr_res in attn_resolutions:
# attn.append(AttnBlock(block_in))
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
up = nn.Module()
up.block = block
Expand Down
17 changes: 8 additions & 9 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,25 +858,26 @@ def test_ddpm_cifar10(self):
image_slice = image[0, -1, -3:, -3:].cpu()

assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761])
expected_slice = torch.tensor([0.2249, 0.3375, 0.2359, 0.0929, 0.3439, 0.3156, 0.1937, 0.3585, 0.1761])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

@slow
def test_ddim_cifar10(self):
generator = torch.manual_seed(0)
model_id = "fusing/ddpm-cifar10"

unet = UNetModel.from_pretrained(model_id)
noise_scheduler = DDIMScheduler(tensor_format="pt")

ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)

generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0)

image_slice = image[0, -1, -3:, -3:].cpu()

assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]
[-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

Expand All @@ -895,7 +896,7 @@ def test_pndm_cifar10(self):

assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471]
[-0.7925, -0.7902, -0.7789, -0.7796, -0.8000, -0.7596, -0.6852, -0.7125, -0.7494]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

Expand Down Expand Up @@ -966,24 +967,22 @@ def test_grad_tts(self):

@slow
def test_score_sde_ve_pipeline(self):
torch.manual_seed(0)

model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp")
scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")

sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)

torch.manual_seed(0)
image = sde_ve(num_inference_steps=2)

expected_image_sum = 3382810112.0
expected_image_mean = 1075.366455078125
expected_image_sum = 3382849024.0
expected_image_mean = 1075.3788

assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4

@slow
def test_score_sde_vp_pipeline(self):

model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp")

Expand Down