From 02a0643480279e4a749c3c61a5800c31835580d8 Mon Sep 17 00:00:00 2001 From: Rudra <92840555+Rudra-Ji@users.noreply.github.com> Date: Fri, 20 Oct 2023 13:15:33 +0530 Subject: [PATCH 01/12] fix typo (#79) --- README.md | 20 ++++++++++---------- scripts/trt.py | 2 +- ui_trt.py | 4 ++-- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 4f3f0b8..dc87b95 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# TensorRT Extension for Stable Diffusion +# TensorRT Extension for Stable Diffusion -This extension enables the best performance on NVIDIA RTX GPUs for Stable Diffusion with TensorRT. +This extension enables the best performance on NVIDIA RTX GPUs for Stable Diffusion with TensorRT. -You need to install the extension and generate optimized engines before using the extension. Please follow the instructions below to set everything up. +You need to install the extension and generate optimized engines before using the extension. Please follow the instructions below to set everything up. Supports Stable Diffusion 1.5 and 2.1. Native SDXL support coming in a future release. Please use the [dev branch](https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/dev) if you would like to use it today. Note that the Dev branch is not intended for production work and may break other things that you are currently using. @@ -17,10 +17,10 @@ Example instructions for Automatic1111: ## How to use -1. Click on the “Generate Default Engines” button. This step takes 2-10 minutes depending on your GPU. You can generate engines for other combinations. +1. Click on the “Generate Default Engines” button. This step takes 2-10 minutes depending on your GPU. You can generate engines for other combinations. 2. Go to Settings → User Interface → Quick Settings List, add sd_unet. Apply these settings, then reload the UI. -3. Back in the main UI, select the TRT model from the sd_unet dropdown menu at the top of the page. -4. You can now start generating images accelerated by TRT. If you need to create more Engines, go to the TensorRT tab. +3. Back in the main UI, select the TRT model from the sd_unet dropdown menu at the top of the page. +4. You can now start generating images accelerated by TRT. If you need to create more Engines, go to the TensorRT tab. Happy prompting! @@ -29,15 +29,15 @@ Happy prompting! TensorRT uses optimized engines for specific resolutions and batch sizes. You can generate as many optimized engines as desired. Types: - The "Export Default Engines” selection adds support for resolutions between 512x512 and 768x768 for Stable Diffusion 1.5 and 768x768 to 1024x1024 for SDXL with batch sizes 1 to 4. -- Static engines support a single specific output resolution and batch size. -- Dynamic engines support a range of resolutions and batch sizes, at a small cost in performance. Wider ranges will use more VRAM. +- Static engines support a single specific output resolution and batch size. +- Dynamic engines support a range of resolutions and batch sizes, at a small cost in performance. Wider ranges will use more VRAM. Each preset can be adjusted with the “Advanced Settings” option. More detailed instructions can be found [here](https://nvidia.custhelp.com/app/answers/detail/a_id/5487/~/tensorrt-extension-for-stable-diffusion-web-ui). ### Common Issues/Limitations **HIRES FIX:** If using the hires.fix option in Automatic1111 you must build engines that match both the starting and ending resolutions. For instance, if initial size is `512 x 512` and hires.fix upscales to `1024 x 1024`, you must either generate two engines, one at 512 and one at 1024, or generate a single dynamic engine that covers the whole range. -Having two seperate engines will heavily impact performance at the moment. Stay tuned for updates. +Having two separate engines will heavily impact performance at the moment. Stay tuned for updates. **Resolution:** When generating images the resolution needs to be a multiple of 64. This applies to hires.fix as well, requiring the low and high-res to be divisible by 64. @@ -55,4 +55,4 @@ Having two seperate engines will heavily impact performance at the moment. Stay - Linux: >= 450.80.02 - Windows: >=452.39 -We always recommend keeping the driver up-to-date for system wide performance improvments. \ No newline at end of file +We always recommend keeping the driver up-to-date for system wide performance improvements. diff --git a/scripts/trt.py b/scripts/trt.py index 2c72ad7..1a8084c 100644 --- a/scripts/trt.py +++ b/scripts/trt.py @@ -77,7 +77,7 @@ def forward(self, x, timesteps, context, *args, **kwargs): if "y" in kwargs: feed_dict["y"] = kwargs["y"].float() - # Need to check compatability on the fly + # Need to check compatibility on the fly if self.shape_hash != hash(x.shape): nvtx.range_push("switch_engine") if x.shape[-1] % 8 or x.shape[-2] % 8: diff --git a/ui_trt.py b/ui_trt.py index 4ae84cf..fca152b 100644 --- a/ui_trt.py +++ b/ui_trt.py @@ -511,13 +511,13 @@ def get_version_from_filename(name): def get_lora_checkpoints(): available_lora_models = {} - canditates = list( + candidates = list( shared.walk_files( shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"], ) ) - for filename in canditates: + for filename in candidates: name = os.path.splitext(os.path.basename(filename))[0] try: metadata = sd_models.read_metadata_from_safetensors(filename) From 7618777d061fe9c56bd2a33e724c1a8f06e9fe9a Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 23 Oct 2023 06:36:05 -0700 Subject: [PATCH 02/12] Faster and simplified install --- install.py | 57 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/install.py b/install.py index 7c07285..a4ba8d7 100644 --- a/install.py +++ b/install.py @@ -1,31 +1,52 @@ import launch +import sys from importlib_metadata import version +python = sys.executable + + def install(): + import torch + + if not torch.cuda.is_available(): + print( + "Torch CUDA is not available! Please install Torch with CUDA and try again." + ) + return + if launch.is_installed("tensorrt"): if not version("tensorrt") == "9.0.1.post11.dev4": - launch.run(["python","-m","pip","uninstall","-y","tensorrt"], "removing old version of tensorrt") - - + print("Removing old TensorRT package and try reinstalling...") + launch.run( + f'"{python}" -m pip uninstall -y tensorrt', + "removing old version of tensorrt", + ) + if not launch.is_installed("tensorrt"): - print("TensorRT is not installed! Installing...") - launch.run_pip("install nvidia-cudnn-cu11==8.9.4.25 --no-cache-dir", "nvidia-cudnn-cu11") - launch.run_pip("install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir", "tensorrt", live=True) - launch.run(["python","-m","pip","uninstall","-y","nvidia-cudnn-cu11"], "removing nvidia-cudnn-cu11") - - if launch.is_installed("nvidia-cudnn-cu11"): - if version("nvidia-cudnn-cu11") == "8.9.4.25": - launch.run(["python","-m","pip","uninstall","-y","nvidia-cudnn-cu11"], "removing nvidia-cudnn-cu11") - - # Polygraphy + launch.run_pip( + "install --pre --extra-index-url https://pypi.nvidia.com --no-cache-dir --no-deps tensorrt==9.0.1.post11.dev4", + "tensorrt", + live=True, + ) + + # Polygraphy if not launch.is_installed("polygraphy"): print("Polygraphy is not installed! Installing...") - launch.run_pip("install polygraphy --extra-index-url https://pypi.ngc.nvidia.com", "polygraphy", live=True) - + launch.run_pip( + "install polygraphy --extra-index-url https://pypi.ngc.nvidia.com", + "polygraphy", + live=True, + ) + # ONNX GS if not launch.is_installed("onnx_graphsurgeon"): print("GS is not installed! Installing...") launch.run_pip("install protobuf==3.20.2", "protobuf", live=True) - launch.run_pip('install onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com', "onnx-graphsurgeon", live=True) - -install() \ No newline at end of file + launch.run_pip( + "install onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com", + "onnx-graphsurgeon", + live=True, + ) + + +install() From 3615f9a08a94ee848d9df45294b030ba13286fa0 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 23 Oct 2023 06:38:10 -0700 Subject: [PATCH 03/12] Print correct profile when engine is loaded --- model_manager.py | 6 ++++-- scripts/trt.py | 14 ++++++++------ utilities.py | 23 ++++++----------------- 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/model_manager.py b/model_manager.py index b4544cf..81b6dbd 100644 --- a/model_manager.py +++ b/model_manager.py @@ -178,14 +178,16 @@ def get_timing_cache(self): def get_valid_models(self, base_model: str, feed_dict: dict): valid_models = [] distances = [] + idx = [] models = self.available_models() - for model in models[base_model]: + for i, model in enumerate(models[base_model]): valid, distance = model["config"].is_compatible(feed_dict) if valid: valid_models.append(model) distances.append(distance) + idx.append(i) - return valid_models, distances + return valid_models, distances, idx @dataclass diff --git a/scripts/trt.py b/scripts/trt.py index 1a8084c..08bdf5d 100644 --- a/scripts/trt.py +++ b/scripts/trt.py @@ -1,18 +1,17 @@ import os import numpy as np -import ldm.modules.diffusionmodules.openaimodel - import torch from torch.cuda import nvtx -from modules import script_callbacks, sd_unet, devices +from modules import script_callbacks, sd_unet, devices, scripts import ui_trt from utilities import Engine from typing import List from model_manager import TRT_MODEL_DIR, modelmanager from modules import sd_models, shared - +from polygraphy.logger import G_LOGGER +G_LOGGER.module_severity = G_LOGGER.ERROR class TrtUnetOption(sd_unet.SdUnetOption): def __init__(self, name: str, filename: List[dict]): @@ -60,8 +59,9 @@ def __init__( self.model_name = model_name self.lora_path = lora_path self.engine_vram_req = 0 + self.profile_idx = 0 - self.loaded_config = self.configs[0] + self.loaded_config = self.configs[self.profile_idx] self.shape_hash = 0 self.engine = Engine( os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]) @@ -101,7 +101,7 @@ def forward(self, x, timesteps, context, *args, **kwargs): return out def switch_engine(self, feed_dict): - valid_models, distances = modelmanager.get_valid_models( + valid_models, distances, idx = modelmanager.get_valid_models( self.model_name, feed_dict ) if len(valid_models) == 0: @@ -109,6 +109,7 @@ def switch_engine(self, feed_dict): "No valid profile found. Please go to the TensorRT tab and generate an engine with the necessary profile. If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net." ) + self.profile_idx = idx[np.argmin(distances)] best = valid_models[np.argmin(distances)] if best["filepath"] == self.loaded_config["filepath"]: return @@ -119,6 +120,7 @@ def switch_engine(self, feed_dict): def activate(self): self.engine.load() + print(f"\nLoaded Profile: {self.profile_idx}") print(self.engine) self.engine_vram_req = self.engine.engine.device_memory_size self.engine.activate(True) diff --git a/utilities.py b/utilities.py index f943f31..a71c72d 100644 --- a/utilities.py +++ b/utilities.py @@ -15,32 +15,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import torch +from torch.cuda import nvtx from collections import OrderedDict import numpy as np import onnx import onnx_graphsurgeon as gs from polygraphy.backend.common import bytes_from_path from polygraphy import util -from polygraphy.backend.trt import CreateConfig, ModifyNetworkOutputs, Profile +from polygraphy.backend.trt import ModifyNetworkOutputs, Profile from polygraphy.backend.trt import ( engine_from_bytes, engine_from_network, network_from_onnx_path, - save_engine, + save_engine ) +from polygraphy.logger import G_LOGGER import tensorrt as trt -import torch -from torch.cuda import nvtx from enum import Enum, auto from safetensors.numpy import save_file, load_file from logging import error, warning -import os -import sys from tqdm import tqdm import copy TRT_LOGGER = trt.Logger(trt.Logger.ERROR) +G_LOGGER.module_severity = G_LOGGER.ERROR # Map of numpy dtype -> torch dtype numpy_to_torch_dtype_dict = { @@ -389,15 +388,6 @@ def build( calib_profile = profile.fill_defaults(network[1]).to_trt(builder, network[1]) config.add_optimization_profile(calib_profile) - - # config = CreateConfig( - # fp16=fp16, - # refittable=enable_refit, - # profiles=p, - # load_timing_cache=timing_cache, - # profiling_verbosity=trt.ProfilingVerbosity.DEFAULT, - # **config_kwargs, - # ) try: engine = engine_from_network( network, @@ -460,7 +450,6 @@ def infer(self, feed_dict, stream, use_cuda_graph=False): def __str__(self): out = "" for opt_profile in range(self.engine.num_optimization_profiles): - out += f"Profile {opt_profile}:\n" for binding_idx in range(self.engine.num_bindings): name = self.engine.get_binding_name(binding_idx) shape = self.engine.get_profile_shape(opt_profile, name) From 26103b15dc1814401d6dbc18f7660194158054f3 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 23 Oct 2023 12:27:55 -0700 Subject: [PATCH 04/12] Use scripts callbacks --- model_manager.py | 70 ++++++++++++++++++--- scripts/trt.py | 160 +++++++++++++++++++++++++++++++++-------------- 2 files changed, 176 insertions(+), 54 deletions(-) diff --git a/model_manager.py b/model_manager.py index 81b6dbd..2bbc2d8 100644 --- a/model_manager.py +++ b/model_manager.py @@ -71,12 +71,14 @@ def update(self): for model_config in models: if model_config["filepath"] not in trt_engines: info( - "Model config outdated. {} was not found".format(model_config["filepath"]) + "Model config outdated. {} was not found".format( + model_config["filepath"] + ) ) continue tmp_config_list[model_config["filepath"]] = model_config - - tmp_config_list = list(tmp_config_list.values()) + + tmp_config_list = list(tmp_config_list.values()) if len(tmp_config_list) == 0: self.all_models[cc].pop(base_model) else: @@ -84,7 +86,6 @@ def update(self): self.write_json() - def __del__(self): self.update() @@ -175,13 +176,36 @@ def get_timing_cache(self): return cache - def get_valid_models(self, base_model: str, feed_dict: dict): + def get_valid_models_from_dict(self, base_model: str, feed_dict: dict): + valid_models = [] + distances = [] + idx = [] + models = self.available_models() + for i, model in enumerate(models[base_model]): + valid, distance = model["config"].is_compatible_from_dict(feed_dict) + if valid: + valid_models.append(model) + distances.append(distance) + idx.append(i) + + return valid_models, distances, idx + + def get_valid_models( + self, + base_model: str, + width: int, + height: int, + batch_size: int, + max_embedding: int, + ): valid_models = [] distances = [] idx = [] models = self.available_models() for i, model in enumerate(models[base_model]): - valid, distance = model["config"].is_compatible(feed_dict) + valid, distance = model["config"].is_compatible( + width, height, batch_size, max_embedding + ) if valid: valid_models.append(model) distances.append(distance) @@ -201,7 +225,7 @@ class ModelConfig: vram: int unet_hidden_dim: int = 4 - def is_compatible(self, feed_dict: dict): + def is_compatible_from_dict(self, feed_dict: dict): distance = 0 for k, v in feed_dict.items(): _min, _opt, _max = self.profile[k] @@ -214,6 +238,38 @@ def is_compatible(self, feed_dict: dict): distance += r_opt.sum() + 0.5 * (r_max.sum() + 0.5 * r_min.sum()) return (True, distance) + def is_compatible( + self, width: int, height: int, batch_size: int, max_embedding: int + ): + distance = 0 + sample = self.profile["sample"] + embedding = self.profile["encoder_hidden_states"] + + batch_size *= 2 + width = width // 8 + height = height // 8 + + _min, _opt, _max = sample + if _min[0] > batch_size or _max[0] < batch_size: + return (False, distance) + if _min[2] > height or _max[2] < height: + return (False, distance) + if _min[3] > width or _max[3] < width: + return (False, distance) + + _min_em, _opt_em, _max_em = embedding + if _min_em[1] > max_embedding or _max_em[1] < max_embedding: + return (False, distance) + + distance = ( + abs(_opt[0] - batch_size) + + abs(_opt[2] - height) + + abs(_opt[3] - width) + + 0.5 * (abs(_max[2] - height) + abs(_max[3] - width)) + ) + + return (True, distance) + class ModelConfigEncoder(JSONEncoder): def default(self, o: ModelConfig): diff --git a/scripts/trt.py b/scripts/trt.py index 08bdf5d..891641c 100644 --- a/scripts/trt.py +++ b/scripts/trt.py @@ -9,10 +9,12 @@ from utilities import Engine from typing import List from model_manager import TRT_MODEL_DIR, modelmanager -from modules import sd_models, shared from polygraphy.logger import G_LOGGER +import gradio as gr + G_LOGGER.module_severity = G_LOGGER.ERROR + class TrtUnetOption(sd_unet.SdUnetOption): def __init__(self, name: str, filename: List[dict]): self.label = f"[TRT] {name}" @@ -25,28 +27,11 @@ def create_unet(self): lora_path = os.path.join(TRT_MODEL_DIR, self.configs[0]["filepath"]) self.model_name = self.configs[0]["base_model"] self.configs = modelmanager.available_models()[self.model_name] - validate_sd_version(self.model_name, exact=True) return TrtUnet(self.model_name, self.configs, lora_path) -def validate_sd_version(model_name, exact=False): - loaded_model = shared.sd_model.sd_checkpoint_info.model_name - if exact: - if not loaded_model == model_name: - raise ValueError( - f"Selected torch model ({loaded_model}) does not match the selected TensorRT U-Net ({model_name}). Please ensure that both models are the same." - ) - else: - if shared.sd_model.is_sdxl: - if not "xl" in model_name: - raise ValueError( - f"Selected torch model ({loaded_model}) does not match the selected TensorRT U-Net ({model_name}). Please ensure that both models are the same." - ) - loaded_version = 1 if shared.sd_model.is_sd1 else 2 - if f"v{loaded_version}" not in model_name: - raise ValueError( - f"Selected torch model ({loaded_model}) does not match the selected TensorRT U-Net ({model_name}). Please ensure that both models are the same." - ) +# This is ugly. Is there a better way to parse this as kwargs to the SD Unet? +GLOBAL_KWARGS = {"profile_idx": None, "profile_hr_idx": None, "model_name": ""} class TrtUnet(sd_unet.SdUnet): @@ -54,15 +39,21 @@ def __init__( self, model_name: str, configs: List[dict], lora_path, *args, **kwargs ): super().__init__(*args, **kwargs) + if not model_name == GLOBAL_KWARGS["model_name"]: + raise ValueError( + """Selected torch model ({}) does not match the selected TensorRT U-Net ({}). + Please ensure that both models are the same or select Automatic from the SD UNet dropdown.""".format( + GLOBAL_KWARGS["model_name"], model_name + ) + ) self.configs = configs self.stream = None self.model_name = model_name self.lora_path = lora_path self.engine_vram_req = 0 - self.profile_idx = 0 + self.profile_idx = GLOBAL_KWARGS["profile_idx"] self.loaded_config = self.configs[self.profile_idx] - self.shape_hash = 0 self.engine = Engine( os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]) ) @@ -77,16 +68,8 @@ def forward(self, x, timesteps, context, *args, **kwargs): if "y" in kwargs: feed_dict["y"] = kwargs["y"].float() - # Need to check compatibility on the fly - if self.shape_hash != hash(x.shape): - nvtx.range_push("switch_engine") - if x.shape[-1] % 8 or x.shape[-2] % 8: - raise ValueError( - "Input shape must be divisible by 64 in both dimensions." - ) - self.switch_engine(feed_dict) - self.shape_hash = hash(x.shape) - nvtx.range_pop() + if not self.profile_idx == GLOBAL_KWARGS["profile_idx"]: + self.switch_engine() tmp = torch.empty( self.engine_vram_req, dtype=torch.uint8, device=devices.device @@ -100,23 +83,14 @@ def forward(self, x, timesteps, context, *args, **kwargs): nvtx.range_pop() return out - def switch_engine(self, feed_dict): - valid_models, distances, idx = modelmanager.get_valid_models( - self.model_name, feed_dict - ) - if len(valid_models) == 0: - raise ValueError( - "No valid profile found. Please go to the TensorRT tab and generate an engine with the necessary profile. If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net." - ) - - self.profile_idx = idx[np.argmin(distances)] - best = valid_models[np.argmin(distances)] - if best["filepath"] == self.loaded_config["filepath"]: - return + def switch_engine(self): + self.profile_idx = GLOBAL_KWARGS["profile_idx"] + self.loaded_config = self.configs[self.profile_idx] self.deactivate() - self.engine = Engine(os.path.join(TRT_MODEL_DIR, best["filepath"])) + self.engine = Engine( + os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]) + ) self.activate() - self.loaded_config = best def activate(self): self.engine.load() @@ -133,11 +107,103 @@ def deactivate(self): del self.engine +class TensorRTScript(scripts.Script): + def __init__(self) -> None: + self.loaded_model = None + pass + + def title(self): + return "TensorRT" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def setup(self, p, *args): + return super().setup(p, *args) + + def before_process(self, p, *args): # 1 + # Check divisibilty + if p.width % 64 or p.height % 64: + raise ValueError( + "Target resolution must be divisible by 64 in both dimensions." + ) + + if p.enable_hr: + hr_w = int(p.width * p.hr_scale) + hr_h = int(p.height * p.hr_scale) + if hr_w % 64 or hr_h % 64: + raise ValueError( + "HIRES Fix resolution must be divisible by 64 in both dimensions. Please change the upscale factor or disable HIRES Fix." + ) + + # lora p.prompt == ' + + def process(self, p, *args): # 2 + # before unet_init + hr_scale = p.hr_scale if p.enable_hr else 1 + ( + valid_models, + distances, + idx, + ) = modelmanager.get_valid_models( + p.sd_model_name, p.width, p.height, p.batch_size, 77 + ) # TODO: max_embedding + if len(valid_models) == 0: + raise ValueError( + """No valid profile found for LOWRES. Please go to the TensorRT tab and generate an engine with the necessary profile. + If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net.""" + ) + best = idx[np.argmin(distances)] + + if hr_scale != 1: + hr_w = int(p.width * p.hr_scale) + hr_h = int(p.height * p.hr_scale) + valid_models_hr, distances_hr, idx_hr = modelmanager.get_valid_models( + p.sd_model_name, hr_w, hr_h, p.batch_size, 77 + ) # TODO: max_embedding + if len(valid_models) == 0: + raise ValueError( + "No valid profile found for HIRES. Please go to the TensorRT tab and generate an engine with the necessary profile. If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net." + ) + merged_idx = [i for i, id in enumerate(idx) if id in idx_hr] + if len(merged_idx) == 0: + gr.Warning( + "No model available for both LOWRES ({}x{}) and HIRES ({}x{}). This will slow-down inference.".format( + p.width, p.height, hr_w, hr_h + ) + ) + best_hr = idx_hr[np.argmin(distances_hr)] + else: + _distances = [distances[i] for i in merged_idx] + best_hr = idx_hr[merged_idx[np.argmin(_distances)]] + best = best_hr + GLOBAL_KWARGS["profile_hr_idx"] = best_hr + GLOBAL_KWARGS["profile_idx"] = best + GLOBAL_KWARGS["model_name"] = p.sd_model_name + + def process_batch(self, p, *args, **kwargs): + return super().process_batch(p, *args, **kwargs) + + def before_hr(self, p, *args): + GLOBAL_KWARGS["profile_idx"] = GLOBAL_KWARGS["profile_hr_idx"] + return super().before_hr(p, *args) # 4 (Only when HR starts.....) + + def after_extra_networks_activate(self, p, *args, **kwargs): + # if self.lora_path is not None: + # self.engine.refit_from_dump(self.lora_path) + + # Called after UNet activate + # p.extra_network_data + # Contains dict of modules.extra_networks.ExtraNetworkParams + return super().after_extra_networks_activate(p, *args, **kwargs) # 3 + + def list_unets(l): model = modelmanager.available_models() for k, v in model.items(): label = "{} ({})".format(k, v[0]["base_model"]) if v[0]["config"].lora else k l.append(TrtUnetOption(label, v)) + script_callbacks.on_list_unets(list_unets) script_callbacks.on_ui_tabs(ui_trt.on_ui_tabs) From b3a5a7391171ee595b31dbe765a8e0d8d99d23fa Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 25 Oct 2023 05:56:19 -0700 Subject: [PATCH 05/12] increase max resolution --- ui_trt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ui_trt.py b/ui_trt.py index fca152b..bbfc249 100644 --- a/ui_trt.py +++ b/ui_trt.py @@ -601,7 +601,7 @@ def on_ui_tabs(): with gr.Column(elem_id="trt_height"): trt_height_min = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Min height", value=default_vals[3], @@ -609,7 +609,7 @@ def on_ui_tabs(): ) trt_height_opt = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Optimal height", value=default_vals[4], @@ -617,7 +617,7 @@ def on_ui_tabs(): ) trt_height_max = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Max height", value=default_vals[5], @@ -627,7 +627,7 @@ def on_ui_tabs(): with gr.Column(elem_id="trt_width"): trt_width_min = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Min width", value=default_vals[6], @@ -635,7 +635,7 @@ def on_ui_tabs(): ) trt_width_opt = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Optimal width", value=default_vals[7], @@ -643,7 +643,7 @@ def on_ui_tabs(): ) trt_width_max = gr.Slider( minimum=256, - maximum=2048, + maximum=4096, step=64, label="Max width", value=default_vals[8], From 19a46840dd41a89acb16a20dd5a2ec06439b1122 Mon Sep 17 00:00:00 2001 From: lspindler Date: Tue, 19 Dec 2023 08:08:41 -0800 Subject: [PATCH 06/12] adding native LoRA support, avoiding model.py import error and some refactoring --- datastructures.py | 183 ++++++ exporter.py | 230 +++++--- model_helper.py | 288 ++++++++++ model_manager.py | 19 +- models.py | 1392 --------------------------------------------- scripts/lora.py | 45 ++ scripts/trt.py | 205 +++++-- ui_trt.py | 639 +++++++++------------ utilities.py | 185 ++---- 9 files changed, 1145 insertions(+), 2041 deletions(-) create mode 100644 datastructures.py create mode 100644 model_helper.py delete mode 100644 models.py create mode 100644 scripts/lora.py diff --git a/datastructures.py b/datastructures.py new file mode 100644 index 0000000..ac3b5eb --- /dev/null +++ b/datastructures.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +from enum import Enum + + +@dataclass +class UNetEngineArgs: + idx: int + hr_idx: int = None + lora: dict = None + controlnets: dict = None + + +class SDVersion(Enum): + SD1 = 1 + SD2 = 2 + SDXL = 3 + Unknown = -1 + + def __str__(self): + return self.name + + @classmethod + def from_str(cls, str): + try: + return cls[str] + except KeyError: + return cls.Unknown + + def match(self, sd_model): + if sd_model.is_sd1 and self == SDVersion.SD1: + return True + elif sd_model.is_sd2 and self == SDVersion.SD2: + return True + elif sd_model.is_sdxl and self == SDVersion.SDXL: + return True + elif self == SDVersion.Unknown: + return True + else: + return False + + +class ModelType(Enum): + UNET = 0 + CONTROLNET = 1 + LORA = 2 + UNDEFINED = -1 + + @classmethod + def from_string(cls, s): + return getattr(cls, s.upper(), None) + + def __str__(self): + return self.name.lower() + + +@dataclass +class ProfileSettings: + bs_min: int + bs_opt: int + bs_max: int + h_min: int + h_opt: int + h_max: int + w_min: int + w_opt: int + w_max: int + t_min: int + t_opt: int + t_max: int + static_shape: bool = False + + def __str__(self) -> str: + return "Batch Size: {}-{}-{}\nHeight: {}-{}-{}\nWidth: {}-{}-{}\nToken Count: {}-{}-{}".format( + self.bs_min, + self.bs_opt, + self.bs_max, + self.h_min, + self.h_opt, + self.h_max, + self.w_min, + self.w_opt, + self.w_max, + self.t_min, + self.t_opt, + self.t_max, + ) + + def out(self): + return ( + self.bs_min, + self.bs_opt, + self.bs_max, + self.h_min, + self.h_opt, + self.h_max, + self.w_min, + self.w_opt, + self.w_max, + self.t_min, + self.t_opt, + self.t_max, + ) + + def token_to_dim(self, static_shapes: bool): + self.t_min = (self.t_min // 75) * 77 + self.t_opt = (self.t_opt // 75) * 77 + self.t_max = (self.t_max // 75) * 77 + + if static_shapes: + self.t_min = self.t_max = self.t_opt + self.bs_min = self.bs_max = self.bs_opt + self.h_min = self.h_max = self.h_opt + self.w_min = self.w_max = self.w_opt + self.static_shape = True + + def get_latent_dim(self): + return ( + self.h_min // 8, + self.h_opt // 8, + self.h_max // 8, + self.w_min // 8, + self.w_opt // 8, + self.w_max // 8, + ) + + def get_a1111_batch_dim(self): + static_batch = self.bs_min == self.bs_max == self.bs_opt + if self.t_max <= 77: + return (self.bs_min * 2, self.bs_opt * 2, self.bs_max * 2) + elif self.t_max > 77 and static_batch: + return (self.bs_opt, self.bs_opt, self.bs_opt) + elif self.t_max > 77 and not static_batch: + if self.t_opt > 77: + return (self.bs_min, self.bs_opt, self.bs_max * 2) + return (self.bs_min, self.bs_opt * 2, self.bs_max * 2) + else: + raise Exception("Uncovered case in get_batch_dim") + + +class ProfilePrests: + def __init__(self): + self.profile_presets = { + "512x512 | Batch Size 1 (Static)": ProfileSettings( + 1, 1, 1, 512, 512, 512, 512, 512, 512, 75, 75, 75 + ), + "768x768 | Batch Size 1 (Static)": ProfileSettings( + 1, 1, 1, 768, 768, 768, 768, 768, 768, 75, 75, 75 + ), + "1024x1024 | Batch Size 1 (Static)": ProfileSettings( + 1, 1, 1, 1024, 1024, 1024, 1024, 1024, 1024, 75, 75, 75 + ), + "256x256 - 512x512 | Batch Size 1-4": ProfileSettings( + 1, 1, 4, 256, 512, 512, 256, 512, 512, 75, 75, 150 + ), + "512x512 - 768x768 | Batch Size 1-4": ProfileSettings( + 1, 1, 4, 512, 512, 768, 512, 512, 768, 75, 75, 150 + ), + "768x768 - 1024x1024 | Batch Size 1-4": ProfileSettings( + 1, 1, 4, 768, 1024, 1024, 768, 1024, 1024, 75, 75, 150 + ), + } + self.default = ProfileSettings( + 1, 1, 4, 512, 512, 768, 512, 512, 768, 75, 75, 150 + ) + self.default_xl = ProfileSettings( + 1, 1, 4, 768, 1024, 1024, 768, 1024, 1024, 75, 75, 150 + ) + + def get_settings_from_version(self, version): + static = False + if version == "Default": + return *self.default.out(), static + if "Static" in version: + static = True + return *self.profile_presets[version].out(), static + + def get_choices(self): + return list(self.profile_presets.keys()) + ["Default"] + + def get_default(self, is_xl: bool): + if is_xl: + return self.default_xl + return self.default diff --git a/exporter.py b/exporter.py index e4c6c5b..b6c6663 100644 --- a/exporter.py +++ b/exporter.py @@ -8,13 +8,19 @@ from modules import sd_hijack, sd_unet, shared from utilities import Engine +from datastructures import ProfileSettings +from model_helper import UNetModel import os - -def get_cc(): - cc_major = torch.cuda.get_device_properties(0).major - cc_minor = torch.cuda.get_device_properties(0).minor - return cc_major, cc_minor +from pathlib import Path +from optimum.onnx.utils import ( + _get_onnx_external_data_tensors, + check_model_uses_external_data, +) +from collections import OrderedDict +from onnx import numpy_helper +import numpy as np +import json def apply_lora(model, lora_path, inputs): @@ -40,19 +46,50 @@ def apply_lora(model, lora_path, inputs): return model -def export_onnx( - onnx_path, - modelobj=None, - profile=None, - opset=17, - diable_optimizations=False, - lora_path=None, +def get_refit_weights( + state_dict, onnx_opt_path, weight_name_mapping, weight_shape_mapping ): - swap_sdpa = hasattr(F, "scaled_dot_product_attention") - old_sdpa = getattr(F, "scaled_dot_product_attention", None) if swap_sdpa else None - if swap_sdpa: - delattr(F, "scaled_dot_product_attention") + refit_weights = OrderedDict() + onnx_opt_dir = os.path.dirname(onnx_opt_path) + onnx_opt_model = onnx.load(onnx_opt_path) + # Create initializer data hashes + initializer_hash_mapping = {} + onnx_data_mapping = {} + for initializer in onnx_opt_model.graph.initializer: + initializer_data = numpy_helper.to_array( + initializer, base_dir=onnx_opt_dir + ).astype(np.float16) + initializer_hash = hash(initializer_data.data.tobytes()) + initializer_hash_mapping[initializer.name] = initializer_hash + onnx_data_mapping[initializer.name] = initializer_data + + for torch_name, initializer_name in weight_name_mapping.items(): + initializer_hash = initializer_hash_mapping[initializer_name] + wt = state_dict[torch_name] + + # get shape transform info + initializer_shape, is_transpose = weight_shape_mapping[torch_name] + if is_transpose: + wt = torch.transpose(wt, 0, 1) + else: + wt = torch.reshape(wt, initializer_shape) + + # include weight if hashes differ + wt_hash = hash(wt.cpu().detach().numpy().astype(np.float16).data.tobytes()) + if initializer_hash != wt_hash: + delta = wt - torch.tensor(onnx_data_mapping[initializer_name]).to(wt.device) + refit_weights[initializer_name] = delta.contiguous() + return refit_weights + + +def export_lora( + modelobj: UNetModel, + onnx_path: str, + weights_map_path: str, + lora_name: str, + profile: ProfileSettings, +): def disable_checkpoint(self): if getattr(self, "use_checkpoint", False) == True: self.use_checkpoint = False @@ -60,27 +97,81 @@ def disable_checkpoint(self): self.checkpoint = False shared.sd_model.model.diffusion_model.apply(disable_checkpoint) - is_xl = shared.sd_model.is_sdxl - sd_unet.apply_unet("None") sd_hijack.model_hijack.apply_optimizations("None") - os.makedirs("onnx_tmp", exist_ok=True) - tmp_path = os.path.abspath(os.path.join("onnx_tmp", "tmp.onnx")) + info("Exporting to ONNX...") + inputs = modelobj.get_sample_input( + profile.bs_opt * 2, + profile.h_opt // 8, + profile.w_opt // 8, + profile.t_opt, + ) + model = shared.sd_model.model.diffusion_model + + with open(weights_map_path, "r") as fp_wts: + print(f"[I] Loading weights map: {weights_map_path} ") + [weights_name_mapping, weights_shape_mapping] = json.load(fp_wts) + + with torch.inference_mode(), torch.autocast("cuda"): + model = apply_lora(model, os.path.splitext(lora_name)[0], inputs) + + refit_dict = get_refit_weights( + model.state_dict(), + onnx_path, + weights_name_mapping, + weights_shape_mapping, + ) + + return refit_dict + +def export_onnx( + onnx_path: str, + modelobj: UNetModel, + profile: ProfileSettings, + opset=17, + diable_optimizations=False, +): + swap_sdpa = hasattr(F, "scaled_dot_product_attention") + old_sdpa = getattr(F, "scaled_dot_product_attention", None) if swap_sdpa else None + if swap_sdpa: + delattr(F, "scaled_dot_product_attention") + + info("Exporting to ONNX...") + inputs = modelobj.get_sample_input( + profile.bs_opt * 2, + profile.h_opt // 8, + profile.w_opt // 8, + profile.t_opt, + ) + + if not os.path.exists(onnx_path): + _export_onnx( + modelobj.unet, + inputs, + Path(onnx_path), + opset, + modelobj.get_input_names(), + modelobj.get_output_names(), + modelobj.get_dynamic_axes(), + modelobj.optimize if not diable_optimizations else None, + ) + + # CleanUp + if swap_sdpa and old_sdpa: + setattr(F, "scaled_dot_product_attention", old_sdpa) + + +def _export_onnx( + model, inputs, path, opset, in_names, out_names, dyn_axes, optimizer=None +): + tmp_dir = os.path.abspath("onnx_tmp") + os.makedirs(tmp_dir, exist_ok=True) + tmp_path = os.path.join(tmp_dir, "model.onnx") try: info("Exporting to ONNX...") with torch.inference_mode(), torch.autocast("cuda"): - inputs = modelobj.get_sample_input( - profile["sample"][1][0] // 2, - profile["sample"][1][-2] * 8, - profile["sample"][1][-1] * 8, - ) - model = shared.sd_model.model.diffusion_model - - if lora_path: - model = apply_lora(model, lora_path, inputs) - torch.onnx.export( model, inputs, @@ -88,51 +179,44 @@ def disable_checkpoint(self): export_params=True, opset_version=opset, do_constant_folding=True, - input_names=modelobj.get_input_names(), - output_names=modelobj.get_output_names(), - dynamic_axes=modelobj.get_dynamic_axes(), - ) - - info("Optimize ONNX.") - - onnx_graph = onnx.load(tmp_path) - if diable_optimizations: - onnx_opt_graph = onnx_graph - else: - onnx_opt_graph = modelobj.optimize(onnx_graph) - - if onnx_opt_graph.ByteSize() > 2147483648 or is_xl: - onnx.save_model( - onnx_opt_graph, - onnx_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - convert_attribute=False, + input_names=in_names, + output_names=out_names, + dynamic_axes=dyn_axes, ) - else: - try: - onnx.save(onnx_opt_graph, onnx_path) - except Exception as e: - error(e) - error("ONNX file too large. Saving as external data.") - onnx.save_model( - onnx_opt_graph, - onnx_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - convert_attribute=False, - ) - info("ONNX export complete.") - del onnx_opt_graph except Exception as e: - error(e) - exit() - - # CleanUp - if swap_sdpa and old_sdpa: - setattr(F, "scaled_dot_product_attention", old_sdpa) - shutil.rmtree(os.path.abspath("onnx_tmp")) - del model + error("Exporting to ONNX failed. {}".format(e)) + return + + info("Optimize ONNX.") + os.makedirs(path.parent, exist_ok=True) + onnx_model = onnx.load(tmp_path, load_external_data=False) + model_uses_external_data = check_model_uses_external_data(onnx_model) + + if model_uses_external_data: + info("ONNX model uses external data. Saving as external data.") + tensors_paths = _get_onnx_external_data_tensors(onnx_model) + onnx_model = onnx.load(tmp_path, load_external_data=True) + onnx.save( + onnx_model, + str(path), + save_as_external_data=True, + all_tensors_to_one_file=True, + location=path.name + "_data", + size_threshold=1024, + ) + + if optimizer is not None: + try: + onnx_opt_graph = optimizer("unet", onnx_model) + onnx.save(onnx_opt_graph, path) + except Exception as e: + error("Optimizing ONNX failed. {}".format(e)) + return + + if not model_uses_external_data and optimizer is None: + shutil.move(tmp_path, str(path)) + + shutil.rmtree(tmp_dir) def export_trt(trt_path, onnx_path, timing_cache, profile, use_fp16): diff --git a/model_helper.py b/model_helper.py new file mode 100644 index 0000000..547395f --- /dev/null +++ b/model_helper.py @@ -0,0 +1,288 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import onnx +from onnx import shape_inference, numpy_helper +import os +from polygraphy.backend.onnx.loader import fold_constants +import tempfile +import torch +import torch.nn.functional as F +import onnx_graphsurgeon as gs +from datastructures import ProfileSettings + +import numpy as np +import json + + +class UNetModel(torch.nn.Module): + def __init__(self, unet, embedding_dim, text_minlen=77, is_xl=False) -> None: + super().__init__() + self.unet = unet + self.is_xl = is_xl + + self.text_minlen = text_minlen + self.embedding_dim = embedding_dim + self.num_xl_classes = 2816 # Magic number for num_classes + self.emb_chn = 1280 + + self.dyn_axes = { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B", 1: "77N"}, + "timesteps": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + "y": {0: "2B"}, + } + + def get_input_names(self): + names = ["sample", "timesteps", "encoder_hidden_states"] + if self.is_xl: + names.append("y") + return names + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + io_names = self.get_input_names() + self.get_output_names() + dyn_axes = {name: self.dyn_axes[name] for name in io_names} + return dyn_axes + + def get_sample_input( + self, + batch_size, + latent_height, + latent_width, + text_len, + device="cuda", + dtype=torch.float32, + ): + return ( + torch.randn( + batch_size, + self.unet.in_channels, + latent_height, + latent_width, + dtype=dtype, + device=device, + ), + torch.randn(batch_size, dtype=dtype, device=device), + torch.randn( + batch_size, + text_len, + self.embedding_dim, + dtype=dtype, + device=device, + ), + torch.randn(batch_size, self.num_xl_classes, dtype=dtype, device=device) + if self.is_xl + else None, + ) + + def get_input_profile(self, profile: ProfileSettings): + min_batch, opt_batch, max_batch = profile.get_a1111_batch_dim() + ( + min_latent_height, + latent_height, + max_latent_height, + min_latent_width, + latent_width, + max_latent_width, + ) = profile.get_latent_dim() + + shape_dict = { + "sample": [ + (min_batch, self.unet.in_channels, min_latent_height, min_latent_width), + (opt_batch, self.unet.in_channels, latent_height, latent_width), + (max_batch, self.unet.in_channels, max_latent_height, max_latent_width), + ], + "timesteps": [(min_batch,), (opt_batch,), (max_batch,)], + "encoder_hidden_states": [ + (min_batch, profile.t_min, self.embedding_dim), + (opt_batch, profile.t_opt, self.embedding_dim), + (max_batch, profile.t_max, self.embedding_dim), + ], + } + if self.is_xl: + shape_dict["y"] = [ + (min_batch, self.num_xl_classes), + (opt_batch, self.num_xl_classes), + (max_batch, self.num_xl_classes), + ] + + return shape_dict + + # Helper utility for weights map + def export_weights_map(self, onnx_opt_path, weights_map_path): + onnx_opt_dir = onnx_opt_path + state_dict = self.unet.state_dict() + onnx_opt_model = onnx.load(onnx_opt_path) + + # Create initializer data hashes + def init_hash_map(onnx_opt_model): + initializer_hash_mapping = {} + for initializer in onnx_opt_model.graph.initializer: + initializer_data = numpy_helper.to_array( + initializer, base_dir=onnx_opt_dir + ).astype(np.float16) + initializer_hash = hash(initializer_data.data.tobytes()) + initializer_hash_mapping[initializer.name] = ( + initializer_hash, + initializer_data.shape, + ) + return initializer_hash_mapping + + initializer_hash_mapping = init_hash_map(onnx_opt_model) + + weights_name_mapping = {} + weights_shape_mapping = {} + # set to keep track of initializers already added to the name_mapping dict + initializers_mapped = set() + for wt_name, wt in state_dict.items(): + # get weight hash + wt = wt.cpu().detach().numpy().astype(np.float16) + wt_hash = hash(wt.data.tobytes()) + wt_t_hash = hash(np.transpose(wt).data.tobytes()) + + for initializer_name, ( + initializer_hash, + initializer_shape, + ) in initializer_hash_mapping.items(): + # Due to constant folding, some weights are transposed during export + # To account for the transpose op, we compare the initializer hash to the + # hash for the weight and its transpose + if wt_hash == initializer_hash or wt_t_hash == initializer_hash: + # The assert below ensures there is a 1:1 mapping between + # PyTorch and ONNX weight names. It can be removed in cases where 1:many + # mapping is found and name_mapping[wt_name] = list() + assert initializer_name not in initializers_mapped + weights_name_mapping[wt_name] = initializer_name + initializers_mapped.add(initializer_name) + is_transpose = False if wt_hash == initializer_hash else True + weights_shape_mapping[wt_name] = ( + initializer_shape, + is_transpose, + ) + + # Sanity check: Were any weights not matched + if wt_name not in weights_name_mapping: + print( + f"[I] PyTorch weight {wt_name} not matched with any ONNX initializer" + ) + print( + f"[I] UNet: {len(weights_name_mapping.keys())} PyTorch weights were matched with ONNX initializers" + ) + + assert weights_name_mapping.keys() == weights_shape_mapping.keys() + with open(weights_map_path, "w") as fp: + json.dump([weights_name_mapping, weights_shape_mapping], fp) + + @staticmethod + def optimize(name, onnx_graph, verbose=False): + opt = Optimizer(onnx_graph, verbose=verbose) + opt.info(name + ": original") + opt.cleanup() + opt.info(name + ": cleanup") + opt.fold_constants() + opt.info(name + ": fold constants") + opt.infer_shapes() + opt.info(name + ": shape inference") + onnx_opt_graph = opt.cleanup(return_onnx=True) + opt.info(name + ": finished") + return onnx_opt_graph + + +class Optimizer: + def __init__(self, onnx_graph, verbose=False): + self.graph = gs.import_onnx(onnx_graph) + self.verbose = verbose + + def info(self, prefix): + if self.verbose: + print( + f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs" + ) + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + if return_onnx: + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self, return_onnx=False): + onnx_graph = fold_constants( + gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True + ) + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes(self, return_onnx=False): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + temp_dir = tempfile.TemporaryDirectory().name + os.makedirs(temp_dir, exist_ok=True) + onnx_orig_path = os.path.join(temp_dir, "model.onnx") + onnx_inferred_path = os.path.join(temp_dir, "inferred.onnx") + onnx.save_model( + onnx_graph, + onnx_orig_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + onnx.shape_inference.infer_shapes_path(onnx_orig_path, onnx_inferred_path) + onnx_graph = onnx.load(onnx_inferred_path) + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def clip_add_hidden_states(self, return_onnx=False): + hidden_layers = -1 + onnx_graph = gs.export_onnx(self.graph) + for i in range(len(onnx_graph.graph.node)): + for j in range(len(onnx_graph.graph.node[i].output)): + name = onnx_graph.graph.node[i].output[j] + if "layers" in name: + hidden_layers = max( + int(name.split(".")[1].split("/")[0]), hidden_layers + ) + for i in range(len(onnx_graph.graph.node)): + for j in range(len(onnx_graph.graph.node[i].output)): + if onnx_graph.graph.node[i].output[ + j + ] == "/text_model/encoder/layers.{}/Add_1_output_0".format( + hidden_layers - 1 + ): + onnx_graph.graph.node[i].output[j] = "hidden_states" + for j in range(len(onnx_graph.graph.node[i].input)): + if onnx_graph.graph.node[i].input[ + j + ] == "/text_model/encoder/layers.{}/Add_1_output_0".format( + hidden_layers - 1 + ): + onnx_graph.graph.node[i].input[j] = "hidden_states" + if return_onnx: + return onnx_graph diff --git a/model_manager.py b/model_manager.py index 2bbc2d8..fdb50b4 100644 --- a/model_manager.py +++ b/model_manager.py @@ -4,9 +4,10 @@ import os from logging import info, warning from dataclasses import dataclass +from datastructures import ModelType import torch -from exporter import get_cc from modules import paths_internal +import copy ONNX_MODEL_DIR = os.path.join(paths_internal.models_path, "Unet-onnx") if not os.path.exists(ONNX_MODEL_DIR): @@ -19,6 +20,13 @@ MODEL_FILE = os.path.join(TRT_MODEL_DIR, "model.json") + +def get_cc(): + cc_major = torch.cuda.get_device_properties(0).major + cc_minor = torch.cuda.get_device_properties(0).minor + return cc_major, cc_minor + + cc_major, cc_minor = get_cc() @@ -35,8 +43,8 @@ def __init__(self, model_file=MODEL_FILE) -> None: self.update() @staticmethod - def get_onnx_path(model_name, model_hash): - onnx_filename = "_".join([model_name, model_hash]) + ".onnx" + def get_onnx_path(model_name): + onnx_filename = f"{model_name}.onnx" onnx_path = os.path.join(ONNX_MODEL_DIR, onnx_filename) return onnx_filename, onnx_path @@ -57,6 +65,9 @@ def get_trt_path(self, model_name, model_hash, profile, static_shape): return trt_filename, trt_path + def get_weights_map_path(self, model_name: str): + return os.path.join(TRT_MODEL_DIR, f"{model_name}_weights_map.json") + def update(self): trt_engines = [ trt_file @@ -64,7 +75,7 @@ def update(self): if trt_file.endswith(".trt") ] - tmp_all_models = self.all_models.copy() + tmp_all_models = copy.deepcopy(self.all_models) for cc, base_models in tmp_all_models.items(): for base_model, models in base_models.items(): tmp_config_list = {} diff --git a/models.py b/models.py deleted file mode 100644 index 1355c2b..0000000 --- a/models.py +++ /dev/null @@ -1,1392 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import onnx -from onnx import shape_inference -import os -from polygraphy.backend.onnx.loader import fold_constants -import tempfile -import torch -import torch.nn.functional as F -import onnx_graphsurgeon as gs - - -class Optimizer: - def __init__(self, onnx_graph, verbose=False): - self.graph = gs.import_onnx(onnx_graph) - self.verbose = verbose - - def info(self, prefix): - if self.verbose: - print( - f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs" - ) - - def cleanup(self, return_onnx=False): - self.graph.cleanup().toposort() - if return_onnx: - return gs.export_onnx(self.graph) - - def select_outputs(self, keep, names=None): - self.graph.outputs = [self.graph.outputs[o] for o in keep] - if names: - for i, name in enumerate(names): - self.graph.outputs[i].name = name - - def fold_constants(self, return_onnx=False): - onnx_graph = fold_constants( - gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True - ) - self.graph = gs.import_onnx(onnx_graph) - if return_onnx: - return onnx_graph - - def infer_shapes(self, return_onnx=False): - onnx_graph = gs.export_onnx(self.graph) - if onnx_graph.ByteSize() > 2147483648: - temp_dir = tempfile.TemporaryDirectory().name - os.makedirs(temp_dir, exist_ok=True) - onnx_orig_path = os.path.join(temp_dir, "model.onnx") - onnx_inferred_path = os.path.join(temp_dir, "inferred.onnx") - onnx.save_model( - onnx_graph, - onnx_orig_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - convert_attribute=False, - ) - onnx.shape_inference.infer_shapes_path(onnx_orig_path, onnx_inferred_path) - onnx_graph = onnx.load(onnx_inferred_path) - else: - onnx_graph = shape_inference.infer_shapes(onnx_graph) - - self.graph = gs.import_onnx(onnx_graph) - if return_onnx: - return onnx_graph - - def clip_add_hidden_states(self, return_onnx=False): - hidden_layers = -1 - onnx_graph = gs.export_onnx(self.graph) - for i in range(len(onnx_graph.graph.node)): - for j in range(len(onnx_graph.graph.node[i].output)): - name = onnx_graph.graph.node[i].output[j] - if "layers" in name: - hidden_layers = max( - int(name.split(".")[1].split("/")[0]), hidden_layers - ) - for i in range(len(onnx_graph.graph.node)): - for j in range(len(onnx_graph.graph.node[i].output)): - if onnx_graph.graph.node[i].output[ - j - ] == "/text_model/encoder/layers.{}/Add_1_output_0".format( - hidden_layers - 1 - ): - onnx_graph.graph.node[i].output[j] = "hidden_states" - for j in range(len(onnx_graph.graph.node[i].input)): - if onnx_graph.graph.node[i].input[ - j - ] == "/text_model/encoder/layers.{}/Add_1_output_0".format( - hidden_layers - 1 - ): - onnx_graph.graph.node[i].input[j] = "hidden_states" - if return_onnx: - return onnx_graph - - -def get_controlnets_path(controlnet_list): - """ - Currently ControlNet 1.0 is supported. - """ - if controlnet_list is None: - return None - return ["lllyasviel/sd-controlnet-" + controlnet for controlnet in controlnet_list] - - -def get_path(version, pipeline, controlnet=None): - if controlnet is not None: - return ["lllyasviel/sd-controlnet-" + modality for modality in controlnet] - - if version == "1.4": - if pipeline.is_inpaint(): - return "runwayml/stable-diffusion-inpainting" - else: - return "CompVis/stable-diffusion-v1-4" - elif version == "1.5": - if pipeline.is_inpaint(): - return "runwayml/stable-diffusion-inpainting" - else: - return "runwayml/stable-diffusion-v1-5" - elif version == "2.0-base": - if pipeline.is_inpaint(): - return "stabilityai/stable-diffusion-2-inpainting" - else: - return "stabilityai/stable-diffusion-2-base" - elif version == "2.0": - if pipeline.is_inpaint(): - return "stabilityai/stable-diffusion-2-inpainting" - else: - return "stabilityai/stable-diffusion-2" - elif version == "2.1": - return "stabilityai/stable-diffusion-2-1" - elif version == "2.1-base": - return "stabilityai/stable-diffusion-2-1-base" - elif version == "xl-1.0": - if pipeline.is_sd_xl_base(): - return "stabilityai/stable-diffusion-xl-base-1.0" - elif pipeline.is_sd_xl_refiner(): - return "stabilityai/stable-diffusion-xl-refiner-1.0" - else: - raise ValueError(f"Unsupported SDXL 1.0 pipeline {pipeline.name}") - else: - raise ValueError(f"Incorrect version {version}") - - -def get_clip_embedding_dim(version, pipeline): - if version in ("1.4", "1.5"): - return 768 - elif version in ("2.0", "2.0-base", "2.1", "2.1-base"): - return 1024 - elif version in ("xl-1.0") and pipeline.is_sd_xl_base(): - return 768 - else: - raise ValueError(f"Invalid version {version} + pipeline {pipeline}") - - -def get_clipwithproj_embedding_dim(version, pipeline): - if version in ("xl-1.0"): - return 1280 - else: - raise ValueError(f"Invalid version {version} + pipeline {pipeline}") - - -def get_unet_embedding_dim(version, pipeline): - if version in ("1.4", "1.5"): - return 768 - elif version in ("2.0", "2.0-base", "2.1", "2.1-base"): - return 1024 - elif version in ("xl-1.0") and pipeline.is_sd_xl_base(): - return 2048 - elif version in ("xl-1.0") and pipeline.is_sd_xl_refiner(): - return 1280 - else: - raise ValueError(f"Invalid version {version} + pipeline {pipeline}") - - -class BaseModel: - def __init__( - self, - version="1.5", - pipeline=None, - hf_token="", - device="cuda", - verbose=True, - fp16=False, - max_batch_size=16, - text_maxlen=77, - embedding_dim=768, - controlnet=None, - ): - self.name = self.__class__.__name__ - self.pipeline = pipeline.name - self.version = version - self.hf_token = hf_token - self.hf_safetensor = pipeline.is_sd_xl() - self.device = device - self.verbose = verbose - self.path = get_path(version, pipeline, controlnet) - - self.fp16 = fp16 - - self.min_batch = 1 - self.max_batch = max_batch_size - self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 768 if version in ("1.4", "1.5") else 1024 - self.min_latent_shape = self.min_image_shape // 8 - self.max_latent_shape = self.max_image_shape // 8 - - self.text_maxlen = text_maxlen - self.embedding_dim = embedding_dim - self.extra_output_names = [] - - def get_input_names(self): - pass - - def get_output_names(self): - pass - - def get_dynamic_axes(self): - return None - - def get_sample_input(self, batch_size, image_height, image_width): - pass - - def get_input_profile( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - return None - - def get_shape_dict(self, batch_size, image_height, image_width): - return None - - def optimize(self, onnx_graph): - opt = Optimizer(onnx_graph, verbose=self.verbose) - opt.info(self.name + ": original") - opt.cleanup() - opt.info(self.name + ": cleanup") - opt.fold_constants() - opt.info(self.name + ": fold constants") - opt.infer_shapes() - opt.info(self.name + ": shape inference") - onnx_opt_graph = opt.cleanup(return_onnx=True) - opt.info(self.name + ": finished") - return onnx_opt_graph - - def check_dims(self, batch_size, image_height, image_width): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - assert image_height % 8 == 0 or image_width % 8 == 0 - latent_height = image_height // 8 - latent_width = image_width // 8 - assert ( - latent_height >= self.min_latent_shape - and latent_height <= self.max_latent_shape - ) - assert ( - latent_width >= self.min_latent_shape - and latent_width <= self.max_latent_shape - ) - return (latent_height, latent_width) - - def get_minmax_dims( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - latent_height = image_height // 8 - latent_width = image_width // 8 - min_image_height = image_height if static_shape else self.min_image_shape - max_image_height = image_height if static_shape else self.max_image_shape - min_image_width = image_width if static_shape else self.min_image_shape - max_image_width = image_width if static_shape else self.max_image_shape - min_latent_height = latent_height if static_shape else self.min_latent_shape - max_latent_height = latent_height if static_shape else self.max_latent_shape - min_latent_width = latent_width if static_shape else self.min_latent_shape - max_latent_width = latent_width if static_shape else self.max_latent_shape - return ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) - - def get_latent_dim(self, min_h, opt_h, max_h, min_w, opt_w, max_w, static_shape): - if static_shape: - return ( - opt_h // 8, - opt_h // 8, - opt_h // 8, - opt_w // 8, - opt_w // 8, - opt_w // 8, - ) - return min_h // 8, opt_h // 8, max_h // 8, min_w // 8, opt_w // 8, max_w // 8 - - def get_batch_dim(self, min_batch, opt_batch, max_batch, static_batch): - if self.text_maxlen <= 77: - return (min_batch * 2, opt_batch * 2, max_batch * 2) - elif self.text_maxlen > 77 and static_batch: - return (opt_batch, opt_batch, opt_batch) - elif self.text_maxlen > 77 and not static_batch: - if self.text_optlen > 77: - return (min_batch, opt_batch, max_batch * 2) - return (min_batch, opt_batch * 2, max_batch * 2) - else: - raise Exception("Uncovered case in get_batch_dim") - - -class CLIP(BaseModel): - def __init__( - self, - version, - pipeline, - hf_token, - device, - verbose, - max_batch_size, - embedding_dim, - output_hidden_states=False, - subfolder="text_encoder", - ): - super(CLIP, self).__init__( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - self.subfolder = subfolder - - # Output the final hidden state - if output_hidden_states: - self.extra_output_names = ["hidden_states"] - - def get_input_names(self): - return ["input_ids"] - - def get_output_names(self): - return ["text_embeddings"] - - def get_dynamic_axes(self): - return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} - - def get_input_profile( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_shape - ) - return { - "input_ids": [ - (min_batch, self.text_maxlen), - (batch_size, self.text_maxlen), - (max_batch, self.text_maxlen), - ] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - output = { - "input_ids": (batch_size, self.text_maxlen), - "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), - } - if "hidden_states" in self.extra_output_names: - output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) - return output - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.zeros( - batch_size, self.text_maxlen, dtype=torch.int32, device=self.device - ) - - def optimize(self, onnx_graph): - opt = Optimizer(onnx_graph, verbose=self.verbose) - opt.info(self.name + ": original") - opt.select_outputs([0]) # delete graph output#1 - opt.cleanup() - opt.info(self.name + ": remove output[1]") - opt.fold_constants() - opt.info(self.name + ": fold constants") - opt.infer_shapes() - opt.info(self.name + ": shape inference") - opt.select_outputs([0], names=["text_embeddings"]) # rename network output - opt.info(self.name + ": remove output[0]") - opt_onnx_graph = opt.cleanup(return_onnx=True) - if "hidden_states" in self.extra_output_names: - opt_onnx_graph = opt.clip_add_hidden_states(return_onnx=True) - opt.info(self.name + ": added hidden_states") - opt.info(self.name + ": finished") - return opt_onnx_graph - - -def make_CLIP( - version, - pipeline, - hf_token, - device, - verbose, - max_batch_size, - output_hidden_states=False, - subfolder="text_encoder", -): - return CLIP( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - embedding_dim=get_clip_embedding_dim(version, pipeline), - output_hidden_states=output_hidden_states, - subfolder=subfolder, - ) - - -class CLIPWithProj(CLIP): - def __init__( - self, - version, - pipeline, - hf_token, - device="cuda", - verbose=True, - max_batch_size=16, - output_hidden_states=False, - subfolder="text_encoder_2", - ): - super(CLIPWithProj, self).__init__( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - embedding_dim=get_clipwithproj_embedding_dim(version, pipeline), - output_hidden_states=output_hidden_states, - ) - self.subfolder = subfolder - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - output = { - "input_ids": (batch_size, self.text_maxlen), - "text_embeddings": (batch_size, self.embedding_dim), - } - if "hidden_states" in self.extra_output_names: - output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) - - return output - - -def make_CLIPWithProj( - version, - pipeline, - hf_token, - device, - verbose, - max_batch_size, - subfolder="text_encoder_2", - output_hidden_states=False, -): - return CLIPWithProj( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - subfolder=subfolder, - output_hidden_states=output_hidden_states, - ) - - -class UNet2DConditionControlNetModel(torch.nn.Module): - def __init__(self, unet, controlnets) -> None: - super().__init__() - self.unet = unet - self.controlnets = controlnets - - def forward( - self, sample, timestep, encoder_hidden_states, images, controlnet_scales - ): - for i, (image, conditioning_scale, controlnet) in enumerate( - zip(images, controlnet_scales, self.controlnets) - ): - down_samples, mid_sample = controlnet( - sample, - timestep, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=image, - return_dict=False, - ) - - down_samples = [ - down_sample * conditioning_scale for down_sample in down_samples - ] - mid_sample *= conditioning_scale - - # merge samples - if i == 0: - down_block_res_samples, mid_block_res_sample = down_samples, mid_sample - else: - down_block_res_samples = [ - samples_prev + samples_curr - for samples_prev, samples_curr in zip( - down_block_res_samples, down_samples - ) - ] - mid_block_res_sample += mid_sample - - noise_pred = self.unet( - sample, - timestep, - encoder_hidden_states=encoder_hidden_states, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) - return noise_pred - - -class UNet(BaseModel): - def __init__( - self, - version, - pipeline, - hf_token, - device="cuda", - verbose=True, - fp16=False, - max_batch_size=16, - text_maxlen=77, - unet_dim=4, - controlnet=None, - ): - super(UNet, self).__init__( - version, - pipeline, - hf_token, - fp16=fp16, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - text_maxlen=text_maxlen, - embedding_dim=get_unet_embedding_dim(version, pipeline), - ) - self.unet_dim = unet_dim - self.controlnet = controlnet - - # def get_model(self, framework_model_dir): - # model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if self.fp16 else {} - # if self.controlnet: - # unet_model = UNet2DConditionModel.from_pretrained(self.path, - # subfolder="unet", - # use_safetensors=self.hf_safetensor, - # use_auth_token=self.hf_token, - # **model_opts).to(self.device) - - # cnet_model_opts = {'torch_dtype': torch.float16} if self.fp16 else {} - # controlnets = torch.nn.ModuleList([ControlNetModel.from_pretrained(path, **cnet_model_opts).to(self.device) for path in self.controlnet]) - # # FIXME - cache UNet2DConditionControlNetModel - # model = UNet2DConditionControlNetModel(unet_model, controlnets) - # else: - # unet_model_dir = os.path.join(framework_model_dir, self.version, self.pipeline, "unet") - # if not os.path.exists(unet_model_dir): - # model = UNet2DConditionModel.from_pretrained(self.path, - # subfolder="unet", - # use_safetensors=self.hf_safetensor, - # use_auth_token=self.hf_token, - # **model_opts).to(self.device) - # model.save_pretrained(unet_model_dir) - # else: - # print(f"[I] Load UNet pytorch model from: {unet_model_dir}") - # model = UNet2DConditionModel.from_pretrained(unet_model_dir).to(self.device) - # return model - - def get_input_names(self): - if self.controlnet is None: - return ["sample", "timestep", "encoder_hidden_states"] - else: - return [ - "sample", - "timestep", - "encoder_hidden_states", - "images", - "controlnet_scales", - ] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - if self.controlnet is None: - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - else: - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "images": {1: "2B", 3: "8H", 4: "8W"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - - def get_input_profile( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_shape - ) - if self.controlnet is None: - return { - "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), - ], - } - else: - return { - "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), - ], - "images": [ - ( - len(self.controlnet), - 2 * min_batch, - 3, - min_image_height, - min_image_width, - ), - ( - len(self.controlnet), - 2 * batch_size, - 3, - image_height, - image_width, - ), - ( - len(self.controlnet), - 2 * max_batch, - 3, - max_image_height, - max_image_width, - ), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - if self.controlnet is None: - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "encoder_hidden_states": ( - 2 * batch_size, - self.text_maxlen, - self.embedding_dim, - ), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - else: - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "encoder_hidden_states": ( - 2 * batch_size, - self.text_maxlen, - self.embedding_dim, - ), - "images": ( - len(self.controlnet), - 2 * batch_size, - 3, - image_height, - image_width, - ), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - dtype = torch.float16 if self.fp16 else torch.float32 - if self.controlnet is None: - return ( - torch.randn( - batch_size, - self.unet_dim, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn( - batch_size, - self.text_maxlen, - self.embedding_dim, - dtype=dtype, - device=self.device, - ), - ) - else: - return ( - torch.randn( - batch_size, - self.unet_dim, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ), - torch.tensor(999, dtype=torch.float32, device=self.device), - torch.randn( - batch_size, - self.text_maxlen, - self.embedding_dim, - dtype=dtype, - device=self.device, - ), - torch.randn( - len(self.controlnet), - batch_size, - 3, - image_height, - image_width, - dtype=dtype, - device=self.device, - ), - torch.randn(len(self.controlnet), dtype=dtype, device=self.device), - ) - - -def make_UNet( - version, pipeline, hf_token, device, verbose, max_batch_size, controlnet=None -): - # Disable torch SDPA - # if hasattr(F, "scaled_dot_product_attention"): - # delattr(F, "scaled_dot_product_attention") - - return UNet( - version, - pipeline, - hf_token, - fp16=True, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - unet_dim=(9 if pipeline.is_inpaint() else 4), - controlnet=get_controlnets_path(controlnet), - ) - - -class OAIUNet(BaseModel): - def __init__( - self, - version, - pipeline, - device="cuda", - verbose=True, - fp16=False, - max_batch_size=16, - text_maxlen=77, - text_optlen=77, - unet_dim=4, - controlnet=None, - ): - super(OAIUNet, self).__init__( - version, - pipeline, - "", - fp16=fp16, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - text_maxlen=text_maxlen, - embedding_dim=get_unet_embedding_dim(version, pipeline), - ) - self.unet_dim = unet_dim - self.controlnet = controlnet - self.text_optlen = text_optlen - - def get_input_names(self): - if self.controlnet is None: - return ["sample", "timesteps", "encoder_hidden_states"] - else: - return [ - "sample", - "timesteps", - "encoder_hidden_states", - "images", - "controlnet_scales", - ] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - if self.controlnet is None: - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "timesteps": {0: "2B"}, - "encoder_hidden_states": {0: "2B", 1: "77N"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - else: - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "timesteps": {0: "2B"}, - "encoder_hidden_states": {0: "2B", 1: "77N"}, - "images": {1: "2B", 3: "8H", 4: "8W"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - - def get_input_profile( - self, - min_batch, - opt_batch, - max_batch, - min_h, - opt_h, - max_h, - min_w, - opt_w, - max_w, - static_shape, - ): - min_batch, opt_batch, max_batch = self.get_batch_dim( - min_batch, opt_batch, max_batch, static_shape - ) - ( - min_latent_height, - latent_height, - max_latent_height, - min_latent_width, - latent_width, - max_latent_width, - ) = self.get_latent_dim(min_h, opt_h, max_h, min_w, opt_w, max_w, static_shape) - - if self.controlnet is None: - return { - "sample": [ - (min_batch, self.unet_dim, min_latent_height, min_latent_width), - (opt_batch, self.unet_dim, latent_height, latent_width), - (max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "timesteps": [(min_batch,), (opt_batch,), (max_batch,)], - "encoder_hidden_states": [ - (min_batch, self.text_optlen, self.embedding_dim), - (opt_batch, self.text_optlen, self.embedding_dim), - (max_batch, self.text_maxlen, self.embedding_dim), - ], - } - else: - return { - "sample": [ - (min_batch, self.unet_dim, min_latent_height, min_latent_width), - (opt_batch, self.unet_dim, latent_height, latent_width), - (max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "timesteps": [(min_batch,), (opt_batch,), (max_batch,)], - "encoder_hidden_states": [ - (min_batch, self.text_optlen, self.embedding_dim), - (opt_batch, self.text_optlen, self.embedding_dim), - (max_batch, self.text_maxlen, self.embedding_dim), - ], - "images": [ - ( - len(self.controlnet), - min_batch, - 3, - min_h, - min_w, - ), - ( - len(self.controlnet), - opt_batch, - 3, - opt_h, - opt_w, - ), - ( - len(self.controlnet), - max_batch, - 3, - max_h, - max_w, - ), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - if self.controlnet is None: - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timesteps": (2 * batch_size,), - "encoder_hidden_states": ( - 2 * batch_size, - self.text_optlen, - self.embedding_dim, - ), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - else: - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timesteps": (2 * batch_size,), - "encoder_hidden_states": ( - 2 * batch_size, - self.text_optlen, - self.embedding_dim, - ), - "images": ( - len(self.controlnet), - 2 * batch_size, - 3, - image_height, - image_width, - ), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - dtype = torch.float16 if self.fp16 else torch.float32 - if self.controlnet is None: - return ( - torch.randn( - 2 * batch_size, - self.unet_dim, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ), - torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device), - torch.randn( - 2 * batch_size, - self.text_optlen, - self.embedding_dim, - dtype=dtype, - device=self.device, - ), - ) - else: - return ( - torch.randn( - batch_size, - self.unet_dim, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ), - torch.ones((batch_size,), dtype=torch.float32, device=self.device), - torch.randn( - batch_size, - self.text_optlen, - self.embedding_dim, - dtype=dtype, - device=self.device, - ), - torch.randn( - len(self.controlnet), - batch_size, - 3, - image_height, - image_width, - dtype=dtype, - device=self.device, - ), - torch.randn(len(self.controlnet), dtype=dtype, device=self.device), - ) - - -def make_OAIUNet( - version, - pipeline, - device, - verbose, - max_batch_size, - text_optlen, - text_maxlen, - controlnet=None, -): - return OAIUNet( - version, - pipeline, - fp16=True, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - text_optlen=text_optlen, - text_maxlen=text_maxlen, - unet_dim=(9 if pipeline.is_inpaint() else 4), - controlnet=get_controlnets_path(controlnet), - ) - - -class OAIUNetXL(BaseModel): - def __init__( - self, - version, - pipeline, - fp16=False, - device="cuda", - verbose=True, - max_batch_size=16, - text_maxlen=77, - text_optlen=77, - unet_dim=4, - time_dim=6, - num_classes=2816, - ): - super(OAIUNetXL, self).__init__( - version, - pipeline, - "", - fp16=fp16, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - text_maxlen=text_maxlen, - embedding_dim=get_unet_embedding_dim(version, pipeline), - ) - self.unet_dim = unet_dim - self.time_dim = time_dim - self.num_classes = num_classes - self.text_optlen = text_optlen - - def get_input_names(self): - return ["sample", "timesteps", "encoder_hidden_states", "y"] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B", 1: "77N"}, - "timesteps": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - "y": {0: "2B", 1: "num_classes"}, - } - - def get_input_profile( - self, - min_batch, - opt_batch, - max_batch, - min_h, - opt_h, - max_h, - min_w, - opt_w, - max_w, - static_shape, - ): - min_batch, opt_batch, max_batch = self.get_batch_dim( - min_batch, opt_batch, max_batch, static_shape - ) - - ( - min_latent_height, - latent_height, - max_latent_height, - min_latent_width, - latent_width, - max_latent_width, - ) = self.get_latent_dim(min_h, opt_h, max_h, min_w, opt_w, max_w, static_shape) - - return { - "sample": [ - (min_batch, self.unet_dim, min_latent_height, min_latent_width), - (opt_batch, self.unet_dim, latent_height, latent_width), - (max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "timesteps": [(min_batch,), (opt_batch,), (max_batch,)], - "encoder_hidden_states": [ - (min_batch, self.text_optlen, self.embedding_dim), - (opt_batch, self.text_optlen, self.embedding_dim), - (max_batch, self.text_maxlen, self.embedding_dim), - ], - "y": [ - (min_batch, self.num_classes), - (opt_batch, self.num_classes), - (max_batch, self.num_classes), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timesteps": (2 * batch_size,), - "encoder_hidden_states": ( - 2 * batch_size, - self.text_optlen, - self.embedding_dim, - ), - "y": (2 * batch_size, self.num_classes), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - dtype = torch.float16 if self.fp16 else torch.float32 - return ( - torch.randn( - 2 * batch_size, - self.unet_dim, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ), - torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device), - torch.randn( - 2 * batch_size, - self.text_optlen, - self.embedding_dim, - dtype=dtype, - device=self.device, - ), - torch.randn( - 2 * batch_size, self.num_classes, dtype=dtype, device=self.device - ), - ) - - -def make_OAIUNetXL( - version, pipeline, device, verbose, max_batch_size, text_optlen, text_maxlen -): - # Disable torch SDPA - # if hasattr(F, "scaled_dot_product_attention"): - # delattr(F, "scaled_dot_product_attention") - return OAIUNetXL( - version, - pipeline, - fp16=True, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - unet_dim=(9 if pipeline.is_inpaint() else 4), - text_optlen=text_optlen, - text_maxlen=text_maxlen, - ) - - -class VAE(BaseModel): - def __init__( - self, - version, - pipeline, - hf_token, - device, - verbose, - max_batch_size, - ): - super(VAE, self).__init__( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - ) - - def get_input_names(self): - return ["latent"] - - def get_output_names(self): - return ["images"] - - def get_dynamic_axes(self): - return { - "latent": {0: "B", 2: "H", 3: "W"}, - "images": {0: "B", 2: "8H", 3: "8W"}, - } - - def get_input_profile( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_shape - ) - return { - "latent": [ - (min_batch, 4, min_latent_height, min_latent_width), - (batch_size, 4, latent_height, latent_width), - (max_batch, 4, max_latent_height, max_latent_width), - ] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - return { - "latent": (batch_size, 4, latent_height, latent_width), - "images": (batch_size, 3, image_height, image_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - return torch.randn( - batch_size, - 4, - latent_height, - latent_width, - dtype=torch.float32, - device=self.device, - ) - - -def make_VAE(version, pipeline, hf_token, device, verbose, max_batch_size): - return VAE( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - ) - - -class VAEEncoder(BaseModel): - def __init__( - self, - version, - pipeline, - hf_token, - device, - verbose, - max_batch_size, - ): - super(VAEEncoder, self).__init__( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - ) - - def get_input_names(self): - return ["images"] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - return { - "images": {0: "B", 2: "8H", 3: "8W"}, - "latent": {0: "B", 2: "H", 3: "W"}, - } - - def get_input_profile( - self, batch_size, image_height, image_width, static_batch, static_shape - ): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - _, - _, - _, - _, - ) = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_shape - ) - - return { - "images": [ - (min_batch, 3, min_image_height, min_image_width), - (batch_size, 3, image_height, image_width), - (max_batch, 3, max_image_height, max_image_width), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims( - batch_size, image_height, image_width - ) - return { - "images": (batch_size, 3, image_height, image_width), - "latent": (batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.randn( - batch_size, - 3, - image_height, - image_width, - dtype=torch.float32, - device=self.device, - ) - - -def make_VAEEncoder(version, pipeline, hf_token, device, verbose, max_batch_size): - return VAEEncoder( - version, - pipeline, - hf_token, - device=device, - verbose=verbose, - max_batch_size=max_batch_size, - ) diff --git a/scripts/lora.py b/scripts/lora.py new file mode 100644 index 0000000..102ba41 --- /dev/null +++ b/scripts/lora.py @@ -0,0 +1,45 @@ +import os +from typing import List +import numpy as np +from safetensors.torch import load_file +import onnx_graphsurgeon as gs +import onnx +import torch +from onnx import numpy_helper + + +def merge_loras(loras: List[str], scales: List[str]): + refit_dict = {} + for lora, scale in zip(loras, scales): + lora_dict = load_file(lora) + for k, v in lora_dict.items(): + if k in refit_dict: + refit_dict[k] += scale * v + else: + refit_dict[k] = scale * v + return refit_dict + + +def apply_loras(base_path: str, loras: List[str], scales: List[str]) -> dict: + refit_dict = merge_loras(loras, scales) + base = onnx.load(base_path) + onnx_opt_dir = os.path.dirname(base_path) + + def convert_int64(arr): + if len(arr.shape) == 0: + return np.array([np.int32(arr)]) + return arr + + for initializer in base.graph.initializer: + if initializer.name not in refit_dict: + continue + + wt = refit_dict[initializer.name] + initializer_data = numpy_helper.to_array( + initializer, base_dir=onnx_opt_dir + ).astype(np.float16) + delta = torch.tensor(initializer_data).to(wt.device) + wt + + refit_dict[initializer.name] = delta.contiguous() + + return refit_dict diff --git a/scripts/trt.py b/scripts/trt.py index 891641c..a62acbb 100644 --- a/scripts/trt.py +++ b/scripts/trt.py @@ -11,6 +11,9 @@ from model_manager import TRT_MODEL_DIR, modelmanager from polygraphy.logger import G_LOGGER import gradio as gr +import re +from datastructures import UNetEngineArgs, ModelType +from scripts.lora import apply_loras G_LOGGER.module_severity = G_LOGGER.ERROR @@ -22,38 +25,29 @@ def __init__(self, name: str, filename: List[dict]): self.configs = filename def create_unet(self): - lora_path = None - if self.configs[0]["config"].lora: - lora_path = os.path.join(TRT_MODEL_DIR, self.configs[0]["filepath"]) - self.model_name = self.configs[0]["base_model"] - self.configs = modelmanager.available_models()[self.model_name] - return TrtUnet(self.model_name, self.configs, lora_path) + return TrtUnet(self.model_name, self.configs) -# This is ugly. Is there a better way to parse this as kwargs to the SD Unet? -GLOBAL_KWARGS = {"profile_idx": None, "profile_hr_idx": None, "model_name": ""} +GLOBAL_ARGS = UNetEngineArgs(0, 0, None, {}) class TrtUnet(sd_unet.SdUnet): - def __init__( - self, model_name: str, configs: List[dict], lora_path, *args, **kwargs - ): + def __init__(self, model_name: str, configs: List[dict], *args, **kwargs): super().__init__(*args, **kwargs) - if not model_name == GLOBAL_KWARGS["model_name"]: - raise ValueError( - """Selected torch model ({}) does not match the selected TensorRT U-Net ({}). - Please ensure that both models are the same or select Automatic from the SD UNet dropdown.""".format( - GLOBAL_KWARGS["model_name"], model_name - ) - ) - self.configs = configs + self.stream = None self.model_name = model_name - self.lora_path = lora_path - self.engine_vram_req = 0 + self.configs = configs - self.profile_idx = GLOBAL_KWARGS["profile_idx"] + self.profile_idx = GLOBAL_ARGS.idx + if self.profile_idx is None: + self.profile_idx = 0 self.loaded_config = self.configs[self.profile_idx] + + self.engine_vram_req = 0 + self.shape_hash = 0 + self.refitted_keys = set() + self.engine = Engine( os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]) ) @@ -68,7 +62,7 @@ def forward(self, x, timesteps, context, *args, **kwargs): if "y" in kwargs: feed_dict["y"] = kwargs["y"].float() - if not self.profile_idx == GLOBAL_KWARGS["profile_idx"]: + if not self.profile_idx == GLOBAL_ARGS.idx: self.switch_engine() tmp = torch.empty( @@ -83,25 +77,39 @@ def forward(self, x, timesteps, context, *args, **kwargs): nvtx.range_pop() return out + def apply_loras(self): + if GLOBAL_ARGS.lora is None: + refit_dict = {} + else: + refit_dict = GLOBAL_ARGS.lora + if not self.refitted_keys.issubset(set(refit_dict.keys())): + # Need to ensure that weights that have been modified before and are not present anymore are reset. + self.refitted_keys = set() + self.switch_engine() + + self.engine.refit_from_dict(refit_dict, is_fp16=True) + self.refitted_keys = set(refit_dict.keys()) + + def set_idx(self): + if GLOBAL_ARGS.idx is None: + raise Exception("No valid profile found. Please generate a profile first.") + self.profile_idx = GLOBAL_ARGS.idx + def switch_engine(self): - self.profile_idx = GLOBAL_KWARGS["profile_idx"] + self.set_idx() self.loaded_config = self.configs[self.profile_idx] - self.deactivate() - self.engine = Engine( - os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]) - ) + self.engine.reset(os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"])) self.activate() + self.shape_hash = 0 def activate(self): + self.shape_hash = 0 self.engine.load() print(f"\nLoaded Profile: {self.profile_idx}") print(self.engine) self.engine_vram_req = self.engine.engine.device_memory_size self.engine.activate(True) - if self.lora_path is not None: - self.engine.refit_from_dump(self.lora_path) - def deactivate(self): self.shape_hash = 0 del self.engine @@ -110,7 +118,8 @@ def deactivate(self): class TensorRTScript(scripts.Script): def __init__(self) -> None: self.loaded_model = None - pass + self.lora_hash = "" + self.update_lora = False def title(self): return "TensorRT" @@ -124,78 +133,150 @@ def setup(self, p, *args): def before_process(self, p, *args): # 1 # Check divisibilty if p.width % 64 or p.height % 64: - raise ValueError( - "Target resolution must be divisible by 64 in both dimensions." - ) + gr.Error("Target resolution must be divisible by 64 in both dimensions.") if p.enable_hr: hr_w = int(p.width * p.hr_scale) hr_h = int(p.height * p.hr_scale) if hr_w % 64 or hr_h % 64: - raise ValueError( + gr.Error( "HIRES Fix resolution must be divisible by 64 in both dimensions. Please change the upscale factor or disable HIRES Fix." ) - # lora p.prompt == ' - - def process(self, p, *args): # 2 - # before unet_init + def get_profile_idx(self, p, model_name, model_type): + best_hr = None hr_scale = p.hr_scale if p.enable_hr else 1 ( valid_models, distances, idx, ) = modelmanager.get_valid_models( - p.sd_model_name, p.width, p.height, p.batch_size, 77 - ) # TODO: max_embedding + model_name, + p.width, + p.height, + p.batch_size, + 77, # model_type + ) # TODO: max_embedding, just ignore? if len(valid_models) == 0: - raise ValueError( - """No valid profile found for LOWRES. Please go to the TensorRT tab and generate an engine with the necessary profile. + gr.Error( + f"""No valid profile found for ({model_name}) LOWRES. Please go to the TensorRT tab and generate an engine with the necessary profile. If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net.""" ) + return None, None best = idx[np.argmin(distances)] if hr_scale != 1: hr_w = int(p.width * p.hr_scale) hr_h = int(p.height * p.hr_scale) valid_models_hr, distances_hr, idx_hr = modelmanager.get_valid_models( - p.sd_model_name, hr_w, hr_h, p.batch_size, 77 + model_name, + hr_w, + hr_h, + p.batch_size, + 77, # model_type ) # TODO: max_embedding - if len(valid_models) == 0: - raise ValueError( - "No valid profile found for HIRES. Please go to the TensorRT tab and generate an engine with the necessary profile. If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net." + if len(valid_models_hr) == 0: + gr.Error( + f"""No valid profile found for ({model_name}) HIRES. Please go to the TensorRT tab and generate an engine with the necessary profile. + If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net.""" ) merged_idx = [i for i, id in enumerate(idx) if id in idx_hr] if len(merged_idx) == 0: gr.Warning( - "No model available for both LOWRES ({}x{}) and HIRES ({}x{}). This will slow-down inference.".format( - p.width, p.height, hr_w, hr_h + "No model available for both ({}) LOWRES ({}x{}) and HIRES ({}x{}). This will slow-down inference.".format( + model_name, p.width, p.height, hr_w, hr_h ) ) - best_hr = idx_hr[np.argmin(distances_hr)] + return None, None else: _distances = [distances[i] for i in merged_idx] best_hr = idx_hr[merged_idx[np.argmin(_distances)]] best = best_hr - GLOBAL_KWARGS["profile_hr_idx"] = best_hr - GLOBAL_KWARGS["profile_idx"] = best - GLOBAL_KWARGS["model_name"] = p.sd_model_name + + return best, best_hr + + def get_loras(self, p): + lora_pathes = [] + lora_scales = [] + + # get lora from prompt + _prompt = p.prompt + extra_networks = re.findall("\<(.*?)\>", _prompt) + loras = [net for net in extra_networks if net.startswith("lora")] + + # Avoid that extra networks will be loaded + for lora in loras: + _prompt = _prompt.replace(f"<{lora}>", "") + p.prompt = _prompt + + # check if lora config has changes + if self.lora_hash != "".join(loras): + self.lora_hash = "".join(loras) + self.update_lora = True + if self.lora_hash == "": + GLOBAL_ARGS.lora = None + return + else: + return + + # Get pathes + print("Apllying LoRAs: " + str(loras)) + available = modelmanager.available_models() + for lora in loras: + lora_name, lora_scale = lora.split(":")[1:] + lora_scales.append(float(lora_scale)) + if lora_name not in available: + raise Exception( + f"Please export the LoRA checkpoint {lora_name} first from the TensorRT LoRA tab" + ) + lora_pathes.append( + os.path.join(TRT_MODEL_DIR, available[lora_name][0]["filepath"]) + ) + + # Merge lora refit dicts + base_name, base_path = modelmanager.get_onnx_path(p.sd_model_name) + refit_dict = apply_loras(base_path, lora_pathes, lora_scales) + + GLOBAL_ARGS.lora = refit_dict + + def process(self, p, *args): + # before unet_init + sd_unet_option = sd_unet.get_unet_option() + if sd_unet_option is None: + return + + if not sd_unet_option.model_name == p.sd_model_name: + gr.Error( + """Selected torch model ({}) does not match the selected TensorRT U-Net ({}). + Please ensure that both models are the same or select Automatic from the SD UNet dropdown.""".format( + p.sd_model_name, sd_unet_option.model_name + ) + ) + GLOBAL_ARGS.idx, GLOBAL_ARGS.hr_idx = self.get_profile_idx( + p, p.sd_model_name, ModelType.UNET + ) + + try: + self.get_loras(p) + except Exception as e: + gr.Error(e) + raise e def process_batch(self, p, *args, **kwargs): + # Called for each batch count return super().process_batch(p, *args, **kwargs) def before_hr(self, p, *args): - GLOBAL_KWARGS["profile_idx"] = GLOBAL_KWARGS["profile_hr_idx"] + GLOBAL_ARGS.idx = GLOBAL_ARGS.hr_idx + return super().before_hr(p, *args) # 4 (Only when HR starts.....) def after_extra_networks_activate(self, p, *args, **kwargs): - # if self.lora_path is not None: - # self.engine.refit_from_dump(self.lora_path) - - # Called after UNet activate - # p.extra_network_data - # Contains dict of modules.extra_networks.ExtraNetworkParams - return super().after_extra_networks_activate(p, *args, **kwargs) # 3 + if self.update_lora: + self.update_lora = False + # Not the fastest, but safest option. Larger bottlenecks to solve first! + # Other two options: Overengingeer, Refit whole model + sd_unet.current_unet.apply_loras() def list_unets(l): diff --git a/ui_trt.py b/ui_trt.py index bbfc249..0a35bb4 100644 --- a/ui_trt.py +++ b/ui_trt.py @@ -1,34 +1,49 @@ import os -from modules import sd_models, shared import gradio as gr - from modules.call_queue import wrap_gradio_gpu_call from modules.shared import cmd_opts from modules.ui_components import FormRow +from modules import sd_hijack, sd_unet, sd_models, shared +from modules.ui_common import refresh_symbol +from modules.ui_components import ToolButton + +from exporter import export_onnx, export_trt, export_lora +from utilities import Engine +from safetensors.torch import save_file + -from exporter import export_onnx, export_trt -from utilities import PIPELINE_TYPE, Engine -from models import make_OAIUNetXL, make_OAIUNet import logging import gc import torch -from model_manager import modelmanager, cc_major, TRT_MODEL_DIR -from time import sleep from collections import defaultdict -from modules.ui_common import refresh_symbol -from modules.ui_components import ToolButton +import json + +from model_helper import UNetModel +from model_manager import modelmanager, cc_major, TRT_MODEL_DIR +from datastructures import SDVersion, ProfilePrests, ProfileSettings + +profile_presets = ProfilePrests() logging.basicConfig(level=logging.INFO) -def get_version_from_model(sd_model): - if sd_model.is_sd1: - return "1.5" - if sd_model.is_sd2: - return "2.1" - if sd_model.is_sdxl: - return "xl-1.0" +# TODO get info from model config +def get_context_dim(): + if shared.sd_model.is_sd1: + return 768 + elif shared.sd_model.is_sd2: + return 1024 + elif shared.sd_model.is_sdxl: + return 2048 + + +def is_fp32(): + use_fp32 = False + if cc_major < 7: + use_fp32 = True + print("FP16 has been disabled because your GPU does not support it.") + return use_fp32 def export_unet_to_trt( @@ -47,75 +62,21 @@ def export_unet_to_trt( force_export, static_shapes, preset, - controlnet=None, ): + def disable_checkpoint(self): + if getattr(self, "use_checkpoint", False) == True: + self.use_checkpoint = False + if getattr(self, "checkpoint", False) == True: + self.checkpoint = False - if preset == "Default": - ( - batch_min, - batch_opt, - batch_max, - height_min, - height_opt, - height_max, - width_min, - width_opt, - width_max, - token_count_min, - token_count_opt, - token_count_max, - ) = export_default_unet_to_trt() - is_inpaint = False - use_fp32 = False - if cc_major < 7: - use_fp32 = True - print("FP16 has been disabled because your GPU does not support it.") - - unet_hidden_dim = shared.sd_model.model.diffusion_model.in_channels - if unet_hidden_dim == 9: - is_inpaint = True + shared.sd_model.model.diffusion_model.apply(disable_checkpoint) + sd_unet.apply_unet("None") + sd_hijack.model_hijack.apply_optimizations("None") - model_hash = shared.sd_model.sd_checkpoint_info.hash + is_xl = shared.sd_model.is_sdxl model_name = shared.sd_model.sd_checkpoint_info.model_name - onnx_filename, onnx_path = modelmanager.get_onnx_path(model_name, model_hash) - - print(f"Exporting {model_name} to TensorRT") - - timing_cache = modelmanager.get_timing_cache() - - version = get_version_from_model(shared.sd_model) - - pipeline = PIPELINE_TYPE.TXT2IMG - if is_inpaint: - pipeline = PIPELINE_TYPE.INPAINT - controlnet = None - min_textlen = (token_count_min // 75) * 77 - opt_textlen = (token_count_opt // 75) * 77 - max_textlen = (token_count_max // 75) * 77 - if static_shapes: - min_textlen = max_textlen = opt_textlen - - if shared.sd_model.is_sdxl: - pipeline = PIPELINE_TYPE.SD_XL_BASE - modelobj = make_OAIUNetXL( - version, pipeline, "cuda", False, batch_max, opt_textlen, max_textlen - ) - diable_optimizations = True - else: - modelobj = make_OAIUNet( - version, - pipeline, - "cuda", - False, - batch_max, - opt_textlen, - max_textlen, - controlnet, - ) - diable_optimizations = False - - profile = modelobj.get_input_profile( + profile_settings = ProfileSettings( batch_min, batch_opt, batch_max, @@ -125,28 +86,54 @@ def export_unet_to_trt( width_min, width_opt, width_max, - static_shapes, + token_count_min, + token_count_opt, + token_count_max, ) - print(profile) + if preset == "Default": + profile_settings = profile_presets.get_default(is_xl=is_xl) + use_fp32 = is_fp32() - if not os.path.exists(onnx_path): - print("No ONNX file found. Exporting ONNX...") - gr.Info("No ONNX file found. Exporting ONNX... Please check the progress in the terminal.") - export_onnx( - onnx_path, - modelobj, - profile=profile, - diable_optimizations=diable_optimizations, - ) - print("Exported to ONNX.") + print(f"Exporting {model_name} to TensorRT using - {profile_settings}") + profile_settings.token_to_dim(static_shapes) + + model_hash = shared.sd_model.sd_checkpoint_info.hash + model_name = shared.sd_model.sd_checkpoint_info.model_name + + onnx_filename, onnx_path = modelmanager.get_onnx_path(model_name) + timing_cache = modelmanager.get_timing_cache() + + diable_optimizations = is_xl + embedding_dim = get_context_dim() + + modelobj = UNetModel( + shared.sd_model.model.diffusion_model, + embedding_dim, + text_minlen=profile_settings.t_min, + is_xl=is_xl, + ) + + profile = modelobj.get_input_profile(profile_settings) + export_onnx( + onnx_path, + modelobj, + profile_settings, + diable_optimizations=diable_optimizations, + ) + gc.collect() + torch.cuda.empty_cache() trt_engine_filename, trt_path = modelmanager.get_trt_path( model_name, model_hash, profile, static_shapes ) if not os.path.exists(trt_path) or force_export: - print("Building TensorRT engine... This can take a while, please check the progress in the terminal.") - gr.Info("Building TensorRT engine... This can take a while, please check the progress in the terminal.") + print( + "Building TensorRT engine... This can take a while, please check the progress in the terminal." + ) + gr.Info( + "Building TensorRT engine... This can take a while, please check the progress in the terminal." + ) gc.collect() torch.cuda.empty_cache() ret = export_trt( @@ -166,248 +153,189 @@ def export_unet_to_trt( profile, static_shapes, fp32=use_fp32, - inpaint=is_inpaint, + inpaint=False, # TODO refit=True, vram=0, - unet_hidden_dim=unet_hidden_dim, + unet_hidden_dim=4, # TODO lora=False, ) else: - print("TensorRT engine found. Skipping build. You can enable Force Export in the Advanced Settings to force a rebuild if needed.") + print( + "TensorRT engine found. Skipping build. You can enable Force Export in the Advanced Settings to force a rebuild if needed." + ) + + gc.collect() + torch.cuda.empty_cache() return "## Exported Successfully \n" def export_lora_to_trt(lora_name, force_export): - is_inpaint = False - use_fp32 = False - if cc_major < 7: - use_fp32 = True - print("FP16 has been disabled because your GPU does not support it.") - unet_hidden_dim = shared.sd_model.model.diffusion_model.in_channels - if unet_hidden_dim == 9: - is_inpaint = True - - model_hash = shared.sd_model.sd_checkpoint_info.hash - model_name = shared.sd_model.sd_checkpoint_info.model_name - base_name = f"{model_name}" # _{model_hash} + is_xl = shared.sd_model.is_sdxl available_lora_models = get_lora_checkpoints() lora_name = lora_name.split(" ")[0] - lora_model = available_lora_models[lora_name] + lora_model = available_lora_models.get(lora_name, None) + if lora_model is None: + return f"## No LoRA model found for {lora_name}" + + version = lora_model.get("version", SDVersion.Unknown) + if version == SDVersion.Unknown: + print( + "LoRA SD version couldm't be determined. Please ensure the correct SD Checkpoint is selected." + ) - onnx_base_filename, onnx_base_path = modelmanager.get_onnx_path( - model_name, model_hash - ) - onnx_lora_filename, onnx_lora_path = modelmanager.get_onnx_path( - lora_name, base_name + model_name = shared.sd_model.sd_checkpoint_info.model_name + model_hash = shared.sd_model.sd_checkpoint_info.hash + + if not version.match(shared.sd_model): + print( + f"""LoRA SD version ({version}) does not match the current SD version ({model_name}). + Please ensure the correct SD Checkpoint is selected.""" + ) + + profile_settings = profile_presets.get_default(is_xl=False) + print(f"Exporting {lora_name} to TensorRT using - {profile_settings}") + profile_settings.token_to_dim(True) + + onnx_base_filename, onnx_base_path = modelmanager.get_onnx_path(model_name) + if not os.path.exists(onnx_base_path): + return f"## Please export the base model ({model_name}) first." + + embedding_dim = get_context_dim() + + def disable_checkpoint(self): + if getattr(self, "use_checkpoint", False) == True: + self.use_checkpoint = False + if getattr(self, "checkpoint", False) == True: + self.checkpoint = False + + shared.sd_model.model.diffusion_model.apply(disable_checkpoint) + sd_unet.apply_unet("None") + sd_hijack.model_hijack.apply_optimizations("None") + + modelobj = UNetModel( + shared.sd_model.model.diffusion_model, + embedding_dim, + text_minlen=profile_settings.t_min, + is_xl=is_xl, ) - version = get_version_from_model(shared.sd_model) + weights_map_path = modelmanager.get_weights_map_path(model_name) + if not os.path.exists(weights_map_path): + modelobj.export_weights_map(onnx_base_path, weights_map_path) - pipeline = PIPELINE_TYPE.TXT2IMG - if is_inpaint: - pipeline = PIPELINE_TYPE.INPAINT + lora_trt_name = f"{lora_name}.trt" + lora_trt_path = os.path.join(TRT_MODEL_DIR, lora_trt_name) - if shared.sd_model.is_sdxl: - pipeline = PIPELINE_TYPE.SD_XL_BASE - modelobj = make_OAIUNetXL(version, pipeline, "cuda", False, 1, 77, 77) - diable_optimizations = True - else: - modelobj = make_OAIUNet( - version, - pipeline, - "cuda", - False, - 1, - 77, - 77, - None, - ) - diable_optimizations = False - - if not os.path.exists(onnx_lora_path): - print("No ONNX file found. Exporting ONNX...") - gr.Info("No ONNX file found. Exporting ONNX... Please check the progress in the terminal.") - export_onnx( - onnx_lora_path, - modelobj, - profile=modelobj.get_input_profile( - 1, 1, 1, 512, 512, 512, 512, 512, 512, True - ), - diable_optimizations=diable_optimizations, - lora_path=lora_model["filename"], + if os.path.exists(lora_trt_path) and not force_export: + print( + "TensorRT engine found. Skipping build. You can enable Force Export in the Advanced Settings to force a rebuild if needed." ) - print("Exported to ONNX.") + return "## Exported Successfully \n" + + profile = modelobj.get_input_profile(profile_settings) + refit_dict = export_lora( + modelobj, + onnx_base_path, + weights_map_path, + lora_model["filename"], + profile_settings, + ) - trt_lora_name = onnx_lora_filename.replace(".onnx", ".trt") - trt_lora_path = os.path.join(TRT_MODEL_DIR, trt_lora_name) + save_file(refit_dict, lora_trt_path) - available_trt_unet = modelmanager.available_models() - if len(available_trt_unet[base_name]) == 0: - return "## Please export the base model first." - trt_base_path = os.path.join( - TRT_MODEL_DIR, available_trt_unet[base_name][0]["filepath"] + modelmanager.add_lora_entry( + model_name, + lora_name, + lora_trt_name, + is_fp32(), + False, + 0, + 4, ) - if not os.path.exists(onnx_base_path): - return "## Please export the base model first." - - if not os.path.exists(trt_lora_path) or force_export: - print("No TensorRT engine found. Building...") - gr.Info("No TensorRT engine found. Building...") - - engine = Engine(trt_base_path) - engine.load() - engine.refit(onnx_base_path, onnx_lora_path, dump_refit_path=trt_lora_path) - print("Built TensorRT engine.") - - modelmanager.add_lora_entry( - base_name, - lora_name, - trt_lora_name, - use_fp32, - is_inpaint, - 0, - unet_hidden_dim, - ) return "## Exported Successfully \n" -def export_default_unet_to_trt(): - is_xl = shared.sd_model.is_sdxl +def get_version_from_filename(name): + if "v1-" in name: + return "1.5" + elif "v2-" in name: + return "2.1" + elif "xl" in name: + return "xl-1.0" + else: + return "Unknown" - batch_min = 1 - batch_opt = 1 - batch_max = 4 - height_min = 768 if is_xl else 512 - height_opt = 1024 if is_xl else 512 - height_max = 1024 if is_xl else 768 - width_min = 768 if is_xl else 512 - width_opt = 1024 if is_xl else 512 - width_max = 1024 if is_xl else 768 - token_count_min = 75 - token_count_opt = 75 - token_count_max = 150 - - return ( - batch_min, - batch_opt, - batch_max, - height_min, - height_opt, - height_max, - width_min, - width_opt, - width_max, - token_count_min, - token_count_opt, - token_count_max, - ) +def get_lora_checkpoints(): + available_lora_models = {} + allowed_extensions = ["pt", "ckpt", "safetensors"] + candidates = [ + p + for p in os.listdir(cmd_opts.lora_dir) + if p.split(".")[-1] in allowed_extensions + ] -profile_presets = { - "512x512 | Batch Size 1 (Static)": ( - 1, - 1, - 1, - 512, - 512, - 512, - 512, - 512, - 512, - 75, - 75, - 75, - ), - "768x768 | Batch Size 1 (Static)": ( - 1, - 1, - 1, - 768, - 768, - 768, - 768, - 768, - 768, - 75, - 75, - 75, - ), - "1024x1024 | Batch Size 1 (Static)": ( - 1, - 1, - 1, - 1024, - 1024, - 1024, - 1024, - 1024, - 1024, - 75, - 75, - 75, - ), - "256x256 - 512x512 | Batch Size 1-4 (Dynamic)": ( - 1, - 1, - 4, - 256, - 512, - 512, - 256, - 512, - 512, - 75, - 75, - 150, - ), - "512x512 - 768x768 | Batch Size 1-4 (Dynamic)": ( - 1, - 1, - 4, - 512, - 512, - 768, - 512, - 512, - 768, - 75, - 75, - 150, - ), - "768x768 - 1024x1024 | Batch Size 1-4 (Dynamic)": ( - 1, - 1, - 4, - 768, - 1024, - 1024, - 768, - 1024, - 1024, - 75, - 75, - 150, - ), -} - - -def get_settings_from_version(version): - static = False - if version == "Default": - return *list(profile_presets.values())[-2], static - if "Static" in version: - static = True - return *profile_presets[version], static + for filename in candidates: + metadata = {} + name, ext = os.path.splitext(filename) + config_file = os.path.join(cmd_opts.lora_dir, name + ".json") + + if ext == ".safetensors": + metadata = sd_models.read_metadata_from_safetensors( + os.path.join(cmd_opts.lora_dir, filename) + ) + else: + print( + """LoRA {} is not a safetensor. This might cause issues when exporting to TensorRT. + Please ensure that the correct base model is selected when exporting.""".format( + name + ) + ) + + base_model = metadata.get("ss_sd_model_name", "Unknown") + if os.path.exists(config_file): + with open(config_file, "r") as f: + config = json.load(f) + version = SDVersion.from_str(config["sd version"]) + + else: + version = SDVersion.Unknown + print( + "No config file found for {}. You can generate it in the LoRA tab.".format( + name + ) + ) + + available_lora_models[name] = { + "filename": filename, + "version": version, + "base_model": base_model, + } + return available_lora_models + + +def get_valid_lora_checkpoints(): + available_lora_models = get_lora_checkpoints() + return [f"{k} ({v['version']})" for k, v in available_lora_models.items()] def diable_export(version): if version == "Default": - return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) + return ( + gr.update(visible=False), + gr.update(visible=True), + gr.update(visible=False), + ) else: - return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True) + return ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=True), + ) + def disable_lora_export(lora): if lora is None: @@ -415,6 +343,7 @@ def disable_lora_export(lora): else: return gr.update(visible=True) + def diable_visibility(hide): num_outputs = 8 out = [gr.update(visible=not hide) for _ in range(num_outputs)] @@ -498,49 +427,6 @@ def get_md_table( return model_md -def get_version_from_filename(name): - if "v1-" in name: - return "1.5" - elif "v2-" in name: - return "2.1" - elif "xl" in name: - return "xl-1.0" - else: - return "Unknown" - - -def get_lora_checkpoints(): - available_lora_models = {} - candidates = list( - shared.walk_files( - shared.cmd_opts.lora_dir, - allowed_extensions=[".pt", ".ckpt", ".safetensors"], - ) - ) - for filename in candidates: - name = os.path.splitext(os.path.basename(filename))[0] - try: - metadata = sd_models.read_metadata_from_safetensors(filename) - version = get_version_from_filename(metadata.get("ss_sd_model_name")) - except (AssertionError, TypeError): - version = "Unknown" - available_lora_models[name] = { - "filename": filename, - "version": version, - } - return available_lora_models - - -def get_valid_lora_checkpoints(): - available_lora_models = get_lora_checkpoints() - return [ - f"{k} ({v['version']})" - for k, v in available_lora_models.items() - if v["version"] == get_version_from_model(shared.sd_model) - or v["version"] == "Unknown" - ] - - def on_ui_tabs(): with gr.Blocks(analytics_enabled=False) as trt_interface: with gr.Row(equal_height=True): @@ -551,17 +437,18 @@ def on_ui_tabs(): value="# TensorRT Exporter", ) - default_version = list(profile_presets.keys())[-2] - default_vals = list(profile_presets.values())[-2] + default_vals = profile_presets.get_default(is_xl=False) version = gr.Dropdown( label="Preset", - choices=list(profile_presets.keys()) + ["Default"], + choices=profile_presets.get_choices(), elem_id="sd_version", default="Default", value="Default", ) - with gr.Accordion("Advanced Settings", open=False, visible=False) as advanced_settings: + with gr.Accordion( + "Advanced Settings", open=False, visible=False + ) as advanced_settings: with FormRow( elem_classes="checkboxes-row", variant="compact" ): @@ -577,7 +464,7 @@ def on_ui_tabs(): maximum=16, step=1, label="Min batch-size", - value=default_vals[0], + value=default_vals.bs_min, elem_id="trt_min_batch", ) @@ -586,7 +473,7 @@ def on_ui_tabs(): maximum=16, step=1, label="Optimal batch-size", - value=default_vals[1], + value=default_vals.bs_opt, elem_id="trt_opt_batch", ) trt_max_batch = gr.Slider( @@ -594,7 +481,7 @@ def on_ui_tabs(): maximum=16, step=1, label="Max batch-size", - value=default_vals[2], + value=default_vals.bs_min, elem_id="trt_max_batch", ) @@ -604,7 +491,7 @@ def on_ui_tabs(): maximum=4096, step=64, label="Min height", - value=default_vals[3], + value=default_vals.h_min, elem_id="trt_min_height", ) trt_height_opt = gr.Slider( @@ -612,7 +499,7 @@ def on_ui_tabs(): maximum=4096, step=64, label="Optimal height", - value=default_vals[4], + value=default_vals.h_opt, elem_id="trt_opt_height", ) trt_height_max = gr.Slider( @@ -620,7 +507,7 @@ def on_ui_tabs(): maximum=4096, step=64, label="Max height", - value=default_vals[5], + value=default_vals.h_max, elem_id="trt_max_height", ) @@ -630,7 +517,7 @@ def on_ui_tabs(): maximum=4096, step=64, label="Min width", - value=default_vals[6], + value=default_vals.w_min, elem_id="trt_min_width", ) trt_width_opt = gr.Slider( @@ -638,7 +525,7 @@ def on_ui_tabs(): maximum=4096, step=64, label="Optimal width", - value=default_vals[7], + value=default_vals.w_opt, elem_id="trt_opt_width", ) trt_width_max = gr.Slider( @@ -646,7 +533,7 @@ def on_ui_tabs(): maximum=4096, step=64, label="Max width", - value=default_vals[8], + value=default_vals.w_max, elem_id="trt_max_width", ) @@ -656,7 +543,7 @@ def on_ui_tabs(): maximum=750, step=75, label="Min prompt token count", - value=default_vals[9], + value=default_vals.t_min, elem_id="trt_opt_token_count_min", ) trt_token_count_opt = gr.Slider( @@ -664,7 +551,7 @@ def on_ui_tabs(): maximum=750, step=75, label="Optimal prompt token count", - value=default_vals[10], + value=default_vals.t_opt, elem_id="trt_opt_token_count_opt", ) trt_token_count_max = gr.Slider( @@ -672,7 +559,7 @@ def on_ui_tabs(): maximum=750, step=75, label="Max prompt token count", - value=default_vals[11], + value=default_vals.t_max, elem_id="trt_opt_token_count_max", ) @@ -700,7 +587,7 @@ def on_ui_tabs(): ) version.change( - get_settings_from_version, + profile_presets.get_settings_from_version, version, [ trt_min_batch, @@ -721,7 +608,11 @@ def on_ui_tabs(): version.change( diable_export, version, - [button_export_unet, button_export_default_unet, advanced_settings], + [ + button_export_unet, + button_export_default_unet, + advanced_settings, + ], ) static_shapes.change( @@ -774,14 +665,16 @@ def on_ui_tabs(): trt_lora_dropdown, ) trt_lora_dropdown.change( - disable_lora_export, trt_lora_dropdown, button_export_lora_unet + disable_lora_export, + trt_lora_dropdown, + button_export_lora_unet, ) with gr.Column(variant="panel"): with open( os.path.join(os.path.dirname(os.path.abspath(__file__)), "info.md"), "r", - encoding='utf-8', + encoding="utf-8", ) as f: trt_info = gr.Markdown(elem_id="trt_info", value=f.read()) @@ -799,17 +692,23 @@ def get_trt_profiles_markdown(): profiles_md_string += "\n" return profiles_md_string - with gr.Column(variant="panel"): with gr.Row(equal_height=True, variant="compact"): - button_refresh_profiles = ToolButton(value=refresh_symbol, elem_id="trt_refresh_profiles", visible=True) + button_refresh_profiles = ToolButton( + value=refresh_symbol, elem_id="trt_refresh_profiles", visible=True + ) profile_header_md = gr.Markdown( value=f"## Available TensorRT Engine Profiles" ) with gr.Row(equal_height=True): - trt_profiles_markdown = gr.Markdown(elem_id=f"trt_profiles_markdown", value=get_trt_profiles_markdown()) - - button_refresh_profiles.click(lambda: gr.Markdown.update(value=get_trt_profiles_markdown()), outputs=[trt_profiles_markdown]) + trt_profiles_markdown = gr.Markdown( + elem_id=f"trt_profiles_markdown", value=get_trt_profiles_markdown() + ) + + button_refresh_profiles.click( + lambda: gr.Markdown.update(value=get_trt_profiles_markdown()), + outputs=[trt_profiles_markdown], + ) button_export_unet.click( export_unet_to_trt, diff --git a/utilities.py b/utilities.py index a71c72d..fc7822f 100644 --- a/utilities.py +++ b/utilities.py @@ -28,7 +28,7 @@ engine_from_bytes, engine_from_network, network_from_onnx_path, - save_engine + save_engine, ) from polygraphy.logger import G_LOGGER import tensorrt as trt @@ -36,7 +36,7 @@ from safetensors.numpy import save_file, load_file from logging import error, warning from tqdm import tqdm -import copy +import copy TRT_LOGGER = trt.Logger(trt.Logger.ERROR) G_LOGGER.module_severity = G_LOGGER.ERROR @@ -180,152 +180,55 @@ def __del__(self): del self.buffers del self.tensors - def refit(self, onnx_path, onnx_refit_path, dump_refit_path=None): - def convert_int64(arr): - # TODO: smarter conversion - if len(arr.shape) == 0: - return np.int32(arr) - return arr - - def add_to_map(refit_dict, name, values): - if name in refit_dict: - assert refit_dict[name] is None - if values.dtype == np.int64: - values = convert_int64(values) - refit_dict[name] = values - - print(f"Refitting TensorRT engine with {onnx_refit_path} weights") - refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes - - # Construct mapping from weight names in refit model -> original model - name_map = {} - for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): - refit_node = refit_nodes[n] - assert node.op == refit_node.op - # Constant nodes in ONNX do not have inputs but have a constant output - if node.op == "Constant": - name_map[refit_node.outputs[0].name] = node.outputs[0].name - # Handle scale and bias weights - elif node.op == "Conv": - if node.inputs[1].__class__ == gs.Constant: - name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" - if node.inputs[2].__class__ == gs.Constant: - name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" - # For all other nodes: find node inputs that are initializers (gs.Constant) - else: - for i, inp in enumerate(node.inputs): - if inp.__class__ == gs.Constant: - name_map[refit_node.inputs[i].name] = inp.name + def reset(self, engine_path=None): + del self.engine + del self.context + del self.buffers + del self.tensors + self.engine_path = engine_path - def map_name(name): - if name in name_map: - return name_map[name] - return name + self.buffers = OrderedDict() + self.tensors = OrderedDict() + self.inputs = {} + self.outputs = {} - # Construct refit dictionary - refit_dict = {} + def refit_from_dict(self, refit_weights, is_fp16): + # Initialize refitter refitter = trt.Refitter(self.engine, TRT_LOGGER) - all_weights = refitter.get_all() - for layer_name, role in zip(all_weights[0], all_weights[1]): - # for specialized roles, use a unique name in the map: - if role == trt.WeightsRole.KERNEL: - name = layer_name + "_TRTKERNEL" - elif role == trt.WeightsRole.BIAS: - name = layer_name + "_TRTBIAS" - else: - name = layer_name - - assert name not in refit_dict, "Found duplicate layer: " + name - refit_dict[name] = None - - for n in refit_nodes: - # Constant nodes in ONNX do not have inputs but have a constant output - if n.op == "Constant": - name = map_name(n.outputs[0].name) - print(f"Add Constant {name}\n") - try: - add_to_map(refit_dict, name, n.outputs[0].values) - except: - error(f"Failed to add Constant {name}\n") - - # Handle scale and bias weights - elif n.op == "Conv": - if n.inputs[1].__class__ == gs.Constant: - name = map_name(n.name + "_TRTKERNEL") - try: - add_to_map(refit_dict, name, n.inputs[1].values) - except: - error(f"Failed to add Conv {name}\n") - - if n.inputs[2].__class__ == gs.Constant: - name = map_name(n.name + "_TRTBIAS") - try: - add_to_map(refit_dict, name, n.inputs[2].values) - except: - error(f"Failed to add Conv {name}\n") - - # For all other nodes: find node inputs that are initializers (AKA gs.Constant) - else: - for inp in n.inputs: - name = map_name(inp.name) - if inp.__class__ == gs.Constant: - add_to_map(refit_dict, name, inp.values) - - if dump_refit_path is not None: - print("Finished refit. Dumping result to disk.") - save_file( - refit_dict, dump_refit_path - ) # TODO need to come up with delta system to save only changed weights - return - - for layer_name, weights_role in zip(all_weights[0], all_weights[1]): - if weights_role == trt.WeightsRole.KERNEL: - custom_name = layer_name + "_TRTKERNEL" - elif weights_role == trt.WeightsRole.BIAS: - custom_name = layer_name + "_TRTBIAS" - else: - custom_name = layer_name - # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model - if layer_name.startswith("onnx::Trilu"): + refitted_weights = set() + # iterate through all tensorrt refittable weights + for trt_weight_name in refitter.get_all_weights(): + if trt_weight_name not in refit_weights: continue - if refit_dict[custom_name] is not None: - refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) - else: - print(f"[W] No refit weights for layer: {layer_name}") - - if not refitter.refit_cuda_engine(): - print("Failed to refit!") - exit(0) - - def refit_from_dump(self, dump_refit_path): - refit_dict = load_file( - dump_refit_path - ) # TODO if deltas are used needs to be unpacked here - - refitter = trt.Refitter(self.engine, TRT_LOGGER) - all_weights = refitter.get_all() - - for layer_name, weights_role in zip(all_weights[0], all_weights[1]): - if weights_role == trt.WeightsRole.KERNEL: - custom_name = layer_name + "_TRTKERNEL" - elif weights_role == trt.WeightsRole.BIAS: - custom_name = layer_name + "_TRTBIAS" - else: - custom_name = layer_name - - # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model - if layer_name.startswith("onnx::Trilu"): - continue + # get weight from state dict + trt_datatype = trt.DataType.FLOAT + if is_fp16: + refit_weights[trt_weight_name] = refit_weights[trt_weight_name].half() + trt_datatype = trt.DataType.HALF + + # trt.Weight and trt.TensorLocation + refit_weights[trt_weight_name] = refit_weights[trt_weight_name].cpu() + trt_wt_tensor = trt.Weights( + trt_datatype, + refit_weights[trt_weight_name].data_ptr(), + torch.numel(refit_weights[trt_weight_name]), + ) + trt_wt_location = ( + trt.TensorLocation.DEVICE + if refit_weights[trt_weight_name].is_cuda + else trt.TensorLocation.HOST + ) - if refit_dict[custom_name] is not None: - refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) - else: - print(f"[W] No refit weights for layer: {layer_name}") + # apply refit + # refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location) + refitter.set_named_weights(trt_weight_name, trt_wt_tensor) + refitted_weights.add(trt_weight_name) + assert set(refitted_weights) == set(refit_weights.keys()) if not refitter.refit_cuda_engine(): - print("Failed to refit!") + print("Error: failed to refit new weights.") exit(0) def build( @@ -385,7 +288,9 @@ def build( profiles = copy.deepcopy(p) for profile in profiles: # Last profile is used for set_calibration_profile. - calib_profile = profile.fill_defaults(network[1]).to_trt(builder, network[1]) + calib_profile = profile.fill_defaults(network[1]).to_trt( + builder, network[1] + ) config.add_optimization_profile(calib_profile) try: From 7a44989b453db92808483e1a9f1defbf8582523d Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 20 Dec 2023 03:09:00 -0800 Subject: [PATCH 07/12] more refactoring and added typing --- datastructures.py | 65 ++++++++++++++++++++++++++++++++- exporter.py | 93 ++++++++++++++++++++++++----------------------- install.py | 9 +++++ model_helper.py | 63 +++++++++++++++++++++----------- model_manager.py | 74 +++---------------------------------- scripts/lora.py | 6 +-- scripts/trt.py | 23 ++++++------ ui_trt.py | 52 ++++++++------------------ utilities.py | 31 ---------------- 9 files changed, 197 insertions(+), 219 deletions(-) diff --git a/datastructures.py b/datastructures.py index ac3b5eb..338c9e4 100644 --- a/datastructures.py +++ b/datastructures.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from enum import Enum - +from json import JSONEncoder +import torch @dataclass class UNetEngineArgs: @@ -52,7 +53,67 @@ def from_string(cls, s): def __str__(self): return self.name.lower() +@dataclass +class ModelConfig: + profile: dict + static_shapes: bool + fp32: bool + inpaint: bool + refit: bool + lora: bool + vram: int + unet_hidden_dim: int = 4 + + def is_compatible_from_dict(self, feed_dict: dict): + distance = 0 + for k, v in feed_dict.items(): + _min, _opt, _max = self.profile[k] + v_tensor = torch.Tensor(list(v.shape)) + r_min = torch.Tensor(_max) - v_tensor + r_opt = (torch.Tensor(_opt) - v_tensor).abs() + r_max = v_tensor - torch.Tensor(_min) + if torch.any(r_min < 0) or torch.any(r_max < 0): + return (False, distance) + distance += r_opt.sum() + 0.5 * (r_max.sum() + 0.5 * r_min.sum()) + return (True, distance) + + def is_compatible( + self, width: int, height: int, batch_size: int, max_embedding: int + ): + distance = 0 + sample = self.profile["sample"] + embedding = self.profile["encoder_hidden_states"] + + batch_size *= 2 + width = width // 8 + height = height // 8 + + _min, _opt, _max = sample + if _min[0] > batch_size or _max[0] < batch_size: + return (False, distance) + if _min[2] > height or _max[2] < height: + return (False, distance) + if _min[3] > width or _max[3] < width: + return (False, distance) + + _min_em, _opt_em, _max_em = embedding + if _min_em[1] > max_embedding or _max_em[1] < max_embedding: + return (False, distance) + + distance = ( + abs(_opt[0] - batch_size) + + abs(_opt[2] - height) + + abs(_opt[3] - width) + + 0.5 * (abs(_max[2] - height) + abs(_max[3] - width)) + ) + + return (True, distance) + +class ModelConfigEncoder(JSONEncoder): + def default(self, o: ModelConfig): + return o.__dict__ + @dataclass class ProfileSettings: bs_min: int @@ -166,7 +227,7 @@ def __init__(self): 1, 1, 4, 768, 1024, 1024, 768, 1024, 1024, 75, 75, 150 ) - def get_settings_from_version(self, version): + def get_settings_from_version(self, version: str): static = False if version == "Default": return *self.default.out(), static diff --git a/exporter.py b/exporter.py index b6c6663..fdc2379 100644 --- a/exporter.py +++ b/exporter.py @@ -1,29 +1,31 @@ +import os +import time +import shutil +import json +from pathlib import Path +from logging import info, error +from collections import OrderedDict +from typing import List, Tuple + import torch import torch.nn.functional as F +import numpy as np import onnx -from logging import info, error -import time -import shutil +from onnx import numpy_helper +from optimum.onnx.utils import ( + _get_onnx_external_data_tensors, + check_model_uses_external_data, +) + -from modules import sd_hijack, sd_unet, shared +from modules import shared from utilities import Engine from datastructures import ProfileSettings from model_helper import UNetModel -import os - -from pathlib import Path -from optimum.onnx.utils import ( - _get_onnx_external_data_tensors, - check_model_uses_external_data, -) -from collections import OrderedDict -from onnx import numpy_helper -import numpy as np -import json -def apply_lora(model, lora_path, inputs): +def apply_lora(model: torch.nn.Module, lora_path: str, inputs: Tuple[torch.Tensor]) -> torch.nn.Module: try: import sys @@ -40,15 +42,15 @@ def apply_lora(model, lora_path, inputs): lora_name = os.path.splitext(os.path.basename(lora_path))[0] networks.load_networks( [lora_name], [1.0], [1.0], [None] - ) # todo: UI for parameters, multiple loras -> Struct of Arrays? + ) model.forward(*inputs) return model def get_refit_weights( - state_dict, onnx_opt_path, weight_name_mapping, weight_shape_mapping -): + state_dict: dict, onnx_opt_path: str, weight_name_mapping: dict, weight_shape_mapping: dict +) -> dict: refit_weights = OrderedDict() onnx_opt_dir = os.path.dirname(onnx_opt_path) onnx_opt_model = onnx.load(onnx_opt_path) @@ -89,17 +91,7 @@ def export_lora( weights_map_path: str, lora_name: str, profile: ProfileSettings, -): - def disable_checkpoint(self): - if getattr(self, "use_checkpoint", False) == True: - self.use_checkpoint = False - if getattr(self, "checkpoint", False) == True: - self.checkpoint = False - - shared.sd_model.model.diffusion_model.apply(disable_checkpoint) - sd_unet.apply_unet("None") - sd_hijack.model_hijack.apply_optimizations("None") - +) -> dict: info("Exporting to ONNX...") inputs = modelobj.get_sample_input( profile.bs_opt * 2, @@ -107,17 +99,18 @@ def disable_checkpoint(self): profile.w_opt // 8, profile.t_opt, ) - model = shared.sd_model.model.diffusion_model with open(weights_map_path, "r") as fp_wts: print(f"[I] Loading weights map: {weights_map_path} ") [weights_name_mapping, weights_shape_mapping] = json.load(fp_wts) with torch.inference_mode(), torch.autocast("cuda"): - model = apply_lora(model, os.path.splitext(lora_name)[0], inputs) + modelobj.unet = apply_lora( + modelobj.unet, os.path.splitext(lora_name)[0], inputs + ) refit_dict = get_refit_weights( - model.state_dict(), + modelobj.unet.state_dict(), onnx_path, weights_name_mapping, weights_shape_mapping, @@ -126,18 +119,30 @@ def disable_checkpoint(self): return refit_dict +def swap_sdpa(func): + def wrapper(*args, **kwargs): + swap_sdpa = hasattr(F, "scaled_dot_product_attention") + old_sdpa = ( + getattr(F, "scaled_dot_product_attention", None) if swap_sdpa else None + ) + if swap_sdpa: + delattr(F, "scaled_dot_product_attention") + ret = func(*args, **kwargs) + if swap_sdpa and old_sdpa: + setattr(F, "scaled_dot_product_attention", old_sdpa) + return ret + + return wrapper + + +@swap_sdpa def export_onnx( onnx_path: str, modelobj: UNetModel, profile: ProfileSettings, - opset=17, - diable_optimizations=False, + opset: int = 17, + diable_optimizations: bool = False, ): - swap_sdpa = hasattr(F, "scaled_dot_product_attention") - old_sdpa = getattr(F, "scaled_dot_product_attention", None) if swap_sdpa else None - if swap_sdpa: - delattr(F, "scaled_dot_product_attention") - info("Exporting to ONNX...") inputs = modelobj.get_sample_input( profile.bs_opt * 2, @@ -158,13 +163,9 @@ def export_onnx( modelobj.optimize if not diable_optimizations else None, ) - # CleanUp - if swap_sdpa and old_sdpa: - setattr(F, "scaled_dot_product_attention", old_sdpa) - def _export_onnx( - model, inputs, path, opset, in_names, out_names, dyn_axes, optimizer=None + model: torch.nn.Module, inputs: Tuple[torch.Tensor], path: str, opset: int, in_names: List[str], out_names: List[str], dyn_axes: dict, optimizer=None ): tmp_dir = os.path.abspath("onnx_tmp") os.makedirs(tmp_dir, exist_ok=True) @@ -219,7 +220,7 @@ def _export_onnx( shutil.rmtree(tmp_dir) -def export_trt(trt_path, onnx_path, timing_cache, profile, use_fp16): +def export_trt(trt_path: str, onnx_path: str, timing_cache: str, profile: dict, use_fp16: bool): engine = Engine(trt_path) # TODO Still approx. 2gb of VRAM unaccounted for... diff --git a/install.py b/install.py index a4ba8d7..02e5cc8 100644 --- a/install.py +++ b/install.py @@ -48,5 +48,14 @@ def install(): live=True, ) + # OPTIMUM + if not launch.is_installed("optimum"): + print("Optimum is not installed! Installing...") + launch.run_pip( + "install optimum", + "optimum", + live=True, + ) + install() diff --git a/model_helper.py b/model_helper.py index 547395f..9a68d5d 100644 --- a/model_helper.py +++ b/model_helper.py @@ -15,22 +15,27 @@ # limitations under the License. # -import onnx -from onnx import shape_inference, numpy_helper import os -from polygraphy.backend.onnx.loader import fold_constants +import json import tempfile +from typing import List, Tuple + import torch -import torch.nn.functional as F +import numpy as np +import onnx +from onnx import shape_inference, numpy_helper import onnx_graphsurgeon as gs -from datastructures import ProfileSettings +from polygraphy.backend.onnx.loader import fold_constants -import numpy as np -import json +from modules import sd_hijack, sd_unet + +from datastructures import ProfileSettings class UNetModel(torch.nn.Module): - def __init__(self, unet, embedding_dim, text_minlen=77, is_xl=False) -> None: + def __init__( + self, unet, embedding_dim: int, text_minlen: int = 77, is_xl: bool = False + ) -> None: super().__init__() self.unet = unet self.is_xl = is_xl @@ -39,6 +44,7 @@ def __init__(self, unet, embedding_dim, text_minlen=77, is_xl=False) -> None: self.embedding_dim = embedding_dim self.num_xl_classes = 2816 # Magic number for num_classes self.emb_chn = 1280 + self.in_channels = self.unet.in_channels self.dyn_axes = { "sample": {0: "2B", 2: "H", 3: "W"}, @@ -48,33 +54,48 @@ def __init__(self, unet, embedding_dim, text_minlen=77, is_xl=False) -> None: "y": {0: "2B"}, } - def get_input_names(self): + def apply_torch_model(self): + def disable_checkpoint(self): + if getattr(self, "use_checkpoint", False) == True: + self.use_checkpoint = False + if getattr(self, "checkpoint", False) == True: + self.checkpoint = False + + self.unet.apply(disable_checkpoint) + self.set_unet("None") + + def set_unet(self, ckpt: str): + # TODO test if using this with TRT works + sd_unet.apply_unet(ckpt) + sd_hijack.model_hijack.apply_optimizations(ckpt) + + def get_input_names(self) -> List[str]: names = ["sample", "timesteps", "encoder_hidden_states"] if self.is_xl: names.append("y") return names - def get_output_names(self): + def get_output_names(self) -> List[str]: return ["latent"] - def get_dynamic_axes(self): + def get_dynamic_axes(self) -> dict: io_names = self.get_input_names() + self.get_output_names() dyn_axes = {name: self.dyn_axes[name] for name in io_names} return dyn_axes def get_sample_input( self, - batch_size, - latent_height, - latent_width, - text_len, - device="cuda", - dtype=torch.float32, - ): + batch_size: int, + latent_height: int, + latent_width: int, + text_len: int, + device: str = "cuda", + dtype: torch.dtype = torch.float32, + ) -> Tuple[torch.Tensor]: return ( torch.randn( batch_size, - self.unet.in_channels, + self.in_channels, latent_height, latent_width, dtype=dtype, @@ -93,7 +114,7 @@ def get_sample_input( else None, ) - def get_input_profile(self, profile: ProfileSettings): + def get_input_profile(self, profile: ProfileSettings) -> dict: min_batch, opt_batch, max_batch = profile.get_a1111_batch_dim() ( min_latent_height, @@ -127,7 +148,7 @@ def get_input_profile(self, profile: ProfileSettings): return shape_dict # Helper utility for weights map - def export_weights_map(self, onnx_opt_path, weights_map_path): + def export_weights_map(self, onnx_opt_path: str, weights_map_path: dict): onnx_opt_dir = onnx_opt_path state_dict = self.unet.state_dict() onnx_opt_model = onnx.load(onnx_opt_path) diff --git a/model_manager.py b/model_manager.py index fdb50b4..9f7604e 100644 --- a/model_manager.py +++ b/model_manager.py @@ -1,13 +1,13 @@ -import json -from json import JSONEncoder - import os +import json +import copy from logging import info, warning -from dataclasses import dataclass -from datastructures import ModelType import torch + from modules import paths_internal -import copy + +from datastructures import ModelConfig, ModelConfigEncoder + ONNX_MODEL_DIR = os.path.join(paths_internal.models_path, "Unet-onnx") if not os.path.exists(ONNX_MODEL_DIR): @@ -225,66 +225,4 @@ def get_valid_models( return valid_models, distances, idx -@dataclass -class ModelConfig: - profile: dict - static_shapes: bool - fp32: bool - inpaint: bool - refit: bool - lora: bool - vram: int - unet_hidden_dim: int = 4 - - def is_compatible_from_dict(self, feed_dict: dict): - distance = 0 - for k, v in feed_dict.items(): - _min, _opt, _max = self.profile[k] - v_tensor = torch.Tensor(list(v.shape)) - r_min = torch.Tensor(_max) - v_tensor - r_opt = (torch.Tensor(_opt) - v_tensor).abs() - r_max = v_tensor - torch.Tensor(_min) - if torch.any(r_min < 0) or torch.any(r_max < 0): - return (False, distance) - distance += r_opt.sum() + 0.5 * (r_max.sum() + 0.5 * r_min.sum()) - return (True, distance) - - def is_compatible( - self, width: int, height: int, batch_size: int, max_embedding: int - ): - distance = 0 - sample = self.profile["sample"] - embedding = self.profile["encoder_hidden_states"] - - batch_size *= 2 - width = width // 8 - height = height // 8 - - _min, _opt, _max = sample - if _min[0] > batch_size or _max[0] < batch_size: - return (False, distance) - if _min[2] > height or _max[2] < height: - return (False, distance) - if _min[3] > width or _max[3] < width: - return (False, distance) - - _min_em, _opt_em, _max_em = embedding - if _min_em[1] > max_embedding or _max_em[1] < max_embedding: - return (False, distance) - - distance = ( - abs(_opt[0] - batch_size) - + abs(_opt[2] - height) - + abs(_opt[3] - width) - + 0.5 * (abs(_max[2] - height) + abs(_max[3] - width)) - ) - - return (True, distance) - - -class ModelConfigEncoder(JSONEncoder): - def default(self, o: ModelConfig): - return o.__dict__ - - modelmanager = ModelManager() diff --git a/scripts/lora.py b/scripts/lora.py index 102ba41..2e21bef 100644 --- a/scripts/lora.py +++ b/scripts/lora.py @@ -1,14 +1,14 @@ import os from typing import List + import numpy as np +import torch from safetensors.torch import load_file -import onnx_graphsurgeon as gs import onnx -import torch from onnx import numpy_helper -def merge_loras(loras: List[str], scales: List[str]): +def merge_loras(loras: List[str], scales: List[str]) -> dict: refit_dict = {} for lora, scale in zip(loras, scales): lora_dict = load_file(lora) diff --git a/scripts/trt.py b/scripts/trt.py index a62acbb..c91d145 100644 --- a/scripts/trt.py +++ b/scripts/trt.py @@ -1,21 +1,23 @@ import os -import numpy as np +import re +from typing import List +import numpy as np import torch from torch.cuda import nvtx +from polygraphy.logger import G_LOGGER +import gradio as gr + from modules import script_callbacks, sd_unet, devices, scripts import ui_trt from utilities import Engine -from typing import List from model_manager import TRT_MODEL_DIR, modelmanager -from polygraphy.logger import G_LOGGER -import gradio as gr -import re from datastructures import UNetEngineArgs, ModelType from scripts.lora import apply_loras G_LOGGER.module_severity = G_LOGGER.ERROR +GLOBAL_ARGS = UNetEngineArgs(0, 0, None, {}) class TrtUnetOption(sd_unet.SdUnetOption): @@ -28,9 +30,6 @@ def create_unet(self): return TrtUnet(self.model_name, self.configs) -GLOBAL_ARGS = UNetEngineArgs(0, 0, None, {}) - - class TrtUnet(sd_unet.SdUnet): def __init__(self, model_name: str, configs: List[dict], *args, **kwargs): super().__init__(*args, **kwargs) @@ -52,7 +51,7 @@ def __init__(self, model_name: str, configs: List[dict], *args, **kwargs): os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]) ) - def forward(self, x, timesteps, context, *args, **kwargs): + def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, *args, **kwargs) -> torch.Tensor: nvtx.range_push("forward") feed_dict = { "sample": x.float(), @@ -134,7 +133,7 @@ def before_process(self, p, *args): # 1 # Check divisibilty if p.width % 64 or p.height % 64: gr.Error("Target resolution must be divisible by 64 in both dimensions.") - + # TODO img2img has not enable hr if p.enable_hr: hr_w = int(p.width * p.hr_scale) hr_h = int(p.height * p.hr_scale) @@ -143,7 +142,7 @@ def before_process(self, p, *args): # 1 "HIRES Fix resolution must be divisible by 64 in both dimensions. Please change the upscale factor or disable HIRES Fix." ) - def get_profile_idx(self, p, model_name, model_type): + def get_profile_idx(self, p, model_name: str, model_type: ModelType) -> (int, int): best_hr = None hr_scale = p.hr_scale if p.enable_hr else 1 ( @@ -282,6 +281,8 @@ def after_extra_networks_activate(self, p, *args, **kwargs): def list_unets(l): model = modelmanager.available_models() for k, v in model.items(): + if v[0]["config"].lora: + continue label = "{} ({})".format(k, v[0]["base_model"]) if v[0]["config"].lora else k l.append(TrtUnetOption(label, v)) diff --git a/ui_trt.py b/ui_trt.py index 0a35bb4..fe71075 100644 --- a/ui_trt.py +++ b/ui_trt.py @@ -1,34 +1,30 @@ import os +import gc +import json +import logging +from collections import defaultdict +import torch +from safetensors.torch import save_file import gradio as gr -from modules.call_queue import wrap_gradio_gpu_call + from modules.shared import cmd_opts from modules.ui_components import FormRow -from modules import sd_hijack, sd_unet, sd_models, shared +from modules import sd_hijack, sd_models, shared from modules.ui_common import refresh_symbol from modules.ui_components import ToolButton -from exporter import export_onnx, export_trt, export_lora -from utilities import Engine -from safetensors.torch import save_file - - -import logging -import gc -import torch -from collections import defaultdict -import json - from model_helper import UNetModel +from exporter import export_onnx, export_trt, export_lora from model_manager import modelmanager, cc_major, TRT_MODEL_DIR from datastructures import SDVersion, ProfilePrests, ProfileSettings + profile_presets = ProfilePrests() logging.basicConfig(level=logging.INFO) -# TODO get info from model config def get_context_dim(): if shared.sd_model.is_sd1: return 768 @@ -63,14 +59,6 @@ def export_unet_to_trt( static_shapes, preset, ): - def disable_checkpoint(self): - if getattr(self, "use_checkpoint", False) == True: - self.use_checkpoint = False - if getattr(self, "checkpoint", False) == True: - self.checkpoint = False - - shared.sd_model.model.diffusion_model.apply(disable_checkpoint) - sd_unet.apply_unet("None") sd_hijack.model_hijack.apply_optimizations("None") is_xl = shared.sd_model.is_sdxl @@ -112,6 +100,7 @@ def disable_checkpoint(self): text_minlen=profile_settings.t_min, is_xl=is_xl, ) + modelobj.apply_torch_model() profile = modelobj.get_input_profile(profile_settings) export_onnx( @@ -134,8 +123,6 @@ def disable_checkpoint(self): gr.Info( "Building TensorRT engine... This can take a while, please check the progress in the terminal." ) - gc.collect() - torch.cuda.empty_cache() ret = export_trt( trt_path, onnx_path, @@ -153,10 +140,10 @@ def disable_checkpoint(self): profile, static_shapes, fp32=use_fp32, - inpaint=False, # TODO + inpaint=True if modelobj.in_channels == 6 else False, refit=True, vram=0, - unet_hidden_dim=4, # TODO + unet_hidden_dim=modelobj.in_channels, lora=False, ) else: @@ -204,22 +191,13 @@ def export_lora_to_trt(lora_name, force_export): embedding_dim = get_context_dim() - def disable_checkpoint(self): - if getattr(self, "use_checkpoint", False) == True: - self.use_checkpoint = False - if getattr(self, "checkpoint", False) == True: - self.checkpoint = False - - shared.sd_model.model.diffusion_model.apply(disable_checkpoint) - sd_unet.apply_unet("None") - sd_hijack.model_hijack.apply_optimizations("None") - modelobj = UNetModel( shared.sd_model.model.diffusion_model, embedding_dim, text_minlen=profile_settings.t_min, is_xl=is_xl, ) + modelobj.apply_torch_model() weights_map_path = modelmanager.get_weights_map_path(model_name) if not os.path.exists(weights_map_path): @@ -458,7 +436,7 @@ def on_ui_tabs(): elem_id="trt_static_shapes", ) - with gr.Column(elem_id="trt_max_batch"): + with gr.Column(elem_id="trt_batch"): trt_min_batch = gr.Slider( minimum=1, maximum=16, diff --git a/utilities.py b/utilities.py index fc7822f..d5285c0 100644 --- a/utilities.py +++ b/utilities.py @@ -19,8 +19,6 @@ from torch.cuda import nvtx from collections import OrderedDict import numpy as np -import onnx -import onnx_graphsurgeon as gs from polygraphy.backend.common import bytes_from_path from polygraphy import util from polygraphy.backend.trt import ModifyNetworkOutputs, Profile @@ -32,8 +30,6 @@ ) from polygraphy.logger import G_LOGGER import tensorrt as trt -from enum import Enum, auto -from safetensors.numpy import save_file, load_file from logging import error, warning from tqdm import tqdm import copy @@ -64,33 +60,6 @@ value: key for (key, value) in numpy_to_torch_dtype_dict.items() } - -class PIPELINE_TYPE(Enum): - TXT2IMG = auto() - IMG2IMG = auto() - INPAINT = auto() - SD_XL_BASE = auto() - SD_XL_REFINER = auto() - - def is_txt2img(self): - return self == self.TXT2IMG - - def is_img2img(self): - return self == self.IMG2IMG - - def is_inpaint(self): - return self == self.INPAINT - - def is_sd_xl_base(self): - return self == self.SD_XL_BASE - - def is_sd_xl_refiner(self): - return self == self.SD_XL_REFINER - - def is_sd_xl(self): - return self.is_sd_xl_base() or self.is_sd_xl_refiner() - - class TQDMProgressMonitor(trt.IProgressMonitor): def __init__(self): trt.IProgressMonitor.__init__(self) From d635c39c40b8928942fb2bf59042f4698a91d7b5 Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 20 Dec 2023 05:49:55 -0800 Subject: [PATCH 08/12] Enable torch fallback --- datastructures.py | 13 ++--- scripts/trt.py | 124 ++++++++++++++++++++++++++++++---------------- 2 files changed, 85 insertions(+), 52 deletions(-) diff --git a/datastructures.py b/datastructures.py index 338c9e4..1257f65 100644 --- a/datastructures.py +++ b/datastructures.py @@ -1,14 +1,7 @@ from dataclasses import dataclass from enum import Enum from json import JSONEncoder -import torch - -@dataclass -class UNetEngineArgs: - idx: int - hr_idx: int = None - lora: dict = None - controlnets: dict = None +import torch class SDVersion(Enum): @@ -53,6 +46,7 @@ def from_string(cls, s): def __str__(self): return self.name.lower() + @dataclass class ModelConfig: profile: dict @@ -113,7 +107,8 @@ def is_compatible( class ModelConfigEncoder(JSONEncoder): def default(self, o: ModelConfig): return o.__dict__ - + + @dataclass class ProfileSettings: bs_min: int diff --git a/scripts/trt.py b/scripts/trt.py index c91d145..e727e64 100644 --- a/scripts/trt.py +++ b/scripts/trt.py @@ -8,16 +8,15 @@ from polygraphy.logger import G_LOGGER import gradio as gr -from modules import script_callbacks, sd_unet, devices, scripts +from modules import script_callbacks, sd_unet, devices, scripts, shared import ui_trt from utilities import Engine from model_manager import TRT_MODEL_DIR, modelmanager -from datastructures import UNetEngineArgs, ModelType +from datastructures import ModelType from scripts.lora import apply_loras G_LOGGER.module_severity = G_LOGGER.ERROR -GLOBAL_ARGS = UNetEngineArgs(0, 0, None, {}) class TrtUnetOption(sd_unet.SdUnetOption): @@ -38,20 +37,22 @@ def __init__(self, model_name: str, configs: List[dict], *args, **kwargs): self.model_name = model_name self.configs = configs - self.profile_idx = GLOBAL_ARGS.idx - if self.profile_idx is None: - self.profile_idx = 0 - self.loaded_config = self.configs[self.profile_idx] + self.profile_idx = 0 + self.loaded_config = None self.engine_vram_req = 0 - self.shape_hash = 0 self.refitted_keys = set() - self.engine = Engine( - os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]) - ) + self.engine = None - def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: nvtx.range_push("forward") feed_dict = { "sample": x.float(), @@ -61,9 +62,6 @@ def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tenso if "y" in kwargs: feed_dict["y"] = kwargs["y"].float() - if not self.profile_idx == GLOBAL_ARGS.idx: - self.switch_engine() - tmp = torch.empty( self.engine_vram_req, dtype=torch.uint8, device=devices.device ) @@ -76,11 +74,7 @@ def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tenso nvtx.range_pop() return out - def apply_loras(self): - if GLOBAL_ARGS.lora is None: - refit_dict = {} - else: - refit_dict = GLOBAL_ARGS.lora + def apply_loras(self, refit_dict: dict): if not self.refitted_keys.issubset(set(refit_dict.keys())): # Need to ensure that weights that have been modified before and are not present anymore are reset. self.refitted_keys = set() @@ -89,20 +83,17 @@ def apply_loras(self): self.engine.refit_from_dict(refit_dict, is_fp16=True) self.refitted_keys = set(refit_dict.keys()) - def set_idx(self): - if GLOBAL_ARGS.idx is None: - raise Exception("No valid profile found. Please generate a profile first.") - self.profile_idx = GLOBAL_ARGS.idx - def switch_engine(self): - self.set_idx() self.loaded_config = self.configs[self.profile_idx] self.engine.reset(os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"])) self.activate() - self.shape_hash = 0 def activate(self): - self.shape_hash = 0 + self.loaded_config = self.configs[self.profile_idx] + if self.engine is None: + self.engine = Engine( + os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]) + ) self.engine.load() print(f"\nLoaded Profile: {self.profile_idx}") print(self.engine) @@ -110,7 +101,6 @@ def activate(self): self.engine.activate(True) def deactivate(self): - self.shape_hash = 0 del self.engine @@ -119,6 +109,10 @@ def __init__(self) -> None: self.loaded_model = None self.lora_hash = "" self.update_lora = False + self.lora_refit_dict = {} + self.idx = None + self.hr_idx = None + self.torch_unet = False def title(self): return "TensorRT" @@ -133,7 +127,9 @@ def before_process(self, p, *args): # 1 # Check divisibilty if p.width % 64 or p.height % 64: gr.Error("Target resolution must be divisible by 64 in both dimensions.") - # TODO img2img has not enable hr + + if self.is_img2img: + return if p.enable_hr: hr_w = int(p.width * p.hr_scale) hr_h = int(p.height * p.hr_scale) @@ -144,7 +140,11 @@ def before_process(self, p, *args): # 1 def get_profile_idx(self, p, model_name: str, model_type: ModelType) -> (int, int): best_hr = None - hr_scale = p.hr_scale if p.enable_hr else 1 + + if self.is_img2img: + hr_scale = 1 + else: + hr_scale = p.hr_scale if p.enable_hr else 1 ( valid_models, distances, @@ -163,6 +163,7 @@ def get_profile_idx(self, p, model_name: str, model_type: ModelType) -> (int, in ) return None, None best = idx[np.argmin(distances)] + best_hr = best if hr_scale != 1: hr_w = int(p.width * p.hr_scale) @@ -213,7 +214,7 @@ def get_loras(self, p): self.lora_hash = "".join(loras) self.update_lora = True if self.lora_hash == "": - GLOBAL_ARGS.lora = None + self.lora_refit_dict = {} return else: return @@ -236,7 +237,7 @@ def get_loras(self, p): base_name, base_path = modelmanager.get_onnx_path(p.sd_model_name) refit_dict = apply_loras(base_path, lora_pathes, lora_scales) - GLOBAL_ARGS.lora = refit_dict + self.lora_refit_dict = refit_dict def process(self, p, *args): # before unet_init @@ -251,31 +252,68 @@ def process(self, p, *args): p.sd_model_name, sd_unet_option.model_name ) ) - GLOBAL_ARGS.idx, GLOBAL_ARGS.hr_idx = self.get_profile_idx( - p, p.sd_model_name, ModelType.UNET - ) + self.idx, self.hr_idx = self.get_profile_idx(p, p.sd_model_name, ModelType.UNET) + self.torch_unet = self.idx is None or self.hr_idx is None try: - self.get_loras(p) + if not self.torch_unet: + self.get_loras(p) except Exception as e: gr.Error(e) raise e + self.apply_unet(sd_unet_option) + + def apply_unet(self, sd_unet_option): + if ( + sd_unet_option == sd_unet.current_unet_option + and sd_unet.current_unet is not None + and not self.torch_unet + ): + return + + if sd_unet.current_unet is not None: + sd_unet.current_unet.deactivate() + + if self.torch_unet: + gr.Warning("Enabling PyTorch fallback as no engine was found.") + sd_unet.current_unet = None + sd_unet.current_unet_option = sd_unet_option + shared.sd_model.model.diffusion_model.to(devices.device) + return + else: + shared.sd_model.model.diffusion_model.to(devices.cpu) + devices.torch_gc() + if self.lora_refit_dict: + self.update_lora = True + sd_unet.current_unet = sd_unet_option.create_unet() + sd_unet.current_unet.profile_idx = self.idx + sd_unet.current_unet.option = sd_unet_option + sd_unet.current_unet_option = sd_unet_option + + print(f"Activating unet: {sd_unet.current_unet.option.label}") + sd_unet.current_unet.activate() + def process_batch(self, p, *args, **kwargs): # Called for each batch count - return super().process_batch(p, *args, **kwargs) + if self.torch_unet: + return super().process_batch(p, *args, **kwargs) + + if self.idx != sd_unet.current_unet.profile_idx: + sd_unet.current_unet.profile_idx = self.idx + sd_unet.current_unet.switch_engine() def before_hr(self, p, *args): - GLOBAL_ARGS.idx = GLOBAL_ARGS.hr_idx + if self.idx != self.hr_idx: + sd_unet.current_unet.profile_idx = self.hr_idx + sd_unet.current_unet.switch_engine() return super().before_hr(p, *args) # 4 (Only when HR starts.....) def after_extra_networks_activate(self, p, *args, **kwargs): - if self.update_lora: + if self.update_lora and not self.torch_unet: self.update_lora = False - # Not the fastest, but safest option. Larger bottlenecks to solve first! - # Other two options: Overengingeer, Refit whole model - sd_unet.current_unet.apply_loras() + sd_unet.current_unet.apply_loras(self.lora_refit_dict) def list_unets(l): From 7959e4cafe4ce0ea9d89d80647463869b2637eee Mon Sep 17 00:00:00 2001 From: lspindler Date: Thu, 28 Dec 2023 08:36:59 -0800 Subject: [PATCH 09/12] update inastall --- install.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/install.py b/install.py index 02e5cc8..b02753d 100644 --- a/install.py +++ b/install.py @@ -6,28 +6,34 @@ def install(): - import torch - - if not torch.cuda.is_available(): - print( - "Torch CUDA is not available! Please install Torch with CUDA and try again." - ) - return - if launch.is_installed("tensorrt"): if not version("tensorrt") == "9.0.1.post11.dev4": - print("Removing old TensorRT package and try reinstalling...") launch.run( - f'"{python}" -m pip uninstall -y tensorrt', + ["python", "-m", "pip", "uninstall", "-y", "tensorrt"], "removing old version of tensorrt", ) if not launch.is_installed("tensorrt"): + print("TensorRT is not installed! Installing...") + launch.run_pip( + "install nvidia-cudnn-cu11==8.9.4.25 --no-cache-dir", "nvidia-cudnn-cu11" + ) launch.run_pip( - "install --pre --extra-index-url https://pypi.nvidia.com --no-cache-dir --no-deps tensorrt==9.0.1.post11.dev4", + "install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir", "tensorrt", live=True, ) + launch.run( + ["python", "-m", "pip", "uninstall", "-y", "nvidia-cudnn-cu11"], + "removing nvidia-cudnn-cu11", + ) + + if launch.is_installed("nvidia-cudnn-cu11"): + if version("nvidia-cudnn-cu11") == "8.9.4.25": + launch.run( + ["python", "-m", "pip", "uninstall", "-y", "nvidia-cudnn-cu11"], + "removing nvidia-cudnn-cu11", + ) # Polygraphy if not launch.is_installed("polygraphy"): From f3ca680e8355ed220a1939ed2552cd15ba9e185e Mon Sep 17 00:00:00 2001 From: lspindler Date: Thu, 28 Dec 2023 08:43:53 -0800 Subject: [PATCH 10/12] change default XL engine --- datastructures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datastructures.py b/datastructures.py index 1257f65..50f0cc7 100644 --- a/datastructures.py +++ b/datastructures.py @@ -219,7 +219,7 @@ def __init__(self): 1, 1, 4, 512, 512, 768, 512, 512, 768, 75, 75, 150 ) self.default_xl = ProfileSettings( - 1, 1, 4, 768, 1024, 1024, 768, 1024, 1024, 75, 75, 150 + 1, 1, 1, 1024, 1024, 1024, 1024, 1024, 1024, 75, 75, 75 ) def get_settings_from_version(self, version: str): From 1a8bab0abdddc92b7ba34dde0012cd0ef59e693e Mon Sep 17 00:00:00 2001 From: lspindler Date: Tue, 2 Jan 2024 01:22:52 -0800 Subject: [PATCH 11/12] cc independent lora --- model_manager.py | 9 +++++++++ scripts/trt.py | 6 +++--- ui_trt.py | 33 ++++++++++++++++----------------- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/model_manager.py b/model_manager.py index 9f7604e..9f248c8 100644 --- a/model_manager.py +++ b/model_manager.py @@ -175,6 +175,15 @@ def available_models(self): available = self.all_models.get(self.cc, {}) return available + def available_loras(self): + available = {} + for p in os.listdir(TRT_MODEL_DIR): + if not p.endswith(".lora"): + continue + available[os.path.splitext(p)[0]] = os.path.join(TRT_MODEL_DIR, p) + + return available + def get_timing_cache(self): current_dir = os.path.dirname(os.path.abspath(__file__)) cache = os.path.join( diff --git a/scripts/trt.py b/scripts/trt.py index e727e64..b9f16aa 100644 --- a/scripts/trt.py +++ b/scripts/trt.py @@ -190,7 +190,7 @@ def get_profile_idx(self, p, model_name: str, model_type: ModelType) -> (int, in return None, None else: _distances = [distances[i] for i in merged_idx] - best_hr = idx_hr[merged_idx[np.argmin(_distances)]] + best_hr = merged_idx[np.argmin(_distances)] best = best_hr return best, best_hr @@ -221,7 +221,7 @@ def get_loras(self, p): # Get pathes print("Apllying LoRAs: " + str(loras)) - available = modelmanager.available_models() + available = modelmanager.available_loras() for lora in loras: lora_name, lora_scale = lora.split(":")[1:] lora_scales.append(float(lora_scale)) @@ -230,7 +230,7 @@ def get_loras(self, p): f"Please export the LoRA checkpoint {lora_name} first from the TensorRT LoRA tab" ) lora_pathes.append( - os.path.join(TRT_MODEL_DIR, available[lora_name][0]["filepath"]) + available[lora_name] ) # Merge lora refit dicts diff --git a/ui_trt.py b/ui_trt.py index fe71075..4f19847 100644 --- a/ui_trt.py +++ b/ui_trt.py @@ -203,7 +203,7 @@ def export_lora_to_trt(lora_name, force_export): if not os.path.exists(weights_map_path): modelobj.export_weights_map(onnx_base_path, weights_map_path) - lora_trt_name = f"{lora_name}.trt" + lora_trt_name = f"{lora_name}.lora" lora_trt_path = os.path.join(TRT_MODEL_DIR, lora_trt_name) if os.path.exists(lora_trt_path) and not force_export: @@ -223,15 +223,6 @@ def export_lora_to_trt(lora_name, force_export): save_file(refit_dict, lora_trt_path) - modelmanager.add_lora_entry( - model_name, - lora_name, - lora_trt_name, - is_fp32(), - False, - 0, - 4, - ) return "## Exported Successfully \n" @@ -372,9 +363,9 @@ def get_md_table( loras_md = {} for base_model, models in available_models.items(): for i, m in enumerate(models): - if m["config"].lora: - loras_md[base_model] = m.get("base_model", None) - continue + # if m["config"].lora: + # loras_md[base_model] = m.get("base_model", None) + # continue s_min, s_opt, s_max = m["config"].profile.get( "sample", [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]] @@ -399,10 +390,11 @@ def get_md_table( model_md[base_model].append(profile_table) - for lora, base_model in loras_md.items(): - model_md[f"{lora} (*{base_model}*)"] = model_md[base_model] + available_loras = modelmanager.available_loras() + for lora, path in available_loras.items(): + loras_md[f"{lora}"] = "" - return model_md + return model_md, loras_md def on_ui_tabs(): @@ -662,12 +654,19 @@ def on_ui_tabs(): def get_trt_profiles_markdown(): profiles_md_string = "" - for model, profiles in engine_profile_card().items(): + engine_cards, lora_cards = engine_profile_card() + for model, profiles in engine_cards.items(): profiles_md_string += f"
{model} ({len(profiles)} Profiles)\n\n" for i, profile in enumerate(profiles): profiles_md_string += f"#### Profile {i} \n{profile}\n\n" profiles_md_string += "
\n" profiles_md_string += "\n" + + profiles_md_string += "\n --- \n ## LoRA Profiles \n" + for model, details in lora_cards.items(): + profiles_md_string += f"
{model}\n\n" + profiles_md_string += details + profiles_md_string += "
\n" return profiles_md_string with gr.Column(variant="panel"): From 7bcb9f42c1ccf2b726891c7978effa6d4c3dba92 Mon Sep 17 00:00:00 2001 From: Luca Date: Fri, 5 Jan 2024 12:10:31 +0100 Subject: [PATCH 12/12] Update Instrucitons --- README.md | 38 +++++++++++++++++++++----------------- info.md | 26 ++++++++++++++++---------- 2 files changed, 37 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index dc87b95..eacd212 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,8 @@ # TensorRT Extension for Stable Diffusion This extension enables the best performance on NVIDIA RTX GPUs for Stable Diffusion with TensorRT. - You need to install the extension and generate optimized engines before using the extension. Please follow the instructions below to set everything up. - -Supports Stable Diffusion 1.5 and 2.1. Native SDXL support coming in a future release. Please use the [dev branch](https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/dev) if you would like to use it today. Note that the Dev branch is not intended for production work and may break other things that you are currently using. +Supports Stable Diffusion 1.5,2.1, SDXL, SDXL Turbo, and LCM. For SDXL and SDXL Turbo, we recommend using a GPU with 12 GB or more VRAM for best performance due to its size and computational intensity. ## Installation @@ -15,44 +13,50 @@ Example instructions for Automatic1111: 3. Copy the link to this repository and paste it into URL for extension's git repository 4. Click Install + ## How to use 1. Click on the “Generate Default Engines” button. This step takes 2-10 minutes depending on your GPU. You can generate engines for other combinations. 2. Go to Settings → User Interface → Quick Settings List, add sd_unet. Apply these settings, then reload the UI. -3. Back in the main UI, select the TRT model from the sd_unet dropdown menu at the top of the page. +3. Back in the main UI, select “Automatic” from the sd_unet dropdown menu at the top of the page if not already selected. 4. You can now start generating images accelerated by TRT. If you need to create more Engines, go to the TensorRT tab. Happy prompting! +### LoRA + +To use LoRA / LyCORIS checkpoints they first need to be converted to a TensorRT format. This can be done in the TensorRT extension in the Export LoRA tab. +1. Select a LoRA checkpoint from the dropdown. +2. Export. (This will not generate an engine but only convert the weights in ~20s) +3. You can use the exported LoRAs as usual using the prompt embedding. + + ## More Information TensorRT uses optimized engines for specific resolutions and batch sizes. You can generate as many optimized engines as desired. Types: - -- The "Export Default Engines” selection adds support for resolutions between 512x512 and 768x768 for Stable Diffusion 1.5 and 768x768 to 1024x1024 for SDXL with batch sizes 1 to 4. +- The "Export Default Engines” selection adds support for resolutions between `512 x 512` and 768x768 for Stable Diffusion 1.5 and 2.1 with batch sizes 1 to 4. For SDXL, this selection generates an engine supporting a resolution of `1024 x 1024` with a batch size of `1`. - Static engines support a single specific output resolution and batch size. - Dynamic engines support a range of resolutions and batch sizes, at a small cost in performance. Wider ranges will use more VRAM. +- The first time generating an engine for a checkpoint will take longer. Additional engines generated for the same checkpoint will be much faster. Each preset can be adjusted with the “Advanced Settings” option. More detailed instructions can be found [here](https://nvidia.custhelp.com/app/answers/detail/a_id/5487/~/tensorrt-extension-for-stable-diffusion-web-ui). ### Common Issues/Limitations -**HIRES FIX:** If using the hires.fix option in Automatic1111 you must build engines that match both the starting and ending resolutions. For instance, if initial size is `512 x 512` and hires.fix upscales to `1024 x 1024`, you must either generate two engines, one at 512 and one at 1024, or generate a single dynamic engine that covers the whole range. -Having two separate engines will heavily impact performance at the moment. Stay tuned for updates. +**HIRES FIX**: If using the hires.fix option in Automatic1111 you must build engines that match both the starting and ending resolutions. For instance, if the initial size is `512 x 512` and hires.fix upscales to `1024 x 1024`, you must generate a single dynamic engine that covers the whole range. -**Resolution:** When generating images the resolution needs to be a multiple of 64. This applies to hires.fix as well, requiring the low and high-res to be divisible by 64. +**Resolution**: When generating images, the resolution needs to be a multiple of 64. This applies to hires.fix as well, requiring the low and high-res to be divisible by 64. -**Failing CMD arguments:** +**Failing CMD arguments**: -- `medvram` and `lowvram` Have caused issues when compiling the engine and running it. +- `medvram` and `lowvram` Have caused issues when compiling the engine. - `api` Has caused the `model.json` to not be updated. Resulting in SD Unets not appearing after compilation. - -**Failing installation or TensorRT tab not appearing in UI:** This is most likely due to a failed install. To resolve this manually use this [guide](https://github.com/NVIDIA/Stable-Diffusion-WebUI-TensorRT/issues/27#issuecomment-1767570566). +- Failing installation or TensorRT tab not appearing in UI: This is most likely due to a failed install. To resolve this manually use this [guide](https://github.com/NVIDIA/Stable-Diffusion-WebUI-TensorRT/issues/27#issuecomment-1767570566). ## Requirements +Driver: -**Driver**: - -- Linux: >= 450.80.02 -- Windows: >=452.39 + Linux: >= 450.80.02 +- Windows: >= 452.39 We always recommend keeping the driver up-to-date for system wide performance improvements. diff --git a/info.md b/info.md index 8ef3797..64f3073 100644 --- a/info.md +++ b/info.md @@ -1,26 +1,32 @@ # TensorRT Extension -Use this extension to generate optimized engines and enable the best performance on NVIDIA RTX GPUs with TensorRT. Please follow the instructions below to set everything up. +This extension enables the best performance on NVIDIA RTX GPUs for Stable Diffusion with TensorRT. -## Set Up +## How to use -1. Click on the "Generate Default Engines" button. This step can take 2-10 min depending on your GPU. You can generate engines for other combinations. +1. Click on the “Generate Default Engines” button. This step takes 2-10 minutes depending on your GPU. You can generate engines for other combinations. 2. Go to Settings → User Interface → Quick Settings List, add sd_unet. Apply these settings, then reload the UI. -3. Back in the main UI, select the TRT model from the sd_unet dropdown menu at the top of the page. +3. Back in the main UI, select “Automatic” from the sd_unet dropdown menu at the top of the page if not already selected. 4. You can now start generating images accelerated by TRT. If you need to create more Engines, go to the TensorRT tab. Happy prompting! +### LoRA + +To use LoRA / LyCORIS checkpoints they first need to be converted to a TensorRT format. This can be done in the TensorRT extension in the Export LoRA tab. +1. Select a LoRA checkpoint from the dropdown. +2. Export. (This will not generate an engine but only convert the weights in ~20s) +3. You can use the exported LoRAs as usual using the prompt embedding. + + ## More Information TensorRT uses optimized engines for specific resolutions and batch sizes. You can generate as many optimized engines as desired. Types: - -- The "Export Default Engines" selection adds support for resolutions between 512x512 and 768x768 for Stable Diffusion 1.5 and 768x768 to 1024x1024 for SDXL with batch sizes 1 to 4. +- The "Export Default Engines” selection adds support for resolutions between `512 x 512` and 768x768 for Stable Diffusion 1.5 and 2.1 with batch sizes 1 to 4. For SDXL, this selection generates an engine supporting a resolution of `1024 x 1024` with a batch size of `1`. - Static engines support a single specific output resolution and batch size. -- Dynamic engines support a range of resolutions and batch sizes, at a small cost in performance. Wider ranges will use more VRAM. - ---- +- Dynamic engines support a range of resolutions and batch sizes, at a small cost in performance. Wider ranges will use more VRAM. +- The first time generating an engine for a checkpoint will take longer. Additional engines generated for the same checkpoint will be much faster. -Each preset can be adjusted with the "Advanced Settings" option. +Each preset can be adjusted with the “Advanced Settings” option. More detailed instructions can be found [here](https://nvidia.custhelp.com/app/answers/detail/a_id/5487/~/tensorrt-extension-for-stable-diffusion-web-ui). For more information, please visit the TensorRT Extension GitHub page [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui-tensorrt).