Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,30 @@ TVM_DLL Expr operator||(Expr a, Expr b);
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator!(Expr a);
/*!
* \brief compute trunc(a / b)
*
* This is the default integer division behavior in C.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr truncdiv(Expr a, Expr b);
/*!
* \brief compute the remainder of truncdiv
*
* This is the default integer division behavior in C.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr truncmod(Expr a, Expr b);
/*!
* \brief compute floor(a / b)
*
Expand Down
46 changes: 46 additions & 0 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,52 @@ def reducer(expr, axis, where=None, *args):
return reducer


def truncdiv(a, b):
"""Compute the truncdiv of two expressions.

Parameters
----------
a : Expr
The left hand operand

b : Expr
The right hand operand

Returns
-------
res : Expr
The result expression.

Note
----
This is the default integer division behavior in C.
"""
return _make._OpTruncDiv(a, b)


def truncmod(a, b):
"""Compute the truncmod of two expressions.

Parameters
----------
a : Expr
The left hand operand

b : Expr
The right hand operand

Returns
-------
res : Expr
The result expression.

Note
----
This is the default integer division behavior in C.
"""
return _make._OpTruncMod(a, b)


def floordiv(a, b):
"""Compute the floordiv of two expressions.

Expand Down
2 changes: 2 additions & 0 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpPow, pow);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
Expand Down
1 change: 0 additions & 1 deletion src/arithmetic/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

/*!
* Copyright (c) 2019 by Contributors
* \file const_fold.h
* \brief Centralized location for constant folding.
*/
Expand Down
21 changes: 11 additions & 10 deletions src/arithmetic/int_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ inline bool WillOverflow<ir::Mod>(int64_t x,
* \return the result.
*/
inline int64_t floordiv(int64_t x, int64_t y) {
bool round_down =
(x >= 0 && y >= 0) ||
(x <= 0 && y <= 0) ||
(x % y == 0);
return round_down ? (x / y) : (x / y - 1);
int64_t rdiv = x / y;
int64_t rmod = x % y;
bool is_floor_div =
(y >= 0 && rmod >= 0) ||
(y < 0 && rmod <= 0);
return is_floor_div ? rdiv : (rdiv - 1);
}


Expand All @@ -114,11 +115,11 @@ inline int64_t floordiv(int64_t x, int64_t y) {
* \return the result.
*/
inline int64_t floormod(int64_t x, int64_t y) {
bool round_down =
(x >= 0 && y >= 0) ||
(x <= 0 && y <= 0) ||
(x % y == 0);
return round_down ? (x % y) : (x % y + y);
int64_t rmod = x % y;
bool is_floor_div =
(y >= 0 && rmod >= 0) ||
(y < 0 && rmod <= 0);
return is_floor_div ? rmod : rmod + y;
}

} // namespace arith
Expand Down
6 changes: 4 additions & 2 deletions src/arithmetic/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
Expand Down Expand Up @@ -152,8 +153,9 @@ Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
Expand Down
2 changes: 2 additions & 0 deletions src/arithmetic/ir_mutator_with_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class IRMutatorWithAnalyzer : public ir::IRMutator {
explicit IRMutatorWithAnalyzer(Analyzer* analyzer)
: analyzer_(analyzer) {}

using IRMutator::Mutate_;

// override functions that need to populate the context information.
Stmt Mutate_(const ir::For* op, const Stmt& self) override;
Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override;
Expand Down
3 changes: 2 additions & 1 deletion src/arithmetic/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

/*!
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/pattern_match.h
*
* \brief Internal tool for expression-template based pattern matching.
Expand Down Expand Up @@ -326,6 +325,8 @@ TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);

Expand Down
28 changes: 28 additions & 0 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1674,6 +1674,16 @@ Mutate_(const Call* op, const Expr& self) {
if (op == nullptr) return ret;
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0];
} else if (op->is_intrinsic(Call::shift_right)) {
if (op->args[0].as<IntImm>() && op->args[1].as<IntImm>()) {
// the operator overload will eagerly constant fold.
return op->args[0] >> op->args[1];
}
} else if (op->is_intrinsic(Call::bitwise_and)) {
if (op->args[0].as<IntImm>() && op->args[1].as<IntImm>()) {
// the operator overload will eagerly constant fold.
return op->args[0] & op->args[1];
}
}
return ret;
}
Expand All @@ -1695,6 +1705,24 @@ Mutate_(const Cast* op, const Expr& self) {
return cast(op->type, op->value);
}

Expr RewriteSimplifier::Impl::
Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
} else {
return Let::make(op->var, value, body);
}
}

Expr RewriteSimplifier::operator()(const Expr& expr) {
// Run simplification in post order
Expr res = expr;
Expand Down
1 change: 1 addition & 0 deletions src/arithmetic/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override;
Expr Mutate_(const Cast* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;

protected:
/*! \brief internal structure for comparison. */
Expand Down
17 changes: 17 additions & 0 deletions src/arithmetic/stmt_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Mutate(stmt);
}

Stmt Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the call to simplify will always inline the var.
analyzer_->Bind(op->var, value);
return Mutate(op->body);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}

// eliminate useless stores
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Expand Down
12 changes: 10 additions & 2 deletions src/lang/expr_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,28 @@ Expr operator*(Expr a, Expr b) {
return ir::Mul::make(a, b);
}

Expr operator/(Expr a, Expr b) {
Expr truncdiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Div>(a, b);
if (ret.defined()) return ret;
return ir::Div::make(a, b);
}

Expr operator%(Expr a, Expr b) {
Expr truncmod(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Mod>(a, b);
if (ret.defined()) return ret;
return ir::Mod::make(a, b);
}

Expr operator/(Expr a, Expr b) {
return truncdiv(a, b);
}

Expr operator%(Expr a, Expr b) {
return truncmod(a, b);
}

Expr floordiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
Expand Down
Loading