diff --git a/deepspeed/runtime/sequence_parallel/ulysses_sp.py b/deepspeed/runtime/sequence_parallel/ulysses_sp.py index 70073bdd696c..eb74cbbf639a 100644 --- a/deepspeed/runtime/sequence_parallel/ulysses_sp.py +++ b/deepspeed/runtime/sequence_parallel/ulysses_sp.py @@ -389,8 +389,8 @@ def register_with_transformers( mpu.initialize_sequence_parallel(sequence_parallel_size=sequence_parallel_size) from transformers import PreTrainedModel - if isinstance(model_name_or_path, PreTrainedModel): - # we already have the model + if hasattr(model_name_or_path, "config") or isinstance(model_name_or_path, PreTrainedModel): + # we already have the model (or a PEFT wrapper with config attribute) hf_model_config = model_name_or_path.config else: # if we don't have the model yet at this stage diff --git a/tests/unit/ulysses_alst/test_ulysses_sp_hf.py b/tests/unit/ulysses_alst/test_ulysses_sp_hf.py index 27db398d7189..9745146089b1 100644 --- a/tests/unit/ulysses_alst/test_ulysses_sp_hf.py +++ b/tests/unit/ulysses_alst/test_ulysses_sp_hf.py @@ -185,3 +185,50 @@ def collate_fn(batch): torch_assert_close(grad_a, grad_b, rtol=1.6e-02, atol=1e-03) else: torch_assert_close(grad_a, grad_b) + + +class TestUlyssesSPHFPEFT(DistributedTest): + world_size = 2 + + def test_ulysses_sp_hf_with_peft_model(self): + """Test that UlyssesSPAttentionHF.register_with_transformers works with PEFT models. + + PEFT models don't inherit from transformers.PreTrainedModel but have a config attribute. + This test verifies the duck-typing check for the config attribute works correctly. + """ + model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + seq_length = 64 + sequence_parallel_size = self.world_size + micro_batch_size = 1 + + # Create a mock PEFT model object that has config but doesn't inherit from PreTrainedModel + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(model_name_or_path) + + class MockPEFTModel: + """Mock PEFT model that simulates PeftModel behavior""" + + def __init__(self, config): + self.config = config + + mock_peft_model = MockPEFTModel(hf_config) + + # Test that register_with_transformers works with PEFT-like model object + # This should not crash and should use the config attribute via duck-typing + mpu = UlyssesSPAttentionHF.register_with_transformers( + model_name_or_path=mock_peft_model, + core_attn_implementation="sdpa", + sequence_parallel_size=sequence_parallel_size, + micro_batch_size=micro_batch_size, + seq_length=seq_length, + seq_length_is_variable=True, + ) + + # Verify mpu is created successfully + assert mpu is not None + + # Verify that the sequence parallel groups are initialized + sp_group = groups._get_sequence_parallel_group() + assert sp_group is not None + sp_world_size = groups._get_sequence_parallel_world_size() + assert sp_world_size == sequence_parallel_size