From 1acb062652de607667b49d06dc3f25ff9e1c028a Mon Sep 17 00:00:00 2001 From: pdoane Date: Fri, 21 Apr 2023 08:17:10 -0700 Subject: [PATCH] Fix issue in maybe_convert_prompt When the token used for textual inversion does not have any special symbols (e.g. it is not surrounded by <>), the tokenizer does not properly split the replacement tokens. Adding a space for the padding tokens fixes this. --- src/diffusers/loaders.py | 2 +- tests/pipelines/test_pipelines.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b4c443fd303b..8878fb116d1d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -410,7 +410,7 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): replacement = token i = 1 while f"{token}_{i}" in tokenizer.added_tokens_encoder: - replacement += f"{token}_{i}" + replacement += f" {token}_{i}" i += 1 prompt = prompt.replace(token, replacement) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index a5d70b01d453..8fb79f0c4057 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -541,7 +541,7 @@ def test_text_inversion_download(self): assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 - assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***><***>_1<***>_2" + assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***> <***>_1 <***>_2" prompt = "hey <***>" out = pipe(prompt, num_inference_steps=1, output_type="numpy").images @@ -569,7 +569,7 @@ def test_text_inversion_download(self): assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 - assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****><****>_1<****>_2" + assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****> <****>_1 <****>_2" prompt = "hey <****>" out = pipe(prompt, num_inference_steps=1, output_type="numpy").images