diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index 97c039ee29cf..772007792ae6 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -602,7 +602,8 @@ class CodegenCutlass : public backend::MemoizedExprTranslatorbody.as(), 0, "nn.dense"); + const auto* dense_call = + GetRootCall(callee->body.as(), 0, std::vector{"nn.dense"}); return GenerateBody(dense_call, "cutlass_dense", GetArgumentNames(caller), DenseArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.dense_bias") { @@ -637,11 +638,12 @@ class CodegenCutlass : public backend::MemoizedExprTranslatorbody.as(), 0, "nn.batch_matmul"); + GetRootCall(callee->body.as(), 0, std::vector{"nn.batch_matmul"}); return GenerateBody(batch_matmul_call, "cutlass_batch_matmul", GetArgumentNames(caller), BatchMatmulArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d") { - const auto* conv2d_call = GetRootCall(callee->body.as(), 0, "nn.conv2d"); + const auto* conv2d_call = + GetRootCall(callee->body.as(), 0, std::vector{"nn.conv2d"}); return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d_bias") { @@ -704,12 +706,13 @@ class CodegenCutlass : public backend::MemoizedExprTranslatorbody.as(), 0, "nn.conv2d_transpose"); + const auto* conv2d_call = GetRootCall(callee->body.as(), 0, + std::vector{"nn.conv2d_transpose"}); return GenerateBody(conv2d_call, "cutlass_conv2d_transpose", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_), true, false)); } else if (pattern_name == "cutlass.conv2d_backward_weight") { - const auto* conv2d_call = - GetRootCall(callee->body.as(), 0, "nn.conv2d_backward_weight"); + const auto* conv2d_call = GetRootCall(callee->body.as(), 0, + std::vector{"nn.conv2d_backward_weight"}); return GenerateBody(conv2d_call, "cutlass_conv2d_backward_weight", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_), false, true)); }