Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ with autocast("cuda"):
image = pipe(prompt).images[0]
```

If you are limited by GPU memory, you might want to consider using the model in `fp16`.
If you are limited by GPU memory, you might want to consider using the model in `fp16` as
well as chunking the attention computation.
The following snippet should result in less than 4GB VRAM.

```python
pipe = StableDiffusionPipeline.from_pretrained(
Expand All @@ -116,6 +118,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_attention_slicing()
with autocast("cuda"):
image = pipe(prompt).images[0]
```
Expand Down
60 changes: 42 additions & 18 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,19 @@ def forward(self, hidden_states):

# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))

attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)

# compute attention output
context_states = torch.matmul(attention_probs, value_states)
hidden_states = torch.matmul(attention_probs, value_states)

context_states = context_states.permute(0, 2, 1, 3).contiguous()
new_context_states_shape = context_states.size()[:-2] + (self.channels,)
context_states = context_states.view(new_context_states_shape)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)

# compute next hidden_states
hidden_states = self.proj_attn(context_states)
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)

# res connect and rescale
Expand Down Expand Up @@ -107,6 +108,10 @@ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_d

self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)

def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)

def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
Expand Down Expand Up @@ -136,6 +141,10 @@ def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint

def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size

def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
Expand All @@ -151,6 +160,10 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.

self.scale = dim_head**-0.5
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self._slice_size = None

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
Expand All @@ -175,8 +188,6 @@ def reshape_batch_dim_to_heads(self, tensor):
def forward(self, x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape

h = self.heads

q = self.to_q(x)
context = context if context is not None else x
k = self.to_k(context)
Expand All @@ -186,20 +197,33 @@ def forward(self, x, context=None, mask=None):
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)

sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

if mask is not None:
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
hidden_states = self._attention(q, k, v, sequence_length, dim)

return self.to_out(hidden_states)

out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)
return self.to_out(out)
def _attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
)
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])

hidden_states[start_idx:end_idx] = attn_slice

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states


class FeedForward(nn.Module):
Expand Down
22 changes: 22 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,28 @@ def __init__(
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

def set_attention_slice(self, slice_size):
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
)
if slice_size is not None and slice_size > self.config.attention_head_dim:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
)

for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)

self.mid_block.set_attention_slice(slice_size)

for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)

def forward(
self,
sample: torch.FloatTensor,
Expand Down
48 changes: 48 additions & 0 deletions src/diffusers/models/unet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def __init__(
super().__init__()

self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)

# there is always at least one resnet
Expand Down Expand Up @@ -342,6 +343,21 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)

def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)

for attn in self.attentions:
attn._set_attention_slice(slice_size)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
Expand Down Expand Up @@ -457,6 +473,7 @@ def __init__(
attentions = []

self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels

for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
Expand Down Expand Up @@ -497,6 +514,21 @@ def __init__(
else:
self.downsamplers = None

def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)

for attn in self.attentions:
attn._set_attention_slice(slice_size)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()

Expand Down Expand Up @@ -989,6 +1021,7 @@ def __init__(
attentions = []

self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels

for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
Expand Down Expand Up @@ -1025,6 +1058,21 @@ def __init__(
else:
self.upsamplers = None

def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)

for attn in self.attentions:
attn._set_attention_slice(slice_size)

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
for resnet, attn in zip(self.resnets, self.attentions):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ def __init__(
feature_extractor=feature_extractor,
)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slice(None)

@torch.no_grad()
def __call__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def __init__(
feature_extractor=feature_extractor,
)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slice(None)

@torch.no_grad()
def __call__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def __init__(
feature_extractor=feature_extractor,
)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slice(None)

@torch.no_grad()
def __call__(
self,
Expand Down
73 changes: 72 additions & 1 deletion tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
chunk_size_feed_forward=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
Expand Down Expand Up @@ -410,6 +409,38 @@ def test_stable_diffusion_k_lms(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_attention_chunk(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")

# make sure chunking the attention yields the same result
sd_pipe.enable_attention_slicing(slice_size=1)
generator = torch.Generator(device=device).manual_seed(0)
output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")

assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4

def test_score_sde_ve_pipeline(self):
unet = self.dummy_uncond_unet
scheduler = ScoreSdeVeScheduler(tensor_format="pt")
Expand Down Expand Up @@ -1045,6 +1076,46 @@ def test_lms_stable_diffusion_pipeline(self):
expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_memory_chunking(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
).to(torch_device)
pipe.set_progress_bar_config(disable=None)

prompt = "a photograph of an astronaut riding a horse"

# make attention efficient
pipe.enable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output_chunked = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images

mem_bytes = torch.cuda.max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# make sure that less than 3.75 GB is allocated
assert mem_bytes < 3.75 * 10**9

# disable chunking
pipe.disable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images

# make sure that more than 3.75 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 3.75 * 10**9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_img2img_pipeline(self):
Expand Down