Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/test/test_polymorphic_value.cpp
${NVFUSER_ROOT}/test/test_matmul_sass.cpp
${NVFUSER_ROOT}/test/test_matmul_scheduler.cpp
${NVFUSER_ROOT}/test/test_mbarrier.cpp
${NVFUSER_ROOT}/test/test_memory.cpp
${NVFUSER_ROOT}/test/test_gpu_view.cpp
${NVFUSER_ROOT}/test/test_gpu_transpose.cpp
Expand Down Expand Up @@ -588,6 +589,7 @@ list(APPEND NVFUSER_RUNTIME_FILES
${NVFUSER_ROOT}/runtime/grid_sync.cu
${NVFUSER_ROOT}/runtime/helpers.cu
${NVFUSER_ROOT}/runtime/index_utils.cu
${NVFUSER_ROOT}/runtime/mbarrier.cu
${NVFUSER_ROOT}/runtime/memory.cu
${NVFUSER_ROOT}/runtime/random_numbers.cu
${NVFUSER_ROOT}/runtime/tensor.cu
Expand Down
40 changes: 40 additions & 0 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2890,6 +2890,46 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
indent() << sync_call << ";\n";
}

void handle(const kir::MBarrierInit* init) final {
auto call = genCall(
"mbarrier::init",
ArgumentBuilder()
.arg(genInline(init->mbarrier()->as<kir::TensorIndex>()->index()))
.arg(genInline(init->threadCount())));
indent() << call << ";\n";
}

void handle(const kir::MBarrierInvalidate* inval) final {
auto call = genCall(
"mbarrier::inval",
ArgumentBuilder().arg(
genInline(inval->mbarrier()->as<kir::TensorIndex>()->index())));
indent() << call << ";\n";
}

void handle(const kir::MBarrierArrive* arrive) final {
if (!print_inline_) {
indent() << gen(arrive->state()) << " = ";
}
auto call = genCall(
"mbarrier::arrive",
ArgumentBuilder().arg(
genInline(arrive->mbarrier()->as<kir::TensorIndex>()->index())));
code_ << call;
if (!print_inline_) {
code_ << ";\n";
}
}

void handle(const kir::MBarrierWait* wait) final {
auto call = genCall(
"mbarrier::wait",
ArgumentBuilder()
.arg(genInline(wait->mbarrier()->as<kir::TensorIndex>()->index()))
.arg(genInline(wait->state())));
indent() << call << ";\n";
}

