diff --git a/CMakeLists.txt b/CMakeLists.txt index 09e7699e129..d948fc3d7af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -449,6 +449,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/test/test_polymorphic_value.cpp ${NVFUSER_ROOT}/test/test_matmul_sass.cpp ${NVFUSER_ROOT}/test/test_matmul_scheduler.cpp + ${NVFUSER_ROOT}/test/test_mbarrier.cpp ${NVFUSER_ROOT}/test/test_memory.cpp ${NVFUSER_ROOT}/test/test_gpu_view.cpp ${NVFUSER_ROOT}/test/test_gpu_transpose.cpp @@ -588,6 +589,7 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/grid_sync.cu ${NVFUSER_ROOT}/runtime/helpers.cu ${NVFUSER_ROOT}/runtime/index_utils.cu + ${NVFUSER_ROOT}/runtime/mbarrier.cu ${NVFUSER_ROOT}/runtime/memory.cu ${NVFUSER_ROOT}/runtime/random_numbers.cu ${NVFUSER_ROOT}/runtime/tensor.cu diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index dd08594f6e5..5eb148f0ef5 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -2890,6 +2890,46 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << sync_call << ";\n"; } + void handle(const kir::MBarrierInit* init) final { + auto call = genCall( + "mbarrier::init", + ArgumentBuilder() + .arg(genInline(init->mbarrier()->as()->index())) + .arg(genInline(init->threadCount()))); + indent() << call << ";\n"; + } + + void handle(const kir::MBarrierInvalidate* inval) final { + auto call = genCall( + "mbarrier::inval", + ArgumentBuilder().arg( + genInline(inval->mbarrier()->as()->index()))); + indent() << call << ";\n"; + } + + void handle(const kir::MBarrierArrive* arrive) final { + if (!print_inline_) { + indent() << gen(arrive->state()) << " = "; + } + auto call = genCall( + "mbarrier::arrive", + ArgumentBuilder().arg( + genInline(arrive->mbarrier()->as()->index()))); + code_ << call; + if (!print_inline_) { + code_ << ";\n"; + } + } + + void handle(const kir::MBarrierWait* wait) final { + auto call = genCall( + "mbarrier::wait", + ArgumentBuilder() + .arg(genInline(wait->mbarrier()->as()->index())) + .arg(genInline(wait->state()))); + indent() << call << ";\n"; + } + void handle(const kir::InitMagicZero*) final { indent() << "NVFUSER_DEFINE_MAGIC_ZERO;\n"; } diff --git a/csrc/dispatch.cpp b/csrc/dispatch.cpp index 172bc65f294..07c23baed54 100644 --- a/csrc/dispatch.cpp +++ b/csrc/dispatch.cpp @@ -248,6 +248,22 @@ void Expr::dispatch(T handler, Expr* expr) { ptr(handler)->handle(expr->as()); return; } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } if (expr->isStrictlyA()) { ptr(handler)->handle(expr->as()); return; @@ -532,6 +548,22 @@ void Expr::constDispatch(T handler, const Expr* expr) { ptr(handler)->handle(expr->as()); return; } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } if (expr->isStrictlyA()) { ptr(handler)->handle(expr->as()); return; @@ -914,6 +946,18 @@ void OptOutConstDispatch::handle(const kir::BlockSync* stmt) { void OptOutConstDispatch::handle(const kir::GridSync* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const kir::MBarrierInit* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::MBarrierInvalidate* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::MBarrierArrive* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::MBarrierWait* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const kir::CpAsyncWait* stmt) { unhandled(stmt); } @@ -1123,6 +1167,18 @@ void OptOutDispatch::handle(kir::BlockSync* stmt) { void OptOutDispatch::handle(kir::GridSync* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(kir::MBarrierInit* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::MBarrierInvalidate* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::MBarrierArrive* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::MBarrierWait* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(kir::CpAsyncWait* stmt) { unhandled(stmt); } diff --git a/csrc/dispatch.h b/csrc/dispatch.h index d5eac436595..ec55a7f064c 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -123,6 +123,10 @@ class TensorIndex; class Allocate; class BlockSync; class GridSync; +class MBarrierInit; +class MBarrierInvalidate; +class MBarrierArrive; +class MBarrierWait; class CpAsyncWait; class CpAsyncCommit; class ForLoop; @@ -209,6 +213,10 @@ class OptOutConstDispatch : public PolymorphicBase { virtual void handle(const kir::Allocate*); virtual void handle(const kir::BlockSync*); virtual void handle(const kir::GridSync*); + virtual void handle(const kir::MBarrierInit*); + virtual void handle(const kir::MBarrierInvalidate*); + virtual void handle(const kir::MBarrierArrive*); + virtual void handle(const kir::MBarrierWait*); virtual void handle(const kir::CpAsyncWait*); virtual void handle(const kir::CpAsyncCommit*); virtual void handle(const kir::InitMagicZero*); @@ -295,6 +303,10 @@ class OptOutDispatch : public PolymorphicBase { virtual void handle(kir::Allocate* stmt); virtual void handle(kir::BlockSync* stmt); virtual void handle(kir::GridSync* stmt); + virtual void handle(kir::MBarrierInit* stmt); + virtual void handle(kir::MBarrierInvalidate* stmt); + virtual void handle(kir::MBarrierArrive* stmt); + virtual void handle(kir::MBarrierWait* stmt); virtual void handle(kir::CpAsyncWait* stmt); virtual void handle(kir::CpAsyncCommit* stmt); virtual void handle(kir::InitMagicZero* stmt); diff --git a/csrc/executor.cpp b/csrc/executor.cpp index 260c2bea3d4..f7792d4cb16 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -326,6 +326,9 @@ void FusionExecutor::compileFusion( lowered_ = std::make_unique(fusion, compile_params); const auto kernel = lowered_->kernel(); + for (const auto& hook : post_lowering_hooks_) { + hook(kernel); + } fusion_ = lowered_->kernel()->as(); fusion_id_ = ++fusion_id_counter_; diff --git a/csrc/executor.h b/csrc/executor.h index 24a23cf3ff7..4922bd162db 100644 --- a/csrc/executor.h +++ b/csrc/executor.h @@ -20,6 +20,8 @@ #include +#include + namespace nvfuser { bool shouldFillAllocationWithNan(); @@ -111,6 +113,12 @@ class FusionExecutor : public NonCopyable { return runFusion(inputs, {}, launch_constraints, compile_params, opt_code); } + // Register a post-lowering hooks that are called to modify the kernel after + // lowering. The main use case is for unit tests to modify the kernel. + void registerPostLoweringHook(std::function hook) { + post_lowering_hooks_.push_back(std::move(hook)); + } + // function to query whether a `FusionExecutor` has a compiled kernel to // execute bool isCompiled() const { @@ -456,6 +464,10 @@ class FusionExecutor : public NonCopyable { // Profiling support: kept copy of the cuda kernel std::string kernel_code_; + + // Post-lowering hooks that are called to modify the kernel after lowering. + // The main use case is for unit tests to modify the kernel. + std::vector> post_lowering_hooks_; }; } // namespace nvfuser diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index 573ae4e1d9e..5205822416c 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include @@ -89,6 +90,7 @@ std::string kernelPreamble() { ss << nvfuser_resources::block_sync_default_cu; } ss << nvfuser_resources::grid_sync_cu; + ss << nvfuser_resources::mbarrier_cu; // Communication classes ss << nvfuser_resources::block_reduction_cu; diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index a0d5b0bed2a..a4559f8b53c 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -261,6 +261,94 @@ std::string GridSync::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(GridSync) +MBarrierInit::MBarrierInit( + IrBuilderPasskey passkey, + Val* mbarrier, + Val* thread_count) + : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_CHECK(thread_count->dtype() == DataType::UInt32); + addInput(mbarrier); + addInput(thread_count); +} + +std::string MBarrierInit::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "MBarrierInit(" << mbarrier()->toString() << ", " + << threadCount()->toString() << ")\n"; + return ss.str(); +} + +std::string MBarrierInit::toInlineString(int indent_size) const { + NVF_CHECK(false, "MBarrierInit can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierInit) + +MBarrierInvalidate::MBarrierInvalidate(IrBuilderPasskey passkey, Val* mbarrier) + : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + addInput(mbarrier); +} + +std::string MBarrierInvalidate::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "MBarrierInvalidate(" << mbarrier()->toString() + << ")\n"; + return ss.str(); +} + +std::string MBarrierInvalidate::toInlineString(int indent_size) const { + NVF_CHECK(false, "MBarrierInvalidate can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierInvalidate) + +MBarrierArrive::MBarrierArrive( + IrBuilderPasskey passkey, + Val* state, + Val* mbarrier) + : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_CHECK(state->dtype() == DataType::UInt); + addInput(mbarrier); + addOutput(state); +} + +std::string MBarrierArrive::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "MBarrierArrive(" << mbarrier()->toString() << ", " + << state()->toString() << ")\n"; + return ss.str(); +} + +std::string MBarrierArrive::toInlineString(int indent_size) const { + NVF_CHECK(false, "MBarrierArrive can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierArrive) + +MBarrierWait::MBarrierWait(IrBuilderPasskey passkey, Val* mbarrier, Val* state) + : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_CHECK(state->dtype() == DataType::UInt); + addInput(mbarrier); + addInput(state); +} + +std::string MBarrierWait::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "MBarrierWait(" << mbarrier()->toString() << ", " + << state()->toString() << ")\n"; + return ss.str(); +} + +std::string MBarrierWait::toInlineString(int indent_size) const { + NVF_CHECK(false, "MBarrierWait can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierWait) + CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, int64_t keep_stages) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 1046aa1f04f..b8bd670375f 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -37,6 +37,10 @@ class TensorIndex; class Allocate; class BlockSync; class GridSync; +class MBarrierInit; +class MBarrierInvalidate; +class MBarrierArrive; +class MBarrierWait; class CpAsyncWait; class CpAsyncCommit; class InitMagicZero; @@ -306,6 +310,97 @@ class GridSync final : public Expr { } }; +class MBarrierInit final : public Expr { + public: + using Expr::Expr; + explicit MBarrierInit( + IrBuilderPasskey passkey, + Val* mbarrier, + Val* thread_count); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "MBarrierInit"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + Val* mbarrier() const { + return input(0); + } + + Val* threadCount() const { + return input(1); + } +}; + +class MBarrierInvalidate final : public Expr { + public: + using Expr::Expr; + explicit MBarrierInvalidate(IrBuilderPasskey passkey, Val* mbarrier); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "MBarrierInvalidate"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + Val* mbarrier() const { + return input(0); + } +}; + +class MBarrierArrive final : public Expr { + public: + using Expr::Expr; + explicit MBarrierArrive(IrBuilderPasskey passkey, Val* state, Val* mbarrier); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "MBarrierArrive"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + Val* state() const { + return output(0); + } + + Val* mbarrier() const { + return input(0); + } +}; + +class MBarrierWait final : public Expr { + public: + using Expr::Expr; + explicit MBarrierWait(IrBuilderPasskey passkey, Val* mbarrier, Val* state); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "MBarrierWait"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + Val* mbarrier() const { + return input(0); + } + + Val* state() const { + return input(1); + } +}; + // CpAsyncWait represents wait intrinsics for cp.async class CpAsyncWait final : public Expr { public: diff --git a/csrc/type.cpp b/csrc/type.cpp index 02e41317785..3a3b90152cf 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -220,6 +220,10 @@ static std::string data_type2string(DataType t) { return "nvfuser_index_t"; case DataType::Int32: return "int"; + case DataType::UInt: + return "uint64_t"; + case DataType::UInt32: + return "uint32_t"; case DataType::SMemAddress: return "unsigned"; case DataType::ComplexFloat: @@ -1227,6 +1231,8 @@ std::string typePrefix(const DataType data_type) { case DataType::Index: case DataType::Int: case DataType::Int32: + case DataType::UInt: + case DataType::UInt32: case DataType::SMemAddress: return "i"; case DataType::ComplexFloat: diff --git a/csrc/type.h b/csrc/type.h index 177333621b5..d99a6666011 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -75,6 +75,8 @@ enum class PrimDataType { // Integral types Int, Int32, + UInt, + UInt32, Index, // Boolean types Bool, @@ -178,6 +180,8 @@ struct DataType { static constexpr PrimDataType Int = PrimDataType::Int; static constexpr PrimDataType Index = PrimDataType::Index; static constexpr PrimDataType Int32 = PrimDataType::Int32; + static constexpr PrimDataType UInt = PrimDataType::UInt; + static constexpr PrimDataType UInt32 = PrimDataType::UInt32; static constexpr PrimDataType Bool = PrimDataType::Bool; static constexpr PrimDataType BFloat16 = PrimDataType::BFloat16; static constexpr PrimDataType ComplexFloat = PrimDataType::ComplexFloat; @@ -258,6 +262,8 @@ inline bool isIntegralType(DataType dtype) { case DataType::Index: case DataType::Int: case DataType::Int32: + case DataType::UInt: + case DataType::UInt32: return true; default: return false; @@ -336,29 +342,16 @@ struct IsPrimitiveNativeType : std::false_type {}; template <> \ struct IsPrimitiveNativeType : std::true_type {} -#define DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( \ - data_type, at_type, native_type) \ - template <> \ - struct DataTypeToNativeType { \ - using type = native_type; \ - }; \ - template <> \ - struct DataTypeToAtenType { \ - static constexpr at::ScalarType type = at_type; \ - }; \ - template <> \ - struct NativeTypeToDataType { \ - static constexpr PrimDataType type = data_type; \ - }; \ - template <> \ - struct IsPrimitiveNativeType : std::true_type {}; \ - template <> \ - struct AtenTypeToDataType { \ - static constexpr PrimDataType type = data_type; \ - }; \ - template <> \ - struct AtenTypeToNativeType { \ - using type = native_type; \ +#define DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( \ + data_type, at_type, native_type) \ + DEFINE_DATATYPE_TO_NATIVE_TYPE(data_type, native_type); \ + template <> \ + struct AtenTypeToDataType { \ + static constexpr PrimDataType type = data_type; \ + }; \ + template <> \ + struct AtenTypeToNativeType { \ + using type = native_type; \ } DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( @@ -385,6 +378,8 @@ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( DataType::Int32, at::ScalarType::Int, int); +DEFINE_DATATYPE_TO_NATIVE_TYPE(DataType::UInt, uint64_t); +DEFINE_DATATYPE_TO_NATIVE_TYPE(DataType::UInt32, uint32_t); DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( DataType::Bool, at::ScalarType::Bool, @@ -938,8 +933,12 @@ constexpr inline size_t primDataTypeSize(PrimDataType type) { NVF_ERROR( false, "The actual type of Index is only known at compile time."); case DataType::Int: - return sizeof(uint64_t); + return sizeof(int64_t); case DataType::Int32: + return sizeof(int32_t); + case DataType::UInt: + return sizeof(uint64_t); + case DataType::UInt32: return sizeof(uint32_t); case DataType::SMemAddress: return sizeof(unsigned); diff --git a/runtime/mbarrier.cu b/runtime/mbarrier.cu new file mode 100644 index 00000000000..79359041bfb --- /dev/null +++ b/runtime/mbarrier.cu @@ -0,0 +1,65 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +// Reference: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-barrier +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier +// https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/copy_sm90_desc.hpp + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + +namespace mbarrier { + +__device__ inline void init( + uint32_t smem_barrier_ptr, + uint32_t thread_count = 1) { + asm volatile( + "mbarrier.init.shared.b64 [%0], %1;\n" ::"r"(smem_barrier_ptr), + "r"(thread_count)); +} + +__device__ inline void inval(uint32_t smem_barrier_ptr) { + asm volatile("mbarrier.inval.shared.b64 [%0];\n" ::"r"(smem_barrier_ptr)); +} + +__device__ inline uint64_t arrive(uint32_t smem_barrier_ptr) { + volatile uint64_t state; + asm volatile("mbarrier.arrive.shared.b64 %0, [%1];\n" + : "=l"(state) + : "r"(smem_barrier_ptr)); + return state; +} + +__device__ inline void wait(uint32_t smem_barrier_ptr, uint64_t state) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile( + "{\n" + ".reg .pred complete;\n" + "waitLoop:\n" + "mbarrier.try_wait.shared.b64 complete, [%0], %1;\n" + "@!complete bra waitLoop;\n" + "}\n" ::"r"(smem_barrier_ptr), + "l"(state)); +#else + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.test_wait.shared.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "nanosleep.u32 20;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" ::"r"(smem_barrier_ptr), + "l"(state)); +#endif +} + +} // namespace mbarrier + +#endif // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) diff --git a/test/test_mbarrier.cpp b/test/test_mbarrier.cpp new file mode 100644 index 00000000000..3a3115c0fb9 --- /dev/null +++ b/test/test_mbarrier.cpp @@ -0,0 +1,148 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace nvfuser { + +class MBarrierTest : public NVFuserTest { + void SetUp() override { + // requires Ampere or newer + if (!deviceMajorMinorCheck(8)) { + GTEST_SKIP() << "skipping tests on pre-Ampere GPUs"; + } + NVFuserTest::SetUp(); + } +}; + +TEST_F(MBarrierTest, Simple) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigConcreteTensor({32, 32}); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion.addInput(tv0); + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv1->axis(1)->parallelize(ParallelType::TIDy); + + tv2->axis(0)->parallelize(ParallelType::TIDy); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + FusionExecutor fe; + + fe.registerPostLoweringHook([](kir::Kernel* kernel) { + // Replace block sync with mbarrier + FusionGuard fg(kernel); + + std::vector& top_level_exprs = + const_cast&>(kernel->topLevelExprs()); + kir::KernelSummary& summary = + const_cast(kernel->summary()); + + // Allocate mbarrier + std::vector& dynamic_smem_allocations = + summary.dynamic_smem_allocations; + ASSERT_EQ(dynamic_smem_allocations.size(), 1); + + TensorView* mbarrier = makeContigConcreteTensor({}, DataType::UInt); + mbarrier->setMemoryType(MemoryType::Shared); + kir::Allocate* mbarrier_alloc = + IrBuilder::create(mbarrier, MemoryType::Shared); + dynamic_smem_allocations.push_back(mbarrier_alloc); + + Val* mbarrier_address = SimplifyingIrBuilder::mulExpr( + dynamic_smem_allocations.at(0)->size(), + dataTypeSize(dynamic_smem_allocations.at(0)->buffer()->dtype())); + mbarrier_alloc->setAddress(mbarrier_address); + + auto smem_alloc_it = std::find_if( + top_level_exprs.begin(), top_level_exprs.end(), [](Expr* expr) { + if (auto alloc = dynamic_cast(expr)) { + if (auto tv = dynamic_cast(alloc->buffer())) { + return tv->getMemoryType() == MemoryType::Shared; + } else { + return false; + } + } + return false; + }); + smem_alloc_it++; + ASSERT_NE(smem_alloc_it, top_level_exprs.end()); + smem_alloc_it = top_level_exprs.insert(smem_alloc_it, mbarrier_alloc); + + // Indexing mbarrier + auto mbarrier_smem_addr = IrBuilder::create(DataType::SMemAddress); + IrBuilder::create( + UnaryOpType::ToUnsignedSmemAddr, + mbarrier_smem_addr, + IrBuilder::metadataExpr(mbarrier)); + auto mbarrier_index = + IrBuilder::create(mbarrier, mbarrier_smem_addr); + + // Initialize mbarrier + smem_alloc_it++; + ASSERT_NE(smem_alloc_it, top_level_exprs.end()); + auto init = IrBuilder::create( + mbarrier_index, IrBuilder::create(1024, DataType::UInt32)); + top_level_exprs.insert(smem_alloc_it, init); + + // Arrive and wait + auto sync_it = std::find_if( + top_level_exprs.begin(), top_level_exprs.end(), [](Expr* expr) { + return expr->isA(); + }); + ASSERT_NE(sync_it, top_level_exprs.end()); + auto state = IrBuilder::create(DataType::UInt); + auto alloc_state = IrBuilder::create( + state, MemoryType::Local, kernel->oneVal()); + auto arrive = IrBuilder::create(state, mbarrier_index); + auto wait = IrBuilder::create(mbarrier_index, state); + *sync_it = wait; + sync_it = top_level_exprs.insert(sync_it, arrive); + top_level_exprs.insert(sync_it, alloc_state); + + // Invalidate mbarrier + auto invalidate = + IrBuilder::create(mbarrier_index); + top_level_exprs.push_back(invalidate); + }); + + fe.compileFusion(&fusion); + + // Make sure that the post-lowering hook successfully inserted all mbarrier + // operations + std::unordered_set remaining_mbarrier_exprs{ + &typeid(kir::MBarrierInit), + &typeid(kir::MBarrierArrive), + &typeid(kir::MBarrierWait), + &typeid(kir::MBarrierInvalidate)}; + for (auto expr : fe.kernel()->topLevelExprs()) { + remaining_mbarrier_exprs.erase(&typeid(*expr)); + } + EXPECT_TRUE(remaining_mbarrier_exprs.empty()); + + auto input = at::randn( + {32, 32}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0)); + auto outputs = fe.runFusion({input}); + + testValidate(&fusion, outputs, {input}, __LINE__, __FILE__); +} + +} // namespace nvfuser