From d4081b2e665b682a1d186f4ba9d2615e32efa8c7 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 10 Oct 2025 14:50:27 -0400 Subject: [PATCH 1/2] Add non_blocking to loading and moving tensors --- flux_train_network.py | 32 ++--- library/custom_offloading_utils.py | 2 +- library/flux_models.py | 14 +-- library/strategy_sd.py | 194 ++++++++++++++++++++--------- library/train_util.py | 33 +++-- library/utils.py | 2 +- networks/oft.py | 6 +- train_db.py | 7 +- train_network.py | 38 +++--- 9 files changed, 215 insertions(+), 113 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index cfc617088..1e38cefc1 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -232,21 +232,21 @@ def cache_text_encoder_outputs_if_needed( logger.info("move vae and unet to cpu to save memory") org_vae_device = vae.device org_unet_device = unet.device - vae.to("cpu") - unet.to("cpu") + vae = vae.to("cpu") + unet = unet.to("cpu") clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 - text_encoders[1].to(accelerator.device) + text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True) # always not fp8 + text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True) if text_encoders[1].dtype == torch.float8_e4m3fn: # if we load fp8 weights, the model is already fp8, so we use it as is self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) else: # otherwise, we need to convert it to target dtype - text_encoders[1].to(weight_dtype) + text_encoders[1] = text_encoders[1].to(weight_dtype, non_blocking=True) with accelerator.autocast(): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) @@ -276,19 +276,19 @@ def cache_text_encoder_outputs_if_needed( # move back to cpu if not self.is_train_text_encoder(args): logger.info("move CLIP-L back to cpu") - text_encoders[0].to("cpu") + text_encoders[0] = text_encoders[0].to("cpu", non_blocking=True) logger.info("move t5XXL back to cpu") - text_encoders[1].to("cpu") + text_encoders[1] = text_encoders[1].to("cpu", non_blocking=True) clean_memory_on_device(accelerator.device) if not args.lowram: logger.info("move vae and unet back to original device") - vae.to(org_vae_device) - unet.to(org_unet_device) + vae = vae.to(org_vae_device, non_blocking=True) + unet = unet.to(org_unet_device, non_blocking=True) else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device) + text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True) + text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True) def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): text_encoders = text_encoder # for compatibility @@ -429,7 +429,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t noisy_model_input[diff_output_pr_indices], sigmas[diff_output_pr_indices] if sigmas is not None else None, ) - target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype, non_blocking=True) return model_pred, target, timesteps, weighting @@ -468,8 +468,8 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): if index == 0: # CLIP-L logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") - text_encoder.to(te_weight_dtype) # fp8 - text_encoder.text_model.embeddings.to(dtype=weight_dtype) + text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8 + text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype) else: # T5XXL def prepare_fp8(text_encoder, target_dtype): @@ -488,7 +488,7 @@ def forward(hidden_states): for module in text_encoder.modules(): if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) + module = module.to(target_dtype, non_blocking=True) if module.__class__.__name__ in ["T5DenseGatedActDense"]: # print("set", module.__class__.__name__, "hooks") module.forward = forward_hook(module) @@ -497,7 +497,7 @@ def forward(hidden_states): logger.info(f"T5XXL already prepared for fp8") else: logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") - text_encoder.to(te_weight_dtype) # fp8 + text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8 prepare_fp8(text_encoder, weight_dtype) def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 0681dcdcb..48faca277 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -53,7 +53,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye # print( # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" # ) - module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) + module_to_cuda.weight.data = module_to_cuda.weight.data.to(device, non_blocking=True) torch.cuda.current_stream().synchronize() # this prevents the illegal loss value diff --git a/library/flux_models.py b/library/flux_models.py index d2d7e06c7..84b5aa358 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -307,7 +307,7 @@ def forward(self, z: Tensor) -> Tensor: mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) if self.sample: std = torch.exp(0.5 * logvar) - return mean + std * torch.randn_like(mean) + return mean + std * torch.randn_like(mean, pin_memory=True) else: return mean @@ -532,7 +532,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 """ t = time_factor * t half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, pin_memory=True) / half).to(t.device, non_blocking=True) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) @@ -600,7 +600,7 @@ def __init__(self, dim: int): def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: q = self.query_norm(q) k = self.key_norm(k) - return q.to(v), k.to(v) + return q.to(v, non_blocking=True), k.to(v, non_blocking=True) class SelfAttention(nn.Module): @@ -997,7 +997,7 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.double_blocks = None self.single_blocks = None - self.to(device) + self = self.to(device, non_blocking=True) if self.blocks_to_swap: self.double_blocks = save_double_blocks @@ -1081,8 +1081,8 @@ def forward( img = img[:, txt.shape[1] :, ...] if self.training and self.cpu_offload_checkpointing: - img = img.to(self.device) - vec = vec.to(self.device) + img = img.to(self.device, non_blocking=True) + vec = vec.to(self.device, non_blocking=True) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) @@ -1243,7 +1243,7 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.double_blocks = nn.ModuleList() self.single_blocks = nn.ModuleList() - self.to(device) + self = self.to(device, non_blocking=True) if self.blocks_to_swap: self.double_blocks = save_double_blocks diff --git a/library/strategy_sd.py b/library/strategy_sd.py index a44fc4092..45a59d72d 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -30,81 +30,171 @@ def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Opt ) else: self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) - if max_length is None: self.max_length = self.tokenizer.model_max_length else: self.max_length = max_length + 2 - + + self.break_separator = "BREAK" + + def _split_on_break(self, text: str) -> List[str]: + """Split text on BREAK separator (case-sensitive), filtering empty segments.""" + segments = text.split(self.break_separator) + # Filter out empty or whitespace-only segments + filtered = [seg.strip() for seg in segments if seg.strip()] + # Return at least one segment to maintain consistency + return filtered if filtered else [""] + + def _tokenize_segments(self, segments: List[str], weighted: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Tokenize multiple segments and concatenate them.""" + if len(segments) == 1: + # No BREAK present, use existing logic + if weighted: + return self._get_input_ids(self.tokenizer, segments[0], self.max_length, weighted=True) + else: + tokens = self._get_input_ids(self.tokenizer, segments[0], self.max_length) + return tokens, None + + # Multiple segments - tokenize each separately + all_tokens = [] + all_weights = [] if weighted else None + + for segment in segments: + if weighted: + seg_tokens, seg_weights = self._get_input_ids(self.tokenizer, segment, self.max_length, weighted=True) + all_tokens.append(seg_tokens) + all_weights.append(seg_weights) + else: + seg_tokens = self._get_input_ids(self.tokenizer, segment, self.max_length) + all_tokens.append(seg_tokens) + + # Concatenate along the sequence dimension (dim=1 for tokens that are [batch, seq_len] or [n_chunks, seq_len]) + combined_tokens = torch.cat(all_tokens, dim=1) if all_tokens[0].dim() == 2 else torch.cat(all_tokens, dim=0) + combined_weights = None + if weighted: + combined_weights = torch.cat(all_weights, dim=1) if all_weights[0].dim() == 2 else torch.cat(all_weights, dim=0) + + return combined_tokens, combined_weights + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text - return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] - + + tokens_list = [] + for t in text: + segments = self._split_on_break(t) + tokens, _ = self._tokenize_segments(segments, weighted=False) + tokens_list.append(tokens) + + # Pad tokens to same length for stacking + max_length = max(t.shape[-1] for t in tokens_list) + padded_tokens = [] + for tokens in tokens_list: + if tokens.shape[-1] < max_length: + # Pad with pad_token_id + pad_size = max_length - tokens.shape[-1] + if tokens.dim() == 2: + padding = torch.full((tokens.shape[0], pad_size), self.tokenizer.pad_token_id, dtype=tokens.dtype) + tokens = torch.cat([tokens, padding], dim=1) + else: + padding = torch.full((pad_size,), self.tokenizer.pad_token_id, dtype=tokens.dtype) + tokens = torch.cat([tokens, padding], dim=0) + padded_tokens.append(tokens) + + return [torch.stack(padded_tokens, dim=0)] + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: text = [text] if isinstance(text, str) else text + tokens_list = [] weights_list = [] for t in text: - tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) + segments = self._split_on_break(t) + tokens, weights = self._tokenize_segments(segments, weighted=True) tokens_list.append(tokens) weights_list.append(weights) + return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] class SdTextEncodingStrategy(TextEncodingStrategy): def __init__(self, clip_skip: Optional[int] = None) -> None: self.clip_skip = clip_skip - + + def _encode_with_clip_skip(self, text_encoder: Any, tokens: torch.Tensor) -> torch.Tensor: + """Encode tokens with optional CLIP skip.""" + if self.clip_skip is None: + return text_encoder(tokens)[0] + + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + hidden_states = enc_out["hidden_states"][-self.clip_skip] + return text_encoder.text_model.final_layer_norm(hidden_states) + + def _reconstruct_embeddings(self, encoder_hidden_states: torch.Tensor, tokens: torch.Tensor, + max_token_length: int, model_max_length: int, + tokenizer: Any) -> torch.Tensor: + """Reconstruct embeddings from chunked encoding.""" + v1 = tokenizer.pad_token_id == tokenizer.eos_token_id + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + + if not v1: + # v2: ... ... の三連を ... ... へ戻す + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == tokenizer.eos_token: + chunk[j, 0] = chunk[j, 1] + states_list.append(chunk) + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) + else: + # v1: ... の三連を ... へ戻す + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) + + return torch.cat(states_list, dim=1) + + def _apply_weights_single_chunk(self, encoder_hidden_states: torch.Tensor, + weights: torch.Tensor) -> torch.Tensor: + """Apply weights for single chunk case (no max_token_length).""" + return encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + + def _apply_weights_multi_chunk(self, encoder_hidden_states: torch.Tensor, + weights: torch.Tensor) -> torch.Tensor: + """Apply weights for multi-chunk case (with max_token_length).""" + for i in range(weights.shape[1]): + start_idx = i * 75 + 1 + end_idx = i * 75 + 76 + encoder_hidden_states[:, start_idx:end_idx] = ( + encoder_hidden_states[:, start_idx:end_idx] * weights[:, i, 1:-1].unsqueeze(-1) + ) + return encoder_hidden_states + def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] ) -> List[torch.Tensor]: text_encoder = models[0] tokens = tokens[0] sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy - - # tokens: b,n,77 + b_size = tokens.size()[0] max_token_length = tokens.size()[1] * tokens.size()[2] model_max_length = sd_tokenize_strategy.tokenizer.model_max_length - tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 - + + tokens = tokens.reshape((-1, model_max_length)) tokens = tokens.to(text_encoder.device) - - if self.clip_skip is None: - encoder_hidden_states = text_encoder(tokens)[0] - else: - enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - - # bs*3, 77, 768 or 1024 + + encoder_hidden_states = self._encode_with_clip_skip(text_encoder, tokens) encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - + if max_token_length != model_max_length: - v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id - if not v1: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, max_token_length, model_max_length): - chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: - # 空、つまり ...のパターン - chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか - encoder_hidden_states = torch.cat(states_list, dim=1) - else: - # v1: ... の三連を ... へ戻す - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, max_token_length, model_max_length): - states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # - encoder_hidden_states = torch.cat(states_list, dim=1) - + encoder_hidden_states = self._reconstruct_embeddings( + encoder_hidden_states, tokens, max_token_length, + model_max_length, sd_tokenize_strategy.tokenizer + ) + return [encoder_hidden_states] - + def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, @@ -113,23 +203,15 @@ def encode_tokens_with_weights( weights_list: List[torch.Tensor], ) -> List[torch.Tensor]: encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] - weights = weights_list[0].to(encoder_hidden_states.device) - - # apply weights - if weights.shape[1] == 1: # no max_token_length - # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) - encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + + if weights.shape[1] == 1: + encoder_hidden_states = self._apply_weights_single_chunk(encoder_hidden_states, weights) else: - # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) - for i in range(weights.shape[1]): - encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ - :, i, 1:-1 - ].unsqueeze(-1) - + encoder_hidden_states = self._apply_weights_multi_chunk(encoder_hidden_states, weights) + return [encoder_hidden_states] - class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. # and we keep the old npz for the backward compatibility. diff --git a/library/train_util.py b/library/train_util.py index 756d88b1c..719611f61 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4,6 +4,7 @@ import ast import asyncio from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import nullcontext import datetime import importlib import json @@ -26,6 +27,7 @@ # from concurrent.futures import ThreadPoolExecutor, as_completed +from torch.cuda import Stream from tqdm import tqdm from packaging.version import Version @@ -1415,10 +1417,11 @@ def cache_text_encoder_outputs_common( return # prepare tokenizers and text encoders - for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): - text_encoder.to(device) + for i, (text_encoder, device, te_dtype) in enumerate(zip(text_encoders, devices, te_dtypes)): + te_kwargs = {} if te_dtype is not None: - text_encoder.to(dtype=te_dtype) + te_kwargs['dtype'] = te_dtype + text_encoders[i] = text_encoder.to(device, non_blocking=True, **te_dtype) # create batch is_sd3 = len(tokenizers) == 1 @@ -1440,6 +1443,8 @@ def cache_text_encoder_outputs_common( if len(batch) > 0: batches.append(batch) + torch.cuda.synchronize() + # iterate batches: call text encoder and cache outputs for memory or disk logger.info("caching text encoder outputs...") if not is_sd3: @@ -3120,7 +3125,10 @@ def cache_batch_latents( images.append(image) img_tensors = torch.stack(images, dim=0) - img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) + + s = Stream() + + img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype, non_blocking=True) with torch.no_grad(): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") @@ -3156,12 +3164,13 @@ def cache_batch_latents( if not HIGH_VRAM: clean_memory_on_device(vae.device) + torch.cuda.synchronize() def cache_batch_text_encoder_outputs( image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype ): - input_ids1 = input_ids1.to(text_encoders[0].device) - input_ids2 = input_ids2.to(text_encoders[1].device) + input_ids1 = input_ids1.to(text_encoders[0].device, non_blocking=True) + input_ids2 = input_ids2.to(text_encoders[1].device, non_blocking=True) with torch.no_grad(): b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( @@ -5619,9 +5628,9 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio ) # work on low-ram device if args.lowram: - text_encoder.to(accelerator.device) - unet.to(accelerator.device) - vae.to(accelerator.device) + text_encoder = text_encoder.to(accelerator.device, non_blocking=True) + unet = unet.to(accelerator.device, non_blocking=True) + vae = vae.to(accelerator.device, non_blocking=True) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -6435,7 +6444,7 @@ def sample_images_common( distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here org_vae_device = vae.device # CPUにいるはず - vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device + vae = vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet_wrapped) @@ -6470,7 +6479,7 @@ def sample_images_common( requires_safety_checker=False, clip_skip=args.clip_skip, ) - pipeline.to(distributed_state.device) + pipeline = pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) @@ -6521,7 +6530,7 @@ def sample_images_common( torch.set_rng_state(rng_state) if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) - vae.to(org_vae_device) + vae = vae.to(org_vae_device) clean_memory_on_device(accelerator.device) diff --git a/library/utils.py b/library/utils.py index 296fc4151..7ae03f812 100644 --- a/library/utils.py +++ b/library/utils.py @@ -110,7 +110,7 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): # cuda to cpu for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.record_stream(stream) - module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu") stream.synchronize() diff --git a/networks/oft.py b/networks/oft.py index 0c3a5393f..cbadf9b70 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -49,11 +49,11 @@ def __init__( if type(alpha) == torch.Tensor: alpha = alpha.detach().numpy() - + # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility # original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha - self.constraint = alpha * out_dim - + self.constraint = alpha * out_dim + self.register_buffer("alpha", torch.tensor(alpha)) self.block_size = out_dim // self.num_blocks diff --git a/train_db.py b/train_db.py index 4bf3b31ce..7209f7dce 100644 --- a/train_db.py +++ b/train_db.py @@ -239,8 +239,8 @@ def train(args): args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder.to(weight_dtype) + unet = unet.to(weight_dtype) + text_encoder = text_encoder.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい if args.deepspeed: @@ -335,6 +335,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + optimizer.train() current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: @@ -384,7 +385,7 @@ def train(args): else: target = noise - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/train_network.py b/train_network.py index 6cebf5fc7..4c6aa1ecd 100644 --- a/train_network.py +++ b/train_network.py @@ -222,8 +222,8 @@ def is_train_text_encoder(self, args): return not args.network_train_unet_only def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): - for t_enc in text_encoders: - t_enc.to(accelerator.device, dtype=weight_dtype) + for i, t_enc in enumerate(text_encoders): + text_encoders[i] = t_enc.to(accelerator.device, dtype=weight_dtype) def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample @@ -323,7 +323,7 @@ def get_noise_pred_and_target( indices=diff_output_pr_indices, ) network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step - target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype, non_blocking=True) return noise_pred, target, timesteps, None @@ -352,7 +352,7 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): text_encoder.text_model.embeddings.requires_grad_(True) def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): - text_encoder.text_model.embeddings.to(dtype=weight_dtype) + text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype) def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module @@ -390,11 +390,11 @@ def process_batch( """ with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) + latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device, non_blocking=True)) else: # latentに変換 if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: - latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype, non_blocking=True)) else: chunks = [ batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size) @@ -402,7 +402,7 @@ def process_batch( list_latents = [] for chunk in chunks: with torch.no_grad(): - chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype)) + chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype, non_blocking=True)) list_latents.append(chunk) latents = torch.cat(list_latents, dim=0) @@ -431,14 +431,14 @@ def process_batch( weights_list, ) else: - input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + input_ids = [ids.to(accelerator.device, non_blocking=True) for ids in batch["input_ids_list"]] encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) if args.full_fp16: - encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + encoded_text_encoder_conds = [c.to(weight_dtype, non_blocking=True) for c in encoded_text_encoder_conds] # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: @@ -449,6 +449,8 @@ def process_batch( if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] + torch.cuda.synchronize() + # sample noise, call unet, get target noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, @@ -816,13 +818,13 @@ def train(self, args): args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") - network.to(weight_dtype) + network = network.to(weight_dtype) elif args.full_bf16: assert ( args.mixed_precision == "bf16" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") - network.to(weight_dtype) + network = network.to(weight_dtype) unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram @@ -844,7 +846,7 @@ def train(self, args): # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") - unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator + unet = unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet.requires_grad_(False) if self.cast_unet(args): @@ -858,7 +860,7 @@ def train(self, args): # nn.Embedding not support FP8 if te_weight_dtype != weight_dtype: - self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) + self.prepare_text_encoder_fp8(i, text_encoders[i], te_weight_dtype, weight_dtype) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: @@ -920,7 +922,7 @@ def train(self, args): if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) + vae = vae.to(accelerator.device, dtype=vae_dtype) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -1398,6 +1400,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen torch.cuda.set_rng_state(gpu_rng_state) random.setstate(python_rng_state) + torch.cuda.empty_cache() + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1454,6 +1458,12 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen if hasattr(network, "update_norms"): network.update_norms() + torch.cuda.synchronize() # Ensure GPU ops complete before next batch + + # Periodic cleanup + if step % 50 == 0: + torch.cuda.empty_cache() + optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) From 46f9e24b24bdbc4c3d839fc3362c2883d637fb2f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 10 Oct 2025 14:54:37 -0400 Subject: [PATCH 2/2] fix: revert strategy_sd.py and remove latents from huber --- library/strategy_sd.py | 196 ++++++++++++----------------------------- train_db.py | 2 +- 2 files changed, 58 insertions(+), 140 deletions(-) diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 45a59d72d..d0a3a68bf 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -30,171 +30,81 @@ def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Opt ) else: self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + if max_length is None: self.max_length = self.tokenizer.model_max_length else: self.max_length = max_length + 2 - - self.break_separator = "BREAK" - - def _split_on_break(self, text: str) -> List[str]: - """Split text on BREAK separator (case-sensitive), filtering empty segments.""" - segments = text.split(self.break_separator) - # Filter out empty or whitespace-only segments - filtered = [seg.strip() for seg in segments if seg.strip()] - # Return at least one segment to maintain consistency - return filtered if filtered else [""] - - def _tokenize_segments(self, segments: List[str], weighted: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Tokenize multiple segments and concatenate them.""" - if len(segments) == 1: - # No BREAK present, use existing logic - if weighted: - return self._get_input_ids(self.tokenizer, segments[0], self.max_length, weighted=True) - else: - tokens = self._get_input_ids(self.tokenizer, segments[0], self.max_length) - return tokens, None - - # Multiple segments - tokenize each separately - all_tokens = [] - all_weights = [] if weighted else None - - for segment in segments: - if weighted: - seg_tokens, seg_weights = self._get_input_ids(self.tokenizer, segment, self.max_length, weighted=True) - all_tokens.append(seg_tokens) - all_weights.append(seg_weights) - else: - seg_tokens = self._get_input_ids(self.tokenizer, segment, self.max_length) - all_tokens.append(seg_tokens) - - # Concatenate along the sequence dimension (dim=1 for tokens that are [batch, seq_len] or [n_chunks, seq_len]) - combined_tokens = torch.cat(all_tokens, dim=1) if all_tokens[0].dim() == 2 else torch.cat(all_tokens, dim=0) - combined_weights = None - if weighted: - combined_weights = torch.cat(all_weights, dim=1) if all_weights[0].dim() == 2 else torch.cat(all_weights, dim=0) - - return combined_tokens, combined_weights - + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text - - tokens_list = [] - for t in text: - segments = self._split_on_break(t) - tokens, _ = self._tokenize_segments(segments, weighted=False) - tokens_list.append(tokens) - - # Pad tokens to same length for stacking - max_length = max(t.shape[-1] for t in tokens_list) - padded_tokens = [] - for tokens in tokens_list: - if tokens.shape[-1] < max_length: - # Pad with pad_token_id - pad_size = max_length - tokens.shape[-1] - if tokens.dim() == 2: - padding = torch.full((tokens.shape[0], pad_size), self.tokenizer.pad_token_id, dtype=tokens.dtype) - tokens = torch.cat([tokens, padding], dim=1) - else: - padding = torch.full((pad_size,), self.tokenizer.pad_token_id, dtype=tokens.dtype) - tokens = torch.cat([tokens, padding], dim=0) - padded_tokens.append(tokens) - - return [torch.stack(padded_tokens, dim=0)] - - def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: text = [text] if isinstance(text, str) else text - tokens_list = [] weights_list = [] for t in text: - segments = self._split_on_break(t) - tokens, weights = self._tokenize_segments(segments, weighted=True) + tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) tokens_list.append(tokens) weights_list.append(weights) - return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] class SdTextEncodingStrategy(TextEncodingStrategy): def __init__(self, clip_skip: Optional[int] = None) -> None: self.clip_skip = clip_skip - - def _encode_with_clip_skip(self, text_encoder: Any, tokens: torch.Tensor) -> torch.Tensor: - """Encode tokens with optional CLIP skip.""" - if self.clip_skip is None: - return text_encoder(tokens)[0] - - enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) - hidden_states = enc_out["hidden_states"][-self.clip_skip] - return text_encoder.text_model.final_layer_norm(hidden_states) - - def _reconstruct_embeddings(self, encoder_hidden_states: torch.Tensor, tokens: torch.Tensor, - max_token_length: int, model_max_length: int, - tokenizer: Any) -> torch.Tensor: - """Reconstruct embeddings from chunked encoding.""" - v1 = tokenizer.pad_token_id == tokenizer.eos_token_id - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - - if not v1: - # v2: ... ... の三連を ... ... へ戻す - for i in range(1, max_token_length, model_max_length): - chunk = encoder_hidden_states[:, i : i + model_max_length - 2] - if i > 0: - for j in range(len(chunk)): - if tokens[j, 1] == tokenizer.eos_token: - chunk[j, 0] = chunk[j, 1] - states_list.append(chunk) - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) - else: - # v1: ... の三連を ... へ戻す - for i in range(1, max_token_length, model_max_length): - states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) - - return torch.cat(states_list, dim=1) - - def _apply_weights_single_chunk(self, encoder_hidden_states: torch.Tensor, - weights: torch.Tensor) -> torch.Tensor: - """Apply weights for single chunk case (no max_token_length).""" - return encoder_hidden_states * weights.squeeze(1).unsqueeze(2) - - def _apply_weights_multi_chunk(self, encoder_hidden_states: torch.Tensor, - weights: torch.Tensor) -> torch.Tensor: - """Apply weights for multi-chunk case (with max_token_length).""" - for i in range(weights.shape[1]): - start_idx = i * 75 + 1 - end_idx = i * 75 + 76 - encoder_hidden_states[:, start_idx:end_idx] = ( - encoder_hidden_states[:, start_idx:end_idx] * weights[:, i, 1:-1].unsqueeze(-1) - ) - return encoder_hidden_states - + def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] ) -> List[torch.Tensor]: text_encoder = models[0] tokens = tokens[0] sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy - + + # tokens: b,n,77 b_size = tokens.size()[0] max_token_length = tokens.size()[1] * tokens.size()[2] model_max_length = sd_tokenize_strategy.tokenizer.model_max_length - - tokens = tokens.reshape((-1, model_max_length)) + tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + tokens = tokens.to(text_encoder.device) - - encoder_hidden_states = self._encode_with_clip_skip(text_encoder, tokens) + + if self.clip_skip is None: + encoder_hidden_states = text_encoder(tokens)[0] + else: + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - + if max_token_length != model_max_length: - encoder_hidden_states = self._reconstruct_embeddings( - encoder_hidden_states, tokens, max_token_length, - model_max_length, sd_tokenize_strategy.tokenizer - ) - + v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id + if not v1: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: + # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + return [encoder_hidden_states] - + def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, @@ -203,15 +113,23 @@ def encode_tokens_with_weights( weights_list: List[torch.Tensor], ) -> List[torch.Tensor]: encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] + weights = weights_list[0].to(encoder_hidden_states.device) - - if weights.shape[1] == 1: - encoder_hidden_states = self._apply_weights_single_chunk(encoder_hidden_states, weights) + + # apply weights + if weights.shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) else: - encoder_hidden_states = self._apply_weights_multi_chunk(encoder_hidden_states, weights) - + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for i in range(weights.shape[1]): + encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ + :, i, 1:-1 + ].unsqueeze(-1) + return [encoder_hidden_states] + class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. # and we keep the old npz for the backward compatibility. diff --git a/train_db.py b/train_db.py index 7209f7dce..689d6c970 100644 --- a/train_db.py +++ b/train_db.py @@ -385,7 +385,7 @@ def train(args): else: target = noise - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch)