Skip to content

Fix ZeRO-3 from_pretrained: load registered buffers in _load_state_dict_into_zero3_model#45402

Merged
SunMarc merged 9 commits intohuggingface:mainfrom
saslifat-gif:fix/gemma4-zero3-base-model-prefix
Apr 17, 2026
Merged

Fix ZeRO-3 from_pretrained: load registered buffers in _load_state_dict_into_zero3_model#45402
SunMarc merged 9 commits intohuggingface:mainfrom
saslifat-gif:fix/gemma4-zero3-base-model-prefix

Conversation

@saslifat-gif
Copy link
Copy Markdown
Contributor

@saslifat-gif saslifat-gif commented Apr 13, 2026

Fixes #45397

What does this PR do?

Fixes #45397

Root cause: _load_state_dict_into_zero3_model in
src/transformers/integrations/deepspeed.py only iterates over
named_parameters — never named_buffers. Buffers registered via
register_buffer() are completely skipped during ZeRO-3 loading,
causing them to always appear as MISSING.

Fix: After gathering and loading parameters, explicitly load
buffers directly. Buffers don't need GatheredParameters since
ZeRO-3 doesn't shard them.

Impact: Affects ANY model with registered buffers under ZeRO-3,
not just Gemma-4. Gemma-4's Gemma4ClippableLinear uses buffers
for clipping values (input_max, output_min, output_max).

Tested: Reproduced bug and verified fix on 2xRTX40 GPUs.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma4

class Gemma4VisionModel(Gemma4PreTrainedModel):
"""The Gemma 4 Vision Encoder."""

base_model_prefix = "vision_tower"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if deepspeed complains about vision as well, can you confirm that it fails with vision weights?

The empty string was intended, since the model itself is a base model We don't want a vision_tower, because Gemma4VisionModel.vision_tower doesn't exist

config: Gemma4AudioConfig
main_input_name = "input_features"
base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained()
base_model_prefix = "audio_tower"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember @eustlb had a reason, also smth with loading

@saslifat-gif
Copy link
Copy Markdown
Contributor Author

hi@zucchini-nlp After deeper investigation, I found the actual root cause.

The bug is in _load_state_dict_into_zero3_model in
src/transformers/integrations/deepspeed.py.

The load() function only iterates over named_parameters:

named_parameters = dict(module.named_parameters(..., recurse=False))
for k in named_parameters:
    if k in state_dict:
        missing_keys.discard(k)  # buffers never reach here

Buffers (named_buffers) are completely skipped — never loaded,
never discarded from missing_keys → show as MISSING.

Reproduced with tiny config on 2xRTX40:

  • 240 ClippableLinear buffers in checkpoint ✓
  • 240 buffers in model state_dict ✓
  • 0/240 buffers loaded under ZeRO-3 ✗

Fix should be in _load_state_dict_into_zero3_model to handle
buffers separately (no GatheredParameters needed — buffers aren't
sharded):

named_buffers = dict(module.named_buffers(prefix=prefix[:-1], recurse=False))
for k, buf in named_buffers.items():
    if k in state_dict:
        missing_keys.discard(k)
        if torch.distributed.get_rank() == 0:
            with torch.no_grad():
                buf.copy_(state_dict[k])

This affects ANY model with registered buffers under ZeRO-3,
not just Gemma-4. The base_model_prefix change in this PR
is incorrect — I'll update the fix.

@saslifat-gif saslifat-gif changed the title Fix Gemma4 ZeRO-3 weight loading by correcting base_model_prefix in AudioModel and VisionModel Fix ZeRO-3 from_pretrained: load registered buffers in _load_state_dict_into_zero3_model Apr 13, 2026
@saslifat-gif
Copy link
Copy Markdown
Contributor Author

saslifat-gif commented Apr 13, 2026

hi @zucchini-nlp @Cyrilvallez

Found the real root cause and updated the PR title and description.
The previous base_model_prefix hypothesis was wrong — reverted that change.

Real fix is in _load_state_dict_into_zero3_model in deepspeed.py.
Reproduced and verified on 2xRTX40 GPUs.
after:
after
before:
before

