Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,8 @@ class CodegenCutlass : public backend::MemoizedExprTranslator<std::vector<Output
ICHECK(pattern_name.defined()) << "Only functions with composite attribute are supported.";

if (pattern_name == "cutlass.dense") {
const auto* dense_call = GetRootCall(callee->body.as<CallNode>(), 0, "nn.dense");
const auto* dense_call =
GetRootCall(callee->body.as<CallNode>(), 0, std::vector<std::string>{"nn.dense"});
return GenerateBody(dense_call, "cutlass_dense", GetArgumentNames(caller),
DenseArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.dense_bias") {
Expand Down Expand Up @@ -637,11 +638,12 @@ class CodegenCutlass : public backend::MemoizedExprTranslator<std::vector<Output
DenseArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.batch_matmul") {
const auto* batch_matmul_call =
GetRootCall(callee->body.as<CallNode>(), 0, "nn.batch_matmul");
GetRootCall(callee->body.as<CallNode>(), 0, std::vector<std::string>{"nn.batch_matmul"});
return GenerateBody(batch_matmul_call, "cutlass_batch_matmul", GetArgumentNames(caller),
BatchMatmulArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d") {
const auto* conv2d_call = GetRootCall(callee->body.as<CallNode>(), 0, "nn.conv2d");
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 0, std::vector<std::string>{"nn.conv2d"});
return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias") {
Expand Down Expand Up @@ -704,12 +706,13 @@ class CodegenCutlass : public backend::MemoizedExprTranslator<std::vector<Output
return GenerateBody(conv2d_call, pattern_name.value(), GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_transpose") {
const auto* conv2d_call = GetRootCall(callee->body.as<CallNode>(), 0, "nn.conv2d_transpose");
const auto* conv2d_call = GetRootCall(callee->body.as<CallNode>(), 0,
std::vector<std::string>{"nn.conv2d_transpose"});
return GenerateBody(conv2d_call, "cutlass_conv2d_transpose", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_), true, false));
} else if (pattern_name == "cutlass.conv2d_backward_weight") {
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 0, "nn.conv2d_backward_weight");
const auto* conv2d_call = GetRootCall(callee->body.as<CallNode>(), 0,
std::vector<std::string>{"nn.conv2d_backward_weight"});
return GenerateBody(conv2d_call, "cutlass_conv2d_backward_weight", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_), false, true));
}
Expand Down