Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

*.pyc
*.ckpt
outputs/
src/
.idea/
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Automatic memory management

This version can be run stand alone, but it's more meant as proof of concept so other forks can implement similar changes.

Allows to use resolutions that require up to 64x more VRAM than possible on the default CompVis build


# Stable Diffusion
*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*

Expand Down
201 changes: 178 additions & 23 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
import gc
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat

from ldm.modules.diffusionmodules.util import is_xformers_available

from einops import repeat

from ldm.modules.diffusionmodules.util import checkpoint

_xformers_available = importlib.util.find_spec("xformers") is not None
try:
_xformers_version = importlib_metadata.version("xformers")
logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
_xformers_available = False

if _xformers_available:
import xformers
import xformers.ops
_USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1
else:
xformers = None
_USE_MEMORY_EFFICIENT_ATTENTION = False

def exists(val):
return val is not None
Expand Down Expand Up @@ -89,7 +107,7 @@ def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
Expand Down Expand Up @@ -148,8 +166,80 @@ def forward(self, x):

return x+h_

class MemoryEfficientCrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.heads = heads
self.dim_head = dim_head

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None

def _maybe_init(self, x):
"""
Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x
: B, Head, Length
"""
if self.attention_op is not None:
return

_, M, K = x.shape
try:
self.attention_op = xformers.ops.AttentionOpDispatch(
dtype=x.dtype,
device=x.device,
k=K,
attn_bias_type=type(None),
has_dropout=False,
kv_len=M,
q_len=M,
).op

except NotImplementedError as err:
raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}")

def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)

b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)

# init the attention op, if required, using the proper dimensions
self._maybe_init(q)

# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)

# TODO: Use this directly in the attention operation, as a bias
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)

class CrossAttention(nn.Module):
MAX_STEPS = 64

def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
Expand All @@ -161,44 +251,109 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)

def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor

def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def get_mem_free(self, device):
stats = torch.cuda.memory_stats(device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
return mem_free_cuda + mem_free_torch

def get_mem_required(self, batch_size_attention, sequence_length, element_size, multiplier):
tensor_size = batch_size_attention * sequence_length**2 * element_size
return tensor_size * multiplier

def get_slice_size(self, device, batch_size_attention, sequence_length, element_size):
multiplier = 3. if element_size == 2 else 2.5
mem_free = self.get_mem_free(device)
mem_required = self.get_mem_required(batch_size_attention, sequence_length, element_size, multiplier)
steps = 1

if mem_required > mem_free:
steps = 2**(math.ceil(math.log(mem_required / mem_free, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")

if steps > CrossAttention.MAX_STEPS:
gb = 1024**3
max_tensor_elem = mem_free / element_size / batch_size_attention * CrossAttention.MAX_STEPS
max_res = math.pow(max_tensor_elem / multiplier, 0.25)
max_res = math.floor(max_res / 8) * 64 # round max res to closest 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/CrossAttention.MAX_STEPS/gb:0.1f}GB free, '
f'Have: {mem_free/gb:0.1f}GB free')

slice_size = sequence_length // steps if (sequence_length % steps) == 0 else sequence_length
return slice_size


def forward(self, x, context=None, mask=None):
h = self.heads
batch_size, sequence_length, dim = x.shape

q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
query_in = self.to_q(x)
context = context if context is not None else x
key_in = self.to_k(context)
value_in = self.to_v(context)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
query = self.reshape_heads_to_batch_dim(query_in)
key = self.reshape_heads_to_batch_dim(key_in).transpose(1, 2)
value = self.reshape_heads_to_batch_dim(value_in)
del query_in, key_in, value_in

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
batch_size_attention = query.shape[0]
slice_size = self.get_slice_size(query.device, batch_size_attention, sequence_length, query.element_size())

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
for i in range(0, sequence_length, slice_size):
end = i + slice_size

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
s1 = torch.matmul(query[:, i:end], key) * self.scale
s2 = s1.softmax(dim=-1, dtype=query.dtype)
del s1

s3 = torch.matmul(s2, value)
del s2

hidden_states[:, i:end] = s3
del s3

del query, key, value

result = self.reshape_batch_dim_to_heads(hidden_states)
del hidden_states

return self.to_out(result)


class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
AttentionBuilder = MemoryEfficientCrossAttention if _USE_MEMORY_EFFICIENT_ATTENTION else CrossAttention
self.attn1 = AttentionBuilder(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
self.attn2 = AttentionBuilder(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
Expand Down Expand Up @@ -258,4 +413,4 @@ def forward(self, x, context=None):
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
return x + x_in
Loading