From afaefd2e5f2e391d8b84fe787916fbf4fb1ee740 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Mon, 5 Sep 2022 09:26:27 +0200 Subject: [PATCH 01/18] Update attention.py Run attention in a loop to allow for much higher resolutions (over 1920x1920 on a 3090) --- ldm/modules/attention.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f4eff39cc..2b7214c71 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -174,23 +174,34 @@ def forward(self, x, context=None, mask=None): context = default(context, x) k = self.to_k(context) v = self.to_v(context) + del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) + # valid values for steps = 2,4,8,16,32,64 + # higher steps is slower but less memory usage + # at 16 can run 1920x1536 on a 3090, at 64 can run over 1920x1920 + # speed seems to be impacted more on 30x series cards + steps = 16 + slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + s1 *= self.scale - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) + s2 = s1.softmax(dim=-1) + del s1 - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) class BasicTransformerBlock(nn.Module): From 5065b41ce12c3b043ba5196283a4907cd2e1df5b Mon Sep 17 00:00:00 2001 From: Doggettx Date: Mon, 5 Sep 2022 10:13:59 +0200 Subject: [PATCH 02/18] Update attention.py Correction to comment --- ldm/modules/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 2b7214c71..39fc20af0 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -182,7 +182,7 @@ def forward(self, x, context=None, mask=None): # valid values for steps = 2,4,8,16,32,64 # higher steps is slower but less memory usage - # at 16 can run 1920x1536 on a 3090, at 64 can run over 1920x1920 + # at 16 can run 1920x1536 on a 3090, at 32 can run over 1920x1920 # speed seems to be impacted more on 30x series cards steps = 16 slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1] From 8283bb5b84580487e7a9e25c37816484bf4ed42b Mon Sep 17 00:00:00 2001 From: Doggettx Date: Mon, 5 Sep 2022 12:03:59 +0200 Subject: [PATCH 03/18] Update attention.py --- ldm/modules/attention.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 39fc20af0..89d8b6db7 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -180,12 +180,18 @@ def forward(self, x, context=None, mask=None): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - # valid values for steps = 2,4,8,16,32,64 - # higher steps is slower but less memory usage - # at 16 can run 1920x1536 on a 3090, at 32 can run over 1920x1920 - # speed seems to be impacted more on 30x series cards - steps = 16 - slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1] + stats = torch.cuda.memory_stats(q.device) + mem_total = torch.cuda.get_device_properties(0).total_memory + mem_active = stats['active_bytes.all.current'] + mem_free = mem_total - mem_active + + mem_required = q.shape[0] * q.shape[1] * k.shape[1] * 4 * 2.5 + steps = 1 + + if mem_required > mem_free: + steps = 2**(math.ceil(math.log(mem_required / mem_free, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) @@ -204,6 +210,7 @@ def forward(self, x, context=None, mask=None): return self.to_out(r2) + class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() From d3c91ec937a4f1d4fc79b68875931bdb5550bb6e Mon Sep 17 00:00:00 2001 From: Doggettx Date: Mon, 5 Sep 2022 19:49:45 +0200 Subject: [PATCH 04/18] Fixed memory handling for model.decode_first_stage Better memory handling for model.decode_first_stage so it doesn't crash anymore after 100% rendering --- ldm/modules/attention.py | 11 +- ldm/modules/diffusionmodules/model.py | 142 +++++++++++++++++++------- 2 files changed, 112 insertions(+), 41 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 89d8b6db7..7d3f8c2be 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -170,13 +170,14 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def forward(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + k_in = self.to_k(context) + v_in = self.to_v(context) del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) @@ -203,6 +204,7 @@ def forward(self, x, context=None, mask=None): r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 + del q, k, v r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -210,7 +212,6 @@ def forward(self, x, context=None, mask=None): return self.to_out(r2) - class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 533e589a2..fd16dd50a 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import gc import math import torch import torch.nn as nn @@ -119,18 +120,30 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, padding=0) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h4 += self.temb_proj(nonlinearity(temb))[:,:,None,None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -138,7 +151,8 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - return x+h + h8 += x + return h8 class LinAttnBlock(LinearAttention): @@ -178,28 +192,61 @@ def __init__(self, in_channels): def forward(self, x): h_ = x h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) + q1 = self.q(h_) + k1 = self.k(h_) v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_total = torch.cuda.get_device_properties(0).total_memory + mem_active = stats['active_bytes.all.current'] + mem_free = mem_total - mem_active + + mem_required = q.shape[0] * q.shape[1] * k.shape[2] * 4 * 2.5 + steps = 1 + + if mem_required > mem_free: + steps = 2**(math.ceil(math.log(mem_required / mem_free, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w1 *= (int(c)**(-0.5)) + w2 = torch.nn.functional.softmax(w1, dim=2) + del w1 + + # attend to values + v1 = v.reshape(b, c, h*w) + w3 = w2.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w2 - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + h_[:, :, i:end] = torch.bmm(v1, w3) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w3 - h_ = self.proj_out(h_) + h2 = h_.reshape(b, c, h, w) + del h_ - return x+h_ + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 def make_attn(in_channels, attn_type="vanilla"): @@ -540,31 +587,54 @@ def forward(self, z): temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + gc.collect() + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t + if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: - h = torch.tanh(h) + t = h + h = torch.tanh(t) + del t + return h From 507ddec578d54ddf7eb39fac5d646c9937526565 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Tue, 6 Sep 2022 09:08:46 +0200 Subject: [PATCH 05/18] Fixed free memory calculation Old version gave incorrect free memory results causing in crashes on edge cases. --- ldm/modules/attention.py | 16 +++++++++++----- ldm/modules/diffusionmodules/model.py | 13 ++++++++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 7d3f8c2be..e6db2ddfc 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -182,15 +182,21 @@ def forward(self, x, context=None, mask=None): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) stats = torch.cuda.memory_stats(q.device) - mem_total = torch.cuda.get_device_properties(0).total_memory mem_active = stats['active_bytes.all.current'] - mem_free = mem_total - mem_active + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch - mem_required = q.shape[0] * q.shape[1] * k.shape[1] * 4 * 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 + mem_required = tensor_size * 2.5 steps = 1 - if mem_required > mem_free: - steps = 2**(math.ceil(math.log(mem_required / mem_free, 2))) + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + gb = 1024**3 + print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index fd16dd50a..7c78f465a 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -211,15 +211,18 @@ def forward(self, x): h_ = torch.zeros_like(k, device=q.device) stats = torch.cuda.memory_stats(q.device) - mem_total = torch.cuda.get_device_properties(0).total_memory mem_active = stats['active_bytes.all.current'] - mem_free = mem_total - mem_active + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch - mem_required = q.shape[0] * q.shape[1] * k.shape[2] * 4 * 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 + mem_required = tensor_size * 2.5 steps = 1 - if mem_required > mem_free: - steps = 2**(math.ceil(math.log(mem_required / mem_free, 2))) + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): From a1fbe55f85dd6e7e4fdb3c9081f8f272e3233b59 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Tue, 6 Sep 2022 09:09:49 +0200 Subject: [PATCH 06/18] Set model to half Set model to half in txt2img and img2img for less memory usage. --- scripts/img2img.py | 1 + scripts/txt2img.py | 1 + 2 files changed, 2 insertions(+) diff --git a/scripts/img2img.py b/scripts/img2img.py index 421e2151d..5b4537d4e 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -198,6 +198,7 @@ def main(): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 59c16a1db..28db4e78a 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -238,6 +238,7 @@ def main(): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) From 7a32fd649360aca42c12b411f05cd47f3bbb13ab Mon Sep 17 00:00:00 2001 From: Doggettx Date: Tue, 6 Sep 2022 09:12:09 +0200 Subject: [PATCH 07/18] Commented out debug info Forgot to comment out debug info --- ldm/modules/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index e6db2ddfc..f82038f67 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -194,9 +194,9 @@ def forward(self, x, context=None, mask=None): if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - gb = 1024**3 - print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + # gb = 1024**3 + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): From 755bec892369f04bccb430b05a8dcad97dbd6703 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Tue, 6 Sep 2022 09:39:10 +0200 Subject: [PATCH 08/18] Raise error when steps too high Technically you could run at higher steps as long as the resolution is dividable by the steps but you're going to run into memory issues later on anyhow. --- ldm/modules/attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f82038f67..f556c7bc0 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -188,16 +188,20 @@ def forward(self, x, context=None, mask=None): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch + gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 mem_required = tensor_size * 2.5 steps = 1 if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - # gb = 1024**3 # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + if steps > 64: + raise RuntimeError(f'Not enough memory, use lower resolution. ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size From f134e245ba8d53d2fa5e3268050dcc752ae1edf0 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Tue, 6 Sep 2022 10:43:19 +0200 Subject: [PATCH 09/18] Added max. res info to memory exception --- ldm/modules/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f556c7bc0..7ad8165c6 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -199,7 +199,8 @@ def forward(self, x, context=None, mask=None): # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") if steps > 64: - raise RuntimeError(f'Not enough memory, use lower resolution. ' + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] From 830f6946f9aa29a7683b1c3b4d92537c0aecc827 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Wed, 7 Sep 2022 08:43:36 +0200 Subject: [PATCH 10/18] Reverted in place tensor functions back to CompVis version Improves performance and is no longer needed. --- ldm/modules/attention.py | 4 ++-- ldm/modules/diffusionmodules/model.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 7ad8165c6..f848a7c75 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,3 +1,4 @@ +import gc from inspect import isfunction import math import torch @@ -206,8 +207,7 @@ def forward(self, x, context=None, mask=None): slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - s1 *= self.scale + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale s2 = s1.softmax(dim=-1) del s1 diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 7c78f465a..cd3328cbe 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -188,7 +188,6 @@ def __init__(self, in_channels): stride=1, padding=0) - def forward(self, x): h_ = x h_ = self.norm(h_) @@ -229,17 +228,18 @@ def forward(self, x): end = i + slice_size w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w1 *= (int(c)**(-0.5)) - w2 = torch.nn.functional.softmax(w1, dim=2) + w2 = w1 * (int(c)**(-0.5)) del w1 + w3 = torch.nn.functional.softmax(w2, dim=2) + del w2 # attend to values v1 = v.reshape(b, c, h*w) - w3 = w2.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w2 + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 - h_[:, :, i:end] = torch.bmm(v1, w3) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w3 + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 h2 = h_.reshape(b, c, h, w) del h_ From c2d72c5c23492343bce6c090dbfd20ae90006deb Mon Sep 17 00:00:00 2001 From: Doggettx Date: Wed, 7 Sep 2022 09:03:19 +0200 Subject: [PATCH 11/18] Missed one function to revert --- ldm/modules/diffusionmodules/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index cd3328cbe..9be8922e9 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -131,7 +131,7 @@ def forward(self, x, temb): del h3 if temb is not None: - h4 += self.temb_proj(nonlinearity(temb))[:,:,None,None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] h5 = self.norm2(h4) del h4 @@ -151,8 +151,7 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - h8 += x - return h8 + return x + h8 class LinAttnBlock(LinearAttention): From cd3d653f79cedc1849a02323f36b9b33fd089ff3 Mon Sep 17 00:00:00 2001 From: Doggettx <110817577+Doggettx@users.noreply.github.com> Date: Wed, 7 Sep 2022 12:29:15 +0200 Subject: [PATCH 12/18] Update README.md --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index c9e6c3bb1..169399aac 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,10 @@ +# Automatic memory management + +This version can be run stand alone, but it's more meant as proof of concept so other forks can implement similar changes. + +Allows to use resolutions that require up to 64x more VRAM than possible on the default CompVis build + + # Stable Diffusion *Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:* From 5fe97c69c9e3738d722690063c9492b297737cee Mon Sep 17 00:00:00 2001 From: Doggettx Date: Sat, 10 Sep 2022 11:11:05 +0200 Subject: [PATCH 13/18] Performance boost and fix sigmoid for higher resolutions Significant performance boost at higher resolutions when running in auto_cast or half mode on 3090 went from 1.13it/s to 1.63it/s at 1024x1024 Will also allow for higher resolutions due to sigmoid fix and using half memory --- ldm/modules/attention.py | 13 +++++++------ ldm/modules/diffusionmodules/model.py | 10 +++++++--- scripts/img2img.py | 3 +++ scripts/txt2img.py | 3 +++ 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f848a7c75..a0f4f18b8 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -90,7 +90,7 @@ def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -162,7 +162,6 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) @@ -190,14 +189,16 @@ def forward(self, x, context=None, mask=None): mem_free_total = mem_free_cuda + mem_free_torch gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 - mem_required = tensor_size * 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier steps = 1 + if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") if steps > 64: max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 @@ -209,7 +210,7 @@ def forward(self, x, context=None, mask=None): end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1) + s2 = s1.softmax(dim=-1, dtype=q.dtype) del s1 r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 9be8922e9..de3ce38c6 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -33,7 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - return x*torch.sigmoid(x) + t = torch.sigmoid(x) + x *= t + del t + + return x def Normalize(in_channels, num_groups=32): @@ -215,7 +219,7 @@ def forward(self, x): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() mem_required = tensor_size * 2.5 steps = 1 @@ -229,7 +233,7 @@ def forward(self, x): w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w2 = w1 * (int(c)**(-0.5)) del w1 - w3 = torch.nn.functional.softmax(w2, dim=2) + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) del w2 # attend to values diff --git a/scripts/img2img.py b/scripts/img2img.py index 5b4537d4e..04b88b54f 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -196,6 +196,9 @@ def main(): opt = parser.parse_args() seed_everything(opt.seed) + # needed when model is in half mode, remove if not using half mode + torch.set_default_tensor_type(torch.HalfTensor) + config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") model = model.half() diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 28db4e78a..a08f52298 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -236,6 +236,9 @@ def main(): seed_everything(opt.seed) + # needed when model is in half mode, remove if not using half mode + torch.set_default_tensor_type(torch.HalfTensor) + config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") model = model.half() From ccb17b55f2e7acbd1a112b55fb8f8415b4862521 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Thu, 15 Sep 2022 18:28:27 +0200 Subject: [PATCH 14/18] prompt2prompt easier to implement, more options Only need to wrap the model now with PromptGuidanceModelWrapper, and call prepare_prompts. No need to change samplers anymore, for example see changes in txt2img.py special format inside prompts: [sentence1:sentence2:step] will swap sentence1 (or sentence) for sentence2 at step [sentence:step] will add sentence at step [:sentence:step] will remove sentence at step [sentence] will add sentence at step 0 (only useful for negative prompts) when a sentence starts with - it will be seen as a negative prompt {scale:step} will switch to defined guidance scale at step, does not work if initial guidance scale was 1.0 --- .gitignore | 6 + scripts/prompt_parser.py | 292 +++++++++++++++++++++++++++++++++++++++ scripts/txt2img.py | 16 ++- 3 files changed, 307 insertions(+), 7 deletions(-) create mode 100644 .gitignore create mode 100644 scripts/prompt_parser.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..a618fafc7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ + +*.pyc +*.ckpt +outputs/ +src/ +.idea/ diff --git a/scripts/prompt_parser.py b/scripts/prompt_parser.py new file mode 100644 index 000000000..44ff8d2c1 --- /dev/null +++ b/scripts/prompt_parser.py @@ -0,0 +1,292 @@ +import torch +import numpy as np +import re +from abc import abstractmethod, ABC +from dataclasses import dataclass +from typing import TypeVar, List, Type, Optional +from ldm.models.diffusion.ddpm import LatentDiffusion + +T = TypeVar('T') + + +def parse_float(text: str) -> float: + try: + return float(text) + except ValueError: + return 0. + + +def get_step(weight: float, steps: int) -> int: + step = int(weight) if weight >= 1. else int(weight * steps) + return max(0, min(steps-1, step)) + + +def parse_step(text: str, steps: int) -> int: + return get_step(parse_float(text), steps) + + +@dataclass +class Guidance: + scale: Optional[float] + cond: Optional[object] + uncond: Optional[object] + prompt: str + + def apply(self, cond: object, uncond: object, scale: float): + if self.cond is not None: + cond = self.cond + + if self.scale is not None: + scale = self.scale + + if self.uncond is not None: + uncond = self.uncond + + return cond, uncond, scale + + +@dataclass +class Prompt: + pos_text: str + neg_text: str + step: int + + +class Token(ABC): + @staticmethod + @abstractmethod + def starts_with(char: str) -> bool: + pass + + @staticmethod + @abstractmethod + def create(text: str, steps: int): + pass + + +class TextToken(Token): + @staticmethod + def create(text: str, steps: int): + return text + + @staticmethod + def starts_with(char: str) -> bool: + return True + + +@dataclass +class CommandToken(Token): + method: str + args: List[str] + step: int + + @staticmethod + def create(text: str, steps: int): + return None + + @staticmethod + def starts_with(char: str) -> bool: + return char == '@' + + +@dataclass +class SwapToken(Token): + word1: str = "" + word2: str = "" + step: int = 0 + + @staticmethod + def create(text: str, steps: int): + value = text[1:-1] + fields = str.split(value, ':') + if len(fields) < 2: + return SwapToken(word2=value) + if len(fields) == 2: + return SwapToken(word2=fields[0], step=parse_step(fields[1], steps)) + else: + return SwapToken(word1=fields[0], word2=fields[1], step=parse_step(fields[2], steps)) + + @staticmethod + def starts_with(char: str) -> bool: + return char == '[' + + +@dataclass +class ScaleToken(Token): + scale: float = -1. + step: int = 0 + + @staticmethod + def create(text: str, steps: int): + fields = str.split(text[1:-1], ':') + if len(fields) != 2: + return ScaleToken() + + return ScaleToken(scale=parse_float(fields[0]), step=parse_step(fields[1], steps)) + + @staticmethod + def starts_with(char: str) -> bool: + return char == '{' + + +def filter_type(array: list, dtype: Type[T]) -> List[T]: + return [item for item in array if type(item) is dtype] + + +class PromptParser: + def __init__(self, model): + self.model = model + self.regex = re.compile(r'\[.*?]|\{.*?}|.+?(?=[\[{])|.*') + self.tokens = [SwapToken, ScaleToken, TextToken] + + # test regex for commands, not used yet + # \[.*?]|\{.+?}|@[^\s(]+\(.*?\)|.+?(?=[\[{@])|.+ + + def get_prompt_guidance(self, prompt, steps, batch_size) -> List[Guidance]: + result: List[Guidance] = list() + + # initialize array + for i in range(0, steps): + result.append(Guidance(None, None, None, "")) + + cur_pos = "" + cur_neg = "" + # set prompts + print("Used prompts:") + for item in self.__parse_prompt(prompt, steps): + if item.pos_text != cur_pos: + print(f'step {item.step}: "{item.pos_text}"') + result[item.step].cond = self.model.get_learned_conditioning(batch_size * item.pos_text) + cur_pos = item.pos_text + + if item.neg_text != cur_neg: + print(f'step {item.step}: [negative] "{item.neg_text}"') + result[item.step].uncond = self.model.get_learned_conditioning(batch_size * item.neg_text) + cur_neg = item.neg_text + + result[item.step].prompt = cur_pos + + # set scales + for scale in self.__get_scales(prompt, steps): + result[scale.step].scale = scale.scale + + return result + + def __get_scales(self, prompt: str, steps: int) -> List[ScaleToken]: + tokens = self.__get_tokens(prompt, steps) + scales = filter_type(tokens, ScaleToken) + + return scales + + def __get_word_info(self, word: str) -> (str, bool): + if len(word) == 0: + return word, False + + if word[0] == '-': + return word[1:], False + + return word, True + + def __parse_prompt(self, prompt, steps) -> List[Prompt]: + tokens = self.__get_tokens(prompt, steps) + values = np.array([token.step for token in filter_type(tokens, SwapToken)]) + values = np.concatenate(([0], values)) + values = np.sort(np.unique(values)) + + builders = [(value, list(), list()) for value in values] + + for token in tokens: + if type(token) is SwapToken: + word1, is_pos1 = self.__get_word_info(token.word1) + word2, is_pos2 = self.__get_word_info(token.word2) + + if not len(word2): + is_pos2 = is_pos1 + + for (value, pos_text, neg_text) in builders: + if value < token.step: + is_pos, word = is_pos1, word1 + else: + is_pos, word = is_pos2, word2 + + builder = pos_text if is_pos else neg_text + builder.append(word) + + elif type(token) is str: + for _, pos_text, _ in builders: + pos_text.append(token) + + return [Prompt(pos_text=''.join(pos_text), neg_text=''.join(neg_text), step=int(value)) + for value, pos_text, neg_text in builders] + + def __get_tokens(self, prompt: str, steps: int): + parts = self.regex.findall(prompt) + result = list() + + for part in parts: + if len(part) == 0: + continue + + for token in self.tokens: + if token.starts_with(part[0]): + result.append(token.create(part, steps)) + break + + return result + + +class PromptGuidanceModelWrapper: + def __init__(self, model: LatentDiffusion): + self.model = model + + self.__step: int = None + self.prompt_guidance: List[Guidance] = None + self.scale: float = 0. + self.init_scale: float = 0. + self.c = None + self.uc = None + self.parser = PromptParser(model) + + def __getattr__(self, attr): + return getattr(self.model, attr) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if self.prompt_guidance is None: + raise RuntimeError("Wrapper not prepared, make sure to call prepare before using the model") + + if self.__step < len(self.prompt_guidance): + self.c, self.uc, self.scale = \ + self.prompt_guidance[self.__step].apply(self.c, self.uc, self.scale) + + has_unconditional = len(cond) == 2 + if has_unconditional: + cond[0] = self.uc + cond[1] = self.c + else: + cond = self.c + + result = self.model.apply_model(x_noisy, t, cond, return_ids) + + if has_unconditional and self.scale != self.init_scale: + e_t_uncond, e_t = result.chunk(2) + e_diff = e_t - e_t_uncond + e_t = e_t_uncond + (self.scale / self.init_scale) * e_diff + result = torch.cat([e_t_uncond, e_t]) + + self.__step += 1 + + return result + + def prepare_prompts(self, prompt: str, scale: float, steps: int, batch_size: int): + self.__step = 0 + + self.prompt_guidance = self.parser.get_prompt_guidance(prompt, steps, batch_size) + + uc = self.model.get_learned_conditioning(batch_size * [""]) + c, uc, scale = self.prompt_guidance[0].apply(uc, uc, scale) + + self.init_scale = scale + self.scale = scale + self.c = c + self.uc = uc + diff --git a/scripts/txt2img.py b/scripts/txt2img.py index a08f52298..0998ce429 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -20,6 +20,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor +from prompt_parser import PromptGuidanceModelWrapper # load safety model @@ -151,7 +152,7 @@ def main(): parser.add_argument( "--n_iter", type=int, - default=2, + default=1, help="sample this often", ) parser.add_argument( @@ -181,7 +182,7 @@ def main(): parser.add_argument( "--n_samples", type=int, - default=3, + default=1, help="how many samples to produce for each given prompt. A.k.a. batch size", ) parser.add_argument( @@ -245,6 +246,7 @@ def main(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) + model = PromptGuidanceModelWrapper(model) if opt.plms: sampler = PLMSSampler(model) @@ -290,19 +292,19 @@ def main(): for n in trange(opt.n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) - c = model.get_learned_conditioning(prompts) + + model.prepare_prompts(prompts[0], opt.scale, opt.ddim_steps, opt.n_samples) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, + conditioning=model.c, batch_size=opt.n_samples, shape=shape, verbose=False, unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, + unconditional_conditioning=model.uc, eta=opt.ddim_eta, x_T=start_code) From 223736adda6c1db3fbe71c45c883aee7b36d272c Mon Sep 17 00:00:00 2001 From: Doggettx Date: Thu, 6 Oct 2022 10:51:35 +0200 Subject: [PATCH 15/18] Update attention.py Changed attention to code like used in diffusers --- ldm/modules/attention.py | 106 ++++++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 35 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index a0f4f18b8..6c64ca79f 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -151,6 +151,8 @@ def forward(self, x): class CrossAttention(nn.Module): + MAX_STEPS = 64 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads @@ -167,61 +169,95 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - k_in = self.to_k(context) - v_in = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - - stats = torch.cuda.memory_stats(q.device) + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def get_mem_free(self, device): + stats = torch.cuda.memory_stats(device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_cuda + mem_free_torch - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 + def get_mem_required(self, batch_size_attention, sequence_length, element_size, multiplier): + tensor_size = batch_size_attention * sequence_length**2 * element_size + return tensor_size * multiplier + def get_slice_size(self, device, batch_size_attention, sequence_length, element_size): + multiplier = 3. if element_size == 2 else 2.5 + mem_free = self.get_mem_free(device) + mem_required = self.get_mem_required(batch_size_attention, sequence_length, element_size, multiplier) + steps = 1 - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + if mem_required > mem_free: + steps = 2**(math.ceil(math.log(mem_required / mem_free, 2))) # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + if steps > CrossAttention.MAX_STEPS: + gb = 1024**3 + max_tensor_elem = mem_free / element_size / batch_size_attention * CrossAttention.MAX_STEPS + max_res = math.pow(max_tensor_elem / multiplier, 0.25) + max_res = math.floor(max_res / 8) * 64 # round max res to closest 64 raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + f'Need: {mem_required/CrossAttention.MAX_STEPS/gb:0.1f}GB free, ' + f'Have: {mem_free/gb:0.1f}GB free') + + slice_size = sequence_length // steps if (sequence_length % steps) == 0 else sequence_length + return slice_size - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): + + def forward(self, x, context=None, mask=None): + batch_size, sequence_length, dim = x.shape + + query_in = self.to_q(x) + context = context if context is not None else x + key_in = self.to_k(context) + value_in = self.to_v(context) + + query = self.reshape_heads_to_batch_dim(query_in) + key = self.reshape_heads_to_batch_dim(key_in).transpose(1, 2) + value = self.reshape_heads_to_batch_dim(value_in) + del query_in, key_in, value_in + + batch_size_attention = query.shape[0] + slice_size = self.get_slice_size(query.device, batch_size_attention, sequence_length, query.element_size()) + + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + + for i in range(0, sequence_length, slice_size): end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1, dtype=q.dtype) + s1 = torch.matmul(query[:, i:end], key) * self.scale + s2 = s1.softmax(dim=-1, dtype=query.dtype) del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + s3 = torch.matmul(s2, value) del s2 - del q, k, v + hidden_states[:, i:end] = s3 + del s3 + + del query, key, value - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 + result = self.reshape_batch_dim_to_heads(hidden_states) + del hidden_states - return self.to_out(r2) + return self.to_out(result) class BasicTransformerBlock(nn.Module): From e6bb37f3ebfc8b72f6aefcafd5eadb828b130fe2 Mon Sep 17 00:00:00 2001 From: Martino Bettucci Date: Tue, 25 Oct 2022 14:25:59 +0200 Subject: [PATCH 16/18] Merge x-attentions from Doggettx fork (#1) * Update attention.py Run attention in a loop to allow for much higher resolutions (over 1920x1920 on a 3090) * Update attention.py Correction to comment * Update attention.py * Fixed memory handling for model.decode_first_stage Better memory handling for model.decode_first_stage so it doesn't crash anymore after 100% rendering * Fixed free memory calculation Old version gave incorrect free memory results causing in crashes on edge cases. * Set model to half Set model to half in txt2img and img2img for less memory usage. * Commented out debug info Forgot to comment out debug info * Raise error when steps too high Technically you could run at higher steps as long as the resolution is dividable by the steps but you're going to run into memory issues later on anyhow. * Added max. res info to memory exception * Reverted in place tensor functions back to CompVis version Improves performance and is no longer needed. * Missed one function to revert * Update README.md * Performance boost and fix sigmoid for higher resolutions Significant performance boost at higher resolutions when running in auto_cast or half mode on 3090 went from 1.13it/s to 1.63it/s at 1024x1024 Will also allow for higher resolutions due to sigmoid fix and using half memory Co-authored-by: Doggettx Co-authored-by: Doggettx <110817577+Doggettx@users.noreply.github.com> --- README.md | 7 ++ ldm/modules/attention.py | 65 ++++++++--- ldm/modules/diffusionmodules/model.py | 152 +++++++++++++++++++------- scripts/img2img.py | 4 + scripts/txt2img.py | 4 + 5 files changed, 177 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index c9e6c3bb1..169399aac 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,10 @@ +# Automatic memory management + +This version can be run stand alone, but it's more meant as proof of concept so other forks can implement similar changes. + +Allows to use resolutions that require up to 64x more VRAM than possible on the default CompVis build + + # Stable Diffusion *Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:* diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f4eff39cc..a0f4f18b8 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,3 +1,4 @@ +import gc from inspect import isfunction import math import torch @@ -89,7 +90,7 @@ def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -161,7 +162,6 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) @@ -170,27 +170,58 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def forward(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + k_in = self.to_k(context) + v_in = self.to_v(context) + del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) class BasicTransformerBlock(nn.Module): diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 533e589a2..de3ce38c6 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import gc import math import torch import torch.nn as nn @@ -32,7 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - return x*torch.sigmoid(x) + t = torch.sigmoid(x) + x *= t + del t + + return x def Normalize(in_channels, num_groups=32): @@ -119,18 +124,30 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, padding=0) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -138,7 +155,7 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - return x+h + return x + h8 class LinAttnBlock(LinearAttention): @@ -174,32 +191,68 @@ def __init__(self, in_channels): stride=1, padding=0) - def forward(self, x): h_ = x h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) + q1 = self.q(h_) + k1 = self.k(h_) v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 - h_ = self.proj_out(h_) + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 - return x+h_ + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 def make_attn(in_channels, attn_type="vanilla"): @@ -540,31 +593,54 @@ def forward(self, z): temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + gc.collect() + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t + if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: - h = torch.tanh(h) + t = h + h = torch.tanh(t) + del t + return h diff --git a/scripts/img2img.py b/scripts/img2img.py index 421e2151d..04b88b54f 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -196,8 +196,12 @@ def main(): opt = parser.parse_args() seed_everything(opt.seed) + # needed when model is in half mode, remove if not using half mode + torch.set_default_tensor_type(torch.HalfTensor) + config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 59c16a1db..a08f52298 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -236,8 +236,12 @@ def main(): seed_everything(opt.seed) + # needed when model is in half mode, remove if not using half mode + torch.set_default_tensor_type(torch.HalfTensor) + config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) From ac6b75dfcbc36065aec3bdc80f5f92d920061166 Mon Sep 17 00:00:00 2001 From: Martino Bettucci Date: Wed, 26 Oct 2022 18:43:32 +0200 Subject: [PATCH 17/18] Update attention.py Signed-off-by: Martino Bettucci --- ldm/modules/attention.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 6c64ca79f..33dafb0cd 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -4,10 +4,20 @@ import torch import torch.nn.functional as F from torch import nn, einsum + +from diffusers.utils.import_utils import is_xformers_available + from einops import rearrange, repeat from ldm.modules.diffusionmodules.util import checkpoint +if is_xformers_available(): + import xformers + import xformers.ops + _USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1 +else: + xformers = None + _USE_MEMORY_EFFICIENT_ATTENTION = False def exists(val): return val is not None @@ -325,4 +335,4 @@ def forward(self, x, context=None): x = block(x, context=context) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) - return x + x_in \ No newline at end of file + return x + x_in From 93b491f9ed9979205cb2fce7a34e16ef0fa23287 Mon Sep 17 00:00:00 2001 From: Martino Bettucci Date: Wed, 26 Oct 2022 19:14:29 +0200 Subject: [PATCH 18/18] Update attention.py Signed-off-by: Martino Bettucci --- ldm/modules/attention.py | 88 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 83 insertions(+), 5 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 33dafb0cd..677b092c8 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -5,13 +5,20 @@ import torch.nn.functional as F from torch import nn, einsum -from diffusers.utils.import_utils import is_xformers_available +from ldm.modules.diffusionmodules.util import is_xformers_available -from einops import rearrange, repeat +from einops import repeat from ldm.modules.diffusionmodules.util import checkpoint -if is_xformers_available(): +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + +if _xformers_available: import xformers import xformers.ops _USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1 @@ -159,6 +166,76 @@ def forward(self, x): return x+h_ +class MemoryEfficientCrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def _maybe_init(self, x): + """ + Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x + : B, Head, Length + """ + if self.attention_op is not None: + return + + _, M, K = x.shape + try: + self.attention_op = xformers.ops.AttentionOpDispatch( + dtype=x.dtype, + device=x.device, + k=K, + attn_bias_type=type(None), + has_dropout=False, + kv_len=M, + q_len=M, + ).op + + except NotImplementedError as err: + raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # init the attention op, if required, using the proper dimensions + self._maybe_init(q) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + # TODO: Use this directly in the attention operation, as a bias + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) class CrossAttention(nn.Module): MAX_STEPS = 64 @@ -273,9 +350,10 @@ def forward(self, x, context=None, mask=None): class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + AttentionBuilder = MemoryEfficientCrossAttention if _USE_MEMORY_EFFICIENT_ATTENTION else CrossAttention + self.attn1 = AttentionBuilder(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + self.attn2 = AttentionBuilder(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)