Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,40 +736,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
#endif // TVM_LLVM_VERSION
}

// Check if this is a warp shuffle intrinsic call and match its
// corresponding nvvm intrinsic. Return true if the match is successful.
static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
// Only 32 bit data type is supported.
if (op->dtype.is_vector() || op->dtype.bits() != 32) {
return false;
}

// Intrinsic lookup table.
// It is difficult to emit _sync verion that works on Pascal.
// We ignore the mask and only emit the non-sync version for nvptx.
llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};

int offset = 0;
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
offset = 0;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
offset = 2;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
offset = 4;
} else {
return false;
}

*id = ids[offset + op->dtype.is_float()];
return true;
}

llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;

if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
Expand Down Expand Up @@ -814,25 +781,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
}
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (GetWarpShuffleIntrinsic(op, &id)) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
// Ignore the first mask operand and remove the last
// redundant warp_size..
size_t n_args = op->args.size() - 1;
for (size_t i = 1; i < n_args; ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::Type* return_type = arg_type[0];
llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
return builder_->CreateCall(func, arg_value);
} else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
// Only nvptx target may keep this intrinsic at this point.
// PTX assembly: asm "activemask.b32 r1;"
auto fty = llvm::FunctionType::get(t_int32_, false);
auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
return builder_->CreateCall(val);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
Expand Down
62 changes: 61 additions & 1 deletion src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
CodeGenLLVM::Optimize();
}

llvm::Value* CreateIntrinsic(const CallNode* op) override;

protected:
void InitTarget(llvm::TargetMachine* tm) final {
// Maximum vector lane = float4
Expand All @@ -178,6 +180,62 @@ class CodeGenNVPTX : public CodeGenLLVM {
}
};

// Check if this is a warp shuffle intrinsic call and match its
// corresponding nvvm intrinsic. Return true if the match is successful.
static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
// Only 32 bit data type is supported.
if (op->dtype.is_vector() || op->dtype.bits() != 32) {
return false;
}

// Intrinsic lookup table.
// It is difficult to emit _sync verion that works on Pascal.
// We ignore the mask and only emit the non-sync version for nvptx.
llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};

int offset = 0;
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
offset = 0;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
offset = 2;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
offset = 4;
} else {
return false;
}

*id = ids[offset + op->dtype.is_float()];
return true;
}

llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;
if (GetWarpShuffleIntrinsic(op, &id)) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
// Ignore the first mask operand and remove the last
// redundant warp_size..
size_t n_args = op->args.size() - 1;
for (size_t i = 1; i < n_args; ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::Type* return_type = arg_type[0];
llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
return builder_->CreateCall(func, arg_value);
} else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
// Only nvptx target may keep this intrinsic at this point.
// PTX assembly: asm "activemask.b32 r1;"
auto fty = llvm::FunctionType::get(t_int32_, false);
auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
return builder_->CreateCall(val);
}
return CodeGenLLVM::CreateIntrinsic(op);
}

inline int DetectCUDAComputeVersion() {
TVMContext tvm_ctx;
tvm_ctx.device_type = kDLGPU;
Expand All @@ -204,8 +262,10 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) {
config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver
<< target.substr(5, target.length() - 5);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
// have a shorter lifetime than the ctx.
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());

cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false);

Expand Down
3 changes: 3 additions & 0 deletions topi/python/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def schedule_softmax(outs):
def sched_warp_softmax():
if tgt.target_name == "nvptx":
return softmax.dtype == "float32" or softmax.dtype == "int32"
if tgt.target_name != "cuda":
# this is used as the gpu schedule for other arches which may not have warp reductions
return False
return True

if len(softmax.shape) > 2:
Expand Down