diff --git a/README.md b/README.md index 4f3f0b8..eacd212 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,8 @@ -# TensorRT Extension for Stable Diffusion +# TensorRT Extension for Stable Diffusion -This extension enables the best performance on NVIDIA RTX GPUs for Stable Diffusion with TensorRT. - -You need to install the extension and generate optimized engines before using the extension. Please follow the instructions below to set everything up. - -Supports Stable Diffusion 1.5 and 2.1. Native SDXL support coming in a future release. Please use the [dev branch](https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/dev) if you would like to use it today. Note that the Dev branch is not intended for production work and may break other things that you are currently using. +This extension enables the best performance on NVIDIA RTX GPUs for Stable Diffusion with TensorRT. +You need to install the extension and generate optimized engines before using the extension. Please follow the instructions below to set everything up. +Supports Stable Diffusion 1.5,2.1, SDXL, SDXL Turbo, and LCM. For SDXL and SDXL Turbo, we recommend using a GPU with 12 GB or more VRAM for best performance due to its size and computational intensity. ## Installation @@ -15,44 +13,50 @@ Example instructions for Automatic1111: 3. Copy the link to this repository and paste it into URL for extension's git repository 4. Click Install + ## How to use -1. Click on the “Generate Default Engines” button. This step takes 2-10 minutes depending on your GPU. You can generate engines for other combinations. +1. Click on the “Generate Default Engines” button. This step takes 2-10 minutes depending on your GPU. You can generate engines for other combinations. 2. Go to Settings → User Interface → Quick Settings List, add sd_unet. Apply these settings, then reload the UI. -3. Back in the main UI, select the TRT model from the sd_unet dropdown menu at the top of the page. -4. You can now start generating images accelerated by TRT. If you need to create more Engines, go to the TensorRT tab. +3. Back in the main UI, select “Automatic” from the sd_unet dropdown menu at the top of the page if not already selected. +4. You can now start generating images accelerated by TRT. If you need to create more Engines, go to the TensorRT tab. Happy prompting! +### LoRA + +To use LoRA / LyCORIS checkpoints they first need to be converted to a TensorRT format. This can be done in the TensorRT extension in the Export LoRA tab. +1. Select a LoRA checkpoint from the dropdown. +2. Export. (This will not generate an engine but only convert the weights in ~20s) +3. You can use the exported LoRAs as usual using the prompt embedding. + + ## More Information TensorRT uses optimized engines for specific resolutions and batch sizes. You can generate as many optimized engines as desired. Types: - -- The "Export Default Engines” selection adds support for resolutions between 512x512 and 768x768 for Stable Diffusion 1.5 and 768x768 to 1024x1024 for SDXL with batch sizes 1 to 4. -- Static engines support a single specific output resolution and batch size. -- Dynamic engines support a range of resolutions and batch sizes, at a small cost in performance. Wider ranges will use more VRAM. +- The "Export Default Engines” selection adds support for resolutions between `512 x 512` and 768x768 for Stable Diffusion 1.5 and 2.1 with batch sizes 1 to 4. For SDXL, this selection generates an engine supporting a resolution of `1024 x 1024` with a batch size of `1`. +- Static engines support a single specific output resolution and batch size. +- Dynamic engines support a range of resolutions and batch sizes, at a small cost in performance. Wider ranges will use more VRAM. +- The first time generating an engine for a checkpoint will take longer. Additional engines generated for the same checkpoint will be much faster. Each preset can be adjusted with the “Advanced Settings” option. More detailed instructions can be found [here](https://nvidia.custhelp.com/app/answers/detail/a_id/5487/~/tensorrt-extension-for-stable-diffusion-web-ui). ### Common Issues/Limitations -**HIRES FIX:** If using the hires.fix option in Automatic1111 you must build engines that match both the starting and ending resolutions. For instance, if initial size is `512 x 512` and hires.fix upscales to `1024 x 1024`, you must either generate two engines, one at 512 and one at 1024, or generate a single dynamic engine that covers the whole range. -Having two seperate engines will heavily impact performance at the moment. Stay tuned for updates. +**HIRES FIX**: If using the hires.fix option in Automatic1111 you must build engines that match both the starting and ending resolutions. For instance, if the initial size is `512 x 512` and hires.fix upscales to `1024 x 1024`, you must generate a single dynamic engine that covers the whole range. -**Resolution:** When generating images the resolution needs to be a multiple of 64. This applies to hires.fix as well, requiring the low and high-res to be divisible by 64. +**Resolution**: When generating images, the resolution needs to be a multiple of 64. This applies to hires.fix as well, requiring the low and high-res to be divisible by 64. -**Failing CMD arguments:** +**Failing CMD arguments**: -- `medvram` and `lowvram` Have caused issues when compiling the engine and running it. +- `medvram` and `lowvram` Have caused issues when compiling the engine. - `api` Has caused the `model.json` to not be updated. Resulting in SD Unets not appearing after compilation. - -**Failing installation or TensorRT tab not appearing in UI:** This is most likely due to a failed install. To resolve this manually use this [guide](https://github.com/NVIDIA/Stable-Diffusion-WebUI-TensorRT/issues/27#issuecomment-1767570566). +- Failing installation or TensorRT tab not appearing in UI: This is most likely due to a failed install. To resolve this manually use this [guide](https://github.com/NVIDIA/Stable-Diffusion-WebUI-TensorRT/issues/27#issuecomment-1767570566). ## Requirements +Driver: -**Driver**: - -- Linux: >= 450.80.02 -- Windows: >=452.39 + Linux: >= 450.80.02 +- Windows: >= 452.39 -We always recommend keeping the driver up-to-date for system wide performance improvments. \ No newline at end of file +We always recommend keeping the driver up-to-date for system wide performance improvements. diff --git a/datastructures.py b/datastructures.py new file mode 100644 index 0000000..50f0cc7 --- /dev/null +++ b/datastructures.py @@ -0,0 +1,239 @@ +from dataclasses import dataclass +from enum import Enum +from json import JSONEncoder +import torch + + +class SDVersion(Enum): + SD1 = 1 + SD2 = 2 + SDXL = 3 + Unknown = -1 + + def __str__(self): + return self.name + + @classmethod + def from_str(cls, str): + try: + return cls[str] + except KeyError: + return cls.Unknown + + def match(self, sd_model): + if sd_model.is_sd1 and self == SDVersion.SD1: + return True + elif sd_model.is_sd2 and self == SDVersion.SD2: + return True + elif sd_model.is_sdxl and self == SDVersion.SDXL: + return True + elif self == SDVersion.Unknown: + return True + else: + return False + + +class ModelType(Enum): + UNET = 0 + CONTROLNET = 1 + LORA = 2 + UNDEFINED = -1 + + @classmethod + def from_string(cls, s): + return getattr(cls, s.upper(), None) + + def __str__(self): + return self.name.lower() + + +@dataclass +class ModelConfig: + profile: dict + static_shapes: bool + fp32: bool + inpaint: bool + refit: bool + lora: bool + vram: int + unet_hidden_dim: int = 4 + + def is_compatible_from_dict(self, feed_dict: dict): + distance = 0 + for k, v in feed_dict.items(): + _min, _opt, _max = self.profile[k] + v_tensor = torch.Tensor(list(v.shape)) + r_min = torch.Tensor(_max) - v_tensor + r_opt = (torch.Tensor(_opt) - v_tensor).abs() + r_max = v_tensor - torch.Tensor(_min) + if torch.any(r_min < 0) or torch.any(r_max < 0): + return (False, distance) + distance += r_opt.sum() + 0.5 * (r_max.sum() + 0.5 * r_min.sum()) + return (True, distance) + + def is_compatible( + self, width: int, height: int, batch_size: int, max_embedding: int + ): + distance = 0 + sample = self.profile["sample"] + embedding = self.profile["encoder_hidden_states"] + + batch_size *= 2 + width = width // 8 + height = height // 8 + + _min, _opt, _max = sample + if _min[0] > batch_size or _max[0] < batch_size: + return (False, distance) + if _min[2] > height or _max[2] < height: + return (False, distance) + if _min[3] > width or _max[3] < width: + return (False, distance) + + _min_em, _opt_em, _max_em = embedding + if _min_em[1] > max_embedding or _max_em[1] < max_embedding: + return (False, distance) + + distance = ( + abs(_opt[0] - batch_size) + + abs(_opt[2] - height) + + abs(_opt[3] - width) + + 0.5 * (abs(_max[2] - height) + abs(_max[3] - width)) + ) + + return (True, distance) + + +class ModelConfigEncoder(JSONEncoder): + def default(self, o: ModelConfig): + return o.__dict__ + + +@dataclass +class ProfileSettings: + bs_min: int + bs_opt: int + bs_max: int + h_min: int + h_opt: int + h_max: int + w_min: int + w_opt: int + w_max: int + t_min: int + t_opt: int + t_max: int + static_shape: bool = False + + def __str__(self) -> str: + return "Batch Size: {}-{}-{}\nHeight: {}-{}-{}\nWidth: {}-{}-{}\nToken Count: {}-{}-{}".format( + self.bs_min, + self.bs_opt, + self.bs_max, + self.h_min, + self.h_opt, + self.h_max, + self.w_min, + self.w_opt, + self.w_max, + self.t_min, + self.t_opt, + self.t_max, + ) + + def out(self): + return ( + self.bs_min, + self.bs_opt, + self.bs_max, + self.h_min, + self.h_opt, + self.h_max, + self.w_min, + self.w_opt, + self.w_max, + self.t_min, + self.t_opt, + self.t_max, + ) + + def token_to_dim(self, static_shapes: bool): + self.t_min = (self.t_min // 75) * 77 + self.t_opt = (self.t_opt // 75) * 77 + self.t_max = (self.t_max // 75) * 77 + + if static_shapes: + self.t_min = self.t_max = self.t_opt + self.bs_min = self.bs_max = self.bs_opt + self.h_min = self.h_max = self.h_opt + self.w_min = self.w_max = self.w_opt + self.static_shape = True + + def get_latent_dim(self): + return ( + self.h_min // 8, + self.h_opt // 8, + self.h_max // 8, + self.w_min // 8, + self.w_opt // 8, + self.w_max // 8, + ) + + def get_a1111_batch_dim(self): + static_batch = self.bs_min == self.bs_max == self.bs_opt + if self.t_max <= 77: + return (self.bs_min * 2, self.bs_opt * 2, self.bs_max * 2) + elif self.t_max > 77 and static_batch: + return (self.bs_opt, self.bs_opt, self.bs_opt) + elif self.t_max > 77 and not static_batch: + if self.t_opt > 77: + return (self.bs_min, self.bs_opt, self.bs_max * 2) + return (self.bs_min, self.bs_opt * 2, self.bs_max * 2) + else: + raise Exception("Uncovered case in get_batch_dim") + + +class ProfilePrests: + def __init__(self): + self.profile_presets = { + "512x512 | Batch Size 1 (Static)": ProfileSettings( + 1, 1, 1, 512, 512, 512, 512, 512, 512, 75, 75, 75 + ), + "768x768 | Batch Size 1 (Static)": ProfileSettings( + 1, 1, 1, 768, 768, 768, 768, 768, 768, 75, 75, 75 + ), + "1024x1024 | Batch Size 1 (Static)": ProfileSettings( + 1, 1, 1, 1024, 1024, 1024, 1024, 1024, 1024, 75, 75, 75 + ), + "256x256 - 512x512 | Batch Size 1-4": ProfileSettings( + 1, 1, 4, 256, 512, 512, 256, 512, 512, 75, 75, 150 + ), + "512x512 - 768x768 | Batch Size 1-4": ProfileSettings( + 1, 1, 4, 512, 512, 768, 512, 512, 768, 75, 75, 150 + ), + "768x768 - 1024x1024 | Batch Size 1-4": ProfileSettings( + 1, 1, 4, 768, 1024, 1024, 768, 1024, 1024, 75, 75, 150 + ), + } + self.default = ProfileSettings( + 1, 1, 4, 512, 512, 768, 512, 512, 768, 75, 75, 150 + ) + self.default_xl = ProfileSettings( + 1, 1, 1, 1024, 1024, 1024, 1024, 1024, 1024, 75, 75, 75 + ) + + def get_settings_from_version(self, version: str): + static = False + if version == "Default": + return *self.default.out(), static + if "Static" in version: + static = True + return *self.profile_presets[version].out(), static + + def get_choices(self): + return list(self.profile_presets.keys()) + ["Default"] + + def get_default(self, is_xl: bool): + if is_xl: + return self.default_xl + return self.default diff --git a/exporter.py b/exporter.py index e4c6c5b..fdc2379 100644 --- a/exporter.py +++ b/exporter.py @@ -1,23 +1,31 @@ +import os +import time +import shutil +import json +from pathlib import Path +from logging import info, error +from collections import OrderedDict +from typing import List, Tuple + import torch import torch.nn.functional as F +import numpy as np import onnx -from logging import info, error -import time -import shutil +from onnx import numpy_helper +from optimum.onnx.utils import ( + _get_onnx_external_data_tensors, + check_model_uses_external_data, +) -from modules import sd_hijack, sd_unet, shared -from utilities import Engine -import os +from modules import shared - -def get_cc(): - cc_major = torch.cuda.get_device_properties(0).major - cc_minor = torch.cuda.get_device_properties(0).minor - return cc_major, cc_minor +from utilities import Engine +from datastructures import ProfileSettings +from model_helper import UNetModel -def apply_lora(model, lora_path, inputs): +def apply_lora(model: torch.nn.Module, lora_path: str, inputs: Tuple[torch.Tensor]) -> torch.nn.Module: try: import sys @@ -34,53 +42,137 @@ def apply_lora(model, lora_path, inputs): lora_name = os.path.splitext(os.path.basename(lora_path))[0] networks.load_networks( [lora_name], [1.0], [1.0], [None] - ) # todo: UI for parameters, multiple loras -> Struct of Arrays? + ) model.forward(*inputs) return model -def export_onnx( - onnx_path, - modelobj=None, - profile=None, - opset=17, - diable_optimizations=False, - lora_path=None, -): - swap_sdpa = hasattr(F, "scaled_dot_product_attention") - old_sdpa = getattr(F, "scaled_dot_product_attention", None) if swap_sdpa else None - if swap_sdpa: - delattr(F, "scaled_dot_product_attention") +def get_refit_weights( + state_dict: dict, onnx_opt_path: str, weight_name_mapping: dict, weight_shape_mapping: dict +) -> dict: + refit_weights = OrderedDict() + onnx_opt_dir = os.path.dirname(onnx_opt_path) + onnx_opt_model = onnx.load(onnx_opt_path) + # Create initializer data hashes + initializer_hash_mapping = {} + onnx_data_mapping = {} + for initializer in onnx_opt_model.graph.initializer: + initializer_data = numpy_helper.to_array( + initializer, base_dir=onnx_opt_dir + ).astype(np.float16) + initializer_hash = hash(initializer_data.data.tobytes()) + initializer_hash_mapping[initializer.name] = initializer_hash + onnx_data_mapping[initializer.name] = initializer_data + + for torch_name, initializer_name in weight_name_mapping.items(): + initializer_hash = initializer_hash_mapping[initializer_name] + wt = state_dict[torch_name] + + # get shape transform info + initializer_shape, is_transpose = weight_shape_mapping[torch_name] + if is_transpose: + wt = torch.transpose(wt, 0, 1) + else: + wt = torch.reshape(wt, initializer_shape) + + # include weight if hashes differ + wt_hash = hash(wt.cpu().detach().numpy().astype(np.float16).data.tobytes()) + if initializer_hash != wt_hash: + delta = wt - torch.tensor(onnx_data_mapping[initializer_name]).to(wt.device) + refit_weights[initializer_name] = delta.contiguous() + + return refit_weights + + +def export_lora( + modelobj: UNetModel, + onnx_path: str, + weights_map_path: str, + lora_name: str, + profile: ProfileSettings, +) -> dict: + info("Exporting to ONNX...") + inputs = modelobj.get_sample_input( + profile.bs_opt * 2, + profile.h_opt // 8, + profile.w_opt // 8, + profile.t_opt, + ) + + with open(weights_map_path, "r") as fp_wts: + print(f"[I] Loading weights map: {weights_map_path} ") + [weights_name_mapping, weights_shape_mapping] = json.load(fp_wts) + + with torch.inference_mode(), torch.autocast("cuda"): + modelobj.unet = apply_lora( + modelobj.unet, os.path.splitext(lora_name)[0], inputs + ) - def disable_checkpoint(self): - if getattr(self, "use_checkpoint", False) == True: - self.use_checkpoint = False - if getattr(self, "checkpoint", False) == True: - self.checkpoint = False + refit_dict = get_refit_weights( + modelobj.unet.state_dict(), + onnx_path, + weights_name_mapping, + weights_shape_mapping, + ) - shared.sd_model.model.diffusion_model.apply(disable_checkpoint) - is_xl = shared.sd_model.is_sdxl + return refit_dict - sd_unet.apply_unet("None") - sd_hijack.model_hijack.apply_optimizations("None") - os.makedirs("onnx_tmp", exist_ok=True) - tmp_path = os.path.abspath(os.path.join("onnx_tmp", "tmp.onnx")) +def swap_sdpa(func): + def wrapper(*args, **kwargs): + swap_sdpa = hasattr(F, "scaled_dot_product_attention") + old_sdpa = ( + getattr(F, "scaled_dot_product_attention", None) if swap_sdpa else None + ) + if swap_sdpa: + delattr(F, "scaled_dot_product_attention") + ret = func(*args, **kwargs) + if swap_sdpa and old_sdpa: + setattr(F, "scaled_dot_product_attention", old_sdpa) + return ret + return wrapper + + +@swap_sdpa +def export_onnx( + onnx_path: str, + modelobj: UNetModel, + profile: ProfileSettings, + opset: int = 17, + diable_optimizations: bool = False, +): + info("Exporting to ONNX...") + inputs = modelobj.get_sample_input( + profile.bs_opt * 2, + profile.h_opt // 8, + profile.w_opt // 8, + profile.t_opt, + ) + + if not os.path.exists(onnx_path): + _export_onnx( + modelobj.unet, + inputs, + Path(onnx_path), + opset, + modelobj.get_input_names(), + modelobj.get_output_names(), + modelobj.get_dynamic_axes(), + modelobj.optimize if not diable_optimizations else None, + ) + + +def _export_onnx( + model: torch.nn.Module, inputs: Tuple[torch.Tensor], path: str, opset: int, in_names: List[str], out_names: List[str], dyn_axes: dict, optimizer=None +): + tmp_dir = os.path.abspath("onnx_tmp") + os.makedirs(tmp_dir, exist_ok=True) + tmp_path = os.path.join(tmp_dir, "model.onnx") try: info("Exporting to ONNX...") with torch.inference_mode(), torch.autocast("cuda"): - inputs = modelobj.get_sample_input( - profile["sample"][1][0] // 2, - profile["sample"][1][-2] * 8, - profile["sample"][1][-1] * 8, - ) - model = shared.sd_model.model.diffusion_model - - if lora_path: - model = apply_lora(model, lora_path, inputs) - torch.onnx.export( model, inputs, @@ -88,54 +180,47 @@ def disable_checkpoint(self): export_params=True, opset_version=opset, do_constant_folding=True, - input_names=modelobj.get_input_names(), - output_names=modelobj.get_output_names(), - dynamic_axes=modelobj.get_dynamic_axes(), + input_names=in_names, + output_names=out_names, + dynamic_axes=dyn_axes, ) - - info("Optimize ONNX.") - - onnx_graph = onnx.load(tmp_path) - if diable_optimizations: - onnx_opt_graph = onnx_graph - else: - onnx_opt_graph = modelobj.optimize(onnx_graph) - - if onnx_opt_graph.ByteSize() > 2147483648 or is_xl: - onnx.save_model( - onnx_opt_graph, - onnx_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - convert_attribute=False, - ) - else: - try: - onnx.save(onnx_opt_graph, onnx_path) - except Exception as e: - error(e) - error("ONNX file too large. Saving as external data.") - onnx.save_model( - onnx_opt_graph, - onnx_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - convert_attribute=False, - ) - info("ONNX export complete.") - del onnx_opt_graph except Exception as e: - error(e) - exit() - - # CleanUp - if swap_sdpa and old_sdpa: - setattr(F, "scaled_dot_product_attention", old_sdpa) - shutil.rmtree(os.path.abspath("onnx_tmp")) - del model - - -def export_trt(trt_path, onnx_path, timing_cache, profile, use_fp16): + error("Exporting to ONNX failed. {}".format(e)) + return + + info("Optimize ONNX.") + os.makedirs(path.parent, exist_ok=True) + onnx_model = onnx.load(tmp_path, load_external_data=False) + model_uses_external_data = check_model_uses_external_data(onnx_model) + + if model_uses_external_data: + info("ONNX model uses external data. Saving as external data.") + tensors_paths = _get_onnx_external_data_tensors(onnx_model) + onnx_model = onnx.load(tmp_path, load_external_data=True) + onnx.save( + onnx_model, + str(path), + save_as_external_data=True, + all_tensors_to_one_file=True, + location=path.name + "_data", + size_threshold=1024, + ) + + if optimizer is not None: + try: + onnx_opt_graph = optimizer("unet", onnx_model) + onnx.save(onnx_opt_graph, path) + except Exception as e: + error("Optimizing ONNX failed. {}".format(e)) + return + + if not model_uses_external_data and optimizer is None: + shutil.move(tmp_path, str(path)) + + shutil.rmtree(tmp_dir) + + +def export_trt(trt_path: str, onnx_path: str, timing_cache: str, profile: dict, use_fp16: bool): engine = Engine(trt_path) # TODO Still approx. 2gb of VRAM unaccounted for... diff --git a/info.md b/info.md index 8ef3797..64f3073 100644 --- a/info.md +++ b/info.md @@ -1,26 +1,32 @@ # TensorRT Extension -Use this extension to generate optimized engines and enable the best performance on NVIDIA RTX GPUs with TensorRT. Please follow the instructions below to set everything up. +This extension enables the best performance on NVIDIA RTX GPUs for Stable Diffusion with TensorRT. -## Set Up +## How to use -1. Click on the "Generate Default Engines" button. This step can take 2-10 min depending on your GPU. You can generate engines for other combinations. +1. Click on the “Generate Default Engines” button. This step takes 2-10 minutes depending on your GPU. You can generate engines for other combinations. 2. Go to Settings → User Interface → Quick Settings List, add sd_unet. Apply these settings, then reload the UI. -3. Back in the main UI, select the TRT model from the sd_unet dropdown menu at the top of the page. +3. Back in the main UI, select “Automatic” from the sd_unet dropdown menu at the top of the page if not already selected. 4. You can now start generating images accelerated by TRT. If you need to create more Engines, go to the TensorRT tab. Happy prompting! +### LoRA + +To use LoRA / LyCORIS checkpoints they first need to be converted to a TensorRT format. This can be done in the TensorRT extension in the Export LoRA tab. +1. Select a LoRA checkpoint from the dropdown. +2. Export. (This will not generate an engine but only convert the weights in ~20s) +3. You can use the exported LoRAs as usual using the prompt embedding. + + ## More Information TensorRT uses optimized engines for specific resolutions and batch sizes. You can generate as many optimized engines as desired. Types: - -- The "Export Default Engines" selection adds support for resolutions between 512x512 and 768x768 for Stable Diffusion 1.5 and 768x768 to 1024x1024 for SDXL with batch sizes 1 to 4. +- The "Export Default Engines” selection adds support for resolutions between `512 x 512` and 768x768 for Stable Diffusion 1.5 and 2.1 with batch sizes 1 to 4. For SDXL, this selection generates an engine supporting a resolution of `1024 x 1024` with a batch size of `1`. - Static engines support a single specific output resolution and batch size. -- Dynamic engines support a range of resolutions and batch sizes, at a small cost in performance. Wider ranges will use more VRAM. - ---- +- Dynamic engines support a range of resolutions and batch sizes, at a small cost in performance. Wider ranges will use more VRAM. +- The first time generating an engine for a checkpoint will take longer. Additional engines generated for the same checkpoint will be much faster. -Each preset can be adjusted with the "Advanced Settings" option. +Each preset can be adjusted with the “Advanced Settings” option. More detailed instructions can be found [here](https://nvidia.custhelp.com/app/answers/detail/a_id/5487/~/tensorrt-extension-for-stable-diffusion-web-ui). For more information, please visit the TensorRT Extension GitHub page [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui-tensorrt). diff --git a/install.py b/install.py index 7c07285..b02753d 100644 --- a/install.py +++ b/install.py @@ -1,31 +1,67 @@ import launch +import sys from importlib_metadata import version +python = sys.executable + + def install(): if launch.is_installed("tensorrt"): if not version("tensorrt") == "9.0.1.post11.dev4": - launch.run(["python","-m","pip","uninstall","-y","tensorrt"], "removing old version of tensorrt") - - + launch.run( + ["python", "-m", "pip", "uninstall", "-y", "tensorrt"], + "removing old version of tensorrt", + ) + if not launch.is_installed("tensorrt"): print("TensorRT is not installed! Installing...") - launch.run_pip("install nvidia-cudnn-cu11==8.9.4.25 --no-cache-dir", "nvidia-cudnn-cu11") - launch.run_pip("install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir", "tensorrt", live=True) - launch.run(["python","-m","pip","uninstall","-y","nvidia-cudnn-cu11"], "removing nvidia-cudnn-cu11") - + launch.run_pip( + "install nvidia-cudnn-cu11==8.9.4.25 --no-cache-dir", "nvidia-cudnn-cu11" + ) + launch.run_pip( + "install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir", + "tensorrt", + live=True, + ) + launch.run( + ["python", "-m", "pip", "uninstall", "-y", "nvidia-cudnn-cu11"], + "removing nvidia-cudnn-cu11", + ) + if launch.is_installed("nvidia-cudnn-cu11"): if version("nvidia-cudnn-cu11") == "8.9.4.25": - launch.run(["python","-m","pip","uninstall","-y","nvidia-cudnn-cu11"], "removing nvidia-cudnn-cu11") + launch.run( + ["python", "-m", "pip", "uninstall", "-y", "nvidia-cudnn-cu11"], + "removing nvidia-cudnn-cu11", + ) - # Polygraphy + # Polygraphy if not launch.is_installed("polygraphy"): print("Polygraphy is not installed! Installing...") - launch.run_pip("install polygraphy --extra-index-url https://pypi.ngc.nvidia.com", "polygraphy", live=True) - + launch.run_pip( + "install polygraphy --extra-index-url https://pypi.ngc.nvidia.com", + "polygraphy", + live=True, + ) + # ONNX GS if not launch.is_installed("onnx_graphsurgeon"): print("GS is not installed! Installing...") launch.run_pip("install protobuf==3.20.2", "protobuf", live=True) - launch.run_pip('install onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com', "onnx-graphsurgeon", live=True) - -install() \ No newline at end of file + launch.run_pip( + "install onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com", + "onnx-graphsurgeon", + live=True, + ) + + # OPTIMUM + if not launch.is_installed("optimum"): + print("Optimum is not installed! Installing...") + launch.run_pip( + "install optimum", + "optimum", + live=True, + ) + + +install() diff --git a/model_helper.py b/model_helper.py new file mode 100644 index 0000000..9a68d5d --- /dev/null +++ b/model_helper.py @@ -0,0 +1,309 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import json +import tempfile +from typing import List, Tuple + +import torch +import numpy as np +import onnx +from onnx import shape_inference, numpy_helper +import onnx_graphsurgeon as gs +from polygraphy.backend.onnx.loader import fold_constants + +from modules import sd_hijack, sd_unet + +from datastructures import ProfileSettings + + +class UNetModel(torch.nn.Module): + def __init__( + self, unet, embedding_dim: int, text_minlen: int = 77, is_xl: bool = False + ) -> None: + super().__init__() + self.unet = unet + self.is_xl = is_xl + + self.text_minlen = text_minlen + self.embedding_dim = embedding_dim + self.num_xl_classes = 2816 # Magic number for num_classes + self.emb_chn = 1280 + self.in_channels = self.unet.in_channels + + self.dyn_axes = { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B", 1: "77N"}, + "timesteps": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + "y": {0: "2B"}, + } + + def apply_torch_model(self): + def disable_checkpoint(self): + if getattr(self, "use_checkpoint", False) == True: + self.use_checkpoint = False + if getattr(self, "checkpoint", False) == True: + self.checkpoint = False + + self.unet.apply(disable_checkpoint) + self.set_unet("None") + + def set_unet(self, ckpt: str): + # TODO test if using this with TRT works + sd_unet.apply_unet(ckpt) + sd_hijack.model_hijack.apply_optimizations(ckpt) + + def get_input_names(self) -> List[str]: + names = ["sample", "timesteps", "encoder_hidden_states"] + if self.is_xl: + names.append("y") + return names + + def get_output_names(self) -> List[str]: + return ["latent"] + + def get_dynamic_axes(self) -> dict: + io_names = self.get_input_names() + self.get_output_names() + dyn_axes = {name: self.dyn_axes[name] for name in io_names} + return dyn_axes + + def get_sample_input( + self, + batch_size: int, + latent_height: int, + latent_width: int, + text_len: int, + device: str = "cuda", + dtype: torch.dtype = torch.float32, + ) -> Tuple[torch.Tensor]: + return ( + torch.randn( + batch_size, + self.in_channels, + latent_height, + latent_width, + dtype=dtype, + device=device, + ), + torch.randn(batch_size, dtype=dtype, device=device), + torch.randn( + batch_size, + text_len, + self.embedding_dim, + dtype=dtype, + device=device, + ), + torch.randn(batch_size, self.num_xl_classes, dtype=dtype, device=device) + if self.is_xl + else None, + ) + + def get_input_profile(self, profile: ProfileSettings) -> dict: + min_batch, opt_batch, max_batch = profile.get_a1111_batch_dim() + ( + min_latent_height, + latent_height, + max_latent_height, + min_latent_width, + latent_width, + max_latent_width, + ) = profile.get_latent_dim() + + shape_dict = { + "sample": [ + (min_batch, self.unet.in_channels, min_latent_height, min_latent_width), + (opt_batch, self.unet.in_channels, latent_height, latent_width), + (max_batch, self.unet.in_channels, max_latent_height, max_latent_width), + ], + "timesteps": [(min_batch,), (opt_batch,), (max_batch,)], + "encoder_hidden_states": [ + (min_batch, profile.t_min, self.embedding_dim), + (opt_batch, profile.t_opt, self.embedding_dim), + (max_batch, profile.t_max, self.embedding_dim), + ], + } + if self.is_xl: + shape_dict["y"] = [ + (min_batch, self.num_xl_classes), + (opt_batch, self.num_xl_classes), + (max_batch, self.num_xl_classes), + ] + + return shape_dict + + # Helper utility for weights map + def export_weights_map(self, onnx_opt_path: str, weights_map_path: dict): + onnx_opt_dir = onnx_opt_path + state_dict = self.unet.state_dict() + onnx_opt_model = onnx.load(onnx_opt_path) + + # Create initializer data hashes + def init_hash_map(onnx_opt_model): + initializer_hash_mapping = {} + for initializer in onnx_opt_model.graph.initializer: + initializer_data = numpy_helper.to_array( + initializer, base_dir=onnx_opt_dir + ).astype(np.float16) + initializer_hash = hash(initializer_data.data.tobytes()) + initializer_hash_mapping[initializer.name] = ( + initializer_hash, + initializer_data.shape, + ) + return initializer_hash_mapping + + initializer_hash_mapping = init_hash_map(onnx_opt_model) + + weights_name_mapping = {} + weights_shape_mapping = {} + # set to keep track of initializers already added to the name_mapping dict + initializers_mapped = set() + for wt_name, wt in state_dict.items(): + # get weight hash + wt = wt.cpu().detach().numpy().astype(np.float16) + wt_hash = hash(wt.data.tobytes()) + wt_t_hash = hash(np.transpose(wt).data.tobytes()) + + for initializer_name, ( + initializer_hash, + initializer_shape, + ) in initializer_hash_mapping.items(): + # Due to constant folding, some weights are transposed during export + # To account for the transpose op, we compare the initializer hash to the + # hash for the weight and its transpose + if wt_hash == initializer_hash or wt_t_hash == initializer_hash: + # The assert below ensures there is a 1:1 mapping between + # PyTorch and ONNX weight names. It can be removed in cases where 1:many + # mapping is found and name_mapping[wt_name] = list() + assert initializer_name not in initializers_mapped + weights_name_mapping[wt_name] = initializer_name + initializers_mapped.add(initializer_name) + is_transpose = False if wt_hash == initializer_hash else True + weights_shape_mapping[wt_name] = ( + initializer_shape, + is_transpose, + ) + + # Sanity check: Were any weights not matched + if wt_name not in weights_name_mapping: + print( + f"[I] PyTorch weight {wt_name} not matched with any ONNX initializer" + ) + print( + f"[I] UNet: {len(weights_name_mapping.keys())} PyTorch weights were matched with ONNX initializers" + ) + + assert weights_name_mapping.keys() == weights_shape_mapping.keys() + with open(weights_map_path, "w") as fp: + json.dump([weights_name_mapping, weights_shape_mapping], fp) + + @staticmethod + def optimize(name, onnx_graph, verbose=False): + opt = Optimizer(onnx_graph, verbose=verbose) + opt.info(name + ": original") + opt.cleanup() + opt.info(name + ": cleanup") + opt.fold_constants() + opt.info(name + ": fold constants") + opt.infer_shapes() + opt.info(name + ": shape inference") + onnx_opt_graph = opt.cleanup(return_onnx=True) + opt.info(name + ": finished") + return onnx_opt_graph + + +class Optimizer: + def __init__(self, onnx_graph, verbose=False): + self.graph = gs.import_onnx(onnx_graph) + self.verbose = verbose + + def info(self, prefix): + if self.verbose: + print( + f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs" + ) + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + if return_onnx: + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self, return_onnx=False): + onnx_graph = fold_constants( + gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True + ) + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes(self, return_onnx=False): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + temp_dir = tempfile.TemporaryDirectory().name + os.makedirs(temp_dir, exist_ok=True) + onnx_orig_path = os.path.join(temp_dir, "model.onnx") + onnx_inferred_path = os.path.join(temp_dir, "inferred.onnx") + onnx.save_model( + onnx_graph, + onnx_orig_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + onnx.shape_inference.infer_shapes_path(onnx_orig_path, onnx_inferred_path) + onnx_graph = onnx.load(onnx_inferred_path) + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def clip_add_hidden_states(self, return_onnx=False): + hidden_layers = -1 + onnx_graph = gs.export_onnx(self.graph) + for i in range(len(onnx_graph.graph.node)): + for j in range(len(onnx_graph.graph.node[i].output)): + name = onnx_graph.graph.node[i].output[j] + if "layers" in name: + hidden_layers = max( + int(name.split(".")[1].split("/")[0]), hidden_layers + ) + for i in range(len(onnx_graph.graph.node)): + for j in range(len(onnx_graph.graph.node[i].output)): + if onnx_graph.graph.node[i].output[ + j + ] == "/text_model/encoder/layers.{}/Add_1_output_0".format( + hidden_layers - 1 + ): + onnx_graph.graph.node[i].output[j] = "hidden_states" + for j in range(len(onnx_graph.graph.node[i].input)): + if onnx_graph.graph.node[i].input[ + j + ] == "/text_model/encoder/layers.{}/Add_1_output_0".format( + hidden_layers - 1 + ): + onnx_graph.graph.node[i].input[j] = "hidden_states" + if return_onnx: + return onnx_graph diff --git a/model_manager.py b/model_manager.py index b4544cf..9f248c8 100644 --- a/model_manager.py +++ b/model_manager.py @@ -1,13 +1,14 @@ -import json -from json import JSONEncoder - import os +import json +import copy from logging import info, warning -from dataclasses import dataclass import torch -from exporter import get_cc + from modules import paths_internal +from datastructures import ModelConfig, ModelConfigEncoder + + ONNX_MODEL_DIR = os.path.join(paths_internal.models_path, "Unet-onnx") if not os.path.exists(ONNX_MODEL_DIR): os.makedirs(ONNX_MODEL_DIR) @@ -19,6 +20,13 @@ MODEL_FILE = os.path.join(TRT_MODEL_DIR, "model.json") + +def get_cc(): + cc_major = torch.cuda.get_device_properties(0).major + cc_minor = torch.cuda.get_device_properties(0).minor + return cc_major, cc_minor + + cc_major, cc_minor = get_cc() @@ -35,8 +43,8 @@ def __init__(self, model_file=MODEL_FILE) -> None: self.update() @staticmethod - def get_onnx_path(model_name, model_hash): - onnx_filename = "_".join([model_name, model_hash]) + ".onnx" + def get_onnx_path(model_name): + onnx_filename = f"{model_name}.onnx" onnx_path = os.path.join(ONNX_MODEL_DIR, onnx_filename) return onnx_filename, onnx_path @@ -57,6 +65,9 @@ def get_trt_path(self, model_name, model_hash, profile, static_shape): return trt_filename, trt_path + def get_weights_map_path(self, model_name: str): + return os.path.join(TRT_MODEL_DIR, f"{model_name}_weights_map.json") + def update(self): trt_engines = [ trt_file @@ -64,19 +75,21 @@ def update(self): if trt_file.endswith(".trt") ] - tmp_all_models = self.all_models.copy() + tmp_all_models = copy.deepcopy(self.all_models) for cc, base_models in tmp_all_models.items(): for base_model, models in base_models.items(): tmp_config_list = {} for model_config in models: if model_config["filepath"] not in trt_engines: info( - "Model config outdated. {} was not found".format(model_config["filepath"]) + "Model config outdated. {} was not found".format( + model_config["filepath"] + ) ) continue tmp_config_list[model_config["filepath"]] = model_config - - tmp_config_list = list(tmp_config_list.values()) + + tmp_config_list = list(tmp_config_list.values()) if len(tmp_config_list) == 0: self.all_models[cc].pop(base_model) else: @@ -84,7 +97,6 @@ def update(self): self.write_json() - def __del__(self): self.update() @@ -163,6 +175,15 @@ def available_models(self): available = self.all_models.get(self.cc, {}) return available + def available_loras(self): + available = {} + for p in os.listdir(TRT_MODEL_DIR): + if not p.endswith(".lora"): + continue + available[os.path.splitext(p)[0]] = os.path.join(TRT_MODEL_DIR, p) + + return available + def get_timing_cache(self): current_dir = os.path.dirname(os.path.abspath(__file__)) cache = os.path.join( @@ -175,47 +196,42 @@ def get_timing_cache(self): return cache - def get_valid_models(self, base_model: str, feed_dict: dict): + def get_valid_models_from_dict(self, base_model: str, feed_dict: dict): + valid_models = [] + distances = [] + idx = [] + models = self.available_models() + for i, model in enumerate(models[base_model]): + valid, distance = model["config"].is_compatible_from_dict(feed_dict) + if valid: + valid_models.append(model) + distances.append(distance) + idx.append(i) + + return valid_models, distances, idx + + def get_valid_models( + self, + base_model: str, + width: int, + height: int, + batch_size: int, + max_embedding: int, + ): valid_models = [] distances = [] + idx = [] models = self.available_models() - for model in models[base_model]: - valid, distance = model["config"].is_compatible(feed_dict) + for i, model in enumerate(models[base_model]): + valid, distance = model["config"].is_compatible( + width, height, batch_size, max_embedding + ) if valid: valid_models.append(model) distances.append(distance) + idx.append(i) - return valid_models, distances - - -@dataclass -class ModelConfig: - profile: dict - static_shapes: bool - fp32: bool - inpaint: bool - refit: bool - lora: bool - vram: int - unet_hidden_dim: int = 4 - - def is_compatible(self, feed_dict: dict): - distance = 0 - for k, v in feed_dict.items(): - _min, _opt, _max = self.profile[k] - v_tensor = torch.Tensor(list(v.shape)) - r_min = torch.Tensor(_max) - v_tensor - r_opt = (torch.Tensor(_opt) - v_tensor).abs() - r_max = v_tensor - torch.Tensor(_min) - if torch.any(r_min < 0) or torch.any(r_max < 0): - return (False, distance) - distance += r_opt.sum() + 0.5 * (r_max.sum() + 0.5 * r_min.sum()) - return (True, distance) - - -class ModelConfigEncoder(JSONEncoder): - def default(self, o: ModelConfig): - return o.__dict__ + return valid_models, distances, idx modelmanager = ModelManager() diff --git a/models.py b/models.py deleted file mode 100644 index 1355c2b..0000000 --- a/models.py +++ /dev/null @@ -1,1392 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import onnx -from onnx import shape_inference -import os -from polygraphy.backend.onnx.loader import fold_constants -import tempfile -import torch -import torch.nn.functional as F -import onnx_graphsurgeon as gs - - -class Optimizer: - def __init__(self, onnx_graph, verbose=False): - self.graph = gs.import_onnx(onnx_graph) - self.verbose = verbose - - def info(self, prefix): - if self.verbose: - print( - f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs" - ) - - def cleanup(self, return_onnx=False): - self.graph.cleanup().toposort() - if return_onnx: - return gs.export_onnx(self.graph) - - def select_outputs(self, keep, names=None): - self.graph.outputs = [self.graph.outputs[o] for o in keep] - if names: - for i, name in enumerate(names): - self.graph.outputs[i].name = name - - def fold_constants(self, return_onnx=False): - onnx_graph = fold_constants( - gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True - ) - self.graph = gs.import_onnx(onnx_graph) - if return_onnx: - return onnx_graph - - def infer_shapes(self, return_onnx=False): - onnx_graph = gs.export_onnx(self.graph) - if onnx_graph.ByteSize() > 2147483648: - temp_dir = tempfile.TemporaryDirectory().name - os.makedirs(temp_dir, exist_ok=True) - onnx_orig_path = os.path.join(temp_dir, "model.onnx") - onnx_inferred_path = os.path.join(temp_dir, "inferred.onnx") - onnx.save_model( - onnx_graph, - onnx_orig_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - convert_attribute=False, - ) - onnx.shape_inference.infer_shapes_path(onnx_orig_path, onnx_inferred_path) - onnx_graph = onnx.load(onnx_inferred_path) - else: - onnx_graph = shape_inference.infer_shapes(onnx_graph) - - self.graph = gs.import_onnx(onnx_graph) - if return_onnx: - return onnx_graph - - def clip_add_hidden_states(self, return_onnx=False): - hidden_layers = -1 - onnx_graph = gs.export_onnx(self.graph) - for i in range(len(onnx_graph.graph.node)): - for j in range(len(onnx_graph.graph.node[i].output)): - name = onnx_graph.graph.node[i].output[j] - if "layers" in name: - hidden_layers = max( - int(name.split(".")[1].split("/")[0]), hidden_layers - ) - for i in range(len(onnx_graph.graph.node)): - for j in range(len(onnx_graph.graph.node[i].output)): - if onnx_graph.graph.node[i].output[ - j - ] == "/text_model/encoder/layers.{}/Add_1_output_0".format( - hidden_layers - 1 - ): - onnx_graph.graph.node[i].output[j] = "hidden_states" - for j in range(len(onnx_graph.graph.node[i].input)): - if onnx_graph.graph.node[i].input[ - j - ] == "/text_model/encoder/layers.{}/Add_1_output_0".format( - hidden_layers - 1 - ): - onnx_graph.graph.node[i].input[j] = "hidden_states" - if return_onnx: - return onnx_graph - - -def get_controlnets_path(controlnet_list): - """ - Currently ControlNet 1.0 is supported. - """ - if controlnet_list is None: - return None - return ["lllyasviel/sd-controlnet-" + controlnet for controlnet in controlnet_list] - - -def get_path(version, pipeline, controlnet=None): - if controlnet is not None: - return ["lllyasviel/sd-controlnet-" + modality for modality in controlnet] - - if version == "1.4": - if pipeline.is_inpaint(): - return "runwayml/stable-diffusion-inpainting" - else: - return "CompVis/stable-diffusion-v1-4" - elif version == "1.5": - if pipeline.is_inpaint(): - return "runwayml/stable-diffusion-inpainting" - else: - return "runwayml/stable-diffusion-v1-5" - elif version == "2.0-base": - if pipeline.is_inpaint(): - return "stabilityai/stable-diffusion-2-inpainting" - else: - return "stabilityai/stable-diffusion-2-base" - elif version == "2.0": - if pipeline.is_inpaint(): - return "stabilityai/stable-diffusion-2-inpainting" - else: - return "stabilityai/stable-diffusion-2" - elif version == "2.1": - return "stabilityai/stable-diffusion-2-1" - elif version == "2.1-base": - return "stabilityai/stable-diffusion-2-1-base" - elif version == "xl-1.0": - if pipeline.is_sd_xl_base(): - return "stabilityai/stable-diffusion-xl-base-1.0" - elif pipeline.is_sd_xl_refiner(): - return "stabilityai/stable-diffusion-xl-refiner-1.0" - else: - raise ValueError(f"Unsupported SDXL 1.0 pipeline {pipeline.name}") - else: - raise ValueError(f"Incorrect version {version}") - - -def get_clip_embedding_dim(version, pipeline): - if version in ("1.4", "1.5"): - return 768 - elif version in ("2.0", "2.0-base", "2.1", "2.1-base"): - return 1024 - elif version in ("xl-1.0") and pipeline.is_sd_xl_base(): - return 768 - else: - raise ValueError(f"Invalid version {version} + pipeline {pipeline}") - - -def get_clipwithproj_embedding_dim(version, pipeline): - if version in ("xl-1.0"): - return 1280 - else: - raise ValueError(f"Invalid version {version} + pipeline {pipeline}") - - -def get_unet_embedding_dim(version, pipeline): - if version in ("1.4", "1.5"): - return 768 - elif version in ("2.0", "2.0-base", "2.1", "2.1-base"): - return 1024 - elif version in ("xl-1.0") and pipeline.is_sd_xl_base(): - return 2048 - elif version in ("xl-1.0") and pipeline.is_sd_xl_refiner(): - return 1280 - else: - raise ValueError(f"Invalid version {version} + pipeline {pipeline}") - - -class BaseModel: - def __init__( - self, - version="1.5", - pipeline=None, - hf_token="", - device="cuda", - verbose=True, - fp16=False, - max_batch_size=16, - text_maxlen=77, - embedding_dim=768, - controlnet=None, - ): - self.name = self.__class__.__name__ - self.pipeline = pipeline.name - self.version = version - self.hf_token = hf_token - self.hf_safetensor = pipeline.is_sd_xl() - self.device = device - self.verbose = verbose - self.path = get_path(version, pipeline, controlnet) - - self.fp16 = fp16 - - self.min_batch = 1 - self.max_batch = max_batch_size - self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 768 if version in ("1.4", "1.5") else 1024 - self.min_latent_shape = self.min_image_shape // 8 - self.max_latent_shape = self.max_image_shape // 8 - - self.text_maxlen = text_maxlen - self.embedding_dim = embedding_dim - self.extra_output_names = [] - - def get_input_names(self): - pass - - def get_output_names(self): - pass - - def get_dynamic_axes(self): - return None - - def get_sample_input(self, batch_size, image_height, image_width): - pass - - def get_input_profile( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - return None - - def get_shape_dict(self, batch_size, image_height, image_width): - return None - - def optimize(self, onnx_graph): - opt = Optimizer(onnx_graph, verbose=self.verbose) - opt.info(self.name + ": original") - opt.cleanup() - opt.info(self.name + ": cleanup") - opt.fold_constants() - opt.info(self.name + ": fold constants") - opt.infer_shapes() - opt.info(self.name + ": shape inference") - onnx_opt_graph = opt.cleanup(return_onnx=True) - opt.info(self.name + ": finished") - return onnx_opt_graph - - def check_dims(self, batch_size, image_height, image_width): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - assert image_height % 8 == 0 or image_width % 8 == 0 - latent_height = image_height // 8 - latent_width = image_width // 8 - assert ( - latent_height >= self.min_latent_shape - and latent_height <= self.max_latent_shape - ) - assert ( - latent_width >= self.min_latent_shape - and latent_width <= self.max_latent_shape - ) - return (latent_height, latent_width) - - def get_minmax_dims( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - latent_height = image_height // 8 - latent_width = image_width // 8 - min_image_height = image_height if static_shape else self.min_image_shape - max_image_height = image_height if static_shape else self.max_image_shape - min_image_width = image_width if static_shape else self.min_image_shape - max_image_width = image_width if static_shape else self.max_image_shape - min_latent_height = latent_height if static_shape else self.min_latent_shape - max_latent_height = latent_height if static_shape else self.max_latent_shape - min_latent_width = latent_width if static_shape else self.min_latent_shape - max_latent_width = latent_width if static_shape else self.max_latent_shape - return ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) - - def get_latent_dim(self, min_h, opt_h, max_h, min_w, opt_w, max_w, static_shape): - if static_shape: - return ( - opt_h // 8, - opt_h // 8, - opt_h // 8, - opt_w // 8, - opt_w // 8, - opt_w // 8, - ) - return min_h // 8, opt_h // 8, max_h // 8, min_w // 8, opt_w // 8, max_w // 8 - - def get_batch_dim(self, min_batch, opt_batch, max_batch, static_batch): - if self.text_maxlen <= 77: - return (min_batch * 2, opt_batch * 2, max_batch * 2) - elif self.text_maxlen > 77 and static_batch: - return (opt_batch, opt_batch, opt_batch) - elif self.text_maxlen > 77 and not static_batch: - if self.text_optlen > 77: - return (min_batch, opt_batch, max_batch * 2) - return (min_batch, opt_batch * 2, max_batch * 2) - else: - raise Exception("Uncovered case in get_batch_dim") - - -class CLIP(BaseModel): - def __init__( - self, - version, - pipeline, - hf_token, - device, - verbose, - max_batch_size, - embedding_dim, - output_hidden_states=False, - subfolder="text_encoder", - ): - super(CLIP, self).__init__( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - self.subfolder = subfolder - - # Output the final hidden state - if output_hidden_states: - self.extra_output_names = ["hidden_states"] - - def get_input_names(self): - return ["input_ids"] - - def get_output_names(self): - return ["text_embeddings"] - - def get_dynamic_axes(self): - return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} - - def get_input_profile( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_shape - ) - return { - "input_ids": [ - (min_batch, self.text_maxlen), - (batch_size, self.text_maxlen), - (max_batch, self.text_maxlen), - ] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - output = { - "input_ids": (batch_size, self.text_maxlen), - "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), - } - if "hidden_states" in self.extra_output_names: - output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) - return output - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.zeros( - batch_size, self.text_maxlen, dtype=torch.int32, device=self.device - ) - - def optimize(self, onnx_graph): - opt = Optimizer(onnx_graph, verbose=self.verbose) - opt.info(self.name + ": original") - opt.select_outputs([0]) # delete graph output#1 - opt.cleanup() - opt.info(self.name + ": remove output[1]") - opt.fold_constants() - opt.info(self.name + ": fold constants") - opt.infer_shapes() - opt.info(self.name + ": shape inference") - opt.select_outputs([0], names=["text_embeddings"]) # rename network output - opt.info(self.name + ": remove output[0]") - opt_onnx_graph = opt.cleanup(return_onnx=True) - if "hidden_states" in self.extra_output_names: - opt_onnx_graph = opt.clip_add_hidden_states(return_onnx=True) - opt.info(self.name + ": added hidden_states") - opt.info(self.name + ": finished") - return opt_onnx_graph - - -def make_CLIP( - version, - pipeline, - hf_token, - device, - verbose, - max_batch_size, - output_hidden_states=False, - subfolder="text_encoder", -): - return CLIP( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - embedding_dim=get_clip_embedding_dim(version, pipeline), - output_hidden_states=output_hidden_states, - subfolder=subfolder, - ) - - -class CLIPWithProj(CLIP): - def __init__( - self, - version, - pipeline, - hf_token, - device="cuda", - verbose=True, - max_batch_size=16, - output_hidden_states=False, - subfolder="text_encoder_2", - ): - super(CLIPWithProj, self).__init__( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - embedding_dim=get_clipwithproj_embedding_dim(version, pipeline), - output_hidden_states=output_hidden_states, - ) - self.subfolder = subfolder - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - output = { - "input_ids": (batch_size, self.text_maxlen), - "text_embeddings": (batch_size, self.embedding_dim), - } - if "hidden_states" in self.extra_output_names: - output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) - - return output - - -def make_CLIPWithProj( - version, - pipeline, - hf_token, - device, - verbose, - max_batch_size, - subfolder="text_encoder_2", - output_hidden_states=False, -): - return CLIPWithProj( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - subfolder=subfolder, - output_hidden_states=output_hidden_states, - ) - - -class UNet2DConditionControlNetModel(torch.nn.Module): - def __init__(self, unet, controlnets) -> None: - super().__init__() - self.unet = unet - self.controlnets = controlnets - - def forward( - self, sample, timestep, encoder_hidden_states, images, controlnet_scales - ): - for i, (image, conditioning_scale, controlnet) in enumerate( - zip(images, controlnet_scales, self.controlnets) - ): - down_samples, mid_sample = controlnet( - sample, - timestep, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=image, - return_dict=False, - ) - - down_samples = [ - down_sample * conditioning_scale for down_sample in down_samples - ] - mid_sample *= conditioning_scale - - # merge samples - if i == 0: - down_block_res_samples, mid_block_res_sample = down_samples, mid_sample - else: - down_block_res_samples = [ - samples_prev + samples_curr - for samples_prev, samples_curr in zip( - down_block_res_samples, down_samples - ) - ] - mid_block_res_sample += mid_sample - - noise_pred = self.unet( - sample, - timestep, - encoder_hidden_states=encoder_hidden_states, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) - return noise_pred - - -class UNet(BaseModel): - def __init__( - self, - version, - pipeline, - hf_token, - device="cuda", - verbose=True, - fp16=False, - max_batch_size=16, - text_maxlen=77, - unet_dim=4, - controlnet=None, - ): - super(UNet, self).__init__( - version, - pipeline, - hf_token, - fp16=fp16, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - text_maxlen=text_maxlen, - embedding_dim=get_unet_embedding_dim(version, pipeline), - ) - self.unet_dim = unet_dim - self.controlnet = controlnet - - # def get_model(self, framework_model_dir): - # model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if self.fp16 else {} - # if self.controlnet: - # unet_model = UNet2DConditionModel.from_pretrained(self.path, - # subfolder="unet", - # use_safetensors=self.hf_safetensor, - # use_auth_token=self.hf_token, - # **model_opts).to(self.device) - - # cnet_model_opts = {'torch_dtype': torch.float16} if self.fp16 else {} - # controlnets = torch.nn.ModuleList([ControlNetModel.from_pretrained(path, **cnet_model_opts).to(self.device) for path in self.controlnet]) - # # FIXME - cache UNet2DConditionControlNetModel - # model = UNet2DConditionControlNetModel(unet_model, controlnets) - # else: - # unet_model_dir = os.path.join(framework_model_dir, self.version, self.pipeline, "unet") - # if not os.path.exists(unet_model_dir): - # model = UNet2DConditionModel.from_pretrained(self.path, - # subfolder="unet", - # use_safetensors=self.hf_safetensor, - # use_auth_token=self.hf_token, - # **model_opts).to(self.device) - # model.save_pretrained(unet_model_dir) - # else: - # print(f"[I] Load UNet pytorch model from: {unet_model_dir}") - # model = UNet2DConditionModel.from_pretrained(unet_model_dir).to(self.device) - # return model - - def get_input_names(self): - if self.controlnet is None: - return ["sample", "timestep", "encoder_hidden_states"] - else: - return [ - "sample", - "timestep", - "encoder_hidden_states", - "images", - "controlnet_scales", - ] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - if self.controlnet is None: - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - else: - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "images": {1: "2B", 3: "8H", 4: "8W"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - - def get_input_profile( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_shape - ) - if self.controlnet is None: - return { - "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), - ], - } - else: - return { - "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), - ], - "images": [ - ( - len(self.controlnet), - 2 * min_batch, - 3, - min_image_height, - min_image_width, - ), - ( - len(self.controlnet), - 2 * batch_size, - 3, - image_height, - image_width, - ), - ( - len(self.controlnet), - 2 * max_batch, - 3, - max_image_height, - max_image_width, - ), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - if self.controlnet is None: - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "encoder_hidden_states": ( - 2 * batch_size, - self.text_maxlen, - self.embedding_dim, - ), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - else: - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "encoder_hidden_states": ( - 2 * batch_size, - self.text_maxlen, - self.embedding_dim, - ), - "images": ( - len(self.controlnet), - 2 * batch_size, - 3, - image_height, - image_width, - ), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - dtype = torch.float16 if self.fp16 else torch.float32 - if self.controlnet is None: - return ( - torch.randn( - batch_size, - self.unet_dim, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn( - batch_size, - self.text_maxlen, - self.embedding_dim, - dtype=dtype, - device=self.device, - ), - ) - else: - return ( - torch.randn( - batch_size, - self.unet_dim, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ), - torch.tensor(999, dtype=torch.float32, device=self.device), - torch.randn( - batch_size, - self.text_maxlen, - self.embedding_dim, - dtype=dtype, - device=self.device, - ), - torch.randn( - len(self.controlnet), - batch_size, - 3, - image_height, - image_width, - dtype=dtype, - device=self.device, - ), - torch.randn(len(self.controlnet), dtype=dtype, device=self.device), - ) - - -def make_UNet( - version, pipeline, hf_token, device, verbose, max_batch_size, controlnet=None -): - # Disable torch SDPA - # if hasattr(F, "scaled_dot_product_attention"): - # delattr(F, "scaled_dot_product_attention") - - return UNet( - version, - pipeline, - hf_token, - fp16=True, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - unet_dim=(9 if pipeline.is_inpaint() else 4), - controlnet=get_controlnets_path(controlnet), - ) - - -class OAIUNet(BaseModel): - def __init__( - self, - version, - pipeline, - device="cuda", - verbose=True, - fp16=False, - max_batch_size=16, - text_maxlen=77, - text_optlen=77, - unet_dim=4, - controlnet=None, - ): - super(OAIUNet, self).__init__( - version, - pipeline, - "", - fp16=fp16, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - text_maxlen=text_maxlen, - embedding_dim=get_unet_embedding_dim(version, pipeline), - ) - self.unet_dim = unet_dim - self.controlnet = controlnet - self.text_optlen = text_optlen - - def get_input_names(self): - if self.controlnet is None: - return ["sample", "timesteps", "encoder_hidden_states"] - else: - return [ - "sample", - "timesteps", - "encoder_hidden_states", - "images", - "controlnet_scales", - ] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - if self.controlnet is None: - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "timesteps": {0: "2B"}, - "encoder_hidden_states": {0: "2B", 1: "77N"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - else: - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "timesteps": {0: "2B"}, - "encoder_hidden_states": {0: "2B", 1: "77N"}, - "images": {1: "2B", 3: "8H", 4: "8W"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - - def get_input_profile( - self, - min_batch, - opt_batch, - max_batch, - min_h, - opt_h, - max_h, - min_w, - opt_w, - max_w, - static_shape, - ): - min_batch, opt_batch, max_batch = self.get_batch_dim( - min_batch, opt_batch, max_batch, static_shape - ) - ( - min_latent_height, - latent_height, - max_latent_height, - min_latent_width, - latent_width, - max_latent_width, - ) = self.get_latent_dim(min_h, opt_h, max_h, min_w, opt_w, max_w, static_shape) - - if self.controlnet is None: - return { - "sample": [ - (min_batch, self.unet_dim, min_latent_height, min_latent_width), - (opt_batch, self.unet_dim, latent_height, latent_width), - (max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "timesteps": [(min_batch,), (opt_batch,), (max_batch,)], - "encoder_hidden_states": [ - (min_batch, self.text_optlen, self.embedding_dim), - (opt_batch, self.text_optlen, self.embedding_dim), - (max_batch, self.text_maxlen, self.embedding_dim), - ], - } - else: - return { - "sample": [ - (min_batch, self.unet_dim, min_latent_height, min_latent_width), - (opt_batch, self.unet_dim, latent_height, latent_width), - (max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "timesteps": [(min_batch,), (opt_batch,), (max_batch,)], - "encoder_hidden_states": [ - (min_batch, self.text_optlen, self.embedding_dim), - (opt_batch, self.text_optlen, self.embedding_dim), - (max_batch, self.text_maxlen, self.embedding_dim), - ], - "images": [ - ( - len(self.controlnet), - min_batch, - 3, - min_h, - min_w, - ), - ( - len(self.controlnet), - opt_batch, - 3, - opt_h, - opt_w, - ), - ( - len(self.controlnet), - max_batch, - 3, - max_h, - max_w, - ), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - if self.controlnet is None: - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timesteps": (2 * batch_size,), - "encoder_hidden_states": ( - 2 * batch_size, - self.text_optlen, - self.embedding_dim, - ), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - else: - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timesteps": (2 * batch_size,), - "encoder_hidden_states": ( - 2 * batch_size, - self.text_optlen, - self.embedding_dim, - ), - "images": ( - len(self.controlnet), - 2 * batch_size, - 3, - image_height, - image_width, - ), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - dtype = torch.float16 if self.fp16 else torch.float32 - if self.controlnet is None: - return ( - torch.randn( - 2 * batch_size, - self.unet_dim, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ), - torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device), - torch.randn( - 2 * batch_size, - self.text_optlen, - self.embedding_dim, - dtype=dtype, - device=self.device, - ), - ) - else: - return ( - torch.randn( - batch_size, - self.unet_dim, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ), - torch.ones((batch_size,), dtype=torch.float32, device=self.device), - torch.randn( - batch_size, - self.text_optlen, - self.embedding_dim, - dtype=dtype, - device=self.device, - ), - torch.randn( - len(self.controlnet), - batch_size, - 3, - image_height, - image_width, - dtype=dtype, - device=self.device, - ), - torch.randn(len(self.controlnet), dtype=dtype, device=self.device), - ) - - -def make_OAIUNet( - version, - pipeline, - device, - verbose, - max_batch_size, - text_optlen, - text_maxlen, - controlnet=None, -): - return OAIUNet( - version, - pipeline, - fp16=True, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - text_optlen=text_optlen, - text_maxlen=text_maxlen, - unet_dim=(9 if pipeline.is_inpaint() else 4), - controlnet=get_controlnets_path(controlnet), - ) - - -class OAIUNetXL(BaseModel): - def __init__( - self, - version, - pipeline, - fp16=False, - device="cuda", - verbose=True, - max_batch_size=16, - text_maxlen=77, - text_optlen=77, - unet_dim=4, - time_dim=6, - num_classes=2816, - ): - super(OAIUNetXL, self).__init__( - version, - pipeline, - "", - fp16=fp16, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - text_maxlen=text_maxlen, - embedding_dim=get_unet_embedding_dim(version, pipeline), - ) - self.unet_dim = unet_dim - self.time_dim = time_dim - self.num_classes = num_classes - self.text_optlen = text_optlen - - def get_input_names(self): - return ["sample", "timesteps", "encoder_hidden_states", "y"] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B", 1: "77N"}, - "timesteps": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - "y": {0: "2B", 1: "num_classes"}, - } - - def get_input_profile( - self, - min_batch, - opt_batch, - max_batch, - min_h, - opt_h, - max_h, - min_w, - opt_w, - max_w, - static_shape, - ): - min_batch, opt_batch, max_batch = self.get_batch_dim( - min_batch, opt_batch, max_batch, static_shape - ) - - ( - min_latent_height, - latent_height, - max_latent_height, - min_latent_width, - latent_width, - max_latent_width, - ) = self.get_latent_dim(min_h, opt_h, max_h, min_w, opt_w, max_w, static_shape) - - return { - "sample": [ - (min_batch, self.unet_dim, min_latent_height, min_latent_width), - (opt_batch, self.unet_dim, latent_height, latent_width), - (max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "timesteps": [(min_batch,), (opt_batch,), (max_batch,)], - "encoder_hidden_states": [ - (min_batch, self.text_optlen, self.embedding_dim), - (opt_batch, self.text_optlen, self.embedding_dim), - (max_batch, self.text_maxlen, self.embedding_dim), - ], - "y": [ - (min_batch, self.num_classes), - (opt_batch, self.num_classes), - (max_batch, self.num_classes), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timesteps": (2 * batch_size,), - "encoder_hidden_states": ( - 2 * batch_size, - self.text_optlen, - self.embedding_dim, - ), - "y": (2 * batch_size, self.num_classes), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - dtype = torch.float16 if self.fp16 else torch.float32 - return ( - torch.randn( - 2 * batch_size, - self.unet_dim, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ), - torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device), - torch.randn( - 2 * batch_size, - self.text_optlen, - self.embedding_dim, - dtype=dtype, - device=self.device, - ), - torch.randn( - 2 * batch_size, self.num_classes, dtype=dtype, device=self.device - ), - ) - - -def make_OAIUNetXL( - version, pipeline, device, verbose, max_batch_size, text_optlen, text_maxlen -): - # Disable torch SDPA - # if hasattr(F, "scaled_dot_product_attention"): - # delattr(F, "scaled_dot_product_attention") - return OAIUNetXL( - version, - pipeline, - fp16=True, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - unet_dim=(9 if pipeline.is_inpaint() else 4), - text_optlen=text_optlen, - text_maxlen=text_maxlen, - ) - - -class VAE(BaseModel): - def __init__( - self, - version, - pipeline, - hf_token, - device, - verbose, - max_batch_size, - ): - super(VAE, self).__init__( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - ) - - def get_input_names(self): - return ["latent"] - - def get_output_names(self): - return ["images"] - - def get_dynamic_axes(self): - return { - "latent": {0: "B", 2: "H", 3: "W"}, - "images": {0: "B", 2: "8H", 3: "8W"}, - } - - def get_input_profile( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_shape - ) - return { - "latent": [ - (min_batch, 4, min_latent_height, min_latent_width), - (batch_size, 4, latent_height, latent_width), - (max_batch, 4, max_latent_height, max_latent_width), - ] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - return { - "latent": (batch_size, 4, latent_height, latent_width), - "images": (batch_size, 3, image_height, image_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - return torch.randn( - batch_size, - 4, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ) - - -def make_VAE(version, pipeline, hf_token, device, verbose, max_batch_size): - return VAE( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - ) - - -class VAEEncoder(BaseModel): - def __init__( - self, - version, - pipeline, - hf_token, - device, - verbose, - max_batch_size, - ): - super(VAEEncoder, self).__init__( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - ) - - def get_input_names(self): - return ["images"] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - return { - "images": {0: "B", 2: "8H", 3: "8W"}, - "latent": {0: "B", 2: "H", 3: "W"}, - } - - def get_input_profile( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - _, - _, - _, - _, - ) = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_shape - ) - - return { - "images": [ - (min_batch, 3, min_image_height, min_image_width), - (batch_size, 3, image_height, image_width), - (max_batch, 3, max_image_height, max_image_width), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - return { - "images": (batch_size, 3, image_height, image_width), - "latent": (batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.randn( - batch_size, - 3, - image_height, - image_width, - dtype=torch.float32, - device=self.device, - ) - - -def make_VAEEncoder(version, pipeline, hf_token, device, verbose, max_batch_size): - return VAEEncoder( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - ) diff --git a/scripts/lora.py b/scripts/lora.py new file mode 100644 index 0000000..2e21bef --- /dev/null +++ b/scripts/lora.py @@ -0,0 +1,45 @@ +import os +from typing import List + +import numpy as np +import torch +from safetensors.torch import load_file +import onnx +from onnx import numpy_helper + + +def merge_loras(loras: List[str], scales: List[str]) -> dict: + refit_dict = {} + for lora, scale in zip(loras, scales): + lora_dict = load_file(lora) + for k, v in lora_dict.items(): + if k in refit_dict: + refit_dict[k] += scale * v + else: + refit_dict[k] = scale * v + return refit_dict + + +def apply_loras(base_path: str, loras: List[str], scales: List[str]) -> dict: + refit_dict = merge_loras(loras, scales) + base = onnx.load(base_path) + onnx_opt_dir = os.path.dirname(base_path) + + def convert_int64(arr): + if len(arr.shape) == 0: + return np.array([np.int32(arr)]) + return arr + + for initializer in base.graph.initializer: + if initializer.name not in refit_dict: + continue + + wt = refit_dict[initializer.name] + initializer_data = numpy_helper.to_array( + initializer, base_dir=onnx_opt_dir + ).astype(np.float16) + delta = torch.tensor(initializer_data).to(wt.device) + wt + + refit_dict[initializer.name] = delta.contiguous() + + return refit_dict diff --git a/scripts/trt.py b/scripts/trt.py index 2c72ad7..b9f16aa 100644 --- a/scripts/trt.py +++ b/scripts/trt.py @@ -1,17 +1,22 @@ import os -import numpy as np - -import ldm.modules.diffusionmodules.openaimodel +import re +from typing import List +import numpy as np import torch from torch.cuda import nvtx -from modules import script_callbacks, sd_unet, devices +from polygraphy.logger import G_LOGGER +import gradio as gr + +from modules import script_callbacks, sd_unet, devices, scripts, shared import ui_trt from utilities import Engine -from typing import List from model_manager import TRT_MODEL_DIR, modelmanager -from modules import sd_models, shared +from datastructures import ModelType +from scripts.lora import apply_loras + +G_LOGGER.module_severity = G_LOGGER.ERROR class TrtUnetOption(sd_unet.SdUnetOption): @@ -21,53 +26,33 @@ def __init__(self, name: str, filename: List[dict]): self.configs = filename def create_unet(self): - lora_path = None - if self.configs[0]["config"].lora: - lora_path = os.path.join(TRT_MODEL_DIR, self.configs[0]["filepath"]) - self.model_name = self.configs[0]["base_model"] - self.configs = modelmanager.available_models()[self.model_name] - validate_sd_version(self.model_name, exact=True) - return TrtUnet(self.model_name, self.configs, lora_path) - - -def validate_sd_version(model_name, exact=False): - loaded_model = shared.sd_model.sd_checkpoint_info.model_name - if exact: - if not loaded_model == model_name: - raise ValueError( - f"Selected torch model ({loaded_model}) does not match the selected TensorRT U-Net ({model_name}). Please ensure that both models are the same." - ) - else: - if shared.sd_model.is_sdxl: - if not "xl" in model_name: - raise ValueError( - f"Selected torch model ({loaded_model}) does not match the selected TensorRT U-Net ({model_name}). Please ensure that both models are the same." - ) - loaded_version = 1 if shared.sd_model.is_sd1 else 2 - if f"v{loaded_version}" not in model_name: - raise ValueError( - f"Selected torch model ({loaded_model}) does not match the selected TensorRT U-Net ({model_name}). Please ensure that both models are the same." - ) + return TrtUnet(self.model_name, self.configs) class TrtUnet(sd_unet.SdUnet): - def __init__( - self, model_name: str, configs: List[dict], lora_path, *args, **kwargs - ): + def __init__(self, model_name: str, configs: List[dict], *args, **kwargs): super().__init__(*args, **kwargs) - self.configs = configs + self.stream = None self.model_name = model_name - self.lora_path = lora_path + self.configs = configs + + self.profile_idx = 0 + self.loaded_config = None + self.engine_vram_req = 0 + self.refitted_keys = set() - self.loaded_config = self.configs[0] - self.shape_hash = 0 - self.engine = Engine( - os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]) - ) + self.engine = None - def forward(self, x, timesteps, context, *args, **kwargs): + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: nvtx.range_push("forward") feed_dict = { "sample": x.float(), @@ -77,17 +62,6 @@ def forward(self, x, timesteps, context, *args, **kwargs): if "y" in kwargs: feed_dict["y"] = kwargs["y"].float() - # Need to check compatability on the fly - if self.shape_hash != hash(x.shape): - nvtx.range_push("switch_engine") - if x.shape[-1] % 8 or x.shape[-2] % 8: - raise ValueError( - "Input shape must be divisible by 64 in both dimensions." - ) - self.switch_engine(feed_dict) - self.shape_hash = hash(x.shape) - nvtx.range_pop() - tmp = torch.empty( self.engine_vram_req, dtype=torch.uint8, device=devices.device ) @@ -100,42 +74,256 @@ def forward(self, x, timesteps, context, *args, **kwargs): nvtx.range_pop() return out - def switch_engine(self, feed_dict): - valid_models, distances = modelmanager.get_valid_models( - self.model_name, feed_dict - ) - if len(valid_models) == 0: - raise ValueError( - "No valid profile found. Please go to the TensorRT tab and generate an engine with the necessary profile. If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net." - ) + def apply_loras(self, refit_dict: dict): + if not self.refitted_keys.issubset(set(refit_dict.keys())): + # Need to ensure that weights that have been modified before and are not present anymore are reset. + self.refitted_keys = set() + self.switch_engine() - best = valid_models[np.argmin(distances)] - if best["filepath"] == self.loaded_config["filepath"]: - return - self.deactivate() - self.engine = Engine(os.path.join(TRT_MODEL_DIR, best["filepath"])) + self.engine.refit_from_dict(refit_dict, is_fp16=True) + self.refitted_keys = set(refit_dict.keys()) + + def switch_engine(self): + self.loaded_config = self.configs[self.profile_idx] + self.engine.reset(os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"])) self.activate() - self.loaded_config = best def activate(self): + self.loaded_config = self.configs[self.profile_idx] + if self.engine is None: + self.engine = Engine( + os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]) + ) self.engine.load() + print(f"\nLoaded Profile: {self.profile_idx}") print(self.engine) self.engine_vram_req = self.engine.engine.device_memory_size self.engine.activate(True) - if self.lora_path is not None: - self.engine.refit_from_dump(self.lora_path) - def deactivate(self): - self.shape_hash = 0 del self.engine +class TensorRTScript(scripts.Script): + def __init__(self) -> None: + self.loaded_model = None + self.lora_hash = "" + self.update_lora = False + self.lora_refit_dict = {} + self.idx = None + self.hr_idx = None + self.torch_unet = False + + def title(self): + return "TensorRT" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def setup(self, p, *args): + return super().setup(p, *args) + + def before_process(self, p, *args): # 1 + # Check divisibilty + if p.width % 64 or p.height % 64: + gr.Error("Target resolution must be divisible by 64 in both dimensions.") + + if self.is_img2img: + return + if p.enable_hr: + hr_w = int(p.width * p.hr_scale) + hr_h = int(p.height * p.hr_scale) + if hr_w % 64 or hr_h % 64: + gr.Error( + "HIRES Fix resolution must be divisible by 64 in both dimensions. Please change the upscale factor or disable HIRES Fix." + ) + + def get_profile_idx(self, p, model_name: str, model_type: ModelType) -> (int, int): + best_hr = None + + if self.is_img2img: + hr_scale = 1 + else: + hr_scale = p.hr_scale if p.enable_hr else 1 + ( + valid_models, + distances, + idx, + ) = modelmanager.get_valid_models( + model_name, + p.width, + p.height, + p.batch_size, + 77, # model_type + ) # TODO: max_embedding, just ignore? + if len(valid_models) == 0: + gr.Error( + f"""No valid profile found for ({model_name}) LOWRES. Please go to the TensorRT tab and generate an engine with the necessary profile. + If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net.""" + ) + return None, None + best = idx[np.argmin(distances)] + best_hr = best + + if hr_scale != 1: + hr_w = int(p.width * p.hr_scale) + hr_h = int(p.height * p.hr_scale) + valid_models_hr, distances_hr, idx_hr = modelmanager.get_valid_models( + model_name, + hr_w, + hr_h, + p.batch_size, + 77, # model_type + ) # TODO: max_embedding + if len(valid_models_hr) == 0: + gr.Error( + f"""No valid profile found for ({model_name}) HIRES. Please go to the TensorRT tab and generate an engine with the necessary profile. + If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net.""" + ) + merged_idx = [i for i, id in enumerate(idx) if id in idx_hr] + if len(merged_idx) == 0: + gr.Warning( + "No model available for both ({}) LOWRES ({}x{}) and HIRES ({}x{}). This will slow-down inference.".format( + model_name, p.width, p.height, hr_w, hr_h + ) + ) + return None, None + else: + _distances = [distances[i] for i in merged_idx] + best_hr = merged_idx[np.argmin(_distances)] + best = best_hr + + return best, best_hr + + def get_loras(self, p): + lora_pathes = [] + lora_scales = [] + + # get lora from prompt + _prompt = p.prompt + extra_networks = re.findall("\<(.*?)\>", _prompt) + loras = [net for net in extra_networks if net.startswith("lora")] + + # Avoid that extra networks will be loaded + for lora in loras: + _prompt = _prompt.replace(f"<{lora}>", "") + p.prompt = _prompt + + # check if lora config has changes + if self.lora_hash != "".join(loras): + self.lora_hash = "".join(loras) + self.update_lora = True + if self.lora_hash == "": + self.lora_refit_dict = {} + return + else: + return + + # Get pathes + print("Apllying LoRAs: " + str(loras)) + available = modelmanager.available_loras() + for lora in loras: + lora_name, lora_scale = lora.split(":")[1:] + lora_scales.append(float(lora_scale)) + if lora_name not in available: + raise Exception( + f"Please export the LoRA checkpoint {lora_name} first from the TensorRT LoRA tab" + ) + lora_pathes.append( + available[lora_name] + ) + + # Merge lora refit dicts + base_name, base_path = modelmanager.get_onnx_path(p.sd_model_name) + refit_dict = apply_loras(base_path, lora_pathes, lora_scales) + + self.lora_refit_dict = refit_dict + + def process(self, p, *args): + # before unet_init + sd_unet_option = sd_unet.get_unet_option() + if sd_unet_option is None: + return + + if not sd_unet_option.model_name == p.sd_model_name: + gr.Error( + """Selected torch model ({}) does not match the selected TensorRT U-Net ({}). + Please ensure that both models are the same or select Automatic from the SD UNet dropdown.""".format( + p.sd_model_name, sd_unet_option.model_name + ) + ) + self.idx, self.hr_idx = self.get_profile_idx(p, p.sd_model_name, ModelType.UNET) + self.torch_unet = self.idx is None or self.hr_idx is None + + try: + if not self.torch_unet: + self.get_loras(p) + except Exception as e: + gr.Error(e) + raise e + + self.apply_unet(sd_unet_option) + + def apply_unet(self, sd_unet_option): + if ( + sd_unet_option == sd_unet.current_unet_option + and sd_unet.current_unet is not None + and not self.torch_unet + ): + return + + if sd_unet.current_unet is not None: + sd_unet.current_unet.deactivate() + + if self.torch_unet: + gr.Warning("Enabling PyTorch fallback as no engine was found.") + sd_unet.current_unet = None + sd_unet.current_unet_option = sd_unet_option + shared.sd_model.model.diffusion_model.to(devices.device) + return + else: + shared.sd_model.model.diffusion_model.to(devices.cpu) + devices.torch_gc() + if self.lora_refit_dict: + self.update_lora = True + sd_unet.current_unet = sd_unet_option.create_unet() + sd_unet.current_unet.profile_idx = self.idx + sd_unet.current_unet.option = sd_unet_option + sd_unet.current_unet_option = sd_unet_option + + print(f"Activating unet: {sd_unet.current_unet.option.label}") + sd_unet.current_unet.activate() + + def process_batch(self, p, *args, **kwargs): + # Called for each batch count + if self.torch_unet: + return super().process_batch(p, *args, **kwargs) + + if self.idx != sd_unet.current_unet.profile_idx: + sd_unet.current_unet.profile_idx = self.idx + sd_unet.current_unet.switch_engine() + + def before_hr(self, p, *args): + if self.idx != self.hr_idx: + sd_unet.current_unet.profile_idx = self.hr_idx + sd_unet.current_unet.switch_engine() + + return super().before_hr(p, *args) # 4 (Only when HR starts.....) + + def after_extra_networks_activate(self, p, *args, **kwargs): + if self.update_lora and not self.torch_unet: + self.update_lora = False + sd_unet.current_unet.apply_loras(self.lora_refit_dict) + + def list_unets(l): model = modelmanager.available_models() for k, v in model.items(): + if v[0]["config"].lora: + continue label = "{} ({})".format(k, v[0]["base_model"]) if v[0]["config"].lora else k l.append(TrtUnetOption(label, v)) + script_callbacks.on_list_unets(list_unets) script_callbacks.on_ui_tabs(ui_trt.on_ui_tabs) diff --git a/ui_trt.py b/ui_trt.py index 4ae84cf..4f19847 100644 --- a/ui_trt.py +++ b/ui_trt.py @@ -1,34 +1,45 @@ import os +import gc +import json +import logging +from collections import defaultdict -from modules import sd_models, shared +import torch +from safetensors.torch import save_file import gradio as gr -from modules.call_queue import wrap_gradio_gpu_call from modules.shared import cmd_opts from modules.ui_components import FormRow - -from exporter import export_onnx, export_trt -from utilities import PIPELINE_TYPE, Engine -from models import make_OAIUNetXL, make_OAIUNet -import logging -import gc -import torch -from model_manager import modelmanager, cc_major, TRT_MODEL_DIR -from time import sleep -from collections import defaultdict +from modules import sd_hijack, sd_models, shared from modules.ui_common import refresh_symbol from modules.ui_components import ToolButton +from model_helper import UNetModel +from exporter import export_onnx, export_trt, export_lora +from model_manager import modelmanager, cc_major, TRT_MODEL_DIR +from datastructures import SDVersion, ProfilePrests, ProfileSettings + + +profile_presets = ProfilePrests() + logging.basicConfig(level=logging.INFO) -def get_version_from_model(sd_model): - if sd_model.is_sd1: - return "1.5" - if sd_model.is_sd2: - return "2.1" - if sd_model.is_sdxl: - return "xl-1.0" +def get_context_dim(): + if shared.sd_model.is_sd1: + return 768 + elif shared.sd_model.is_sd2: + return 1024 + elif shared.sd_model.is_sdxl: + return 2048 + + +def is_fp32(): + use_fp32 = False + if cc_major < 7: + use_fp32 = True + print("FP16 has been disabled because your GPU does not support it.") + return use_fp32 def export_unet_to_trt( @@ -47,75 +58,13 @@ def export_unet_to_trt( force_export, static_shapes, preset, - controlnet=None, ): + sd_hijack.model_hijack.apply_optimizations("None") - if preset == "Default": - ( - batch_min, - batch_opt, - batch_max, - height_min, - height_opt, - height_max, - width_min, - width_opt, - width_max, - token_count_min, - token_count_opt, - token_count_max, - ) = export_default_unet_to_trt() - is_inpaint = False - use_fp32 = False - if cc_major < 7: - use_fp32 = True - print("FP16 has been disabled because your GPU does not support it.") - - unet_hidden_dim = shared.sd_model.model.diffusion_model.in_channels - if unet_hidden_dim == 9: - is_inpaint = True - - model_hash = shared.sd_model.sd_checkpoint_info.hash + is_xl = shared.sd_model.is_sdxl model_name = shared.sd_model.sd_checkpoint_info.model_name - onnx_filename, onnx_path = modelmanager.get_onnx_path(model_name, model_hash) - - print(f"Exporting {model_name} to TensorRT") - timing_cache = modelmanager.get_timing_cache() - - version = get_version_from_model(shared.sd_model) - - pipeline = PIPELINE_TYPE.TXT2IMG - if is_inpaint: - pipeline = PIPELINE_TYPE.INPAINT - controlnet = None - - min_textlen = (token_count_min // 75) * 77 - opt_textlen = (token_count_opt // 75) * 77 - max_textlen = (token_count_max // 75) * 77 - if static_shapes: - min_textlen = max_textlen = opt_textlen - - if shared.sd_model.is_sdxl: - pipeline = PIPELINE_TYPE.SD_XL_BASE - modelobj = make_OAIUNetXL( - version, pipeline, "cuda", False, batch_max, opt_textlen, max_textlen - ) - diable_optimizations = True - else: - modelobj = make_OAIUNet( - version, - pipeline, - "cuda", - False, - batch_max, - opt_textlen, - max_textlen, - controlnet, - ) - diable_optimizations = False - - profile = modelobj.get_input_profile( + profile_settings = ProfileSettings( batch_min, batch_opt, batch_max, @@ -125,30 +74,55 @@ def export_unet_to_trt( width_min, width_opt, width_max, - static_shapes, + token_count_min, + token_count_opt, + token_count_max, ) - print(profile) + if preset == "Default": + profile_settings = profile_presets.get_default(is_xl=is_xl) + use_fp32 = is_fp32() - if not os.path.exists(onnx_path): - print("No ONNX file found. Exporting ONNX...") - gr.Info("No ONNX file found. Exporting ONNX... Please check the progress in the terminal.") - export_onnx( - onnx_path, - modelobj, - profile=profile, - diable_optimizations=diable_optimizations, - ) - print("Exported to ONNX.") + print(f"Exporting {model_name} to TensorRT using - {profile_settings}") + profile_settings.token_to_dim(static_shapes) + + model_hash = shared.sd_model.sd_checkpoint_info.hash + model_name = shared.sd_model.sd_checkpoint_info.model_name + + onnx_filename, onnx_path = modelmanager.get_onnx_path(model_name) + timing_cache = modelmanager.get_timing_cache() + + diable_optimizations = is_xl + embedding_dim = get_context_dim() + + modelobj = UNetModel( + shared.sd_model.model.diffusion_model, + embedding_dim, + text_minlen=profile_settings.t_min, + is_xl=is_xl, + ) + modelobj.apply_torch_model() + + profile = modelobj.get_input_profile(profile_settings) + export_onnx( + onnx_path, + modelobj, + profile_settings, + diable_optimizations=diable_optimizations, + ) + gc.collect() + torch.cuda.empty_cache() trt_engine_filename, trt_path = modelmanager.get_trt_path( model_name, model_hash, profile, static_shapes ) if not os.path.exists(trt_path) or force_export: - print("Building TensorRT engine... This can take a while, please check the progress in the terminal.") - gr.Info("Building TensorRT engine... This can take a while, please check the progress in the terminal.") - gc.collect() - torch.cuda.empty_cache() + print( + "Building TensorRT engine... This can take a while, please check the progress in the terminal." + ) + gr.Info( + "Building TensorRT engine... This can take a while, please check the progress in the terminal." + ) ret = export_trt( trt_path, onnx_path, @@ -166,248 +140,171 @@ def export_unet_to_trt( profile, static_shapes, fp32=use_fp32, - inpaint=is_inpaint, + inpaint=True if modelobj.in_channels == 6 else False, refit=True, vram=0, - unet_hidden_dim=unet_hidden_dim, + unet_hidden_dim=modelobj.in_channels, lora=False, ) else: - print("TensorRT engine found. Skipping build. You can enable Force Export in the Advanced Settings to force a rebuild if needed.") + print( + "TensorRT engine found. Skipping build. You can enable Force Export in the Advanced Settings to force a rebuild if needed." + ) + + gc.collect() + torch.cuda.empty_cache() return "## Exported Successfully \n" def export_lora_to_trt(lora_name, force_export): - is_inpaint = False - use_fp32 = False - if cc_major < 7: - use_fp32 = True - print("FP16 has been disabled because your GPU does not support it.") - unet_hidden_dim = shared.sd_model.model.diffusion_model.in_channels - if unet_hidden_dim == 9: - is_inpaint = True - - model_hash = shared.sd_model.sd_checkpoint_info.hash - model_name = shared.sd_model.sd_checkpoint_info.model_name - base_name = f"{model_name}" # _{model_hash} + is_xl = shared.sd_model.is_sdxl available_lora_models = get_lora_checkpoints() lora_name = lora_name.split(" ")[0] - lora_model = available_lora_models[lora_name] + lora_model = available_lora_models.get(lora_name, None) + if lora_model is None: + return f"## No LoRA model found for {lora_name}" + + version = lora_model.get("version", SDVersion.Unknown) + if version == SDVersion.Unknown: + print( + "LoRA SD version couldm't be determined. Please ensure the correct SD Checkpoint is selected." + ) - onnx_base_filename, onnx_base_path = modelmanager.get_onnx_path( - model_name, model_hash - ) - onnx_lora_filename, onnx_lora_path = modelmanager.get_onnx_path( - lora_name, base_name - ) + model_name = shared.sd_model.sd_checkpoint_info.model_name + model_hash = shared.sd_model.sd_checkpoint_info.hash - version = get_version_from_model(shared.sd_model) + if not version.match(shared.sd_model): + print( + f"""LoRA SD version ({version}) does not match the current SD version ({model_name}). + Please ensure the correct SD Checkpoint is selected.""" + ) - pipeline = PIPELINE_TYPE.TXT2IMG - if is_inpaint: - pipeline = PIPELINE_TYPE.INPAINT + profile_settings = profile_presets.get_default(is_xl=False) + print(f"Exporting {lora_name} to TensorRT using - {profile_settings}") + profile_settings.token_to_dim(True) - if shared.sd_model.is_sdxl: - pipeline = PIPELINE_TYPE.SD_XL_BASE - modelobj = make_OAIUNetXL(version, pipeline, "cuda", False, 1, 77, 77) - diable_optimizations = True - else: - modelobj = make_OAIUNet( - version, - pipeline, - "cuda", - False, - 1, - 77, - 77, - None, - ) - diable_optimizations = False - - if not os.path.exists(onnx_lora_path): - print("No ONNX file found. Exporting ONNX...") - gr.Info("No ONNX file found. Exporting ONNX... Please check the progress in the terminal.") - export_onnx( - onnx_lora_path, - modelobj, - profile=modelobj.get_input_profile( - 1, 1, 1, 512, 512, 512, 512, 512, 512, True - ), - diable_optimizations=diable_optimizations, - lora_path=lora_model["filename"], - ) - print("Exported to ONNX.") + onnx_base_filename, onnx_base_path = modelmanager.get_onnx_path(model_name) + if not os.path.exists(onnx_base_path): + return f"## Please export the base model ({model_name}) first." - trt_lora_name = onnx_lora_filename.replace(".onnx", ".trt") - trt_lora_path = os.path.join(TRT_MODEL_DIR, trt_lora_name) + embedding_dim = get_context_dim() - available_trt_unet = modelmanager.available_models() - if len(available_trt_unet[base_name]) == 0: - return "## Please export the base model first." - trt_base_path = os.path.join( - TRT_MODEL_DIR, available_trt_unet[base_name][0]["filepath"] + modelobj = UNetModel( + shared.sd_model.model.diffusion_model, + embedding_dim, + text_minlen=profile_settings.t_min, + is_xl=is_xl, ) + modelobj.apply_torch_model() - if not os.path.exists(onnx_base_path): - return "## Please export the base model first." - - if not os.path.exists(trt_lora_path) or force_export: - print("No TensorRT engine found. Building...") - gr.Info("No TensorRT engine found. Building...") - - engine = Engine(trt_base_path) - engine.load() - engine.refit(onnx_base_path, onnx_lora_path, dump_refit_path=trt_lora_path) - print("Built TensorRT engine.") - - modelmanager.add_lora_entry( - base_name, - lora_name, - trt_lora_name, - use_fp32, - is_inpaint, - 0, - unet_hidden_dim, + weights_map_path = modelmanager.get_weights_map_path(model_name) + if not os.path.exists(weights_map_path): + modelobj.export_weights_map(onnx_base_path, weights_map_path) + + lora_trt_name = f"{lora_name}.lora" + lora_trt_path = os.path.join(TRT_MODEL_DIR, lora_trt_name) + + if os.path.exists(lora_trt_path) and not force_export: + print( + "TensorRT engine found. Skipping build. You can enable Force Export in the Advanced Settings to force a rebuild if needed." ) + return "## Exported Successfully \n" + + profile = modelobj.get_input_profile(profile_settings) + refit_dict = export_lora( + modelobj, + onnx_base_path, + weights_map_path, + lora_model["filename"], + profile_settings, + ) + + save_file(refit_dict, lora_trt_path) + + return "## Exported Successfully \n" -def export_default_unet_to_trt(): - is_xl = shared.sd_model.is_sdxl +def get_version_from_filename(name): + if "v1-" in name: + return "1.5" + elif "v2-" in name: + return "2.1" + elif "xl" in name: + return "xl-1.0" + else: + return "Unknown" - batch_min = 1 - batch_opt = 1 - batch_max = 4 - height_min = 768 if is_xl else 512 - height_opt = 1024 if is_xl else 512 - height_max = 1024 if is_xl else 768 - width_min = 768 if is_xl else 512 - width_opt = 1024 if is_xl else 512 - width_max = 1024 if is_xl else 768 - token_count_min = 75 - token_count_opt = 75 - token_count_max = 150 - - return ( - batch_min, - batch_opt, - batch_max, - height_min, - height_opt, - height_max, - width_min, - width_opt, - width_max, - token_count_min, - token_count_opt, - token_count_max, - ) +def get_lora_checkpoints(): + available_lora_models = {} + allowed_extensions = ["pt", "ckpt", "safetensors"] + candidates = [ + p + for p in os.listdir(cmd_opts.lora_dir) + if p.split(".")[-1] in allowed_extensions + ] -profile_presets = { - "512x512 | Batch Size 1 (Static)": ( - 1, - 1, - 1, - 512, - 512, - 512, - 512, - 512, - 512, - 75, - 75, - 75, - ), - "768x768 | Batch Size 1 (Static)": ( - 1, - 1, - 1, - 768, - 768, - 768, - 768, - 768, - 768, - 75, - 75, - 75, - ), - "1024x1024 | Batch Size 1 (Static)": ( - 1, - 1, - 1, - 1024, - 1024, - 1024, - 1024, - 1024, - 1024, - 75, - 75, - 75, - ), - "256x256 - 512x512 | Batch Size 1-4 (Dynamic)": ( - 1, - 1, - 4, - 256, - 512, - 512, - 256, - 512, - 512, - 75, - 75, - 150, - ), - "512x512 - 768x768 | Batch Size 1-4 (Dynamic)": ( - 1, - 1, - 4, - 512, - 512, - 768, - 512, - 512, - 768, - 75, - 75, - 150, - ), - "768x768 - 1024x1024 | Batch Size 1-4 (Dynamic)": ( - 1, - 1, - 4, - 768, - 1024, - 1024, - 768, - 1024, - 1024, - 75, - 75, - 150, - ), -} - - -def get_settings_from_version(version): - static = False - if version == "Default": - return *list(profile_presets.values())[-2], static - if "Static" in version: - static = True - return *profile_presets[version], static + for filename in candidates: + metadata = {} + name, ext = os.path.splitext(filename) + config_file = os.path.join(cmd_opts.lora_dir, name + ".json") + + if ext == ".safetensors": + metadata = sd_models.read_metadata_from_safetensors( + os.path.join(cmd_opts.lora_dir, filename) + ) + else: + print( + """LoRA {} is not a safetensor. This might cause issues when exporting to TensorRT. + Please ensure that the correct base model is selected when exporting.""".format( + name + ) + ) + + base_model = metadata.get("ss_sd_model_name", "Unknown") + if os.path.exists(config_file): + with open(config_file, "r") as f: + config = json.load(f) + version = SDVersion.from_str(config["sd version"]) + + else: + version = SDVersion.Unknown + print( + "No config file found for {}. You can generate it in the LoRA tab.".format( + name + ) + ) + + available_lora_models[name] = { + "filename": filename, + "version": version, + "base_model": base_model, + } + return available_lora_models + + +def get_valid_lora_checkpoints(): + available_lora_models = get_lora_checkpoints() + return [f"{k} ({v['version']})" for k, v in available_lora_models.items()] def diable_export(version): if version == "Default": - return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) + return ( + gr.update(visible=False), + gr.update(visible=True), + gr.update(visible=False), + ) else: - return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True) + return ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=True), + ) + def disable_lora_export(lora): if lora is None: @@ -415,6 +312,7 @@ def disable_lora_export(lora): else: return gr.update(visible=True) + def diable_visibility(hide): num_outputs = 8 out = [gr.update(visible=not hide) for _ in range(num_outputs)] @@ -465,9 +363,9 @@ def get_md_table( loras_md = {} for base_model, models in available_models.items(): for i, m in enumerate(models): - if m["config"].lora: - loras_md[base_model] = m.get("base_model", None) - continue + # if m["config"].lora: + # loras_md[base_model] = m.get("base_model", None) + # continue s_min, s_opt, s_max = m["config"].profile.get( "sample", [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]] @@ -492,53 +390,11 @@ def get_md_table( model_md[base_model].append(profile_table) - for lora, base_model in loras_md.items(): - model_md[f"{lora} (*{base_model}*)"] = model_md[base_model] + available_loras = modelmanager.available_loras() + for lora, path in available_loras.items(): + loras_md[f"{lora}"] = "" - return model_md - - -def get_version_from_filename(name): - if "v1-" in name: - return "1.5" - elif "v2-" in name: - return "2.1" - elif "xl" in name: - return "xl-1.0" - else: - return "Unknown" - - -def get_lora_checkpoints(): - available_lora_models = {} - canditates = list( - shared.walk_files( - shared.cmd_opts.lora_dir, - allowed_extensions=[".pt", ".ckpt", ".safetensors"], - ) - ) - for filename in canditates: - name = os.path.splitext(os.path.basename(filename))[0] - try: - metadata = sd_models.read_metadata_from_safetensors(filename) - version = get_version_from_filename(metadata.get("ss_sd_model_name")) - except (AssertionError, TypeError): - version = "Unknown" - available_lora_models[name] = { - "filename": filename, - "version": version, - } - return available_lora_models - - -def get_valid_lora_checkpoints(): - available_lora_models = get_lora_checkpoints() - return [ - f"{k} ({v['version']})" - for k, v in available_lora_models.items() - if v["version"] == get_version_from_model(shared.sd_model) - or v["version"] == "Unknown" - ] + return model_md, loras_md def on_ui_tabs(): @@ -551,17 +407,18 @@ def on_ui_tabs(): value="# TensorRT Exporter", ) - default_version = list(profile_presets.keys())[-2] - default_vals = list(profile_presets.values())[-2] + default_vals = profile_presets.get_default(is_xl=False) version = gr.Dropdown( label="Preset", - choices=list(profile_presets.keys()) + ["Default"], + choices=profile_presets.get_choices(), elem_id="sd_version", default="Default", value="Default", ) - with gr.Accordion("Advanced Settings", open=False, visible=False) as advanced_settings: + with gr.Accordion( + "Advanced Settings", open=False, visible=False + ) as advanced_settings: with FormRow( elem_classes="checkboxes-row", variant="compact" ): @@ -571,13 +428,13 @@ def on_ui_tabs(): elem_id="trt_static_shapes", ) - with gr.Column(elem_id="trt_max_batch"): + with gr.Column(elem_id="trt_batch"): trt_min_batch = gr.Slider( minimum=1, maximum=16, step=1, label="Min batch-size", - value=default_vals[0], + value=default_vals.bs_min, elem_id="trt_min_batch", ) @@ -586,7 +443,7 @@ def on_ui_tabs(): maximum=16, step=1, label="Optimal batch-size", - value=default_vals[1], + value=default_vals.bs_opt, elem_id="trt_opt_batch", ) trt_max_batch = gr.Slider( @@ -594,59 +451,59 @@ def on_ui_tabs(): maximum=16, step=1, label="Max batch-size", - value=default_vals[2], + value=default_vals.bs_min, elem_id="trt_max_batch", ) with gr.Column(elem_id="trt_height"): trt_height_min = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Min height", - value=default_vals[3], + value=default_vals.h_min, elem_id="trt_min_height", ) trt_height_opt = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Optimal height", - value=default_vals[4], + value=default_vals.h_opt, elem_id="trt_opt_height", ) trt_height_max = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Max height", - value=default_vals[5], + value=default_vals.h_max, elem_id="trt_max_height", ) with gr.Column(elem_id="trt_width"): trt_width_min = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Min width", - value=default_vals[6], + value=default_vals.w_min, elem_id="trt_min_width", ) trt_width_opt = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Optimal width", - value=default_vals[7], + value=default_vals.w_opt, elem_id="trt_opt_width", ) trt_width_max = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Max width", - value=default_vals[8], + value=default_vals.w_max, elem_id="trt_max_width", ) @@ -656,7 +513,7 @@ def on_ui_tabs(): maximum=750, step=75, label="Min prompt token count", - value=default_vals[9], + value=default_vals.t_min, elem_id="trt_opt_token_count_min", ) trt_token_count_opt = gr.Slider( @@ -664,7 +521,7 @@ def on_ui_tabs(): maximum=750, step=75, label="Optimal prompt token count", - value=default_vals[10], + value=default_vals.t_opt, elem_id="trt_opt_token_count_opt", ) trt_token_count_max = gr.Slider( @@ -672,7 +529,7 @@ def on_ui_tabs(): maximum=750, step=75, label="Max prompt token count", - value=default_vals[11], + value=default_vals.t_max, elem_id="trt_opt_token_count_max", ) @@ -700,7 +557,7 @@ def on_ui_tabs(): ) version.change( - get_settings_from_version, + profile_presets.get_settings_from_version, version, [ trt_min_batch, @@ -721,7 +578,11 @@ def on_ui_tabs(): version.change( diable_export, version, - [button_export_unet, button_export_default_unet, advanced_settings], + [ + button_export_unet, + button_export_default_unet, + advanced_settings, + ], ) static_shapes.change( @@ -774,14 +635,16 @@ def on_ui_tabs(): trt_lora_dropdown, ) trt_lora_dropdown.change( - disable_lora_export, trt_lora_dropdown, button_export_lora_unet + disable_lora_export, + trt_lora_dropdown, + button_export_lora_unet, ) with gr.Column(variant="panel"): with open( os.path.join(os.path.dirname(os.path.abspath(__file__)), "info.md"), "r", - encoding='utf-8', + encoding="utf-8", ) as f: trt_info = gr.Markdown(elem_id="trt_info", value=f.read()) @@ -791,25 +654,38 @@ def on_ui_tabs(): def get_trt_profiles_markdown(): profiles_md_string = "" - for model, profiles in engine_profile_card().items(): + engine_cards, lora_cards = engine_profile_card() + for model, profiles in engine_cards.items(): profiles_md_string += f"
{model} ({len(profiles)} Profiles)\n\n" for i, profile in enumerate(profiles): profiles_md_string += f"#### Profile {i} \n{profile}\n\n" profiles_md_string += "
\n" profiles_md_string += "\n" - return profiles_md_string + profiles_md_string += "\n --- \n ## LoRA Profiles \n" + for model, details in lora_cards.items(): + profiles_md_string += f"
{model}\n\n" + profiles_md_string += details + profiles_md_string += "
\n" + return profiles_md_string with gr.Column(variant="panel"): with gr.Row(equal_height=True, variant="compact"): - button_refresh_profiles = ToolButton(value=refresh_symbol, elem_id="trt_refresh_profiles", visible=True) + button_refresh_profiles = ToolButton( + value=refresh_symbol, elem_id="trt_refresh_profiles", visible=True + ) profile_header_md = gr.Markdown( value=f"## Available TensorRT Engine Profiles" ) with gr.Row(equal_height=True): - trt_profiles_markdown = gr.Markdown(elem_id=f"trt_profiles_markdown", value=get_trt_profiles_markdown()) - - button_refresh_profiles.click(lambda: gr.Markdown.update(value=get_trt_profiles_markdown()), outputs=[trt_profiles_markdown]) + trt_profiles_markdown = gr.Markdown( + elem_id=f"trt_profiles_markdown", value=get_trt_profiles_markdown() + ) + + button_refresh_profiles.click( + lambda: gr.Markdown.update(value=get_trt_profiles_markdown()), + outputs=[trt_profiles_markdown], + ) button_export_unet.click( export_unet_to_trt, diff --git a/utilities.py b/utilities.py index f943f31..d5285c0 100644 --- a/utilities.py +++ b/utilities.py @@ -15,32 +15,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import torch +from torch.cuda import nvtx from collections import OrderedDict import numpy as np -import onnx -import onnx_graphsurgeon as gs from polygraphy.backend.common import bytes_from_path from polygraphy import util -from polygraphy.backend.trt import CreateConfig, ModifyNetworkOutputs, Profile +from polygraphy.backend.trt import ModifyNetworkOutputs, Profile from polygraphy.backend.trt import ( engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine, ) +from polygraphy.logger import G_LOGGER import tensorrt as trt -import torch -from torch.cuda import nvtx -from enum import Enum, auto -from safetensors.numpy import save_file, load_file from logging import error, warning -import os -import sys from tqdm import tqdm -import copy +import copy TRT_LOGGER = trt.Logger(trt.Logger.ERROR) +G_LOGGER.module_severity = G_LOGGER.ERROR # Map of numpy dtype -> torch dtype numpy_to_torch_dtype_dict = { @@ -65,33 +60,6 @@ value: key for (key, value) in numpy_to_torch_dtype_dict.items() } - -class PIPELINE_TYPE(Enum): - TXT2IMG = auto() - IMG2IMG = auto() - INPAINT = auto() - SD_XL_BASE = auto() - SD_XL_REFINER = auto() - - def is_txt2img(self): - return self == self.TXT2IMG - - def is_img2img(self): - return self == self.IMG2IMG - - def is_inpaint(self): - return self == self.INPAINT - - def is_sd_xl_base(self): - return self == self.SD_XL_BASE - - def is_sd_xl_refiner(self): - return self == self.SD_XL_REFINER - - def is_sd_xl(self): - return self.is_sd_xl_base() or self.is_sd_xl_refiner() - - class TQDMProgressMonitor(trt.IProgressMonitor): def __init__(self): trt.IProgressMonitor.__init__(self) @@ -181,152 +149,55 @@ def __del__(self): del self.buffers del self.tensors - def refit(self, onnx_path, onnx_refit_path, dump_refit_path=None): - def convert_int64(arr): - # TODO: smarter conversion - if len(arr.shape) == 0: - return np.int32(arr) - return arr - - def add_to_map(refit_dict, name, values): - if name in refit_dict: - assert refit_dict[name] is None - if values.dtype == np.int64: - values = convert_int64(values) - refit_dict[name] = values - - print(f"Refitting TensorRT engine with {onnx_refit_path} weights") - refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes - - # Construct mapping from weight names in refit model -> original model - name_map = {} - for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): - refit_node = refit_nodes[n] - assert node.op == refit_node.op - # Constant nodes in ONNX do not have inputs but have a constant output - if node.op == "Constant": - name_map[refit_node.outputs[0].name] = node.outputs[0].name - # Handle scale and bias weights - elif node.op == "Conv": - if node.inputs[1].__class__ == gs.Constant: - name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" - if node.inputs[2].__class__ == gs.Constant: - name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" - # For all other nodes: find node inputs that are initializers (gs.Constant) - else: - for i, inp in enumerate(node.inputs): - if inp.__class__ == gs.Constant: - name_map[refit_node.inputs[i].name] = inp.name + def reset(self, engine_path=None): + del self.engine + del self.context + del self.buffers + del self.tensors + self.engine_path = engine_path - def map_name(name): - if name in name_map: - return name_map[name] - return name + self.buffers = OrderedDict() + self.tensors = OrderedDict() + self.inputs = {} + self.outputs = {} - # Construct refit dictionary - refit_dict = {} + def refit_from_dict(self, refit_weights, is_fp16): + # Initialize refitter refitter = trt.Refitter(self.engine, TRT_LOGGER) - all_weights = refitter.get_all() - for layer_name, role in zip(all_weights[0], all_weights[1]): - # for specialized roles, use a unique name in the map: - if role == trt.WeightsRole.KERNEL: - name = layer_name + "_TRTKERNEL" - elif role == trt.WeightsRole.BIAS: - name = layer_name + "_TRTBIAS" - else: - name = layer_name - - assert name not in refit_dict, "Found duplicate layer: " + name - refit_dict[name] = None - - for n in refit_nodes: - # Constant nodes in ONNX do not have inputs but have a constant output - if n.op == "Constant": - name = map_name(n.outputs[0].name) - print(f"Add Constant {name}\n") - try: - add_to_map(refit_dict, name, n.outputs[0].values) - except: - error(f"Failed to add Constant {name}\n") - - # Handle scale and bias weights - elif n.op == "Conv": - if n.inputs[1].__class__ == gs.Constant: - name = map_name(n.name + "_TRTKERNEL") - try: - add_to_map(refit_dict, name, n.inputs[1].values) - except: - error(f"Failed to add Conv {name}\n") - - if n.inputs[2].__class__ == gs.Constant: - name = map_name(n.name + "_TRTBIAS") - try: - add_to_map(refit_dict, name, n.inputs[2].values) - except: - error(f"Failed to add Conv {name}\n") - - # For all other nodes: find node inputs that are initializers (AKA gs.Constant) - else: - for inp in n.inputs: - name = map_name(inp.name) - if inp.__class__ == gs.Constant: - add_to_map(refit_dict, name, inp.values) - - if dump_refit_path is not None: - print("Finished refit. Dumping result to disk.") - save_file( - refit_dict, dump_refit_path - ) # TODO need to come up with delta system to save only changed weights - return - - for layer_name, weights_role in zip(all_weights[0], all_weights[1]): - if weights_role == trt.WeightsRole.KERNEL: - custom_name = layer_name + "_TRTKERNEL" - elif weights_role == trt.WeightsRole.BIAS: - custom_name = layer_name + "_TRTBIAS" - else: - custom_name = layer_name - # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model - if layer_name.startswith("onnx::Trilu"): + refitted_weights = set() + # iterate through all tensorrt refittable weights + for trt_weight_name in refitter.get_all_weights(): + if trt_weight_name not in refit_weights: continue - if refit_dict[custom_name] is not None: - refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) - else: - print(f"[W] No refit weights for layer: {layer_name}") - - if not refitter.refit_cuda_engine(): - print("Failed to refit!") - exit(0) - - def refit_from_dump(self, dump_refit_path): - refit_dict = load_file( - dump_refit_path - ) # TODO if deltas are used needs to be unpacked here - - refitter = trt.Refitter(self.engine, TRT_LOGGER) - all_weights = refitter.get_all() - - for layer_name, weights_role in zip(all_weights[0], all_weights[1]): - if weights_role == trt.WeightsRole.KERNEL: - custom_name = layer_name + "_TRTKERNEL" - elif weights_role == trt.WeightsRole.BIAS: - custom_name = layer_name + "_TRTBIAS" - else: - custom_name = layer_name - - # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model - if layer_name.startswith("onnx::Trilu"): - continue + # get weight from state dict + trt_datatype = trt.DataType.FLOAT + if is_fp16: + refit_weights[trt_weight_name] = refit_weights[trt_weight_name].half() + trt_datatype = trt.DataType.HALF + + # trt.Weight and trt.TensorLocation + refit_weights[trt_weight_name] = refit_weights[trt_weight_name].cpu() + trt_wt_tensor = trt.Weights( + trt_datatype, + refit_weights[trt_weight_name].data_ptr(), + torch.numel(refit_weights[trt_weight_name]), + ) + trt_wt_location = ( + trt.TensorLocation.DEVICE + if refit_weights[trt_weight_name].is_cuda + else trt.TensorLocation.HOST + ) - if refit_dict[custom_name] is not None: - refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) - else: - print(f"[W] No refit weights for layer: {layer_name}") + # apply refit + # refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location) + refitter.set_named_weights(trt_weight_name, trt_wt_tensor) + refitted_weights.add(trt_weight_name) + assert set(refitted_weights) == set(refit_weights.keys()) if not refitter.refit_cuda_engine(): - print("Failed to refit!") + print("Error: failed to refit new weights.") exit(0) def build( @@ -386,18 +257,11 @@ def build( profiles = copy.deepcopy(p) for profile in profiles: # Last profile is used for set_calibration_profile. - calib_profile = profile.fill_defaults(network[1]).to_trt(builder, network[1]) + calib_profile = profile.fill_defaults(network[1]).to_trt( + builder, network[1] + ) config.add_optimization_profile(calib_profile) - - # config = CreateConfig( - # fp16=fp16, - # refittable=enable_refit, - # profiles=p, - # load_timing_cache=timing_cache, - # profiling_verbosity=trt.ProfilingVerbosity.DEFAULT, - # **config_kwargs, - # ) try: engine = engine_from_network( network, @@ -460,7 +324,6 @@ def infer(self, feed_dict, stream, use_cuda_graph=False): def __str__(self): out = "" for opt_profile in range(self.engine.num_optimization_profiles): - out += f"Profile {opt_profile}:\n" for binding_idx in range(self.engine.num_bindings): name = self.engine.get_binding_name(binding_idx) shape = self.engine.get_profile_shape(opt_profile, name)