diff --git a/torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py b/torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py index 09e533ba4f..9e243fc188 100644 --- a/torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py @@ -16,7 +16,7 @@ TrainingConfig, ) from torchtitan.distributed import ParallelDims - +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.experiments.graph_trainer.common_utils import ( annotate_ac_regions, @@ -138,6 +138,12 @@ def parallelize_deepseekv3( pad_multiple=pad_multiple, ) + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + ) + if ac_config.mode != "none": apply_graph_ac(compile_config, ac_config) diff --git a/torchtitan/experiments/graph_trainer/llama3/parallelize.py b/torchtitan/experiments/graph_trainer/llama3/parallelize.py index ece0627b5a..d707edad12 100644 --- a/torchtitan/experiments/graph_trainer/llama3/parallelize.py +++ b/torchtitan/experiments/graph_trainer/llama3/parallelize.py @@ -15,6 +15,7 @@ TrainingConfig, ) from torchtitan.distributed import ParallelDims +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.experiments.graph_trainer.common_utils import ( annotate_ac_regions, @@ -101,9 +102,17 @@ def parallelize_llama( tp_mesh, enable_loss_parallel=not parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + enable_cp=parallel_dims.cp_enabled, + enable_sp=parallelism.enable_sequence_parallel, ) maybe_enable_async_tp(parallelism, compile_config, tp_mesh) + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + ) + if ac_config.mode != "none": apply_graph_ac(compile_config, ac_config)