diff --git a/cranelift/codegen/src/opts/arithmetic.isle b/cranelift/codegen/src/opts/arithmetic.isle index ec40c310f4c8..400da6bab18d 100644 --- a/cranelift/codegen/src/opts/arithmetic.isle +++ b/cranelift/codegen/src/opts/arithmetic.isle @@ -151,3 +151,19 @@ (subsume (bxor ty (bxor ty a b) (bxor ty c d)))) (rule (simplify (bxor ty (bxor ty (bxor ty a b) c) d)) (subsume (bxor ty (bxor ty a b) (bxor ty c d)))) + +;; Detect people open-coding `mulhi`: (x as big * y as big) >> bits +;; LLVM doesn't have an intrinsic for it, so you'll see it in code like +;; +(rule (simplify (sshr ty (imul ty (sextend _ x@(value_type half_ty)) + (sextend _ y@(value_type half_ty))) + (iconst_u _ k))) + (if-let $true (ty_equal half_ty (ty_half_width ty))) + (if-let $true (u64_eq k (ty_bits_u64 half_ty))) + (sextend ty (smulhi half_ty x y))) +(rule (simplify (ushr ty (imul ty (uextend _ x@(value_type half_ty)) + (uextend _ y@(value_type half_ty))) + (iconst_u _ k))) + (if-let $true (ty_equal half_ty (ty_half_width ty))) + (if-let $true (u64_eq k (ty_bits_u64 half_ty))) + (uextend ty (umulhi half_ty x y))) diff --git a/cranelift/codegen/src/opts/extends.isle b/cranelift/codegen/src/opts/extends.isle index 17be06f456cb..40ecec2d29ff 100644 --- a/cranelift/codegen/src/opts/extends.isle +++ b/cranelift/codegen/src/opts/extends.isle @@ -58,3 +58,43 @@ (uextend bigty (bor smallty x y))) (rule (simplify (bxor bigty (uextend _ x@(value_type smallty)) (uextend _ y@(value_type smallty)))) (uextend bigty (bxor smallty x y))) + +;; Matches values where `ireducing` them will not actually introduce another +;; instruction, since other rules will collapse them with the reduction. +(decl pure multi will_simplify_with_ireduce (Value) Value) +(rule (will_simplify_with_ireduce x@(uextend _ _)) x) +(rule (will_simplify_with_ireduce x@(sextend _ _)) x) +(rule (will_simplify_with_ireduce x@(iconst _ _)) x) +(rule (will_simplify_with_ireduce x@(unary_op _ _ a)) + (if-let _ (will_simplify_with_ireduce a)) + x) +(rule (will_simplify_with_ireduce x@(binary_op _ _ a b)) + (if-let _ (will_simplify_with_ireduce a)) + (if-let _ (will_simplify_with_ireduce b)) + x) + +;; Matches values where the high bits of the input don't affect lower bits of +;; the output, and thus the inputs can be reduced before the operation rather +;; than doing the wide operation then reducing afterwards. +(decl pure multi reducible_modular_op (Value) Value) +(rule (reducible_modular_op x@(ineg _ _)) x) +(rule (reducible_modular_op x@(bnot _ _)) x) +(rule (reducible_modular_op x@(iadd _ _ _)) x) +(rule (reducible_modular_op x@(isub _ _ _)) x) +(rule (reducible_modular_op x@(imul _ _ _)) x) +(rule (reducible_modular_op x@(bor _ _ _)) x) +(rule (reducible_modular_op x@(bxor _ _ _)) x) +(rule (reducible_modular_op x@(band _ _ _)) x) + +;; Replace `(small)(x OP y)` with `(small)x OP (small)y` in cases where that's +;; legal and it reduces the total number of instructions since the reductions +;; to the arguments simplify further. +(rule (simplify (ireduce smallty val@(unary_op _ op x))) + (if-let _ (reducible_modular_op val)) + (if-let _ (will_simplify_with_ireduce x)) + (unary_op smallty op (ireduce smallty x))) +(rule (simplify (ireduce smallty val@(binary_op _ op x y))) + (if-let _ (reducible_modular_op val)) + (if-let _ (will_simplify_with_ireduce x)) + (if-let _ (will_simplify_with_ireduce y)) + (binary_op smallty op (ireduce smallty x) (ireduce smallty y))) diff --git a/cranelift/codegen/src/prelude_opt.isle b/cranelift/codegen/src/prelude_opt.isle index e0cd14860e88..1f02387d21f3 100644 --- a/cranelift/codegen/src/prelude_opt.isle +++ b/cranelift/codegen/src/prelude_opt.isle @@ -120,3 +120,15 @@ (extractor (sextend_maybe ty val) (sextend_maybe_etor ty val)) (rule 0 (sextend_maybe ty val) (sextend ty val)) (rule 1 (sextend_maybe ty val@(value_type ty)) val) + +(decl unary_op (Type Opcode Value) Value) +(extractor (unary_op ty opcode x) + (inst_data ty (InstructionData.Unary opcode x))) +(rule (unary_op ty opcode x) + (make_inst ty (InstructionData.Unary opcode x))) + +(decl binary_op (Type Opcode Value Value) Value) +(extractor (binary_op ty opcode x y) + (inst_data ty (InstructionData.Binary opcode (value_array_2 x y)))) +(rule (binary_op ty opcode x y) + (make_inst ty (InstructionData.Binary opcode (value_array_2_ctor x y)))) diff --git a/cranelift/filetests/filetests/egraph/arithmetic.clif b/cranelift/filetests/filetests/egraph/arithmetic.clif index 1abdeae1e1e8..1873f4b9b942 100644 --- a/cranelift/filetests/filetests/egraph/arithmetic.clif +++ b/cranelift/filetests/filetests/egraph/arithmetic.clif @@ -250,3 +250,108 @@ block0(v1: f32, v2: f32): ; check: v6 = fmul v1, v2 ; check: return v6 + +function %manual_smulhi_32(i32, i32) -> i32 { +block0(v0: i32, v1: i32): + v2 = sextend.i64 v0 + v3 = sextend.i64 v1 + v4 = imul v2, v3 + v5 = iconst.i32 32 + v6 = sshr v4, v5 + v7 = ireduce.i32 v6 + return v7 +} + +; check: v8 = smulhi v0, v1 +; check: return v8 + +function %manual_smulhi_64(i64, i64) -> i64 { +block0(v0: i64, v1: i64): + v2 = sextend.i128 v0 + v3 = sextend.i128 v1 + v4 = imul v2, v3 + v5 = iconst.i32 64 + v6 = sshr v4, v5 + v7 = ireduce.i64 v6 + return v7 +} + +; check: v8 = smulhi v0, v1 +; check: return v8 + +function %manual_umulhi_32(i32, i32) -> i32 { +block0(v0: i32, v1: i32): + v2 = uextend.i64 v0 + v3 = uextend.i64 v1 + v4 = imul v2, v3 + v5 = iconst.i32 32 + v6 = ushr v4, v5 + v7 = ireduce.i32 v6 + return v7 +} + +; check: v8 = umulhi v0, v1 +; check: return v8 + +function %manual_umulhi_64(i64, i64) -> i64 { +block0(v0: i64, v1: i64): + v2 = uextend.i128 v0 + v3 = uextend.i128 v1 + v4 = imul v2, v3 + v5 = iconst.i32 64 + v6 = ushr v4, v5 + v7 = ireduce.i64 v6 + return v7 +} + +; check: v8 = umulhi v0, v1 +; check: return v8 + +function %u64_widening_mul(i64, i64, i64) { +block0(v0: i64, v1: i64, v2: i64): + v3 = uextend.i128 v1 + v4 = uextend.i128 v2 + v5 = imul v3, v4 + v6 = iconst.i32 64 + v7 = ushr v5, v6 + v8 = ireduce.i64 v7 + v9 = ireduce.i64 v5 + store.i64 v9, v0 + store.i64 v8, v0+8 + return +} + +; check: v18 = imul v1, v2 +; check: store v18, v0 +; check: v10 = umulhi v1, v2 +; check: store v10, v0+8 + +function %char_plus_one(i8) -> i8 { +block0(v0: i8): + v1 = sextend.i32 v0 + v2 = iconst.i32 257 + v3 = iadd v1, v2 + v4 = ireduce.i8 v3 + return v4 +} + +; check: v8 = iconst.i8 1 +; check: v9 = iadd v0, v8 ; v8 = 1 +; check: return v9 + +;; Adding three `short`s together and storing them in a `short`, +;; which in C involves extending them to `int`s in the middle. +function %extend_iadd_iadd_reduce(i16, i16, i16) -> i16 { +block0(v0: i16, v1: i16, v2: i16): + v3 = sextend.i32 v0 + v4 = sextend.i32 v1 + v5 = sextend.i32 v2 + v6 = iadd v3, v4 + v7 = iadd v6, v5 + v8 = ireduce.i16 v7 + return v8 +} + +; check: v14 = iadd v0, v1 +; check: v18 = iadd v14, v2 +; check: return v18 diff --git a/cranelift/filetests/filetests/egraph/extends.clif b/cranelift/filetests/filetests/egraph/extends.clif index 6562dc58d46b..0cb790f4d069 100644 --- a/cranelift/filetests/filetests/egraph/extends.clif +++ b/cranelift/filetests/filetests/egraph/extends.clif @@ -118,3 +118,72 @@ block0(v0: i8): ; check: v5 = icmp ne v0, v4 ; check: return v5 +function %extend_imul_reduce(i64, i64) -> i64 { +block0(v0: i64, v1: i64): + v2 = uextend.i128 v0 + v3 = uextend.i128 v1 + v4 = imul v2, v3 + v5 = ireduce.i64 v4 + return v5 +} + +; check: v10 = imul v0, v1 +; check: return v10 + +function %extend_iadd_reduce(i16, i16) -> i16 { +block0(v0: i16, v1: i16): + v2 = sextend.i32 v0 + v3 = sextend.i32 v1 + v4 = iadd v2, v3 + v5 = ireduce.i16 v4 + return v5 +} + +; check: v10 = iadd v0, v1 +; check: return v10 + +function %extend_bxor_reduce(i64, i64) -> i64 { +block0(v0: i64, v1: i64): + v2 = uextend.i128 v0 + v3 = uextend.i128 v1 + v4 = bxor v2, v3 + v5 = ireduce.i64 v4 + return v5 +} + +; check: v6 = bxor v0, v1 +; check: return v6 + +function %extend_band_reduce(i16, i16) -> i16 { +block0(v0: i16, v1: i16): + v2 = sextend.i32 v0 + v3 = sextend.i32 v1 + v4 = band v2, v3 + v5 = ireduce.i16 v4 + return v5 +} + +; check: v10 = band v0, v1 +; check: return v10 + +function %extend_ineg_reduce(i64) -> i64 { +block0(v0: i64): + v1 = sextend.i128 v0 + v2 = ineg v1 + v3 = ireduce.i64 v2 + return v3 +} + +; check: v6 = ineg v0 +; check: return v6 + +function %extend_bnot_reduce(i16) -> i16 { +block0(v0: i16): + v1 = uextend.i32 v0 + v2 = bnot v1 + v3 = ireduce.i16 v2 + return v3 +} + +; check: v6 = bnot v0 +; check: return v6