diff --git a/tests/fsdp/test_context_parallel.py b/tests/fsdp/test_context_parallel.py new file mode 100644 index 000000000000..8e0b58a32187 --- /dev/null +++ b/tests/fsdp/test_context_parallel.py @@ -0,0 +1,224 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import sys +from pathlib import Path + +from transformers import is_torch_available +from transformers.testing_utils import ( + TestCasePlus, + execute_subprocess_async, + require_accelerate, + require_torch_multi_accelerator, + run_first, + slow, +) + + +if is_torch_available(): + import torch + + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + HfArgumentParser, + Trainer, + TrainingArguments, + ) + + +class TestContextParallel(TestCasePlus): + """Test Trainer with Torch context parallelism enabled via accelerate's ParallelismConfig.""" + + @require_torch_multi_accelerator + @require_accelerate + @slow + @run_first + def test_cp_equivalence(self): + """Test that CP produces the same losses as without CP.""" + + # Shared setup + world_size = 2 + script_path = __file__ + + # Step 1: Run with CP enabled (cp_size=world_size) + cp_yes_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve() + cp_yes_config_path = cp_yes_output_dir / "context_parallel_config.yaml" + cp_yes_losses_path = cp_yes_output_dir / "cp_yes_losses.json" + + # Write config file inline (self-contained test) + with open(cp_yes_config_path, "w") as f: + f.write( + f"""distributed_type: FSDP +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_version: 2 +mixed_precision: bf16 +num_processes: {world_size} +parallelism_config: + parallelism_config_dp_replicate_size: 1 + parallelism_config_dp_shard_size: 1 + parallelism_config_tp_size: 1 + parallelism_config_cp_size: {world_size} + parallelism_config_cp_comm_strategy: alltoall +""" + ) + + cmd_cp_yes = f""" + accelerate launch + --config_file {cp_yes_config_path} + {script_path} + --output_dir {cp_yes_output_dir} + --report_to none + --max_steps 10 + --per_device_train_batch_size 1 + --gradient_accumulation_steps 1 + --logging_steps 1 + --remove_unused_columns False + --seed 42 + --loss_output_file {cp_yes_losses_path} + """.split() + + execute_subprocess_async(cmd_cp_yes, env=self.get_env()) + + # Step 2: Run without CP (FSDP with num_processes=1, no parallelism_config) + cp_no_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve() + cp_no_config_path = cp_no_output_dir / "context_parallel_config.yaml" + cp_no_losses_path = cp_no_output_dir / "cp_no_losses.json" + + # Write config file inline (self-contained test) + with open(cp_no_config_path, "w") as f: + f.write( + """distributed_type: FSDP +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_version: 2 +mixed_precision: bf16 +num_processes: 1 +""" + ) + + cmd_cp_no = f""" + accelerate launch + --config_file {cp_no_config_path} + {script_path} + --output_dir {cp_no_output_dir} + --report_to none + --max_steps 10 + --per_device_train_batch_size 1 + --gradient_accumulation_steps 1 + --logging_steps 1 + --remove_unused_columns False + --seed 42 + --loss_output_file {cp_no_losses_path} + """.split() + + execute_subprocess_async(cmd_cp_no, env=self.get_env()) + + # Compare losses - should be very close since CP just splits sequence computation + with open(cp_yes_losses_path) as f: + cp_yes_losses = json.load(f) + with open(cp_no_losses_path) as f: + cp_no_losses = json.load(f) + + assert len(cp_yes_losses) == len(cp_no_losses), ( + f"Different number of losses: CP has {len(cp_yes_losses)}, no-CP has {len(cp_no_losses)}" + ) + + # CP should produce very similar results (small numerical differences expected) + # The differences come from: + # - Different gradient reduction patterns in distributed training + # - BF16 mixed precision accumulated differences + # - Sequence splitting and gathering in CP mode + cp_yes_losses_tensor = torch.tensor(cp_yes_losses) + cp_no_losses_tensor = torch.tensor(cp_no_losses) + + # Use torch.testing.assert_close with rtol=2% and atol=0.02 + # Testing shows actual differences are typically <1.5% + torch.testing.assert_close( + cp_yes_losses_tensor, + cp_no_losses_tensor, + rtol=2e-2, # 2% relative tolerance + atol=2e-2, # 0.02 absolute tolerance + msg=f"CP losses {cp_yes_losses} do not match non-CP losses {cp_no_losses}", + ) + + +if __name__ == "__main__": + # Parse custom arguments (not TrainingArguments parameters) + loss_output_file = None + + if "--loss_output_file" in sys.argv: + idx = sys.argv.index("--loss_output_file") + loss_output_file = sys.argv[idx + 1] + sys.argv.pop(idx) + sys.argv.pop(idx) + + parser = HfArgumentParser((TrainingArguments,)) + training_args = parser.parse_args_into_dataclasses()[0] + + # Use SmolLM (small Llama-based model that works with CP) + model_name = "HuggingFaceTB/SmolLM-135M" + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + model_name, + attn_implementation="sdpa", # CP requires SDPA + ) + + # Create simple dataset: just tokenize some text + texts = [ + "The quick brown fox jumps over the lazy dog. " * 10, + "Hello world, this is a test sentence for training. " * 10, + ] * 4 # 8 samples total + + def tokenize_function(examples): + return tokenizer(examples, max_length=128, truncation=True, padding="max_length") + + train_dataset = [tokenize_function(text) for text in texts] + + # Use standard DataCollatorForLanguageModeling for causal LM + # pad_to_multiple_of=4 ensures sequences are divisible by cp_size * 2 (for cp_size=2) + # Trainer will automatically generate position_ids and shift_labels as needed + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, # Causal language modeling + pad_to_multiple_of=4, + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + data_collator=data_collator, + ) + + # Train for a few steps + trainer.train() + + # Verify training completed + assert trainer.state.global_step > 0, "Training should have completed at least one step" + + # Save losses to file if requested (for equivalence testing) + if loss_output_file and training_args.process_index == 0: + losses = [log["loss"] for log in trainer.state.log_history if "loss" in log] + with open(loss_output_file, "w") as f: + json.dump(losses, f)