diff --git a/src/ir/bits.h b/src/ir/bits.h index 25d80fba74b..c4b8a76de18 100644 --- a/src/ir/bits.h +++ b/src/ir/bits.h @@ -34,6 +34,14 @@ inline int32_t lowBitMask(int32_t bits) { return ret >> (32 - bits); } +inline int64_t lowBitMask64(int64_t bits) { + uint64_t ret = -1; + if (bits >= 64) { + return ret; + } + return ret >> (64 - bits); +} + // checks if the input is a mask of lower bits, i.e., all 1s up to some high // bit, and all zeros from there. returns the number of masked bits, or 0 if // this is not such a mask diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index f16b0c1be1c..f4c20b287b3 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -3933,10 +3933,8 @@ struct OptimizeInstructions Binary* add; Const* c1; Const* c2; - if ((matches(curr, - binary(binary(&add, Add, any(), ival(&c1)), ival(&c2))) || - matches(curr, - binary(binary(&add, Add, any(), ival(&c1)), ival(&c2)))) && + if (matches(curr, + binary(binary(&add, Add, any(), ival(&c1)), ival(&c2))) && !canOverflow(add)) { if (c2->value.geU(c1->value).getInteger()) { // This is the first line above, we turn into x > (C2-C1) @@ -3957,6 +3955,118 @@ struct OptimizeInstructions } } + // (x >> C1) ? C2 -> x ? (C2 << C1) if no overflow in << + // This may require an adjustment to the constant on the right, see below. + // TODO: unsigned shift + { + Binary* shift; + Const* c1; + Const* c2; + if (matches( + curr, + binary(binary(&shift, ShrS, any(), ival(&c1)), ival(&c2)))) { + // Consider C2 << C1 and see if it can overflow or even change the + // sign. If it can't, then we know that removing the shift from the + // left side does not change the sign there (it's a signed-shift- + // right) and since neither does the left change sign, we can move the + // shift to the right side without altering the result. + auto shifts = Bits::getEffectiveShifts(c1); + auto c2MaxBits = Bits::getMaxBits(c2, this); + auto typeMaxBits = getBitsForType(type); + if (c2MaxBits + shifts < typeMaxBits) { + // Great, the reversed shift is in a range that does not cause any + // problems, so we can in principle try to reverse the operation + // by shifting both sides to the left, which will "undo" the + // existing shift on x, causing us to only have a shift on the + // right: + // + // ((x >> C1) << C1) ? (C2 << C1) + // + // However, an adjustment may be needed since the original Shr is + // not a linear operation - it clears the lower bits. That is, after + // the new shift, we have: + // + // (x & mask) ? (C2 << C1) + // + // Note that we don't want to optimize to this pattern, as it is + // strictly larger than the original (the mask is a larger constant + // than the shift, and the new constant on the right is larger). So + // we only want to optimize here if we can reduce this further, + // which we can in some cases. + // + // To implement the necessary adjustment, consider that the mask + // makes the lower bits not matter. In other words, we are rounding + // x down, and comparing it to a number that is similarly rounded + // (since it is the result of a left shift). As a result, an + // adjustment may be needed as a larger or smaller x may be + // possible, e.g., consider for < : + // + // signed(x & -4) < (100 << 2) = 400 + // + // Consider x = 400. It has no lower bits to mask off, and 400 < 400 + // which is false. For anything less than 400 the mask will round + // x down to something less than 400, so this will be true: + // + // (x & -4) < x < 400 + // + // Thus we can optimize this to x < 400, and no special adjustment + // is necessary. However, consider <= instead of < : + // + // signed(x & -4) <= 400 + // + // x = 400 results in 400 <= 400 which is true. But x can also be + // larger since the bits would get masked off. That is, for x in + // [401, 403], we get (x & -4) == 400. Only 404 would be too large. + // And so we can optimize to x <= 403, basically adding the bits to + // the constant on the right. + // + // Note that we cannot optimize == or != here, as e.g. + // + // (x & -4) == 400 + // + // is true for all of [400, 403]. Only when we have a range can we + // extend the range with an adjustment, basically. + auto moveShift = [&]() { + // Helper function to remove the shift on the left and add a shift + // onto the constant on the right, + // (x >> C1) ? C2 => x ? (C2 << C1) + curr->left = shift->left; + c2->value = c2->value.shl(c1->value); + }; + auto orLowerBits = [&]() { + c2->value = c2->value.or_( + Literal::makeFromInt64(Bits::lowBitMask64(shifts), type)); + }; + if (curr->op == Abstract::getBinary(type, LtS)) { + // Explained above. + moveShift(); + return curr; + } else if (curr->op == Abstract::getBinary(type, LeS)) { + // Explained above. + moveShift(); + orLowerBits(); + return curr; + } else if (curr->op == Abstract::getBinary(type, GtS)) { + // E.g. + // signed(x & -4) > (100 << 2) = 400 + // signed(x & -4) >= 401 + // x & -4 is rounded down to a multiple of 4, so this is only true + // when x > 403. + moveShift(); + orLowerBits(); + return curr; + } else if (curr->op == Abstract::getBinary(type, GeS)) { + // E.g. + // signed(x & -4) >= (100 << 2) = 400 + // x & -4 is rounded down to a multiple of 4, so this is only true + // when x >= 400, and no adjustment is needed. + moveShift(); + return curr; + } + } + } + } + // Comparisons can sometimes be simplified depending on the number of // bits, e.g. (unsigned)x > y must be true if x has strictly more bits. // A common case is a constant on the right, e.g. (x & 255) < 256 must be diff --git a/test/lit/passes/optimize-instructions-shifts.wast b/test/lit/passes/optimize-instructions-shifts.wast new file mode 100644 index 00000000000..1ed271ae148 --- /dev/null +++ b/test/lit/passes/optimize-instructions-shifts.wast @@ -0,0 +1,200 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py and should not be edited. +;; RUN: wasm-opt %s --optimize-instructions -S -o - | filecheck %s + +(module + ;; CHECK: (func $less-than-shifted (param $x i32) (param $y i64) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.lt_s + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (i32.const 400) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i64.lt_s + ;; CHECK-NEXT: (local.get $y) + ;; CHECK-NEXT: (i64.const 400) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $less-than-shifted (param $x i32) (param $y i64) + ;; (x >> 2) < 100 => x < 400 + (drop + (i32.lt_s + (i32.shr_s + (local.get $x) + (i32.const 2) + ) + (i32.const 100) + ) + ) + ;; As above, but with i64. + (drop + (i64.lt_s + (i64.shr_s + (local.get $y) + (i64.const 2) + ) + (i64.const 100) + ) + ) + ) + + ;; CHECK: (func $less-than-shifted-overflow (param $x i32) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.lt_s + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (i32.const 2139095040) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.lt_s + ;; CHECK-NEXT: (i32.shr_s + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (i32.const 24) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 255) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.lt_s + ;; CHECK-NEXT: (i32.shr_s + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (i32.const 25) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 255) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $less-than-shifted-overflow (param $x i32) + ;; Borderline values: we don't want the constant on the right, when shifted + ;; by the number of shifts, to become signed, as that might alter the + ;; result. This case can be optimized, and the ones after it not. + (drop + (i32.lt_s + (i32.shr_s + (local.get $x) + (i32.const 23) + ) + (i32.const 255) + ) + ) + (drop + (i32.lt_s + (i32.shr_s + (local.get $x) + (i32.const 24) + ) + (i32.const 255) + ) + ) + (drop + (i32.lt_s + (i32.shr_s + (local.get $x) + (i32.const 25) + ) + (i32.const 255) + ) + ) + ) + + ;; CHECK: (func $less-than-shifted-todo (param $x i32) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.lt_u + ;; CHECK-NEXT: (i32.shr_s + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (i32.const 2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 100) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.lt_u + ;; CHECK-NEXT: (i32.shr_u + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (i32.const 2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 100) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $less-than-shifted-todo (param $x i32) + ;; We don't optimize these yet. + ;; This comparison is unsigned. + (drop + (i32.lt_u + (i32.shr_s + (local.get $x) + (i32.const 2) + ) + (i32.const 100) + ) + ) + ;; This shift is unsigned. + (drop + (i32.lt_s + (i32.shr_u + (local.get $x) + (i32.const 2) + ) + (i32.const 100) + ) + ) + ) + + ;; CHECK: (func $other-comparisons (param $x i32) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.le_s + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (i32.const 403) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.gt_s + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (i32.const 403) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.ge_s + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (i32.const 400) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $other-comparisons (param $x i32) + ;; <= : + ;; (x >> 2) <= 100 => x <= 403 + (drop + (i32.le_s + (i32.shr_s + (local.get $x) + (i32.const 2) + ) + (i32.const 100) + ) + ) + ;; > : + ;; (x >> 2) > 100 => x > 403 + (drop + (i32.gt_s + (i32.shr_s + (local.get $x) + (i32.const 2) + ) + (i32.const 100) + ) + ) + ;; >= : + ;; (x >> 2) >= 100 => x >= 400 + (drop + (i32.ge_s + (i32.shr_s + (local.get $x) + (i32.const 2) + ) + (i32.const 100) + ) + ) + ) +) diff --git a/test/lit/passes/optimize-instructions.wast b/test/lit/passes/optimize-instructions.wast index c1835442d34..7907e985b58 100644 --- a/test/lit/passes/optimize-instructions.wast +++ b/test/lit/passes/optimize-instructions.wast @@ -2638,11 +2638,8 @@ ) ;; CHECK: (func $lt_s-sext-zero (param $0 i32) (result i32) ;; CHECK-NEXT: (i32.lt_s - ;; CHECK-NEXT: (i32.shr_s - ;; CHECK-NEXT: (i32.shl - ;; CHECK-NEXT: (local.get $0) - ;; CHECK-NEXT: (i32.const 24) - ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.shl + ;; CHECK-NEXT: (local.get $0) ;; CHECK-NEXT: (i32.const 24) ;; CHECK-NEXT: ) ;; CHECK-NEXT: (i32.const 0)