Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TrainingConfig,
)
from torchtitan.distributed import ParallelDims

from torchtitan.distributed.context_parallel import apply_cp_to_attention_module
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.experiments.graph_trainer.common_utils import (
annotate_ac_regions,
Expand Down Expand Up @@ -138,6 +138,12 @@ def parallelize_deepseekv3(
pad_multiple=pad_multiple,
)

if parallel_dims.cp_enabled:
apply_cp_to_attention_module(
[block.attention.inner_attention for block in model.layers.values()],
parallel_dims.get_mesh("cp"),
)

if ac_config.mode != "none":
apply_graph_ac(compile_config, ac_config)

Expand Down
9 changes: 9 additions & 0 deletions torchtitan/experiments/graph_trainer/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TrainingConfig,
)
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.context_parallel import apply_cp_to_attention_module
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.experiments.graph_trainer.common_utils import (
annotate_ac_regions,
Expand Down Expand Up @@ -101,9 +102,17 @@ def parallelize_llama(
tp_mesh,
enable_loss_parallel=not parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
enable_cp=parallel_dims.cp_enabled,
enable_sp=parallelism.enable_sequence_parallel,
)
maybe_enable_async_tp(parallelism, compile_config, tp_mesh)

if parallel_dims.cp_enabled:
apply_cp_to_attention_module(
[block.attention.inner_attention for block in model.layers.values()],
parallel_dims.get_mesh("cp"),
)

if ac_config.mode != "none":
apply_graph_ac(compile_config, ac_config)

Expand Down
Loading