Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 111 additions & 25 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>

#include <limits>
#include <unordered_set>

#include "../../arith/ir_mutator_with_analyzer.h"
Expand Down Expand Up @@ -112,20 +113,63 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// Common path, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) {
return truncdiv(op->a, op->b);
}

// If the numerator's lower bound is known, express the floordiv
// in terms of truncdiv using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value != arith::ConstIntBound::kNegInf &&
const_int_bound->min_value < 0 &&
const_int_bound->min_value > Downcast<IntImm>(tvm::min_value(op->a->dtype))->value) {
// The goal is to write floordiv(a,b) in terms of truncdiv, without using
// negative operands.
//
// For any integer c
//
// floordiv(a,b) == floordiv(a + b*c - b*c, b)
// == floordiv(a + b*c, b) - c
//
// Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of
// truncdiv as follows.
//
// c == ceildiv(-a_min,b)
// == floordiv(-a_min + (b-1), b)
// == truncdiv(-a_min + (b-1), b)
//
// When substituted into `a + b*c`, this results in a positive argument.
//
// a + b*c
// == a + b*ceildiv(-a_min,b)
// == a - b*floordiv(a_min,b)
// >= a - b*floordiv(a,b)
// == floormod(a, b)
// >= 0
//
// Since the argument is positive, this allows floordiv to be written as
// followed.
//
// floordiv(a,b)
// == floordiv(a + b*c, b) - c
// == truncdiv(a + b*c, b) - c
IntImm min(op->a->dtype, const_int_bound->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncdiv(offset_numerator, op->b) - ceildiv;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the rationale is:

a // b =>
(a + -min - -min) // b => 
(a + (x*b-y) - (x*b-y)) // b =>  ( where x is ceildiv(-min, b), x >= 0, y >= 0 )
(a + x*b) // b - x =>
(a + x*b) / b - x ( since a + x*b >= a + -min >= 0 )

Right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that is the rationale. I've added the derivation in a comment, which I probably should have done from the start.

}

DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1));
}
return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1));
}

} else {
if (dtype.is_float()) {
// floor(a / b)
Expand Down Expand Up @@ -165,21 +209,63 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// Common pass, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
return truncmod(op->a, op->b);
}

// If the numerator's lower bound is known, express the floormod
// in terms of truncmod using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value != arith::ConstIntBound::kNegInf &&
const_int_bound->min_value < 0 &&
const_int_bound->min_value > Downcast<IntImm>(tvm::min_value(op->a->dtype))->value) {
// The goal is to write floormod(a,b) in terms of truncdiv and truncmod,
// without using negative operands.
//
// For any integer c
//
// floormod(a, b) == floormod(a + b*c, b)
//
// Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of
// truncdiv as follows.
//
// c == ceildiv(-a_min,b)
// == floordiv(-a_min + (b-1), b)
// == truncdiv(-a_min + (b-1), b)
//
// When substituted into `a + b*c`, this results in a positive argument.
//
// a + b*c
// == a + b*ceildiv(-a_min,b)
// == a - b*floordiv(a_min,b)
// >= a - b*floordiv(a,b)
// == floormod(a, b)
// >= 0
//
// Since the argument is positive, this allows floordiv to be written as
// followed.
//
// floormod(a,b)
// == floormod(a + b*c, b)
// == truncmod(a + b*c, b)
IntImm min(op->a->dtype, const_int_bound->min_value);
PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncmod(offset_numerator, op->b);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

floormod(a, b) =>
floormod(a + -min - -min, b) => 
floormod(a + (x*b-y) - (x*b-y)), b) =>  ( where x is ceildiv(-min, b), x >= 0, y >= 0 )
floormod(a + x*b, b) =>
(a + x*b) % b  ( since a + x*b >= a + -min >= 0 )

Right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that derivation is correct, and I've added a comment here as well.

}

DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
PrimExpr rmod = truncmod(op->a, op->b);
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
} else {
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
PrimExpr rmod = truncmod(op->a, op->b);
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
} else {
return tir::Select(rmod >= 0, rmod, rmod + op->b);
}
return tir::Select(rmod >= 0, rmod, rmod + op->b);
}

} else {
if (dtype.is_float()) {
// a - floor(a / b) * b
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_target_codegen_vulkan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import tvm.testing
from tvm import relay, te
from tvm.topi.math import cast
from tvm.script import tir as T


dtype = tvm.testing.parameter("float32", "int32", "float16", "int8")
Expand Down Expand Up @@ -558,5 +559,46 @@ def do_compute(ins, outs):
tvm.build(s, [Out], target)


def test_negative_operand_divmod(target, dev):
"""Test handling of negative offsets to floormod/floordiv

Even though the SPIR-V spec states that OpSRem and OpSMod can give
the signed modulo, the Vulkan spec states that any use of negative
operands is undefined behavior. This test starts with negative
operands to floordiv, validating that they are simplified into the
corresponding positive operands, such that the final TIR can be
expressed using only positive operands.

SPIR-V: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSRem
Vulkan: https://registry.khronos.org/vulkan/specs/1.3/html/chap37.html#spirvenv-op-prec
"""

N = 32
offset = 16
divisor = 5

@T.prim_func
def func(A: T.Buffer[(N, 2), "int32"]):
for i in T.serial(N):
with T.block("A"):
v_i = T.axis.spatial(N, i)
A[v_i, 0] = T.floordiv(v_i - offset, divisor)
A[v_i, 1] = T.floormod(v_i - offset, divisor)

if "gpu" in tvm.target.Target(target).keys:
sch = tvm.tir.Schedule(func)
sch.bind(sch.get_loops("A")[0], "threadIdx.x")
func = sch.mod["main"]

built = tvm.build(func, target=target)

a_dev = tvm.nd.empty([N, 2], "int32", dev)
built(a_dev)
a = a_dev.numpy()

np.testing.assert_array_equal(a[:, 0], (np.arange(N) - offset) // divisor)
np.testing.assert_array_equal(a[:, 1], (np.arange(N) - offset) % divisor)


if __name__ == "__main__":
tvm.testing.main()