diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index bef9705e4749..9b94e9179ec3 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -332,6 +332,30 @@ TVM_DLL Expr operator||(Expr a, Expr b); * \note This operator does eager constant folding. */ TVM_DLL Expr operator!(Expr a); +/*! + * \brief compute trunc(a / b) + * + * This is the default integer division behavior in C. + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr truncdiv(Expr a, Expr b); +/*! + * \brief compute the remainder of truncdiv + * + * This is the default integer division behavior in C. + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr truncmod(Expr a, Expr b); /*! * \brief compute floor(a / b) * diff --git a/python/tvm/api.py b/python/tvm/api.py index 490899ebe69a..b54d36426ba9 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -891,6 +891,52 @@ def reducer(expr, axis, where=None, *args): return reducer +def truncdiv(a, b): + """Compute the truncdiv of two expressions. + + Parameters + ---------- + a : Expr + The left hand operand + + b : Expr + The right hand operand + + Returns + ------- + res : Expr + The result expression. + + Note + ---- + This is the default integer division behavior in C. + """ + return _make._OpTruncDiv(a, b) + + +def truncmod(a, b): + """Compute the truncmod of two expressions. + + Parameters + ---------- + a : Expr + The left hand operand + + b : Expr + The right hand operand + + Returns + ------- + res : Expr + The result expression. + + Note + ---- + This is the default integer division behavior in C. + """ + return _make._OpTruncMod(a, b) + + def floordiv(a, b): """Compute the floordiv of two expressions. diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 2216793898e3..9e14048dbbfe 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -196,6 +196,8 @@ REGISTER_MAKE_BINARY_OP(_OpDiv, operator/); REGISTER_MAKE_BINARY_OP(_OpMod, operator%); REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); +REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); +REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod); REGISTER_MAKE_BINARY_OP(_OpPow, pow); REGISTER_MAKE_BINARY_OP(_OpMin, min); REGISTER_MAKE_BINARY_OP(_OpMax, max); diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index c921afc9b7ff..57f90534fbb4 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file const_fold.h * \brief Centralized location for constant folding. */ diff --git a/src/arithmetic/int_operator.h b/src/arithmetic/int_operator.h index 4d940b2aa0ce..d92094415eba 100644 --- a/src/arithmetic/int_operator.h +++ b/src/arithmetic/int_operator.h @@ -99,11 +99,12 @@ inline bool WillOverflow(int64_t x, * \return the result. */ inline int64_t floordiv(int64_t x, int64_t y) { - bool round_down = - (x >= 0 && y >= 0) || - (x <= 0 && y <= 0) || - (x % y == 0); - return round_down ? (x / y) : (x / y - 1); + int64_t rdiv = x / y; + int64_t rmod = x % y; + bool is_floor_div = + (y >= 0 && rmod >= 0) || + (y < 0 && rmod <= 0); + return is_floor_div ? rdiv : (rdiv - 1); } @@ -114,11 +115,11 @@ inline int64_t floordiv(int64_t x, int64_t y) { * \return the result. */ inline int64_t floormod(int64_t x, int64_t y) { - bool round_down = - (x >= 0 && y >= 0) || - (x <= 0 && y <= 0) || - (x % y == 0); - return round_down ? (x % y) : (x % y + y); + int64_t rmod = x % y; + bool is_floor_div = + (y >= 0 && rmod >= 0) || + (y < 0 && rmod <= 0); + return is_floor_div ? rmod : rmod + y; } } // namespace arith diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index 2e230990eea7..04e166ae52c0 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -41,8 +41,9 @@ Mutate_(const LetStmt* op, const Stmt& s) { Expr value = this->Mutate(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); - return this->Mutate(op->body); } + // We keep the let-binding here + // as sub-class may or maynot choose to replace it. Stmt body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { @@ -152,8 +153,9 @@ Mutate_(const Let* op, const Expr& self) { Expr value = this->Mutate(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); - return this->Mutate(op->body); } + // We keep the let-binding here + // as sub-class may or maynot choose to replace it. Expr body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { diff --git a/src/arithmetic/ir_mutator_with_analyzer.h b/src/arithmetic/ir_mutator_with_analyzer.h index 1c03a233c910..bf4118e9c698 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.h +++ b/src/arithmetic/ir_mutator_with_analyzer.h @@ -45,6 +45,8 @@ class IRMutatorWithAnalyzer : public ir::IRMutator { explicit IRMutatorWithAnalyzer(Analyzer* analyzer) : analyzer_(analyzer) {} + using IRMutator::Mutate_; + // override functions that need to populate the context information. Stmt Mutate_(const ir::For* op, const Stmt& self) override; Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override; diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index 514b493ee05d..1278c7d32ee5 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file tvm/arithmetic/pattern_match.h * * \brief Internal tool for expression-template based pattern matching. @@ -326,6 +325,8 @@ TVM_PATTERN_BINARY_OP(operator/, ir::Div); TVM_PATTERN_BINARY_OP(operator%, ir::Mod); TVM_PATTERN_BINARY_OP(min, ir::Min); TVM_PATTERN_BINARY_OP(max, ir::Max); +TVM_PATTERN_BINARY_OP(truncdiv, ir::Div); +TVM_PATTERN_BINARY_OP(truncmod, ir::Mod); TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv); TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod); diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 740840961605..a567f502f766 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -1674,6 +1674,16 @@ Mutate_(const Call* op, const Expr& self) { if (op == nullptr) return ret; if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) { return op->args[0]; + } else if (op->is_intrinsic(Call::shift_right)) { + if (op->args[0].as() && op->args[1].as()) { + // the operator overload will eagerly constant fold. + return op->args[0] >> op->args[1]; + } + } else if (op->is_intrinsic(Call::bitwise_and)) { + if (op->args[0].as() && op->args[1].as()) { + // the operator overload will eagerly constant fold. + return op->args[0] & op->args[1]; + } } return ret; } @@ -1695,6 +1705,24 @@ Mutate_(const Cast* op, const Expr& self) { return cast(op->type, op->value); } +Expr RewriteSimplifier::Impl:: +Mutate_(const Let* op, const Expr& self) { + Expr value = this->Mutate(op->value); + if (!ir::HasSideEffect(value)) { + // it is fine to discard the let binding + // because the value will always be inlined in the simplifier. + analyzer_->Bind(op->var, value); + return this->Mutate(op->body); + } + Expr body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return self; + } else { + return Let::make(op->var, value, body); + } +} + Expr RewriteSimplifier::operator()(const Expr& expr) { // Run simplification in post order Expr res = expr; diff --git a/src/arithmetic/rewrite_simplify.h b/src/arithmetic/rewrite_simplify.h index 8ce27604d8a4..55965ce42d6a 100644 --- a/src/arithmetic/rewrite_simplify.h +++ b/src/arithmetic/rewrite_simplify.h @@ -72,6 +72,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { Expr Mutate_(const Call* op, const Expr& self) override; Expr Mutate_(const Variable* op, const Expr& self) override; Expr Mutate_(const Cast* op, const Expr& self) override; + Expr Mutate_(const Let* op, const Expr& self) override; protected: /*! \brief internal structure for comparison. */ diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index d0cd921d12c6..f784514e1302 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -51,6 +51,23 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return Mutate(stmt); } + Stmt Mutate_(const LetStmt* op, const Stmt& s) { + Expr value = this->Mutate(op->value); + if (!ir::HasSideEffect(value)) { + // it is fine to discard the let binding + // because the call to simplify will always inline the var. + analyzer_->Bind(op->var, value); + return Mutate(op->body); + } + Stmt body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return s; + } else { + return LetStmt::make(op->var, value, body); + } + } + // eliminate useless stores Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt stmt = IRMutator::Mutate_(op, s); diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index cd61ccaa0147..50da8a144c45 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -178,20 +178,28 @@ Expr operator*(Expr a, Expr b) { return ir::Mul::make(a, b); } -Expr operator/(Expr a, Expr b) { +Expr truncdiv(Expr a, Expr b) { BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::Div::make(a, b); } -Expr operator%(Expr a, Expr b) { +Expr truncmod(Expr a, Expr b) { BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::Mod::make(a, b); } +Expr operator/(Expr a, Expr b) { + return truncdiv(a, b); +} + +Expr operator%(Expr a, Expr b) { + return truncmod(a, b); +} + Expr floordiv(Expr a, Expr b) { BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 9403b71eb935..916724725eac 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,23 +18,28 @@ */ /*! - * Copyright (c) 2017 by Contributors - * Lower intrinsic calls to device specific ir when possible. + * Lower intrinsic calls and ops to device specific ir when possible. * \file lower_intrin.cc */ #include #include #include #include +#include #include #include "ir_util.h" +#include "../arithmetic/pattern_match.h" +#include "../arithmetic/ir_mutator_with_analyzer.h" namespace tvm { namespace ir { -class IntrinInjecter : public IRMutator { +class IntrinInjecter : public arith::IRMutatorWithAnalyzer { public: - explicit IntrinInjecter(std::string target) { + using IRMutatorWithAnalyzer::Mutate_; + + IntrinInjecter(arith::Analyzer* analyzer, std::string target) + : IRMutatorWithAnalyzer(analyzer) { std::istringstream is(target); std::string starget; is >> starget; @@ -61,6 +66,118 @@ class IntrinInjecter : public IRMutator { return IRMutator::Mutate_(op, e); } + // We use floordiv for integer analysis, + // but will need to lower them to native truncdiv instructions + Expr Mutate_(const FloorDiv* op, const Expr& e) final { + Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e); + op = ret.as(); + if (op == nullptr) return ret; + int shift; + const DataType& dtype = op->type; + if (dtype.is_float()) { + return floor(Div::make(op->a, op->b)); + } + CHECK(dtype.is_int() || !dtype.is_uint()); + + if (is_const_power_of_two_integer(op->b, &shift)) { + // lower to right shift if possible. + return op->a >> make_const(dtype, shift); + } + + if (analyzer_->CanProveGreaterEqual(op->b, 0)) { + // Common path, positive divisor + if (analyzer_->CanProveGreaterEqual(op->a, 0) || + analyzer_->CanProveGreaterEqual(e, 0)) { + return truncdiv(op->a, op->b); + } else { + DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; + Expr rdiv = truncdiv(op->a, op->b); + Expr rmod = truncmod(op->a, op->b); + // condition on b >= 0. + // truncmod(a, b) < 0 will implies ceildiv, + // So we need to correct these cases. + if (dtype == Int(32) || dtype == Int(64)) { + // equivalent to rdiv + (rmod >= 0 ? 0: -1); + return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); + } else { + return ir::Select::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1)); + } + } + } else { + // uncommon case + DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor"; + // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1) + // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) + Expr rdiv = truncdiv(op->a, op->b); + Expr rmod = truncmod(op->a, op->b); + return ir::Select::make( + (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), + rdiv, rdiv - make_const(dtype, 1)); + } + } + + Expr Mutate_(const FloorMod* op, const Expr& e) final { + Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e); + op = ret.as(); + if (op == nullptr) return ret; + // Lower floordiv to native truncdiv. + int shift; + const DataType& dtype = op->type; + CHECK(dtype.is_int() || !dtype.is_uint()); + + if (is_const_power_of_two_integer(op->b, &shift)) { + // lower to masking if possible. + int64_t mask = ( + static_cast(1) << static_cast(shift)) - 1; + return op->a & make_const(dtype, mask); + } + + if (analyzer_->CanProveGreaterEqual(op->b, 0)) { + // Common pass, positive divisor + if (analyzer_->CanProveGreaterEqual(op->a, 0) || + analyzer_->CanProveGreaterEqual(e, 0)) { + return truncmod(op->a, op->b); + } else { + DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident"; + // NOTE:condition on b >= 0. + // mod(a, b) < 0 will imply we are doing ceildiv, + // So we need to correct these cases. + Expr rmod = truncmod(op->a, op->b); + if (dtype == Int(32) || dtype == Int(64)) { + // (rmod >> shift) & b + // -> (rmod >= 0 ? 0: -1) & b + // -> rmod >= 0 ? 0 : b + return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); + } else { + return ir::Select::make(rmod >= 0, rmod, rmod + op->b); + } + } + } else { + // uncommon case + DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident"; + Expr rmod = truncmod(op->a, op->b); + // b > 0 && rmod >= 0 -> rmod + // b > 0 && rmod < 0 -> rmod + b + // b < 0 && rmod < 0 -> rmod + // b < 0 && rmod > 0 -> rmod + b + return ir::Select::make( + (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), + rmod, rmod + op->b); + } + } + + Expr Mutate_(const Max* op, const Expr& e) final { + using namespace arith; + PVar x, y; + PVar c; + if (max(floordiv(x, y), c).Match(e) && + c.Eval()->value >= 0 && + analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { + return max(Mutate(truncdiv(x, y).Eval()), c.Eval()); + } + return IRMutatorWithAnalyzer::Mutate_(op, e); + } + private: Expr SwapBroadcastCast(const Expr& e) { // Try to change broadcast(cast(x)) to cast(broadcast(x)) @@ -132,17 +249,27 @@ class IntrinInjecter : public IRMutator { } return Expr(); } + // patterns std::vector patterns_; const PackedFunc* fma_{nullptr}; }; +Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { + arith::Analyzer analyzer; + return IntrinInjecter(&analyzer, target).Mutate(stmt); +} + LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target) { auto n = make_node(*f.operator->()); - n->body = IntrinInjecter(target).Mutate(n->body); + n->body = LowerIntrinStmt(n->body, target); return LoweredFunc(n); } +// Register the api only for test purposes +TVM_REGISTER_API("ir_pass._LowerIntrinStmt") +.set_body_typed(LowerIntrinStmt); + } // namespace ir } // namespace tvm diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 34dad36a9076..06f1801d59bb 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -87,6 +87,7 @@ def test_llvm_lookup_intrin(): func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True) fcode = tvm.build(func, None, "llvm") + def test_llvm_add_pipeline(): nn = 1024 n = tvm.convert(nn) diff --git a/tests/python/unittest/test_pass_lower_intrin.py b/tests/python/unittest/test_pass_lower_intrin.py new file mode 100644 index 000000000000..d2d106df001e --- /dev/null +++ b/tests/python/unittest/test_pass_lower_intrin.py @@ -0,0 +1,111 @@ + # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file + # distributed with this work for additional information + # regarding copyright ownership. The ASF licenses this file + # to you under the Apache License, Version 2.0 (the + # "License"); you may not use this file except in compliance + # with the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, + # software distributed under the License is distributed on an + # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + # KIND, either express or implied. See the License for the + # specific language governing permissions and limitations + # under the License. +import tvm +import numpy as np + +def lower_intrin(stmt): + """wrapper to call transformation in stmt""" + lower_expr = isinstance(stmt, tvm.expr.Expr) + stmt = tvm.stmt.Evaluate(stmt) if lower_expr else stmt + stmt = tvm.ir_pass.CanonicalSimplify(stmt) + stmt = tvm.ir_pass._LowerIntrinStmt(stmt, "llvm") + return stmt.value if lower_expr else stmt.body + + +def check_value(expr, vx, vy, data, fref): + n = len(data) + A = tvm.placeholder((n,), name="A", dtype=expr.dtype) + B = tvm.placeholder((n,), name="B", dtype=expr.dtype) + + def make_binds(i): + x = expr + x = tvm.expr.Let(vx, A[i], x) + x = tvm.expr.Let(vy, B[i], x) + return x + + C = tvm.compute((n,), make_binds) + s = tvm.create_schedule([C.op]) + + if not tvm.module.enabled("llvm"): + return + + f = tvm.build(s, [A, B, C], "llvm") + a = tvm.nd.array(np.array([x for x, y in data], dtype=expr.dtype)) + b = tvm.nd.array(np.array([y for x, y in data], dtype=expr.dtype)) + c = tvm.nd.array(np.zeros(len(data), dtype=expr.dtype)) + f(a, b, c) + cref = np.array([fref(x, y) for x, y in data]) + np.testing.assert_equal(c.asnumpy(), cref) + + + +def get_ref_data(): + """Get reference data for every pairs""" + import itertools + x = range(-10, 10) + y = list(range(-10, 10)) + y.remove(0) + return list(itertools.product(x, y)) + + +def test_lower_floordiv(): + data = get_ref_data() + for dtype in ["int32", "int64", "int16"]: + x = tvm.var("x", dtype=dtype) + y = tvm.var("y", dtype=dtype) + zero = tvm.const(0, dtype) + # no constraints + res = lower_intrin(tvm.floordiv(x, y)) + check_value(res, x, y, data, lambda a, b: a // b) + # rhs >= 0 + res = lower_intrin(tvm.expr.Select(y >= 0, tvm.floordiv(x, y), zero)) + check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0) + # involves max + res = lower_intrin(tvm.expr.Select(y >= 0, tvm.max(tvm.floordiv(x, y), zero), zero)) + check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0) + # lhs >= 0 + res = lower_intrin(tvm.expr.Select(tvm.all(y >= 0, x >= 0), tvm.floordiv(x, y), zero)) + check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0) + # const power of two + res = lower_intrin(tvm.floordiv(x, tvm.const(8, dtype=dtype))) + check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b) + + +def test_lower_floormod(): + data = get_ref_data() + for dtype in ["int32", "int64", "int16"]: + x = tvm.var("x", dtype=dtype) + y = tvm.var("y", dtype=dtype) + zero = tvm.const(0, dtype) + # no constraints + res = lower_intrin(tvm.floormod(x, y)) + check_value(res, x, y, data, lambda a, b: a % b) + # rhs >= 0 + res = lower_intrin(tvm.expr.Select(y >= 0, tvm.floormod(x, y), zero)) + check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0) + # lhs >= 0 + res = lower_intrin(tvm.expr.Select(tvm.all(y >= 0, x >= 0), tvm.floormod(x, y), zero)) + check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0) + # const power of two + res = lower_intrin(tvm.floormod(x, tvm.const(8, dtype=dtype))) + check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b) + + + +if __name__ == "__main__": + test_lower_floordiv() + test_lower_floormod()