From 6ffa9d467ac46e94841deaed9284814ec40daf4d Mon Sep 17 00:00:00 2001 From: "Gregory D. Hunkins" Date: Mon, 15 May 2023 16:48:40 -0400 Subject: [PATCH 1/9] enable state dict for textual inversion loader --- src/diffusers/loaders.py | 71 +++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a1f0d8ec2a52..157a94dab135 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -465,7 +465,7 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): def load_textual_inversion( self, - pretrained_model_name_or_path: Union[str, List[str]], + pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], token: Optional[Union[str, List[str]]] = None, **kwargs, ): @@ -480,7 +480,7 @@ def load_textual_inversion( Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`): + pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. @@ -489,6 +489,8 @@ def load_textual_inversion( - A path to a *directory* containing textual inversion weights, e.g. `./my_text_inversion_directory/`. - A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). Or a list of those elements. token (`str` or `List[str]`, *optional*): @@ -613,7 +615,7 @@ def load_textual_inversion( "framework": "pytorch", } - if isinstance(pretrained_model_name_or_path, str): + if not isinstance(pretrained_model_name_or_path, list): pretrained_model_name_or_paths = [pretrained_model_name_or_path] else: pretrained_model_name_or_paths = pretrained_model_name_or_path @@ -638,16 +640,38 @@ def load_textual_inversion( token_ids_and_embeddings = [] for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): - # 1. Load textual inversion file - model_file = None - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: + if not isinstance(pretrained_model_name_or_path, dict): + # 1. Load textual inversion file + model_file = None + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except Exception as e: + if not allow_pickle: + raise e + + model_file = None + + if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, - weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + weights_name=weight_name or TEXT_INVERSION_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, @@ -658,28 +682,9 @@ def load_textual_inversion( subfolder=subfolder, user_agent=user_agent, ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except Exception as e: - if not allow_pickle: - raise e - - model_file = None - - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=weight_name or TEXT_INVERSION_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = torch.load(model_file, map_location="cpu") + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path # 2. Load token and embedding correcly from file if isinstance(state_dict, torch.Tensor): From e5fba6985f9624fc8581507b267de56365268648 Mon Sep 17 00:00:00 2001 From: "Gregory D. Hunkins" Date: Wed, 17 May 2023 12:19:22 -0400 Subject: [PATCH 2/9] Empty-Commit | restart CI From b12690727a936a192f64ba96bfe69a0aa05f8a03 Mon Sep 17 00:00:00 2001 From: "Gregory D. Hunkins" Date: Wed, 17 May 2023 14:21:58 -0400 Subject: [PATCH 3/9] Empty-Commit | restart CI From 6a8b49bca4268ab4b7aa8f5864e719bb2b79a3bb Mon Sep 17 00:00:00 2001 From: "Gregory D. Hunkins" Date: Wed, 17 May 2023 18:24:59 -0400 Subject: [PATCH 4/9] Empty-Commit | restart CI From d80fbd3e1fb0ac5fa1462ff7d5f13eec30e3467c Mon Sep 17 00:00:00 2001 From: "Gregory D. Hunkins" Date: Wed, 17 May 2023 23:38:21 -0400 Subject: [PATCH 5/9] Empty-Commit | restart CI From a85b3714000d70d6490dd337b39c11a4423a3e73 Mon Sep 17 00:00:00 2001 From: "Gregory D. Hunkins" Date: Fri, 26 May 2023 13:55:19 -0400 Subject: [PATCH 6/9] add tests --- tests/pipelines/test_pipelines.py | 59 +++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 8eaee0915a4f..94d1faae3696 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -663,6 +663,65 @@ def test_text_inversion_download(self): out = pipe(prompt, num_inference_steps=1, output_type="numpy").images assert out.shape == (1, 128, 128, 3) + # single token state dict load + ten = {"<*>": torch.ones((32,))} + pipe.load_textual_inversion(ten) + + token = pipe.tokenizer.convert_tokens_to_ids("<*>") + assert token == num_tokens, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32 + assert pipe._maybe_convert_prompt("<*>", pipe.tokenizer) == "<*>" + + prompt = "hey <*>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # multi embedding state dict load + ten1 = {"<*****>": torch.ones((32,))} + ten2 = {"<******>": 2 * torch.ones((1, 32))} + + pipe.load_textual_inversion([ten1, ten2]) + + token = pipe.tokenizer.convert_tokens_to_ids("<*****>") + assert token == num_tokens + 8, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32 + assert pipe._maybe_convert_prompt("<*****>", pipe.tokenizer) == "<*****>" + + token = pipe.tokenizer.convert_tokens_to_ids("<******>") + assert token == num_tokens + 9, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64 + assert pipe._maybe_convert_prompt("<******>", pipe.tokenizer) == "<******>" + + prompt = "hey <*****> <******>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # auto1111 multi-token state dict load + ten = { + "string_to_param": { + "*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))]) + }, + "name": "<****>", + } + + pipe.load_textual_inversion(ten) + + token = pipe.tokenizer.convert_tokens_to_ids("<****>") + token_1 = pipe.tokenizer.convert_tokens_to_ids("<****>_1") + token_2 = pipe.tokenizer.convert_tokens_to_ids("<****>_2") + + assert token == num_tokens + 5, "Added token must be at spot `num_tokens`" + assert token_1 == num_tokens + 6, "Added token must be at spot `num_tokens`" + assert token_2 == num_tokens + 7, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 + assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 + assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****> <****>_1 <****>_2" + + prompt = "hey <****>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + def test_download_ignore_files(self): # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 with tempfile.TemporaryDirectory() as tmpdirname: From 9877fff46109a7d8c4883f4c7842acc614de2a44 Mon Sep 17 00:00:00 2001 From: "Gregory D. Hunkins" Date: Fri, 26 May 2023 14:10:28 -0400 Subject: [PATCH 7/9] fix tests --- tests/pipelines/test_pipelines.py | 36 +++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 94d1faae3696..48dbe4990250 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -664,51 +664,51 @@ def test_text_inversion_download(self): assert out.shape == (1, 128, 128, 3) # single token state dict load - ten = {"<*>": torch.ones((32,))} + ten = {"": torch.ones((32,))} pipe.load_textual_inversion(ten) - token = pipe.tokenizer.convert_tokens_to_ids("<*>") + token = pipe.tokenizer.convert_tokens_to_ids("") assert token == num_tokens, "Added token must be at spot `num_tokens`" assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32 - assert pipe._maybe_convert_prompt("<*>", pipe.tokenizer) == "<*>" + assert pipe._maybe_convert_prompt("", pipe.tokenizer) == "" - prompt = "hey <*>" + prompt = "hey " out = pipe(prompt, num_inference_steps=1, output_type="numpy").images assert out.shape == (1, 128, 128, 3) # multi embedding state dict load - ten1 = {"<*****>": torch.ones((32,))} - ten2 = {"<******>": 2 * torch.ones((1, 32))} + ten1 = {"": torch.ones((32,))} + ten2 = {"": 2 x torch.ones((1, 32))} pipe.load_textual_inversion([ten1, ten2]) - token = pipe.tokenizer.convert_tokens_to_ids("<*****>") + token = pipe.tokenizer.convert_tokens_to_ids("") assert token == num_tokens + 8, "Added token must be at spot `num_tokens`" assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32 - assert pipe._maybe_convert_prompt("<*****>", pipe.tokenizer) == "<*****>" + assert pipe._maybe_convert_prompt("", pipe.tokenizer) == "" - token = pipe.tokenizer.convert_tokens_to_ids("<******>") + token = pipe.tokenizer.convert_tokens_to_ids("") assert token == num_tokens + 9, "Added token must be at spot `num_tokens`" assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64 - assert pipe._maybe_convert_prompt("<******>", pipe.tokenizer) == "<******>" + assert pipe._maybe_convert_prompt("", pipe.tokenizer) == "" - prompt = "hey <*****> <******>" + prompt = "hey " out = pipe(prompt, num_inference_steps=1, output_type="numpy").images assert out.shape == (1, 128, 128, 3) # auto1111 multi-token state dict load ten = { "string_to_param": { - "*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))]) + "*": torch.cat([3 x torch.ones((1, 32)), 4 x torch.ones((1, 32)), 5 x torch.ones((1, 32))]) }, - "name": "<****>", + "name": "", } pipe.load_textual_inversion(ten) - token = pipe.tokenizer.convert_tokens_to_ids("<****>") - token_1 = pipe.tokenizer.convert_tokens_to_ids("<****>_1") - token_2 = pipe.tokenizer.convert_tokens_to_ids("<****>_2") + token = pipe.tokenizer.convert_tokens_to_ids("") + token_1 = pipe.tokenizer.convert_tokens_to_ids("_1") + token_2 = pipe.tokenizer.convert_tokens_to_ids("_2") assert token == num_tokens + 5, "Added token must be at spot `num_tokens`" assert token_1 == num_tokens + 6, "Added token must be at spot `num_tokens`" @@ -716,9 +716,9 @@ def test_text_inversion_download(self): assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 - assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****> <****>_1 <****>_2" + assert pipe._maybe_convert_prompt("", pipe.tokenizer) == " _1 _2" - prompt = "hey <****>" + prompt = "hey " out = pipe(prompt, num_inference_steps=1, output_type="numpy").images assert out.shape == (1, 128, 128, 3) From 72f06ebe29bec6084c3870a5eb50fdab9070b9de Mon Sep 17 00:00:00 2001 From: "Gregory D. Hunkins" Date: Fri, 26 May 2023 14:15:21 -0400 Subject: [PATCH 8/9] fix tests --- tests/pipelines/test_pipelines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 48dbe4990250..0b0dd2b74441 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -678,7 +678,7 @@ def test_text_inversion_download(self): # multi embedding state dict load ten1 = {"": torch.ones((32,))} - ten2 = {"": 2 x torch.ones((1, 32))} + ten2 = {"": 2 * torch.ones((1, 32))} pipe.load_textual_inversion([ten1, ten2]) @@ -699,7 +699,7 @@ def test_text_inversion_download(self): # auto1111 multi-token state dict load ten = { "string_to_param": { - "*": torch.cat([3 x torch.ones((1, 32)), 4 x torch.ones((1, 32)), 5 x torch.ones((1, 32))]) + "*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))]) }, "name": "", } From 1c1a69d19f6b8d325d3256462d361e87d05abe8c Mon Sep 17 00:00:00 2001 From: "Gregory D. Hunkins" Date: Fri, 26 May 2023 14:31:43 -0400 Subject: [PATCH 9/9] fix tests --- tests/pipelines/test_pipelines.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 0b0dd2b74441..bb7c980875ef 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -668,7 +668,7 @@ def test_text_inversion_download(self): pipe.load_textual_inversion(ten) token = pipe.tokenizer.convert_tokens_to_ids("") - assert token == num_tokens, "Added token must be at spot `num_tokens`" + assert token == num_tokens + 10, "Added token must be at spot `num_tokens`" assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32 assert pipe._maybe_convert_prompt("", pipe.tokenizer) == "" @@ -683,12 +683,12 @@ def test_text_inversion_download(self): pipe.load_textual_inversion([ten1, ten2]) token = pipe.tokenizer.convert_tokens_to_ids("") - assert token == num_tokens + 8, "Added token must be at spot `num_tokens`" + assert token == num_tokens + 11, "Added token must be at spot `num_tokens`" assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32 assert pipe._maybe_convert_prompt("", pipe.tokenizer) == "" token = pipe.tokenizer.convert_tokens_to_ids("") - assert token == num_tokens + 9, "Added token must be at spot `num_tokens`" + assert token == num_tokens + 12, "Added token must be at spot `num_tokens`" assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64 assert pipe._maybe_convert_prompt("", pipe.tokenizer) == "" @@ -710,9 +710,9 @@ def test_text_inversion_download(self): token_1 = pipe.tokenizer.convert_tokens_to_ids("_1") token_2 = pipe.tokenizer.convert_tokens_to_ids("_2") - assert token == num_tokens + 5, "Added token must be at spot `num_tokens`" - assert token_1 == num_tokens + 6, "Added token must be at spot `num_tokens`" - assert token_2 == num_tokens + 7, "Added token must be at spot `num_tokens`" + assert token == num_tokens + 13, "Added token must be at spot `num_tokens`" + assert token_1 == num_tokens + 14, "Added token must be at spot `num_tokens`" + assert token_2 == num_tokens + 15, "Added token must be at spot `num_tokens`" assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160