diff --git a/src/transformers/models/gemma3/configuration_gemma3.py b/src/transformers/models/gemma3/configuration_gemma3.py index c0184c1993d3..b1ec3311ba66 100644 --- a/src/transformers/models/gemma3/configuration_gemma3.py +++ b/src/transformers/models/gemma3/configuration_gemma3.py @@ -136,6 +136,8 @@ class Gemma3TextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE rope_local_base_freq (float, *optional*, defaults to 10000.0): The base period of the RoPE embeddings for local attention. + use_bidirectional_attention (`bool`, *optional*, defaults to `False`): If True, the model will attend to all + text tokens instead of using a causal mask. This does not change behavior for vision tokens. ```python >>> from transformers import Gemma3TextModel, Gemma3TextConfig @@ -193,6 +195,7 @@ def __init__( attn_logit_softcapping=None, rope_scaling=None, rope_local_base_freq=10_000.0, + use_bidirectional_attention=False, **kwargs, ): super().__init__( @@ -222,6 +225,7 @@ def __init__( self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping self.layer_types = layer_types + self.use_bidirectional_attention = use_bidirectional_attention self.rope_local_base_freq = rope_local_base_freq self.rope_scaling = rope_scaling diff --git a/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py b/src/transformers/models/gemma3/convert_gemma3_weights.py similarity index 79% rename from src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py rename to src/transformers/models/gemma3/convert_gemma3_weights.py index 6bd2b7da4cc0..8d7a21219197 100644 --- a/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py +++ b/src/transformers/models/gemma3/convert_gemma3_weights.py @@ -16,7 +16,7 @@ r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. -python -m transformers.models.gemma3.convert_gemma3_weights_orbax_to_hf \ +python src/transformers/models/gemma3/convert_gemma3_weights.py \ --variant='gemma3_4b' \ --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \ --checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \ @@ -24,7 +24,7 @@ """ from collections.abc import Iterator, Sequence -from typing import Any +from typing import Any, Optional import accelerate import numpy as np @@ -40,6 +40,7 @@ Gemma3ImageProcessor, Gemma3Processor, Gemma3TextConfig, + Gemma3TextModel, GemmaTokenizerFast, GenerationConfig, SiglipVisionConfig, @@ -100,10 +101,10 @@ _SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK) _SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm" -_TRANSFORMER_DECODER_BLOCK = "transformer/layer_" +_TRANSFORMER_DECODER_BLOCK = "/layer_" _TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK) -_TRANSFORMER_EMBEDDER = "transformer/embedder" -_TRANSFORMER_FINAL_NORM = "transformer/final_norm" +_TRANSFORMER_EMBEDDER = "/embedder" +_TRANSFORMER_FINAL_NORM = "/final_norm" _TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/" _TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX) @@ -121,11 +122,46 @@ "vision_use_head": False, } +_VARIANT_EMBEDDINGGEMMA = "embedding" +_VARIANT_GEMMA_3_270M = "gemma3_270m" _VARIANT_GEMMA_3_1B = "gemma3_1b" _VARIANT_GEMMA_3_4B = "gemma3_4b" _VARIANT_GEMMA_3_12B = "gemma3_12b" _VARIANT_GEMMA_3_27B = "gemma3_27b" _VARIANTS = { + _VARIANT_EMBEDDINGGEMMA: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_144, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=24, + num_attention_heads=3, + num_key_value_heads=1, + head_dim=256, + max_position_embeddings=1024, + query_pre_attn_scalar=256, + sliding_window=512, + rope_scaling=None, + use_bidirectional_attention=True, + ), + vision_config=None, + ), + _VARIANT_GEMMA_3_270M: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_144, + hidden_size=640, + intermediate_size=2048, + num_hidden_layers=18, + num_attention_heads=4, + num_key_value_heads=1, + head_dim=256, + max_position_embeddings=32768, + query_pre_attn_scalar=256, + sliding_window=512, + rope_scaling=None, + ), + vision_config=None, + ), _VARIANT_GEMMA_3_1B: Gemma3Config( text_config=Gemma3TextConfig( vocab_size=262_144, @@ -200,6 +236,8 @@ ), } +_TEXT_ONLY_VARIANTS = (_VARIANT_EMBEDDINGGEMMA, _VARIANT_GEMMA_3_270M, _VARIANT_GEMMA_3_1B) + # ==== Flags ==== _CHECKPOINT_PATH = flags.DEFINE_string( @@ -220,6 +258,12 @@ required=True, ) +_NUM_LINEAR_LAYERS = flags.DEFINE_integer( + name="num_linear_layers", + default=2, + help="Number of linear projection layers at the end of the Sentence Transformer.", +) + _TRANSFORMER_DTYPE = flags.DEFINE_enum( name="text_dtype", default="bfloat16", @@ -358,12 +402,12 @@ def convert_transformer_weights( attn_head_dim = config.num_attention_heads * config.head_dim kv_head_dim = config.num_key_value_heads * config.head_dim - if path == _TRANSFORMER_EMBEDDER: + if path.endswith(_TRANSFORMER_EMBEDDER): if prop == "input_embedding": # Tied to language_model.lm_head.weight, assigned at the end. converted_paths = ["language_model.model.embed_tokens.weight"] - if _VARIANT.value != _VARIANT_GEMMA_3_1B: + if _VARIANT.value not in _TEXT_ONLY_VARIANTS: # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama pre_expansion_embeddings = weights mu = np.mean(pre_expansion_embeddings, axis=0) @@ -372,12 +416,12 @@ def convert_transformer_weights( weights = np.vstack([pre_expansion_embeddings, new_embeddings]) converted_weights = [weights] - elif _VARIANT.value == _VARIANT_GEMMA_3_1B or prop in ("mm_output_embedding", "mm_input_embedding_extra"): + elif _VARIANT.value in _TEXT_ONLY_VARIANTS or prop in ("mm_output_embedding", "mm_input_embedding_extra"): return zip([], []) else: raise ValueError(f"Unexpected member, {prop}, in Embedder.") elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"): - if _VARIANT.value == _VARIANT_GEMMA_3_1B: + if _VARIANT.value in _TEXT_ONLY_VARIANTS: return zip([], []) if path.endswith("/mm_input_projection"): @@ -388,14 +432,16 @@ def convert_transformer_weights( converted_weights = [weights] else: raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.") - elif path == _TRANSFORMER_FINAL_NORM: + elif path.endswith(_TRANSFORMER_FINAL_NORM): converted_paths = ["language_model.model.norm.weight"] converted_weights = [weights] - elif path.startswith(_TRANSFORMER_DECODER_BLOCK): - decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:] - next_path_separator_idx = decoder_block_path.find("/") - layer_idx = decoder_block_path[:next_path_separator_idx] - decoder_block_path = decoder_block_path[next_path_separator_idx:] + elif _TRANSFORMER_DECODER_BLOCK in path: + decoder_block_start = path.find(_TRANSFORMER_DECODER_BLOCK) + decoder_block_offset = decoder_block_start + _TRANSFORMER_DECODER_BLOCK_LEN + decoder_block_path = path[decoder_block_offset:] + next_path_seperator_idx = decoder_block_path.find("/") + layer_idx = decoder_block_path[:next_path_seperator_idx] + decoder_block_path = decoder_block_path[next_path_seperator_idx:] base_path = f"language_model.model.layers.{layer_idx}" @@ -445,8 +491,6 @@ def convert_transformer_weights( converted_weights = [weights] else: raise ValueError(f"Unexpected path `{path}` in Decoder Block.") - else: - raise ValueError(f"Unexpected path `{path}`.") if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): raise ValueError( @@ -457,11 +501,14 @@ def convert_transformer_weights( return zip(converted_paths, converted_weights) -def convert(checkpoint_path: str, config: Gemma3Config) -> dict[str, torch.Tensor]: +def convert( + checkpoint_path: str, config: Gemma3Config, variant: str +) -> tuple[dict[str, torch.Tensor], Optional[Sequence[np.ndarray]]]: """Loads Orbax checkpoint from `input_path` and converts it to HF tree.""" checkpointer = obc.PyTreeCheckpointer() ckpt = checkpointer.restore(checkpoint_path) hf_tree: dict[str, torch.Tensor] = {} + orbax_tree_flat = tree.flatten_with_path(ckpt) def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> None: hf_tree[path] = torch.from_numpy(weights.astype("float32")).type(target_dtype) @@ -473,7 +520,7 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No target_dtype, ) - for paths, value in tree.flatten_with_path(ckpt): + for paths, value in orbax_tree_flat: if paths[0].startswith("SigLiPFromPatches_"): if config.vision_config is None: continue @@ -482,17 +529,21 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No update_tree(path, weights, config.vision_config.dtype) else: for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): - if config.vision_config is None: + if variant in _TEXT_ONLY_VARIANTS: path = path[len("language_model.") :] + if variant == _VARIANT_EMBEDDINGGEMMA: + path = path[len("model.") :] update_tree(path, weights, config.text_config.dtype) - if config.vision_config is None: + if variant == _VARIANT_EMBEDDINGGEMMA: + return hf_tree, [weight[1].T for weight in orbax_tree_flat[: _NUM_LINEAR_LAYERS.value]] + elif config.vision_config is None: hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] else: hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"] - return hf_tree + return hf_tree, None def main(*args): @@ -504,7 +555,7 @@ def main(*args): config = _VARIANTS[variant] config.text_config.dtype = getattr(torch, _TRANSFORMER_DTYPE.value) - if variant == _VARIANT_GEMMA_3_1B: + if variant in _TEXT_ONLY_VARIANTS: config.vision_config = None else: config.vision_config.dtype = getattr(torch, _VISION_DTYPE.value) @@ -520,11 +571,13 @@ def main(*args): _TRANSFORMER_DTYPE.value, _VISION_DTYPE.value, ) - state_tree = convert(_CHECKPOINT_PATH.value, config) + state_tree, st_linears = convert(_CHECKPOINT_PATH.value, config, variant) logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant) with accelerate.init_empty_weights(): - if variant == _VARIANT_GEMMA_3_1B: + if variant == _VARIANT_EMBEDDINGGEMMA: + model = Gemma3TextModel(config=config.text_config) + elif variant in _TEXT_ONLY_VARIANTS: model = Gemma3ForCausalLM(config=config.text_config) else: model = Gemma3ForConditionalGeneration(config) @@ -548,6 +601,8 @@ def main(*args): tokenizer = GemmaTokenizerFast( _TOKENIZER_PATH.value, add_bos_token=True, + add_eos_token=variant == _VARIANT_EMBEDDINGGEMMA, + padding_side="right" if variant == _VARIANT_EMBEDDINGGEMMA else "left", extra_special_tokens={ "image_token": "", # Should be ID=262_144 "boi_token": "", # Should be ID=255_999 @@ -558,7 +613,7 @@ def main(*args): tokenizer.save_pretrained(output_path) logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path) - if variant != _VARIANT_GEMMA_3_1B: + if variant not in _TEXT_ONLY_VARIANTS: image_processor = Gemma3ImageProcessor( image_seq_length=256, image_mean=(0.5,) * 3, @@ -589,6 +644,46 @@ def main(*args): ) generation_config.save_pretrained(output_path) + if variant == _VARIANT_EMBEDDINGGEMMA: + from sentence_transformers import SentenceTransformer, models + + # TODO: Support Retrieval tasks where we use `"title: {title} | text: {passage}"` interally and construct this + # from split-records cached data, but externally these come through as a single string with components + # separated by a newline. This should be used for `passage` for SentenceTransformers and the relevant MTEB + # Retrieval tasks. + # https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts + task_prompts = { + "query": "task: search result | query: ", + "document": "title: none | text: ", + "BitextMining": "task: search result | query: ", + "Clustering": "task: clustering | query: ", + "Classification": "task: classification | query: ", + "InstructionRetrieval": "task: code retrieval | query: ", + "MultilabelClassification": "task: classification | query: ", + "PairClassification": "task: sentence similarity | query: ", + "Reranking": "task: search result | query: ", + "Retrieval": "task: search result | query: ", + "Retrieval-query": "task: search result | query: ", + "Retrieval-document": "title: none | text: ", + "STS": "task: sentence similarity | query: ", + "Summarization": "task: summarization | query: ", + } + + transformer = models.Transformer(output_path) + pooling = models.Pooling(config.text_config.hidden_size, pooling_mode="mean") + normalize = models.Normalize() + linears = [] + + for linear_weight in st_linears: + out_size, in_size = linear_weight.shape[:2] + dense = models.Dense(in_size, out_size, bias=False, activation_function=None) + dense.linear.weight.data = torch.from_numpy(linear_weight.astype("float32")) + linears.append(dense) + + model = SentenceTransformer(modules=[transformer, pooling, *linears, normalize], prompts=task_prompts) + model = model.to(getattr(torch, _TRANSFORMER_DTYPE.value)) + model.save_pretrained(output_path) + if __name__ == "__main__": app.run(main) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 2b60466d7ff1..d2ba04298dec 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -443,6 +443,19 @@ def _init_weights(self, module): module.mm_input_projection_weight.data.zero_() +def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]: + """ + Enables a bidirectional mask within the sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + """A token can attend to any other token if their absolute distance is within + half the sliding window size (distance <= sliding_window // 2).""" + return abs(q_idx - kv_idx) <= sliding_window // 2 + + return inner_mask + + @auto_docstring class Gemma3TextModel(Gemma3PreTrainedModel): config: Gemma3TextConfig @@ -531,10 +544,16 @@ def forward( "past_key_values": past_key_values, "position_ids": position_ids, } + sliding_mask_kwargs = mask_kwargs.copy() + + if self.config.use_bidirectional_attention: + mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) + sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window) + # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), - "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), } # embed positions diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 947a22ab8eaa..fc70fa6e9d8e 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -162,6 +162,8 @@ class Gemma3TextConfig(Gemma2Config, PretrainedConfig): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE rope_local_base_freq (float, *optional*, defaults to 10000.0): The base period of the RoPE embeddings for local attention. + use_bidirectional_attention (`bool`, *optional*, defaults to `False`): If True, the model will attend to all + text tokens instead of using a causal mask. This does not change behavior for vision tokens. ```python >>> from transformers import Gemma3TextModel, Gemma3TextConfig @@ -204,6 +206,7 @@ def __init__( attn_logit_softcapping=None, rope_scaling=None, rope_local_base_freq=10_000.0, + use_bidirectional_attention=False, **kwargs, ): PretrainedConfig.__init__( @@ -233,6 +236,7 @@ def __init__( self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping self.layer_types = layer_types + self.use_bidirectional_attention = use_bidirectional_attention self.rope_local_base_freq = rope_local_base_freq self.rope_scaling = rope_scaling @@ -535,6 +539,19 @@ def _init_weights(self, module): module.mm_input_projection_weight.data.zero_() +def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]: + """ + Enables a bidirectional mask within the sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + """A token can attend to any other token if their absolute distance is within + half the sliding window size (distance <= sliding_window // 2).""" + return abs(q_idx - kv_idx) <= sliding_window // 2 + + return inner_mask + + class Gemma3TextModel(Gemma2Model): config: Gemma3TextConfig @@ -609,10 +626,16 @@ def forward( "past_key_values": past_key_values, "position_ids": position_ids, } + sliding_mask_kwargs = mask_kwargs.copy() + + if self.config.use_bidirectional_attention: + mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) + sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window) + # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), - "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), } # embed positions