From 32c2be514bb4d3bd76b23a5ceed7865bc34f821d Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 16:50:20 +0000 Subject: [PATCH 1/7] Add `init_weights` method to `FlaxMixin` --- src/diffusers/modeling_flax_utils.py | 83 +++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 4f2d25dfb168..f151a08b2c1b 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -20,7 +20,7 @@ import jax import jax.numpy as jnp import msgpack.exceptions -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict from huggingface_hub import hf_hub_download @@ -183,6 +183,9 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): ```""" return self._cast_floating_to(params, jnp.float16, mask) + def init_weights(self, rng: jax.random.PRNGKey) -> Dict: + raise NotImplementedError(f"init method has to be implemented for {self}") + @classmethod def from_pretrained( cls, @@ -272,6 +275,7 @@ def from_pretrained( ```""" config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) @@ -280,6 +284,7 @@ def from_pretrained( revision = kwargs.pop("revision", None) from_auto_class = kwargs.pop("_from_auto", False) subfolder = kwargs.pop("subfolder", None) + prng_key = kwargs.pop("prng_key", None) user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} @@ -394,6 +399,82 @@ def from_pretrained( # flatten dicts state = flatten_dict(state) + prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) + params_shape_tree = jax.eval_shape(model.init_weights, prng_key) + required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) + + random_state = flatten_dict(unfreeze(params_shape_tree)) + + missing_keys = required_params - set(state.keys()) + unexpected_keys = set(state.keys()) - required_params + + if missing_keys: + logger.warning( + f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " + "Make sure to call model.init_weights to initialize the missing weights." + ) + cls._missing_keys = missing_keys + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys = [] + for key in state.keys(): + if key in random_state and state[key].shape != random_state[key].shape: + if ignore_mismatched_sizes: + mismatched_keys.append((key, state[key].shape, random_state[key].shape)) + state[key] = random_state[key] + else: + raise ValueError( + f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " + f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " + "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " + "model." + ) + + # remove unexpected keys to not be saved again + for unexpected_key in unexpected_keys: + del state[unexpected_key] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + # dictionary of key: dtypes for the model params param_dtypes = jax.tree_map(lambda x: x.dtype, state) # extract keys of parameters not in jnp.float32 From 305a5448d1de9c2b4c45ee47b150d37f6919466f Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Thu, 15 Sep 2022 08:02:17 +0000 Subject: [PATCH 2/7] Rn `random_state` -> `shape_state` --- src/diffusers/modeling_flax_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index f151a08b2c1b..a444385a4769 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -403,7 +403,7 @@ def from_pretrained( params_shape_tree = jax.eval_shape(model.init_weights, prng_key) required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) - random_state = flatten_dict(unfreeze(params_shape_tree)) + shape_state = flatten_dict(unfreeze(params_shape_tree)) missing_keys = required_params - set(state.keys()) unexpected_keys = set(state.keys()) - required_params @@ -419,14 +419,14 @@ def from_pretrained( # matching the weights in the model. mismatched_keys = [] for key in state.keys(): - if key in random_state and state[key].shape != random_state[key].shape: + if key in shape_state and state[key].shape != shape_state[key].shape: if ignore_mismatched_sizes: - mismatched_keys.append((key, state[key].shape, random_state[key].shape)) - state[key] = random_state[key] + mismatched_keys.append((key, state[key].shape, shape_state[key].shape)) + state[key] = shape_state[key] else: raise ValueError( f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " - f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " + f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " "model." ) From 803da8f666b2f17c6accf4a73c804fdefd459461 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Thu, 15 Sep 2022 08:04:03 +0000 Subject: [PATCH 3/7] `PRNGKey(0)` for `jax.eval_shape` --- src/diffusers/modeling_flax_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index a444385a4769..80eb91b17dc5 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -284,7 +284,6 @@ def from_pretrained( revision = kwargs.pop("revision", None) from_auto_class = kwargs.pop("_from_auto", False) subfolder = kwargs.pop("subfolder", None) - prng_key = kwargs.pop("prng_key", None) user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} @@ -399,7 +398,7 @@ def from_pretrained( # flatten dicts state = flatten_dict(state) - prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) + prng_key = jax.random.PRNGKey(0) params_shape_tree = jax.eval_shape(model.init_weights, prng_key) required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) From 73e0bc692c5761e55faff39c80a26d7a3cfc748c Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Thu, 15 Sep 2022 08:06:06 +0000 Subject: [PATCH 4/7] No allow mismatched sizes --- src/diffusers/modeling_flax_utils.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 80eb91b17dc5..044df71fdf27 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -230,10 +230,6 @@ def from_pretrained( cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): - Whether or not to raise an error if some of the weights from the checkpoint do not have the same size - as the weights of the model (if for instance, you are instantiating a model with 10 labels from a - checkpoint with 3 labels). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -275,7 +271,6 @@ def from_pretrained( ```""" config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) @@ -419,16 +414,10 @@ def from_pretrained( mismatched_keys = [] for key in state.keys(): if key in shape_state and state[key].shape != shape_state[key].shape: - if ignore_mismatched_sizes: - mismatched_keys.append((key, state[key].shape, shape_state[key].shape)) - state[key] = shape_state[key] - else: - raise ValueError( - f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " - f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " - "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " - "model." - ) + raise ValueError( + f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " + f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " + ) # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: From f9675e96b7de78fa7af8b580259d35fd439b66b6 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Thu, 15 Sep 2022 16:48:07 +0200 Subject: [PATCH 5/7] Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil --- src/diffusers/modeling_flax_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 044df71fdf27..f98d51e2cfca 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -393,8 +393,7 @@ def from_pretrained( # flatten dicts state = flatten_dict(state) - prng_key = jax.random.PRNGKey(0) - params_shape_tree = jax.eval_shape(model.init_weights, prng_key) + params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0)) required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) shape_state = flatten_dict(unfreeze(params_shape_tree)) From 1ed00f30fed374c0f3ca9c03c9336b6a3aa083fc Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Thu, 15 Sep 2022 16:48:14 +0200 Subject: [PATCH 6/7] Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil --- src/diffusers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index f98d51e2cfca..9d590b52d383 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -184,7 +184,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): return self._cast_floating_to(params, jnp.float16, mask) def init_weights(self, rng: jax.random.PRNGKey) -> Dict: - raise NotImplementedError(f"init method has to be implemented for {self}") + raise NotImplementedError(f"init_weights method has to be implemented for {self}") @classmethod def from_pretrained( From 869014ca901aabd82b7e053feb251caed74f9484 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Thu, 15 Sep 2022 14:54:57 +0000 Subject: [PATCH 7/7] docstring diffusers --- src/diffusers/modeling_flax_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 9d590b52d383..0e298db63b29 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -427,10 +427,7 @@ def from_pretrained( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" - " with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" - " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + " with another architecture." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")