@zucchini-nlp
Copy link
Copy Markdown
Member

Thanks a lot for digging, Cyril is off today and will review tomorrow prob. And for deepspeed also cc @SunMarc

@saslifat-gif
Copy link
Copy Markdown
Contributor Author

Reproduction script — creates a tiny Gemma4 model, saves without ZeRO-3, loads with ZeRO-3, checks for MISSING buffers:

from transformers.integrations import HfDeepSpeedConfig
from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration, Gemma4AudioModel
from transformers.models.gemma4.configuration_gemma4 import Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig, Gemma4AudioConfig

print("AudioModel prefix:", Gemma4AudioModel.base_model_prefix)

text_config = Gemma4TextConfig(
    hidden_size=128, num_hidden_layers=2, num_attention_heads=2,
    intermediate_size=256, vocab_size=32000, num_key_value_heads=2, pad_token_id=None,
)
vision_config = Gemma4VisionConfig(hidden_size=64, num_hidden_layers=2, num_attention_heads=2, intermediate_size=128)
audio_config = Gemma4AudioConfig()
config = Gemma4Config(text_config=text_config, vision_config=vision_config, audio_config=audio_config)

# save without ZeRO3
model = Gemma4ForConditionalGeneration(config)
model.save_pretrained("/tmp/tiny_gemma4_full")
print("saved!")
del model

# load with ZeRO3
ds_config = {
    "bf16": {"enabled": True},
    "train_micro_batch_size_per_gpu": 1,
    "train_batch_size": 2,
    "zero_optimization": {"stage": 3}
}
dschf = HfDeepSpeedConfig(ds_config)
model2 = Gemma4ForConditionalGeneration.from_pretrained(
    "/tmp/tiny_gemma4_full",
    torch_dtype=torch.bfloat16
)
print("loaded!")

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Can you add a test in this class TestNonTrainerIntegrationDeepSpeed ? you will find it in tests/trainer/distributed/test_trainer_distributed_deepspeed.py. I think that we didn't catch this since we don't really register persistant buffer in transformers. Thanks for fixing this. cc @Cyrilvallez for confirmation

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@saslifat-gif saslifat-gif force-pushed the fix/gemma4-zero3-base-model-prefix branch from 4674baf to e7ff996 Compare April 15, 2026 05:54
@saslifat-gif
Copy link
Copy Markdown
Contributor Author

Added the test in TestNonTrainerIntegrationDeepSpeed as requested. It creates a tiny Gemma4 model, saves without ZeRO-3, loads with ZeRO-3, and asserts no registered buffers are MISSING.cc @Cyrilvallez

@saslifat-gif saslifat-gif force-pushed the fix/gemma4-zero3-base-model-prefix branch from e7ff996 to c414759 Compare April 16, 2026 06:20
@saslifat-gif
Copy link
Copy Markdown
Contributor Author

test_generate_with_static_cache passes locally but fails on CI — this appears to be a flaky test due to floating point precision differences across hardware environments, unrelated to this PR's changes (ZeRO-3 buffer loading).

tests/models/gemma4/test_modeling_gemma4.py::Gemma4TextModelTest::test_generate_with_static_cache PASSED         [100%]tests/models/gemma4/test_modeling_gemma4.py::Gemma4TextModelTest::test_generate_with_static_cache [PASSED] 0.17s


=================================================== warnings summary ===================================================
../../../opt/homebrew/Caskroom/miniconda/base/lib/python3.13/site-packages/torch/jit/_script.py:1488: 25 warnings
  /opt/homebrew/Caskroom/miniconda/base/lib/python3.13/site-packages/torch/jit/_script.py:1488: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================== 1 passed, 25 warnings in 41.93s ============================================
No stash entries found.

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Apr 16, 2026

I've tested the PR and there were a couple of issues. Can you do the following changes @saslifat-gif ?
We need to set _is_hf_initialized = True and also we need to copy the buffer on all ranks

            # Buffers are not partitioned by ZeRO-3, load them directly on all ranks
            named_buffers = dict(module.named_buffers(prefix=prefix[:-1], recurse=False))
            for k, buf in named_buffers.items():
                if k in state_dict and buf is not None:
                    missing_keys.discard(k)
                    with torch.no_grad():
                        buf.copy_(state_dict[k])
                    buf._is_hf_initialized = True

