From 43f96a25dc5cddc7ddb4561dd54537c5e0833b30 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Tue, 9 May 2023 16:38:58 +0530 Subject: [PATCH] Add multiclip and image encoder --- src/diffusers/pipelines/kandinsky/multiclip.py | 18 ++++++++++++++++++ .../pipelines/kandinsky/pipeline_kandinsky.py | 16 ++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/pipelines/kandinsky/multiclip.py diff --git a/src/diffusers/pipelines/kandinsky/multiclip.py b/src/diffusers/pipelines/kandinsky/multiclip.py new file mode 100644 index 000000000000..cf0b4d8f82ce --- /dev/null +++ b/src/diffusers/pipelines/kandinsky/multiclip.py @@ -0,0 +1,18 @@ +import torch +from torch import nn +from transformers import XLMRobertaPreTrainedModel, XLMRobertaModel + +class MultilingualCLIP(XLMRobertaPreTrainedModel): + def __init__(self, config, in_features=1024, out_features=768): # 1024, 768 + super().__init__(config) + self.transformer = XLMRobertaModel(config) + self.LinearTransformation = torch.nn.Linear( + in_features=in_features, out_features=out_features + ) + + def forward(self, input_ids, attention_mask): + embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] + embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum( + dim=1 + )[:, None] + return self.LinearTransformation(embs2), embs diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py index 83de5526d04e..af422fd848a8 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -16,12 +16,12 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPTokenizer, XLMRobertaTokenizerFast from ...models import PriorTransformer, UNet2DConditionModel from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler - +from .multiclip import MultilingualCLIP from ...utils import ( logging, randn_tensor, @@ -39,6 +39,8 @@ class KandinskyPipeline(DiffusionPipeline): Args: text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. prior_tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). @@ -46,14 +48,21 @@ class KandinskyPipeline(DiffusionPipeline): The canonincal unCLIP prior to approximate the image embedding from the text embedding. scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. + multiclip ([`MultilingualCLIP`]): + A multilingual text encoder. + multiclip_tokenizer ([`XLMRobertaTokenizerFast`]): + Tokenizer for multiclip """ def __init__( self, prior: PriorTransformer, text_encoder: CLIPTextModelWithProjection, + image_encoder: CLIPVisionModelWithProjection, prior_tokenizer: CLIPTokenizer, prior_scheduler: UnCLIPScheduler, + multiclip: MultilingualCLIP, + multiclip_tokenizer: XLMRobertaTokenizerFast, ): super().__init__() @@ -62,6 +71,9 @@ def __init__( text_encoder=text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, + image_encoder=image_encoder, + multiclip=multiclip, + multiclip_tokenizer=multiclip_tokenizer, ) def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):