diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index e5bc55c32751..faec63e1907a 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -21,6 +21,7 @@ import tvm from tvm import DataType +from tvm.arith import Analyzer from tvm.relax import transform from tvm.relax.transform import PatternCheckContext @@ -119,6 +120,8 @@ def _check_matmul(context: PatternCheckContext) -> bool: # cuBLAS only supports bias vector return False + analyzer = Analyzer() + # cuBLASLt does not seem to support batched GEMM with one of matrices having # one batch (with batch_stride 0). So for batched GEMM, the two batch counts # must be equal. If lhs is batched but rhs is not, we can use the regular GEMM by @@ -126,7 +129,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: return ( isinstance(lhs_batches, tvm.tir.Var) or isinstance(rhs_batches, tvm.tir.Var) - or (int(lhs_batches) == int(rhs_batches)) + or (analyzer.can_prove_equal(lhs_batches, rhs_batches)) or (lhs_batches >= 1 and rhs_batches == 1) ) diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 4ff498ae2b93..1b19c891863f 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -149,6 +149,8 @@ def get_relax_matmul_dequantize_module( ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"), # ND x ND ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"), + ((_vars["a"], 3, 32, 8), (_vars["a"], 3, 8, 10), True, "relu"), + ((_vars["a"], _vars["b"], 32, 8), (_vars["a"], _vars["b"], 8, 10), True, "relu"), # ND x 2D ((5, 3, 32, 8), (8, 10), False, "none"), ],