diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 39119f2381..cc12939d46 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -222,10 +222,9 @@ def __init__( f"World size({world_size}) must equal to dp_size({dp_size}) * tp_size({tp_size}) * cp_size({cp_size}) to use DTensor" ) - if sequence_parallel_enabled: - assert tp_size > 1, ( - "Sequence parallel needs to be used together with tensor parallel. " - "Please either set tp_size > 1 or disable sequence parallel." + if sequence_parallel_enabled and tp_size == 1: + print( + "[WARNING]: sequence_parallel=True, but tp_size=1 which has no effect. Enable tp_size > 1 to use sequence parallelism." ) if cp_size > 1: diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index e208873353..fcd0977117 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -42,7 +42,7 @@ def create_test_config( sequence_parallel: bool = False, cpu_offload: bool = False, activation_checkpointing: bool = False, - custom_parallel_plan: str = None, + custom_parallel_plan: str | None = None, ) -> PolicyConfig: return { "model_name": model_name, @@ -237,7 +237,7 @@ def test_lm_policy_init(policy_setup): @pytest.fixture def training_setup(request, two_gpu_virtual_cluster): """Setup and teardown specifically for training tests.""" - model_name, tp, cp, cpu_offload, sequence_parallel, activation_checkpointing = ( + model_name, tp, cp, sequence_parallel, cpu_offload, activation_checkpointing = ( request.param ) policy = None @@ -246,7 +246,7 @@ def training_setup(request, two_gpu_virtual_cluster): try: config = create_test_config( - model_name, tp, cp, cpu_offload, sequence_parallel, activation_checkpointing + model_name, tp, cp, sequence_parallel, cpu_offload, activation_checkpointing ) tokenizer = get_tokenizer(config["tokenizer"]) print( @@ -300,8 +300,7 @@ def training_setup(request, two_gpu_virtual_cluster): @pytest.mark.parametrize( "training_setup", [ - # model_name, tp, cp, cpu_offload, sequence_parallel, activation_checkpointing - # Split grid over tp/cp/cpu/sp/act across qwen and llama + # model_name tp cp sp cpu act (TEST_ASSETS.TINY_LLAMA_MODEL_PATH, 1, 1, False, False, False), (TEST_ASSETS.TINY_LLAMA_MODEL_PATH, 1, 1, True, False, False), (TEST_ASSETS.TINY_LLAMA_MODEL_PATH, 1, 1, False, True, False), @@ -317,7 +316,14 @@ def training_setup(request, two_gpu_virtual_cluster): (TEST_ASSETS.TINY_QWEN3_MODEL_PATH, 1, 1, False, True, True), (TEST_ASSETS.TINY_QWEN3_MODEL_PATH, 1, 1, True, True, True), (TEST_ASSETS.TINY_QWEN3_MODEL_PATH, 1, 2, False, False, False), - (TEST_ASSETS.TINY_GEMMA3_MODEL_PATH, 1, 1, True, True, False), + ( + TEST_ASSETS.TINY_GEMMA3_MODEL_PATH, + 1, + 1, + True, + True, + False, + ), # gemma3 doesn't support spda (TEST_ASSETS.TINY_GEMMA3_MODEL_PATH, 1, 1, True, False, True), (TEST_ASSETS.TINY_GEMMA3_MODEL_PATH, 1, 1, False, True, True), (TEST_ASSETS.TINY_GEMMA3_MODEL_PATH, 1, 1, True, True, True), @@ -363,7 +369,7 @@ def verify_loss_tensor(loss_tensor): @pytest.fixture def logprob_setup(request, two_gpu_virtual_cluster): """Setup and teardown specifically for training tests.""" - model_name, tp, cp, cpu_offload, sequence_parallel, activation_checkpointing = ( + model_name, tp, cp, sequence_parallel, cpu_offload, activation_checkpointing = ( request.param ) policy = None @@ -371,7 +377,7 @@ def logprob_setup(request, two_gpu_virtual_cluster): try: config = create_test_config( - model_name, tp, cp, cpu_offload, sequence_parallel, activation_checkpointing + model_name, tp, cp, sequence_parallel, cpu_offload, activation_checkpointing ) tokenizer = get_tokenizer(config["tokenizer"]) print( @@ -494,8 +500,9 @@ def test_dtensor_tp_and_tied_model_with_custom_parallel_plan(two_gpu_virtual_clu config = create_test_config( model_name=TEST_ASSETS.TINY_LLAMA_TIED_MODEL_PATH, tp=2, - cpu_offload=False, + cp=1, sequence_parallel=False, + cpu_offload=False, activation_checkpointing=False, custom_parallel_plan=custom_parallel_plan, )