diff --git a/src/diffusers/pipelines/kandinsky/multiclip.py b/src/diffusers/pipelines/kandinsky/multiclip.py new file mode 100644 index 000000000000..cf0b4d8f82ce --- /dev/null +++ b/src/diffusers/pipelines/kandinsky/multiclip.py @@ -0,0 +1,18 @@ +import torch +from torch import nn +from transformers import XLMRobertaPreTrainedModel, XLMRobertaModel + +class MultilingualCLIP(XLMRobertaPreTrainedModel): + def __init__(self, config, in_features=1024, out_features=768): # 1024, 768 + super().__init__(config) + self.transformer = XLMRobertaModel(config) + self.LinearTransformation = torch.nn.Linear( + in_features=in_features, out_features=out_features + ) + + def forward(self, input_ids, attention_mask): + embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] + embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum( + dim=1 + )[:, None] + return self.LinearTransformation(embs2), embs diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py index e8b2af450515..d5b619535aa6 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -16,11 +16,13 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPTokenizer, XLMRobertaTokenizerFast from ...models import PriorTransformer, UNet2DConditionModel from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler + +from .multiclip import MultilingualCLIP from .text_proj import KandinskyTextProjModel from ...utils import ( @@ -52,19 +54,29 @@ class KandinskyPriorPipeline(DiffusionPipeline): The canonincal unCLIP prior to approximate the image embedding from the text embedding. prior_text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. prior_tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior_scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. + multiclip ([`MultilingualCLIP`]): + A multilingual text encoder. + multiclip_tokenizer ([`XLMRobertaTokenizerFast`]): + Tokenizer for multiclip """ def __init__( self, prior: PriorTransformer, + text_encoder: CLIPTextModelWithProjection, + image_encoder: CLIPVisionModelWithProjection, prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, prior_scheduler: UnCLIPScheduler, + multiclip: MultilingualCLIP, + multiclip_tokenizer: XLMRobertaTokenizerFast, ): super().__init__() @@ -73,6 +85,9 @@ def __init__( prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, + image_encoder=image_encoder, + multiclip=multiclip, + multiclip_tokenizer=multiclip_tokenizer, ) def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):