Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/ir/bits.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ inline int32_t lowBitMask(int32_t bits) {
return ret >> (32 - bits);
}

inline int64_t lowBitMask64(int64_t bits) {
uint64_t ret = -1;
if (bits >= 64) {
return ret;
}
return ret >> (64 - bits);
}

// checks if the input is a mask of lower bits, i.e., all 1s up to some high
// bit, and all zeros from there. returns the number of masked bits, or 0 if
// this is not such a mask
Expand Down
118 changes: 114 additions & 4 deletions src/passes/OptimizeInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3933,10 +3933,8 @@ struct OptimizeInstructions
Binary* add;
Const* c1;
Const* c2;
if ((matches(curr,
binary(binary(&add, Add, any(), ival(&c1)), ival(&c2))) ||
matches(curr,
binary(binary(&add, Add, any(), ival(&c1)), ival(&c2)))) &&
if (matches(curr,
binary(binary(&add, Add, any(), ival(&c1)), ival(&c2))) &&
!canOverflow(add)) {
if (c2->value.geU(c1->value).getInteger()) {
// This is the first line above, we turn into x > (C2-C1)
Expand All @@ -3957,6 +3955,118 @@ struct OptimizeInstructions
}
}

// (x >> C1) ? C2 -> x ? (C2 << C1) if no overflow in <<
// This may require an adjustment to the constant on the right, see below.
// TODO: unsigned shift
{
Binary* shift;
Const* c1;
Const* c2;
if (matches(
curr,
binary(binary(&shift, ShrS, any(), ival(&c1)), ival(&c2)))) {
// Consider C2 << C1 and see if it can overflow or even change the
// sign. If it can't, then we know that removing the shift from the
// left side does not change the sign there (it's a signed-shift-
// right) and since neither does the left change sign, we can move the
// shift to the right side without altering the result.
auto shifts = Bits::getEffectiveShifts(c1);
auto c2MaxBits = Bits::getMaxBits(c2, this);
auto typeMaxBits = getBitsForType(type);
if (c2MaxBits + shifts < typeMaxBits) {
// Great, the reversed shift is in a range that does not cause any
// problems, so we can in principle try to reverse the operation
// by shifting both sides to the left, which will "undo" the
// existing shift on x, causing us to only have a shift on the
// right:
//
// ((x >> C1) << C1) ? (C2 << C1)
//
// However, an adjustment may be needed since the original Shr is
// not a linear operation - it clears the lower bits. That is, after
// the new shift, we have:
//
// (x & mask) ? (C2 << C1)
//
// Note that we don't want to optimize to this pattern, as it is
// strictly larger than the original (the mask is a larger constant
// than the shift, and the new constant on the right is larger). So
// we only want to optimize here if we can reduce this further,
// which we can in some cases.
//
// To implement the necessary adjustment, consider that the mask
// makes the lower bits not matter. In other words, we are rounding
// x down, and comparing it to a number that is similarly rounded
// (since it is the result of a left shift). As a result, an
// adjustment may be needed as a larger or smaller x may be
// possible, e.g., consider for < :
//
// signed(x & -4) < (100 << 2) = 400
//
// Consider x = 400. It has no lower bits to mask off, and 400 < 400
// which is false. For anything less than 400 the mask will round
// x down to something less than 400, so this will be true:
//
// (x & -4) < x < 400
//
// Thus we can optimize this to x < 400, and no special adjustment
// is necessary. However, consider <= instead of < :
//
// signed(x & -4) <= 400
//
// x = 400 results in 400 <= 400 which is true. But x can also be
// larger since the bits would get masked off. That is, for x in
// [401, 403], we get (x & -4) == 400. Only 404 would be too large.
// And so we can optimize to x <= 403, basically adding the bits to
// the constant on the right.
//
// Note that we cannot optimize == or != here, as e.g.
//
// (x & -4) == 400
//
// is true for all of [400, 403]. Only when we have a range can we
// extend the range with an adjustment, basically.
auto moveShift = [&]() {
// Helper function to remove the shift on the left and add a shift
// onto the constant on the right,
// (x >> C1) ? C2 => x ? (C2 << C1)
curr->left = shift->left;
c2->value = c2->value.shl(c1->value);
};
auto orLowerBits = [&]() {
c2->value = c2->value.or_(
Literal::makeFromInt64(Bits::lowBitMask64(shifts), type));
};
if (curr->op == Abstract::getBinary(type, LtS)) {
// Explained above.
moveShift();
return curr;
} else if (curr->op == Abstract::getBinary(type, LeS)) {
// Explained above.
moveShift();
orLowerBits();
return curr;
} else if (curr->op == Abstract::getBinary(type, GtS)) {
// E.g.
// signed(x & -4) > (100 << 2) = 400
// signed(x & -4) >= 401
// x & -4 is rounded down to a multiple of 4, so this is only true
// when x > 403.
moveShift();
orLowerBits();
return curr;
} else if (curr->op == Abstract::getBinary(type, GeS)) {
// E.g.
// signed(x & -4) >= (100 << 2) = 400
// x & -4 is rounded down to a multiple of 4, so this is only true
// when x >= 400, and no adjustment is needed.
moveShift();
return curr;
}
}
}
}

// Comparisons can sometimes be simplified depending on the number of
// bits, e.g. (unsigned)x > y must be true if x has strictly more bits.
// A common case is a constant on the right, e.g. (x & 255) < 256 must be
Expand Down
200 changes: 200 additions & 0 deletions test/lit/passes/optimize-instructions-shifts.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
;; NOTE: Assertions have been generated by update_lit_checks.py and should not be edited.
;; RUN: wasm-opt %s --optimize-instructions -S -o - | filecheck %s

(module
;; CHECK: (func $less-than-shifted (param $x i32) (param $y i64)
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (i32.lt_s
;; CHECK-NEXT: (local.get $x)
;; CHECK-NEXT: (i32.const 400)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (i64.lt_s
;; CHECK-NEXT: (local.get $y)
;; CHECK-NEXT: (i64.const 400)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $less-than-shifted (param $x i32) (param $y i64)
;; (x >> 2) < 100 => x < 400
(drop
(i32.lt_s
(i32.shr_s
(local.get $x)
(i32.const 2)
)
(i32.const 100)
)
)
;; As above, but with i64.
(drop
(i64.lt_s
(i64.shr_s
(local.get $y)
(i64.const 2)
)
(i64.const 100)
)
)
)

;; CHECK: (func $less-than-shifted-overflow (param $x i32)
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (i32.lt_s
;; CHECK-NEXT: (local.get $x)
;; CHECK-NEXT: (i32.const 2139095040)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (i32.lt_s
;; CHECK-NEXT: (i32.shr_s
;; CHECK-NEXT: (local.get $x)
;; CHECK-NEXT: (i32.const 24)
;; CHECK-NEXT: )
;; CHECK-NEXT: (i32.const 255)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (i32.lt_s
;; CHECK-NEXT: (i32.shr_s
;; CHECK-NEXT: (local.get $x)
;; CHECK-NEXT: (i32.const 25)
;; CHECK-NEXT: )
;; CHECK-NEXT: (i32.const 255)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $less-than-shifted-overflow (param $x i32)
;; Borderline values: we don't want the constant on the right, when shifted
;; by the number of shifts, to become signed, as that might alter the
;; result. This case can be optimized, and the ones after it not.
(drop
(i32.lt_s
(i32.shr_s
(local.get $x)
(i32.const 23)
)
(i32.const 255)
)
)
(drop
(i32.lt_s
(i32.shr_s
(local.get $x)
(i32.const 24)
)
(i32.const 255)
)
)
(drop
(i32.lt_s
(i32.shr_s
(local.get $x)
(i32.const 25)
)
(i32.const 255)
)
)
)

;; CHECK: (func $less-than-shifted-todo (param $x i32)
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (i32.lt_u
;; CHECK-NEXT: (i32.shr_s
;; CHECK-NEXT: (local.get $x)
;; CHECK-NEXT: (i32.const 2)
;; CHECK-NEXT: )
;; CHECK-NEXT: (i32.const 100)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (i32.lt_u
;; CHECK-NEXT: (i32.shr_u
;; CHECK-NEXT: (local.get $x)
;; CHECK-NEXT: (i32.const 2)
;; CHECK-NEXT: )
;; CHECK-NEXT: (i32.const 100)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $less-than-shifted-todo (param $x i32)
;; We don't optimize these yet.
;; This comparison is unsigned.
(drop
(i32.lt_u
(i32.shr_s
(local.get $x)
(i32.const 2)
)
(i32.const 100)
)
)
;; This shift is unsigned.
(drop
(i32.lt_s
(i32.shr_u
(local.get $x)
(i32.const 2)
)
(i32.const 100)
)
)
)

;; CHECK: (func $other-comparisons (param $x i32)
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (i32.le_s
;; CHECK-NEXT: (local.get $x)
;; CHECK-NEXT: (i32.const 403)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (i32.gt_s
;; CHECK-NEXT: (local.get $x)
;; CHECK-NEXT: (i32.const 403)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (i32.ge_s
;; CHECK-NEXT: (local.get $x)
;; CHECK-NEXT: (i32.const 400)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $other-comparisons (param $x i32)
;; <= :
;; (x >> 2) <= 100 => x <= 403
(drop
(i32.le_s
(i32.shr_s
(local.get $x)
(i32.const 2)
)
(i32.const 100)
)
)
;; > :
;; (x >> 2) > 100 => x > 403
(drop
(i32.gt_s
(i32.shr_s
(local.get $x)
(i32.const 2)
)
(i32.const 100)
)
)
;; >= :
;; (x >> 2) >= 100 => x >= 400
(drop
(i32.ge_s
(i32.shr_s
(local.get $x)
(i32.const 2)
)
(i32.const 100)
)
)
)
)
7 changes: 2 additions & 5 deletions test/lit/passes/optimize-instructions.wast
Original file line number Diff line number Diff line change
Expand Up @@ -2638,11 +2638,8 @@
)
;; CHECK: (func $lt_s-sext-zero (param $0 i32) (result i32)
;; CHECK-NEXT: (i32.lt_s
;; CHECK-NEXT: (i32.shr_s
;; CHECK-NEXT: (i32.shl
;; CHECK-NEXT: (local.get $0)
;; CHECK-NEXT: (i32.const 24)
;; CHECK-NEXT: )
;; CHECK-NEXT: (i32.shl
;; CHECK-NEXT: (local.get $0)
;; CHECK-NEXT: (i32.const 24)
;; CHECK-NEXT: )
;; CHECK-NEXT: (i32.const 0)
Expand Down