Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepspeed/runtime/sequence_parallel/ulysses_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ def register_with_transformers(
mpu.initialize_sequence_parallel(sequence_parallel_size=sequence_parallel_size)

from transformers import PreTrainedModel
if isinstance(model_name_or_path, PreTrainedModel):
# we already have the model
if hasattr(model_name_or_path, "config") or isinstance(model_name_or_path, PreTrainedModel):
# we already have the model (or a PEFT wrapper with config attribute)
hf_model_config = model_name_or_path.config
else:
# if we don't have the model yet at this stage
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/ulysses_alst/test_ulysses_sp_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,50 @@ def collate_fn(batch):
torch_assert_close(grad_a, grad_b, rtol=1.6e-02, atol=1e-03)
else:
torch_assert_close(grad_a, grad_b)


class TestUlyssesSPHFPEFT(DistributedTest):
world_size = 2

def test_ulysses_sp_hf_with_peft_model(self):
"""Test that UlyssesSPAttentionHF.register_with_transformers works with PEFT models.

PEFT models don't inherit from transformers.PreTrainedModel but have a config attribute.
This test verifies the duck-typing check for the config attribute works correctly.
"""
model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM'
seq_length = 64
sequence_parallel_size = self.world_size
micro_batch_size = 1

# Create a mock PEFT model object that has config but doesn't inherit from PreTrainedModel
from transformers import AutoConfig
hf_config = AutoConfig.from_pretrained(model_name_or_path)

class MockPEFTModel:
"""Mock PEFT model that simulates PeftModel behavior"""

def __init__(self, config):
self.config = config

mock_peft_model = MockPEFTModel(hf_config)

# Test that register_with_transformers works with PEFT-like model object
# This should not crash and should use the config attribute via duck-typing
mpu = UlyssesSPAttentionHF.register_with_transformers(
model_name_or_path=mock_peft_model,
core_attn_implementation="sdpa",
sequence_parallel_size=sequence_parallel_size,
micro_batch_size=micro_batch_size,
seq_length=seq_length,
seq_length_is_variable=True,
)

# Verify mpu is created successfully
assert mpu is not None

# Verify that the sequence parallel groups are initialized
sp_group = groups._get_sequence_parallel_group()
assert sp_group is not None
sp_world_size = groups._get_sequence_parallel_world_size()
assert sp_world_size == sequence_parallel_size