Skip to content

[Stream lowering] Allgather p2p#4515

Merged
samnordmann merged 7 commits intomainfrom
host_irs/stream_lowering/AG_p2p_pr
Jun 3, 2025
Merged

[Stream lowering] Allgather p2p#4515
samnordmann merged 7 commits intomainfrom
host_irs/stream_lowering/AG_p2p_pr

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented May 26, 2025

on top of

What

Add Stream lowering to Allgather p2p linear, with NCCL backend

For example: MultiDeviceStreamParallelTypeTest.AllgatherP2p from tests/cpp/test_multidevice_stream_parallel_type.cpp:

  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  fusion->addInput(tv0);
  fusion->addOutput(tv1);

  const DeviceMesh mesh =
      DeviceMesh::createForNumDevices(communicator_->size());
  tv0->setDeviceMesh(mesh);
  tv1->setDeviceMesh(mesh);
  tv0->axis(0)->parallelize(ParallelType::DIDx);
  tv1->axis(0)->parallelize(ParallelType::Stream);

is lowered to:

%HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx )
    IF Manual ( StreamIdx == deviceIdx.x ):
      T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 )
      T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = Set( T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
    ELSE:
      StartCoalescing
      P2PCommunication 30 (type=recv, buffer=T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      P2PCommunication 31 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      EndCoalescing 32
      Wait Communication 32
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer

An test with an overlapped matmul is also proposed in AG_matmul_P2p, which generates the following host program:

%HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i3 ), zero_init=false, resets_to_zero=false)
  T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i5 ), zero_init=false, resets_to_zero=false
)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx9{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx9{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx9{i0}, index = StreamIdx )
    IF Manual ( StreamIdx == deviceIdx.x ):
      T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 )
      T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = Set( T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
    ELSE:
      StartCoalescing
      P2PCommunication 41 (type=recv, buffer=T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      P2PCommunication 42 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      EndCoalescing 43
      Wait Communication 43
    T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx5{i0}, index = StreamIdx )
    T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer

@samnordmann
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented May 26, 2025

Review updated until commit 045202a

Description

  • Added stream lowering for Allgather P2P linear with NCCL backend

  • Updated tests to include Allgather P2P cases

  • Renamed variables for clarity and consistency

  • Fixed Communicator usage in IfThenElseTest


Changes walkthrough 📝

Relevant files
Enhancement
executor.cpp
Bind rank and add non-blocking copy                                           

csrc/host_ir/executor.cpp

  • Bound rank to communicator_->deviceId()
  • Added non-blocking copy for out_tensor.copy_
  • +2/-1     
    host_ir.cpp
    Update EndCoalescing toString                                                       

    csrc/host_ir/host_ir.cpp

    • Updated EndCoalescing::toString to include name
    +1/-1     
    stream_parallel_type.cpp
    Implement P2P communication for Allgather P2P                       

    csrc/host_ir/pass/stream_parallel_type.cpp

  • Added special handling for DIDx to Stream parallel type conversion
  • Implemented P2P communication for Allgather P2P linear
  • +107/-5 
    test_host_irs.cpp
    Pass Communicator to HostIrEvaluator                                         

    tests/cpp/test_host_irs.cpp

    • Passed Communicator::getInstance() to HostIrEvaluator
    +9/-7     
    test_multidevice_stream_parallel_type.cpp
    Add Allgather P2P tests                                                                   

    tests/cpp/test_multidevice_stream_parallel_type.cpp

    • Added tests for Allgather P2P and AG matmul P2P
    +91/-0   
    executor.h
    Set default communicator                                                                 

    csrc/host_ir/executor.h

    • Set default communicator to Communicator::getInstance()
    +1/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Error Handling

    The code assumes that the input and output tensors are always sharded and stream parallelized in a specific way. It should include error handling for cases where these assumptions do not hold.

    for (auto* body_expr : for_loop->body().exprs()) {
      // We have a special handling for when an axis pass from DIDx to Stream
      // parallel type in one expression. This case should be lowered to a P2P
      // Communication. For now, we only allow the "Linear Allgather" case,
      // where tv0 [DIDx(i0), ...] and tv1=set(tv0) [Stream(i0), ...]. In this
      // case, the set should be lowered to something like
      //
      // FOR StreamIdx in range(i0):
      //   [...]
      //   SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
      //   IF StreamIdx == rank: // This is the local copy
      //     Tv1[StreamIdx, ...].copy_(Tv0[0, ...]) // the index 0 because Tv0
      //     is sharded
      //   ELSE:
      //     Recv (buffer=Tv1[StreamIdx, ...], peer=StreamIdx)
      //     Send (buffer=Tv0[0, ...], peer=StreamIdx)
      //   [...]
      bool needs_p2p_handling = false;
    
      // Check if any input needs P2P handling
      for (auto* input :
           ir_utils::filterByType<TensorView>(body_expr->inputs())) {
        if (auto stream_idx =
                findStreamAxisIndex(input, for_loop->iterDomain(), id_model);
            stream_idx != -1) {
          if (input->getLogicalDomain()[stream_idx]->isDeviceDim()) {
            needs_p2p_handling = true;
            break;
          }
        }
      }
    
      if (needs_p2p_handling) {
        NVF_ERROR(
            body_expr->isA<LoadStoreOp>() &&
                body_expr->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set,
            "expected a set operation but got ",
            body_expr);
        NVF_ERROR(
            body_expr->isA<LoadStoreOp>(),
            "expected a Tv operation but got ",
            body_expr);
        auto* set_op = body_expr->as<LoadStoreOp>();
        auto* input_tv = set_op->in()->as<TensorView>();
        auto* output_tv = set_op->out()->as<TensorView>();
        NVF_ERROR(
            input_tv->axis(0)->isDeviceDim(),
            "expected a sharded first axis on the input but got ",
            input_tv);
        NVF_ERROR(
            output_tv->axis(0)->getParallelType() == ParallelType::Stream,
            "expected a stream parallelized first axis on the output but got ",
            output_tv);
    
        auto* peer = for_loop->index();
        auto* my_device_id =
            IrBuilder::create<NamedScalar>("rank", DataType::Int);
        auto* is_sending_to_self =
            IrBuilder::create<kir::Predicate>(eq(peer, my_device_id));
        auto if_then_else =
            IrBuilder::create<kir::IfThenElse>(is_sending_to_self);
    
        auto [slicing_input, is_new] = tensor_slicing_cache.get(
            input_tv,
            /*dim=*/0,
            /*index=*/FusionGuard::getCurFusion()->zeroVal());
        auto [slicing_output, is_new_] = tensor_slicing_cache.get(
            output_tv, /*dim=*/0, /*index=*/for_loop->index());
    
        auto* local_copy = IrBuilder::create<LoadStoreOp>(
            LoadStoreOpType::Set, slicing_output->out(), slicing_input->out());
    
        if_then_else->thenBody().push_back(slicing_input);
        if_then_else->thenBody().push_back(local_copy);
    
        // Using Start/EndCoalescing here is important to 1) avoid hangs because
        // of a wrong global order of send/recv and 2) enjoy full bi-directional
        // bandwith.
        auto start_coalescing = IrBuilder::create<hir::StartCoalescing>();
        auto recv = IrBuilder::create<P2PCommunication>(
            P2PCommunicationType::RECV,
            slicing_output->out(),
            /*peer*/ for_loop->index(),
            CommunicatorBackend::kNccl);
        auto send = IrBuilder::create<P2PCommunication>(
            P2PCommunicationType::SEND,
            input_tv,
            /*peer*/ for_loop->index(),
            CommunicatorBackend::kNccl);
        auto end_coalescing = IrBuilder::create<hir::EndCoalescing>();
        auto wait = IrBuilder::create<hir::Wait>(end_coalescing);
    
        if_then_else->elseBody().push_back(start_coalescing);
        if_then_else->elseBody().push_back(recv);
        if_then_else->elseBody().push_back(send);
        if_then_else->elseBody().push_back(end_coalescing);
        if_then_else->elseBody().push_back(wait);
    
        new_loop_body.push_back(slicing_output);
        new_loop_body.push_back(if_then_else);
      } else {
        // Process inputs and outputs normally
        for (auto* input :
             ir_utils::filterByType<TensorView>(body_expr->inputs())) {
          processTensor(body_expr, input);
        }
        for (auto* output :
             ir_utils::filterByType<TensorView>(body_expr->outputs())) {
          processTensor(body_expr, output);
        }
        new_loop_body.push_back(body_expr);
      }
    }
    Test Coverage

    The new tests cover specific cases but may not cover all edge cases or different configurations. Consider adding more tests to ensure robustness.

    TEST_F(MultiDeviceStreamParallelTypeTest, AllgatherP2p) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
      TensorView* tv0 = makeContigTensor(2);
      TensorView* tv1 = set(tv0);
      fusion->addInput(tv0);
      fusion->addOutput(tv1);
    
      const DeviceMesh mesh =
          DeviceMesh::createForNumDevices(communicator_->size());
      tv0->setDeviceMesh(mesh);
      tv1->setDeviceMesh(mesh);
      tv0->axis(0)->parallelize(ParallelType::DIDx);
      tv1->axis(0)->parallelize(ParallelType::Stream);
    
      MultiDeviceExecutor executor(std::move(fusion), *communicator_);
    
      hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
      EXPECT_THAT(
          container->topLevelExprs(),
          ElementsAre(
              IsA<kir::Allocate>(),
              IsA<hir::GetCurrentStream>(),
              IsA<ForLoop>(),
              IsA<ForLoop>()));
    
      auto options =
          at::TensorOptions().device(at::kCUDA, communicator_->deviceId());
      at::Tensor unsharded_input = at::rand({communicator_->size(), 4}, options);
      at::Tensor input = shardTensor(unsharded_input, /*axis=*/0, mesh);
      auto output =
          executor.runWithInput(KernelArgumentHolder({input}))[0].as<at::Tensor>();
    
      EXPECT_TRUE(torch::allclose(output, unsharded_input, 1e-2, 1e-2))
          << "Output: " << output << "\nExpected: " << unsharded_input;
    }
    
    TEST_F(MultiDeviceStreamParallelTypeTest, AG_matmul_P2p) {
      constexpr int64_t M = 32768;
      constexpr int64_t K = 32768;
      constexpr int64_t N = 1024;
      const int64_t D = communicator_->size();
      if (M % D != 0) {
        GTEST_SKIP() << "M must be a multiple of D, but got M = " << M
                     << ", D = " << D;
      }
    
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* tv0 = makeContigTensor(3); //[DIDx(D), M/D, K]
      TensorView* tv1 = makeContigTensor(2); //[K, N]
      TensorView* tv2 = matmul(tv0, tv1); //[Stream(D), M/D, N]
    
      fusion->addInput(tv0);
      fusion->addInput(tv1);
      fusion->addOutput(tv2);
    
      auto mesh = DeviceMesh::createForNumDevices(D);
      tv0->setDeviceMesh(mesh);
      tv1->setDeviceMesh(mesh);
      tv2->setDeviceMesh(mesh);
    
      tv0->axis(0)->parallelize(ParallelType::DIDx);
      tv2->axis(0)->parallelize(ParallelType::Stream);
    
      MultiDeviceExecutor executor(std::move(fusion), *communicator_);
    
      hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
      EXPECT_THAT(
          container->topLevelExprs(),
          ElementsAre(
              IsA<kir::Allocate>(),
              IsA<kir::Allocate>(),
              IsA<hir::GetCurrentStream>(),
              IsA<ForLoop>(),
              IsA<ForLoop>()));
    
      auto tensor_options =
          at::TensorOptions().dtype(at::kFloat).device(communicator_->device());
      auto t0_unsharded = at::randn({D, M / D, K}, tensor_options);
      auto t0 = t0_unsharded.slice(
          0, communicator_->deviceId(), communicator_->deviceId() + 1);
      auto t1 = at::randn({K, N}, tensor_options);
    
      auto t2 = executor.runWithInput({t0, t1})[0].as<at::Tensor>();
    
      auto t2_ref = at::matmul(t0_unsharded, t1);
      EXPECT_TRUE(torch::allclose(t2_ref, t2, 1e-2, 1e-2));
    }
    Non-blocking Copy

    The change to use non-blocking copy (copy_(t, /*non_blocking=*/true)) may lead to race conditions if not handled properly. Ensure that all necessary synchronization is in place.

    out_tensor.copy_(t, /*non_blocking=*/true);

    @samnordmann samnordmann force-pushed the host_irs/stream_lowering/AG_p2p_pr branch from 5c424bd to 99954c6 Compare May 27, 2025 14:16
    @samnordmann
    Copy link
    Collaborator Author

    !test

    @samnordmann samnordmann force-pushed the host_irs/stream_lowering/AG_p2p_pr branch from 99954c6 to 3662100 Compare May 28, 2025 13:52
    @samnordmann
    Copy link
    Collaborator Author

    !test

    @samnordmann samnordmann requested review from nsarka and wujingyue May 28, 2025 14:08
    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM otherwise

    Comment on lines +257 to +260
    FusionGuard fg(container_.get());
    expr_evaluator_.bind(
    NamedScalar::getParallelIndex(ParallelType::DIDx),
    communicator_->deviceId());
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is this the right thing to do in the foreseeable future? Isn't DIDx decided also by the mesh?

    Copy link
    Collaborator Author

    @samnordmann samnordmann Jun 2, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is this the right thing to do in the foreseeable future? Isn't DIDx decided also by the mesh?

    Thanks for the question. IIUC, you're asking if deviceIdx.x should be:

    1. An absolute device ID (e.g., always 1 for device 1), or
    2. A mesh-relative index (e.g., 0 for device 1 if the mesh is {1}).

    The current PR implements option 1. However, you question made me think, and I am seeing now that option 2 makes more sense, especially when we move to 2D (with the caveat that the mesh is per-Tensor and can change during a fusion).

    I decided to change the name of that NamedScalar to "myDeviceId" for now. Let me know if this sounds good to you.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    with the caveat that the mesh is per-Tensor and can change during a fusion

    Yes. Therefore, I was also unsure about deviceIdx.x being a "global" variable as in the previous version. Let me read your new changes...

    @samnordmann
    Copy link
    Collaborator Author

    !test

    @samnordmann
    Copy link
    Collaborator Author

    !test

    @samnordmann
    Copy link
    Collaborator Author

    !test

    @samnordmann samnordmann merged commit 43b9a1a into main Jun 3, 2025
    52 of 53 checks passed
    @samnordmann samnordmann deleted the host_irs/stream_lowering/AG_p2p_pr branch June 3, 2025 14:17
    nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
    on top of
    - NVIDIA#4387
    
    # What
    Add Stream lowering to Allgather p2p linear, with NCCL backend
    
    For example: `MultiDeviceStreamParallelTypeTest.AllgatherP2p` from
    `tests/cpp/test_multidevice_stream_parallel_type.cpp`:
    ```
      TensorView* tv0 = makeContigTensor(2);
      TensorView* tv1 = set(tv0);
      fusion->addInput(tv0);
      fusion->addOutput(tv1);
    
      const DeviceMesh mesh =
          DeviceMesh::createForNumDevices(communicator_->size());
      tv0->setDeviceMesh(mesh);
      tv1->setDeviceMesh(mesh);
      tv0->axis(0)->parallelize(ParallelType::DIDx);
      tv1->axis(0)->parallelize(ParallelType::Stream);
    ```
    is lowered to:
    ```
    %HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx2{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx2{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx )
        IF Manual ( StreamIdx == deviceIdx.x ):
          T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
             = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 )
          T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
             = Set( T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
        ELSE:
          StartCoalescing
          P2PCommunication 30 (type=recv, buffer=T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
          P2PCommunication 31 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
          EndCoalescing 32
          Wait Communication 32
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    ```
    
    
    An test with an overlapped matmul is also proposed in `AG_matmul_P2p`,
    which generates the following host program:
    ```
    %HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i3 ), zero_init=false, resets_to_zero=false)
      T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i5 ), zero_init=false, resets_to_zero=false
    )
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx9{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx9{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx9{i0}, index = StreamIdx )
        IF Manual ( StreamIdx == deviceIdx.x ):
          T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
             = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 )
          T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
             = Set( T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
        ELSE:
          StartCoalescing
          P2PCommunication 41 (type=recv, buffer=T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
          P2PCommunication 42 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
          EndCoalescing 43
          Wait Communication 43
        T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx5{i0}, index = StreamIdx )
        T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = matmul(T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                    T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}))
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    ```
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants