From 0aa51155f91dfe7e89fcc9eafeb359f4766babae Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 30 Mar 2020 00:53:01 -0700 Subject: [PATCH 1/7] [arith] linear system and equation solver Co-authored-by: Sergei Grechanik --- include/tvm/arith/linear_system.h | 171 +++++++ include/tvm/arith/util.h | 45 ++ python/tvm/arith/__init__.py | 1 + python/tvm/arith/linear_system.py | 99 ++++ src/arith/linear_system.cc | 85 ++++ src/arith/solve_linear_equation.cc | 471 ++++++++++++++++++ src/arith/util.cc | 53 ++ .../test_arith_solve_linear_system.py | 91 ++++ 8 files changed, 1016 insertions(+) create mode 100644 include/tvm/arith/linear_system.h create mode 100644 include/tvm/arith/util.h create mode 100644 python/tvm/arith/linear_system.py create mode 100644 src/arith/linear_system.cc create mode 100644 src/arith/solve_linear_equation.cc create mode 100644 src/arith/util.cc create mode 100644 tests/python/unittest/test_arith_solve_linear_system.py diff --git a/include/tvm/arith/linear_system.h b/include/tvm/arith/linear_system.h new file mode 100644 index 000000000000..f7f33039698a --- /dev/null +++ b/include/tvm/arith/linear_system.h @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/linear_system.h + * \brief Linear system data structures and solvers + */ +#ifndef TVM_ARITH_LINEAR_SYSTEM_H_ +#define TVM_ARITH_LINEAR_SYSTEM_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace arith { + +using tir::Var; +using tir::VarNode; +using tir::IterVar; + +/*! + * \brief Represent a linear system including variables, their ranges and + * the linear relations between them (either equations or inequalities) + * \sa LinearSystem + */ +class LinearSystemNode : public Object { + public: + // e.g., \alpha, \beta + Array variables; + // e.g., 1 <= \alpha <= N, etc. + Map ranges; + // linear equalities or inequalities + // e.g., A \alpha = \beta or A \alpha <= \beta + Array relations; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("variables", &variables); + v->Visit("ranges", &ranges); + v->Visit("relations", &relations); + } + + static constexpr const char* _type_key = "arith.LinearSystem"; + TVM_DECLARE_FINAL_OBJECT_INFO(LinearSystemNode, Object); +}; + +/*! + * \brief Managed reference to LinearSystemNode. + * \sa LinearSystemNode + */ +class LinearSystem : public ObjectRef { + public: + /*! + * \brief Constructor by fields + * \param variables The variables in the system. + * \param ranges The ranges of the variables. + * \param relations The linear relations between the variables + * (either equations or inequalities) + */ + TVM_DLL LinearSystem(Array variables, + Map ranges, + Array relations); + + TVM_DEFINE_OBJECT_REF_METHODS(LinearSystem, ObjectRef, LinearSystemNode); +}; + +/*! + * \brief We can have different set of variables to represent the same linear system. + * For example, the following two systems are equivalent, + * {a + b = 0 | a >= 0, b >= 0} and + * {m - n = 0 | m >= 0, n <= 0} + * This data structure represents the transformation + * between two equivalent linear systems. + * In the above example, + * src : {a + b = 0 | a >= 0, b >= 0} + * dst : {m - n = 0 | m >= 0, n <= 0} + * src_to_dst : {a -> m, b -> -n} + * dst_to_src : {m -> a, n -> -b} + * \sa LinearSystemTransform + */ +class LinearSystemTransformNode : public Object { + public: + LinearSystem src; + LinearSystem dst; + Map src_to_dst; + Map dst_to_src; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("src", &src); + v->Visit("dst", &dst); + v->Visit("src_to_dst", &src_to_dst); + v->Visit("dst_to_src", &dst_to_src); + } + + static constexpr const char* _type_key = "arith.LinearSystemTransform"; + TVM_DECLARE_FINAL_OBJECT_INFO(LinearSystemTransformNode, Object); +}; + +/*! + * \brief Managed reference to LinearSystemTransformNode. + * \sa LinearSystemTransformNode + */ +class LinearSystemTransform : public ObjectRef { + public: + /*! + * \brief Constructor by fields + * \param src source linear system, e.g., {a + b = 0 | a >= 0, b >= 0} + * \param dst linear system equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0} + * \param src_to_dst mapping from variables in the \p src to the variables in the \p dst, + * e.g., {a -> m, b -> -n} + * \param dst_to_src mapping from variables in the \p dst to the variables in the \p src, + * e.g., {m -> a, n -> -b} + */ + TVM_DLL LinearSystemTransform(LinearSystem src, + LinearSystem dst, + Map src_to_dst, + Map dst_to_src); + + TVM_DEFINE_OBJECT_REF_METHODS(LinearSystemTransform, ObjectRef, LinearSystemTransformNode); +}; + +/*! + * \brief Obtain Smith Normal Form of linear equation A x = y. + * Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn}, + * in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0), + * such that si | s_{i+1} and r is the rank of A. + * U_{mxm} and V_{nxn} are invertible matrices. + * This function modifies \p S to be S_{mxn}, \p V to be V_{nxn}, + * \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. + * \param S the original A_{mxn}, it will be modified to S_{mxn} + * \param V an identity matrix, it will be modified to V_{nxn} + * \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} + * \param y the y in A x = y. it will be modified to U_{mxm} y_{mx1} + */ +void SmithNormalFormDiag(std::vector> *S, + std::vector> *V, + std::vector* x, + std::vector *y); + +/*! + * \brief Solve linear equations. + * \param system_to_solve the variables to solve, their ranges, and a list of equations. + * \return A new linear system, with less variables (if \p system_to_solve is NOT of full rank), + * or no variable (if \p system_to_solve is of full rank), + * or an empty linear system (if \p system_to_solve is unsolvable). + * It also provides the ranges of the variables in the new system, + * as well as inequalities inferred from the \p system_to_solve. + * You can get the mapping from the original variables to the solution via ret->src_to_dst. + */ +LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve); + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITH_LINEAR_SYSTEM_H_ diff --git a/include/tvm/arith/util.h b/include/tvm/arith/util.h new file mode 100644 index 000000000000..adfcefcd2e21 --- /dev/null +++ b/include/tvm/arith/util.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/util.h + * \brief Utils for arithmetic analysis. + */ +#ifndef TVM_ARITH_UTIL_H_ +#define TVM_ARITH_UTIL_H_ + +#include +#include + +namespace tvm { +/*! \brief namespace of arithmetic analysis. */ +namespace arith { + +/*! + * \brief Calculate the extended greatest common divisor for two values. + * See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm. + * \param a an integer number + * \param b an integer number + * \return 3 integers (div, m, n) where div = gcd(a, b) and a*m + b*n = div + */ +std::tuple xgcd(int64_t a, int64_t b); + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITH_UTIL_H_ diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 40e977e61d75..25f30bc60e26 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -20,3 +20,4 @@ from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound +from .linear_system import solve_equations diff --git a/python/tvm/arith/linear_system.py b/python/tvm/arith/linear_system.py new file mode 100644 index 000000000000..86c6a0f76ba3 --- /dev/null +++ b/python/tvm/arith/linear_system.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Linear system data structures and solvers""" +import tvm._ffi +from tvm.runtime import Object +from . import _ffi_api + + +@tvm._ffi.register_object("arith.LinearSystem") +class LinearSystem(Object): + """Represent a linear system including variables, their ranges and + the linear relations between them (either equations or inequalities) + + Parameters + ---------- + variables : List[tvm.tir.Var] + The variables in the system. + ranges : Map[tvm.tir.Var, tvm.ir.Range] + The ranges of the variables. + relations : List[tvm.ir.PrimExpr] + The linear relations between the variables (either equations or inequalities) + """ + def __init__(self, variables, ranges, relations): + self.__init_handle_by_constructor__( + _ffi_api.LinearSystem, variables, ranges, relations) + + +@tvm._ffi.register_object("arith.LinearSystemTransform") +class LinearSystemTransform(Object): + """We can have different set of variables to represent the same linear system. + For example, the following two systems are equivalent, + {a + b = 0 | a >= 0, b >= 0} and + {m - n = 0 | m >= 0, n <= 0} + This data structure represents the transformation + between two equivalent linear systems. + In the above example, + src : {a + b = 0 | a >= 0, b >= 0} + dst : {m - n = 0 | m >= 0, n <= 0} + src_to_dst : {a -> m, b -> -n} + dst_to_src : {m -> a, n -> -b} + + Parameters + ---------- + src : arith.LinearSystem + source linear system, e.g., {a + b = 0 | a >= 0, b >= 0} + dst : arith.LinearSystem + linear system equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0} + src_to_dst : Map[tvm.tir.Var, tvm.ir.PrimExpr] + mapping from variables in the src to the variables in the dst, + e.g., {a -> m, b -> -n} + dst_to_src : Map[tvm.tir.Var, tvm.ir.PrimExpr] + mapping from variables in the dst to the variables in the src, + e.g., {m -> a, n -> -b} + """ + def __init__(self, src, dst, src_to_dst, dst_to_src): + self.__init_handle_by_constructor__( + _ffi_api.LinearSystemTransform, src, dst, src_to_dst, dst_to_src) + + +def solve_equations(equations, variables, ranges): + """Solve linear equations. + + Parameters + ---------- + equations: List[tvm.ir.PrimExpr] or LinearSystemTransform + The linear relations between the variables (either equations or inequalities) + variables : List[tvm.tir.Var] + The variables in the system. + ranges : Map[tvm.tir.Var, tvm.ir.Range] + The ranges of the variables. + + Returns + ------- + linear_system_transform : LinearSystemTransform + A new linear system, with less variables (if the problem is NOT of full rank), + or no variable (if the problem is of full rank), + or an empty linear system (if the problem is unsolvable). + It also provides the ranges of the variables in the new system, + as well as inequalities inferred from the problem. + You can get the mapping from the original variables to the solution via + linear_system_transform.src_to_dst. + """ + if isinstance(equations, LinearSystemTransform): + return _ffi_api.SolveEquations(equations) + return _ffi_api.SolveEquations(variables, ranges, equations) diff --git a/src/arith/linear_system.cc b/src/arith/linear_system.cc new file mode 100644 index 000000000000..169d5e384e19 --- /dev/null +++ b/src/arith/linear_system.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file linear_system.cc + * \brief The linear system data structures. + */ +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace arith { + +LinearSystem::LinearSystem(Array variables, + Map ranges, + Array relations) { + ObjectPtr node = make_object(); + node->variables = std::move(variables); + node->ranges = std::move(ranges); + node->relations = std::move(relations); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(LinearSystemNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "LinearSystem(" + << op->variables + << ", " << op->ranges + << ", " << op->relations + << ")"; + }); + + +LinearSystemTransform::LinearSystemTransform(LinearSystem src, + LinearSystem dst, + Map src_to_dst, + Map dst_to_src) { + ObjectPtr node = make_object(); + node->src = std::move(src); + node->dst = std::move(dst); + node->src_to_dst = std::move(src_to_dst); + node->dst_to_src = std::move(dst_to_src); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(LinearSystemTransformNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "LinearSystemTransform(" + << "\n\t" << op->src + << "\n\t" << op->dst + << "\n\t" << op->src_to_dst + << "\n\t" << op->dst_to_src + << "\n)"; + }); + +} // namespace arith +} // namespace tvm diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc new file mode 100644 index 000000000000..480f5e69a1fa --- /dev/null +++ b/src/arith/solve_linear_equation.cc @@ -0,0 +1,471 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/linear_solver.cc + * \brief Solve linear equations. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace arith { + +using namespace tvm::runtime; + +void SmithNormalFormDiag(std::vector >* S, + std::vector >* V, + std::vector* x, + std::vector* y) { + if (S->empty() || V->empty()) return; + size_t m = S->size(); + size_t n = (*S)[0].size(); // n is # of variables + CHECK_EQ(V->size(), n); + CHECK_EQ((*V)[0].size(), n); + + for (size_t index = 0; index < std::min(m, n); ++index) { + // Here A is partially diagonalized, that is A[i, j] is zero for all i, j + // such that (i < index) or (j < index), unless (i == j). + // That is, now we are diagonalizing the submatrix with i >= index and j >= index + + // Find a row with a nonzero element in the index-th column + // (We also prefer rows where this element has minimal abs value) + size_t best_i = index; + for (size_t i = best_i; i < m; ++i) { + int64_t s_old = (*S)[best_i][index]; + int64_t s_new = (*S)[i][index]; + if (s_new != 0) { + if (s_old == 0 || std::abs(s_new) < std::abs(s_old)) { + best_i = i; + } + } + } + // Move the row we found to the index-th position + std::swap((*S)[index], (*S)[best_i]); + std::swap((*y)[index], (*y)[best_i]); + + // If the index-th diagonal element is still zero, try to find a column with nonzero index-th + // element and move it to the index-th position + if ((*S)[index][index] == 0) { + for (size_t j = index + 1; j < n; ++j) { + if ((*S)[index][j] != 0) { + for (size_t i = index; i < m; ++i) { + std::swap((*S)[i][index], (*S)[i][j]); + } + // swapping columns corresponds to swapping the corresponding x + std::swap((*x)[index], (*x)[j]); + for (size_t i = 0; i < n; ++i) { + std::swap((*V)[i][index], (*V)[i][j]); + } + break; + } + } + } + + // If the index-th diagonal element is still zero, then both the index-th row and the index-th + // column are completely zero, and we don't need to do anything; just go to the next index + if ((*S)[index][index] == 0) { + continue; + } + + // Now the index-th diagonal element is non-zero and we can zero all the index-th column + // below it by subtracting rows from each other + for (auto i = index + 1; i < m; ++i) { + if ((*S)[i][index] != 0) { + int64_t g, a, b; + // g = a*matrix[index][index] + b*matrix[i][index] + if ((*S)[i][index] % (*S)[index][index] != 0) { + std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[i][index]); + } else { + // Explicitly avoid changing the index-th row. This is important to avoid infinite loop. + g = (*S)[index][index]; + a = 1; + b = 0; + } + + // Let m = S[index][index], n = S[i][index], then the following is true: + // + // [ a n/g ][ m/g n/g ] = [ 1 0 ] + // [ b -m/g ][ b -a ] = [ 0 1 ] + // + // Note that the two matrices are integer (since g = gcd(m, n)). + // We will essentially multiply our matrix on the left by a dilated and transposed version + // of the first of these two matrices. The second matrix is not needed here, however we will + // use it while zeroing the index-th row. + + int64_t m_g = (*S)[index][index] / g; + int64_t n_g = (*S)[i][index] / g; + + // Note that j is the index of the column, not the row + for (size_t j = index; j < (*S)[i].size(); ++j) { + // Multiply index-th row by a and add the i-th row multiplied by b + // This will make the index-th diagonal element equal to the gcd + int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j]; + // This transformation performs zeroing of matrix[i][index] + int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j]; + (*S)[index][j] = new_index_j; + (*S)[i][j] = new_i_j; + } + // We have to do the same with rhs + PrimExpr ea = te::make_const((*y)[index].dtype(), a); + PrimExpr eb = te::make_const((*y)[i].dtype(), b); + PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g); + PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g); + PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i]; + PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i]; + (*y)[index] = new_index_rhs; + (*y)[i] = new_i_rhs; + } + } + + bool changed = false; + + // Now we have to zero the elements of the index-th row by manipulating columns. + // This is more difficult because column manipulation corresponds to variable manipulation, + // but the algorithm is essentially the same as before. + for (size_t j = index + 1; j < n; ++j) { + if ((*S)[index][j] != 0) { + int64_t g, a, b; + // g = a*matrix[index][index] + b*matrix[index][j] + if ((*S)[index][j] % (*S)[index][index] != 0) { + std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[index][j]); + // During this phase we may disrupt the zeroness of the index-th column, so we will + // have to take some action if this might have happened. + changed = true; + } else { + // Explicitly avoid changing the index-th column. This is important to avoid infinite + // loop. Note that here we don't have to set `changed` to true since we don't change the + // index-th column. + g = (*S)[index][index]; + a = 1; + b = 0; + } + + // Let m = S[index][index], n = S[index][j], then the following is true: + // + // [ a n/g ][ m/g n/g ] = [ 1 0 ] + // [ b -m/g ][ b -a ] = [ 0 1 ] + // + // Now we are going to multiply our matrix on the right (to manipulate columns instead of + // rows), we will also transform the old_to_new matrix the same way, and we will use the + // second matrix to transform new_to_old. + + int64_t m_g = (*S)[index][index] / g; + int64_t n_g = (*S)[index][j] / g; + + for (size_t i = index; i < m; ++i) { + int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j]; + int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j]; + (*S)[i][index] = new_i_index; + (*S)[i][j] = new_i_j; + } + // We do exactly the same transformations with V + for (size_t i = 0; i < n; ++i) { + int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j]; + int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j]; + (*V)[i][index] = new_i_index; + (*V)[i][j] = new_i_j; + } + // And apply reverse transformations to new_to_old. + PrimExpr ea = te::make_const((*x)[j].dtype(), a); + PrimExpr eb = te::make_const((*x)[index].dtype(), b); + PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g); + PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g); + PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j]; + PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j]; + (*x)[index] = new_index; + (*x)[j] = new_j; + } + } + + if (changed) { + // We might have changed the first column, so we have to zero it once more + // (or at least check if it's zero), so just perform this iteration once more. + index -= 1; + } + } +} + +Map InferRange(const Map& vars_to_infer, + const Array& ori_vars, + const Map& ori_ranges) { + // The resulting ranges + Map new_ranges; + + std::unordered_set ori_vset; + for (const Var& v : ori_vars) { + ori_vset.insert(v.get()); + } + + std::unordered_map var_intsets; + for (const auto& p : ori_ranges) { + if (!ori_vset.count(p.first.get())) { + // First of all, fill the new ranges with outer variable ranges + new_ranges.Set(p.first, p.second); + } + // Convert original ranges to IntSets + var_intsets[p.first.get()] = IntSet::range(p.second); + } + + // Infer ranges for the new variables and add them to the resulting ranges + for (const auto& p : vars_to_infer) { + const auto& var = p.first; + const auto& expr = p.second; + Range range = EvalSet(expr, var_intsets).cover_range(Range()); + if (range.defined()) { + new_ranges.Set(var, range); + } + } + return new_ranges; +} + +// pretty print matrix equation +void DebugPrint(const std::vector>& S, + const std::vector>& V, + const std::vector& V_inv_x, + const std::vector& rhs) { + std::cout << "S:\n"; + for (size_t i = 0; i < S.size(); ++i) { + for (auto e : S[i]) { + std::cout << e << "\t"; + } + std::cout << "\t->\t" << rhs[i]; + std::cout << "\n"; + } + std::cout << "V:\n"; + for (const auto& r : V) { + for (auto e : r) { + std::cout << e << "\t"; + } + std::cout << "\n"; + } + std::cout << "V_inv x:\n" << Array(V_inv_x); + std::cout << "\n" << std::endl; +} + +LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { + // m: # of equations + // n: # of variables + // we first construct A_{mxn} x_{nx1} = y_{mx1} + // then get Smith normal form of matrix A, + // S_{mxn} = U_{mxm} A_{mxn} V_{nxn} + // => U^{-1} S V^{-1} x = y + // S V^{-1} x = U y + std::vector Uy; // mx1 + std::vector> S; // mxn + std::vector> V; // nxn + std::vector V_inv_x; // V^{-1} x, nx1 + // Conditions we don't know what to do with + std::vector rest; + + size_t num_vars = system_to_solve->variables.size(); + + // initialize V_{nxn} with identity matrix, + // initialize V^{-1} x as x + for (size_t i = 0; i < num_vars; ++i) { + V.emplace_back(num_vars); + V.back()[i] = 1; + V_inv_x.push_back(system_to_solve->variables[i]); + } + + // Transform formulas into rows of the matrix + // S_{mxn} V^{-1}_{nxn} x_{nx1} = U y, in which n is # of variables + // here we initialize S_{mxn} to be A, U to be identity matrix. + for (const PrimExpr& equation : system_to_solve->relations) { + if (const tir::EQNode* eq = equation.as()) { + // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] + Array coeffs = arith::DetectLinearEquation( + tir::Simplify(eq->a - eq->b, system_to_solve->ranges), + system_to_solve->variables); + if (!coeffs.empty()) { + std::vector row; + for (size_t j = 0; j < coeffs.size() - 1; ++j) { + PrimExpr c = coeffs[j]; + if (const IntImmNode* ic = c.as()) { + row.push_back(ic->value); + } else { + // elements in matrix S V must be integers + // ignore equations that we cannot deal with. + row.clear(); + break; + } + } + + if (!row.empty()) { + // S V (a-b) = Uy + // V is identity for now + S.push_back(row); + Uy.push_back(-coeffs[coeffs.size() - 1]); + continue; + } + } + } + + // otherwise + rest.push_back(equation); + } + + // After diagonalizing, we have + // S_{mxn} is the Smith normal form (diagonal matrix) + // V_{nxn} is invertible + // however, to simplify the calculation, we modify inplace so that + // x' = V^{-1} x + SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy); + + Array new_vars; + Array new_relations; + Map new_to_old_map; + Map old_to_new_map; + + // Simplify right hand sides + for (PrimExpr r : Uy) { + r = tir::Simplify(r, system_to_solve->ranges); + } + + // Create the relations of the existence of a solution + for (size_t j = 0; j < S.size(); ++j) { + PrimExpr new_relation; + if (j >= num_vars || S[j][j] == 0) { + // The row of matrix is zero. A solution exists only if the Ub[j] is also zero + new_relation = (Uy[j] == 0); + } else { + // The diagonal element is non-zero. A solution exists only if the diagonal element + // is a divisor of the Ub[j] + new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0); + } + new_relation = tir::Simplify(new_relation, system_to_solve->ranges); + if (tir::is_const_int(new_relation, 0)) { + // unable to solve the system. + return LinearSystemTransform( + system_to_solve, + LinearSystem( + /*variables=*/{}, + /*ranges=*/{}, + /*relations=*/{te::make_zero(DataType::Bool())}), + {}, {}); + } else if (!tir::is_const_int(new_relation, 1)) { + new_relations.push_back(new_relation); + } + } + + Array solution_for_V_inv_x; + // Now create new variables or directly solve the equations + // suppose the rank of A is r, aka r = # of non-zeros in S + // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b + // is + // x = (pseudo-inverse of A) b + K_{n, n-r} z_{n-r} + // = V_{n,n} S^{-1}_{n,m} (Ub)_{mxn} + K_{n, n-r} z_{n-r} + // in which K is the right n-r columns of V, z is variable vector + // thus, + // V^{-1} x = S^{-1}_{n,m} (Ub)_{mxn} + + // [[0, ... 0]_{n-r}, ... [0, ..., 0], diag(1, ..., 1)_{(n-r)x(n-r)}] z_{n-r} + for (size_t j = 0; j < num_vars; ++j) { + if (j >= S.size() || S[j][j] == 0) { + // The j-th variable can take any integer value, create a tvm variable for it + PrimExpr to_old = tir::Simplify(V_inv_x[j], system_to_solve->ranges); + std::string name_hint = "n" + std::to_string(new_vars.size()); + if (const VarNode* v_old = to_old.as()) { + name_hint += "_" + v_old->name_hint; + } + Var v = Var(name_hint, V_inv_x[j].dtype()); + solution_for_V_inv_x.push_back(v); + new_vars.push_back(v); + new_to_old_map.Set(v, to_old); + } else { + // The j-th variable is just a single value, don't create a tvm variable + // S^{-1}_{nxm} Uy_{mxn} + if (S[j][j] >= 0) { + PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]); + solution_for_V_inv_x.push_back( + tir::Simplify(floordiv(Uy[j], a), system_to_solve->ranges)); + } else { + // This is required because some simplifiers + // have problems with dividing by negative numbers + PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]); + solution_for_V_inv_x.push_back( + tir::Simplify(floordiv(-Uy[j], a), system_to_solve->ranges)); + } + } + } + + // V V^{-1} x = x + for (size_t i = 0; i < num_vars; ++i) { + PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype()); + for (size_t j = 0; j < num_vars; ++j) { + e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j]; + } + e = tir::Simplify(e); + old_to_new_map.Set(system_to_solve->variables[i], e); + } + + // The resulting ranges + Map new_ranges = InferRange( + new_to_old_map, system_to_solve->variables, system_to_solve->ranges); + + // We have to transform ranges of the old variables into relations over new variables because + // new ranges are not enough usually. + for (const auto& p : system_to_solve->ranges) { + const Var& old_var = p.first; + const Range& old_range = p.second; + if (old_to_new_map.count(old_var)) { + PrimExpr express_by_new_vars = old_to_new_map[old_var]; + PrimExpr lower_cond = tir::Simplify( + old_range->min <= express_by_new_vars, new_ranges); + PrimExpr upper_cond = tir::Simplify( + express_by_new_vars < old_range->min + old_range->extent, new_ranges); + if (!tir::is_const_int(lower_cond, 1)) { + new_relations.push_back(lower_cond); + } + if (!tir::is_const_int(upper_cond, 1)) { + new_relations.push_back(upper_cond); + } + } + } + + // Add the rest conditions + for (const PrimExpr& cond : rest) { + new_relations.push_back(Substitute(cond, old_to_new_map)); + } + + LinearSystem solution(new_vars, new_ranges, new_relations); + LinearSystemTransform transform( + system_to_solve, solution, old_to_new_map, new_to_old_map); + + return transform; +} + +TVM_REGISTER_GLOBAL("arith.SolveEquations") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() == 1) { + *ret = SolveEquations(args[0]); + } else if (args.size() == 3) { + LinearSystem problem(args[0], args[1], args[2]); + *ret = SolveEquations(problem); + } + }); + +} // namespace arith +} // namespace tvm diff --git a/src/arith/util.cc b/src/arith/util.cc new file mode 100644 index 000000000000..058c3e959528 --- /dev/null +++ b/src/arith/util.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file util.cc + * \brief The utils for arithmetic analysis. + */ +#include +#include + +namespace tvm { +namespace arith { + +std::tuple xgcd(int64_t a, int64_t b) { + int64_t s = 0, old_s = 1; + int64_t t = 1, old_t = 0; + int64_t r = b, old_r = a; + + while (r != 0) { + int64_t q = old_r / r; + std::swap(r, old_r); + r -= q * old_r; + std::swap(s, old_s); + s -= q * old_s; + std::swap(t, old_t); + t -= q * old_t; + } + + CHECK_EQ(a % old_r, 0); + CHECK_EQ(b % old_r, 0); + CHECK(old_r == old_s*a + old_t*b); + + return std::make_tuple(old_r, old_s, old_t); +} + +} // namespace arith +} // namespace tvm diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py new file mode 100644 index 000000000000..11b63c259ffd --- /dev/null +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te, arith +from tvm.tir import ir_pass + + +def test_unique_solution(): + x, y = te.var("x"), te.var("y") + ranges = {} + + solution = arith.solve_equations([ + tvm.tir.EQ(x + y, 20), + tvm.tir.EQ(x - y, 10), + ], [x, y], ranges) + assert list(solution.dst.variables) == [] + assert ir_pass.Equal(solution.src_to_dst[x], 15) + assert ir_pass.Equal(solution.src_to_dst[y], 5) + + +def test_low_rank(): + x, y, z = te.var("x"), te.var("y"), te.var("z") + ranges = {} + + solution = arith.solve_equations([ + tvm.tir.EQ(x + y + z, 15), + tvm.tir.EQ(x + y, 10), + ], [x, y, z], ranges) + [n0] = solution.dst.variables + assert ir_pass.Equal(solution.src_to_dst[x], n0 + 10) + assert ir_pass.Equal(solution.src_to_dst[y], -n0) + assert ir_pass.Equal(solution.src_to_dst[z], 5) + + +def test_infer_range(): + x, y = te.var("x"), te.var("y") + ranges = { + x: tvm.ir.Range.make_by_min_extent(-5, 10), + y: tvm.ir.Range.make_by_min_extent(0, 10), + } + + solution = arith.solve_equations([ + tvm.tir.EQ(x + y, 0), + ], [x, y], ranges) + [n0] = solution.dst.variables + assert ir_pass.Equal(solution.src_to_dst[x], n0) + assert ir_pass.Equal(solution.src_to_dst[y], -n0) + # inferred from y's range + assert ir_pass.Equal(solution.dst.ranges[n0].min, -9) + assert ir_pass.Equal(solution.dst.ranges[n0].extent, 10) + # additional inequality is added into the system for x + [ineq] = solution.dst.relations + assert isinstance(ineq, tvm.tir.LE) + assert ir_pass.Equal(ineq.a, -5) + assert ir_pass.Equal(ineq.b, n0) + + +def test_ill_formed(): + x, y = te.var("x"), te.var("y") + + solution = arith.solve_equations([ + tvm.tir.EQ(x + y, 0), + tvm.tir.EQ(x - y, 0), + tvm.tir.EQ(x, 5), + ], [x, y], {}) + assert list(solution.dst.variables) == [] + [rel] = solution.dst.relations + assert ir_pass.Equal(rel, False) + assert len(solution.src_to_dst) == 0 + assert len(solution.dst_to_src) == 0 + + +if __name__ == "__main__": + test_unique_solution() + test_low_rank() + test_infer_range() + test_ill_formed() From 42317d94e5bb0cf9409a699953860b1e7299508f Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 31 Mar 2020 16:19:25 -0700 Subject: [PATCH 2/7] avoid constructing analyzer every time --- include/tvm/arith/analyzer.h | 6 ++++++ python/tvm/arith/linear_system.py | 6 +++--- src/arith/analyzer.cc | 5 +++++ src/arith/solve_linear_equation.cc | 31 +++++++++++++++++------------- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 1889e16fef66..3a71e5eb5fbf 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -423,6 +423,12 @@ class Analyzer { * \param range The range we bind to. */ void Bind(const Var& var, const Range& range); + /*! + * \brief Bind all the vars in the Map + * + * \param variables The {variable -> range} map. + */ + void Bind(const Map& variables); /*! * \brief Whether can we prove expr >= val. diff --git a/python/tvm/arith/linear_system.py b/python/tvm/arith/linear_system.py index 86c6a0f76ba3..d1e6f8502be1 100644 --- a/python/tvm/arith/linear_system.py +++ b/python/tvm/arith/linear_system.py @@ -76,8 +76,8 @@ def solve_equations(equations, variables, ranges): Parameters ---------- - equations: List[tvm.ir.PrimExpr] or LinearSystemTransform - The linear relations between the variables (either equations or inequalities) + equations: List[tvm.ir.PrimExpr] or LinearSystem + The equations of the variables variables : List[tvm.tir.Var] The variables in the system. ranges : Map[tvm.tir.Var, tvm.ir.Range] @@ -94,6 +94,6 @@ def solve_equations(equations, variables, ranges): You can get the mapping from the original variables to the solution via linear_system_transform.src_to_dst. """ - if isinstance(equations, LinearSystemTransform): + if isinstance(equations, LinearSystem): return _ffi_api.SolveEquations(equations) return _ffi_api.SolveEquations(variables, ranges, equations) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 9df5aa2d246d..83dfc64009cf 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -58,6 +58,11 @@ void Analyzer::Bind(const Var& var, const Range& range) { // skip rewrite simplify } +void Analyzer::Bind(const Map& variables) { + for (const auto& iter : variables) { + this->Bind(iter.first, iter.second); + } +} void ConstraintContext::EnterWithScope() { CHECK(exit_ == nullptr); diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 480f5e69a1fa..87979587aca3 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -281,6 +281,9 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { // Conditions we don't know what to do with std::vector rest; + Analyzer analyzer_problem; + analyzer_problem.Bind(system_to_solve->ranges); + size_t num_vars = system_to_solve->variables.size(); // initialize V_{nxn} with identity matrix, @@ -298,7 +301,7 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { if (const tir::EQNode* eq = equation.as()) { // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] Array coeffs = arith::DetectLinearEquation( - tir::Simplify(eq->a - eq->b, system_to_solve->ranges), + analyzer_problem.Simplify(eq->a - eq->b), system_to_solve->variables); if (!coeffs.empty()) { std::vector row; @@ -315,7 +318,7 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { } if (!row.empty()) { - // S V (a-b) = Uy + // S V^{-1} (a-b) = Uy // V is identity for now S.push_back(row); Uy.push_back(-coeffs[coeffs.size() - 1]); @@ -331,8 +334,8 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { // After diagonalizing, we have // S_{mxn} is the Smith normal form (diagonal matrix) // V_{nxn} is invertible - // however, to simplify the calculation, we modify inplace so that - // x' = V^{-1} x + // V_inv_x is V^{-1} \times x + // Uy is U \times y SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy); Array new_vars; @@ -342,7 +345,7 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { // Simplify right hand sides for (PrimExpr r : Uy) { - r = tir::Simplify(r, system_to_solve->ranges); + r = analyzer_problem.Simplify(r); } // Create the relations of the existence of a solution @@ -356,7 +359,7 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { // is a divisor of the Ub[j] new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0); } - new_relation = tir::Simplify(new_relation, system_to_solve->ranges); + new_relation = analyzer_problem.Simplify(new_relation); if (tir::is_const_int(new_relation, 0)) { // unable to solve the system. return LinearSystemTransform( @@ -385,7 +388,7 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { for (size_t j = 0; j < num_vars; ++j) { if (j >= S.size() || S[j][j] == 0) { // The j-th variable can take any integer value, create a tvm variable for it - PrimExpr to_old = tir::Simplify(V_inv_x[j], system_to_solve->ranges); + PrimExpr to_old = analyzer_problem.Simplify(V_inv_x[j]); std::string name_hint = "n" + std::to_string(new_vars.size()); if (const VarNode* v_old = to_old.as()) { name_hint += "_" + v_old->name_hint; @@ -400,7 +403,7 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { if (S[j][j] >= 0) { PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]); solution_for_V_inv_x.push_back( - tir::Simplify(floordiv(Uy[j], a), system_to_solve->ranges)); + analyzer_problem.Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers @@ -417,13 +420,15 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { for (size_t j = 0; j < num_vars; ++j) { e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j]; } - e = tir::Simplify(e); + e = analyzer_problem.Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); } // The resulting ranges Map new_ranges = InferRange( new_to_old_map, system_to_solve->variables, system_to_solve->ranges); + Analyzer analyzer_solution; + analyzer_solution.Bind(new_ranges); // We have to transform ranges of the old variables into relations over new variables because // new ranges are not enough usually. @@ -432,10 +437,10 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { const Range& old_range = p.second; if (old_to_new_map.count(old_var)) { PrimExpr express_by_new_vars = old_to_new_map[old_var]; - PrimExpr lower_cond = tir::Simplify( - old_range->min <= express_by_new_vars, new_ranges); - PrimExpr upper_cond = tir::Simplify( - express_by_new_vars < old_range->min + old_range->extent, new_ranges); + PrimExpr lower_cond = analyzer_solution.Simplify( + old_range->min <= express_by_new_vars); + PrimExpr upper_cond = analyzer_solution.Simplify( + express_by_new_vars < old_range->min + old_range->extent); if (!tir::is_const_int(lower_cond, 1)) { new_relations.push_back(lower_cond); } From ff42ed1bec9daec194ac732206f7aa97b047bd98 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 3 Apr 2020 00:53:21 -0700 Subject: [PATCH 3/7] generate random test cases and address comments Co-authored-by: Sergei Grechanik --- include/tvm/arith/linear_system.h | 7 +- src/arith/solve_linear_equation.cc | 10 +- .../test_arith_solve_linear_system.py | 150 ++++++++++++++++-- 3 files changed, 147 insertions(+), 20 deletions(-) diff --git a/include/tvm/arith/linear_system.h b/include/tvm/arith/linear_system.h index f7f33039698a..4046a58836d6 100644 --- a/include/tvm/arith/linear_system.h +++ b/include/tvm/arith/linear_system.h @@ -38,7 +38,7 @@ using tir::IterVar; /*! * \brief Represent a linear system including variables, their ranges and - * the linear relations between them (either equations or inequalities) + * the relations between them (either equations or inequalities). * \sa LinearSystem */ class LinearSystemNode : public Object { @@ -139,8 +139,9 @@ class LinearSystemTransform : public ObjectRef { /*! * \brief Obtain Smith Normal Form of linear equation A x = y. * Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn}, - * in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0), - * such that si | s_{i+1} and r is the rank of A. + * in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0) and r is the rank of A. + * NOTE: Although in standard Smith Normal Form the diagonal elements satisfy + * s_i | s_{i+1} (| means divides), the implement here does not guarantee it. * U_{mxm} and V_{nxn} are invertible matrices. * This function modifies \p S to be S_{mxn}, \p V to be V_{nxn}, * \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 87979587aca3..fc8e03669344 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -312,6 +312,8 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { } else { // elements in matrix S V must be integers // ignore equations that we cannot deal with. + LOG(WARNING) << "Cannot deal with non-integer coefficients, ignore equation " + << equation; row.clear(); break; } @@ -379,11 +381,11 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { // suppose the rank of A is r, aka r = # of non-zeros in S // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b // is - // x = (pseudo-inverse of A) b + K_{n, n-r} z_{n-r} - // = V_{n,n} S^{-1}_{n,m} (Ub)_{mxn} + K_{n, n-r} z_{n-r} + // x = (pseudo-inverse of A) b + K_{(n)x(n-r)} z_{n-r} + // = V_{nxn} S^{-1}_{nxm} (Ub)_{mxn} + K_{(n)x(n-r)} z_{n-r} // in which K is the right n-r columns of V, z is variable vector // thus, - // V^{-1} x = S^{-1}_{n,m} (Ub)_{mxn} + + // V^{-1} x = S^{-1}_{nxm} (Ub)_{mxn} + // [[0, ... 0]_{n-r}, ... [0, ..., 0], diag(1, ..., 1)_{(n-r)x(n-r)}] z_{n-r} for (size_t j = 0; j < num_vars; ++j) { if (j >= S.size() || S[j][j] == 0) { @@ -469,6 +471,8 @@ TVM_REGISTER_GLOBAL("arith.SolveEquations") } else if (args.size() == 3) { LinearSystem problem(args[0], args[1], args[2]); *ret = SolveEquations(problem); + } else { + LOG(FATAL) << "arith.SolveEquations expects 1 or 3 arguments, gets " << args.size(); } }); diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py index 11b63c259ffd..c50fddf558da 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -14,9 +14,130 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import random +import numpy as np import tvm -from tvm import te, arith -from tvm.tir import ir_pass +from tvm import te, arith, ir, tir + + +def run_expr(expr, vranges): + def _compute_body(*us): + vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} + return tir.ir_pass.Substitute(expr, vmap) + + A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body) + args = [tvm.nd.empty(A.shape, A.dtype)] + sch = te.create_schedule(A.op) + mod = tvm.build(sch, [A]) + mod(*args) + return args[0].asnumpy() + + +def check_bruteforce(bool_expr, vranges, cond=None): + if cond is not None: + bool_expr = te.any(tir.Not(cond), bool_expr) + + res = run_expr(bool_expr, vranges) + if not np.all(res): + indices = list(np.argwhere(res == 0)[0]) + counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] + counterex = sorted(counterex, key=lambda x: x[0]) + counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) + raise AssertionError("Expression {}\nis not true on {}\n" + "Counterexample: {}" + .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex)) + + +def check_solution(solution, vranges={}): + def _check_forward(formula1, formula2, varmap, backvarmap): + all_vranges = vranges.copy() + all_vranges.update({v: r for v, r in formula1.ranges.items()}) + + # Check that the transformation is injective + cond_on_vars = tir.const(1, 'bool') + for v in formula1.variables: + # variable mapping is consistent + v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap)) + cond_on_vars = te.all(cond_on_vars, v == v_back) + # Also we have to check that the new relations are true when old relations are true + cond_subst = tir.ir_pass.Substitute( + te.all(tir.const(1, 'bool'), *formula2.relations), backvarmap) + # We have to include relations from vranges too + for v in formula2.variables: + if v in formula2.ranges: + r = formula2.ranges[v] + range_cond = te.all(v >= r.min, v < r.min + r.extent) + range_cond = tir.ir_pass.Substitute(range_cond, backvarmap) + cond_subst = te.all(cond_subst, range_cond) + cond_subst = tir.ir_pass.Simplify(cond_subst) + check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, + cond=te.all(tir.const(1, 'bool'), *formula1.relations)) + + rels = solution.dst.relations + if len(rels) == 1 and ir.structural_equal(rels[0], False): + # not solvable, skip + return + _check_forward(solution.src, solution.dst, + solution.src_to_dst, solution.dst_to_src) + _check_forward(solution.dst, solution.src, + solution.dst_to_src, solution.src_to_dst) + + +def test_solution_consistency(): + random.seed(0) + + def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): + variables = [te.var("x" + str(i)) for i in range(num_vars)] + + relations = [] + for i in range(num_formulas): + s1 = sum([v*random.randint(coef[0], coef[1]) for v in variables]) + s1 += random.randint(coef[0], coef[1]) + s2 = sum([v*random.randint(coef[0], coef[1]) for v in variables]) + s2 += random.randint(coef[0], coef[1]) + if random.random() < 0.7: + op = tvm.tir.EQ + else: + # we also make sure it can correctly handle inequalities + op = random.choice([tvm.tir.LE, tvm.tir.LT, tvm.tir.GE, tvm.tir.GT]) + relations.append(op(s1, s2)) + + vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables} + solution = arith.solve_equations(relations, variables, vranges) + + check_solution(solution) + + # leaving some variables as parameters should also be ok + for k in [1, 2]: + if len(variables) > k: + solution = arith.solve_equations(relations, variables[:-k], vranges) + param_ranges = {v: vranges[v] for v in variables[-k:]} + check_solution(solution, param_ranges) + + for i in range(2): + _check(num_vars=1, num_formulas=1) + for i in range(2): + _check(num_vars=1, num_formulas=2) + + for i in range(2): + _check(num_vars=2, num_formulas=1) + for i in range(2): + _check(num_vars=2, num_formulas=2) + for i in range(2): + _check(num_vars=2, num_formulas=3) + + for i in range(3): + _check(num_vars=3, num_formulas=3, coef=(-2, 2)) + for i in range(3): + _check(num_vars=3, num_formulas=4, coef=(-2, 2)) + + for i in range(3): + _check(num_vars=4, num_formulas=3, coef=(-1, 1)) + + for i in range(3): + _check(num_vars=10, num_formulas=2, coef=(-1, 1), bounds=(0, 4)) + for i in range(3): + _check(num_vars=10, num_formulas=3, coef=(0, 1), bounds=(0, 4)) def test_unique_solution(): @@ -28,8 +149,8 @@ def test_unique_solution(): tvm.tir.EQ(x - y, 10), ], [x, y], ranges) assert list(solution.dst.variables) == [] - assert ir_pass.Equal(solution.src_to_dst[x], 15) - assert ir_pass.Equal(solution.src_to_dst[y], 5) + assert ir.structural_equal(solution.src_to_dst[x], 15) + assert ir.structural_equal(solution.src_to_dst[y], 5) def test_low_rank(): @@ -41,9 +162,9 @@ def test_low_rank(): tvm.tir.EQ(x + y, 10), ], [x, y, z], ranges) [n0] = solution.dst.variables - assert ir_pass.Equal(solution.src_to_dst[x], n0 + 10) - assert ir_pass.Equal(solution.src_to_dst[y], -n0) - assert ir_pass.Equal(solution.src_to_dst[z], 5) + assert ir.structural_equal(solution.src_to_dst[x], n0 + 10) + assert ir.structural_equal(solution.src_to_dst[y], -n0) + assert ir.structural_equal(solution.src_to_dst[z], 5) def test_infer_range(): @@ -57,16 +178,16 @@ def test_infer_range(): tvm.tir.EQ(x + y, 0), ], [x, y], ranges) [n0] = solution.dst.variables - assert ir_pass.Equal(solution.src_to_dst[x], n0) - assert ir_pass.Equal(solution.src_to_dst[y], -n0) + assert ir.structural_equal(solution.src_to_dst[x], n0) + assert ir.structural_equal(solution.src_to_dst[y], -n0) # inferred from y's range - assert ir_pass.Equal(solution.dst.ranges[n0].min, -9) - assert ir_pass.Equal(solution.dst.ranges[n0].extent, 10) + assert ir.structural_equal(solution.dst.ranges[n0].min, -9) + assert ir.structural_equal(solution.dst.ranges[n0].extent, 10) # additional inequality is added into the system for x [ineq] = solution.dst.relations assert isinstance(ineq, tvm.tir.LE) - assert ir_pass.Equal(ineq.a, -5) - assert ir_pass.Equal(ineq.b, n0) + assert ir.structural_equal(ineq.a, -5) + assert ir.structural_equal(ineq.b, n0) def test_ill_formed(): @@ -79,7 +200,7 @@ def test_ill_formed(): ], [x, y], {}) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - assert ir_pass.Equal(rel, False) + assert ir.structural_equal(rel, False) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 @@ -89,3 +210,4 @@ def test_ill_formed(): test_low_rank() test_infer_range() test_ill_formed() + test_solution_consistency() From 8256321953ca7245c2b1bab5eaf788abcd3b42ab Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 4 Apr 2020 23:04:54 -0700 Subject: [PATCH 4/7] rename linear_system to int_constraints --- .../arith/{linear_system.h => int_solver.h} | 73 ++++++++++--------- python/tvm/arith/__init__.py | 2 +- .../arith/{linear_system.py => int_solver.py} | 54 +++++++------- .../{linear_system.cc => int_constraints.cc} | 44 ++++++----- src/arith/solve_linear_equation.cc | 24 +++--- .../test_arith_solve_linear_system.py | 12 +-- 6 files changed, 107 insertions(+), 102 deletions(-) rename include/tvm/arith/{linear_system.h => int_solver.h} (70%) rename python/tvm/arith/{linear_system.py => int_solver.py} (61%) rename src/arith/{linear_system.cc => int_constraints.cc} (56%) diff --git a/include/tvm/arith/linear_system.h b/include/tvm/arith/int_solver.h similarity index 70% rename from include/tvm/arith/linear_system.h rename to include/tvm/arith/int_solver.h index 4046a58836d6..db5c84989338 100644 --- a/include/tvm/arith/linear_system.h +++ b/include/tvm/arith/int_solver.h @@ -18,11 +18,11 @@ */ /*! - * \file tvm/arith/linear_system.h - * \brief Linear system data structures and solvers + * \file tvm/arith/int_solver.h + * \brief integer constraints data structures and solvers */ -#ifndef TVM_ARITH_LINEAR_SYSTEM_H_ -#define TVM_ARITH_LINEAR_SYSTEM_H_ +#ifndef TVM_ARITH_INT_SOLVER_H_ +#define TVM_ARITH_INT_SOLVER_H_ #include #include @@ -37,13 +37,13 @@ using tir::VarNode; using tir::IterVar; /*! - * \brief Represent a linear system including variables, their ranges and + * \brief Represent integer constrains including (integer) variables, their ranges and * the relations between them (either equations or inequalities). * \sa LinearSystem */ -class LinearSystemNode : public Object { +class IntConstraintsNode : public Object { public: - // e.g., \alpha, \beta + // e.g., \alpha, \beta, must be integers Array variables; // e.g., 1 <= \alpha <= N, etc. Map ranges; @@ -57,32 +57,32 @@ class LinearSystemNode : public Object { v->Visit("relations", &relations); } - static constexpr const char* _type_key = "arith.LinearSystem"; - TVM_DECLARE_FINAL_OBJECT_INFO(LinearSystemNode, Object); + static constexpr const char* _type_key = "arith.IntConstraints"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object); }; /*! - * \brief Managed reference to LinearSystemNode. - * \sa LinearSystemNode + * \brief Managed reference to IntConstraintsNode. + * \sa IntConstraintsNode */ -class LinearSystem : public ObjectRef { +class IntConstraints : public ObjectRef { public: /*! * \brief Constructor by fields - * \param variables The variables in the system. + * \param variables The variables in the constraints, must be integers. * \param ranges The ranges of the variables. * \param relations The linear relations between the variables * (either equations or inequalities) */ - TVM_DLL LinearSystem(Array variables, - Map ranges, - Array relations); + TVM_DLL IntConstraints(Array variables, + Map ranges, + Array relations); - TVM_DEFINE_OBJECT_REF_METHODS(LinearSystem, ObjectRef, LinearSystemNode); + TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode); }; /*! - * \brief We can have different set of variables to represent the same linear system. + * \brief We can have different set of variables to represent the same constraints. * For example, the following two systems are equivalent, * {a + b = 0 | a >= 0, b >= 0} and * {m - n = 0 | m >= 0, n <= 0} @@ -93,12 +93,12 @@ class LinearSystem : public ObjectRef { * dst : {m - n = 0 | m >= 0, n <= 0} * src_to_dst : {a -> m, b -> -n} * dst_to_src : {m -> a, n -> -b} - * \sa LinearSystemTransform + * \sa IntConstraintsTransform */ -class LinearSystemTransformNode : public Object { +class IntConstraintsTransformNode : public Object { public: - LinearSystem src; - LinearSystem dst; + IntConstraints src; + IntConstraints dst; Map src_to_dst; Map dst_to_src; @@ -109,31 +109,32 @@ class LinearSystemTransformNode : public Object { v->Visit("dst_to_src", &dst_to_src); } - static constexpr const char* _type_key = "arith.LinearSystemTransform"; - TVM_DECLARE_FINAL_OBJECT_INFO(LinearSystemTransformNode, Object); + static constexpr const char* _type_key = "arith.IntConstraintsTransform"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object); }; /*! - * \brief Managed reference to LinearSystemTransformNode. - * \sa LinearSystemTransformNode + * \brief Managed reference to IntConstraintsTransformNode. + * \sa IntConstraintsTransformNode */ -class LinearSystemTransform : public ObjectRef { +class IntConstraintsTransform : public ObjectRef { public: /*! * \brief Constructor by fields - * \param src source linear system, e.g., {a + b = 0 | a >= 0, b >= 0} - * \param dst linear system equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0} + * \param src source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0} + * \param dst integer constraints equivalent to the source, + * e.g., {m - n = 0 | m >= 0, n <= 0} * \param src_to_dst mapping from variables in the \p src to the variables in the \p dst, * e.g., {a -> m, b -> -n} * \param dst_to_src mapping from variables in the \p dst to the variables in the \p src, * e.g., {m -> a, n -> -b} */ - TVM_DLL LinearSystemTransform(LinearSystem src, - LinearSystem dst, - Map src_to_dst, - Map dst_to_src); + TVM_DLL IntConstraintsTransform(IntConstraints src, + IntConstraints dst, + Map src_to_dst, + Map dst_to_src); - TVM_DEFINE_OBJECT_REF_METHODS(LinearSystemTransform, ObjectRef, LinearSystemTransformNode); + TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; /*! @@ -165,8 +166,8 @@ void SmithNormalFormDiag(std::vector> *S, * as well as inequalities inferred from the \p system_to_solve. * You can get the mapping from the original variables to the solution via ret->src_to_dst. */ -LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve); +IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve); } // namespace arith } // namespace tvm -#endif // TVM_ARITH_LINEAR_SYSTEM_H_ +#endif // TVM_ARITH_INT_SOLVER_H_ diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 25f30bc60e26..017934a03b34 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -20,4 +20,4 @@ from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound -from .linear_system import solve_equations +from .int_solver import solve_linear_equations diff --git a/python/tvm/arith/linear_system.py b/python/tvm/arith/int_solver.py similarity index 61% rename from python/tvm/arith/linear_system.py rename to python/tvm/arith/int_solver.py index d1e6f8502be1..c0e169f30dbd 100644 --- a/python/tvm/arith/linear_system.py +++ b/python/tvm/arith/int_solver.py @@ -14,39 +14,39 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Linear system data structures and solvers""" +"""integer constraints data structures and solvers""" import tvm._ffi from tvm.runtime import Object from . import _ffi_api -@tvm._ffi.register_object("arith.LinearSystem") -class LinearSystem(Object): - """Represent a linear system including variables, their ranges and - the linear relations between them (either equations or inequalities) +@tvm._ffi.register_object("arith.IntConstraints") +class IntConstraints(Object): + """Represent a set of integer constraints including variables, their ranges and + the relations between them (either equations or inequalities) Parameters ---------- variables : List[tvm.tir.Var] - The variables in the system. + The variables in the constraints. Must be integers ranges : Map[tvm.tir.Var, tvm.ir.Range] The ranges of the variables. relations : List[tvm.ir.PrimExpr] - The linear relations between the variables (either equations or inequalities) + The relations between the variables (either equations or inequalities) """ def __init__(self, variables, ranges, relations): self.__init_handle_by_constructor__( - _ffi_api.LinearSystem, variables, ranges, relations) + _ffi_api.IntConstraints, variables, ranges, relations) -@tvm._ffi.register_object("arith.LinearSystemTransform") -class LinearSystemTransform(Object): - """We can have different set of variables to represent the same linear system. - For example, the following two systems are equivalent, +@tvm._ffi.register_object("arith.IntConstraintsTransform") +class IntConstraintsTransform(Object): + """We can have different set of variables to represent the same integer constraints. + For example, the following two constrains are equivalent, {a + b = 0 | a >= 0, b >= 0} and {m - n = 0 | m >= 0, n <= 0} This data structure represents the transformation - between two equivalent linear systems. + between two equivalent integer constraints. In the above example, src : {a + b = 0 | a >= 0, b >= 0} dst : {m - n = 0 | m >= 0, n <= 0} @@ -55,10 +55,10 @@ class LinearSystemTransform(Object): Parameters ---------- - src : arith.LinearSystem - source linear system, e.g., {a + b = 0 | a >= 0, b >= 0} - dst : arith.LinearSystem - linear system equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0} + src : arith.IntConstraints + source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0} + dst : arith.IntConstraints + integer constraints equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0} src_to_dst : Map[tvm.tir.Var, tvm.ir.PrimExpr] mapping from variables in the src to the variables in the dst, e.g., {a -> m, b -> -n} @@ -68,15 +68,15 @@ class LinearSystemTransform(Object): """ def __init__(self, src, dst, src_to_dst, dst_to_src): self.__init_handle_by_constructor__( - _ffi_api.LinearSystemTransform, src, dst, src_to_dst, dst_to_src) + _ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src) -def solve_equations(equations, variables, ranges): +def solve_linear_equations(equations, variables, ranges): """Solve linear equations. Parameters ---------- - equations: List[tvm.ir.PrimExpr] or LinearSystem + equations: List[tvm.ir.PrimExpr] or IntConstraints The equations of the variables variables : List[tvm.tir.Var] The variables in the system. @@ -85,15 +85,15 @@ def solve_equations(equations, variables, ranges): Returns ------- - linear_system_transform : LinearSystemTransform - A new linear system, with less variables (if the problem is NOT of full rank), + int_constraints_transform : IntConstraintsTransform + New integer constraints, with less variables (if the problem is NOT of full rank), or no variable (if the problem is of full rank), - or an empty linear system (if the problem is unsolvable). + or an empty integer constraints (if the problem is unsolvable). It also provides the ranges of the variables in the new system, as well as inequalities inferred from the problem. You can get the mapping from the original variables to the solution via - linear_system_transform.src_to_dst. + int_constraints_transform.src_to_dst. """ - if isinstance(equations, LinearSystem): - return _ffi_api.SolveEquations(equations) - return _ffi_api.SolveEquations(variables, ranges, equations) + if isinstance(equations, IntConstraints): + return _ffi_api.SolveLinearEquations(equations) + return _ffi_api.SolveLinearEquations(variables, ranges, equations) diff --git a/src/arith/linear_system.cc b/src/arith/int_constraints.cc similarity index 56% rename from src/arith/linear_system.cc rename to src/arith/int_constraints.cc index 169d5e384e19..21b0d27fde1d 100644 --- a/src/arith/linear_system.cc +++ b/src/arith/int_constraints.cc @@ -18,10 +18,10 @@ */ /*! - * \file linear_system.cc - * \brief The linear system data structures. + * \file int_constraints.cc + * \brief The integer constraints data structures. */ -#include +#include #include #include #include @@ -33,22 +33,26 @@ namespace tvm { namespace arith { -LinearSystem::LinearSystem(Array variables, - Map ranges, - Array relations) { - ObjectPtr node = make_object(); +IntConstraints::IntConstraints(Array variables, + Map ranges, + Array relations) { + ObjectPtr node = make_object(); + for (const auto& var : variables) { + CHECK(var.dtype().is_int() || var.dtype().is_uint()) + << "Variables in IntConstraints must be integers"; + } node->variables = std::move(variables); node->ranges = std::move(ranges); node->relations = std::move(relations); data_ = std::move(node); } -TVM_REGISTER_NODE_TYPE(LinearSystemNode); +TVM_REGISTER_NODE_TYPE(IntConstraintsNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "LinearSystem(" +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntConstraints(" << op->variables << ", " << op->ranges << ", " << op->relations @@ -56,11 +60,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); -LinearSystemTransform::LinearSystemTransform(LinearSystem src, - LinearSystem dst, - Map src_to_dst, - Map dst_to_src) { - ObjectPtr node = make_object(); +IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, + IntConstraints dst, + Map src_to_dst, + Map dst_to_src) { + ObjectPtr node = make_object(); node->src = std::move(src); node->dst = std::move(dst); node->src_to_dst = std::move(src_to_dst); @@ -68,12 +72,12 @@ LinearSystemTransform::LinearSystemTransform(LinearSystem src, data_ = std::move(node); } -TVM_REGISTER_NODE_TYPE(LinearSystemTransformNode); +TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "LinearSystemTransform(" +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntConstraintsTransform(" << "\n\t" << op->src << "\n\t" << op->dst << "\n\t" << op->src_to_dst diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index fc8e03669344..0e89378389d5 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -18,13 +18,13 @@ */ /*! - * \file tvm/arith/linear_solver.cc + * \file tvm/arith/solve_linear_equation.cc * \brief Solve linear equations. */ #include #include #include -#include +#include #include #include #include @@ -266,7 +266,7 @@ void DebugPrint(const std::vector>& S, std::cout << "\n" << std::endl; } -LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { +IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) { // m: # of equations // n: # of variables // we first construct A_{mxn} x_{nx1} = y_{mx1} @@ -364,9 +364,9 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { new_relation = analyzer_problem.Simplify(new_relation); if (tir::is_const_int(new_relation, 0)) { // unable to solve the system. - return LinearSystemTransform( + return IntConstraintsTransform( system_to_solve, - LinearSystem( + IntConstraints( /*variables=*/{}, /*ranges=*/{}, /*relations=*/{te::make_zero(DataType::Bool())}), @@ -457,22 +457,22 @@ LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve) { new_relations.push_back(Substitute(cond, old_to_new_map)); } - LinearSystem solution(new_vars, new_ranges, new_relations); - LinearSystemTransform transform( + IntConstraints solution(new_vars, new_ranges, new_relations); + IntConstraintsTransform transform( system_to_solve, solution, old_to_new_map, new_to_old_map); return transform; } -TVM_REGISTER_GLOBAL("arith.SolveEquations") +TVM_REGISTER_GLOBAL("arith.SolveLinearEquations") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() == 1) { - *ret = SolveEquations(args[0]); + *ret = SolveLinearEquations(args[0]); } else if (args.size() == 3) { - LinearSystem problem(args[0], args[1], args[2]); - *ret = SolveEquations(problem); + IntConstraints problem(args[0], args[1], args[2]); + *ret = SolveLinearEquations(problem); } else { - LOG(FATAL) << "arith.SolveEquations expects 1 or 3 arguments, gets " << args.size(); + LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); } }); diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py index c50fddf558da..6d85abe6b893 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -103,14 +103,14 @@ def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): relations.append(op(s1, s2)) vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables} - solution = arith.solve_equations(relations, variables, vranges) + solution = arith.solve_linear_equations(relations, variables, vranges) check_solution(solution) # leaving some variables as parameters should also be ok for k in [1, 2]: if len(variables) > k: - solution = arith.solve_equations(relations, variables[:-k], vranges) + solution = arith.solve_linear_equations(relations, variables[:-k], vranges) param_ranges = {v: vranges[v] for v in variables[-k:]} check_solution(solution, param_ranges) @@ -144,7 +144,7 @@ def test_unique_solution(): x, y = te.var("x"), te.var("y") ranges = {} - solution = arith.solve_equations([ + solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y, 20), tvm.tir.EQ(x - y, 10), ], [x, y], ranges) @@ -157,7 +157,7 @@ def test_low_rank(): x, y, z = te.var("x"), te.var("y"), te.var("z") ranges = {} - solution = arith.solve_equations([ + solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y + z, 15), tvm.tir.EQ(x + y, 10), ], [x, y, z], ranges) @@ -174,7 +174,7 @@ def test_infer_range(): y: tvm.ir.Range.make_by_min_extent(0, 10), } - solution = arith.solve_equations([ + solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y, 0), ], [x, y], ranges) [n0] = solution.dst.variables @@ -193,7 +193,7 @@ def test_infer_range(): def test_ill_formed(): x, y = te.var("x"), te.var("y") - solution = arith.solve_equations([ + solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y, 0), tvm.tir.EQ(x - y, 0), tvm.tir.EQ(x, 5), From 8cac9e36a44483370a5e70c4ed2408a4a8dd9662 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 7 Apr 2020 19:15:09 -0700 Subject: [PATCH 5/7] add comments and use random seed --- include/tvm/arith/int_solver.h | 3 ++ .../test_arith_solve_linear_system.py | 40 ++++++++++++------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index db5c84989338..cb038c0b5c91 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -143,6 +143,9 @@ class IntConstraintsTransform : public ObjectRef { * in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0) and r is the rank of A. * NOTE: Although in standard Smith Normal Form the diagonal elements satisfy * s_i | s_{i+1} (| means divides), the implement here does not guarantee it. + * TODO(yzhliu): From sergei-grechanik: + * computing the proper Smith normal form may improve stability of automatic differentiation + * (generating the same gradient code for slightly different but equivalent input code * U_{mxm} and V_{nxn} are invertible matrices. * This function modifies \p S to be S_{mxn}, \p V to be V_{nxn}, * \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py index 6d85abe6b893..ff6917010e18 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -16,11 +16,17 @@ # under the License. import random import numpy as np +import sys +import pytest import tvm from tvm import te, arith, ir, tir def run_expr(expr, vranges): + """ Evaluate expr for every value of free variables + given by vranges and return the tensor of results. + TODO(yzhliu): move to utils + """ def _compute_body(*us): vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} return tir.ir_pass.Substitute(expr, vmap) @@ -34,6 +40,10 @@ def _compute_body(*us): def check_bruteforce(bool_expr, vranges, cond=None): + """ Check that bool_expr holds given the condition cond + for every value of free variables from vranges. + TODO(yzhliu): move to utils + """ if cond is not None: bool_expr = te.any(tir.Not(cond), bool_expr) @@ -49,29 +59,30 @@ def check_bruteforce(bool_expr, vranges, cond=None): def check_solution(solution, vranges={}): - def _check_forward(formula1, formula2, varmap, backvarmap): + """Check that solution is a bijective transformation""" + def _check_forward(constraints1, constraints2, varmap, backvarmap): all_vranges = vranges.copy() - all_vranges.update({v: r for v, r in formula1.ranges.items()}) + all_vranges.update({v: r for v, r in constraints1.ranges.items()}) # Check that the transformation is injective cond_on_vars = tir.const(1, 'bool') - for v in formula1.variables: + for v in constraints1.variables: # variable mapping is consistent v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap)) cond_on_vars = te.all(cond_on_vars, v == v_back) # Also we have to check that the new relations are true when old relations are true cond_subst = tir.ir_pass.Substitute( - te.all(tir.const(1, 'bool'), *formula2.relations), backvarmap) + te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap) # We have to include relations from vranges too - for v in formula2.variables: - if v in formula2.ranges: - r = formula2.ranges[v] + for v in constraints2.variables: + if v in constraints2.ranges: + r = constraints2.ranges[v] range_cond = te.all(v >= r.min, v < r.min + r.extent) range_cond = tir.ir_pass.Substitute(range_cond, backvarmap) cond_subst = te.all(cond_subst, range_cond) cond_subst = tir.ir_pass.Simplify(cond_subst) check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, - cond=te.all(tir.const(1, 'bool'), *formula1.relations)) + cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) rels = solution.dst.relations if len(rels) == 1 and ir.structural_equal(rels[0], False): @@ -83,8 +94,11 @@ def _check_forward(formula1, formula2, varmap, backvarmap): solution.dst_to_src, solution.src_to_dst) -def test_solution_consistency(): - random.seed(0) +def test_solution_consistency(capsys): + seed = random.randrange(sys.maxsize) + with capsys.disabled(): + print("\nUse seed {} to reproduce the results.\n".format(seed)) + random.seed(seed) def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): variables = [te.var("x" + str(i)) for i in range(num_vars)] @@ -206,8 +220,4 @@ def test_ill_formed(): if __name__ == "__main__": - test_unique_solution() - test_low_rank() - test_infer_range() - test_ill_formed() - test_solution_consistency() + pytest.main([__file__]) From 2d89fef7a802bfe242b99344b5c0a0da122b74b9 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 8 Apr 2020 00:27:26 -0700 Subject: [PATCH 6/7] message for reporting failure with seed --- tests/python/unittest/test_arith_solve_linear_system.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py index ff6917010e18..79ecbc3b33b0 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -97,7 +97,8 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): def test_solution_consistency(capsys): seed = random.randrange(sys.maxsize) with capsys.disabled(): - print("\nUse seed {} to reproduce the results.\n".format(seed)) + print("\nThis test is intentionally non-deterministic, " + "if it fails please report it in github issue together with this seed {}\n".format(seed)) random.seed(seed) def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): From dc76eac227a2a276d3245d5e962be3f6e33c8351 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 9 Apr 2020 17:17:21 -0700 Subject: [PATCH 7/7] add SEqualReduce to IntConstraints; allow variables & ranges to be None --- include/tvm/arith/int_solver.h | 32 +++++++++++++++++++ python/tvm/arith/int_solver.py | 6 ++-- src/arith/int_constraints.cc | 7 ++++ src/arith/solve_linear_equation.cc | 2 +- .../test_arith_solve_linear_system.py | 25 +++++++++++---- 5 files changed, 62 insertions(+), 10 deletions(-) diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index cb038c0b5c91..57f3af4bb67b 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -46,6 +46,8 @@ class IntConstraintsNode : public Object { // e.g., \alpha, \beta, must be integers Array variables; // e.g., 1 <= \alpha <= N, etc. + // it is absolutely ok to include ranges for parameters + // (variables that are not in this->variables) in this map Map ranges; // linear equalities or inequalities // e.g., A \alpha = \beta or A \alpha <= \beta @@ -57,6 +59,20 @@ class IntConstraintsNode : public Object { v->Visit("relations", &relations); } + bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const { + return + equal(variables, other->variables) && + equal(ranges, other->ranges) && + equal(relations, other->relations); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(variables); + hash_reduce(ranges); + hash_reduce(relations); + } + + static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntConstraints"; TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object); }; @@ -109,6 +125,22 @@ class IntConstraintsTransformNode : public Object { v->Visit("dst_to_src", &dst_to_src); } + bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const { + return + equal(src, other->src) && + equal(dst, other->dst) && + equal(src_to_dst, other->src_to_dst) && + equal(dst_to_src, other->dst_to_src); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(src); + hash_reduce(dst); + hash_reduce(src_to_dst); + hash_reduce(dst_to_src); + } + + static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntConstraintsTransform"; TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object); }; diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index c0e169f30dbd..e35435c1da03 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -71,16 +71,16 @@ def __init__(self, src, dst, src_to_dst, dst_to_src): _ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src) -def solve_linear_equations(equations, variables, ranges): +def solve_linear_equations(equations, variables=None, ranges=None): """Solve linear equations. Parameters ---------- equations: List[tvm.ir.PrimExpr] or IntConstraints The equations of the variables - variables : List[tvm.tir.Var] + variables : Optional[List[tvm.tir.Var]] The variables in the system. - ranges : Map[tvm.tir.Var, tvm.ir.Range] + ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]] The ranges of the variables. Returns diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 21b0d27fde1d..34efa986e985 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -37,6 +37,13 @@ IntConstraints::IntConstraints(Array variables, Map ranges, Array relations) { ObjectPtr node = make_object(); + if (!variables.defined()) { + variables = Array(); + } + if (!ranges.defined()) { + ranges = Map(); + } + CHECK(relations.defined()); for (const auto& var : variables) { CHECK(var.dtype().is_int() || var.dtype().is_uint()) << "Variables in IntConstraints must be integers"; diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 0e89378389d5..8142a03155c8 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -411,7 +411,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // have problems with dividing by negative numbers PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]); solution_for_V_inv_x.push_back( - tir::Simplify(floordiv(-Uy[j], a), system_to_solve->ranges)); + analyzer_problem.Simplify(floordiv(-Uy[j], a))); } } } diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py index 79ecbc3b33b0..45f8fc10aaf0 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -94,11 +94,10 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): solution.dst_to_src, solution.src_to_dst) -def test_solution_consistency(capsys): +def test_solution_consistency(): seed = random.randrange(sys.maxsize) - with capsys.disabled(): - print("\nThis test is intentionally non-deterministic, " - "if it fails please report it in github issue together with this seed {}\n".format(seed)) + print("\nThis test is intentionally non-deterministic, " + "if it fails please report it in github issue together with this seed {}\n".format(seed)) random.seed(seed) def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): @@ -155,14 +154,28 @@ def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): _check(num_vars=10, num_formulas=3, coef=(0, 1), bounds=(0, 4)) +def test_empty_var_to_solve(): + x, y = te.var("x"), te.var("y") + equations = [ + tvm.tir.EQ(x + y, 20), + tvm.tir.EQ(x - y, 10), + ] + solution = arith.solve_linear_equations(equations) + assert len(solution.src_to_dst) == 0 + assert len(solution.dst_to_src) == 0 + assert len(solution.src.variables) == 0 + assert len(solution.src.ranges) == 0 + assert ir.structural_equal(solution.src.relations, equations) + assert ir.structural_equal(solution.src, solution.dst) + + def test_unique_solution(): x, y = te.var("x"), te.var("y") - ranges = {} solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y, 20), tvm.tir.EQ(x - y, 10), - ], [x, y], ranges) + ], [x, y]) assert list(solution.dst.variables) == [] assert ir.structural_equal(solution.src_to_dst[x], 15) assert ir.structural_equal(solution.src_to_dst[y], 5)