Fix ZeRO-3 from_pretrained: load registered buffers in _load_state_dict_into_zero3_model#45402
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma4 |
| class Gemma4VisionModel(Gemma4PreTrainedModel): | ||
| """The Gemma 4 Vision Encoder.""" | ||
|
|
||
| base_model_prefix = "vision_tower" |
There was a problem hiding this comment.
Not sure if deepspeed complains about vision as well, can you confirm that it fails with vision weights?
The empty string was intended, since the model itself is a base model We don't want a vision_tower, because Gemma4VisionModel.vision_tower doesn't exist
| config: Gemma4AudioConfig | ||
| main_input_name = "input_features" | ||
| base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained() | ||
| base_model_prefix = "audio_tower" |
There was a problem hiding this comment.
I remember @eustlb had a reason, also smth with loading
|
hi@zucchini-nlp After deeper investigation, I found the actual root cause. The bug is in The named_parameters = dict(module.named_parameters(..., recurse=False))
for k in named_parameters:
if k in state_dict:
missing_keys.discard(k) # buffers never reach hereBuffers ( Reproduced with tiny config on 2xRTX40:
Fix should be in named_buffers = dict(module.named_buffers(prefix=prefix[:-1], recurse=False))
for k, buf in named_buffers.items():
if k in state_dict:
missing_keys.discard(k)
if torch.distributed.get_rank() == 0:
with torch.no_grad():
buf.copy_(state_dict[k])This affects ANY model with registered buffers under ZeRO-3, |
|
Thanks a lot for digging, Cyril is off today and will review tomorrow prob. And for deepspeed also cc @SunMarc |
|
Reproduction script — creates a tiny Gemma4 model, saves without ZeRO-3, loads with ZeRO-3, checks for MISSING buffers: from transformers.integrations import HfDeepSpeedConfig
from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration, Gemma4AudioModel
from transformers.models.gemma4.configuration_gemma4 import Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig, Gemma4AudioConfig
print("AudioModel prefix:", Gemma4AudioModel.base_model_prefix)
text_config = Gemma4TextConfig(
hidden_size=128, num_hidden_layers=2, num_attention_heads=2,
intermediate_size=256, vocab_size=32000, num_key_value_heads=2, pad_token_id=None,
)
vision_config = Gemma4VisionConfig(hidden_size=64, num_hidden_layers=2, num_attention_heads=2, intermediate_size=128)
audio_config = Gemma4AudioConfig()
config = Gemma4Config(text_config=text_config, vision_config=vision_config, audio_config=audio_config)
# save without ZeRO3
model = Gemma4ForConditionalGeneration(config)
model.save_pretrained("/tmp/tiny_gemma4_full")
print("saved!")
del model
# load with ZeRO3
ds_config = {
"bf16": {"enabled": True},
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": 2,
"zero_optimization": {"stage": 3}
}
dschf = HfDeepSpeedConfig(ds_config)
model2 = Gemma4ForConditionalGeneration.from_pretrained(
"/tmp/tiny_gemma4_full",
torch_dtype=torch.bfloat16
)
print("loaded!") |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks ! Can you add a test in this class TestNonTrainerIntegrationDeepSpeed ? you will find it in tests/trainer/distributed/test_trainer_distributed_deepspeed.py. I think that we didn't catch this since we don't really register persistant buffer in transformers. Thanks for fixing this. cc @Cyrilvallez for confirmation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
4674baf to
e7ff996
Compare
|
Added the test in TestNonTrainerIntegrationDeepSpeed as requested. It creates a tiny Gemma4 model, saves without ZeRO-3, loads with ZeRO-3, and asserts no registered buffers are MISSING.cc @Cyrilvallez |
e7ff996 to
c414759
Compare
|
test_generate_with_static_cache passes locally but fails on CI — this appears to be a flaky test due to floating point precision differences across hardware environments, unrelated to this PR's changes (ZeRO-3 buffer loading). |
|
I've tested the PR and there were a couple of issues. Can you do the following changes @saslifat-gif ? # Buffers are not partitioned by ZeRO-3, load them directly on all ranks
named_buffers = dict(module.named_buffers(prefix=prefix[:-1], recurse=False))
for k, buf in named_buffers.items():
if k in state_dict and buf is not None:
missing_keys.discard(k)
with torch.no_grad():
buf.copy_(state_dict[k])
buf._is_hf_initialized = TrueFor the test, we also need to check that the values are correct: def test_zero3_load_registered_buffers(self):
"""Test that registered buffers are loaded with correct values under ZeRO-3 from_pretrained."""
from transformers.models.gemma4.configuration_gemma4 import (
Gemma4AudioConfig,
Gemma4Config,
Gemma4TextConfig,
Gemma4VisionConfig,
)
from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration
text_config = Gemma4TextConfig(
hidden_size=128,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=256,
vocab_size=32000,
num_key_value_heads=2,
pad_token_id=0,
)
vision_config = Gemma4VisionConfig(
hidden_size=64, num_hidden_layers=2, num_attention_heads=2, intermediate_size=128
)
audio_config = Gemma4AudioConfig()
config = Gemma4Config(text_config=text_config, vision_config=vision_config, audio_config=audio_config)
# Save without ZeRO-3, with non-default buffer values
save_path = self.get_auto_remove_tmp_dir()
model = Gemma4ForConditionalGeneration(config)
for name, buf in model.named_buffers():
if "input_max" in name:
buf.fill_(42.0)
elif "output_min" in name:
buf.fill_(-42.0)
elif "layer_scalar" in name:
buf.fill_(0.5)
model.save_pretrained(save_path)
del model
# Load with ZeRO-3
ds_config = self._get_zero3_ds_config(bf16={"enabled": True})
dschf = HfDeepSpeedConfig(ds_config)
self.assertTrue(dschf.is_zero3())
with mockenv_context(**self.dist_env_1_gpu):
model2 = Gemma4ForConditionalGeneration.from_pretrained(save_path, torch_dtype=torch.bfloat16)
# Verify buffer VALUES were loaded from checkpoint, not re-initialized
for name, buf in model2.named_buffers():
if "input_max" in name:
self.assertEqual(buf.item(), 42.0, f"{name} was not loaded from checkpoint")
elif "output_min" in name:
self.assertEqual(buf.item(), -42.0, f"{name} was not loaded from checkpoint")
elif "layer_scalar" in name:
self.assertEqual(buf.item(), 0.5, f"{name} was not loaded from checkpoint") |
…udioModel and VisionModel
Buffers registered via register_buffer() were completely skipped during from_pretrained() under DeepSpeed ZeRO-3. The load() function in _load_state_dict_into_zero3_model only iterated over named_parameters, never named_buffers, so buffer values from checkpoint were never loaded and always reported as MISSING. Fix: after gathering and loading parameters, explicitly load buffers directly (no GatheredParameters needed since buffers are not sharded by ZeRO-3). Fixes huggingface#45397
- Remove rank==0 guard so buffers are copied on all ranks - Set buf._is_hf_initialized = True after copy to prevent re-initialization - Update test to verify buffer VALUES survive ZeRO-3 from_pretrained round-trip
16c4474 to
a9fcac8
Compare
|
Thank you for testing and the clear feedback! I've applied both changes — removed the rank==0 guard so buffers are copied on all ranks, added buf._is_hf_initialized = True, and updated the test to verify the actual values survive the round-trip. |


Fixes #45397
What does this PR do?
Fixes #45397
Root cause:
_load_state_dict_into_zero3_modelinsrc/transformers/integrations/deepspeed.pyonly iterates overnamed_parameters— nevernamed_buffers. Buffers registered viaregister_buffer()are completely skipped during ZeRO-3 loading,causing them to always appear as MISSING.
Fix: After gathering and loading parameters, explicitly load
buffers directly. Buffers don't need
GatheredParameterssinceZeRO-3 doesn't shard them.
Impact: Affects ANY model with registered buffers under ZeRO-3,
not just Gemma-4. Gemma-4's
Gemma4ClippableLinearuses buffersfor clipping values (
input_max,output_min,output_max).Tested: Reproduced bug and verified fix on 2xRTX40 GPUs.