From a1f115c1d2705f6437547e7ec9b8fec91f619ef1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 11:20:05 +0000 Subject: [PATCH 01/14] first commit: - add `from_pt` argument in `from_pretrained` function - add `modeling_flax_pytorch_utils.py` file --- src/diffusers/modeling_flax_pytorch_utils.py | 126 +++++++++++++++++++ src/diffusers/modeling_flax_utils.py | 57 +++++---- 2 files changed, 161 insertions(+), 22 deletions(-) create mode 100644 src/diffusers/modeling_flax_pytorch_utils.py diff --git a/src/diffusers/modeling_flax_pytorch_utils.py b/src/diffusers/modeling_flax_pytorch_utils.py new file mode 100644 index 000000000000..943206d772d8 --- /dev/null +++ b/src/diffusers/modeling_flax_pytorch_utils.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch - Flax general utilities.""" +from typing import Dict, Tuple + +import numpy as np + +import jax.numpy as jnp +from flax.traverse_util import flatten_dict, unflatten_dict + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +##################### +# PyTorch => Flax # +##################### + +# Copied from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 +def rename_key_and_reshape_tensor( + pt_tuple_key: Tuple[str], + pt_tensor: np.ndarray, + random_flax_state_dict: Dict[str, jnp.ndarray], + model_prefix: str, +) -> (Tuple[str], np.ndarray): + """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" + + def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: + """Checks if `key` of `(prefix,) + key` is in random_flax_state_dict""" + return len(set(random_flax_state_dict) & set([key, (model_prefix,) + key])) > 0 + + # layer norm + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): + return renamed_pt_tuple_key, pt_tensor + + # embedding + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) + if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): + return renamed_pt_tuple_key, pt_tensor + + # conv layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key): + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) + return renamed_pt_tuple_key, pt_tensor + + # linear layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key): + pt_tensor = pt_tensor.T + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm weight + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) + if pt_tuple_key[-1] == "gamma": + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm bias + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) + if pt_tuple_key[-1] == "beta": + return renamed_pt_tuple_key, pt_tensor + + return pt_tuple_key, pt_tensor + + +# Copied from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L116 +def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): + # convert pytorch tensor to numpy + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + model_prefix = flax_model.base_model_prefix + random_flax_state_dict = flatten_dict(flax_model.params) + flax_state_dict = {} + + load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( + model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( + model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + pt_tuple_key = tuple(pt_key.split(".")) + + # remove base model prefix if necessary + has_base_model_prefix = pt_tuple_key[0] == model_prefix + if load_model_with_head_into_base_model and has_base_model_prefix: + pt_tuple_key = pt_tuple_key[1:] + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix + ) + + # add model prefix if necessary + require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict + if load_base_model_into_model_with_head and require_base_model_prefix: + flax_key = (model_prefix,) + flax_key + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + + return unflatten_dict(flax_state_dict) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index f195462eca3e..3463e1a18494 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -27,6 +27,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError +from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax from .modeling_utils import WEIGHTS_NAME from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging @@ -245,6 +246,8 @@ def from_pretrained( The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). Behaves differently depending on whether a `config` is provided or @@ -272,6 +275,7 @@ def from_pretrained( config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) + from_pt = kwargs.pop("from_pt", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) @@ -305,10 +309,16 @@ def from_pretrained( # Load from a Flax checkpoint model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) # At this stage we don't have a weight file so we will raise an error. - elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + elif from_pt: + if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + raise EnvironmentError( + f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " + ) + model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + if os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): raise EnvironmentError( - f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " - "but there is a file for PyTorch weights." + f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model" + " using `from_pt=True`." ) else: raise EnvironmentError( @@ -319,7 +329,7 @@ def from_pretrained( try: model_file = hf_hub_download( pretrained_model_name_or_path, - filename=FLAX_WEIGHTS_NAME, + filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -369,25 +379,28 @@ def from_pretrained( f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." ) - try: - with open(model_file, "rb") as state_f: - state = from_bytes(cls, state_f.read()) - except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + if from_pt: + state = convert_pytorch_state_dict_to_flax(model, model_file) + else: try: - with open(model_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please" - " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" - " folder you cloned." - ) - else: - raise ValueError from e - except (UnicodeDecodeError, ValueError): - raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") - # make sure all arrays are stored as jnp.ndarray - # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: - # https://github.com/google/flax/issues/1261 + with open(model_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + # make sure all arrays are stored as jnp.ndarray + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) # flatten dicts From 952741ecf83ccba944a9ebbeee7ac2c4548fad03 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 12:06:03 +0000 Subject: [PATCH 02/14] small nit - fix a small nit - to not enter in the second if condition --- src/diffusers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 3463e1a18494..59b611d328d7 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -315,7 +315,7 @@ def from_pretrained( f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " ) model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - if os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): raise EnvironmentError( f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model" " using `from_pt=True`." From fd937eab0b2b88bb4fbeb7a5678ccf413cd3fa37 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 14:00:27 +0000 Subject: [PATCH 03/14] major changes - modify FlaxUnet modules - first conversion script - more keys to be matched --- src/diffusers/modeling_flax_pytorch_utils.py | 67 ++++++++------------ src/diffusers/modeling_flax_utils.py | 8 ++- src/diffusers/models/attention_flax.py | 18 +++--- 3 files changed, 42 insertions(+), 51 deletions(-) diff --git a/src/diffusers/modeling_flax_pytorch_utils.py b/src/diffusers/modeling_flax_pytorch_utils.py index 943206d772d8..667c3efeb50a 100644 --- a/src/diffusers/modeling_flax_pytorch_utils.py +++ b/src/diffusers/modeling_flax_pytorch_utils.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch - Flax general utilities.""" +import re from typing import Dict, Tuple import numpy as np import jax.numpy as jnp from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey from .utils import logging @@ -26,6 +28,14 @@ logger = logging.get_logger(__name__) +def rename_key(key): + regex = r"\w+[.]\d+" + pats = re.findall(regex, key) + for pat in pats: + key = key.replace(pat, "_".join(pat.split("."))) + return key + + ##################### # PyTorch => Flax # ##################### @@ -34,34 +44,28 @@ def rename_key_and_reshape_tensor( pt_tuple_key: Tuple[str], pt_tensor: np.ndarray, - random_flax_state_dict: Dict[str, jnp.ndarray], - model_prefix: str, ) -> (Tuple[str], np.ndarray): """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" - def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: - """Checks if `key` of `(prefix,) + key` is in random_flax_state_dict""" - return len(set(random_flax_state_dict) & set([key, (model_prefix,) + key])) > 0 - - # layer norm + # # conv norm or layer norm + # This is not really stable since any module that has the name 'scale' + # Will be affected. Maybe just check pt_tuple_key[-2] ? renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) - if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): + if any("norm" in str_ for str_ in pt_tuple_key) and pt_tuple_key[-1] == "weight": return renamed_pt_tuple_key, pt_tensor - # embedding - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) - if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): - return renamed_pt_tuple_key, pt_tensor + # # embedding + # renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) # conv layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) - if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key): + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: pt_tensor = pt_tensor.transpose(2, 3, 1, 0) return renamed_pt_tuple_key, pt_tensor # linear layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) - if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key): + if pt_tuple_key[-1] == "weight": pt_tensor = pt_tensor.T return renamed_pt_tuple_key, pt_tensor @@ -78,40 +82,23 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: return pt_tuple_key, pt_tensor -# Copied from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L116 -def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): - # convert pytorch tensor to numpy +def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): + # Step 1: Convert pytorch tensor to numpy pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} - model_prefix = flax_model.base_model_prefix - random_flax_state_dict = flatten_dict(flax_model.params) - flax_state_dict = {} + # Step 2: Since the model is stateless, get random Flax params + random_flax_params = flax_model.init_weights(PRNGKey(init_key)) - load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( - model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) - ) - load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( - model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) - ) + random_flax_state_dict = flatten_dict(random_flax_params) + flax_state_dict = {} # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): - pt_tuple_key = tuple(pt_key.split(".")) - - # remove base model prefix if necessary - has_base_model_prefix = pt_tuple_key[0] == model_prefix - if load_model_with_head_into_base_model and has_base_model_prefix: - pt_tuple_key = pt_tuple_key[1:] + renamed_pt_key = rename_key(pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) # Correctly rename weight parameters - flax_key, flax_tensor = rename_key_and_reshape_tensor( - pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix - ) - - # add model prefix if necessary - require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict - if load_base_model_into_model_with_head and require_base_model_prefix: - flax_key = (model_prefix,) + flax_key + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor) if flax_key in random_flax_state_dict: if flax_tensor.shape != random_flax_state_dict[flax_key].shape: diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 59b611d328d7..6281abfa6e1d 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -28,7 +28,7 @@ from requests import HTTPError from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax -from .modeling_utils import WEIGHTS_NAME +from .modeling_utils import WEIGHTS_NAME, load_state_dict from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging @@ -380,7 +380,11 @@ def from_pretrained( ) if from_pt: - state = convert_pytorch_state_dict_to_flax(model, model_file) + # Step 1: Get the pytorch file + pytorch_model_file = load_state_dict(model_file) + + # Step 2: Convert the weights + state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) else: try: with open(model_file, "rb") as state_f: diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 918c7469a74c..4683b871bfc4 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -32,7 +32,7 @@ def setup(self): self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") - self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out") + self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -82,9 +82,9 @@ class FlaxBasicTransformerBlock(nn.Module): def setup(self): # self attention - self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) # cross attention - self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) @@ -93,12 +93,12 @@ def setup(self): def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states - hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic) + hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention residual = hidden_states - hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic) + hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) hidden_states = hidden_states + residual # feed forward @@ -169,12 +169,12 @@ class FlaxGluFeedForward(nn.Module): def setup(self): inner_dim = self.dim * 4 - self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype) - self.dense2 = nn.Dense(self.dim, dtype=self.dtype) + self.net_0 = nn.Dense(inner_dim * 2, dtype=self.dtype) + self.net_1 = nn.Dense(self.dim, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): - hidden_states = self.dense1(hidden_states) + hidden_states = self.net_0(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) hidden_states = hidden_linear * nn.gelu(hidden_gelu) - hidden_states = self.dense2(hidden_states) + hidden_states = self.net_1(hidden_states) return hidden_states From 1a289b82ae9f967c9c0d1afbd7cea03574549f8c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 14:31:56 +0000 Subject: [PATCH 04/14] keys match - now all keys match - change module names for correct matching - upsample module name changed --- src/diffusers/models/attention_flax.py | 26 ++++++++++++++++++------ src/diffusers/models/unet_blocks_flax.py | 16 +++++++-------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 4683b871bfc4..525be4818dcc 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -168,13 +168,27 @@ class FlaxGluFeedForward(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - inner_dim = self.dim * 4 - self.net_0 = nn.Dense(inner_dim * 2, dtype=self.dtype) - self.net_1 = nn.Dense(self.dim, dtype=self.dtype) + # The second linear layer needs to be called + # net_2 for now to match the index of the Sequential layer + self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) + self.net_2 = nn.Dense(self.dim, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): hidden_states = self.net_0(hidden_states) - hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) - hidden_states = hidden_linear * nn.gelu(hidden_gelu) - hidden_states = self.net_1(hidden_states) + hidden_states = self.net_2(hidden_states) return hidden_states + + +class FlaxGEGLU(nn.Module): + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim * 4 + self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.proj(hidden_states) + hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) + return hidden_linear * nn.gelu(hidden_gelu) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index ce67eb12b19f..fe6cc3b194e3 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -55,7 +55,7 @@ def setup(self): self.attentions = attentions if self.add_downsample: - self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): output_states = () @@ -66,7 +66,7 @@ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=Tru output_states += (hidden_states,) if self.add_downsample: - hidden_states = self.downsample(hidden_states) + hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states @@ -96,7 +96,7 @@ def setup(self): self.resnets = resnets if self.add_downsample: - self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, deterministic=True): output_states = () @@ -106,7 +106,7 @@ def __call__(self, hidden_states, temb, deterministic=True): output_states += (hidden_states,) if self.add_downsample: - hidden_states = self.downsample(hidden_states) + hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states @@ -151,7 +151,7 @@ def setup(self): self.attentions = attentions if self.add_upsample: - self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): for resnet, attn in zip(self.resnets, self.attentions): @@ -164,7 +164,7 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) if self.add_upsample: - hidden_states = self.upsample(hidden_states) + hidden_states = self.upsamplers_0(hidden_states) return hidden_states @@ -196,7 +196,7 @@ def setup(self): self.resnets = resnets if self.add_upsample: - self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): for resnet in self.resnets: @@ -208,7 +208,7 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=T hidden_states = resnet(hidden_states, temb, deterministic=deterministic) if self.add_upsample: - hidden_states = self.upsample(hidden_states) + hidden_states = self.upsamplers_0(hidden_states) return hidden_states From e9dde493c8ec813666999544871f3728ef1ae164 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 15:12:07 +0000 Subject: [PATCH 05/14] working v1 - test pass with atol and rtol= `4e-02` --- src/diffusers/models/unet_2d_condition_flax.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 1ac68e10c159..7b1a24c68e91 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -192,7 +192,7 @@ def setup(self): def __call__( self, sample, - timesteps, + timestep, encoder_hidden_states, return_dict: bool = True, train: bool = False, @@ -214,6 +214,13 @@ def __call__( When returning a tuple, the first element is the sample tensor. """ # 1. time + timesteps = timestep + if not isinstance(timesteps, jnp.ndarray): + timesteps = jnp.array([timesteps], dtype=jnp.int32) + elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: + timesteps = timesteps.to(dtype=jnp.float32) + timesteps = timesteps[None] + t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) From 7527ab185d744ac09948d86581bcae06e027f37c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 15:26:31 +0000 Subject: [PATCH 06/14] replace unsued arg --- src/diffusers/models/unet_2d_condition_flax.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 7b1a24c68e91..865d02c1c436 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -192,7 +192,7 @@ def setup(self): def __call__( self, sample, - timestep, + timesteps, encoder_hidden_states, return_dict: bool = True, train: bool = False, @@ -214,11 +214,10 @@ def __call__( When returning a tuple, the first element is the sample tensor. """ # 1. time - timesteps = timestep if not isinstance(timesteps, jnp.ndarray): timesteps = jnp.array([timesteps], dtype=jnp.int32) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: - timesteps = timesteps.to(dtype=jnp.float32) + timesteps = timesteps.astype(dtype=jnp.float32) timesteps = timesteps[None] t_emb = self.time_proj(timesteps) From 84ba1c0f251c2fd44c24b42cff5ecf975560fc10 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 15:27:19 +0000 Subject: [PATCH 07/14] make quality --- src/diffusers/modeling_flax_pytorch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_pytorch_utils.py b/src/diffusers/modeling_flax_pytorch_utils.py index 667c3efeb50a..9d123ce493f6 100644 --- a/src/diffusers/modeling_flax_pytorch_utils.py +++ b/src/diffusers/modeling_flax_pytorch_utils.py @@ -14,7 +14,7 @@ # limitations under the License. """ PyTorch - Flax general utilities.""" import re -from typing import Dict, Tuple +from typing import Tuple import numpy as np From 09c0c844293a40fe1c0c18284008a3d3a52b0e23 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 15:33:19 +0000 Subject: [PATCH 08/14] add small docstring --- src/diffusers/modeling_flax_pytorch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_pytorch_utils.py b/src/diffusers/modeling_flax_pytorch_utils.py index 9d123ce493f6..7c1c17fb3a47 100644 --- a/src/diffusers/modeling_flax_pytorch_utils.py +++ b/src/diffusers/modeling_flax_pytorch_utils.py @@ -40,7 +40,7 @@ def rename_key(key): # PyTorch => Flax # ##################### -# Copied from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 +# Inspired from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 def rename_key_and_reshape_tensor( pt_tuple_key: Tuple[str], pt_tensor: np.ndarray, From 0bb3f049a88d3042144752a46309da2c44e2f1fb Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 15:38:56 +0000 Subject: [PATCH 09/14] add more comments - add TODO for embedding layers --- src/diffusers/modeling_flax_pytorch_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modeling_flax_pytorch_utils.py b/src/diffusers/modeling_flax_pytorch_utils.py index 7c1c17fb3a47..0a8934b9d977 100644 --- a/src/diffusers/modeling_flax_pytorch_utils.py +++ b/src/diffusers/modeling_flax_pytorch_utils.py @@ -54,8 +54,9 @@ def rename_key_and_reshape_tensor( if any("norm" in str_ for str_ in pt_tuple_key) and pt_tuple_key[-1] == "weight": return renamed_pt_tuple_key, pt_tensor - # # embedding - # renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) + # embedding + # For now the embedding layers are not converted + # TODO: figure out how to detect embedding layers # conv layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) From 1facd9fa4b9fb164be779c9cb5feeb60c23c4895 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 16:01:32 +0000 Subject: [PATCH 10/14] small change - use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array --- src/diffusers/models/unet_2d_condition_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 865d02c1c436..bb6079b3889d 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -218,7 +218,7 @@ def __call__( timesteps = jnp.array([timesteps], dtype=jnp.int32) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: timesteps = timesteps.astype(dtype=jnp.float32) - timesteps = timesteps[None] + timesteps = jnp.expand_dims(timesteps, 0) t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) From 5973e435410222f4ad702226d04ee800fe5ee13d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Sep 2022 16:13:35 +0000 Subject: [PATCH 11/14] add more conditions on conversion - add better test to check for keys conversion --- src/diffusers/modeling_flax_pytorch_utils.py | 33 +++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/src/diffusers/modeling_flax_pytorch_utils.py b/src/diffusers/modeling_flax_pytorch_utils.py index 0a8934b9d977..9c7a5de2ad6e 100644 --- a/src/diffusers/modeling_flax_pytorch_utils.py +++ b/src/diffusers/modeling_flax_pytorch_utils.py @@ -14,9 +14,6 @@ # limitations under the License. """ PyTorch - Flax general utilities.""" import re -from typing import Tuple - -import numpy as np import jax.numpy as jnp from flax.traverse_util import flatten_dict, unflatten_dict @@ -40,23 +37,29 @@ def rename_key(key): # PyTorch => Flax # ##################### -# Inspired from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 -def rename_key_and_reshape_tensor( - pt_tuple_key: Tuple[str], - pt_tensor: np.ndarray, -) -> (Tuple[str], np.ndarray): +# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 +# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py +def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" - # # conv norm or layer norm - # This is not really stable since any module that has the name 'scale' - # Will be affected. Maybe just check pt_tuple_key[-2] ? + # conv norm or layer norm renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) - if any("norm" in str_ for str_ in pt_tuple_key) and pt_tuple_key[-1] == "weight": + if ( + any("norm" in str_ for str_ in pt_tuple_key) + and (pt_tuple_key[-1] == "bias") + and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) + and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) + ): + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + return renamed_pt_tuple_key, pt_tensor + elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) return renamed_pt_tuple_key, pt_tensor # embedding - # For now the embedding layers are not converted - # TODO: figure out how to detect embedding layers + if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: + pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) + return renamed_pt_tuple_key, pt_tensor # conv layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) @@ -99,7 +102,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): pt_tuple_key = tuple(renamed_pt_key.split(".")) # Correctly rename weight parameters - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor) + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) if flax_key in random_flax_state_dict: if flax_tensor.shape != random_flax_state_dict[flax_key].shape: From 4cad1aeb4aeb224402dad13c018a5d42e96267f6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 19 Sep 2022 21:55:49 +0000 Subject: [PATCH 12/14] make shapes consistent - output `img_w x img_h x n_channels` from the VAE --- src/diffusers/models/vae_flax.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 793010e87f67..77451e2ceb42 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -559,7 +559,7 @@ def setup(self): def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: # init input tensors - sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample_shape = (1, self.sample_size, self.sample_size, self.in_channels) sample = jnp.zeros(sample_shape, dtype=jnp.float32) params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) @@ -568,8 +568,6 @@ def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: return self.init(rngs, sample)["params"] def encode(self, sample, deterministic: bool = True, return_dict: bool = True): - sample = jnp.transpose(sample, (0, 2, 3, 1)) - hidden_states = self.encoder(sample, deterministic=deterministic) moments = self.quant_conv(hidden_states) posterior = DiagonalGaussianDistribution(moments) @@ -586,8 +584,6 @@ def decode(self, latents, deterministic: bool = True, return_dict: bool = True): hidden_states = self.post_quant_conv(latents) hidden_states = self.decoder(hidden_states, deterministic=deterministic) - hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) - if not return_dict: return (hidden_states,) From b6d5e2842c1d29beb1df4094b99dc9cded8df189 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 10:15:08 +0000 Subject: [PATCH 13/14] Revert "make shapes consistent" This reverts commit 4cad1aeb4aeb224402dad13c018a5d42e96267f6. --- src/diffusers/models/vae_flax.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 77451e2ceb42..793010e87f67 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -559,7 +559,7 @@ def setup(self): def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: # init input tensors - sample_shape = (1, self.sample_size, self.sample_size, self.in_channels) + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) @@ -568,6 +568,8 @@ def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: return self.init(rngs, sample)["params"] def encode(self, sample, deterministic: bool = True, return_dict: bool = True): + sample = jnp.transpose(sample, (0, 2, 3, 1)) + hidden_states = self.encoder(sample, deterministic=deterministic) moments = self.quant_conv(hidden_states) posterior = DiagonalGaussianDistribution(moments) @@ -584,6 +586,8 @@ def decode(self, latents, deterministic: bool = True, return_dict: bool = True): hidden_states = self.post_quant_conv(latents) hidden_states = self.decoder(hidden_states, deterministic=deterministic) + hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) + if not return_dict: return (hidden_states,) From d58798afc5dff95af515a27fd59d1e61cbc2fda3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 20 Sep 2022 10:19:29 +0000 Subject: [PATCH 14/14] fix unet shape - channels first! --- src/diffusers/models/unet_2d_condition_flax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 33fdc24cefe4..d0fcd9f6ae13 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -76,7 +76,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: # init input tensors - sample_shape = (1, self.sample_size, self.sample_size, self.in_channels) + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) @@ -224,6 +224,7 @@ def __call__( t_emb = self.time_embedding(t_emb) # 2. pre-process + sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) # 3. down @@ -257,6 +258,7 @@ def __call__( sample = self.conv_norm_out(sample) sample = nn.silu(sample) sample = self.conv_out(sample) + sample = jnp.transpose(sample, (0, 3, 1, 2)) if not return_dict: return (sample,)