diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 9d23e3b0f3b..c5a8ed5c3d9 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -453,6 +453,8 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( auto* stream_index = mod(j, number_of_streams); auto* stream = IrBuilder::create(stream_index); auto* set_stream = IrBuilder::create(stream); + auto* initial_sync_stream = + IrBuilder::create(original_stream); TensorView* tva_j = select(tva, 0, j); TensorView* tva_allgathered_j = select(tva_allgathered, 0, j); @@ -488,6 +490,7 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( std::vector loop_body = { set_stream, + initial_sync_stream, tva_j->definition(), tva_allgathered_j->definition(), communication,