diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 298d7895722c..2f21a1d313e2 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -566,7 +566,10 @@ def get_flattened_batch_dim(arg_name, batch_rank): transposed = "transposed" in func_name or "dense" in func_name lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 0) rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 1) - bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None) + if "bias" in func_name: + bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", 2) + else: + bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None) residual_arg_idx = _get_optional_int_annotation(annotations, "residual_arg_idx", None) lhs_arg = func_args[lhs_arg_idx]