Skip to content

Update propagateSharding preseg pass for DID loop split#3838

Merged
Priya2698 merged 77 commits intomainfrom
pm/preseg_sharding_prop
Apr 24, 2025
Merged

Update propagateSharding preseg pass for DID loop split#3838
Priya2698 merged 77 commits intomainfrom
pm/preseg_sharding_prop

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Feb 6, 2025

This PR extends the propagateSharding presegmentation pass for DID loop splits.
Key changes:

  1. We use TransformPropagator for all expressions except ViewOp which is handled manually since TransformPropagator does not support it without first propagating the reshape to the producer.
  2. makeReshardingContiguous sets allocation domain for tvs with device mesh. Ideally, we need to set it only for global tensors but this is not known before segmentation, but should be set before segmentation.
  3. The following tests are modified: See discussion. PR Move MarkAliasAnalysisPreparePass before propagateShardingsPass #4274 resolved this.

Follow-up PRs:

  • ViewOp will be handled in a followup PR.
  • Currently, we only backpropagate sharding for a tv that does not already have a device dimension. This can be extended to propagate for all parallel types not present on the tv. This will be done in a followup. Backpropagating shardings can incorrectly change DIDx to serial or modify DIDx to be on another location. shardAllLike can be modified to specify which parallel type to propagate. Since insertResharding and propagateSharding require different behavior, I will handle it in a separate PR.
  • Use TransformReplay::CasP in lieu of TransformPropagator.
  • Propagate DID transforms within castOp: privatizeUpcast clones cast operations, which fails segmentation since the transforms are not replicated.

Findings from experiments: #3838 (comment)

@github-actions
Copy link

github-actions bot commented Feb 6, 2025

Review updated until commit b09d926

Description

  • Updated propagateShardings pass to handle ViewOp manually.

  • Added support for multiple merges or splits in a reshape.

  • Enhanced makeReshardingContiguous to set allocation domain for TVs with device mesh.

  • Added new tests for loop split MLP and MHA.


Changes walkthrough 📝

Relevant files
Enhancement
make_resharding_contiguous.cpp
Enhance allocation domain setting                                               

