From 9a687498716ab107f533c7af473f4f3b6e9d841c Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 14 Feb 2025 06:41:23 -0800 Subject: [PATCH] add sync stream at start of for-loop --- csrc/host_ir/lower.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 9d23e3b0f3b..c5a8ed5c3d9 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -453,6 +453,8 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( auto* stream_index = mod(j, number_of_streams); auto* stream = IrBuilder::create(stream_index); auto* set_stream = IrBuilder::create(stream); + auto* initial_sync_stream = + IrBuilder::create(original_stream); TensorView* tva_j = select(tva, 0, j); TensorView* tva_allgathered_j = select(tva_allgathered, 0, j); @@ -488,6 +490,7 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( std::vector loop_body = { set_stream, + initial_sync_stream, tva_j->definition(), tva_allgathered_j->definition(), communication,