Skip to content
Merged
11 changes: 10 additions & 1 deletion src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=Fals
for k in named_parameters:
if k in state_dict:
param = named_parameters[k]
# crutial to not init the weight again
# crucial to not init the weight again
param._is_hf_initialized = True
params_to_gather.append(param)
missing_keys.discard(k)
Expand All @@ -504,6 +504,15 @@ def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=Fals
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)

# Buffers are not partitioned by ZeRO-3, load them directly
named_buffers = dict(module.named_buffers(prefix=prefix[:-1], recurse=False))
for k, buf in named_buffers.items():
if k in state_dict and buf is not None:
missing_keys.discard(k)
with torch.no_grad():
buf.copy_(state_dict[k])
buf._is_hf_initialized = True

for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
Expand Down
54 changes: 54 additions & 0 deletions tests/trainer/distributed/test_trainer_distributed_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,6 +1526,60 @@ def test_resize_token_embeddings_zero3(self):
with deepspeed.zero.GatheredParameters([embedding.weight]):
self.assertEqual(embedding.weight.shape[0], new_size)

def test_zero3_load_registered_buffers(self):
"""Test that registered buffers are loaded with correct values under ZeRO-3 from_pretrained."""
from transformers.models.gemma4.configuration_gemma4 import (
Gemma4AudioConfig,
Gemma4Config,
Gemma4TextConfig,
Gemma4VisionConfig,
)
from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration

text_config = Gemma4TextConfig(
hidden_size=128,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=256,
vocab_size=32000,
num_key_value_heads=2,
pad_token_id=0,
)
vision_config = Gemma4VisionConfig(
hidden_size=64, num_hidden_layers=2, num_attention_heads=2, intermediate_size=128
)
audio_config = Gemma4AudioConfig()
config = Gemma4Config(text_config=text_config, vision_config=vision_config, audio_config=audio_config)

# Save without ZeRO-3, with non-default buffer values
save_path = self.get_auto_remove_tmp_dir()
model = Gemma4ForConditionalGeneration(config)
for name, buf in model.named_buffers():
if "input_max" in name:
buf.fill_(42.0)
elif "output_min" in name:
buf.fill_(-42.0)
elif "layer_scalar" in name:
buf.fill_(0.5)
model.save_pretrained(save_path)
del model

# Load with ZeRO-3
ds_config = self._get_zero3_ds_config(bf16={"enabled": True})
dschf = HfDeepSpeedConfig(ds_config)
self.assertTrue(dschf.is_zero3())
with mockenv_context(**self.dist_env_1_gpu):
model2 = Gemma4ForConditionalGeneration.from_pretrained(save_path, torch_dtype=torch.bfloat16)

# Verify buffer VALUES were loaded from checkpoint, not re-initialized
for name, buf in model2.named_buffers():
if "input_max" in name:
self.assertEqual(buf.item(), 42.0, f"{name} was not loaded from checkpoint")
elif "output_min" in name:
self.assertEqual(buf.item(), -42.0, f"{name} was not loaded from checkpoint")
elif "layer_scalar" in name:
self.assertEqual(buf.item(), 0.5, f"{name} was not loaded from checkpoint")


# ---------------------------------------------------------------------------
# Model Zoo — test many architectures with DeepSpeed + zero_to_fp32 recovery
Expand Down
Loading