For the test, we also need to check that the values are correct:

    def test_zero3_load_registered_buffers(self):
        """Test that registered buffers are loaded with correct values under ZeRO-3 from_pretrained."""
        from transformers.models.gemma4.configuration_gemma4 import (
            Gemma4AudioConfig,
            Gemma4Config,
            Gemma4TextConfig,
            Gemma4VisionConfig,
        )
        from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration

        text_config = Gemma4TextConfig(
            hidden_size=128,
            num_hidden_layers=2,
            num_attention_heads=2,
            intermediate_size=256,
            vocab_size=32000,
            num_key_value_heads=2,
            pad_token_id=0,
        )
        vision_config = Gemma4VisionConfig(
            hidden_size=64, num_hidden_layers=2, num_attention_heads=2, intermediate_size=128
        )
        audio_config = Gemma4AudioConfig()
        config = Gemma4Config(text_config=text_config, vision_config=vision_config, audio_config=audio_config)

        # Save without ZeRO-3, with non-default buffer values
        save_path = self.get_auto_remove_tmp_dir()
        model = Gemma4ForConditionalGeneration(config)
        for name, buf in model.named_buffers():
            if "input_max" in name:
                buf.fill_(42.0)
            elif "output_min" in name:
                buf.fill_(-42.0)
            elif "layer_scalar" in name:
                buf.fill_(0.5)
        model.save_pretrained(save_path)
        del model

        # Load with ZeRO-3
        ds_config = self._get_zero3_ds_config(bf16={"enabled": True})
        dschf = HfDeepSpeedConfig(ds_config)
        self.assertTrue(dschf.is_zero3())
        with mockenv_context(**self.dist_env_1_gpu):
            model2 = Gemma4ForConditionalGeneration.from_pretrained(save_path, torch_dtype=torch.bfloat16)

        # Verify buffer VALUES were loaded from checkpoint, not re-initialized
        for name, buf in model2.named_buffers():
            if "input_max" in name:
                self.assertEqual(buf.item(), 42.0, f"{name} was not loaded from checkpoint")
            elif "output_min" in name:
                self.assertEqual(buf.item(), -42.0, f"{name} was not loaded from checkpoint")
            elif "layer_scalar" in name:
                self.assertEqual(buf.item(), 0.5, f"{name} was not loaded from checkpoint")

saslifat-gif added 8 commits April 16, 2026 23:27
Buffers registered via register_buffer() were completely skipped
during from_pretrained() under DeepSpeed ZeRO-3. The load() function
in _load_state_dict_into_zero3_model only iterated over named_parameters,
never named_buffers, so buffer values from checkpoint were never loaded
and always reported as MISSING.

Fix: after gathering and loading parameters, explicitly load buffers
directly (no GatheredParameters needed since buffers are not sharded
by ZeRO-3).

Fixes huggingface#45397
- Remove rank==0 guard so buffers are copied on all ranks
- Set buf._is_hf_initialized = True after copy to prevent re-initialization
- Update test to verify buffer VALUES survive ZeRO-3 from_pretrained round-trip
@saslifat-gif saslifat-gif force-pushed the fix/gemma4-zero3-base-model-prefix branch from 16c4474 to a9fcac8 Compare April 16, 2026 15:41
@saslifat-gif
Copy link
Copy Markdown
Contributor Author

Thank you for testing and the clear feedback! I've applied both changes — removed the rank==0 guard so buffers are copied on all ranks, added buf._is_hf_initialized = True, and updated the test to verify the actual values survive the round-trip.

@SunMarc SunMarc enabled auto-merge April 17, 2026 13:35
@SunMarc SunMarc added this pull request to the merge queue Apr 17, 2026
Merged via the queue into huggingface:main with commit f67ebcd Apr 17, 2026
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] gemma-4 zero3 from_pretrained

4 participants