diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 006b866d6bad..f9f332fdaccb 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -161,7 +161,8 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None): """ b, m, k = x.shape b, n, k = y.shape - cfg.add_flop(b * m * k * n * 2) + if isinstance(b, int) and isinstance(m, int) and isinstance(n, int) and isinstance(k, int): + cfg.add_flop(b * m * k * n * 2) return cublas.batch_matmul(x, y, False, True)