From 4b99ea0f13004c31f7fb6243d1e2b4ab3cd9a35f Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Fri, 6 Jan 2023 21:53:00 -0600 Subject: [PATCH 1/4] [Arith] Use ConstIntBound to remove negative numerator when lowering Negative numerators to modulo/remainder operations are not supported by the Vulkan API. While the SPIR-V instructions [`OpSRem`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSRem) and [`OpSMod`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSMod) have identical semantics to `tir::Mod` and `tir::FloorMod`, respectively, use of either instruction within Vulkan results in undefined behavior. From the [Vulkan spec](https://registry.khronos.org/vulkan/specs/1.3/html/chap37.html#spirvenv-op-prec): > For the OpSRem and OpSMod instructions, if either operand is > negative the result is undefined. > > Note: While the OpSRem and OpSMod instructions are supported by the > Vulkan environment, they require non-negative values and thus do not > enable additional functionality beyond what OpUMod provides. This issue was first noticed in https://github.com/apache/tvm/pull/13530, where use of integer arithmetic resulted in negative numerators. This hadn't caused issues previously, because most use of div/mod use a denominator that is a power of two. In these cases, `tir.LowerIntrin` implements floordiv and floormod using only bitwise operations. When the denominator isn't a power of two, both `tir::FloorDiv` and `tir::FloorMod` are implemented in terms of `tir::Mod`, which triggers the undefined behavior for negative numerators. This commit alters the lowering of FloorDiv/FloorMod to TruncDiv/TruncMod, in cases where the denominator is positive, the numerator is sometimes negative, and the range of the numerator is known. In these cases, the FloorDiv/FloorMod is now implemented by offsetting the numerator such that it is always positive. --- src/tir/transforms/lower_intrin.cc | 72 ++++++++++++------- .../unittest/test_target_codegen_vulkan.py | 42 +++++++++++ 2 files changed, 89 insertions(+), 25 deletions(-) diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 2555002d29b0..4c02b5d65ab2 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -112,20 +112,31 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // Common path, positive divisor if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) { return truncdiv(op->a, op->b); + } + + // If the numerator's lower bound is known, express the floordiv + // in terms of truncdiv using only positive operands. + arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); + if (const_int_bound->min_value != arith::ConstIntBound::kNegInf && + const_int_bound->min_value < 0) { + IntImm min(op->a->dtype, const_int_bound->min_value); + PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); + return truncdiv(op->a + op->b * ceildiv, op->b) - ceildiv; + } + + DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; + PrimExpr rdiv = truncdiv(op->a, op->b); + PrimExpr rmod = truncmod(op->a, op->b); + // condition on b >= 0. + // truncmod(a, b) < 0 will implies ceildiv, + // So we need to correct these cases. + if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) { + // equivalent to rdiv + (rmod >= 0 ? 0: -1); + return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { - DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; - PrimExpr rdiv = truncdiv(op->a, op->b); - PrimExpr rmod = truncmod(op->a, op->b); - // condition on b >= 0. - // truncmod(a, b) < 0 will implies ceildiv, - // So we need to correct these cases. - if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) { - // equivalent to rdiv + (rmod >= 0 ? 0: -1); - return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); - } else { - return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); - } + return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); } + } else { if (dtype.is_float()) { // floor(a / b) @@ -165,21 +176,32 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // Common pass, positive divisor if (analyzer_->CanProveGreaterEqual(op->a, 0)) { return truncmod(op->a, op->b); + } + + // If the numerator's lower bound is known, express the floormod + // in terms of truncmod using only positive operands. + arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); + if (const_int_bound->min_value != arith::ConstIntBound::kNegInf && + const_int_bound->min_value < 0) { + IntImm min(op->a->dtype, const_int_bound->min_value); + PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); + return truncmod(op->a + op->b * ceildiv, op->b); + } + + DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident"; + // NOTE:condition on b >= 0. + // mod(a, b) < 0 will imply we are doing ceildiv, + // So we need to correct these cases. + PrimExpr rmod = truncmod(op->a, op->b); + if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) { + // (rmod >> shift) & b + // -> (rmod >= 0 ? 0: -1) & b + // -> rmod >= 0 ? 0 : b + return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); } else { - DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident"; - // NOTE:condition on b >= 0. - // mod(a, b) < 0 will imply we are doing ceildiv, - // So we need to correct these cases. - PrimExpr rmod = truncmod(op->a, op->b); - if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) { - // (rmod >> shift) & b - // -> (rmod >= 0 ? 0: -1) & b - // -> rmod >= 0 ? 0 : b - return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); - } else { - return tir::Select(rmod >= 0, rmod, rmod + op->b); - } + return tir::Select(rmod >= 0, rmod, rmod + op->b); } + } else { if (dtype.is_float()) { // a - floor(a / b) * b diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 76cad250e053..7b71f4d4ab17 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -28,6 +28,7 @@ import tvm.testing from tvm import relay, te from tvm.topi.math import cast +from tvm.script import tir as T dtype = tvm.testing.parameter("float32", "int32", "float16", "int8") @@ -558,5 +559,46 @@ def do_compute(ins, outs): tvm.build(s, [Out], target) +def test_negative_operand_divmod(target, dev): + """Test handling of negative offsets to floormod/floordiv + + Even though the SPIR-V spec states that OpSRem and OpSMod can give + the signed modulo, the Vulkan spec states that any use of negative + operands is undefined behavior. This test starts with negative + operands to floordiv, validating that they are simplified into the + corresponding positive operands, such that the final TIR can be + expressed using only positive operands. + + SPIR-V: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSRem + Vulkan: https://registry.khronos.org/vulkan/specs/1.3/html/chap37.html#spirvenv-op-prec + """ + + N = 32 + offset = 16 + divisor = 5 + + @T.prim_func + def func(A: T.Buffer[(N, 2), "int32"]): + for i in T.serial(N): + with T.block("A"): + v_i = T.axis.spatial(N, i) + A[v_i, 0] = T.floordiv(v_i - offset, divisor) + A[v_i, 1] = T.floormod(v_i - offset, divisor) + + if "gpu" in tvm.target.Target(target).keys: + sch = tvm.tir.Schedule(func) + sch.bind(sch.get_loops("A")[0], "threadIdx.x") + func = sch.mod["main"] + + built = tvm.build(func, target=target) + + a_dev = tvm.nd.empty([N, 2], "int32", dev) + built(a_dev) + a = a_dev.numpy() + + np.testing.assert_array_equal(a[:, 0], (np.arange(N) - offset) // divisor) + np.testing.assert_array_equal(a[:, 1], (np.arange(N) - offset) % divisor) + + if __name__ == "__main__": tvm.testing.main() From 2d0a7cf23a9b484ecc81f159084170d7877ba4e1 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Sun, 8 Jan 2023 20:56:17 -0600 Subject: [PATCH 2/4] Add check to avoid -INT32_MIN --- src/tir/transforms/lower_intrin.cc | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 4c02b5d65ab2..5c0cdc974765 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -27,6 +27,7 @@ #include #include +#include #include #include "../../arith/ir_mutator_with_analyzer.h" @@ -118,10 +119,12 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // in terms of truncdiv using only positive operands. arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); if (const_int_bound->min_value != arith::ConstIntBound::kNegInf && - const_int_bound->min_value < 0) { + const_int_bound->min_value < 0 && + const_int_bound->min_value > -(1LL << (op->a->dtype.bits() - 1))) { IntImm min(op->a->dtype, const_int_bound->min_value); - PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); - return truncdiv(op->a + op->b * ceildiv, op->b) - ceildiv; + PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); + PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); + return truncdiv(offset_numerator, op->b) - ceildiv; } DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; @@ -182,10 +185,12 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // in terms of truncmod using only positive operands. arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); if (const_int_bound->min_value != arith::ConstIntBound::kNegInf && - const_int_bound->min_value < 0) { + const_int_bound->min_value < 0 && + const_int_bound->min_value > -(1LL << (op->a->dtype.bits() - 1))) { IntImm min(op->a->dtype, const_int_bound->min_value); PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); - return truncmod(op->a + op->b * ceildiv, op->b); + PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); + return truncmod(offset_numerator, op->b); } DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident"; From a56f5c3e556af2d6083a06c401e2f46f3ff57c98 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Tue, 10 Jan 2023 08:11:39 -0600 Subject: [PATCH 3/4] Updated to use `tvm::min_value(DataType)` --- src/tir/transforms/lower_intrin.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 5c0cdc974765..0051b309641f 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -120,7 +120,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); if (const_int_bound->min_value != arith::ConstIntBound::kNegInf && const_int_bound->min_value < 0 && - const_int_bound->min_value > -(1LL << (op->a->dtype.bits() - 1))) { + const_int_bound->min_value > Downcast(tvm::min_value(op->a->dtype))->value) { IntImm min(op->a->dtype, const_int_bound->min_value); PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); @@ -186,7 +186,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); if (const_int_bound->min_value != arith::ConstIntBound::kNegInf && const_int_bound->min_value < 0 && - const_int_bound->min_value > -(1LL << (op->a->dtype.bits() - 1))) { + const_int_bound->min_value > Downcast(tvm::min_value(op->a->dtype))->value) { IntImm min(op->a->dtype, const_int_bound->min_value); PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); From c392b3e7dcb6f46c0a7b23a33766378b012e9f4d Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Tue, 10 Jan 2023 08:41:29 -0600 Subject: [PATCH 4/4] Added derivation for floordiv/floormod in terms of truncdiv/trundmod --- src/tir/transforms/lower_intrin.cc | 59 ++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 0051b309641f..8c850f0dea41 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -121,6 +121,36 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { if (const_int_bound->min_value != arith::ConstIntBound::kNegInf && const_int_bound->min_value < 0 && const_int_bound->min_value > Downcast(tvm::min_value(op->a->dtype))->value) { + // The goal is to write floordiv(a,b) in terms of truncdiv, without using + // negative operands. + // + // For any integer c + // + // floordiv(a,b) == floordiv(a + b*c - b*c, b) + // == floordiv(a + b*c, b) - c + // + // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of + // truncdiv as follows. + // + // c == ceildiv(-a_min,b) + // == floordiv(-a_min + (b-1), b) + // == truncdiv(-a_min + (b-1), b) + // + // When substituted into `a + b*c`, this results in a positive argument. + // + // a + b*c + // == a + b*ceildiv(-a_min,b) + // == a - b*floordiv(a_min,b) + // >= a - b*floordiv(a,b) + // == floormod(a, b) + // >= 0 + // + // Since the argument is positive, this allows floordiv to be written as + // followed. + // + // floordiv(a,b) + // == floordiv(a + b*c, b) - c + // == truncdiv(a + b*c, b) - c IntImm min(op->a->dtype, const_int_bound->min_value); PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); @@ -187,6 +217,35 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { if (const_int_bound->min_value != arith::ConstIntBound::kNegInf && const_int_bound->min_value < 0 && const_int_bound->min_value > Downcast(tvm::min_value(op->a->dtype))->value) { + // The goal is to write floormod(a,b) in terms of truncdiv and truncmod, + // without using negative operands. + // + // For any integer c + // + // floormod(a, b) == floormod(a + b*c, b) + // + // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of + // truncdiv as follows. + // + // c == ceildiv(-a_min,b) + // == floordiv(-a_min + (b-1), b) + // == truncdiv(-a_min + (b-1), b) + // + // When substituted into `a + b*c`, this results in a positive argument. + // + // a + b*c + // == a + b*ceildiv(-a_min,b) + // == a - b*floordiv(a_min,b) + // >= a - b*floordiv(a,b) + // == floormod(a, b) + // >= 0 + // + // Since the argument is positive, this allows floordiv to be written as + // followed. + // + // floormod(a,b) + // == floormod(a + b*c, b) + // == truncmod(a + b*c, b) IntImm min(op->a->dtype, const_int_bound->min_value); PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);