From d8b6883347e4d68b617e7ce8f86063a10cfdfc87 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Tue, 19 Nov 2024 15:10:36 +0100 Subject: [PATCH 1/2] fixed assert by using analyzer to the prove equality --- python/tvm/topi/scatter.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 799b3d16733f..9cf19e2e61f4 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -18,9 +18,11 @@ """ScatterND operator""" from tvm import te, tir # hide redefinition of min and max from tvm.tir import expr +from tvm.arith.analyzer import Analyzer def _verify_scatter_nd_inputs(data, indices, updates): + analyzer = Analyzer() mdim = int(indices.shape[0]) assert mdim <= len(data.shape), ( f"The first dimension of the indices ({mdim}) must be less than or equal to " @@ -29,7 +31,8 @@ def _verify_scatter_nd_inputs(data, indices, updates): for i in range(len(indices.shape) - 1): if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var): continue - assert indices.shape[i + 1] == updates.shape[i], ( + + assert analyzer.can_prove_equal(indices.shape[i + 1], updates.shape[i]), ( f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " f"updates[{i}] ({updates.shape[i]})." ) From ead8f46255447b7a9081165dcd1a0d5ae000eac8 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Thu, 21 Nov 2024 14:04:52 +0100 Subject: [PATCH 2/2] updated docs in Analyzer class --- python/tvm/arith/analyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 22555e0fb3a4..f8069a717da3 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -218,7 +218,7 @@ def int_set(self, expr, dom_map): expr : PrimExpr The expression. - dom_map : Dict[Var, tvm.arith.IntSet] + dom_map : Dict[tvm.tir.Var, tvm.arith.IntSet] The domain for variables to be relaxed. Returns