void handle(const kir::InitMagicZero*) final {
indent() << "NVFUSER_DEFINE_MAGIC_ZERO;\n";
}
Expand Down
56 changes: 56 additions & 0 deletions csrc/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,22 @@ void Expr::dispatch(T handler, Expr* expr) {
ptr(handler)->handle(expr->as<kir::GridSync>());
return;
}
if (expr->isStrictlyA<kir::MBarrierInit>()) {
ptr(handler)->handle(expr->as<kir::MBarrierInit>());
return;
}
if (expr->isStrictlyA<kir::MBarrierInvalidate>()) {
ptr(handler)->handle(expr->as<kir::MBarrierInvalidate>());
return;
}
if (expr->isStrictlyA<kir::MBarrierArrive>()) {
ptr(handler)->handle(expr->as<kir::MBarrierArrive>());
return;
}
if (expr->isStrictlyA<kir::MBarrierWait>()) {
ptr(handler)->handle(expr->as<kir::MBarrierWait>());
return;
}
if (expr->isStrictlyA<kir::CpAsyncWait>()) {
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
return;
Expand Down Expand Up @@ -532,6 +548,22 @@ void Expr::constDispatch(T handler, const Expr* expr) {
ptr(handler)->handle(expr->as<kir::GridSync>());
return;
}
if (expr->isStrictlyA<kir::MBarrierInit>()) {
ptr(handler)->handle(expr->as<kir::MBarrierInit>());
return;
}
if (expr->isStrictlyA<kir::MBarrierInvalidate>()) {
ptr(handler)->handle(expr->as<kir::MBarrierInvalidate>());
return;
}
if (expr->isStrictlyA<kir::MBarrierArrive>()) {
ptr(handler)->handle(expr->as<kir::MBarrierArrive>());
return;
}
if (expr->isStrictlyA<kir::MBarrierWait>()) {
ptr(handler)->handle(expr->as<kir::MBarrierWait>());
return;
}
if (expr->isStrictlyA<kir::CpAsyncWait>()) {
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
return;
Expand Down Expand Up @@ -914,6 +946,18 @@ void OptOutConstDispatch::handle(const kir::BlockSync* stmt) {
void OptOutConstDispatch::handle(const kir::GridSync* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::MBarrierInit* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::MBarrierInvalidate* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::MBarrierArrive* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::MBarrierWait* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::CpAsyncWait* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -1123,6 +1167,18 @@ void OptOutDispatch::handle(kir::BlockSync* stmt) {
void OptOutDispatch::handle(kir::GridSync* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::MBarrierInit* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::MBarrierInvalidate* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::MBarrierArrive* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::MBarrierWait* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::CpAsyncWait* stmt) {
unhandled(stmt);
}
Expand Down
12 changes: 12 additions & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ class TensorIndex;
class Allocate;
class BlockSync;
class GridSync;
class MBarrierInit;
class MBarrierInvalidate;
class MBarrierArrive;
class MBarrierWait;
class CpAsyncWait;
class CpAsyncCommit;
class ForLoop;
Expand Down Expand Up @@ -209,6 +213,10 @@ class OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const kir::Allocate*);
virtual void handle(const kir::BlockSync*);
virtual void handle(const kir::GridSync*);
virtual void handle(const kir::MBarrierInit*);
virtual void handle(const kir::MBarrierInvalidate*);
virtual void handle(const kir::MBarrierArrive*);
virtual void handle(const kir::MBarrierWait*);
virtual void handle(const kir::CpAsyncWait*);
virtual void handle(const kir::CpAsyncCommit*);
virtual void handle(const kir::InitMagicZero*);
Expand Down Expand Up @@ -295,6 +303,10 @@ class OptOutDispatch : public PolymorphicBase {
virtual void handle(kir::Allocate* stmt);
virtual void handle(kir::BlockSync* stmt);
virtual void handle(kir::GridSync* stmt);
virtual void handle(kir::MBarrierInit* stmt);
virtual void handle(kir::MBarrierInvalidate* stmt);
virtual void handle(kir::MBarrierArrive* stmt);
virtual void handle(kir::MBarrierWait* stmt);
virtual void handle(kir::CpAsyncWait* stmt);
virtual void handle(kir::CpAsyncCommit* stmt);
virtual void handle(kir::InitMagicZero* stmt);
Expand Down
3 changes: 3 additions & 0 deletions csrc/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ void FusionExecutor::compileFusion(
lowered_ = std::make_unique<GpuLower>(fusion, compile_params);

const auto kernel = lowered_->kernel();
for (const auto& hook : post_lowering_hooks_) {
hook(kernel);
}
fusion_ = lowered_->kernel()->as<Fusion>();

fusion_id_ = ++fusion_id_counter_;
Expand Down
12 changes: 12 additions & 0 deletions csrc/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include <c10/core/DeviceType.h>

#include <functional>

namespace nvfuser {

bool shouldFillAllocationWithNan();
Expand Down Expand Up @@ -111,6 +113,12 @@ class FusionExecutor : public NonCopyable {
return runFusion(inputs, {}, launch_constraints, compile_params, opt_code);
}

// Register a post-lowering hooks that are called to modify the kernel after
// lowering. The main use case is for unit tests to modify the kernel.
void registerPostLoweringHook(std::function<void(kir::Kernel*)> hook) {
post_lowering_hooks_.push_back(std::move(hook));
}

// function to query whether a `FusionExecutor` has a compiled kernel to
// execute
bool isCompiled() const {
Expand Down Expand Up @@ -456,6 +464,10 @@ class FusionExecutor : public NonCopyable {

// Profiling support: kept copy of the cuda kernel
std::string kernel_code_;

// Post-lowering hooks that are called to modify the kernel after lowering.
// The main use case is for unit tests to modify the kernel.
std::vector<std::function<void(kir::Kernel*)>> post_lowering_hooks_;
};

} // namespace nvfuser
2 changes: 2 additions & 0 deletions csrc/executor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include <nvfuser_resources/grid_sync.h>
#include <nvfuser_resources/helpers.h>
#include <nvfuser_resources/index_utils.h>
#include <nvfuser_resources/mbarrier.h>
#include <nvfuser_resources/memory.h>
#include <nvfuser_resources/random_numbers.h>
#include <nvfuser_resources/tensor.h>
Expand Down Expand Up @@ -89,6 +90,7 @@ std::string kernelPreamble() {
ss << nvfuser_resources::block_sync_default_cu;
}
ss << nvfuser_resources::grid_sync_cu;
ss << nvfuser_resources::mbarrier_cu;

// Communication classes
ss << nvfuser_resources::block_reduction_cu;
Expand Down
88 changes: 88 additions & 0 deletions csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,94 @@ std::string GridSync::toInlineString(int indent_size) const {

NVFUSER_DEFINE_CLONE_AND_CREATE(GridSync)

MBarrierInit::MBarrierInit(
IrBuilderPasskey passkey,
Val* mbarrier,
Val* thread_count)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_CHECK(thread_count->dtype() == DataType::UInt32);
addInput(mbarrier);
addInput(thread_count);
}

std::string MBarrierInit::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "MBarrierInit(" << mbarrier()->toString() << ", "
<< threadCount()->toString() << ")\n";
return ss.str();
}

std::string MBarrierInit::toInlineString(int indent_size) const {
NVF_CHECK(false, "MBarrierInit can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierInit)

MBarrierInvalidate::MBarrierInvalidate(IrBuilderPasskey passkey, Val* mbarrier)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
addInput(mbarrier);
}

std::string MBarrierInvalidate::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "MBarrierInvalidate(" << mbarrier()->toString()
<< ")\n";
return ss.str();
}

std::string MBarrierInvalidate::toInlineString(int indent_size) const {
NVF_CHECK(false, "MBarrierInvalidate can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierInvalidate)

MBarrierArrive::MBarrierArrive(
IrBuilderPasskey passkey,
Val* state,
Val* mbarrier)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_CHECK(state->dtype() == DataType::UInt);
addInput(mbarrier);
addOutput(state);
}

std::string MBarrierArrive::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "MBarrierArrive(" << mbarrier()->toString() << ", "
<< state()->toString() << ")\n";
return ss.str();
}

std::string MBarrierArrive::toInlineString(int indent_size) const {
NVF_CHECK(false, "MBarrierArrive can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierArrive)

MBarrierWait::MBarrierWait(IrBuilderPasskey passkey, Val* mbarrier, Val* state)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_CHECK(state->dtype() == DataType::UInt);
addInput(mbarrier);
addInput(state);
}

std::string MBarrierWait::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "MBarrierWait(" << mbarrier()->toString() << ", "
<< state()->toString() << ")\n";
return ss.str();
}

std::string MBarrierWait::toInlineString(int indent_size) const {
NVF_CHECK(false, "MBarrierWait can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierWait)

CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, int64_t keep_stages)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
Expand Down
Loading