From 8924b57381fecb8cc3fb789035fca5f24b96f2b4 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Thu, 21 Nov 2024 06:27:55 +0000 Subject: [PATCH 01/32] Added example of model search. --- examples/model_search/pipeline_output.py | 52 ++ .../model_search/pipeline_search_for_hubs.py | 638 ++++++++++++++++++ 2 files changed, 690 insertions(+) create mode 100644 examples/model_search/pipeline_output.py create mode 100644 examples/model_search/pipeline_search_for_hubs.py diff --git a/examples/model_search/pipeline_output.py b/examples/model_search/pipeline_output.py new file mode 100644 index 000000000000..db21f2f1ad64 --- /dev/null +++ b/examples/model_search/pipeline_output.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass + +@dataclass +class RepoStatus: + """ + Data class for storing repository status information. + + Attributes: + repo_id (str): The name of the repository. + repo_hash (str): The hash of the repository. + version (str): The version ID of the repository. + """ + repo_id: str = "" + repo_hash: str = "" + version: str = "" + + +@dataclass +class ModelStatus: + """ + Data class for storing model status information. + + Attributes: + search_word (str): The search word used to find the model. + download_url (str): The URL to download the model. + file_name (str): The name of the model file. + file_id (str): The ID of the model file. + fp (str): Floating-point precision formats. + local (bool): Whether the model is stored locally. + """ + search_word: str = "" + download_url: str = "" + file_name: str = "" + local: bool = False + + +@dataclass +class SearchPipelineOutput: + """ + Data class for storing model data. + + Attributes: + model_path (str): The path to the model. + load_type (str): The type of loading method used for the model. + repo_status (RepoStatus): The status of the repository. + model_status (ModelStatus): The status of the model. + """ + model_path: str = "" + loading_method: str = "" # "" or "from_single_file" or "from_pretrained" + checkpoint_format: str = None # "single_file" or "diffusers" + repo_status: RepoStatus = RepoStatus() + model_status: ModelStatus = ModelStatus() \ No newline at end of file diff --git a/examples/model_search/pipeline_search_for_hubs.py b/examples/model_search/pipeline_search_for_hubs.py new file mode 100644 index 000000000000..a2f507690d75 --- /dev/null +++ b/examples/model_search/pipeline_search_for_hubs.py @@ -0,0 +1,638 @@ +import os +import re +import requests +from typing import Union +from tqdm.auto import tqdm +from dataclasses import asdict +from huggingface_hub import ( + hf_api, + hf_hub_download, +) + +from diffusers.utils import logging +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.loaders.single_file_utils import ( + VALID_URL_PREFIXES, + _extract_repo_id_and_weights_name, +) + +from .pipeline_output import ( + SearchPipelineOutput, + ModelStatus, + RepoStatus, +) + + + +CUSTOM_SEARCH_KEY = { + "sd" : "stabilityai/stable-diffusion-2-1", + } + + +CONFIG_FILE_LIST = [ + "preprocessor_config.json", + "config.json", + "model.safetensors", + "model.fp16.safetensors", + "model.ckpt", + "pytorch_model.bin", + "pytorch_model.fp16.bin", + "scheduler_config.json", + "special_tokens_map.json", + "tokenizer_config.json", + "vocab.json", + "diffusion_pytorch_model.bin", + "diffusion_pytorch_model.fp16.bin", + "diffusion_pytorch_model.safetensors", + "diffusion_pytorch_model.fp16.safetensors", + "diffusion_pytorch_model.ckpt", + "diffusion_pytorch_model.fp16.ckpt", + "diffusion_pytorch_model.non_ema.bin", + "diffusion_pytorch_model.non_ema.safetensors", + "safety_checker/pytorch_model.bin", + "safety_checker/model.safetensors", + "safety_checker/model.ckpt", + "safety_checker/model.fp16.safetensors", + "safety_checker/model.fp16.ckpt", + "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model.safetensors", + "unet/diffusion_pytorch_model.fp16.safetensors", + "unet/diffusion_pytorch_model.ckpt", + "unet/diffusion_pytorch_model.fp16.ckpt", + "vae/diffusion_pytorch_model.bin", + "vae/diffusion_pytorch_model.safetensors", + "vae/diffusion_pytorch_model.fp16.safetensors", + "vae/diffusion_pytorch_model.ckpt", + "vae/diffusion_pytorch_model.fp16.ckpt", + "text_encoder/pytorch_model.bin", + "text_encoder/model.safetensors", + "text_encoder/model.fp16.safetensors", + "text_encoder/model.ckpt", + "text_encoder/model.fp16.ckpt", + "text_encoder_2/model.safetensors", + "text_encoder_2/model.ckpt" +] + +EXTENSION = [".safetensors", ".ckpt",".bin"] + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +def get_keyword_types(keyword): + """ + Determine the type and loading method for a given keyword. + + Args: + keyword (str): The input keyword to classify. + + Returns: + dict: A dictionary containing the model format, loading method, + and various types and extra types flags. + """ + + # Initialize the status dictionary with default values + status = { + "checkpoint_format": None, + "loading_method": None, + "type": { + "search_word": False, + "hf_url": False, + "hf_repo": False, + "civitai_url": False, + "local": False, + }, + "extra_type": { + "url": False, + "missing_model_index": None, + }, + } + + # Check if the keyword is an HTTP or HTTPS URL + status["extra_type"]["url"] = bool(re.search(r"^(https?)://", keyword)) + + # Check if the keyword is a file + if os.path.isfile(keyword): + status["type"]["local"] = True + status["checkpoint_format"] = "single_file" + status["loading_method"] = "from_single_file" + + # Check if the keyword is a directory + elif os.path.isdir(keyword): + status["type"]["local"] = True + status["checkpoint_format"] = "diffusers" + status["loading_method"] = "from_pretrained" + if not os.path.exists(os.path.join(keyword, "model_index.json")): + status["extra_type"]["missing_model_index"] = True + + # Check if the keyword is a Civitai URL + elif keyword.startswith("https://civitai.com/"): + status["type"]["civitai_url"] = True + status["checkpoint_format"] = "single_file" + status["loading_method"] = None + + # Check if the keyword starts with any valid URL prefixes + elif any(keyword.startswith(prefix) for prefix in VALID_URL_PREFIXES): + repo_id, weights_name = _extract_repo_id_and_weights_name(keyword) + if weights_name: + status["type"]["hf_url"] = True + status["checkpoint_format"] = "single_file" + status["loading_method"] = "from_single_file" + else: + status["type"]["hf_repo"] = True + status["checkpoint_format"] = "diffusers" + status["loading_method"] = "from_pretrained" + + # Check if the keyword matches a Hugging Face repository format + elif re.match(r"^[^/]+/[^/]+$", keyword): + status["type"]["hf_repo"] = True + status["checkpoint_format"] = "diffusers" + status["loading_method"] = "from_pretrained" + + # If none of the above, treat it as a search word + else: + status["type"]["search_word"] = True + status["checkpoint_format"] = None + status["loading_method"] = None + + return status + + +class HFSearchPipeline: + """ + Search for models from Huggingface. + """ + model_info = { + "model_path": "", + "load_type": "", + "repo_status": { + "repo_name": "", + "repo_id": "", + "revision": "" + }, + "model_status": { + "search_word": "", + "download_url": "", + "filename": "", + "local": False, + "single_file": False + }, + } + + def __init__(self): + pass + + + @staticmethod + def create_huggingface_url(repo_id, file_name): + """ + Create a Hugging Face URL for a given repository ID and file name. + + Args: + repo_id (str): The repository ID. + file_name (str): The file name within the repository. + + Returns: + str: The complete URL to the file or repository on Hugging Face. + """ + if file_name: + return f"https://huggingface.co/{repo_id}/blob/main/{file_name}" + else: + return f"https://huggingface.co/{repo_id}" + + @staticmethod + def hf_find_safest_model(models) -> str: + """ + Sort and find the safest model. + + Args: + models (list): A list of model names to sort and check. + + Returns: + The name of the safest model or the first model in the list if no safe model is found. + """ + for model in sorted(models, reverse=True): + if bool(re.search(r"(?i)[-_](safe|sfw)", model)): + return model + return models[0] + + + @classmethod + def for_HF(cls, search_word, **kwargs): + """ + Class method to search and download models from Hugging Face. + + Args: + search_word (str): The search keyword for finding models. + **kwargs: Additional keyword arguments. + + Returns: + str: The path to the downloaded model or search word. + """ + # Extract additional parameters from kwargs + revision = kwargs.pop("revision", None) + checkpoint_format = kwargs.pop("checkpoint_format", "single_file") + download = kwargs.pop("download", False) + force_download = kwargs.pop("force_download", False) + include_params = kwargs.pop("include_params", False) + pipeline_tag = kwargs.pop("pipeline_tag", None) + hf_token = kwargs.pop("hf_token", None) + skip_error = kwargs.pop("skip_error", False) + + # Get the type and loading method for the keyword + search_word_status = get_keyword_types(search_word) + + # Handle different types of keywords + if search_word_status["type"]["hf_repo"]: + if download: + model_path = DiffusionPipeline.download( + search_word, + revision=revision, + token=hf_token + ) + else: + model_path = search_word + elif search_word_status["type"]["hf_url"]: + repo_id, weights_name = _extract_repo_id_and_weights_name(search_word) + if download: + model_path = hf_hub_download( + repo_id=repo_id, + filename=weights_name, + force_download=force_download, + token=hf_token + ) + else: + model_path = search_word + elif search_word_status["type"]["local"]: + model_path = search_word + elif search_word_status["type"]["civitai_url"]: + if skip_error: + return None + else: + raise ValueError("The URL for Civitai is invalid with `for_hf`. Please use `for_civitai` instead.") + + else: + # Get model data from HF API + hf_models = hf_api.list_models( + search=search_word, + sort="downloads", + direction=-1, + limit=100, + fetch_config=True, + pipeline_tag=pipeline_tag, + full=True, + token=hf_token + ) + model_dicts = [asdict(value) for value in list(hf_models)] + + hf_repo_info = {} + file_list = [] + repo_id, file_name = "", "" + + # Loop through models to find a suitable candidate + for repo_info in model_dicts: + repo_id = repo_info["id"] + file_list = [] + hf_repo_info = hf_api.model_info( + repo_id=repo_id, + securityStatus=True + ) + # Lists files with security issues. + hf_security_info = hf_repo_info.security_repo_status + exclusion = [issue['path'] for issue in hf_security_info['filesWithIssues']] + + # Checks for multi-folder diffusers model or valid files (models with security issues are excluded). + diffusers_model_exists = False + if hf_security_info["scansDone"]: + for info in repo_info["siblings"]: + file_path = info["rfilename"] + if ( + "model_index.json" == file_path + and checkpoint_format in ["diffusers", "all"] + ): + diffusers_model_exists = True + break + + elif ( + any(file_path.endswith(ext) for ext in EXTENSION) + and (file_path not in CONFIG_FILE_LIST) + and (file_path not in exclusion) + ): + file_list.append(file_path) + + # Exit from the loop if a multi-folder diffusers model or valid file is found + if diffusers_model_exists or file_list: + break + else: + # Handle case where no models match the criteria + if skip_error: + return None + else: + raise ValueError("No models matching your criteria were found on huggingface.") + + download_url = cls.create_huggingface_url( + repo_id=repo_id, file_name=file_name + ) + if diffusers_model_exists: + if download: + model_path = DiffusionPipeline.download( + repo_id=repo_id, + token=hf_token, + ) + else: + model_path = repo_id + elif file_list: + file_name = cls.hf_find_safest_model(file_list) + if download: + model_path = hf_hub_download( + repo_id=repo_id, + filename=file_name, + revision=revision, + token=hf_token + ) + else: + model_path = cls.create_huggingface_url( + repo_id=repo_id, file_name=file_name + ) + + output_info = get_keyword_types(model_path) + + if include_params: + return SearchPipelineOutput( + model_path=model_path, + loading_method=output_info["loading_method"], + checkpoint_format=output_info["checkpoint_format"], + repo_status=RepoStatus( + repo_id=repo_id, + repo_hash=hf_repo_info["sha"], + version=revision + ), + model_status=ModelStatus( + search_word=search_word, + download_url=download_url, + file_name=file_name, + local=download, + ) + ) + + else: + return model_path + + + +class CivitaiSearchPipeline: + """ + The Civitai class is used to search and download models from Civitai. + + Attributes: + base_civitai_dir (str): Base directory for Civitai. + max_number_of_choices (int): Maximum number of choices. + chunk_size (int): Chunk size. + + Methods: + for_civitai(search_word, auto, model_type, download, civitai_token, skip_error, include_hugface): + Downloads a model from Civitai. + civitai_security_check(value): Performs a security check. + requests_civitai(query, auto, model_type, civitai_token, include_hugface): Retrieves models from Civitai. + repo_select_civitai(state, auto, recursive, include_hugface): Selects a repository from Civitai. + download_model(url, save_path, civitai_token): Downloads a model. + version_select_civitai(state, auto, recursive): Selects a model version from Civitai. + file_select_civitai(state_list, auto, recursive): Selects a file to download. + civitai_save_path(): Sets the save path. + """ + + base_civitai_dir = "/root/.cache/Civitai" + max_number_of_choices: int = 15 + chunk_size: int = 8192 + + def __init__(self): + pass + + @staticmethod + def civitai_find_safest_model(models) -> str: + """ + Sort and find the safest model. + + Args: + models (list): A list of model names to check. + + Returns: + The name of the safest model or the first model in the list if no safe model is found. + """ + + for model_data in models: + if bool(re.search(r"(?i)[-_](safe|sfw)", model_data["filename"])): + return model_data + return models[0] + + + @classmethod + def for_civitai( + cls, + search_word, + **kwargs + ) -> Union[str,SearchPipelineOutput,None]: + """ + Downloads a model from Civitai. + + Parameters: + - search_word (str): Search query string. + - auto (bool): Auto-select flag. + - model_type (str): Type of model to search for. + - download (bool): Whether to download the model. + - include_params (bool): Whether to include parameters in the returned data. + + Returns: + - SearchPipelineOutput + """ + model_type = kwargs.pop("model_type", "Checkpoint") + download = kwargs.pop("download", False) + force_download = kwargs.pop("force_download", False) + civitai_token = kwargs.pop("civitai_token", None) + include_params = kwargs.pop("include_params", False) + skip_error = kwargs.pop("skip_error", False) + + model_info = { + "model_path" : "", + "load_type" : "", + "repo_status":{ + "repo_name":"", + "repo_id":"", + "revision":"" + }, + "model_status":{ + "search_word" : "", + "download_url": "", + "filename":"", + "local" : False, + "single_file" : False + }, + } + + params = { + "query": search_word, + "types": model_type, + "sort": "Highest Rated", + "limit":20 + } + + headers = {} + if civitai_token: + headers["Authorization"] = f"Bearer {civitai_token}" + + try: + response = requests.get( + "https://civitai.com/api/v1/models", params=params, headers=headers + ) + response.raise_for_status() + except requests.exceptions.HTTPError as err: + raise requests.HTTPError(f"Could not get elements from the URL: {err}") + else: + try: + data = response.json() + except AttributeError: + if skip_error: + return None + else: + raise ValueError("Invalid JSON response") + # Put the repo sorting process on this line. + sorted_repos = sorted(data["items"], key=lambda x: x["stats"]["downloadCount"], reverse=True) + + model_path = "" + repo_name = "" + repo_id = "" + version_id = "" + models_list = [] + selected_repo = {} + selected_model = {} + selected_version = {} + + for selected_repo in sorted_repos: + repo_name = selected_repo["name"] + repo_id = selected_repo["id"] + + sorted_versions = sorted(selected_repo["modelVersions"], key=lambda x: x["stats"]["downloadCount"], reverse=True) + for selected_version in sorted_versions: + version_id = selected_version["id"] + models_list = [] + for model_data in selected_version["files"]: + if ( + model_data["pickleScanResult"] == "Success" + and model_data["virusScanResult"] == "Success" + and any(model_data["name"].endswith(ext) for ext in EXTENSION) + ): + file_status = { + "filename": model_data["name"], + "download_url": model_data["downloadUrl"], + } + models_list.append(file_status) + + if models_list: + sorted_models = sorted(models_list, key=lambda x: x["filename"], reverse=True) + selected_model = cls.civitai_find_safest_model(sorted_models) + break + else: + continue + break + + if not selected_model: + if skip_error: + return None + else: + raise ValueError("No models found") + + file_name = selected_model["filename"] + download_url = selected_model["download_url"] + # Handle file download and setting model information + if download: + model_path = f"/root/.cache/Civitai/{repo_id}/{version_id}/{file_name}" + os.makedirs(os.path.dirname(model_path), exist_ok=True) + if (not os.path.exists(model_path)) or force_download: + headers = {} + if civitai_token: + headers["Authorization"] = f"Bearer {civitai_token}" + + try: + response = requests.get(download_url, stream=True, headers=headers) + response.raise_for_status() + except requests.HTTPError: + raise requests.HTTPError(f"Invalid URL: {download_url}, {response.status_code}") + + with tqdm.wrapattr( + open(model_path, "wb"), + "write", + miniters=1, + desc=file_name, + total=int(response.headers.get("content-length", 0)), + ) as fetched_model_info: + for chunk in response.iter_content(chunk_size=8192): + fetched_model_info.write(chunk) + else: + model_path = download_url + + output_info = get_keyword_types(model_path) + + # Return appropriate result based on include_params + if not include_params: + return model_path + else: + return SearchPipelineOutput( + model_path=model_path, + loading_method=output_info["loading_method"], + checkpoint_format=output_info["checkpoint_format"], + repo_status=RepoStatus( + repo_id=repo_name, + repo_hash=repo_id, + version=version_id + ), + model_status=ModelStatus( + search_word=search_word, + download_url=download_url, + file_name=file_name, + local=output_info["type"]["local"] + ) + ) + + +class ModelSearchPipeline( + HFSearchPipeline, + CivitaiSearchPipeline + ): + + def __init__(self): + pass + + @classmethod + def for_hubs( + cls, + search_word: str, + **kwargs + ) -> Union[None, str, SearchPipelineOutput]: + """ + Search and retrieve model information from various sources (e.g., Hugging Face or CivitAI). + + This method allows flexible searching of models across different hubs. It accepts several parameters + to customize the search behavior, such as filtering by model type, format, or priority hub. Additionally, + it supports authentication tokens for private or restricted access. + + Args: + search_word (str): The search term or keyword used to locate the desired model. + download (bool, optional): Whether to download the model locally after finding it. Defaults to False. + model_type (str, optional): Type of the model to search for (e.g., "Checkpoint", "LORA"). Defaults to "Checkpoint". + checkpoint_format (str, optional): Specifies the format of the model (e.g., "single_file", "diffusers"). Defaults to "single_file". + branch (str, optional): The branch of the repository to search in. Defaults to "main". + include_params (bool, optional): Whether to include additional parameters about the model in the output. Defaults to False. + hf_token (str, optional): Hugging Face API token for authentication. Required for private or restricted models. + civitai_token (str, optional): CivitAI API token for authentication. Required for private or restricted models. + + Returns: + Union[None, str, SearchPipelineOutput]: + - `None`: If no model is found or accessible. + - `str`: A string path to the retrieved model if `include_params=False`. + - `SearchPipelineOutput`: Detailed model information if `include_params=True`. + """ + return ( + cls.for_HF(search_word=search_word, skip_error=True, **kwargs) + or cls.for_HF(search_word=search_word, skip_error=True, **kwargs) + ) \ No newline at end of file From 3e158b39efaa70a1c29d219de625402b215260e6 Mon Sep 17 00:00:00 2001 From: suzukimain Date: Thu, 21 Nov 2024 21:12:02 +0900 Subject: [PATCH 02/32] Combine processing into one file --- examples/model_search/pipeline_output.py | 52 --- ...r_hubs.py => search_for_civitai_and_HF.py} | 361 ++++++++++-------- 2 files changed, 206 insertions(+), 207 deletions(-) delete mode 100644 examples/model_search/pipeline_output.py rename examples/model_search/{pipeline_search_for_hubs.py => search_for_civitai_and_HF.py} (70%) diff --git a/examples/model_search/pipeline_output.py b/examples/model_search/pipeline_output.py deleted file mode 100644 index db21f2f1ad64..000000000000 --- a/examples/model_search/pipeline_output.py +++ /dev/null @@ -1,52 +0,0 @@ -from dataclasses import dataclass - -@dataclass -class RepoStatus: - """ - Data class for storing repository status information. - - Attributes: - repo_id (str): The name of the repository. - repo_hash (str): The hash of the repository. - version (str): The version ID of the repository. - """ - repo_id: str = "" - repo_hash: str = "" - version: str = "" - - -@dataclass -class ModelStatus: - """ - Data class for storing model status information. - - Attributes: - search_word (str): The search word used to find the model. - download_url (str): The URL to download the model. - file_name (str): The name of the model file. - file_id (str): The ID of the model file. - fp (str): Floating-point precision formats. - local (bool): Whether the model is stored locally. - """ - search_word: str = "" - download_url: str = "" - file_name: str = "" - local: bool = False - - -@dataclass -class SearchPipelineOutput: - """ - Data class for storing model data. - - Attributes: - model_path (str): The path to the model. - load_type (str): The type of loading method used for the model. - repo_status (RepoStatus): The status of the repository. - model_status (ModelStatus): The status of the model. - """ - model_path: str = "" - loading_method: str = "" # "" or "from_single_file" or "from_pretrained" - checkpoint_format: str = None # "single_file" or "diffusers" - repo_status: RepoStatus = RepoStatus() - model_status: ModelStatus = ModelStatus() \ No newline at end of file diff --git a/examples/model_search/pipeline_search_for_hubs.py b/examples/model_search/search_for_civitai_and_HF.py similarity index 70% rename from examples/model_search/pipeline_search_for_hubs.py rename to examples/model_search/search_for_civitai_and_HF.py index a2f507690d75..501de9eb0a80 100644 --- a/examples/model_search/pipeline_search_for_hubs.py +++ b/examples/model_search/search_for_civitai_and_HF.py @@ -1,9 +1,15 @@ import os import re import requests -from typing import Union +from typing import ( + Union, + List +) from tqdm.auto import tqdm -from dataclasses import asdict +from dataclasses import ( + asdict, + dataclass +) from huggingface_hub import ( hf_api, hf_hub_download, @@ -16,18 +22,6 @@ _extract_repo_id_and_weights_name, ) -from .pipeline_output import ( - SearchPipelineOutput, - ModelStatus, - RepoStatus, -) - - - -CUSTOM_SEARCH_KEY = { - "sd" : "stabilityai/stable-diffusion-2-1", - } - CONFIG_FILE_LIST = [ "preprocessor_config.json", @@ -80,16 +74,81 @@ -def get_keyword_types(keyword): +@dataclass +class RepoStatus: + r""" + Data class for storing repository status information. + + Attributes: + repo_id (`str`): + The name of the repository. + repo_hash (`str`): + The hash of the repository. + version (`str`): + The version ID of the repository. """ - Determine the type and loading method for a given keyword. + repo_id: str = "" + repo_hash: str = "" + version: str = "" + + +@dataclass +class ModelStatus: + r""" + Data class for storing model status information. + + Attributes: + search_word (`str`): + The search word used to find the model. + download_url (`str`): + The URL to download the model. + file_name (`str`): + The name of the model file. + local (`bool`): + Whether the model exists locally + """ + search_word: str = "" + download_url: str = "" + file_name: str = "" + local: bool = False + + +@dataclass +class SearchPipelineOutput: + r""" + Data class for storing model data. + + Attributes: + model_path (`str`): + The path to the model. + loading_method (`str`): + The type of loading method used for the model ( None or 'from_single_file' or 'from_pretrained') + checkpoint_format (`str`): + The format of the model checkpoint (`single_file` or `diffusers`). + repo_status (`RepoStatus`): + The status of the repository. + model_status (`ModelStatus`): + The status of the model. + """ + model_path: str = "" + loading_method: str = None + checkpoint_format: str = None + repo_status: RepoStatus = RepoStatus() + model_status: ModelStatus = ModelStatus() - Args: - keyword (str): The input keyword to classify. - + + +def get_keyword_types(keyword): + r""" + Determine the type and loading method for a given keyword. + + Parameters: + keyword (`str`): + The input keyword to classify. + Returns: - dict: A dictionary containing the model format, loading method, - and various types and extra types flags. + `dict`: A dictionary containing the model format, loading method, + and various types and extra types flags. """ # Initialize the status dictionary with default values @@ -163,38 +222,23 @@ class HFSearchPipeline: """ Search for models from Huggingface. """ - model_info = { - "model_path": "", - "load_type": "", - "repo_status": { - "repo_name": "", - "repo_id": "", - "revision": "" - }, - "model_status": { - "search_word": "", - "download_url": "", - "filename": "", - "local": False, - "single_file": False - }, - } def __init__(self): pass - @staticmethod def create_huggingface_url(repo_id, file_name): - """ + r""" Create a Hugging Face URL for a given repository ID and file name. - - Args: - repo_id (str): The repository ID. - file_name (str): The file name within the repository. - + + Parameters: + repo_id (`str`): + The repository ID. + file_name (`str`): + The file name within the repository. + Returns: - str: The complete URL to the file or repository on Hugging Face. + `str`: The complete URL to the file or repository on Hugging Face. """ if file_name: return f"https://huggingface.co/{repo_id}/blob/main/{file_name}" @@ -203,32 +247,48 @@ def create_huggingface_url(repo_id, file_name): @staticmethod def hf_find_safest_model(models) -> str: - """ + r""" Sort and find the safest model. - Args: - models (list): A list of model names to sort and check. + Parameters: + models (`list`): + A list of model names to sort and check. Returns: - The name of the safest model or the first model in the list if no safe model is found. + `str`: The name of the safest model or the first model in the list if no safe model is found. """ for model in sorted(models, reverse=True): if bool(re.search(r"(?i)[-_](safe|sfw)", model)): return model return models[0] + + @classmethod + def for_HF(cls, search_word: str, **kwargs) -> Union[str, SearchPipelineOutput, None]: + r""" + Downloads a model from Hugging Face. + Parameters: + search_word (`str`): + The search query string. + revision (`str`, *optional*): + The specific version of the model to download. + checkpoint_format (`str`, *optional*, defaults to `"single_file"`): + The format of the model checkpoint. + download (`bool`, *optional*, defaults to `False`): + Whether to download the model. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force the download if the model already exists. + include_params (`bool`, *optional*, defaults to `False`): + Whether to include parameters in the returned data. + pipeline_tag (`str`, *optional*): + Tag to filter models by pipeline. + hf_token (`str`, *optional*): + API token for Hugging Face authentication. + skip_error (`bool`, *optional*, defaults to `False`): + Whether to skip errors and return None. - @classmethod - def for_HF(cls, search_word, **kwargs): - """ - Class method to search and download models from Hugging Face. - - Args: - search_word (str): The search keyword for finding models. - **kwargs: Additional keyword arguments. - Returns: - str: The path to the downloaded model or search word. + `Union[str, SearchPipelineOutput, None]`: The model path or SearchPipelineOutput or None. """ # Extract additional parameters from kwargs revision = kwargs.pop("revision", None) @@ -383,69 +443,60 @@ def for_HF(cls, search_word, **kwargs): class CivitaiSearchPipeline: """ - The Civitai class is used to search and download models from Civitai. - - Attributes: - base_civitai_dir (str): Base directory for Civitai. - max_number_of_choices (int): Maximum number of choices. - chunk_size (int): Chunk size. - - Methods: - for_civitai(search_word, auto, model_type, download, civitai_token, skip_error, include_hugface): - Downloads a model from Civitai. - civitai_security_check(value): Performs a security check. - requests_civitai(query, auto, model_type, civitai_token, include_hugface): Retrieves models from Civitai. - repo_select_civitai(state, auto, recursive, include_hugface): Selects a repository from Civitai. - download_model(url, save_path, civitai_token): Downloads a model. - version_select_civitai(state, auto, recursive): Selects a model version from Civitai. - file_select_civitai(state_list, auto, recursive): Selects a file to download. - civitai_save_path(): Sets the save path. + Find checkpoints and more from Civitai. """ - base_civitai_dir = "/root/.cache/Civitai" - max_number_of_choices: int = 15 - chunk_size: int = 8192 - def __init__(self): pass @staticmethod - def civitai_find_safest_model(models) -> str: - """ + def civitai_find_safest_model(models: List[dict]) -> dict: + r""" Sort and find the safest model. - - Args: - models (list): A list of model names to check. - + + Parameters: + models (`list`): + A list of model dictionaries to check. Each dictionary should contain a 'filename' key. + Returns: - The name of the safest model or the first model in the list if no safe model is found. + `dict`: The dictionary of the safest model or the first model in the list if no safe model is found. """ - + for model_data in models: if bool(re.search(r"(?i)[-_](safe|sfw)", model_data["filename"])): return model_data return models[0] - @classmethod def for_civitai( cls, - search_word, + search_word: str, **kwargs - ) -> Union[str,SearchPipelineOutput,None]: - """ + ) -> Union[str, SearchPipelineOutput, None]: + r""" Downloads a model from Civitai. Parameters: - - search_word (str): Search query string. - - auto (bool): Auto-select flag. - - model_type (str): Type of model to search for. - - download (bool): Whether to download the model. - - include_params (bool): Whether to include parameters in the returned data. + search_word (`str`): + The search query string. + model_type (`str`, *optional*, defaults to `Checkpoint`): + The type of model to search for. + download (`bool`, *optional*, defaults to `False`): + Whether to download the model. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force the download if the model already exists. + civitai_token (`str`, *optional*): + API token for Civitai authentication. + include_params (`bool`, *optional*, defaults to `False`): + Whether to include parameters in the returned data. + skip_error (`bool`, *optional*, defaults to `False`): + Whether to skip errors and return None. Returns: - - SearchPipelineOutput + `Union[str, SearchPipelineOutput, None]`: The model path or `SearchPipelineOutput` or None. """ + + # Extract additional parameters from kwargs model_type = kwargs.pop("model_type", "Checkpoint") download = kwargs.pop("download", False) force_download = kwargs.pop("force_download", False) @@ -453,35 +504,31 @@ def for_civitai( include_params = kwargs.pop("include_params", False) skip_error = kwargs.pop("skip_error", False) - model_info = { - "model_path" : "", - "load_type" : "", - "repo_status":{ - "repo_name":"", - "repo_id":"", - "revision":"" - }, - "model_status":{ - "search_word" : "", - "download_url": "", - "filename":"", - "local" : False, - "single_file" : False - }, - } - + # Initialize additional variables with default values + model_path = "" + repo_name = "" + repo_id = "" + version_id = "" + models_list = [] + selected_repo = {} + selected_model = {} + selected_version = {} + + # Set up parameters and headers for the CivitAI API request params = { "query": search_word, "types": model_type, "sort": "Highest Rated", - "limit":20 - } + "limit": 20 + } + headers = {} if civitai_token: headers["Authorization"] = f"Bearer {civitai_token}" try: + # Make the request to the CivitAI API response = requests.get( "https://civitai.com/api/v1/models", params=params, headers=headers ) @@ -496,27 +543,21 @@ def for_civitai( return None else: raise ValueError("Invalid JSON response") - # Put the repo sorting process on this line. + + # Sort repositories by download count in descending order sorted_repos = sorted(data["items"], key=lambda x: x["stats"]["downloadCount"], reverse=True) - model_path = "" - repo_name = "" - repo_id = "" - version_id = "" - models_list = [] - selected_repo = {} - selected_model = {} - selected_version = {} - for selected_repo in sorted_repos: repo_name = selected_repo["name"] repo_id = selected_repo["id"] - + + # Sort versions within the selected repo by download count sorted_versions = sorted(selected_repo["modelVersions"], key=lambda x: x["stats"]["downloadCount"], reverse=True) for selected_version in sorted_versions: version_id = selected_version["id"] models_list = [] for model_data in selected_version["files"]: + # Check if the file passes security scans and has a valid extension if ( model_data["pickleScanResult"] == "Success" and model_data["virusScanResult"] == "Success" @@ -527,23 +568,25 @@ def for_civitai( "download_url": model_data["downloadUrl"], } models_list.append(file_status) - + if models_list: + # Sort the models list by filename and find the safest model sorted_models = sorted(models_list, key=lambda x: x["filename"], reverse=True) selected_model = cls.civitai_find_safest_model(sorted_models) break else: continue break - + if not selected_model: if skip_error: return None else: - raise ValueError("No models found") + raise ValueError("No model found. Please try changing the word you are searching for.") file_name = selected_model["filename"] download_url = selected_model["download_url"] + # Handle file download and setting model information if download: model_path = f"/root/.cache/Civitai/{repo_id}/{version_id}/{file_name}" @@ -551,14 +594,14 @@ def for_civitai( if (not os.path.exists(model_path)) or force_download: headers = {} if civitai_token: - headers["Authorization"] = f"Bearer {civitai_token}" - + headers["Authorization"] = f"Bearer {civitai_token}" + try: response = requests.get(download_url, stream=True, headers=headers) response.raise_for_status() except requests.HTTPError: raise requests.HTTPError(f"Invalid URL: {download_url}, {response.status_code}") - + with tqdm.wrapattr( open(model_path, "wb"), "write", @@ -573,7 +616,6 @@ def for_civitai( output_info = get_keyword_types(model_path) - # Return appropriate result based on include_params if not include_params: return model_path else: @@ -595,6 +637,7 @@ def for_civitai( ) + class ModelSearchPipeline( HFSearchPipeline, CivitaiSearchPipeline @@ -609,29 +652,37 @@ def for_hubs( search_word: str, **kwargs ) -> Union[None, str, SearchPipelineOutput]: - """ - Search and retrieve model information from various sources (e.g., Hugging Face or CivitAI). - - This method allows flexible searching of models across different hubs. It accepts several parameters - to customize the search behavior, such as filtering by model type, format, or priority hub. Additionally, - it supports authentication tokens for private or restricted access. - - Args: - search_word (str): The search term or keyword used to locate the desired model. - download (bool, optional): Whether to download the model locally after finding it. Defaults to False. - model_type (str, optional): Type of the model to search for (e.g., "Checkpoint", "LORA"). Defaults to "Checkpoint". - checkpoint_format (str, optional): Specifies the format of the model (e.g., "single_file", "diffusers"). Defaults to "single_file". - branch (str, optional): The branch of the repository to search in. Defaults to "main". - include_params (bool, optional): Whether to include additional parameters about the model in the output. Defaults to False. - hf_token (str, optional): Hugging Face API token for authentication. Required for private or restricted models. - civitai_token (str, optional): CivitAI API token for authentication. Required for private or restricted models. + r""" + Search and download models from multiple hubs. + + Parameters: + search_word (`str`): + The search query string. + model_type (`str`, *optional*, defaults to `Checkpoint`, Civitai only): + The type of model to search for. + revision (`str`, *optional*, Hugging Face only): + The specific version of the model to download. + include_params (`bool`, *optional*, defaults to `False`, both): + Whether to include parameters in the returned data. + checkpoint_format (`str`, *optional*, defaults to `"single_file"`, Hugging Face only): + The format of the model checkpoint. + download (`bool`, *optional*, defaults to `False`, both): + Whether to download the model. + pipeline_tag (`str`, *optional*, Hugging Face only): + Tag to filter models by pipeline. + force_download (`bool`, *optional*, defaults to `False`, both): + Whether to force the download if the model already exists. + hf_token (`str`, *optional*, Hugging Face only): + API token for Hugging Face authentication. + civitai_token (`str`, *optional*, Civitai only): + API token for Civitai authentication. + skip_error (`bool`, *optional*, defaults to `False`, both): + Whether to skip errors and return None. Returns: - Union[None, str, SearchPipelineOutput]: - - `None`: If no model is found or accessible. - - `str`: A string path to the retrieved model if `include_params=False`. - - `SearchPipelineOutput`: Detailed model information if `include_params=True`. + `Union[None, str, SearchPipelineOutput]`: The model path, SearchPipelineOutput, or None if not found. """ + return ( cls.for_HF(search_word=search_word, skip_error=True, **kwargs) or cls.for_HF(search_word=search_word, skip_error=True, **kwargs) From ca50857bb10817a20e4b8046b81af1ea992da3c1 Mon Sep 17 00:00:00 2001 From: suzukimain Date: Thu, 21 Nov 2024 22:06:42 +0900 Subject: [PATCH 03/32] Add parameters for base model --- examples/model_search/search_for_civitai_and_HF.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/model_search/search_for_civitai_and_HF.py b/examples/model_search/search_for_civitai_and_HF.py index 501de9eb0a80..817b379abeab 100644 --- a/examples/model_search/search_for_civitai_and_HF.py +++ b/examples/model_search/search_for_civitai_and_HF.py @@ -481,6 +481,8 @@ def for_civitai( The search query string. model_type (`str`, *optional*, defaults to `Checkpoint`): The type of model to search for. + base_model (`str`, *optional*): + The base model to filter by. download (`bool`, *optional*, defaults to `False`): Whether to download the model. force_download (`bool`, *optional*, defaults to `False`): @@ -499,6 +501,7 @@ def for_civitai( # Extract additional parameters from kwargs model_type = kwargs.pop("model_type", "Checkpoint") download = kwargs.pop("download", False) + base_model = kwargs.pop("base_model", None) force_download = kwargs.pop("force_download", False) civitai_token = kwargs.pop("civitai_token", None) include_params = kwargs.pop("include_params", False) @@ -521,7 +524,8 @@ def for_civitai( "sort": "Highest Rated", "limit": 20 } - + if base_model is not None: + params["baseModel"] = base_model headers = {} if civitai_token: @@ -670,6 +674,8 @@ def for_hubs( Whether to download the model. pipeline_tag (`str`, *optional*, Hugging Face only): Tag to filter models by pipeline. + base_model (`str`, *optional*, Civitai only): + The base model to filter by. force_download (`bool`, *optional*, defaults to `False`, both): Whether to force the download if the model already exists. hf_token (`str`, *optional*, Hugging Face only): From 28352b00f1726a99dc2948f8acaf160cc102064c Mon Sep 17 00:00:00 2001 From: suzukimain Date: Fri, 22 Nov 2024 00:35:59 +0900 Subject: [PATCH 04/32] Bug Fixes --- examples/model_search/search_for_civitai_and_HF.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/model_search/search_for_civitai_and_HF.py b/examples/model_search/search_for_civitai_and_HF.py index 817b379abeab..980ab5395457 100644 --- a/examples/model_search/search_for_civitai_and_HF.py +++ b/examples/model_search/search_for_civitai_and_HF.py @@ -44,11 +44,13 @@ "diffusion_pytorch_model.non_ema.bin", "diffusion_pytorch_model.non_ema.safetensors", "safety_checker/pytorch_model.bin", + "safety_checker/pytorch_model.fp16.bin", "safety_checker/model.safetensors", "safety_checker/model.ckpt", "safety_checker/model.fp16.safetensors", "safety_checker/model.fp16.ckpt", "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.safetensors", "unet/diffusion_pytorch_model.fp16.safetensors", "unet/diffusion_pytorch_model.ckpt", @@ -56,6 +58,7 @@ "vae/diffusion_pytorch_model.bin", "vae/diffusion_pytorch_model.safetensors", "vae/diffusion_pytorch_model.fp16.safetensors", + "vae/diffusion_pytorch_model.fp16.bin", "vae/diffusion_pytorch_model.ckpt", "vae/diffusion_pytorch_model.fp16.ckpt", "text_encoder/pytorch_model.bin", @@ -376,8 +379,8 @@ def for_HF(cls, search_word: str, **kwargs) -> Union[str, SearchPipelineOutput, elif ( any(file_path.endswith(ext) for ext in EXTENSION) - and (file_path not in CONFIG_FILE_LIST) - and (file_path not in exclusion) + and not any(config in file_path for config in CONFIG_FILE_LIST) + and not any(exc in file_path for exc in exclusion) ): file_list.append(file_path) From 76b32dacb246a5b85859f87fbc7f79837ac3e617 Mon Sep 17 00:00:00 2001 From: suzukimain Date: Fri, 22 Nov 2024 00:49:33 +0900 Subject: [PATCH 05/32] bug fix --- examples/model_search/search_for_civitai_and_HF.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_search/search_for_civitai_and_HF.py b/examples/model_search/search_for_civitai_and_HF.py index 980ab5395457..3576969e4240 100644 --- a/examples/model_search/search_for_civitai_and_HF.py +++ b/examples/model_search/search_for_civitai_and_HF.py @@ -428,7 +428,7 @@ def for_HF(cls, search_word: str, **kwargs) -> Union[str, SearchPipelineOutput, checkpoint_format=output_info["checkpoint_format"], repo_status=RepoStatus( repo_id=repo_id, - repo_hash=hf_repo_info["sha"], + repo_hash=hf_repo_info.sha, version=revision ), model_status=ModelStatus( From 17e38b00aba45646e1d521a0d55475743ffcb215 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Fri, 22 Nov 2024 01:12:16 +0900 Subject: [PATCH 06/32] Create README.md --- examples/model_search/README.md | 147 ++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 examples/model_search/README.md diff --git a/examples/model_search/README.md b/examples/model_search/README.md new file mode 100644 index 000000000000..6619279ec6b4 --- /dev/null +++ b/examples/model_search/README.md @@ -0,0 +1,147 @@ +# Searching Models on Civitai and Hugging Face +Please refer to the original library [here](https://pypi.org/project/auto-diffusers/) + +## Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run: +```bash +pip install -r requirements.txt +``` + + +## Example +```bash +!wget https://raw.githubusercontent.com/suzukimain/diffusers/refs/heads/ModelSearch/examples/model_search/search_for_civitai_and_HF.py +``` + +```python +# Search for Civitai + +from search_for_civitai_and_HF import CivitaiSearchPipeline +from diffusers import StableDiffusionPipeline + + +model_path = CivitaiSearchPipeline.for_civitai( + "any", + base_model="SD 1.5", + download=True +) +pipe = StableDiffusionPipeline.from_single_file(model_path).to("cuda") + +``` + + +```python +# Search for Hugging Face + +from search_for_civitai_and_HF import HFSearchPipeline +from diffusers import StableDiffusionPipeline + +model_path = HFSearchPipeline.for_HF( + "stable", + model_format="diffusers", + download = False + ) + +pipe = StableDiffusionPipeline.from_pretrained(model_path).to("cuda") + +# or + +model_path = HFSearchPipeline.for_HF( + "stable", + model_format="single_file", + download = False + ) + +pipe = StableDIffusionPipeline.from_single_file(model_path).to("cuda") +``` + + +
+ +> Arguments of `HFSearchPipeline.for_HF` +> +| Name | Type | Default | Description | +|:----------------:|:-------:|:-------------:|:-------------------------------------------------------------:| +| search_word | string | ー | The search query string. | +| revision | string | None | The specific version of the model to download. | +| checkpoint_format| string | "single_file" | The format of the model checkpoint. | +| download | bool | False | Whether to download the model. | +| force_download | bool | False | Whether to force the download if the model already exists. | +| include_params | bool | False | Whether to include parameters in the returned data. | +| pipeline_tag | string | None | Tag to filter models by pipeline. | +| hf_token | string | None | API token for Hugging Face authentication. | +| skip_error | bool | False | Whether to skip errors and return None. | + + + +> Arguments of `CivitaiSearchPipeline.for_civitai` +> +| Name | Type | Default | Description | +|:----------------:|:-------:|:-------------:|:-------------------------------------------------------------:| +| search_word | string | ー | The search query string. | +| model_type | string | "Checkpoint" | The type of model to search for. | +| base_model | string | None | The base model to filter by. | +| download | bool | False | Whether to download the model. | +| force_download | bool | False | Whether to force the download if the model already exists. | +| civitai_token | string | None | API token for Civitai authentication. | +| include_params | bool | False | Whether to include parameters in the returned data. | +| skip_error | bool | False | Whether to skip errors and return None. | + + + + +
+search_word + +| Type | Description | +| :--------------------------: | :--------------------------------------------------------------------: | +| keyword | Keywords to search model
| +| url | URL of either huggingface or Civitai | +| Local directory or file path | Locally stored model paths | +| huggingface path | The following format: `< creator > / < repo >` | + +
+ + + +
+model_type + +| Input Available | +| :--------------------------: | +| `Checkpoint` | +| `TextualInversion` | +| `Hypernetwork` | +| `AestheticGradient` | +| `LORA` | +| `Controlnet` | +| `Poses` | + +
+ + + +
+checkpoint_format + +| Argument | Description | +| :--------------------------: | :--------------------------------------------------------------------: | +| all | In auto, `multifolder diffusers format checkpoint` takes precedence | +| single_file | Only `single file checkpoint` are searched. | +| diffusers | Search only for `multifolder diffusers format checkpoint` | + +
+ +
From aac8073c084ef00455da0df48fb3492667ca886b Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Fri, 22 Nov 2024 01:15:46 +0900 Subject: [PATCH 07/32] Update search_for_civitai_and_HF.py --- examples/model_search/search_for_civitai_and_HF.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/model_search/search_for_civitai_and_HF.py b/examples/model_search/search_for_civitai_and_HF.py index 3576969e4240..e389eb32e074 100644 --- a/examples/model_search/search_for_civitai_and_HF.py +++ b/examples/model_search/search_for_civitai_and_HF.py @@ -306,7 +306,6 @@ def for_HF(cls, search_word: str, **kwargs) -> Union[str, SearchPipelineOutput, # Get the type and loading method for the keyword search_word_status = get_keyword_types(search_word) - # Handle different types of keywords if search_word_status["type"]["hf_repo"]: if download: model_path = DiffusionPipeline.download( @@ -695,4 +694,4 @@ def for_hubs( return ( cls.for_HF(search_word=search_word, skip_error=True, **kwargs) or cls.for_HF(search_word=search_word, skip_error=True, **kwargs) - ) \ No newline at end of file + ) From 6c3c2591ed8e5d14364876b59525bd8e392209d7 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Fri, 22 Nov 2024 01:17:59 +0900 Subject: [PATCH 08/32] Create requirements.txt --- examples/model_search/requirements.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 examples/model_search/requirements.txt diff --git a/examples/model_search/requirements.txt b/examples/model_search/requirements.txt new file mode 100644 index 000000000000..db7bc19a3a2b --- /dev/null +++ b/examples/model_search/requirements.txt @@ -0,0 +1 @@ +huggingface-hub>=0.26.2 From 5202bb142ece3097734725a918f76d22cff82b52 Mon Sep 17 00:00:00 2001 From: suzukimain Date: Fri, 22 Nov 2024 01:30:09 +0900 Subject: [PATCH 09/32] bug fix --- examples/model_search/search_for_civitai_and_HF.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/model_search/search_for_civitai_and_HF.py b/examples/model_search/search_for_civitai_and_HF.py index e389eb32e074..dd0373382c9c 100644 --- a/examples/model_search/search_for_civitai_and_HF.py +++ b/examples/model_search/search_for_civitai_and_HF.py @@ -351,7 +351,8 @@ def for_HF(cls, search_word: str, **kwargs) -> Union[str, SearchPipelineOutput, hf_repo_info = {} file_list = [] repo_id, file_name = "", "" - + diffusers_model_exists = False + # Loop through models to find a suitable candidate for repo_info in model_dicts: repo_id = repo_info["id"] @@ -365,7 +366,6 @@ def for_HF(cls, search_word: str, **kwargs) -> Union[str, SearchPipelineOutput, exclusion = [issue['path'] for issue in hf_security_info['filesWithIssues']] # Checks for multi-folder diffusers model or valid files (models with security issues are excluded). - diffusers_model_exists = False if hf_security_info["scansDone"]: for info in repo_info["siblings"]: file_path = info["rfilename"] @@ -404,6 +404,7 @@ def for_HF(cls, search_word: str, **kwargs) -> Union[str, SearchPipelineOutput, ) else: model_path = repo_id + elif file_list: file_name = cls.hf_find_safest_model(file_list) if download: @@ -694,4 +695,4 @@ def for_hubs( return ( cls.for_HF(search_word=search_word, skip_error=True, **kwargs) or cls.for_HF(search_word=search_word, skip_error=True, **kwargs) - ) + ) \ No newline at end of file From 13a66bbabd723d7a017660a54358ff3cea8e1b0f Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Fri, 22 Nov 2024 01:34:34 +0900 Subject: [PATCH 10/32] Update README.md --- examples/model_search/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index 6619279ec6b4..f0059692af9c 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -50,7 +50,7 @@ from diffusers import StableDiffusionPipeline model_path = HFSearchPipeline.for_HF( "stable", - model_format="diffusers", + checkpoint_format="diffusers", download = False ) @@ -60,7 +60,7 @@ pipe = StableDiffusionPipeline.from_pretrained(model_path).to("cuda") model_path = HFSearchPipeline.for_HF( "stable", - model_format="single_file", + checkpoint_format="single_file", download = False ) From 920876da4902364e3f7e795dd708c7a351da79d4 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Fri, 22 Nov 2024 07:47:59 +0900 Subject: [PATCH 11/32] bug fix --- examples/model_search/search_for_civitai_and_HF.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/model_search/search_for_civitai_and_HF.py b/examples/model_search/search_for_civitai_and_HF.py index dd0373382c9c..dc44b7ec675f 100644 --- a/examples/model_search/search_for_civitai_and_HF.py +++ b/examples/model_search/search_for_civitai_and_HF.py @@ -693,6 +693,6 @@ def for_hubs( """ return ( - cls.for_HF(search_word=search_word, skip_error=True, **kwargs) - or cls.for_HF(search_word=search_word, skip_error=True, **kwargs) - ) \ No newline at end of file + cls.for_HF(search_word, skip_error=True, **kwargs) + or cls.for_civitai(search_word, skip_error=True, **kwargs) + ) From d953d7f0566d4dc08e16b632a988bd25b4904105 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Fri, 22 Nov 2024 08:14:17 +0900 Subject: [PATCH 12/32] Correction of typos --- examples/model_search/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index f0059692af9c..711271d758a9 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -138,7 +138,7 @@ pipe = StableDIffusionPipeline.from_single_file(model_path).to("cuda") | Argument | Description | | :--------------------------: | :--------------------------------------------------------------------: | -| all | In auto, `multifolder diffusers format checkpoint` takes precedence | +| all | The `multifolder diffusers format checkpoint` takes precedence. | | single_file | Only `single file checkpoint` are searched. | | diffusers | Search only for `multifolder diffusers format checkpoint` | From 4b02904ce6d7a873a2b1541c28046d7d11dd0736 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:27:48 +0900 Subject: [PATCH 13/32] Update examples/model_search/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/model_search/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index 711271d758a9..4b377a71b169 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -1,4 +1,6 @@ -# Searching Models on Civitai and Hugging Face +# Search models on Civitai and Hugging Face + +The [auto_diffusers](https://github.com/suzukimain/auto_diffusers) library provides additional functionalities to Diffusers such as searching for models on Civitai and the Hugging Face Hub. Please refer to the original library [here](https://pypi.org/project/auto-diffusers/) ## Installing the dependencies From 994fe95f645dc653a356e73fea994a067e68dec3 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:28:02 +0900 Subject: [PATCH 14/32] Update examples/model_search/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/model_search/README.md | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index 4b377a71b169..d26df237d419 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -3,25 +3,17 @@ The [auto_diffusers](https://github.com/suzukimain/auto_diffusers) library provides additional functionalities to Diffusers such as searching for models on Civitai and the Hugging Face Hub. Please refer to the original library [here](https://pypi.org/project/auto-diffusers/) -## Installing the dependencies +## Installation Before running the scripts, make sure to install the library's training dependencies: -**Important** +> [!IMPORTANT] +> To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the installation up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment. -To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: ```bash git clone https://github.com/huggingface/diffusers cd diffusers pip install . -``` - -Then cd in the example folder and run: -```bash -pip install -r requirements.txt -``` - - ## Example ```bash !wget https://raw.githubusercontent.com/suzukimain/diffusers/refs/heads/ModelSearch/examples/model_search/search_for_civitai_and_HF.py From 65372d45372db4cdf4fa31d960421f8539cf21d8 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:28:15 +0900 Subject: [PATCH 15/32] Update examples/model_search/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/model_search/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index d26df237d419..55bb421349a1 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -14,7 +14,7 @@ Before running the scripts, make sure to install the library's training dependen git clone https://github.com/huggingface/diffusers cd diffusers pip install . -## Example +## Search for models ```bash !wget https://raw.githubusercontent.com/suzukimain/diffusers/refs/heads/ModelSearch/examples/model_search/search_for_civitai_and_HF.py ``` From ab76aa4aa85b17792f7557cd7fed7ad5178f74f4 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:28:29 +0900 Subject: [PATCH 16/32] Update examples/model_search/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/model_search/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index 55bb421349a1..a0257b37fad9 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -61,8 +61,6 @@ model_path = HFSearchPipeline.for_HF( pipe = StableDIffusionPipeline.from_single_file(model_path).to("cuda") ``` - -
> Arguments of `HFSearchPipeline.for_HF` > From 341f3a30c7028f67b0a9ca4cc4543ad7b8734595 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:28:46 +0900 Subject: [PATCH 17/32] Update examples/model_search/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/model_search/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index a0257b37fad9..cadd63a7e03c 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -63,7 +63,6 @@ pipe = StableDIffusionPipeline.from_single_file(model_path).to("cuda") > Arguments of `HFSearchPipeline.for_HF` -> | Name | Type | Default | Description | |:----------------:|:-------:|:-------------:|:-------------------------------------------------------------:| | search_word | string | ー | The search query string. | From 05594cc6e2e96683f9623eb9402d7b7a62160245 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:28:58 +0900 Subject: [PATCH 18/32] Update examples/model_search/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/model_search/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index cadd63a7e03c..be8024df0e8e 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -77,8 +77,7 @@ pipe = StableDIffusionPipeline.from_single_file(model_path).to("cuda") -> Arguments of `CivitaiSearchPipeline.for_civitai` -> +### CivitaiSearchPipeline.for_civitai parameters | Name | Type | Default | Description | |:----------------:|:-------:|:-------------:|:-------------------------------------------------------------:| | search_word | string | ー | The search query string. | From 241c943bb23710925f31f7ad148398e7f245eecd Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Wed, 4 Dec 2024 21:22:05 +0900 Subject: [PATCH 19/32] apply the changes --- examples/model_search/README.md | 228 ++++++++++++++++++-------------- 1 file changed, 129 insertions(+), 99 deletions(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index be8024df0e8e..a82d4fae07e8 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -14,124 +14,154 @@ Before running the scripts, make sure to install the library's training dependen git clone https://github.com/huggingface/diffusers cd diffusers pip install . -## Search for models ```bash -!wget https://raw.githubusercontent.com/suzukimain/diffusers/refs/heads/ModelSearch/examples/model_search/search_for_civitai_and_HF.py +!wget https://raw.githubusercontent.com/suzukimain/auto_diffusers/refs/heads/master/src/auto_diffusers/pipeline_easy.py ``` +### Search for Civitai ```python -# Search for Civitai +from pipeline_easy import ( + EasyPipelineForText2Image, + EasyPipelineForImage2Image, + EasyPipelineForInpainting, +) -from search_for_civitai_and_HF import CivitaiSearchPipeline -from diffusers import StableDiffusionPipeline +# Text-to-Image +pipeline = EasyPipelineForText2Image.from_civitai( + "search_word", + base_model="SD 1.5", +).to("cuda") -model_path = CivitaiSearchPipeline.for_civitai( - "any", +# Image-to-Image +pipeline = EasyPipelineForImage2Image.from_civitai( + "search_word", base_model="SD 1.5", - download=True -) -pipe = StableDiffusionPipeline.from_single_file(model_path).to("cuda") +).to("cuda") -``` +# Inpainting +pipeline = EasyPipelineForInpainting.from_civitai( + "search_word", + base_model="SD 1.5", +).to("cuda") +``` +### Search for Hugging Face ```python -# Search for Hugging Face - -from search_for_civitai_and_HF import HFSearchPipeline -from diffusers import StableDiffusionPipeline +from pipeline_easy import ( + EasyPipelineForText2Image, + EasyPipelineForImage2Image, + EasyPipelineForInpainting, +) -model_path = HFSearchPipeline.for_HF( - "stable", - checkpoint_format="diffusers", - download = False - ) +# Text-to-Image +pipeline = EasyPipelineForText2Image.from_huggingface( + "search_word", + checkpoint_format="diffusers", +).to("cuda") -pipe = StableDiffusionPipeline.from_pretrained(model_path).to("cuda") -# or +# Image-to-Image +pipeline = EasyPipelineForImage2Image.from_huggingface( + "search_word", + checkpoint_format="diffusers", +).to("cuda") -model_path = HFSearchPipeline.for_HF( - "stable", - checkpoint_format="single_file", - download = False - ) -pipe = StableDIffusionPipeline.from_single_file(model_path).to("cuda") +# Inpainting +pipeline = EasyPipelineForInpainting.from_huggingface( + "search_word", + checkpoint_format="diffusers", +).to("cuda") ``` - -> Arguments of `HFSearchPipeline.for_HF` -| Name | Type | Default | Description | -|:----------------:|:-------:|:-------------:|:-------------------------------------------------------------:| -| search_word | string | ー | The search query string. | -| revision | string | None | The specific version of the model to download. | -| checkpoint_format| string | "single_file" | The format of the model checkpoint. | -| download | bool | False | Whether to download the model. | -| force_download | bool | False | Whether to force the download if the model already exists. | -| include_params | bool | False | Whether to include parameters in the returned data. | -| pipeline_tag | string | None | Tag to filter models by pipeline. | -| hf_token | string | None | API token for Hugging Face authentication. | -| skip_error | bool | False | Whether to skip errors and return None. | - - - -### CivitaiSearchPipeline.for_civitai parameters -| Name | Type | Default | Description | -|:----------------:|:-------:|:-------------:|:-------------------------------------------------------------:| -| search_word | string | ー | The search query string. | -| model_type | string | "Checkpoint" | The type of model to search for. | -| base_model | string | None | The base model to filter by. | -| download | bool | False | Whether to download the model. | -| force_download | bool | False | Whether to force the download if the model already exists. | -| civitai_token | string | None | API token for Civitai authentication. | -| include_params | bool | False | Whether to include parameters in the returned data. | -| skip_error | bool | False | Whether to skip errors and return None. | - - - -
-search_word +### Application Examples -| Type | Description | -| :--------------------------: | :--------------------------------------------------------------------: | -| keyword | Keywords to search model
| -| url | URL of either huggingface or Civitai | -| Local directory or file path | Locally stored model paths | -| huggingface path | The following format: `< creator > / < repo >` | - -
- - - -
-model_type - -| Input Available | -| :--------------------------: | -| `Checkpoint` | -| `TextualInversion` | -| `Hypernetwork` | -| `AestheticGradient` | -| `LORA` | -| `Controlnet` | -| `Poses` | - -
- - - -
-checkpoint_format - -| Argument | Description | -| :--------------------------: | :--------------------------------------------------------------------: | -| all | The `multifolder diffusers format checkpoint` takes precedence. | -| single_file | Only `single file checkpoint` are searched. | -| diffusers | Search only for `multifolder diffusers format checkpoint` | - -
+```python +from pipeline_easy import ( + search_huggingface, + search_civitai, +) + +# Search Lora +Lora = search_civitai( + "Keyword_to_search_Lora", + model_type="LORA", + base_model = "SD 1.5", + download=True, + ) +# Load Lora into the pipeline. +pipeline.load_lora_weights(Lora) + + +# Search TextualInversion +TextualInversion = search_civitai( + "EasyNegative", + model_type="TextualInversion", + base_model = "SD 1.5", + download=True +) +# Load TextualInversion into the pipeline. +pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") +``` -
+> [!TIP] +> **If an error occurs, insert the `token` and run again.** + +### `EasyPipeline.from_civitai` parameters + +| Name | Type | Default | Description | +|:---------------:|:----------------------:|:-------------:|:-----------------------------------------------------------------------------------:| +| search_word | string, Path | ー | The search query string. Can be a keyword, Civitai URL, local directory or file path. | +| model_type | string | `Checkpoint` | The type of model to search for.
(for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`) | +| base_model | string | None | Trained model tag (for example `SD 1.5`, `SD 3.5`, `SDXL 1.0`) | +| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. | +| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. | +| cache_dir | string, Path | None | Path to the folder where cached files are stored. | +| resume | bool | False | Whether to resume an incomplete download. | +| token | string | None | API token for Civitai authentication. | + + +### `search_civitai` parameters + +| Name | Type | Default | Description | +|:---------------:|:--------------:|:-------------:|:-----------------------------------------------------------------------------------:| +| search_word | string, Path | ー | The search query string. Can be a keyword, Civitai URL, local directory or file path. | +| model_type | string | `Checkpoint` | The type of model to search for.
(for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`) | +| base_model | string | None | Trained model tag (for example `SD 1.5`, `SD 3.5`, `SDXL 1.0`) | +| download | bool | False | Whether to download the model. | +| force_download | bool | False | Whether to force the download if the model already exists. | +| cache_dir | string, Path | None | Path to the folder where cached files are stored. | +| resume | bool | False | Whether to resume an incomplete download. | +| token | string | None | API token for Civitai authentication. | +| include_params | bool | False | Whether to include parameters in the returned data. | +| skip_error | bool | False | Whether to skip errors and return None. | + + +### `EasyPipeline.from_huggingface` parameters + +| Name | Type | Default | Description | +|:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:| +| search_word | string, Path | ー | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`/`). | +| checkpoint_format | string | `single_file` | The format of the model checkpoint.
● `single_file` to search for `single file checkpoint`
●`diffusers` to search for `multifolder diffusers format checkpoint` | +| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. | +| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. | +| cache_dir | string, Path | None | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. | +| token | string, bool | None | The token to use as HTTP bearer authorization for remote files. | + + +### `search_huggingface` parameters + +| Name | Type | Default | Description | +|:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:| +| search_word | string, Path | ー | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`/`). | +| checkpoint_format | string | `single_file` | The format of the model checkpoint.
● `single_file` to search for `single file checkpoint`
●`diffusers` to search for `multifolder diffusers format checkpoint` | +| pipeline_tag | string | None | Tag to filter models by pipeline. | +| download | bool | False | Whether to download the model. | +| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. | +| cache_dir | string, Path | None | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. | +| token | string, bool | None | The token to use as HTTP bearer authorization for remote files. | +| include_params | bool | False | Whether to include parameters in the returned data. | +| skip_error | bool | False | Whether to skip errors and return None. | From 43d936a1418798e4ed913943159dc2f5d655bb7a Mon Sep 17 00:00:00 2001 From: suzukimain Date: Wed, 4 Dec 2024 21:28:23 +0900 Subject: [PATCH 20/32] Replace search_for_civitai_and_HF.py with pipeline_easy.py --- examples/model_search/pipeline_easy.py | 1563 +++++++++++++++++ .../model_search/search_for_civitai_and_HF.py | 698 -------- 2 files changed, 1563 insertions(+), 698 deletions(-) create mode 100644 examples/model_search/pipeline_easy.py delete mode 100644 examples/model_search/search_for_civitai_and_HF.py diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py new file mode 100644 index 000000000000..74045f716a75 --- /dev/null +++ b/examples/model_search/pipeline_easy.py @@ -0,0 +1,1563 @@ +# coding=utf-8 +# Copyright 2024 suzukimain +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import requests +from tqdm.auto import tqdm +from typing import Union +from collections import OrderedDict +from dataclasses import ( + dataclass, + asdict +) + +from huggingface_hub.file_download import http_get +from huggingface_hub.utils import validate_hf_hub_args +from huggingface_hub import ( + hf_api, + hf_hub_download, +) + +from diffusers.utils import logging +from diffusers.loaders.single_file_utils import ( + infer_diffusers_model_type, + load_single_file_checkpoint, + _extract_repo_id_and_weights_name, + VALID_URL_PREFIXES, +) +from diffusers.pipelines.auto_pipeline import ( + AutoPipelineForText2Image, + AutoPipelineForImage2Image, + AutoPipelineForInpainting, +) +from diffusers.pipelines.controlnet import ( + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, +) +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, +) +from diffusers.pipelines.stable_diffusion_xl import ( + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +logger = logging.get_logger(__name__) + + +SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING = OrderedDict( + [ + ("xl_base", StableDiffusionXLPipeline), + ("xl_refiner", StableDiffusionXLPipeline), + ("xl_inpaint", None), + ("playground-v2-5", StableDiffusionXLPipeline), + ("upscale", None), + ("inpainting", None), + ("inpainting_v2", None), + ("controlnet", StableDiffusionControlNetPipeline), + ("v2", StableDiffusionPipeline), + ("v1", StableDiffusionPipeline), + ] +) + +SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING = OrderedDict( + [ + ("xl_base", StableDiffusionXLImg2ImgPipeline), + ("xl_refiner", StableDiffusionXLImg2ImgPipeline), + ("xl_inpaint", None), + ("playground-v2-5", StableDiffusionXLImg2ImgPipeline), + ("upscale", None), + ("inpainting", None), + ("inpainting_v2", None), + ("controlnet", StableDiffusionControlNetImg2ImgPipeline), + ("v2", StableDiffusionImg2ImgPipeline), + ("v1", StableDiffusionImg2ImgPipeline), + ] +) + +SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING = OrderedDict( + [ + ("xl_base", None), + ("xl_refiner", None), + ("xl_inpaint", StableDiffusionXLInpaintPipeline), + ("playground-v2-5", None), + ("upscale", None), + ("inpainting", StableDiffusionInpaintPipeline), + ("inpainting_v2", StableDiffusionInpaintPipeline), + ("controlnet", StableDiffusionControlNetInpaintPipeline), + ("v2", None), + ("v1", None), + ] +) + + +CONFIG_FILE_LIST = [ + "pytorch_model.bin", + "pytorch_model.fp16.bin", + "diffusion_pytorch_model.bin", + "diffusion_pytorch_model.fp16.bin", + "diffusion_pytorch_model.safetensors", + "diffusion_pytorch_model.fp16.safetensors", + "diffusion_pytorch_model.ckpt", + "diffusion_pytorch_model.fp16.ckpt", + "diffusion_pytorch_model.non_ema.bin", + "diffusion_pytorch_model.non_ema.safetensors", +] + +DIFFUSERS_CONFIG_DIR = ["safety_checker", "unet", "vae", "text_encoder", "text_encoder_2"] + +INPAINT_PIPELINE_KEYS = [ + "xl_inpaint", + "inpainting", + "inpainting_v2", +] + +EXTENSION = [".safetensors", ".ckpt", ".bin"] + +CACHE_HOME = os.path.expanduser("~/.cache") + +@dataclass +class RepoStatus: + r""" + Data class for storing repository status information. + + Attributes: + repo_id (`str`): + The name of the repository. + repo_hash (`str`): + The hash of the repository. + version (`str`): + The version ID of the repository. + """ + repo_id: str = "" + repo_hash: str = "" + version: str = "" + +@dataclass +class ModelStatus: + r""" + Data class for storing model status information. + + Attributes: + search_word (`str`): + The search word used to find the model. + download_url (`str`): + The URL to download the model. + file_name (`str`): + The name of the model file. + local (`bool`): + Whether the model exists locally + """ + search_word: str = "" + download_url: str = "" + file_name: str = "" + local: bool = False + +@dataclass +class SearchResult: + r""" + Data class for storing model data. + + Attributes: + model_path (`str`): + The path to the model. + loading_method (`str`): + The type of loading method used for the model ( None or 'from_single_file' or 'from_pretrained') + checkpoint_format (`str`): + The format of the model checkpoint (`single_file` or `diffusers`). + repo_status (`RepoStatus`): + The status of the repository. + model_status (`ModelStatus`): + The status of the model. + """ + model_path: str = "" + loading_method: Union[str, None] = None + checkpoint_format: Union[str, None] = None + repo_status: RepoStatus = RepoStatus() + model_status: ModelStatus = ModelStatus() + + +@validate_hf_hub_args +def load_pipeline_from_single_file(pretrained_model_or_path, pipeline_mapping, **kwargs): + r""" + Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors` + format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + pipeline_mapping (`dict`): + A mapping of model types to their corresponding pipeline classes. This is used to determine + which pipeline class to instantiate based on the model type inferred from the checkpoint. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + original_config_file (`str`, *optional*): + The path to the original config file that was used to train the model. If not provided, the config file + will be inferred from the checkpoint file. + config (`str`, *optional*): + Can be either: + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline + component configs in Diffusers format. + checkpoint (`dict`, *optional*): + The loaded state dictionary of the model. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + """ + + # Load the checkpoint from the provided link or path + checkpoint = load_single_file_checkpoint(pretrained_model_or_path) + + # Infer the model type from the loaded checkpoint + model_type = infer_diffusers_model_type(checkpoint) + + # Get the corresponding pipeline class from the pipeline mapping + pipeline_class = pipeline_mapping[model_type] + + # For tasks not supported by this pipeline + if pipeline_class is None: + raise ValueError( + f"{model_type} is not supported in this pipeline." + "For `Text2Image`, please use `AutoPipelineForText2Image.from_pretrained`, " + "for `Image2Image` , please use `AutoPipelineForImage2Image.from_pretrained`, " + "and `inpaint` is only supported in `AutoPipelineForInpainting.from_pretrained`" + ) + + else: + # Instantiate and return the pipeline with the loaded checkpoint and any additional kwargs + return pipeline_class.from_single_file(pretrained_model_or_path, **kwargs) + + +def get_keyword_types(keyword): + r""" + Determine the type and loading method for a given keyword. + + Parameters: + keyword (`str`): + The input keyword to classify. + + Returns: + `dict`: A dictionary containing the model format, loading method, + and various types and extra types flags. + """ + + # Initialize the status dictionary with default values + status = { + "checkpoint_format": None, + "loading_method": None, + "type": { + "other": False, + "hf_url": False, + "hf_repo": False, + "civitai_url": False, + "local": False, + }, + "extra_type": { + "url": False, + "missing_model_index": None, + }, + } + + # Check if the keyword is an HTTP or HTTPS URL + status["extra_type"]["url"] = bool(re.search(r"^(https?)://", keyword)) + + # Check if the keyword is a file + if os.path.isfile(keyword): + status["type"]["local"] = True + status["checkpoint_format"] = "single_file" + status["loading_method"] = "from_single_file" + + # Check if the keyword is a directory + elif os.path.isdir(keyword): + status["type"]["local"] = True + status["checkpoint_format"] = "diffusers" + status["loading_method"] = "from_pretrained" + if not os.path.exists(os.path.join(keyword, "model_index.json")): + status["extra_type"]["missing_model_index"] = True + + # Check if the keyword is a Civitai URL + elif keyword.startswith("https://civitai.com/"): + status["type"]["civitai_url"] = True + status["checkpoint_format"] = "single_file" + status["loading_method"] = None + + # Check if the keyword starts with any valid URL prefixes + elif any(keyword.startswith(prefix) for prefix in VALID_URL_PREFIXES): + repo_id, weights_name = _extract_repo_id_and_weights_name(keyword) + if weights_name: + status["type"]["hf_url"] = True + status["checkpoint_format"] = "single_file" + status["loading_method"] = "from_single_file" + else: + status["type"]["hf_repo"] = True + status["checkpoint_format"] = "diffusers" + status["loading_method"] = "from_pretrained" + + # Check if the keyword matches a Hugging Face repository format + elif re.match(r"^[^/]+/[^/]+$", keyword): + status["type"]["hf_repo"] = True + status["checkpoint_format"] = "diffusers" + status["loading_method"] = "from_pretrained" + + # If none of the above apply + else: + status["type"]["other"] = True + status["checkpoint_format"] = None + status["loading_method"] = None + + return status + + +def file_downloader( + url, + save_path, + **kwargs, + ) -> None: + """ + Downloads a file from a given URL and saves it to the specified path. + + parameters: + url (`str`): + The URL of the file to download. + save_path (`str`): + The local path where the file will be saved. + resume (`bool`, *optional*, defaults to `False`): + Whether to resume an incomplete download. + headers (`dict`, *optional*, defaults to `None`): + Dictionary of HTTP Headers to send with the request. + proxies (`dict`, *optional*, defaults to `None`): + Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force the download even if the file already exists. + displayed_filename (`str`, *optional*): + The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If + not set, the filename is guessed from the URL or the `Content-Disposition` header. + + returns: + None + """ + + # Get optional parameters from kwargs, with their default values + resume = kwargs.pop("resume", False) + headers = kwargs.pop("headers", None) + proxies = kwargs.pop("proxies", None) + force_download = kwargs.pop("force_download", False) + displayed_filename = kwargs.pop("displayed_filename", None) + # Default mode for file writing and initial file size + mode = "wb" + file_size = 0 + + # Create directory + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # Check if the file already exists at the save path + if os.path.exists(save_path): + if not force_download: + # If the file exists and force_download is False, skip the download + logger.warning(f"File already exists: {save_path}, skipping download.") + return None + elif resume: + # If resuming, set mode to append binary and get current file size + mode = "ab" + file_size = os.path.getsize(save_path) + + # Open the file in the appropriate mode (write or append) + with open(save_path, mode) as model_file: + # Call the http_get function to perform the file download + return http_get( + url=url, + temp_file=model_file, + resume_size=file_size, + displayed_filename=displayed_filename, + headers=headers, + proxies=proxies, + **kwargs, + ) + + +def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, None]: + r""" + Downloads a model from Hugging Face. + + Parameters: + search_word (`str`): + The search query string. + revision (`str`, *optional*): + The specific version of the model to download. + checkpoint_format (`str`, *optional*, defaults to `"single_file"`): + The format of the model checkpoint. + download (`bool`, *optional*, defaults to `False`): + Whether to download the model. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force the download if the model already exists. + include_params (`bool`, *optional*, defaults to `False`): + Whether to include parameters in the returned data. + pipeline_tag (`str`, *optional*): + Tag to filter models by pipeline. + token (`str`, *optional*): + API token for Hugging Face authentication. + gated (`bool`, *optional*, defaults to `False` ): + A boolean to filter models on the Hub that are gated or not. + skip_error (`bool`, *optional*, defaults to `False`): + Whether to skip errors and return None. + + Returns: + `Union[str, SearchResult, None]`: The model path or SearchResult or None. + """ + # Extract additional parameters from kwargs + revision = kwargs.pop("revision", None) + checkpoint_format = kwargs.pop("checkpoint_format", "single_file") + download = kwargs.pop("download", False) + force_download = kwargs.pop("force_download", False) + include_params = kwargs.pop("include_params", False) + pipeline_tag = kwargs.pop("pipeline_tag", None) + token = kwargs.pop("token", None) + gated = kwargs.pop("gated", False) + skip_error = kwargs.pop("skip_error", False) + + # Get the type and loading method for the keyword + search_word_status = get_keyword_types(search_word) + + if search_word_status["type"]["hf_repo"]: + if download: + model_path = DiffusionPipeline.download( + search_word, + revision=revision, + token=token, + force_download=force_download, + **kwargs, + ) + else: + model_path = search_word + elif search_word_status["type"]["hf_url"]: + repo_id, weights_name = _extract_repo_id_and_weights_name(search_word) + if download: + model_path = hf_hub_download( + repo_id=repo_id, + filename=weights_name, + force_download=force_download, + token=token, + ) + else: + model_path = search_word + elif search_word_status["type"]["local"]: + model_path = search_word + elif search_word_status["type"]["civitai_url"]: + if skip_error: + return None + else: + raise ValueError("The URL for Civitai is invalid with `for_hf`. Please use `for_civitai` instead.") + else: + # Get model data from HF API + hf_models = hf_api.list_models( + search=search_word, + direction=-1, + limit=100, + fetch_config=True, + pipeline_tag=pipeline_tag, + full=True, + gated=gated, + token=token + ) + model_dicts = [asdict(value) for value in list(hf_models)] + + file_list = [] + hf_repo_info = {} + hf_security_info = {} + model_path = "" + repo_id, file_name = "", "" + diffusers_model_exists = False + + # Loop through models to find a suitable candidate + for repo_info in model_dicts: + repo_id = repo_info["id"] + file_list = [] + hf_repo_info = hf_api.model_info( + repo_id=repo_id, + securityStatus=True + ) + # Lists files with security issues. + hf_security_info = hf_repo_info.security_repo_status + exclusion = [issue['path'] for issue in hf_security_info['filesWithIssues']] + + # Checks for multi-folder diffusers model or valid files (models with security issues are excluded). + if hf_security_info["scansDone"]: + for info in repo_info["siblings"]: + file_path = info["rfilename"] + if ( + "model_index.json" == file_path + and checkpoint_format in ["diffusers", "all"] + ): + diffusers_model_exists = True + break + + elif ( + any(file_path.endswith(ext) for ext in EXTENSION) + and not any(config in file_path for config in CONFIG_FILE_LIST) + and not any(exc in file_path for exc in exclusion) + and os.path.basename(os.path.dirname(file_path)) not in DIFFUSERS_CONFIG_DIR + ): + file_list.append(file_path) + + # Exit from the loop if a multi-folder diffusers model or valid file is found + if diffusers_model_exists or file_list: + break + else: + # Handle case where no models match the criteria + if skip_error: + return None + else: + raise ValueError("No models matching your criteria were found on huggingface.") + + if diffusers_model_exists: + if download: + model_path = DiffusionPipeline.download( + repo_id, + token=token, + **kwargs, + ) + else: + model_path = repo_id + + elif file_list: + # Sort and find the safest model + file_name = next( + ( + model + for model in sorted(file_list, reverse=True) + if re.search(r"(?i)[-_](safe|sfw)", model) + ), + file_list[0] + ) + + + if download: + model_path = hf_hub_download( + repo_id=repo_id, + filename=file_name, + revision=revision, + token=token, + force_download=force_download, + ) + + if file_name: + download_url = f"https://huggingface.co/{repo_id}/blob/main/{file_name}" + else: + download_url = f"https://huggingface.co/{repo_id}" + + output_info = get_keyword_types(model_path) + + if include_params: + return SearchResult( + model_path=model_path or download_url, + loading_method=output_info["loading_method"], + checkpoint_format=output_info["checkpoint_format"], + repo_status=RepoStatus( + repo_id=repo_id, + repo_hash=hf_repo_info.sha, + version=revision + ), + model_status=ModelStatus( + search_word=search_word, + download_url=download_url, + file_name=file_name, + local=download, + ) + ) + + else: + return model_path + + +def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]: + r""" + Downloads a model from Civitai. + + Parameters: + search_word (`str`): + The search query string. + model_type (`str`, *optional*, defaults to `Checkpoint`): + The type of model to search for. + base_model (`str`, *optional*): + The base model to filter by. + download (`bool`, *optional*, defaults to `False`): + Whether to download the model. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force the download if the model already exists. + token (`str`, *optional*): + API token for Civitai authentication. + include_params (`bool`, *optional*, defaults to `False`): + Whether to include parameters in the returned data. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + resume (`bool`, *optional*, defaults to `False`): + Whether to resume an incomplete download. + skip_error (`bool`, *optional*, defaults to `False`): + Whether to skip errors and return None. + + Returns: + `Union[str, SearchResult, None]`: The model path or ` SearchResult` or None. + """ + + # Extract additional parameters from kwargs + model_type = kwargs.pop("model_type", "Checkpoint") + download = kwargs.pop("download", False) + base_model = kwargs.pop("base_model", None) + force_download = kwargs.pop("force_download", False) + token = kwargs.pop("token", None) + include_params = kwargs.pop("include_params", False) + resume = kwargs.pop("resume", False) + cache_dir = kwargs.pop("cache_dir", None) + skip_error = kwargs.pop("skip_error", False) + + # Initialize additional variables with default values + model_path = "" + repo_name = "" + repo_id = "" + version_id = "" + models_list = [] + selected_repo = {} + selected_model = {} + selected_version = {} + civitai_cache_dir = cache_dir or os.path.join(CACHE_HOME, "Civitai") + + # Set up parameters and headers for the CivitAI API request + params = { + "query": search_word, + "types": model_type, + "sort": "Most Downloaded", + "limit": 20, + } + if base_model is not None: + params["baseModel"] = base_model + + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + try: + # Make the request to the CivitAI API + response = requests.get( + "https://civitai.com/api/v1/models", params=params, headers=headers + ) + response.raise_for_status() + except requests.exceptions.HTTPError as err: + raise requests.HTTPError(f"Could not get elements from the URL: {err}") + else: + try: + data = response.json() + except AttributeError: + if skip_error: + return None + else: + raise ValueError("Invalid JSON response") + + # Sort repositories by download count in descending order + sorted_repos = sorted(data["items"], key=lambda x: x["stats"]["downloadCount"], reverse=True) + + for selected_repo in sorted_repos: + repo_name = selected_repo["name"] + repo_id = selected_repo["id"] + + # Sort versions within the selected repo by download count + sorted_versions = sorted(selected_repo["modelVersions"], key=lambda x: x["stats"]["downloadCount"], reverse=True) + for selected_version in sorted_versions: + version_id = selected_version["id"] + models_list = [] + for model_data in selected_version["files"]: + # Check if the file passes security scans and has a valid extension + file_name = model_data["name"] + if ( + model_data["pickleScanResult"] == "Success" + and model_data["virusScanResult"] == "Success" + and any(file_name.endswith(ext) for ext in EXTENSION) + and os.path.basename(os.path.dirname(file_name)) not in DIFFUSERS_CONFIG_DIR + ): + file_status = { + "filename": file_name, + "download_url": model_data["downloadUrl"], + } + models_list.append(file_status) + + if models_list: + # Sort the models list by filename and find the safest model + sorted_models = sorted(models_list, key=lambda x: x["filename"], reverse=True) + selected_model = next( + ( + model_data + for model_data in sorted_models + if bool(re.search(r"(?i)[-_](safe|sfw)", model_data["filename"])) + ), + sorted_models[0] + ) + + break + else: + continue + break + + # Exception handling when search candidates are not found + if not selected_model: + if skip_error: + return None + else: + raise ValueError("No model found. Please try changing the word you are searching for.") + + # Define model file status + file_name = selected_model["filename"] + download_url = selected_model["download_url"] + + # Handle file download and setting model information + if download: + # The path where the model is to be saved. + model_path = os.path.join( + str(civitai_cache_dir), str(repo_id), str(version_id), str(file_name) + ) + # Download Model File + file_downloader( + url=download_url, + save_path=model_path, + resume=resume, + force_download=force_download, + displayed_filename=file_name, + headers=headers, + **kwargs, + ) + + else: + model_path = download_url + + output_info = get_keyword_types(model_path) + + if not include_params: + return model_path + else: + return SearchResult( + model_path=model_path, + loading_method=output_info["loading_method"], + checkpoint_format=output_info["checkpoint_format"], + repo_status=RepoStatus( + repo_id=repo_name, + repo_hash=repo_id, + version=version_id + ), + model_status=ModelStatus( + search_word=search_word, + download_url=download_url, + file_name=file_name, + local=output_info["type"]["local"] + ) + ) + + +class EasyPipelineForText2Image(AutoPipelineForText2Image): + r""" + + [`AutoPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The + specific underlying pipeline class is automatically selected from either the + [`~AutoPipelineForText2Image.from_pretrained`] or [`~AutoPipelineForText2Image.from_pipe`] methods. + + This class cannot be instantiated using `__init__()` (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + + """ + + config_name = "model_index.json" + + def __init__(self, *args, **kwargs): + # EnvironmentError is returned + super().__init__() + + + @classmethod + @validate_hf_hub_args + def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): + r""" + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A keyword to search for Hugging Face (for example `Stable Diffusion`) + - Link to `.ckpt` or `.safetensors` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + checkpoint_format (`str`, *optional*, defaults to `"single_file"`): + The format of the model checkpoint. + pipeline_tag (`str`, *optional*): + Tag to filter models by pipeline. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + gated (`bool`, *optional*, defaults to `False` ): + A boolean to filter models on the Hub that are gated or not. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForText2Image + + >>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") + >>> image = pipeline(prompt).images[0] + ``` + """ + # Update kwargs to ensure the model is downloaded and parameters are included + _status = { + "download": True, + "include_params": True, + "skip_error": False, + "pipeline_tag": "text-to-image", + } + kwargs.update(_status) + + # Search for the model on Hugging Face and get the model status + hf_model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) + logger.warning(f"checkpoint_path: {hf_model_status.model_status.download_url}") + checkpoint_path = hf_model_status.model_path + + # Check the format of the model checkpoint + if hf_model_status.checkpoint_format == "single_file": + # Load the pipeline from a single file checkpoint + return load_pipeline_from_single_file( + pretrained_model_or_path=checkpoint_path, + pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING, + **kwargs + ) + else: + return cls.from_pretrained(checkpoint_path, **kwargs) + + @classmethod + def from_civitai(cls, pretrained_model_link_or_path, **kwargs): + r""" + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A keyword to search for Hugging Face (for example `Stable Diffusion`) + - Link to `.ckpt` or `.safetensors` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + model_type (`str`, *optional*, defaults to `Checkpoint`): + The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `LORA`, `Controlnet`) + base_model (`str`, *optional*): + The base model to filter by. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + resume (`bool`, *optional*, defaults to `False`): + Whether to resume an incomplete download. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForText2Image + + >>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") + >>> image = pipeline(prompt).images[0] + ``` + """ + # Update kwargs to ensure the model is downloaded and parameters are included + _status = { + "download": True, + "include_params": True, + "skip_error": False, + "model_type": "Checkpoint", + } + kwargs.update(_status) + + # Search for the model on Civitai and get the model status + model_status = search_civitai(pretrained_model_link_or_path, **kwargs) + logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") + checkpoint_path = model_status.model_path + + # Load the pipeline from a single file checkpoint + return load_pipeline_from_single_file( + pretrained_model_or_path=checkpoint_path, + pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING, + **kwargs + ) + + + +class EasyPipelineForImage2Image(AutoPipelineForImage2Image): + r""" + + [`AutoPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The + specific underlying pipeline class is automatically selected from either the + [`~AutoPipelineForImage2Image.from_pretrained`] or [`~AutoPipelineForImage2Image.from_pipe`] methods. + + This class cannot be instantiated using `__init__()` (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + + """ + + config_name = "model_index.json" + + def __init__(self, *args, **kwargs): + # EnvironmentError is returned + super().__init__() + + @classmethod + @validate_hf_hub_args + def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): + r""" + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A keyword to search for Hugging Face (for example `Stable Diffusion`) + - Link to `.ckpt` or `.safetensors` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + checkpoint_format (`str`, *optional*, defaults to `"single_file"`): + The format of the model checkpoint. + pipeline_tag (`str`, *optional*): + Tag to filter models by pipeline. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + gated (`bool`, *optional*, defaults to `False` ): + A boolean to filter models on the Hub that are gated or not. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForText2Image + + >>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") + >>> image = pipeline(prompt).images[0] + ``` + """ + # Update kwargs to ensure the model is downloaded and parameters are included + _parmas = { + "download": True, + "include_params": True, + "skip_error": False, + "pipeline_tag": "image-to-image", + } + kwargs.update(_parmas) + + # Search for the model on Hugging Face and get the model status + model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) + logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") + checkpoint_path = model_status.model_path + + # Check the format of the model checkpoint + if model_status.checkpoint_format == "single_file": + # Load the pipeline from a single file checkpoint + return load_pipeline_from_single_file( + pretrained_model_or_path=checkpoint_path, + pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING, + **kwargs + ) + else: + return cls.from_pretrained(checkpoint_path, **kwargs) + + @classmethod + def from_civitai(cls, pretrained_model_link_or_path, **kwargs): + r""" + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A keyword to search for Hugging Face (for example `Stable Diffusion`) + - Link to `.ckpt` or `.safetensors` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + model_type (`str`, *optional*, defaults to `Checkpoint`): + The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `LORA`, `Controlnet`) + base_model (`str`, *optional*): + The base model to filter by. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + resume (`bool`, *optional*, defaults to `False`): + Whether to resume an incomplete download. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForText2Image + + >>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") + >>> image = pipeline(prompt).images[0] + ``` + """ + # Update kwargs to ensure the model is downloaded and parameters are included + _status = { + "download": True, + "include_params": True, + "skip_error": False, + "model_type": "Checkpoint", + } + kwargs.update(_status) + + # Search for the model on Civitai and get the model status + model_status = search_civitai(pretrained_model_link_or_path, **kwargs) + logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") + checkpoint_path = model_status.model_path + + # Load the pipeline from a single file checkpoint + return load_pipeline_from_single_file( + pretrained_model_or_path=checkpoint_path, + pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING, + **kwargs + ) + + + +class EasyPipelineForInpainting(AutoPipelineForInpainting): + r""" + + [`AutoPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The + specific underlying pipeline class is automatically selected from either the + [`~AutoPipelineForInpainting.from_pretrained`] or [`~AutoPipelineForInpainting.from_pipe`] methods. + + This class cannot be instantiated using `__init__()` (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + + """ + + config_name = "model_index.json" + + def __init__(self, *args, **kwargs): + # EnvironmentError is returned + super().__init__() + + @classmethod + @validate_hf_hub_args + def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): + r""" + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A keyword to search for Hugging Face (for example `Stable Diffusion`) + - Link to `.ckpt` or `.safetensors` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + checkpoint_format (`str`, *optional*, defaults to `"single_file"`): + The format of the model checkpoint. + pipeline_tag (`str`, *optional*): + Tag to filter models by pipeline. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + gated (`bool`, *optional*, defaults to `False` ): + A boolean to filter models on the Hub that are gated or not. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForText2Image + + >>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") + >>> image = pipeline(prompt).images[0] + ``` + """ + # Update kwargs to ensure the model is downloaded and parameters are included + _status = { + "download": True, + "include_params": True, + "skip_error": False, + "pipeline_tag": "image-to-image", + } + kwargs.update(_status) + + # Search for the model on Hugging Face and get the model status + model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) + logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") + checkpoint_path = model_status.model_path + + # Check the format of the model checkpoint + if model_status.checkpoint_format == "single_file": + # Load the pipeline from a single file checkpoint + return load_pipeline_from_single_file( + pretrained_model_or_path=checkpoint_path, + pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING, + **kwargs + ) + else: + return cls.from_pretrained(checkpoint_path, **kwargs) + + @classmethod + def from_civitai(cls, pretrained_model_link_or_path, **kwargs): + r""" + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A keyword to search for Hugging Face (for example `Stable Diffusion`) + - Link to `.ckpt` or `.safetensors` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + model_type (`str`, *optional*, defaults to `Checkpoint`): + The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `LORA`, `Controlnet`) + base_model (`str`, *optional*): + The base model to filter by. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + resume (`bool`, *optional*, defaults to `False`): + Whether to resume an incomplete download. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForText2Image + + >>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5") + >>> image = pipeline(prompt).images[0] + ``` + """ + # Update kwargs to ensure the model is downloaded and parameters are included + _status = { + "download": True, + "include_params": True, + "skip_error": False, + "model_type": "Checkpoint", + } + kwargs.update(_status) + + # Search for the model on Civitai and get the model status + model_status = search_civitai(pretrained_model_link_or_path, **kwargs) + logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") + checkpoint_path = model_status.model_path + + # Load the pipeline from a single file checkpoint + return load_pipeline_from_single_file( + pretrained_model_or_path=checkpoint_path, + pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING, + **kwargs + ) \ No newline at end of file diff --git a/examples/model_search/search_for_civitai_and_HF.py b/examples/model_search/search_for_civitai_and_HF.py deleted file mode 100644 index dc44b7ec675f..000000000000 --- a/examples/model_search/search_for_civitai_and_HF.py +++ /dev/null @@ -1,698 +0,0 @@ -import os -import re -import requests -from typing import ( - Union, - List -) -from tqdm.auto import tqdm -from dataclasses import ( - asdict, - dataclass -) -from huggingface_hub import ( - hf_api, - hf_hub_download, -) - -from diffusers.utils import logging -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.loaders.single_file_utils import ( - VALID_URL_PREFIXES, - _extract_repo_id_and_weights_name, -) - - -CONFIG_FILE_LIST = [ - "preprocessor_config.json", - "config.json", - "model.safetensors", - "model.fp16.safetensors", - "model.ckpt", - "pytorch_model.bin", - "pytorch_model.fp16.bin", - "scheduler_config.json", - "special_tokens_map.json", - "tokenizer_config.json", - "vocab.json", - "diffusion_pytorch_model.bin", - "diffusion_pytorch_model.fp16.bin", - "diffusion_pytorch_model.safetensors", - "diffusion_pytorch_model.fp16.safetensors", - "diffusion_pytorch_model.ckpt", - "diffusion_pytorch_model.fp16.ckpt", - "diffusion_pytorch_model.non_ema.bin", - "diffusion_pytorch_model.non_ema.safetensors", - "safety_checker/pytorch_model.bin", - "safety_checker/pytorch_model.fp16.bin", - "safety_checker/model.safetensors", - "safety_checker/model.ckpt", - "safety_checker/model.fp16.safetensors", - "safety_checker/model.fp16.ckpt", - "unet/diffusion_pytorch_model.bin", - "unet/diffusion_pytorch_model.fp16.bin", - "unet/diffusion_pytorch_model.safetensors", - "unet/diffusion_pytorch_model.fp16.safetensors", - "unet/diffusion_pytorch_model.ckpt", - "unet/diffusion_pytorch_model.fp16.ckpt", - "vae/diffusion_pytorch_model.bin", - "vae/diffusion_pytorch_model.safetensors", - "vae/diffusion_pytorch_model.fp16.safetensors", - "vae/diffusion_pytorch_model.fp16.bin", - "vae/diffusion_pytorch_model.ckpt", - "vae/diffusion_pytorch_model.fp16.ckpt", - "text_encoder/pytorch_model.bin", - "text_encoder/model.safetensors", - "text_encoder/model.fp16.safetensors", - "text_encoder/model.ckpt", - "text_encoder/model.fp16.ckpt", - "text_encoder_2/model.safetensors", - "text_encoder_2/model.ckpt" -] - -EXTENSION = [".safetensors", ".ckpt",".bin"] - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - - -@dataclass -class RepoStatus: - r""" - Data class for storing repository status information. - - Attributes: - repo_id (`str`): - The name of the repository. - repo_hash (`str`): - The hash of the repository. - version (`str`): - The version ID of the repository. - """ - repo_id: str = "" - repo_hash: str = "" - version: str = "" - - -@dataclass -class ModelStatus: - r""" - Data class for storing model status information. - - Attributes: - search_word (`str`): - The search word used to find the model. - download_url (`str`): - The URL to download the model. - file_name (`str`): - The name of the model file. - local (`bool`): - Whether the model exists locally - """ - search_word: str = "" - download_url: str = "" - file_name: str = "" - local: bool = False - - -@dataclass -class SearchPipelineOutput: - r""" - Data class for storing model data. - - Attributes: - model_path (`str`): - The path to the model. - loading_method (`str`): - The type of loading method used for the model ( None or 'from_single_file' or 'from_pretrained') - checkpoint_format (`str`): - The format of the model checkpoint (`single_file` or `diffusers`). - repo_status (`RepoStatus`): - The status of the repository. - model_status (`ModelStatus`): - The status of the model. - """ - model_path: str = "" - loading_method: str = None - checkpoint_format: str = None - repo_status: RepoStatus = RepoStatus() - model_status: ModelStatus = ModelStatus() - - - -def get_keyword_types(keyword): - r""" - Determine the type and loading method for a given keyword. - - Parameters: - keyword (`str`): - The input keyword to classify. - - Returns: - `dict`: A dictionary containing the model format, loading method, - and various types and extra types flags. - """ - - # Initialize the status dictionary with default values - status = { - "checkpoint_format": None, - "loading_method": None, - "type": { - "search_word": False, - "hf_url": False, - "hf_repo": False, - "civitai_url": False, - "local": False, - }, - "extra_type": { - "url": False, - "missing_model_index": None, - }, - } - - # Check if the keyword is an HTTP or HTTPS URL - status["extra_type"]["url"] = bool(re.search(r"^(https?)://", keyword)) - - # Check if the keyword is a file - if os.path.isfile(keyword): - status["type"]["local"] = True - status["checkpoint_format"] = "single_file" - status["loading_method"] = "from_single_file" - - # Check if the keyword is a directory - elif os.path.isdir(keyword): - status["type"]["local"] = True - status["checkpoint_format"] = "diffusers" - status["loading_method"] = "from_pretrained" - if not os.path.exists(os.path.join(keyword, "model_index.json")): - status["extra_type"]["missing_model_index"] = True - - # Check if the keyword is a Civitai URL - elif keyword.startswith("https://civitai.com/"): - status["type"]["civitai_url"] = True - status["checkpoint_format"] = "single_file" - status["loading_method"] = None - - # Check if the keyword starts with any valid URL prefixes - elif any(keyword.startswith(prefix) for prefix in VALID_URL_PREFIXES): - repo_id, weights_name = _extract_repo_id_and_weights_name(keyword) - if weights_name: - status["type"]["hf_url"] = True - status["checkpoint_format"] = "single_file" - status["loading_method"] = "from_single_file" - else: - status["type"]["hf_repo"] = True - status["checkpoint_format"] = "diffusers" - status["loading_method"] = "from_pretrained" - - # Check if the keyword matches a Hugging Face repository format - elif re.match(r"^[^/]+/[^/]+$", keyword): - status["type"]["hf_repo"] = True - status["checkpoint_format"] = "diffusers" - status["loading_method"] = "from_pretrained" - - # If none of the above, treat it as a search word - else: - status["type"]["search_word"] = True - status["checkpoint_format"] = None - status["loading_method"] = None - - return status - - -class HFSearchPipeline: - """ - Search for models from Huggingface. - """ - - def __init__(self): - pass - - @staticmethod - def create_huggingface_url(repo_id, file_name): - r""" - Create a Hugging Face URL for a given repository ID and file name. - - Parameters: - repo_id (`str`): - The repository ID. - file_name (`str`): - The file name within the repository. - - Returns: - `str`: The complete URL to the file or repository on Hugging Face. - """ - if file_name: - return f"https://huggingface.co/{repo_id}/blob/main/{file_name}" - else: - return f"https://huggingface.co/{repo_id}" - - @staticmethod - def hf_find_safest_model(models) -> str: - r""" - Sort and find the safest model. - - Parameters: - models (`list`): - A list of model names to sort and check. - - Returns: - `str`: The name of the safest model or the first model in the list if no safe model is found. - """ - for model in sorted(models, reverse=True): - if bool(re.search(r"(?i)[-_](safe|sfw)", model)): - return model - return models[0] - - @classmethod - def for_HF(cls, search_word: str, **kwargs) -> Union[str, SearchPipelineOutput, None]: - r""" - Downloads a model from Hugging Face. - - Parameters: - search_word (`str`): - The search query string. - revision (`str`, *optional*): - The specific version of the model to download. - checkpoint_format (`str`, *optional*, defaults to `"single_file"`): - The format of the model checkpoint. - download (`bool`, *optional*, defaults to `False`): - Whether to download the model. - force_download (`bool`, *optional*, defaults to `False`): - Whether to force the download if the model already exists. - include_params (`bool`, *optional*, defaults to `False`): - Whether to include parameters in the returned data. - pipeline_tag (`str`, *optional*): - Tag to filter models by pipeline. - hf_token (`str`, *optional*): - API token for Hugging Face authentication. - skip_error (`bool`, *optional*, defaults to `False`): - Whether to skip errors and return None. - - Returns: - `Union[str, SearchPipelineOutput, None]`: The model path or SearchPipelineOutput or None. - """ - # Extract additional parameters from kwargs - revision = kwargs.pop("revision", None) - checkpoint_format = kwargs.pop("checkpoint_format", "single_file") - download = kwargs.pop("download", False) - force_download = kwargs.pop("force_download", False) - include_params = kwargs.pop("include_params", False) - pipeline_tag = kwargs.pop("pipeline_tag", None) - hf_token = kwargs.pop("hf_token", None) - skip_error = kwargs.pop("skip_error", False) - - # Get the type and loading method for the keyword - search_word_status = get_keyword_types(search_word) - - if search_word_status["type"]["hf_repo"]: - if download: - model_path = DiffusionPipeline.download( - search_word, - revision=revision, - token=hf_token - ) - else: - model_path = search_word - elif search_word_status["type"]["hf_url"]: - repo_id, weights_name = _extract_repo_id_and_weights_name(search_word) - if download: - model_path = hf_hub_download( - repo_id=repo_id, - filename=weights_name, - force_download=force_download, - token=hf_token - ) - else: - model_path = search_word - elif search_word_status["type"]["local"]: - model_path = search_word - elif search_word_status["type"]["civitai_url"]: - if skip_error: - return None - else: - raise ValueError("The URL for Civitai is invalid with `for_hf`. Please use `for_civitai` instead.") - - else: - # Get model data from HF API - hf_models = hf_api.list_models( - search=search_word, - sort="downloads", - direction=-1, - limit=100, - fetch_config=True, - pipeline_tag=pipeline_tag, - full=True, - token=hf_token - ) - model_dicts = [asdict(value) for value in list(hf_models)] - - hf_repo_info = {} - file_list = [] - repo_id, file_name = "", "" - diffusers_model_exists = False - - # Loop through models to find a suitable candidate - for repo_info in model_dicts: - repo_id = repo_info["id"] - file_list = [] - hf_repo_info = hf_api.model_info( - repo_id=repo_id, - securityStatus=True - ) - # Lists files with security issues. - hf_security_info = hf_repo_info.security_repo_status - exclusion = [issue['path'] for issue in hf_security_info['filesWithIssues']] - - # Checks for multi-folder diffusers model or valid files (models with security issues are excluded). - if hf_security_info["scansDone"]: - for info in repo_info["siblings"]: - file_path = info["rfilename"] - if ( - "model_index.json" == file_path - and checkpoint_format in ["diffusers", "all"] - ): - diffusers_model_exists = True - break - - elif ( - any(file_path.endswith(ext) for ext in EXTENSION) - and not any(config in file_path for config in CONFIG_FILE_LIST) - and not any(exc in file_path for exc in exclusion) - ): - file_list.append(file_path) - - # Exit from the loop if a multi-folder diffusers model or valid file is found - if diffusers_model_exists or file_list: - break - else: - # Handle case where no models match the criteria - if skip_error: - return None - else: - raise ValueError("No models matching your criteria were found on huggingface.") - - download_url = cls.create_huggingface_url( - repo_id=repo_id, file_name=file_name - ) - if diffusers_model_exists: - if download: - model_path = DiffusionPipeline.download( - repo_id=repo_id, - token=hf_token, - ) - else: - model_path = repo_id - - elif file_list: - file_name = cls.hf_find_safest_model(file_list) - if download: - model_path = hf_hub_download( - repo_id=repo_id, - filename=file_name, - revision=revision, - token=hf_token - ) - else: - model_path = cls.create_huggingface_url( - repo_id=repo_id, file_name=file_name - ) - - output_info = get_keyword_types(model_path) - - if include_params: - return SearchPipelineOutput( - model_path=model_path, - loading_method=output_info["loading_method"], - checkpoint_format=output_info["checkpoint_format"], - repo_status=RepoStatus( - repo_id=repo_id, - repo_hash=hf_repo_info.sha, - version=revision - ), - model_status=ModelStatus( - search_word=search_word, - download_url=download_url, - file_name=file_name, - local=download, - ) - ) - - else: - return model_path - - - -class CivitaiSearchPipeline: - """ - Find checkpoints and more from Civitai. - """ - - def __init__(self): - pass - - @staticmethod - def civitai_find_safest_model(models: List[dict]) -> dict: - r""" - Sort and find the safest model. - - Parameters: - models (`list`): - A list of model dictionaries to check. Each dictionary should contain a 'filename' key. - - Returns: - `dict`: The dictionary of the safest model or the first model in the list if no safe model is found. - """ - - for model_data in models: - if bool(re.search(r"(?i)[-_](safe|sfw)", model_data["filename"])): - return model_data - return models[0] - - @classmethod - def for_civitai( - cls, - search_word: str, - **kwargs - ) -> Union[str, SearchPipelineOutput, None]: - r""" - Downloads a model from Civitai. - - Parameters: - search_word (`str`): - The search query string. - model_type (`str`, *optional*, defaults to `Checkpoint`): - The type of model to search for. - base_model (`str`, *optional*): - The base model to filter by. - download (`bool`, *optional*, defaults to `False`): - Whether to download the model. - force_download (`bool`, *optional*, defaults to `False`): - Whether to force the download if the model already exists. - civitai_token (`str`, *optional*): - API token for Civitai authentication. - include_params (`bool`, *optional*, defaults to `False`): - Whether to include parameters in the returned data. - skip_error (`bool`, *optional*, defaults to `False`): - Whether to skip errors and return None. - - Returns: - `Union[str, SearchPipelineOutput, None]`: The model path or `SearchPipelineOutput` or None. - """ - - # Extract additional parameters from kwargs - model_type = kwargs.pop("model_type", "Checkpoint") - download = kwargs.pop("download", False) - base_model = kwargs.pop("base_model", None) - force_download = kwargs.pop("force_download", False) - civitai_token = kwargs.pop("civitai_token", None) - include_params = kwargs.pop("include_params", False) - skip_error = kwargs.pop("skip_error", False) - - # Initialize additional variables with default values - model_path = "" - repo_name = "" - repo_id = "" - version_id = "" - models_list = [] - selected_repo = {} - selected_model = {} - selected_version = {} - - # Set up parameters and headers for the CivitAI API request - params = { - "query": search_word, - "types": model_type, - "sort": "Highest Rated", - "limit": 20 - } - if base_model is not None: - params["baseModel"] = base_model - - headers = {} - if civitai_token: - headers["Authorization"] = f"Bearer {civitai_token}" - - try: - # Make the request to the CivitAI API - response = requests.get( - "https://civitai.com/api/v1/models", params=params, headers=headers - ) - response.raise_for_status() - except requests.exceptions.HTTPError as err: - raise requests.HTTPError(f"Could not get elements from the URL: {err}") - else: - try: - data = response.json() - except AttributeError: - if skip_error: - return None - else: - raise ValueError("Invalid JSON response") - - # Sort repositories by download count in descending order - sorted_repos = sorted(data["items"], key=lambda x: x["stats"]["downloadCount"], reverse=True) - - for selected_repo in sorted_repos: - repo_name = selected_repo["name"] - repo_id = selected_repo["id"] - - # Sort versions within the selected repo by download count - sorted_versions = sorted(selected_repo["modelVersions"], key=lambda x: x["stats"]["downloadCount"], reverse=True) - for selected_version in sorted_versions: - version_id = selected_version["id"] - models_list = [] - for model_data in selected_version["files"]: - # Check if the file passes security scans and has a valid extension - if ( - model_data["pickleScanResult"] == "Success" - and model_data["virusScanResult"] == "Success" - and any(model_data["name"].endswith(ext) for ext in EXTENSION) - ): - file_status = { - "filename": model_data["name"], - "download_url": model_data["downloadUrl"], - } - models_list.append(file_status) - - if models_list: - # Sort the models list by filename and find the safest model - sorted_models = sorted(models_list, key=lambda x: x["filename"], reverse=True) - selected_model = cls.civitai_find_safest_model(sorted_models) - break - else: - continue - break - - if not selected_model: - if skip_error: - return None - else: - raise ValueError("No model found. Please try changing the word you are searching for.") - - file_name = selected_model["filename"] - download_url = selected_model["download_url"] - - # Handle file download and setting model information - if download: - model_path = f"/root/.cache/Civitai/{repo_id}/{version_id}/{file_name}" - os.makedirs(os.path.dirname(model_path), exist_ok=True) - if (not os.path.exists(model_path)) or force_download: - headers = {} - if civitai_token: - headers["Authorization"] = f"Bearer {civitai_token}" - - try: - response = requests.get(download_url, stream=True, headers=headers) - response.raise_for_status() - except requests.HTTPError: - raise requests.HTTPError(f"Invalid URL: {download_url}, {response.status_code}") - - with tqdm.wrapattr( - open(model_path, "wb"), - "write", - miniters=1, - desc=file_name, - total=int(response.headers.get("content-length", 0)), - ) as fetched_model_info: - for chunk in response.iter_content(chunk_size=8192): - fetched_model_info.write(chunk) - else: - model_path = download_url - - output_info = get_keyword_types(model_path) - - if not include_params: - return model_path - else: - return SearchPipelineOutput( - model_path=model_path, - loading_method=output_info["loading_method"], - checkpoint_format=output_info["checkpoint_format"], - repo_status=RepoStatus( - repo_id=repo_name, - repo_hash=repo_id, - version=version_id - ), - model_status=ModelStatus( - search_word=search_word, - download_url=download_url, - file_name=file_name, - local=output_info["type"]["local"] - ) - ) - - - -class ModelSearchPipeline( - HFSearchPipeline, - CivitaiSearchPipeline - ): - - def __init__(self): - pass - - @classmethod - def for_hubs( - cls, - search_word: str, - **kwargs - ) -> Union[None, str, SearchPipelineOutput]: - r""" - Search and download models from multiple hubs. - - Parameters: - search_word (`str`): - The search query string. - model_type (`str`, *optional*, defaults to `Checkpoint`, Civitai only): - The type of model to search for. - revision (`str`, *optional*, Hugging Face only): - The specific version of the model to download. - include_params (`bool`, *optional*, defaults to `False`, both): - Whether to include parameters in the returned data. - checkpoint_format (`str`, *optional*, defaults to `"single_file"`, Hugging Face only): - The format of the model checkpoint. - download (`bool`, *optional*, defaults to `False`, both): - Whether to download the model. - pipeline_tag (`str`, *optional*, Hugging Face only): - Tag to filter models by pipeline. - base_model (`str`, *optional*, Civitai only): - The base model to filter by. - force_download (`bool`, *optional*, defaults to `False`, both): - Whether to force the download if the model already exists. - hf_token (`str`, *optional*, Hugging Face only): - API token for Hugging Face authentication. - civitai_token (`str`, *optional*, Civitai only): - API token for Civitai authentication. - skip_error (`bool`, *optional*, defaults to `False`, both): - Whether to skip errors and return None. - - Returns: - `Union[None, str, SearchPipelineOutput]`: The model path, SearchPipelineOutput, or None if not found. - """ - - return ( - cls.for_HF(search_word, skip_error=True, **kwargs) - or cls.for_civitai(search_word, skip_error=True, **kwargs) - ) From 4f3d4b0feae1acbac6a143b7d003d1fe6a5011d6 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Thu, 5 Dec 2024 07:26:16 +0900 Subject: [PATCH 21/32] Update examples/model_search/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/model_search/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index a82d4fae07e8..b2e2aecd6a9b 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -18,7 +18,7 @@ pip install . !wget https://raw.githubusercontent.com/suzukimain/auto_diffusers/refs/heads/master/src/auto_diffusers/pipeline_easy.py ``` -### Search for Civitai +## Load from Civitai ```python from pipeline_easy import ( EasyPipelineForText2Image, From 8b2e3e694b862c5582bcd6bc05f6ed78ca017b2d Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Thu, 5 Dec 2024 07:27:05 +0900 Subject: [PATCH 22/32] Update examples/model_search/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/model_search/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index b2e2aecd6a9b..9677a53147aa 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -77,7 +77,7 @@ pipeline = EasyPipelineForInpainting.from_huggingface( ``` -### Application Examples +## Search Civitai and Hugging Face ```python from pipeline_easy import ( From 2e32020c0c8ea947e7351d8f60f5a7d3c8093889 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Thu, 5 Dec 2024 07:27:21 +0900 Subject: [PATCH 23/32] Update examples/model_search/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/model_search/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index 9677a53147aa..7c2740bd18cc 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -47,7 +47,7 @@ pipeline = EasyPipelineForInpainting.from_civitai( ).to("cuda") ``` -### Search for Hugging Face +## Load from Hugging Face ```python from pipeline_easy import ( EasyPipelineForText2Image, From 68000fa526faba6759485bd4936aa12283009a75 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Thu, 5 Dec 2024 07:49:30 +0900 Subject: [PATCH 24/32] Update README.md --- examples/model_search/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index 7c2740bd18cc..ad73b6ff4d7b 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -14,6 +14,8 @@ Before running the scripts, make sure to install the library's training dependen git clone https://github.com/huggingface/diffusers cd diffusers pip install . +``` +Set up the pipeline. You can also cd to this folder and run it. ```bash !wget https://raw.githubusercontent.com/suzukimain/auto_diffusers/refs/heads/master/src/auto_diffusers/pipeline_easy.py ``` From cd706471900e365f092371c9b3188ef29189484d Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Thu, 5 Dec 2024 08:12:11 +0900 Subject: [PATCH 25/32] Organize the table of parameters --- examples/model_search/README.md | 28 ++-------------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index ad73b6ff4d7b..e6c0a3d21be0 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -112,20 +112,6 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") > [!TIP] > **If an error occurs, insert the `token` and run again.** -### `EasyPipeline.from_civitai` parameters - -| Name | Type | Default | Description | -|:---------------:|:----------------------:|:-------------:|:-----------------------------------------------------------------------------------:| -| search_word | string, Path | ー | The search query string. Can be a keyword, Civitai URL, local directory or file path. | -| model_type | string | `Checkpoint` | The type of model to search for.
(for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`) | -| base_model | string | None | Trained model tag (for example `SD 1.5`, `SD 3.5`, `SDXL 1.0`) | -| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. | -| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. | -| cache_dir | string, Path | None | Path to the folder where cached files are stored. | -| resume | bool | False | Whether to resume an incomplete download. | -| token | string | None | API token for Civitai authentication. | - - ### `search_civitai` parameters | Name | Type | Default | Description | @@ -133,6 +119,7 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") | search_word | string, Path | ー | The search query string. Can be a keyword, Civitai URL, local directory or file path. | | model_type | string | `Checkpoint` | The type of model to search for.
(for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`) | | base_model | string | None | Trained model tag (for example `SD 1.5`, `SD 3.5`, `SDXL 1.0`) | +| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. | | download | bool | False | Whether to download the model. | | force_download | bool | False | Whether to force the download if the model already exists. | | cache_dir | string, Path | None | Path to the folder where cached files are stored. | @@ -142,18 +129,6 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") | skip_error | bool | False | Whether to skip errors and return None. | -### `EasyPipeline.from_huggingface` parameters - -| Name | Type | Default | Description | -|:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:| -| search_word | string, Path | ー | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`/`). | -| checkpoint_format | string | `single_file` | The format of the model checkpoint.
● `single_file` to search for `single file checkpoint`
●`diffusers` to search for `multifolder diffusers format checkpoint` | -| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. | -| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. | -| cache_dir | string, Path | None | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. | -| token | string, bool | None | The token to use as HTTP bearer authorization for remote files. | - - ### `search_huggingface` parameters | Name | Type | Default | Description | @@ -161,6 +136,7 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") | search_word | string, Path | ー | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`/`). | | checkpoint_format | string | `single_file` | The format of the model checkpoint.
● `single_file` to search for `single file checkpoint`
●`diffusers` to search for `multifolder diffusers format checkpoint` | | pipeline_tag | string | None | Tag to filter models by pipeline. | +| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. | | download | bool | False | Whether to download the model. | | force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. | | cache_dir | string, Path | None | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. | From c9e87b82ca73b93b51d6b69c254de499b4b6e199 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Thu, 5 Dec 2024 08:30:10 +0900 Subject: [PATCH 26/32] Update README.md --- examples/model_search/README.md | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index e6c0a3d21be0..ad73b6ff4d7b 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -112,6 +112,20 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") > [!TIP] > **If an error occurs, insert the `token` and run again.** +### `EasyPipeline.from_civitai` parameters + +| Name | Type | Default | Description | +|:---------------:|:----------------------:|:-------------:|:-----------------------------------------------------------------------------------:| +| search_word | string, Path | ー | The search query string. Can be a keyword, Civitai URL, local directory or file path. | +| model_type | string | `Checkpoint` | The type of model to search for.
(for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`) | +| base_model | string | None | Trained model tag (for example `SD 1.5`, `SD 3.5`, `SDXL 1.0`) | +| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. | +| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. | +| cache_dir | string, Path | None | Path to the folder where cached files are stored. | +| resume | bool | False | Whether to resume an incomplete download. | +| token | string | None | API token for Civitai authentication. | + + ### `search_civitai` parameters | Name | Type | Default | Description | @@ -119,7 +133,6 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") | search_word | string, Path | ー | The search query string. Can be a keyword, Civitai URL, local directory or file path. | | model_type | string | `Checkpoint` | The type of model to search for.
(for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`) | | base_model | string | None | Trained model tag (for example `SD 1.5`, `SD 3.5`, `SDXL 1.0`) | -| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. | | download | bool | False | Whether to download the model. | | force_download | bool | False | Whether to force the download if the model already exists. | | cache_dir | string, Path | None | Path to the folder where cached files are stored. | @@ -129,6 +142,18 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") | skip_error | bool | False | Whether to skip errors and return None. | +### `EasyPipeline.from_huggingface` parameters + +| Name | Type | Default | Description | +|:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:| +| search_word | string, Path | ー | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`/`). | +| checkpoint_format | string | `single_file` | The format of the model checkpoint.
● `single_file` to search for `single file checkpoint`
●`diffusers` to search for `multifolder diffusers format checkpoint` | +| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. | +| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. | +| cache_dir | string, Path | None | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. | +| token | string, bool | None | The token to use as HTTP bearer authorization for remote files. | + + ### `search_huggingface` parameters | Name | Type | Default | Description | @@ -136,7 +161,6 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") | search_word | string, Path | ー | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`/`). | | checkpoint_format | string | `single_file` | The format of the model checkpoint.
● `single_file` to search for `single file checkpoint`
●`diffusers` to search for `multifolder diffusers format checkpoint` | | pipeline_tag | string | None | Tag to filter models by pipeline. | -| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. | | download | bool | False | Whether to download the model. | | force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. | | cache_dir | string, Path | None | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. | From b10a55a317d03e6280edaee29b3fc99b067d9056 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Thu, 5 Dec 2024 08:50:34 +0900 Subject: [PATCH 27/32] Update README.md --- examples/model_search/README.md | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index ad73b6ff4d7b..f7e2d944c77e 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -79,7 +79,7 @@ pipeline = EasyPipelineForInpainting.from_huggingface( ``` -## Search Civitai and Hugging Face +## Search Civitai and Huggingface ```python from pipeline_easy import ( @@ -109,10 +109,12 @@ TextualInversion = search_civitai( pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") ``` +## Search Civitai + > [!TIP] > **If an error occurs, insert the `token` and run again.** -### `EasyPipeline.from_civitai` parameters +#### `EasyPipeline.from_civitai` parameters | Name | Type | Default | Description | |:---------------:|:----------------------:|:-------------:|:-----------------------------------------------------------------------------------:| @@ -126,7 +128,7 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") | token | string | None | API token for Civitai authentication. | -### `search_civitai` parameters +#### `search_civitai` parameters | Name | Type | Default | Description | |:---------------:|:--------------:|:-------------:|:-----------------------------------------------------------------------------------:| @@ -141,8 +143,12 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") | include_params | bool | False | Whether to include parameters in the returned data. | | skip_error | bool | False | Whether to skip errors and return None. | +## Search Huggingface + +> [!TIP] +> **If an error occurs, insert the `token` and run again.** -### `EasyPipeline.from_huggingface` parameters +#### `EasyPipeline.from_huggingface` parameters | Name | Type | Default | Description | |:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:| @@ -154,7 +160,7 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") | token | string, bool | None | The token to use as HTTP bearer authorization for remote files. | -### `search_huggingface` parameters +#### `search_huggingface` parameters | Name | Type | Default | Description | |:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:| From 026202cd55e2cb8d104a1079bfb7e7d32104fd83 Mon Sep 17 00:00:00 2001 From: suzukimain <131413573+suzukimain@users.noreply.github.com> Date: Thu, 5 Dec 2024 09:54:46 +0900 Subject: [PATCH 28/32] Update README.md --- examples/model_search/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/model_search/README.md b/examples/model_search/README.md index f7e2d944c77e..ae91fd47569d 100644 --- a/examples/model_search/README.md +++ b/examples/model_search/README.md @@ -109,7 +109,7 @@ TextualInversion = search_civitai( pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") ``` -## Search Civitai +### Search Civitai > [!TIP] > **If an error occurs, insert the `token` and run again.** @@ -143,7 +143,7 @@ pipeline.load_textual_inversion(TextualInversion, token="EasyNegative") | include_params | bool | False | Whether to include parameters in the returned data. | | skip_error | bool | False | Whether to skip errors and return None. | -## Search Huggingface +### Search Huggingface > [!TIP] > **If an error occurs, insert the `token` and run again.** From efe5fba76cc12790085b3b9268ac27afbe211511 Mon Sep 17 00:00:00 2001 From: suzukimain Date: Fri, 6 Dec 2024 13:43:52 +0900 Subject: [PATCH 29/32] make style --- examples/model_search/pipeline_easy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py index 74045f716a75..79e89ba2d69d 100644 --- a/examples/model_search/pipeline_easy.py +++ b/examples/model_search/pipeline_easy.py @@ -374,7 +374,7 @@ def file_downloader( displayed_filename (`str`, *optional*): The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If not set, the filename is guessed from the URL or the `Content-Disposition` header. - + returns: None """ From 5df8f99dc78df328b8ca3b834b88bb8b3104851e Mon Sep 17 00:00:00 2001 From: suzukimain Date: Fri, 6 Dec 2024 19:18:30 +0900 Subject: [PATCH 30/32] Fixing the style of pipeline --- examples/model_search/pipeline_easy.py | 81 +++++++++++++------------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py index 79e89ba2d69d..50b5711609b3 100644 --- a/examples/model_search/pipeline_easy.py +++ b/examples/model_search/pipeline_easy.py @@ -16,13 +16,12 @@ import os import re import requests -from tqdm.auto import tqdm -from typing import Union from collections import OrderedDict from dataclasses import ( dataclass, asdict ) +from typing import Union from huggingface_hub.file_download import http_get from huggingface_hub.utils import validate_hf_hub_args @@ -31,11 +30,10 @@ hf_hub_download, ) -from diffusers.utils import logging from diffusers.loaders.single_file_utils import ( + _extract_repo_id_and_weights_name, infer_diffusers_model_type, load_single_file_checkpoint, - _extract_repo_id_and_weights_name, VALID_URL_PREFIXES, ) from diffusers.pipelines.auto_pipeline import ( @@ -48,6 +46,7 @@ StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, ) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import ( StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, @@ -58,7 +57,7 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) -from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import logging logger = logging.get_logger(__name__) @@ -189,7 +188,7 @@ class SearchResult: The status of the model. """ model_path: str = "" - loading_method: Union[str, None] = None + loading_method: Union[str, None] = None checkpoint_format: Union[str, None] = None repo_status: RepoStatus = RepoStatus() model_status: ModelStatus = ModelStatus() @@ -282,7 +281,7 @@ def get_keyword_types(keyword): `dict`: A dictionary containing the model format, loading method, and various types and extra types flags. """ - + # Initialize the status dictionary with default values status = { "checkpoint_format": None, @@ -299,16 +298,16 @@ def get_keyword_types(keyword): "missing_model_index": None, }, } - + # Check if the keyword is an HTTP or HTTPS URL status["extra_type"]["url"] = bool(re.search(r"^(https?)://", keyword)) - + # Check if the keyword is a file if os.path.isfile(keyword): status["type"]["local"] = True status["checkpoint_format"] = "single_file" status["loading_method"] = "from_single_file" - + # Check if the keyword is a directory elif os.path.isdir(keyword): status["type"]["local"] = True @@ -316,13 +315,13 @@ def get_keyword_types(keyword): status["loading_method"] = "from_pretrained" if not os.path.exists(os.path.join(keyword, "model_index.json")): status["extra_type"]["missing_model_index"] = True - + # Check if the keyword is a Civitai URL elif keyword.startswith("https://civitai.com/"): status["type"]["civitai_url"] = True status["checkpoint_format"] = "single_file" status["loading_method"] = None - + # Check if the keyword starts with any valid URL prefixes elif any(keyword.startswith(prefix) for prefix in VALID_URL_PREFIXES): repo_id, weights_name = _extract_repo_id_and_weights_name(keyword) @@ -334,19 +333,19 @@ def get_keyword_types(keyword): status["type"]["hf_repo"] = True status["checkpoint_format"] = "diffusers" status["loading_method"] = "from_pretrained" - + # Check if the keyword matches a Hugging Face repository format elif re.match(r"^[^/]+/[^/]+$", keyword): status["type"]["hf_repo"] = True status["checkpoint_format"] = "diffusers" status["loading_method"] = "from_pretrained" - + # If none of the above apply else: status["type"]["other"] = True status["checkpoint_format"] = None status["loading_method"] = None - + return status @@ -532,7 +531,7 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N ): diffusers_model_exists = True break - + elif ( any(file_path.endswith(ext) for ext in EXTENSION) and not any(config in file_path for config in CONFIG_FILE_LIST) @@ -540,7 +539,7 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N and os.path.basename(os.path.dirname(file_path)) not in DIFFUSERS_CONFIG_DIR ): file_list.append(file_path) - + # Exit from the loop if a multi-folder diffusers model or valid file is found if diffusers_model_exists or file_list: break @@ -560,7 +559,7 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N ) else: model_path = repo_id - + elif file_list: # Sort and find the safest model file_name = next( @@ -571,7 +570,7 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N ), file_list[0] ) - + if download: model_path = hf_hub_download( @@ -581,12 +580,12 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N token=token, force_download=force_download, ) - + if file_name: download_url = f"https://huggingface.co/{repo_id}/blob/main/{file_name}" else: download_url = f"https://huggingface.co/{repo_id}" - + output_info = get_keyword_types(model_path) if include_params: @@ -606,10 +605,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N local=download, ) ) - + else: return model_path - + def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]: r""" @@ -693,7 +692,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] return None else: raise ValueError("Invalid JSON response") - + # Sort repositories by download count in descending order sorted_repos = sorted(data["items"], key=lambda x: x["stats"]["downloadCount"], reverse=True) @@ -737,14 +736,14 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] else: continue break - + # Exception handling when search candidates are not found if not selected_model: if skip_error: return None else: raise ValueError("No model found. Please try changing the word you are searching for.") - + # Define model file status file_name = selected_model["filename"] download_url = selected_model["download_url"] @@ -814,7 +813,7 @@ def __init__(self, *args, **kwargs): # EnvironmentError is returned super().__init__() - + @classmethod @validate_hf_hub_args def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): @@ -929,10 +928,10 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): kwargs.update(_status) # Search for the model on Hugging Face and get the model status - hf_model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) + hf_model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) logger.warning(f"checkpoint_path: {hf_model_status.model_status.download_url}") checkpoint_path = hf_model_status.model_path - + # Check the format of the model checkpoint if hf_model_status.checkpoint_format == "single_file": # Load the pipeline from a single file checkpoint @@ -943,7 +942,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): ) else: return cls.from_pretrained(checkpoint_path, **kwargs) - + @classmethod def from_civitai(cls, pretrained_model_link_or_path, **kwargs): r""" @@ -1047,7 +1046,7 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs): pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING, **kwargs ) - + class EasyPipelineForImage2Image(AutoPipelineForImage2Image): @@ -1071,7 +1070,7 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image): def __init__(self, *args, **kwargs): # EnvironmentError is returned super().__init__() - + @classmethod @validate_hf_hub_args def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): @@ -1186,10 +1185,10 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): kwargs.update(_parmas) # Search for the model on Hugging Face and get the model status - model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) + model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") checkpoint_path = model_status.model_path - + # Check the format of the model checkpoint if model_status.checkpoint_format == "single_file": # Load the pipeline from a single file checkpoint @@ -1200,7 +1199,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): ) else: return cls.from_pretrained(checkpoint_path, **kwargs) - + @classmethod def from_civitai(cls, pretrained_model_link_or_path, **kwargs): r""" @@ -1305,7 +1304,7 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs): **kwargs ) - + class EasyPipelineForInpainting(AutoPipelineForInpainting): r""" @@ -1328,7 +1327,7 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting): def __init__(self, *args, **kwargs): # EnvironmentError is returned super().__init__() - + @classmethod @validate_hf_hub_args def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): @@ -1443,10 +1442,10 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): kwargs.update(_status) # Search for the model on Hugging Face and get the model status - model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) + model_status = search_huggingface(pretrained_model_link_or_path, **kwargs) logger.warning(f"checkpoint_path: {model_status.model_status.download_url}") checkpoint_path = model_status.model_path - + # Check the format of the model checkpoint if model_status.checkpoint_format == "single_file": # Load the pipeline from a single file checkpoint @@ -1457,7 +1456,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): ) else: return cls.from_pretrained(checkpoint_path, **kwargs) - + @classmethod def from_civitai(cls, pretrained_model_link_or_path, **kwargs): r""" @@ -1560,4 +1559,4 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs): pretrained_model_or_path=checkpoint_path, pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING, **kwargs - ) \ No newline at end of file + ) From 4ff6e388c7890eb84c8f3442d50f992153967df5 Mon Sep 17 00:00:00 2001 From: suzukimain Date: Fri, 6 Dec 2024 20:09:17 +0900 Subject: [PATCH 31/32] Fix pipeline style --- examples/model_search/pipeline_easy.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py index 50b5711609b3..6555850cd71b 100644 --- a/examples/model_search/pipeline_easy.py +++ b/examples/model_search/pipeline_easy.py @@ -15,31 +15,25 @@ import os import re -import requests from collections import OrderedDict -from dataclasses import ( - dataclass, - asdict -) +from dataclasses import asdict, dataclass from typing import Union +import requests +from huggingface_hub import hf_api, hf_hub_download from huggingface_hub.file_download import http_get from huggingface_hub.utils import validate_hf_hub_args -from huggingface_hub import ( - hf_api, - hf_hub_download, -) from diffusers.loaders.single_file_utils import ( + VALID_URL_PREFIXES, _extract_repo_id_and_weights_name, infer_diffusers_model_type, load_single_file_checkpoint, - VALID_URL_PREFIXES, ) from diffusers.pipelines.auto_pipeline import ( - AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting, + AutoPipelineForText2Image, ) from diffusers.pipelines.controlnet import ( StableDiffusionControlNetImg2ImgPipeline, @@ -59,6 +53,7 @@ ) from diffusers.utils import logging + logger = logging.get_logger(__name__) From 5e92b8e0a904ce86cf5393a3747e08cfaeaba404 Mon Sep 17 00:00:00 2001 From: suzukimain Date: Fri, 6 Dec 2024 20:24:38 +0900 Subject: [PATCH 32/32] fix --- examples/model_search/pipeline_easy.py | 86 ++++++++++---------------- 1 file changed, 34 insertions(+), 52 deletions(-) diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py index 6555850cd71b..8264ffad28f6 100644 --- a/examples/model_search/pipeline_easy.py +++ b/examples/model_search/pipeline_easy.py @@ -124,10 +124,11 @@ "inpainting_v2", ] -EXTENSION = [".safetensors", ".ckpt", ".bin"] +EXTENSION = [".safetensors", ".ckpt", ".bin"] CACHE_HOME = os.path.expanduser("~/.cache") + @dataclass class RepoStatus: r""" @@ -141,10 +142,12 @@ class RepoStatus: version (`str`): The version ID of the repository. """ + repo_id: str = "" repo_hash: str = "" version: str = "" + @dataclass class ModelStatus: r""" @@ -160,11 +163,13 @@ class ModelStatus: local (`bool`): Whether the model exists locally """ + search_word: str = "" download_url: str = "" file_name: str = "" local: bool = False + @dataclass class SearchResult: r""" @@ -182,6 +187,7 @@ class SearchResult: model_status (`ModelStatus`): The status of the model. """ + model_path: str = "" loading_method: Union[str, None] = None checkpoint_format: Union[str, None] = None @@ -345,10 +351,10 @@ def get_keyword_types(keyword): def file_downloader( - url, - save_path, - **kwargs, - ) -> None: + url, + save_path, + **kwargs, +) -> None: """ Downloads a file from a given URL and saves it to the specified path. @@ -493,7 +499,7 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N pipeline_tag=pipeline_tag, full=True, gated=gated, - token=token + token=token, ) model_dicts = [asdict(value) for value in list(hf_models)] @@ -508,22 +514,16 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N for repo_info in model_dicts: repo_id = repo_info["id"] file_list = [] - hf_repo_info = hf_api.model_info( - repo_id=repo_id, - securityStatus=True - ) + hf_repo_info = hf_api.model_info(repo_id=repo_id, securityStatus=True) # Lists files with security issues. hf_security_info = hf_repo_info.security_repo_status - exclusion = [issue['path'] for issue in hf_security_info['filesWithIssues']] + exclusion = [issue["path"] for issue in hf_security_info["filesWithIssues"]] # Checks for multi-folder diffusers model or valid files (models with security issues are excluded). if hf_security_info["scansDone"]: for info in repo_info["siblings"]: file_path = info["rfilename"] - if ( - "model_index.json" == file_path - and checkpoint_format in ["diffusers", "all"] - ): + if "model_index.json" == file_path and checkpoint_format in ["diffusers", "all"]: diffusers_model_exists = True break @@ -558,15 +558,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N elif file_list: # Sort and find the safest model file_name = next( - ( - model - for model in sorted(file_list, reverse=True) - if re.search(r"(?i)[-_](safe|sfw)", model) - ), - file_list[0] + (model for model in sorted(file_list, reverse=True) if re.search(r"(?i)[-_](safe|sfw)", model)), + file_list[0], ) - if download: model_path = hf_hub_download( repo_id=repo_id, @@ -588,17 +583,13 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N model_path=model_path or download_url, loading_method=output_info["loading_method"], checkpoint_format=output_info["checkpoint_format"], - repo_status=RepoStatus( - repo_id=repo_id, - repo_hash=hf_repo_info.sha, - version=revision - ), + repo_status=RepoStatus(repo_id=repo_id, repo_hash=hf_repo_info.sha, version=revision), model_status=ModelStatus( search_word=search_word, download_url=download_url, file_name=file_name, local=download, - ) + ), ) else: @@ -673,9 +664,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] try: # Make the request to the CivitAI API - response = requests.get( - "https://civitai.com/api/v1/models", params=params, headers=headers - ) + response = requests.get("https://civitai.com/api/v1/models", params=params, headers=headers) response.raise_for_status() except requests.exceptions.HTTPError as err: raise requests.HTTPError(f"Could not get elements from the URL: {err}") @@ -696,7 +685,9 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] repo_id = selected_repo["id"] # Sort versions within the selected repo by download count - sorted_versions = sorted(selected_repo["modelVersions"], key=lambda x: x["stats"]["downloadCount"], reverse=True) + sorted_versions = sorted( + selected_repo["modelVersions"], key=lambda x: x["stats"]["downloadCount"], reverse=True + ) for selected_version in sorted_versions: version_id = selected_version["id"] models_list = [] @@ -724,7 +715,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] for model_data in sorted_models if bool(re.search(r"(?i)[-_](safe|sfw)", model_data["filename"])) ), - sorted_models[0] + sorted_models[0], ) break @@ -746,9 +737,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] # Handle file download and setting model information if download: # The path where the model is to be saved. - model_path = os.path.join( - str(civitai_cache_dir), str(repo_id), str(version_id), str(file_name) - ) + model_path = os.path.join(str(civitai_cache_dir), str(repo_id), str(version_id), str(file_name)) # Download Model File file_downloader( url=download_url, @@ -772,17 +761,13 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None] model_path=model_path, loading_method=output_info["loading_method"], checkpoint_format=output_info["checkpoint_format"], - repo_status=RepoStatus( - repo_id=repo_name, - repo_hash=repo_id, - version=version_id - ), + repo_status=RepoStatus(repo_id=repo_name, repo_hash=repo_id, version=version_id), model_status=ModelStatus( search_word=search_word, download_url=download_url, file_name=file_name, - local=output_info["type"]["local"] - ) + local=output_info["type"]["local"], + ), ) @@ -808,7 +793,6 @@ def __init__(self, *args, **kwargs): # EnvironmentError is returned super().__init__() - @classmethod @validate_hf_hub_args def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): @@ -933,7 +917,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): return load_pipeline_from_single_file( pretrained_model_or_path=checkpoint_path, pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING, - **kwargs + **kwargs, ) else: return cls.from_pretrained(checkpoint_path, **kwargs) @@ -1039,11 +1023,10 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs): return load_pipeline_from_single_file( pretrained_model_or_path=checkpoint_path, pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING, - **kwargs + **kwargs, ) - class EasyPipelineForImage2Image(AutoPipelineForImage2Image): r""" @@ -1190,7 +1173,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): return load_pipeline_from_single_file( pretrained_model_or_path=checkpoint_path, pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING, - **kwargs + **kwargs, ) else: return cls.from_pretrained(checkpoint_path, **kwargs) @@ -1296,11 +1279,10 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs): return load_pipeline_from_single_file( pretrained_model_or_path=checkpoint_path, pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING, - **kwargs + **kwargs, ) - class EasyPipelineForInpainting(AutoPipelineForInpainting): r""" @@ -1447,7 +1429,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs): return load_pipeline_from_single_file( pretrained_model_or_path=checkpoint_path, pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING, - **kwargs + **kwargs, ) else: return cls.from_pretrained(checkpoint_path, **kwargs) @@ -1553,5 +1535,5 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs): return load_pipeline_from_single_file( pretrained_model_or_path=checkpoint_path, pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING, - **kwargs + **kwargs, )