diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index eee5d9a685b3..80f1fe1765c3 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -839,7 +839,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): ) else: strategy.add_implementation( - wrap_compute_batch_matmul(topi.cuda.batch_matmul), + wrap_compute_batch_matmul(topi.cuda.batch_matmul, need_out_dtype=True), wrap_topi_schedule(topi.cuda.schedule_batch_matmul), name="batch_matmul.cuda", plevel=10,