diff --git a/csrc/device_lower/pass/double_buffer.cpp b/csrc/device_lower/pass/double_buffer.cpp index d8d2100fe46..f9120433137 100644 --- a/csrc/device_lower/pass/double_buffer.cpp +++ b/csrc/device_lower/pass/double_buffer.cpp @@ -150,10 +150,9 @@ class DoubleBufferFusionInspector : private IterVisitor { // The epilogue loop is only created when the producer of a double // buffer tensor is on smem, in which case it would otherwise require -// an additional predicate to guard buffer overruns. When it's on -// gmem, that isn't the case, so it does not need to create an -// epilogue loop. +// an additional predicate to guard buffer overruns. bool requireEpilogue(const std::vector& exprs) { + return true; return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) { return expr->input(0)->as()->getMemoryType() == MemoryType::Shared; @@ -214,13 +213,13 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { loop_type_ == DoubleBufferLoopStage::Main && requireEpilogue(double_buffer_load_exprs_)) { stop = IrBuilder::subExpr( - double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal()); + double_buffer_loop_->stop(), + IrBuilder::create(int64_t(stage_depth - 1), DataType::Index)); } else if (loop_type_ == DoubleBufferLoopStage::Epilog) { NVF_ERROR(requireEpilogue(double_buffer_load_exprs_)); start = IrBuilder::subExpr( double_buffer_loop_->stop(), - SimplifyingIrBuilder::create( - int64_t(stage_depth - 1), DataType::Index)); + IrBuilder::create(int64_t(stage_depth - 1), DataType::Index)); } cloned_top_level_loop_ = IrBuilder::create( @@ -596,6 +595,34 @@ class DoubleBufferInserter : private kir::ExprMutator { loads, DoubleBufferLoopStage::Epilog, alloc_in_main); + if (has_cpasync) { + // Insert a cp.async.wait_group at the end of each epilogue loop + // iteration. This should require all but the remaining future iteration + // loads to be completed. + // for i in (N-(D-1))..N: // epilog + // for j in ... + // .. = x[(i%2)*S+j] + // cp.async.wait N-1-i; // Ensure all but the future loads are + // complete + // For D=4 this unrolls as + // for j in ... + // .. = x[(N-3)%2)*S+j] + // cp.async.wait 2; + // for j in ... + // .. = x[(N-2)%2)*S+j] + // cp.async.wait 1; + // for j in ... + // .. = x[(N-1)%2)*S+j] + // cp.async.wait 0; + auto cp_async_wait = IrBuilder::create( + AsyncOpType::CpAsync, + SimplifyingIrBuilder::subExpr( + epilogue_loop->stop(), + SimplifyingIrBuilder::addExpr( + epilogue_loop->index(), + FusionGuard::getCurFusion()->oneVal()))); + epilogue_loop->body().push_back(cp_async_wait); + } registerInsertAfter(double_buffer_loop, epilogue_loop); } } diff --git a/csrc/device_lower/pass/double_buffer.h b/csrc/device_lower/pass/double_buffer.h index 18ff579f221..4c2ecd563d5 100644 --- a/csrc/device_lower/pass/double_buffer.h +++ b/csrc/device_lower/pass/double_buffer.h @@ -69,9 +69,7 @@ // iteration. However, the value loaded by the invalid load would not // be used, so instead of adding the additional predicate, the Epilogue // loop is replicated from the original loop, except for the load -// expression since it's not used. Note that this overrun does not -// happen when the producer is on gmem, so in that case, this -// additional replication is not done. +// expression since it's not used. // // When creating those three types of loops, additional care must be // taken when multiple tensors are double buffered. When multiple @@ -109,16 +107,16 @@ // if pred: // x[i*S+j] = y[i, j]; // -// for i in 0..N: // main loop +// for i in 0..(N-(D-1)): // main loop // for j in ... // if pred: // x[((i+D-1)%D)*S+j] = y[i+D-1, j]; // for j in ... // .. = x[(i%D)*S+j] // -// (Epilog omitted since this only makes sense in using -// cp.async, where producer will be in global mem and consumer will -// be in shared mem). +// for i in (N-(D-1))..N: // epilog +// for j in ... +// .. = x[(i%2)*S+j] // // The profitability of this optimization comes from extra tolerance // of global memory pipeline latency, as on the expression `.. = x[(i%D)*S+j]` @@ -162,6 +160,12 @@ // ensure completion of its own async copies so // would need to sync to this point to ensure // completion of the whole tile. +// +// for i in (N-(D-1))..N: // epilog +// for j in ... +// // TODO: is a sync required here? +// .. = x[(i%2)*S+j] +// cp.async.wait N-1-i; // Ensure all but the future loads are incomplete namespace nvfuser { diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 950bc0bcffc..7b61db9f1c1 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -32,7 +32,8 @@ class LowerToInlinePtx : public kir::ExprMutator { void handle(kir::AsyncWait* wait) override { if (wait->asyncOpType() == AsyncOpType::CpAsync && - wait->keepStages() == 0) { + wait->keepStages()->isConstInt() && + wait->keepStages()->evaluate().as() == 0) { // cp.async uses wait_all for zero keep stages, other instructions uses a // unified interface for all keep stages. registerReplace( @@ -48,7 +49,7 @@ class LowerToInlinePtx : public kir::ExprMutator { IrBuilder::create( wait->ptx(), std::vector{}, - std::vector{IrBuilder::create(wait->keepStages())}, + std::vector{wait->keepStages()}, kir::Asm::Options{/*volatile=*/true, /*memory=*/wait->memory()})); } } diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index 714fc44c65a..8be95cd43f2 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -612,19 +612,29 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(BlockSerializeRelease) AsyncWait::AsyncWait( IrBuilderPasskey passkey, AsyncOpType async_op_type, - int64_t keep_stages) + Val* keep_stages) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); addDataAttribute(async_op_type); - addDataAttribute(keep_stages); + addAttribute(keep_stages); } +AsyncWait::AsyncWait( + IrBuilderPasskey passkey, + AsyncOpType async_op_type, + int64_t keep_stages) + : AsyncWait( + passkey, + async_op_type, + IrBuilder::create(keep_stages, DataType::Int)) {} + std::string AsyncWait::toString(int indent_size) const { std::stringstream ss; - indent(ss, indent_size) << ptx() << " " << keepStages() << "\n"; + indent(ss, indent_size) << ptx() << " " << keepStages()->toInlineString() + << "\n"; return ss.str(); } diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 1a2a74333d5..877723a039f 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -628,6 +628,11 @@ class AsyncWait final : public Expr { public: using Expr::Expr; + explicit AsyncWait( + IrBuilderPasskey passkey, + AsyncOpType async_op_type, + Val* keep_stages); + explicit AsyncWait( IrBuilderPasskey passkey, AsyncOpType async_op_type, @@ -651,8 +656,8 @@ class AsyncWait final : public Expr { //! Returns the remaining number of stages that are not synchronized //! after this op. - int64_t keepStages() const { - return attribute(1); + Val* keepStages() const { + return attributeVal(1); } };