From 5ae5962e9700bbd18bde80ab1731233e53efe35b Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 8 Jul 2023 23:41:28 -0700 Subject: [PATCH] [Unity][TIR] Allow symbolic bounds in IndexMap analysis Following #15264, this PR makes changes accordingly to the Unity branch to enable symbolic bounds in IndexMap analysis. --- src/relax/analysis/layout_transformation.cc | 3 ++- src/relax/op/tensor/manipulate.cc | 3 ++- src/relax/transform/alter_op_impl.cc | 12 ++++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index 44538fea98e5..8348365761fa 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -22,6 +22,7 @@ * \brief Analyze the PrimFunc and suggest layout transformation on it's blocks and buffers based on * the user provided layout transformations on it's outputs. */ +#include #include #include #include @@ -172,8 +173,8 @@ static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { // Create a new shape expression. Array t1_initial_indices = t1->initial_indices.Map([](tir::Var i) -> PrimExpr { return i; }); - auto t0_output = t0->MapIndices(t1_initial_indices); arith::Analyzer analyzer; + auto t0_output = t0->MapIndices(t1_initial_indices, &analyzer); for (size_t i = 0; i < t0_output.size(); ++i) { if (!analyzer.CanProveEqual(t0_output[i], t1->final_indices[i])) return false; } diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 5b298110be55..a55d19982224 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -474,7 +474,8 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); } - Array output_shape = index_map->MapShape(shape_sinfo->values.value()); + arith::Analyzer analyzer; + Array output_shape = index_map->MapShape(shape_sinfo->values.value(), &analyzer); return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); } diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 1fadb86d715c..f40ee3b3bf48 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -23,6 +23,7 @@ * identify PrimFuncs to be replaced. Marks the new PrimFuncs with kFrozenLayout attribute set to * true. */ +#include #include #include #include @@ -60,9 +61,9 @@ static IndexMap DeepCopyIndexMap(const IndexMap& index_map) { bool IsTransformBijective(const Expr& expr, const IndexMap& transform) { Array input_shape = GetShapeFromTensor(expr); Array initial_ranges = ConstructRangeFromShape(input_shape); - auto [inverse, padding_predicate] = transform.NonSurjectiveInverse(initial_ranges); - (void)inverse; // to avoid unused variable warning; arith::Analyzer analyzer; + auto [inverse, padding_predicate] = transform.NonSurjectiveInverse(initial_ranges, &analyzer); + (void)inverse; // to avoid unused variable warning; if (!analyzer.CanProve(!padding_predicate)) return false; return true; } @@ -169,7 +170,9 @@ class AlterOpImplMutator : public ExprMutator { const TensorStructInfo& old_tensor_sinfo) { Array old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo); Array initial_ranges = ConstructRangeFromShape(old_shape); - auto [inverse_index_map, padding_predicate] = index_map.NonSurjectiveInverse(initial_ranges); + arith::Analyzer analyzer; + auto [inverse_index_map, padding_predicate] = + index_map.NonSurjectiveInverse(initial_ranges, &analyzer); ICHECK(tir::is_zero(padding_predicate)) << "Only bijective transformations on input/output buffers are supported, but found " "padding predicate " @@ -245,7 +248,8 @@ class AlterOpImplMutator : public ExprMutator { /*! \brief Returns the TensorStructInfo after applying the \p transform on its shape */ StructInfo UpdateStructInfo(const TensorStructInfo& tensor_sinfo, const IndexMap& transform) { auto shape = GetShapeFromTensorStructInfo(tensor_sinfo); - auto new_shape = transform->MapShape(shape); + arith::Analyzer analyzer; + auto new_shape = transform->MapShape(shape, &analyzer); return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype); }