diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..a618fafc7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ + +*.pyc +*.ckpt +outputs/ +src/ +.idea/ diff --git a/README.md b/README.md index c9e6c3bb1..169399aac 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,10 @@ +# Automatic memory management + +This version can be run stand alone, but it's more meant as proof of concept so other forks can implement similar changes. + +Allows to use resolutions that require up to 64x more VRAM than possible on the default CompVis build + + # Stable Diffusion *Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:* diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f4eff39cc..677b092c8 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,12 +1,30 @@ +import gc from inspect import isfunction import math import torch import torch.nn.functional as F from torch import nn, einsum -from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import is_xformers_available + +from einops import repeat from ldm.modules.diffusionmodules.util import checkpoint +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + +if _xformers_available: + import xformers + import xformers.ops + _USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1 +else: + xformers = None + _USE_MEMORY_EFFICIENT_ATTENTION = False def exists(val): return val is not None @@ -89,7 +107,7 @@ def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -148,8 +166,80 @@ def forward(self, x): return x+h_ +class MemoryEfficientCrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def _maybe_init(self, x): + """ + Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x + : B, Head, Length + """ + if self.attention_op is not None: + return + + _, M, K = x.shape + try: + self.attention_op = xformers.ops.AttentionOpDispatch( + dtype=x.dtype, + device=x.device, + k=K, + attn_bias_type=type(None), + has_dropout=False, + kv_len=M, + q_len=M, + ).op + + except NotImplementedError as err: + raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # init the attention op, if required, using the proper dimensions + self._maybe_init(q) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + # TODO: Use this directly in the attention operation, as a bias + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) class CrossAttention(nn.Module): + MAX_STEPS = 64 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads @@ -161,44 +251,109 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def get_mem_free(self, device): + stats = torch.cuda.memory_stats(device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + return mem_free_cuda + mem_free_torch + + def get_mem_required(self, batch_size_attention, sequence_length, element_size, multiplier): + tensor_size = batch_size_attention * sequence_length**2 * element_size + return tensor_size * multiplier + + def get_slice_size(self, device, batch_size_attention, sequence_length, element_size): + multiplier = 3. if element_size == 2 else 2.5 + mem_free = self.get_mem_free(device) + mem_required = self.get_mem_required(batch_size_attention, sequence_length, element_size, multiplier) + steps = 1 + + if mem_required > mem_free: + steps = 2**(math.ceil(math.log(mem_required / mem_free, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > CrossAttention.MAX_STEPS: + gb = 1024**3 + max_tensor_elem = mem_free / element_size / batch_size_attention * CrossAttention.MAX_STEPS + max_res = math.pow(max_tensor_elem / multiplier, 0.25) + max_res = math.floor(max_res / 8) * 64 # round max res to closest 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/CrossAttention.MAX_STEPS/gb:0.1f}GB free, ' + f'Have: {mem_free/gb:0.1f}GB free') + + slice_size = sequence_length // steps if (sequence_length % steps) == 0 else sequence_length + return slice_size + + def forward(self, x, context=None, mask=None): - h = self.heads + batch_size, sequence_length, dim = x.shape - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + query_in = self.to_q(x) + context = context if context is not None else x + key_in = self.to_k(context) + value_in = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + query = self.reshape_heads_to_batch_dim(query_in) + key = self.reshape_heads_to_batch_dim(key_in).transpose(1, 2) + value = self.reshape_heads_to_batch_dim(value_in) + del query_in, key_in, value_in - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + batch_size_attention = query.shape[0] + slice_size = self.get_slice_size(query.device, batch_size_attention, sequence_length, query.element_size()) - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) + for i in range(0, sequence_length, slice_size): + end = i + slice_size - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + s1 = torch.matmul(query[:, i:end], key) * self.scale + s2 = s1.softmax(dim=-1, dtype=query.dtype) + del s1 + + s3 = torch.matmul(s2, value) + del s2 + + hidden_states[:, i:end] = s3 + del s3 + + del query, key, value + + result = self.reshape_batch_dim_to_heads(hidden_states) + del hidden_states + + return self.to_out(result) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + AttentionBuilder = MemoryEfficientCrossAttention if _USE_MEMORY_EFFICIENT_ATTENTION else CrossAttention + self.attn1 = AttentionBuilder(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + self.attn2 = AttentionBuilder(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) @@ -258,4 +413,4 @@ def forward(self, x, context=None): x = block(x, context=context) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) - return x + x_in \ No newline at end of file + return x + x_in diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 533e589a2..de3ce38c6 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import gc import math import torch import torch.nn as nn @@ -32,7 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - return x*torch.sigmoid(x) + t = torch.sigmoid(x) + x *= t + del t + + return x def Normalize(in_channels, num_groups=32): @@ -119,18 +124,30 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, padding=0) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -138,7 +155,7 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - return x+h + return x + h8 class LinAttnBlock(LinearAttention): @@ -174,32 +191,68 @@ def __init__(self, in_channels): stride=1, padding=0) - def forward(self, x): h_ = x h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) + q1 = self.q(h_) + k1 = self.k(h_) v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 - h_ = self.proj_out(h_) + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 - return x+h_ + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 def make_attn(in_channels, attn_type="vanilla"): @@ -540,31 +593,54 @@ def forward(self, z): temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + gc.collect() + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t + if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: - h = torch.tanh(h) + t = h + h = torch.tanh(t) + del t + return h diff --git a/scripts/img2img.py b/scripts/img2img.py index 421e2151d..04b88b54f 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -196,8 +196,12 @@ def main(): opt = parser.parse_args() seed_everything(opt.seed) + # needed when model is in half mode, remove if not using half mode + torch.set_default_tensor_type(torch.HalfTensor) + config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) diff --git a/scripts/prompt_parser.py b/scripts/prompt_parser.py new file mode 100644 index 000000000..44ff8d2c1 --- /dev/null +++ b/scripts/prompt_parser.py @@ -0,0 +1,292 @@ +import torch +import numpy as np +import re +from abc import abstractmethod, ABC +from dataclasses import dataclass +from typing import TypeVar, List, Type, Optional +from ldm.models.diffusion.ddpm import LatentDiffusion + +T = TypeVar('T') + + +def parse_float(text: str) -> float: + try: + return float(text) + except ValueError: + return 0. + + +def get_step(weight: float, steps: int) -> int: + step = int(weight) if weight >= 1. else int(weight * steps) + return max(0, min(steps-1, step)) + + +def parse_step(text: str, steps: int) -> int: + return get_step(parse_float(text), steps) + + +@dataclass +class Guidance: + scale: Optional[float] + cond: Optional[object] + uncond: Optional[object] + prompt: str + + def apply(self, cond: object, uncond: object, scale: float): + if self.cond is not None: + cond = self.cond + + if self.scale is not None: + scale = self.scale + + if self.uncond is not None: + uncond = self.uncond + + return cond, uncond, scale + + +@dataclass +class Prompt: + pos_text: str + neg_text: str + step: int + + +class Token(ABC): + @staticmethod + @abstractmethod + def starts_with(char: str) -> bool: + pass + + @staticmethod + @abstractmethod + def create(text: str, steps: int): + pass + + +class TextToken(Token): + @staticmethod + def create(text: str, steps: int): + return text + + @staticmethod + def starts_with(char: str) -> bool: + return True + + +@dataclass +class CommandToken(Token): + method: str + args: List[str] + step: int + + @staticmethod + def create(text: str, steps: int): + return None + + @staticmethod + def starts_with(char: str) -> bool: + return char == '@' + + +@dataclass +class SwapToken(Token): + word1: str = "" + word2: str = "" + step: int = 0 + + @staticmethod + def create(text: str, steps: int): + value = text[1:-1] + fields = str.split(value, ':') + if len(fields) < 2: + return SwapToken(word2=value) + if len(fields) == 2: + return SwapToken(word2=fields[0], step=parse_step(fields[1], steps)) + else: + return SwapToken(word1=fields[0], word2=fields[1], step=parse_step(fields[2], steps)) + + @staticmethod + def starts_with(char: str) -> bool: + return char == '[' + + +@dataclass +class ScaleToken(Token): + scale: float = -1. + step: int = 0 + + @staticmethod + def create(text: str, steps: int): + fields = str.split(text[1:-1], ':') + if len(fields) != 2: + return ScaleToken() + + return ScaleToken(scale=parse_float(fields[0]), step=parse_step(fields[1], steps)) + + @staticmethod + def starts_with(char: str) -> bool: + return char == '{' + + +def filter_type(array: list, dtype: Type[T]) -> List[T]: + return [item for item in array if type(item) is dtype] + + +class PromptParser: + def __init__(self, model): + self.model = model + self.regex = re.compile(r'\[.*?]|\{.*?}|.+?(?=[\[{])|.*') + self.tokens = [SwapToken, ScaleToken, TextToken] + + # test regex for commands, not used yet + # \[.*?]|\{.+?}|@[^\s(]+\(.*?\)|.+?(?=[\[{@])|.+ + + def get_prompt_guidance(self, prompt, steps, batch_size) -> List[Guidance]: + result: List[Guidance] = list() + + # initialize array + for i in range(0, steps): + result.append(Guidance(None, None, None, "")) + + cur_pos = "" + cur_neg = "" + # set prompts + print("Used prompts:") + for item in self.__parse_prompt(prompt, steps): + if item.pos_text != cur_pos: + print(f'step {item.step}: "{item.pos_text}"') + result[item.step].cond = self.model.get_learned_conditioning(batch_size * item.pos_text) + cur_pos = item.pos_text + + if item.neg_text != cur_neg: + print(f'step {item.step}: [negative] "{item.neg_text}"') + result[item.step].uncond = self.model.get_learned_conditioning(batch_size * item.neg_text) + cur_neg = item.neg_text + + result[item.step].prompt = cur_pos + + # set scales + for scale in self.__get_scales(prompt, steps): + result[scale.step].scale = scale.scale + + return result + + def __get_scales(self, prompt: str, steps: int) -> List[ScaleToken]: + tokens = self.__get_tokens(prompt, steps) + scales = filter_type(tokens, ScaleToken) + + return scales + + def __get_word_info(self, word: str) -> (str, bool): + if len(word) == 0: + return word, False + + if word[0] == '-': + return word[1:], False + + return word, True + + def __parse_prompt(self, prompt, steps) -> List[Prompt]: + tokens = self.__get_tokens(prompt, steps) + values = np.array([token.step for token in filter_type(tokens, SwapToken)]) + values = np.concatenate(([0], values)) + values = np.sort(np.unique(values)) + + builders = [(value, list(), list()) for value in values] + + for token in tokens: + if type(token) is SwapToken: + word1, is_pos1 = self.__get_word_info(token.word1) + word2, is_pos2 = self.__get_word_info(token.word2) + + if not len(word2): + is_pos2 = is_pos1 + + for (value, pos_text, neg_text) in builders: + if value < token.step: + is_pos, word = is_pos1, word1 + else: + is_pos, word = is_pos2, word2 + + builder = pos_text if is_pos else neg_text + builder.append(word) + + elif type(token) is str: + for _, pos_text, _ in builders: + pos_text.append(token) + + return [Prompt(pos_text=''.join(pos_text), neg_text=''.join(neg_text), step=int(value)) + for value, pos_text, neg_text in builders] + + def __get_tokens(self, prompt: str, steps: int): + parts = self.regex.findall(prompt) + result = list() + + for part in parts: + if len(part) == 0: + continue + + for token in self.tokens: + if token.starts_with(part[0]): + result.append(token.create(part, steps)) + break + + return result + + +class PromptGuidanceModelWrapper: + def __init__(self, model: LatentDiffusion): + self.model = model + + self.__step: int = None + self.prompt_guidance: List[Guidance] = None + self.scale: float = 0. + self.init_scale: float = 0. + self.c = None + self.uc = None + self.parser = PromptParser(model) + + def __getattr__(self, attr): + return getattr(self.model, attr) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if self.prompt_guidance is None: + raise RuntimeError("Wrapper not prepared, make sure to call prepare before using the model") + + if self.__step < len(self.prompt_guidance): + self.c, self.uc, self.scale = \ + self.prompt_guidance[self.__step].apply(self.c, self.uc, self.scale) + + has_unconditional = len(cond) == 2 + if has_unconditional: + cond[0] = self.uc + cond[1] = self.c + else: + cond = self.c + + result = self.model.apply_model(x_noisy, t, cond, return_ids) + + if has_unconditional and self.scale != self.init_scale: + e_t_uncond, e_t = result.chunk(2) + e_diff = e_t - e_t_uncond + e_t = e_t_uncond + (self.scale / self.init_scale) * e_diff + result = torch.cat([e_t_uncond, e_t]) + + self.__step += 1 + + return result + + def prepare_prompts(self, prompt: str, scale: float, steps: int, batch_size: int): + self.__step = 0 + + self.prompt_guidance = self.parser.get_prompt_guidance(prompt, steps, batch_size) + + uc = self.model.get_learned_conditioning(batch_size * [""]) + c, uc, scale = self.prompt_guidance[0].apply(uc, uc, scale) + + self.init_scale = scale + self.scale = scale + self.c = c + self.uc = uc + diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 59c16a1db..0998ce429 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -20,6 +20,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor +from prompt_parser import PromptGuidanceModelWrapper # load safety model @@ -151,7 +152,7 @@ def main(): parser.add_argument( "--n_iter", type=int, - default=2, + default=1, help="sample this often", ) parser.add_argument( @@ -181,7 +182,7 @@ def main(): parser.add_argument( "--n_samples", type=int, - default=3, + default=1, help="how many samples to produce for each given prompt. A.k.a. batch size", ) parser.add_argument( @@ -236,11 +237,16 @@ def main(): seed_everything(opt.seed) + # needed when model is in half mode, remove if not using half mode + torch.set_default_tensor_type(torch.HalfTensor) + config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) + model = PromptGuidanceModelWrapper(model) if opt.plms: sampler = PLMSSampler(model) @@ -286,19 +292,19 @@ def main(): for n in trange(opt.n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) - c = model.get_learned_conditioning(prompts) + + model.prepare_prompts(prompts[0], opt.scale, opt.ddim_steps, opt.n_samples) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, + conditioning=model.c, batch_size=opt.n_samples, shape=shape, verbose=False, unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, + unconditional_conditioning=model.uc, eta=opt.ddim_eta, x_T=start_code)