diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 3a0441ef84af..d74d5d2a845a 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -262,7 +262,17 @@ def simplify(expr): out : Expr or int The simplified output """ - return tvm.arith.Analyzer().simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr + if isinstance(expr, te.Tensor): + return te.compute( + expr.shape, + lambda *indices: tvm.arith.Analyzer().simplify(expr[indices]), + name="simplify_output", + tag="simplify", + ) + elif isinstance(expr, tvm.tir.PrimExpr): + return tvm.arith.Analyzer().simplify(expr) + else: + return expr def ravel_index(indices, shape):