csrc/preseg_passes/make_resharding_contiguous.cpp

  • Added validation of meshes for all TensorViews.
  • Implemented setLoopAndAllocationDomain to set allocation domain based
    on loop domain.
  • Updated MakeReshardingContiguousPass to use
    setLoopAndAllocationDomain.
  • +122/-17
    propagate_shardings.cpp
    Enhance propagateShardings pass                                                   

    csrc/preseg_passes/propagate_shardings.cpp

  • Added custom selector for directioned propagation.
  • Implemented selectiveReorderDIDToFront to reorder DID axis.
  • Updated propagateShardingsPass to handle ViewOp manually and propagate
    shardings from reference inputs.
  • +255/-76
    Cleanup
    test_multidevice_sharding.cpp
    Remove old propagateShardings and test                                     

    tests/cpp/test_multidevice_sharding.cpp

  • Removed old propagateShardings function.
  • Removed TransformerFwd test.
  • +0/-155 
    Tests
    test_multidevice_transformer.cpp
    Add new loop split tests                                                                 

    tests/cpp/test_multidevice_transformer.cpp

  • Added reference_loop_split_mlp and reference_loop_split_mha functions.
  • Added LoopSplitMLP and LoopSplitMHAFwd tests.
  • +138/-0 
    test_sharding.cpp
    Add ResidualAdd test                                                                         

    tests/cpp/test_sharding.cpp

    • Added ResidualAdd test to verify backpropagation of shardings.
    +35/-0   
    Documentation
    make_resharding_contiguous.h
    Update comments                                                                                   

    csrc/preseg_passes/make_resharding_contiguous.h

  • Updated comments to describe the new functionality of
    MakeReshardingContiguousPass.
  • +11/-4   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function setLoopAndAllocationDomain assumes that the allocation domain is a permutation of the logical domain, which may not always be true. This could lead to incorrect reordering and setting of the allocation domain.

    std::vector<Expr*> transform_exprs = DependencyCheck::getAllExprsBetween(
        {alloc_dom.begin(), alloc_dom.end()},
        {tv->getLoopDomain().begin(), tv->getLoopDomain().end()});
    
    NVF_ERROR(
        std::all_of(
            transform_exprs.begin(),
            transform_exprs.end(),
            [](Expr* expr) { return expr->isA<Split>(); }),
        "Expected all transform exprs to be a split between logical and loop domain during sharding propagation.");
    
    for (auto* expr : transform_exprs) {
      Split* split = dynamic_cast<Split*>(expr);
      auto find_it = std::find(alloc_dom.begin(), alloc_dom.end(), split->in());
      NVF_ERROR(
          find_it != alloc_dom.end(),
          "Split input ",
          split->in()->toString(),
          " not found in given ids: ",
          alloc_dom);
    
      auto pos = std::distance(alloc_dom.begin(), find_it);
      auto [outer_contiguity, inner_contiguity] =
          splitContiguity(contiguity.at(pos));
    
      alloc_dom[pos] = split->inner();
      alloc_dom.insert(alloc_dom.begin() + pos, split->outer());
    
      contiguity[pos] = inner_contiguity;
      contiguity.insert(contiguity.begin() + pos, outer_contiguity);
    }
    
    std::optional<std::vector<int64_t>> permutation =
        ir_utils::computePermutation(alloc_dom, tv->getLoopDomain());
    NVF_ERROR(
        permutation.has_value(),
        "Failed to find a valid permutation for reordering",
        tv->getLoopDomain(),
        " as ",
        alloc_dom);
    tv->reorder(permutation.value());
    tv->setAllocationDomain(tv->getLoopDomain(), contiguity);
    Performance Concern

    The function propagateDIDTransform uses TransformPropagator for all expressions except ViewOp, which is handled manually. This manual handling might introduce inconsistencies or performance issues that need to be evaluated.

    for (auto* ref_input : reference_inputs) {
      // Skip if the input has no device mesh or is nullptr.
      NVF_ERROR(
          ref_input != nullptr && ref_input->hasDeviceMesh(),
          "Reference input ",
          ref_input,
          " has no device mesh.");
    
      // Reorder the DID axis to the front only if it does not have a parallel
      // type already seen on the outputs. This avoids propagating the same
      // parallel type on multiple axis of the output when using multiple
      // reference inputs. Consider out [M, N] = linear (inp [M, K], weight (N,
      // K)) with inp sharded on M ([DIDx(d), M/d, K]) and weight sharded on N
      // ([DIDy(d), N/d, K]). We propagate from weights first, so the output
      // will be [M, DIDx(d), N/d]. When we propagate from inp next, we should
      // not propagate DIDx parallel type to the output. Otherwise, the output
      // will have multiple DIDx shardings which is invalid.
      std::unordered_set<ParallelType> selected_parallel_types =
          getParallelTypesToPropagate(outputs_without_mesh);
    
      // This restricts the transform propagation to only the relevant DID axis.
      int64_t did_pos =
          selectiveReorderDIDToFront(ref_input, selected_parallel_types);
    
      // Propagate the DID loop split to the outputs without mesh.
      propagateDIDTransform(
          /*ref=*/ref_input,
          /*tvs=*/outputs_without_mesh,
          /*did_pos=*/did_pos,
          /*allow_c2p=*/false,
          /*allow_p2c=*/true);
    
      // Apply parallelization on the outputs without mesh.
      shardAllLike(ref_input, outputs_without_mesh, selected_parallel_types);
    }
    Testing Limitation

    The tests for LoopSplitMLP and LoopSplitMHAFwd are currently limited to DataType::Float and do not include float16 or bfloat16 due to issues with privatizeUpcast. This should be addressed to ensure comprehensive testing.

    // TODO: Allow testing for float16 and bfloat16 for loop split mlp and mha
    // This currently fails because privatizeUpcast clones cast operations,
    // which fails segmentation since the transforms are not replicated.
    TEST_F(DistributedTransformerTest, LoopSplitMLP) {
      if ((4 * E) % D != 0) {
        GTEST_SKIP() << "Requires number of devices=" << D
                     << " evenly divide 4*E=" << 4 * E;
      }
      auto dtype = DataType::Float;
      at::ScalarType at_dtype = data_type_to_aten(dtype);
    
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const int d = communicator_->size();
      auto mesh = DeviceMesh::createForNumDevices(d);
    
      TensorView* inp = makeContigConcreteTensor({B, S, E}, dtype);
      TensorView* w0 = makeContigConcreteTensor({4 * E, E}, dtype);
      TensorView* w1 = makeContigConcreteTensor({E, 4 * E}, dtype);
    
      TensorView* linear0 = linear(inp, w0);
      TensorView* linear0_float = castOp(DataType::Float, linear0);
      TensorView* gelu = tanh_gelu(linear0_float);
      TensorView* gelu_dtype = castOp(dtype, gelu);
      TensorView* linear1 = linear(gelu_dtype, w1);
    
      std::vector<TensorView*> fusion_inputs{inp, w0, w1};
      for (auto tv : fusion_inputs) {
        fusion->addInput(tv);
        tv->setDeviceMesh(mesh);
      }
      fusion->addOutput(linear1);
    
      w0->outer_split(0, d);
      w0->axis(0)->parallelize(ParallelType::DIDx);
      w1->outer_split(1, d);
      w1->axis(1)->parallelize(ParallelType::DIDx);
    
      FusionExecutorCache executor_cache(std::move(fusion));
      at::Tensor inp_tensor = at::randn({B, S, E}, tensor_options.dtype(at_dtype));
      at::Tensor w0_tensor = at::randn({4 * E, E}, tensor_options.dtype(at_dtype));
      at::Tensor w1_tensor = at::randn({E, 4 * E}, tensor_options.dtype(at_dtype));
    
      at::Tensor w0_sharded = shardTensor(w0_tensor, 0, mesh);
      at::Tensor w1_sharded = shardTensor(w1_tensor, 1, mesh);
    
      KernelArgumentHolder args = {inp_tensor, w0_sharded, w1_sharded};
      auto outputs = executor_cache.runFusionWithInputs(args);
      at::Tensor nvf_out = outputs[0].as<at::Tensor>();
    
      at::Tensor ref_out =
          reference_loop_split_mlp(inp_tensor, w0_tensor, w1_tensor);
      validate({ref_out}, {nvf_out}, {0.02});
    }
    
    TEST_F(DistributedTransformerTest, LoopSplitMHAFwd) {
      if (H % D != 0) {
        GTEST_SKIP() << "Requires number of devices=" << D
                     << " evenly divide H=" << H;
      }
    
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      auto dtype = DataType::Half;
      at::ScalarType at_dtype = data_type_to_aten(dtype);
    
      const int d = communicator_->size();
    
      auto mesh = DeviceMesh::createForNumDevices(d);
    
      TensorView* qkv = makeContigConcreteTensor({B, S, H, 3 * E / H}, dtype);
      TensorView* q = slice(qkv, {0, 0, 0, 0}, {B, S, H, E / H});
      TensorView* k = slice(qkv, {0, 0, 0, E / H}, {B, S, H, 2 * E / H});
      TensorView* v = slice(qkv, {0, 0, 0, 2 * E / H}, {B, S, H, 3 * E / H});
    
      TensorView* q_permuted = permute(q, {0, 2, 1, 3});
      TensorView* k_permuted = permute(k, {0, 2, 1, 3});
      TensorView* v_permuted = permute(v, {0, 2, 1, 3});
    
      SdpfaFwdResult sdpa_out = sdpfa_fwd(
          q_permuted,
          k_permuted,
          v_permuted,
          /*dropout_p=*/IrBuilder::create<Val>(kDropoutProb),
          /*is_causal=*/IrBuilder::create<Val>(true),
          /*scale=*/nullptr);
    
      TensorView* attn = sdpa_out.output;
      TensorView* attn_permute = permute(attn, {0, 2, 1, 3});
    
      fusion->addInput(qkv);
      fusion->addOutput(attn_permute);
    
      qkv->setDeviceMesh(mesh);
      qkv->outer_split(2, d);
      qkv->axis(2)->parallelize(ParallelType::DIDx);
    
      FusionExecutorCache executor_cache(std::move(fusion));
      at::Tensor unsharded_inp_tensor =
          at::randn({B, S, H, 3 * E / H}, tensor_options.dtype(at_dtype));
      at::Tensor inp_tensor = shardTensor(unsharded_inp_tensor, 2, mesh);
    
      KernelArgumentHolder args = {inp_tensor};
      auto outputs = executor_cache.runFusionWithInputs(args);
      at::Tensor nvf_out = outputs[0].as<at::Tensor>();
      at::Tensor ref_out = reference_loop_split_mha(inp_tensor);
      validate({ref_out}, {nvf_out}, {0.02});
    }

    @Priya2698 Priya2698 force-pushed the pm/preseg_sharding_prop branch from dcab5b7 to c41b7ee Compare February 6, 2025 18:48
    @Priya2698 Priya2698 marked this pull request as draft March 7, 2025 02:48
    @Priya2698 Priya2698 force-pushed the pm/preseg_sharding_prop branch from baad591 to 6d159ac Compare March 7, 2025 20:17
    @Priya2698 Priya2698 force-pushed the pm/preseg_sharding_prop branch from b50134f to f136246 Compare March 18, 2025 22:12
    @Priya2698 Priya2698 changed the title [WIP] update propagateSharding preseg pass for DID loop split Update propagateSharding preseg pass for DID loop split Mar 18, 2025
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 requested a review from wujingyue March 21, 2025 23:39
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    propagateShardingsPass runs after markAliasPreparePass: In this case, there is a missing device mesh which triggers isResharding for the permute operation and hence not picked up by expr_eval scheduler

    Sorry, I missed this comment. This sounds like a bug to fix. I fixed several csrc/ops APIs to propagate DeviceMesh and parallel types on the logical domain (eg) so preseg passes after sharding propagation (including markAliasPreparePass) can use these ops API for convenience. Which function gave you an empty DeviceMesh?

    Got it. I wasn't aware of this change, narrowed it to permute called through transpose in the affected tests. I'll introduce a fix in this PR and check.

    For my reference:

    permute operation does set the device mesh for the output tv. However, in the given test, the device meshes are set after the fusion definition, so at the point the permute op is called, the input tv does not have any device mesh. The tests were unimpacted because propagation of shardings is before markAliasAnalysis

    @Priya2698 Priya2698 requested a review from wujingyue April 23, 2025 01:08

    } // namespace

    void MakeReshardingContiguousPass::runPass(Fusion* fusion) {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'll rename this pass in a separate PR.

    Priya2698 added a commit that referenced this pull request Apr 23, 2025
    …4274)
    
    This makes #3838 performance neutral.
    
    Benchmarking results on GH200 nodes:
    
    On main:
    ```
    Name (time in ms)               Min     Max    Mean  StdDev  Median     IQR  Outliers       OPS  Rounds  Iterations
    
    test_transformer_forward     6.2744  7.0567  6.4946  0.3369  6.2961  0.4077       1;0  153.9732       5           1
    test_transformer_forward     6.2781  7.0573  6.4949  0.3368  6.2962  0.4076       1;0  153.9664       5           1
    -------------------------------------------------------------------------------------------------------------------
    
    test_transformer_backward     12.5244  13.7777  13.0152  0.6278  12.5900  1.1082       1;0  76.8331       5           1
    test_transformer_backward     12.5348  13.7620  13.0204  0.6094  12.6391  1.0909       1;0  76.8024       5           1
    -----------------------------------------------------------------------------------------------------------------------
    ```
    
    This branch:
    ```
    
    Name (time in ms)               Min     Max    Mean  StdDev  Median     IQR  Outliers       OPS  Rounds  Iterations
    test_transformer_forward     6.2889  7.0885  6.5132  0.3481  6.2960  0.4302       1;0  153.5349       5           1
    test_transformer_forward     6.2895  7.0262  6.5010  0.3231  6.2963  0.4195       1;0  153.8221       5           1
    
    Name (time in ms)                 Min      Max     Mean  StdDev   Median     IQR  Outliers      OPS  Rounds  Iterations
    test_transformer_backward     12.4542  13.6518  12.9532  0.5625  12.6231  0.9795       1;0  77.2012       5           1
    test_transformer_backward     12.4778  13.6544  12.9510  0.5641  12.5828  0.9724       1;0  77.2139       5           1
    -----------------------------------------------------------------------------------------------------------------------
    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I don't yet fully understand the code around selectiveReorderDIDToFront. LGTM otherwise!

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Priya2698 and others added 3 commits April 23, 2025 14:54
    Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
    Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 merged commit c9d2cc9 into main Apr 24, 2025
    53 checks passed
    @Priya2698 Priya2698 deleted the pm/preseg_sharding_prop branch April 24, 2025 03:43
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants