From ef6919f33a596fd20db06fafbc5f411318e0e361 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Wed, 8 May 2024 16:41:39 -0400 Subject: [PATCH 1/3] [Unity][BYOC] Use arith.Analyzer to check batch equality of matmul in cublas --- python/tvm/relax/backend/contrib/cublas.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index e5bc55c32751..aaaf335d8bed 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -21,6 +21,7 @@ import tvm from tvm import DataType +from tvm.arith import Analyzer from tvm.relax import transform from tvm.relax.transform import PatternCheckContext @@ -118,6 +119,8 @@ def _check_matmul(context: PatternCheckContext) -> bool: if not isinstance(bias_batches, (tvm.tir.expr.IntImm, int)) or int(bias_batches) > 1: # cuBLAS only supports bias vector return False + + analyzer = Analyzer() # cuBLASLt does not seem to support batched GEMM with one of matrices having # one batch (with batch_stride 0). So for batched GEMM, the two batch counts @@ -126,7 +129,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: return ( isinstance(lhs_batches, tvm.tir.Var) or isinstance(rhs_batches, tvm.tir.Var) - or (int(lhs_batches) == int(rhs_batches)) + or (analyzer.can_prove_equal(lhs_batches, rhs_batches)) or (lhs_batches >= 1 and rhs_batches == 1) ) From b735e4e9c5a5f2004298e1ec202ef9f90cabe386 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Wed, 8 May 2024 17:05:53 -0400 Subject: [PATCH 2/3] Fix lint, add unit tests --- python/tvm/relax/backend/contrib/cublas.py | 2 +- tests/python/relax/test_codegen_cublas.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index aaaf335d8bed..faec63e1907a 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -119,7 +119,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: if not isinstance(bias_batches, (tvm.tir.expr.IntImm, int)) or int(bias_batches) > 1: # cuBLAS only supports bias vector return False - + analyzer = Analyzer() # cuBLASLt does not seem to support batched GEMM with one of matrices having diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 4ff498ae2b93..1b19c891863f 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -149,6 +149,8 @@ def get_relax_matmul_dequantize_module( ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"), # ND x ND ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"), + ((_vars["a"], 3, 32, 8), (_vars["a"], 3, 8, 10), True, "relu"), + ((_vars["a"], _vars["b"], 32, 8), (_vars["a"], _vars["b"], 8, 10), True, "relu"), # ND x 2D ((5, 3, 32, 8), (8, 10), False, "none"), ], From 96a1dd56e7ccfe9737d079ad116f6fc20fa86937 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Wed, 8 May 2024 22:13:25 -0400 Subject: [PATCH 3/3] Trigger Build