From e9f42a09810835b56ec0c66ba08c61c65d2404c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=A7=84=EB=B0=B0=20=EB=B0=95?= Date: Fri, 26 Jan 2024 17:13:02 +0900 Subject: [PATCH] [Unity][Cutlass] Fix C source generation of dense operation This commit fixes an issue that generates wrong c sources of dense operation using cutlass. --- python/tvm/contrib/cutlass/gen_tensor_op.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 298d7895722c..2f21a1d313e2 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -566,7 +566,10 @@ def get_flattened_batch_dim(arg_name, batch_rank): transposed = "transposed" in func_name or "dense" in func_name lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 0) rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 1) - bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None) + if "bias" in func_name: + bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", 2) + else: + bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None) residual_arg_idx = _get_optional_int_annotation(annotations, "residual_arg_idx", None) lhs_arg = func_args[lhs_arg_idx]