From e515c72ed176d04b228e49c8f7e3ba4c2c57bbf4 Mon Sep 17 00:00:00 2001 From: saslifat-gif Date: Mon, 13 Apr 2026 19:49:30 +0800 Subject: [PATCH 1/8] Fix Gemma4 ZeRO-3 weight loading by correcting base_model_prefix in AudioModel and VisionModel --- src/transformers/models/gemma4/modeling_gemma4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index 88c340a9414b..1ca3bc350407 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -1876,7 +1876,7 @@ class Gemma4AudioModel(Gemma4PreTrainedModel): config: Gemma4AudioConfig main_input_name = "input_features" - base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained() + base_model_prefix = "audio_tower" _can_record_outputs = { "hidden_states": Gemma4AudioLayer, "attentions": Gemma4AudioAttention, @@ -1959,6 +1959,7 @@ def forward( class Gemma4VisionModel(Gemma4PreTrainedModel): """The Gemma 4 Vision Encoder.""" + base_model_prefix = "vision_tower" config = Gemma4VisionConfig _can_record_outputs = { "hidden_states": Gemma4VisionEncoderLayer, From 2ddf0c69c26b21eb70883c29130e93f54b064c0c Mon Sep 17 00:00:00 2001 From: saslifat-gif Date: Mon, 13 Apr 2026 20:36:46 +0800 Subject: [PATCH 2/8] Revert VisionModel base_model_prefix change per review feedback --- src/transformers/models/gemma4/modeling_gemma4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index 1ca3bc350407..ce8f7be666a8 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -1959,7 +1959,6 @@ def forward( class Gemma4VisionModel(Gemma4PreTrainedModel): """The Gemma 4 Vision Encoder.""" - base_model_prefix = "vision_tower" config = Gemma4VisionConfig _can_record_outputs = { "hidden_states": Gemma4VisionEncoderLayer, From 28d38cb2595bb3fa9696d9d3048cc1a133c8265e Mon Sep 17 00:00:00 2001 From: saslifat-gif Date: Mon, 13 Apr 2026 22:01:06 +0800 Subject: [PATCH 3/8] Fix ZeRO-3 loading: handle buffers in _load_state_dict_into_zero3_model Buffers registered via register_buffer() were completely skipped during from_pretrained() under DeepSpeed ZeRO-3. The load() function in _load_state_dict_into_zero3_model only iterated over named_parameters, never named_buffers, so buffer values from checkpoint were never loaded and always reported as MISSING. Fix: after gathering and loading parameters, explicitly load buffers directly (no GatheredParameters needed since buffers are not sharded by ZeRO-3). Fixes #45397 --- src/transformers/integrations/deepspeed.py | 15 ++++++++++++--- src/transformers/models/gemma4/modeling_gemma4.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 7a64a753c6d4..6e71fcbad51d 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -474,7 +474,7 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict, load_config=Non # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. - def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): +def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata["assign_to_params_buffers"] = assign_to_params_buffers @@ -504,13 +504,22 @@ def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=Fals if torch.distributed.get_rank() == 0: module._load_from_state_dict(*args) + # Buffers are not partitioned by ZeRO-3, load them directly + named_buffers = dict(module.named_buffers(prefix=prefix[:-1], recurse=False)) + for k, buf in named_buffers.items(): + if k in state_dict and buf is not None: + missing_keys.discard(k) + if torch.distributed.get_rank() == 0: + with torch.no_grad(): + buf.copy_(state_dict[k]) + for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".", assign_to_params_buffers) - load(model_to_load, state_dict, assign_to_params_buffers=False) + load(model_to_load, state_dict, assign_to_params_buffers=False) - return error_msgs, missing_keys + return error_msgs, missing_keys def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index ce8f7be666a8..88c340a9414b 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -1876,7 +1876,7 @@ class Gemma4AudioModel(Gemma4PreTrainedModel): config: Gemma4AudioConfig main_input_name = "input_features" - base_model_prefix = "audio_tower" + base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained() _can_record_outputs = { "hidden_states": Gemma4AudioLayer, "attentions": Gemma4AudioAttention, From ffc0067d5779f2b442262cbfb119f64bdf5b6138 Mon Sep 17 00:00:00 2001 From: saslifat-gif Date: Mon, 13 Apr 2026 22:20:20 +0800 Subject: [PATCH 4/8] Fix indentation in _load_state_dict_into_zero3_model buffer handling --- src/transformers/integrations/deepspeed.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 6e71fcbad51d..ca46f03e7b48 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -474,7 +474,7 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict, load_config=Non # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. -def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): + def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata["assign_to_params_buffers"] = assign_to_params_buffers @@ -491,7 +491,7 @@ def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=Fals for k in named_parameters: if k in state_dict: param = named_parameters[k] - # crutial to not init the weight again + # crucial to not init the weight again param._is_hf_initialized = True params_to_gather.append(param) missing_keys.discard(k) @@ -517,9 +517,9 @@ def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=Fals if child is not None: load(child, state_dict, prefix + name + ".", assign_to_params_buffers) - load(model_to_load, state_dict, assign_to_params_buffers=False) + load(model_to_load, state_dict, assign_to_params_buffers=False) - return error_msgs, missing_keys + return error_msgs, missing_keys def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): From d5f64abea813309feb77730d1faf5e8de9d26560 Mon Sep 17 00:00:00 2001 From: saslifat-gif Date: Wed, 15 Apr 2026 13:51:21 +0800 Subject: [PATCH 5/8] Add test for ZeRO-3 registered buffer loading --- .../test_trainer_distributed_deepspeed.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/trainer/distributed/test_trainer_distributed_deepspeed.py b/tests/trainer/distributed/test_trainer_distributed_deepspeed.py index 8d3672a55c26..f1b81f7ef89a 100644 --- a/tests/trainer/distributed/test_trainer_distributed_deepspeed.py +++ b/tests/trainer/distributed/test_trainer_distributed_deepspeed.py @@ -1525,6 +1525,40 @@ def test_resize_token_embeddings_zero3(self): embedding = model.get_input_embeddings() with deepspeed.zero.GatheredParameters([embedding.weight]): self.assertEqual(embedding.weight.shape[0], new_size) + + def test_zero3_load_registered_buffers(self): + """Test that registered buffers are loaded correctly under ZeRO-3 from_pretrained.""" + from transformers.models.gemma4.configuration_gemma4 import ( + Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig, + ) + from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration + + text_config = Gemma4TextConfig( + hidden_size=128, num_hidden_layers=2, num_attention_heads=2, + intermediate_size=256, vocab_size=32000, num_key_value_heads=2, pad_token_id=None, + ) + vision_config = Gemma4VisionConfig(hidden_size=64, num_hidden_layers=2, num_attention_heads=2, intermediate_size=128) + audio_config = Gemma4AudioConfig() + config = Gemma4Config(text_config=text_config, vision_config=vision_config, audio_config=audio_config) + + # save without ZeRO-3 + save_path = self.get_auto_remove_tmp_dir() + model = Gemma4ForConditionalGeneration(config) + model.save_pretrained(save_path) + del model + + # load with ZeRO-3 + ds_config = self._get_zero3_ds_config(bf16={"enabled": True}, train_micro_batch_size_per_gpu=1) + with mockenv_context(**self.dist_env_1_gpu): + dschf = HfDeepSpeedConfig(ds_config) + model2 = Gemma4ForConditionalGeneration.from_pretrained(save_path, torch_dtype=torch.bfloat16) + + # verify no registered buffers are MISSING + missing = [ + name for name, buf in model2.named_buffers() + if buf is None + ] + self.assertEqual(missing, [], f"Registered buffers missing after ZeRO-3 load: {missing}") # --------------------------------------------------------------------------- From 21f9e952466923ebe12c8c19beb1a48c6953ceaa Mon Sep 17 00:00:00 2001 From: saslifat-gif Date: Thu, 16 Apr 2026 14:24:16 +0800 Subject: [PATCH 6/8] fix: organize imports and remove unused variable in deepspeed test --- .../distributed/test_trainer_distributed_deepspeed.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/trainer/distributed/test_trainer_distributed_deepspeed.py b/tests/trainer/distributed/test_trainer_distributed_deepspeed.py index f1b81f7ef89a..9080603c40ea 100644 --- a/tests/trainer/distributed/test_trainer_distributed_deepspeed.py +++ b/tests/trainer/distributed/test_trainer_distributed_deepspeed.py @@ -1525,11 +1525,14 @@ def test_resize_token_embeddings_zero3(self): embedding = model.get_input_embeddings() with deepspeed.zero.GatheredParameters([embedding.weight]): self.assertEqual(embedding.weight.shape[0], new_size) - + def test_zero3_load_registered_buffers(self): """Test that registered buffers are loaded correctly under ZeRO-3 from_pretrained.""" from transformers.models.gemma4.configuration_gemma4 import ( - Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig, + Gemma4AudioConfig, + Gemma4Config, + Gemma4TextConfig, + Gemma4VisionConfig, ) from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration @@ -1550,7 +1553,7 @@ def test_zero3_load_registered_buffers(self): # load with ZeRO-3 ds_config = self._get_zero3_ds_config(bf16={"enabled": True}, train_micro_batch_size_per_gpu=1) with mockenv_context(**self.dist_env_1_gpu): - dschf = HfDeepSpeedConfig(ds_config) + HfDeepSpeedConfig(ds_config) model2 = Gemma4ForConditionalGeneration.from_pretrained(save_path, torch_dtype=torch.bfloat16) # verify no registered buffers are MISSING From 238dd3156f0e458975626c140ea70ddee8daaad2 Mon Sep 17 00:00:00 2001 From: saslifat-gif Date: Thu, 16 Apr 2026 14:33:02 +0800 Subject: [PATCH 7/8] fix: apply ruff formatting to deepspeed test --- .../test_trainer_distributed_deepspeed.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/trainer/distributed/test_trainer_distributed_deepspeed.py b/tests/trainer/distributed/test_trainer_distributed_deepspeed.py index 9080603c40ea..554465d79881 100644 --- a/tests/trainer/distributed/test_trainer_distributed_deepspeed.py +++ b/tests/trainer/distributed/test_trainer_distributed_deepspeed.py @@ -1537,10 +1537,17 @@ def test_zero3_load_registered_buffers(self): from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration text_config = Gemma4TextConfig( - hidden_size=128, num_hidden_layers=2, num_attention_heads=2, - intermediate_size=256, vocab_size=32000, num_key_value_heads=2, pad_token_id=None, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=2, + intermediate_size=256, + vocab_size=32000, + num_key_value_heads=2, + pad_token_id=None, + ) + vision_config = Gemma4VisionConfig( + hidden_size=64, num_hidden_layers=2, num_attention_heads=2, intermediate_size=128 ) - vision_config = Gemma4VisionConfig(hidden_size=64, num_hidden_layers=2, num_attention_heads=2, intermediate_size=128) audio_config = Gemma4AudioConfig() config = Gemma4Config(text_config=text_config, vision_config=vision_config, audio_config=audio_config) @@ -1557,10 +1564,7 @@ def test_zero3_load_registered_buffers(self): model2 = Gemma4ForConditionalGeneration.from_pretrained(save_path, torch_dtype=torch.bfloat16) # verify no registered buffers are MISSING - missing = [ - name for name, buf in model2.named_buffers() - if buf is None - ] + missing = [name for name, buf in model2.named_buffers() if buf is None] self.assertEqual(missing, [], f"Registered buffers missing after ZeRO-3 load: {missing}") From a9fcac857dc1c562595eb5a828840c67364835f3 Mon Sep 17 00:00:00 2001 From: saslifat-gif Date: Thu, 16 Apr 2026 23:39:58 +0800 Subject: [PATCH 8/8] fix: copy buffers on all ranks and set _is_hf_initialized in ZeRO-3 load - Remove rank==0 guard so buffers are copied on all ranks - Set buf._is_hf_initialized = True after copy to prevent re-initialization - Update test to verify buffer VALUES survive ZeRO-3 from_pretrained round-trip --- src/transformers/integrations/deepspeed.py | 6 ++-- .../test_trainer_distributed_deepspeed.py | 31 +++++++++++++------ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index ca46f03e7b48..9703f642f8bc 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -509,9 +509,9 @@ def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=Fals for k, buf in named_buffers.items(): if k in state_dict and buf is not None: missing_keys.discard(k) - if torch.distributed.get_rank() == 0: - with torch.no_grad(): - buf.copy_(state_dict[k]) + with torch.no_grad(): + buf.copy_(state_dict[k]) + buf._is_hf_initialized = True for name, child in module._modules.items(): if child is not None: diff --git a/tests/trainer/distributed/test_trainer_distributed_deepspeed.py b/tests/trainer/distributed/test_trainer_distributed_deepspeed.py index 554465d79881..6a0b3c49160e 100644 --- a/tests/trainer/distributed/test_trainer_distributed_deepspeed.py +++ b/tests/trainer/distributed/test_trainer_distributed_deepspeed.py @@ -1527,7 +1527,7 @@ def test_resize_token_embeddings_zero3(self): self.assertEqual(embedding.weight.shape[0], new_size) def test_zero3_load_registered_buffers(self): - """Test that registered buffers are loaded correctly under ZeRO-3 from_pretrained.""" + """Test that registered buffers are loaded with correct values under ZeRO-3 from_pretrained.""" from transformers.models.gemma4.configuration_gemma4 import ( Gemma4AudioConfig, Gemma4Config, @@ -1543,7 +1543,7 @@ def test_zero3_load_registered_buffers(self): intermediate_size=256, vocab_size=32000, num_key_value_heads=2, - pad_token_id=None, + pad_token_id=0, ) vision_config = Gemma4VisionConfig( hidden_size=64, num_hidden_layers=2, num_attention_heads=2, intermediate_size=128 @@ -1551,21 +1551,34 @@ def test_zero3_load_registered_buffers(self): audio_config = Gemma4AudioConfig() config = Gemma4Config(text_config=text_config, vision_config=vision_config, audio_config=audio_config) - # save without ZeRO-3 + # Save without ZeRO-3, with non-default buffer values save_path = self.get_auto_remove_tmp_dir() model = Gemma4ForConditionalGeneration(config) + for name, buf in model.named_buffers(): + if "input_max" in name: + buf.fill_(42.0) + elif "output_min" in name: + buf.fill_(-42.0) + elif "layer_scalar" in name: + buf.fill_(0.5) model.save_pretrained(save_path) del model - # load with ZeRO-3 - ds_config = self._get_zero3_ds_config(bf16={"enabled": True}, train_micro_batch_size_per_gpu=1) + # Load with ZeRO-3 + ds_config = self._get_zero3_ds_config(bf16={"enabled": True}) + dschf = HfDeepSpeedConfig(ds_config) + self.assertTrue(dschf.is_zero3()) with mockenv_context(**self.dist_env_1_gpu): - HfDeepSpeedConfig(ds_config) model2 = Gemma4ForConditionalGeneration.from_pretrained(save_path, torch_dtype=torch.bfloat16) - # verify no registered buffers are MISSING - missing = [name for name, buf in model2.named_buffers() if buf is None] - self.assertEqual(missing, [], f"Registered buffers missing after ZeRO-3 load: {missing}") + # Verify buffer VALUES were loaded from checkpoint, not re-initialized + for name, buf in model2.named_buffers(): + if "input_max" in name: + self.assertEqual(buf.item(), 42.0, f"{name} was not loaded from checkpoint") + elif "output_min" in name: + self.assertEqual(buf.item(), -42.0, f"{name} was not loaded from checkpoint") + elif "layer_scalar" in name: + self.assertEqual(buf.item(), 0.5, f"{name} was not loaded from checkpoint") # ---------------------------------------------------------------------------