diff --git a/src/diffusers/modeling_flax_pytorch_utils.py b/src/diffusers/modeling_flax_pytorch_utils.py new file mode 100644 index 000000000000..9c7a5de2ad6e --- /dev/null +++ b/src/diffusers/modeling_flax_pytorch_utils.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch - Flax general utilities.""" +import re + +import jax.numpy as jnp +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +def rename_key(key): + regex = r"\w+[.]\d+" + pats = re.findall(regex, key) + for pat in pats: + key = key.replace(pat, "_".join(pat.split("."))) + return key + + +##################### +# PyTorch => Flax # +##################### + +# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 +# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py +def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): + """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" + + # conv norm or layer norm + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + if ( + any("norm" in str_ for str_ in pt_tuple_key) + and (pt_tuple_key[-1] == "bias") + and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) + and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) + ): + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + return renamed_pt_tuple_key, pt_tensor + elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + return renamed_pt_tuple_key, pt_tensor + + # embedding + if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: + pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) + return renamed_pt_tuple_key, pt_tensor + + # conv layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) + return renamed_pt_tuple_key, pt_tensor + + # linear layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight": + pt_tensor = pt_tensor.T + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm weight + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) + if pt_tuple_key[-1] == "gamma": + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm bias + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) + if pt_tuple_key[-1] == "beta": + return renamed_pt_tuple_key, pt_tensor + + return pt_tuple_key, pt_tensor + + +def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): + # Step 1: Convert pytorch tensor to numpy + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + # Step 2: Since the model is stateless, get random Flax params + random_flax_params = flax_model.init_weights(PRNGKey(init_key)) + + random_flax_state_dict = flatten_dict(random_flax_params) + flax_state_dict = {} + + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + renamed_pt_key = rename_key(pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + + return unflatten_dict(flax_state_dict) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 505c2881cbab..d1dfacf36265 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -27,7 +27,8 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError -from .modeling_utils import WEIGHTS_NAME +from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax +from .modeling_utils import WEIGHTS_NAME, load_state_dict from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging @@ -245,6 +246,8 @@ def from_pretrained( The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). Behaves differently depending on whether a `config` is provided or @@ -272,6 +275,7 @@ def from_pretrained( config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) + from_pt = kwargs.pop("from_pt", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) @@ -306,10 +310,16 @@ def from_pretrained( # Load from a Flax checkpoint model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) # At this stage we don't have a weight file so we will raise an error. + elif from_pt: + if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + raise EnvironmentError( + f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " + ) + model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): raise EnvironmentError( - f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " - "but there is a file for PyTorch weights." + f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model" + " using `from_pt=True`." ) else: raise EnvironmentError( @@ -320,7 +330,7 @@ def from_pretrained( try: model_file = hf_hub_download( pretrained_model_name_or_path, - filename=FLAX_WEIGHTS_NAME, + filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -370,25 +380,32 @@ def from_pretrained( f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." ) - try: - with open(model_file, "rb") as state_f: - state = from_bytes(cls, state_f.read()) - except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + if from_pt: + # Step 1: Get the pytorch file + pytorch_model_file = load_state_dict(model_file) + + # Step 2: Convert the weights + state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) + else: try: - with open(model_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please" - " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" - " folder you cloned." - ) - else: - raise ValueError from e - except (UnicodeDecodeError, ValueError): - raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") - # make sure all arrays are stored as jnp.ndarray - # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: - # https://github.com/google/flax/issues/1261 + with open(model_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + # make sure all arrays are stored as jnp.ndarray + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) # flatten dicts diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 918c7469a74c..525be4818dcc 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -32,7 +32,7 @@ def setup(self): self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") - self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out") + self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -82,9 +82,9 @@ class FlaxBasicTransformerBlock(nn.Module): def setup(self): # self attention - self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) # cross attention - self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) @@ -93,12 +93,12 @@ def setup(self): def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states - hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic) + hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention residual = hidden_states - hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic) + hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) hidden_states = hidden_states + residual # feed forward @@ -167,14 +167,28 @@ class FlaxGluFeedForward(nn.Module): dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 + def setup(self): + # The second linear layer needs to be called + # net_2 for now to match the index of the Sequential layer + self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) + self.net_2 = nn.Dense(self.dim, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.net_0(hidden_states) + hidden_states = self.net_2(hidden_states) + return hidden_states + + +class FlaxGEGLU(nn.Module): + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + def setup(self): inner_dim = self.dim * 4 - self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype) - self.dense2 = nn.Dense(self.dim, dtype=self.dtype) + self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): - hidden_states = self.dense1(hidden_states) + hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) - hidden_states = hidden_linear * nn.gelu(hidden_gelu) - hidden_states = self.dense2(hidden_states) - return hidden_states + return hidden_linear * nn.gelu(hidden_gelu) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 636a7ef9816a..d0fcd9f6ae13 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -76,7 +76,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: # init input tensors - sample_shape = (1, self.sample_size, self.sample_size, self.in_channels) + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) @@ -214,10 +214,17 @@ def __call__( When returning a tuple, the first element is the sample tensor. """ # 1. time + if not isinstance(timesteps, jnp.ndarray): + timesteps = jnp.array([timesteps], dtype=jnp.int32) + elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: + timesteps = timesteps.astype(dtype=jnp.float32) + timesteps = jnp.expand_dims(timesteps, 0) + t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) # 2. pre-process + sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) # 3. down @@ -251,6 +258,7 @@ def __call__( sample = self.conv_norm_out(sample) sample = nn.silu(sample) sample = self.conv_out(sample) + sample = jnp.transpose(sample, (0, 3, 1, 2)) if not return_dict: return (sample,) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index ce67eb12b19f..fe6cc3b194e3 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -55,7 +55,7 @@ def setup(self): self.attentions = attentions if self.add_downsample: - self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): output_states = () @@ -66,7 +66,7 @@ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=Tru output_states += (hidden_states,) if self.add_downsample: - hidden_states = self.downsample(hidden_states) + hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states @@ -96,7 +96,7 @@ def setup(self): self.resnets = resnets if self.add_downsample: - self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, deterministic=True): output_states = () @@ -106,7 +106,7 @@ def __call__(self, hidden_states, temb, deterministic=True): output_states += (hidden_states,) if self.add_downsample: - hidden_states = self.downsample(hidden_states) + hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states @@ -151,7 +151,7 @@ def setup(self): self.attentions = attentions if self.add_upsample: - self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): for resnet, attn in zip(self.resnets, self.attentions): @@ -164,7 +164,7 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) if self.add_upsample: - hidden_states = self.upsample(hidden_states) + hidden_states = self.upsamplers_0(hidden_states) return hidden_states @@ -196,7 +196,7 @@ def setup(self): self.resnets = resnets if self.add_upsample: - self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): for resnet in self.resnets: @@ -208,7 +208,7 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=T hidden_states = resnet(hidden_states, temb, deterministic=deterministic) if self.add_upsample: - hidden_states = self.upsample(hidden_states) + hidden_states = self.upsamplers_0(hidden_states) return hidden_states