From 3628c3091b06a20eeb4e5197af089bb5d7a701af Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Mar 2024 19:40:22 +0000 Subject: [PATCH 1/3] [WIP] Add epilogue loop for circular buffering --- csrc/device_lower/pass/double_buffer.cpp | 42 ++++++++++++++++++++---- csrc/device_lower/pass/double_buffer.h | 18 ++++++---- csrc/kernel_ir.h | 15 +++++++-- 3 files changed, 59 insertions(+), 16 deletions(-) diff --git a/csrc/device_lower/pass/double_buffer.cpp b/csrc/device_lower/pass/double_buffer.cpp index d8d2100fe46..9693f908876 100644 --- a/csrc/device_lower/pass/double_buffer.cpp +++ b/csrc/device_lower/pass/double_buffer.cpp @@ -150,10 +150,9 @@ class DoubleBufferFusionInspector : private IterVisitor { // The epilogue loop is only created when the producer of a double // buffer tensor is on smem, in which case it would otherwise require -// an additional predicate to guard buffer overruns. When it's on -// gmem, that isn't the case, so it does not need to create an -// epilogue loop. +// an additional predicate to guard buffer overruns. bool requireEpilogue(const std::vector& exprs) { + return true; return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) { return expr->input(0)->as()->getMemoryType() == MemoryType::Shared; @@ -207,22 +206,29 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { double_buffer_loop_->iter_domain()); if (loop_type_ == DoubleBufferLoopStage::Prolog) { + std::cout << "Prolog" << std::endl; NVF_ERROR(start->isZeroInt()); stop = SimplifyingIrBuilder::create( int64_t(stage_depth - 1), DataType::Index); } else if ( loop_type_ == DoubleBufferLoopStage::Main && requireEpilogue(double_buffer_load_exprs_)) { + std::cout << "Main" << std::endl; stop = IrBuilder::subExpr( - double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal()); + double_buffer_loop_->stop(), + IrBuilder::create(int64_t(stage_depth - 1), DataType::Index)); } else if (loop_type_ == DoubleBufferLoopStage::Epilog) { + std::cout << "Epilog" << std::endl; NVF_ERROR(requireEpilogue(double_buffer_load_exprs_)); start = IrBuilder::subExpr( double_buffer_loop_->stop(), - SimplifyingIrBuilder::create( - int64_t(stage_depth - 1), DataType::Index)); + IrBuilder::create(int64_t(stage_depth - 1), DataType::Index)); } + std::cout << "start=" << start->toInlineString() + << " stop=" << stop->toInlineString() + << " stage_depth=" << stage_depth << std::endl; + cloned_top_level_loop_ = IrBuilder::create( double_buffer_loop_->iter_domain(), index, @@ -596,6 +602,30 @@ class DoubleBufferInserter : private kir::ExprMutator { loads, DoubleBufferLoopStage::Epilog, alloc_in_main); + if (has_cpasync) { + // Insert a cp.async.wait_group at the end of each epilogue loop + // iteration. This should require all but the remaining future iteration + // loads to be completed. + // for i in (N-(D-1))..N: // epilog + // for j in ... + // .. = x[(i%2)*S+j] + // cp.async.wait N-1-i; // Ensure all but the future loads are + // complete + // For D=4 this unrolls as + // for j in ... + // .. = x[(N-3)%2)*S+j] + // cp.async.wait 2; + // for j in ... + // .. = x[(N-2)%2)*S+j] + // cp.async.wait 1; + // for j in ... + // .. = x[(N-1)%2)*S+j] + // cp.async.wait 0; + auto cp_async_wait = IrBuilder::create( + AsyncOpType::CpAsync, + SimplifyingIrBuilder::subExpr( + epilogue_loop->index() FusionGuard::getCurFusion()->oneVal())); + } registerInsertAfter(double_buffer_loop, epilogue_loop); } } diff --git a/csrc/device_lower/pass/double_buffer.h b/csrc/device_lower/pass/double_buffer.h index 18ff579f221..4c2ecd563d5 100644 --- a/csrc/device_lower/pass/double_buffer.h +++ b/csrc/device_lower/pass/double_buffer.h @@ -69,9 +69,7 @@ // iteration. However, the value loaded by the invalid load would not // be used, so instead of adding the additional predicate, the Epilogue // loop is replicated from the original loop, except for the load -// expression since it's not used. Note that this overrun does not -// happen when the producer is on gmem, so in that case, this -// additional replication is not done. +// expression since it's not used. // // When creating those three types of loops, additional care must be // taken when multiple tensors are double buffered. When multiple @@ -109,16 +107,16 @@ // if pred: // x[i*S+j] = y[i, j]; // -// for i in 0..N: // main loop +// for i in 0..(N-(D-1)): // main loop // for j in ... // if pred: // x[((i+D-1)%D)*S+j] = y[i+D-1, j]; // for j in ... // .. = x[(i%D)*S+j] // -// (Epilog omitted since this only makes sense in using -// cp.async, where producer will be in global mem and consumer will -// be in shared mem). +// for i in (N-(D-1))..N: // epilog +// for j in ... +// .. = x[(i%2)*S+j] // // The profitability of this optimization comes from extra tolerance // of global memory pipeline latency, as on the expression `.. = x[(i%D)*S+j]` @@ -162,6 +160,12 @@ // ensure completion of its own async copies so // would need to sync to this point to ensure // completion of the whole tile. +// +// for i in (N-(D-1))..N: // epilog +// for j in ... +// // TODO: is a sync required here? +// .. = x[(i%2)*S+j] +// cp.async.wait N-1-i; // Ensure all but the future loads are incomplete namespace nvfuser { diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 1a2a74333d5..b784575c8dc 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -631,7 +631,16 @@ class AsyncWait final : public Expr { explicit AsyncWait( IrBuilderPasskey passkey, AsyncOpType async_op_type, - int64_t keep_stages = 0); + Val* keep_stages); + + explicit AsyncWait( + IrBuilderPasskey passkey, + AsyncOpType async_op_type, + int64_t keep_stages = 0) + : AsyncWait( + passkey, + async_op_type, + IrBuilder::create(keep_stages, DataType::Int)) {} NVFUSER_DECLARE_CLONE_AND_CREATE @@ -651,8 +660,8 @@ class AsyncWait final : public Expr { //! Returns the remaining number of stages that are not synchronized //! after this op. - int64_t keepStages() const { - return attribute(1); + Val* keepStages() const { + return attributeVal(1); } }; From b1e1e7d4d2c275e6a39e49c4bfa33564a7688bdd Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Mar 2024 19:56:14 +0000 Subject: [PATCH 2/3] Some fixes --- csrc/device_lower/pass/double_buffer.cpp | 13 +++++-------- csrc/kernel_ir.cpp | 13 +++++++++++-- csrc/kernel_ir.h | 6 +----- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/csrc/device_lower/pass/double_buffer.cpp b/csrc/device_lower/pass/double_buffer.cpp index 9693f908876..f9120433137 100644 --- a/csrc/device_lower/pass/double_buffer.cpp +++ b/csrc/device_lower/pass/double_buffer.cpp @@ -206,29 +206,22 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { double_buffer_loop_->iter_domain()); if (loop_type_ == DoubleBufferLoopStage::Prolog) { - std::cout << "Prolog" << std::endl; NVF_ERROR(start->isZeroInt()); stop = SimplifyingIrBuilder::create( int64_t(stage_depth - 1), DataType::Index); } else if ( loop_type_ == DoubleBufferLoopStage::Main && requireEpilogue(double_buffer_load_exprs_)) { - std::cout << "Main" << std::endl; stop = IrBuilder::subExpr( double_buffer_loop_->stop(), IrBuilder::create(int64_t(stage_depth - 1), DataType::Index)); } else if (loop_type_ == DoubleBufferLoopStage::Epilog) { - std::cout << "Epilog" << std::endl; NVF_ERROR(requireEpilogue(double_buffer_load_exprs_)); start = IrBuilder::subExpr( double_buffer_loop_->stop(), IrBuilder::create(int64_t(stage_depth - 1), DataType::Index)); } - std::cout << "start=" << start->toInlineString() - << " stop=" << stop->toInlineString() - << " stage_depth=" << stage_depth << std::endl; - cloned_top_level_loop_ = IrBuilder::create( double_buffer_loop_->iter_domain(), index, @@ -624,7 +617,11 @@ class DoubleBufferInserter : private kir::ExprMutator { auto cp_async_wait = IrBuilder::create( AsyncOpType::CpAsync, SimplifyingIrBuilder::subExpr( - epilogue_loop->index() FusionGuard::getCurFusion()->oneVal())); + epilogue_loop->stop(), + SimplifyingIrBuilder::addExpr( + epilogue_loop->index(), + FusionGuard::getCurFusion()->oneVal()))); + epilogue_loop->body().push_back(cp_async_wait); } registerInsertAfter(double_buffer_loop, epilogue_loop); } diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index 714fc44c65a..b561e83f35a 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -612,16 +612,25 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(BlockSerializeRelease) AsyncWait::AsyncWait( IrBuilderPasskey passkey, AsyncOpType async_op_type, - int64_t keep_stages) + Val* keep_stages) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); addDataAttribute(async_op_type); - addDataAttribute(keep_stages); + addAttribute(keep_stages); } +AsyncWait::AsyncWait( + IrBuilderPasskey passkey, + AsyncOpType async_op_type, + int64_t keep_stages) + : AsyncWait( + passkey, + async_op_type, + IrBuilder::create(keep_stages, DataType::Int)) {} + std::string AsyncWait::toString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << ptx() << " " << keepStages() << "\n"; diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index b784575c8dc..877723a039f 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -636,11 +636,7 @@ class AsyncWait final : public Expr { explicit AsyncWait( IrBuilderPasskey passkey, AsyncOpType async_op_type, - int64_t keep_stages = 0) - : AsyncWait( - passkey, - async_op_type, - IrBuilder::create(keep_stages, DataType::Int)) {} + int64_t keep_stages = 0); NVFUSER_DECLARE_CLONE_AND_CREATE From 2e5cc9e045ee23f35f53b3e92c0f27930c96e71b Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Mar 2024 20:09:38 +0000 Subject: [PATCH 3/3] More updates from converting keepStages to Val* --- csrc/device_lower/pass/inline_ptx.cpp | 5 +++-- csrc/kernel_ir.cpp | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 950bc0bcffc..7b61db9f1c1 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -32,7 +32,8 @@ class LowerToInlinePtx : public kir::ExprMutator { void handle(kir::AsyncWait* wait) override { if (wait->asyncOpType() == AsyncOpType::CpAsync && - wait->keepStages() == 0) { + wait->keepStages()->isConstInt() && + wait->keepStages()->evaluate().as() == 0) { // cp.async uses wait_all for zero keep stages, other instructions uses a // unified interface for all keep stages. registerReplace( @@ -48,7 +49,7 @@ class LowerToInlinePtx : public kir::ExprMutator { IrBuilder::create( wait->ptx(), std::vector{}, - std::vector{IrBuilder::create(wait->keepStages())}, + std::vector{wait->keepStages()}, kir::Asm::Options{/*volatile=*/true, /*memory=*/wait->memory()})); } } diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index b561e83f35a..8be95cd43f2 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -633,7 +633,8 @@ AsyncWait::AsyncWait( std::string AsyncWait::toString(int indent_size) const { std::stringstream ss; - indent(ss, indent_size) << ptx() << " " << keepStages() << "\n"; + indent(ss, indent_size) << ptx() << " " << keepStages()->toInlineString() + << "\n"; return ss.str(); }