From a935d5b0465601e062e9373021bfc4c93ee5bcd1 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 7 Dec 2021 13:32:22 -0800 Subject: [PATCH 1/4] Let lerp lowering incorporate a final cast This lets it save a few instructions on x86 and arm. cast(UInt(16), lerp(some_u8s)) produces the following, before and after this PR Before: x86: vmovdqu (%r15,%r13), %xmm4 vpmovzxbw -2(%r15,%r13), %ymm5 vpxor %xmm0, %xmm4, %xmm6 vpmovzxbw %xmm6, %ymm6 vpmovzxbw -1(%r15,%r13), %ymm7 vpmullw %ymm6, %ymm5, %ymm5 vpmovzxbw %xmm4, %ymm4 vpmullw %ymm4, %ymm7, %ymm4 vpaddw %ymm4, %ymm5, %ymm4 vpaddw %ymm1, %ymm4, %ymm4 vpmulhuw %ymm2, %ymm4, %ymm4 vpsrlw $7, %ymm4, %ymm4 vpand %ymm3, %ymm4, %ymm4 vmovdqu %ymm4, (%rbx,%r13,2) addq $16, %r13 decq %r10 jne .LBB0_10 arm: ldr q0, [x17] ldur q2, [x17, #-1] ldur q1, [x17, #-2] subs x0, x0, #1 // =1 mvn v3.16b, v0.16b umull v4.8h, v2.8b, v0.8b umull2 v0.8h, v2.16b, v0.16b umlal v4.8h, v1.8b, v3.8b umlal2 v0.8h, v1.16b, v3.16b urshr v1.8h, v4.8h, #8 urshr v2.8h, v0.8h, #8 raddhn v1.8b, v1.8h, v4.8h raddhn v0.8b, v2.8h, v0.8h ushll v0.8h, v0.8b, #0 ushll v1.8h, v1.8b, #0 add x17, x17, #16 // =16 stp q1, q0, [x18, #-16] add x18, x18, #32 // =32 b.ne .LBB0_10 After: x86: vpmovzxbw -2(%r15,%r13), %ymm3 vmovdqu (%r15,%r13), %xmm4 vpxor %xmm0, %xmm4, %xmm5 vpmovzxbw %xmm5, %ymm5 vpmullw %ymm5, %ymm3, %ymm3 vpmovzxbw -1(%r15,%r13), %ymm5 vpmovzxbw %xmm4, %ymm4 vpmullw %ymm4, %ymm5, %ymm4 vpaddw %ymm4, %ymm3, %ymm3 vpaddw %ymm1, %ymm3, %ymm3 vpmulhuw %ymm2, %ymm3, %ymm3 vpsrlw $7, %ymm3, %ymm3 vmovdqu %ymm3, (%rbp,%r13,2) addq $16, %r13 decq %r10 jne .LBB0_10 arm: ldr q0, [x17] ldur q2, [x17, #-1] ldur q1, [x17, #-2] subs x0, x0, #1 // =1 mvn v3.16b, v0.16b umull v4.8h, v2.8b, v0.8b umull2 v0.8h, v2.16b, v0.16b umlal v4.8h, v1.8b, v3.8b umlal2 v0.8h, v1.16b, v3.16b ursra v4.8h, v4.8h, #8 ursra v0.8h, v0.8h, #8 urshr v1.8h, v4.8h, #8 urshr v0.8h, v0.8h, #8 add x17, x17, #16 // =16 stp q1, q0, [x18, #-16] add x18, x18, #32 // =32 b.ne .LBB0_10 So on X86 we skip a pointless and instruction, and on ARM we get a rounding add and shift right instead of a rounding narrowing add shift right followed by a widen. --- src/CodeGen_C.cpp | 2 +- src/CodeGen_LLVM.cpp | 18 ++++++++++++++++-- src/HexagonOptimize.cpp | 2 +- src/Lerp.cpp | 7 +++++-- src/Lerp.h | 8 +++++--- 5 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index 50a25f6748f4..58bacde63cb5 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -2314,7 +2314,7 @@ void CodeGen_C::visit(const Call *op) { } } else if (op->is_intrinsic(Call::lerp)) { internal_assert(op->args.size() == 3); - Expr e = lower_lerp(op->args[0], op->args[1], op->args[2], target); + Expr e = lower_lerp(op->type, op->args[0], op->args[1], op->args[2], target); rhs << print_expr(e); } else if (op->is_intrinsic(Call::absd)) { internal_assert(op->args.size() == 2); diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 0d241ea02ea1..0726c9881830 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1392,6 +1392,20 @@ void CodeGen_LLVM::visit(const Cast *op) { return; } + if (const Call *c = Call::as_intrinsic(op->value, {Call::lerp})) { + // We want to codegen a cast of a lerp as a single thing, because it can + // be done more intelligently than a lerp followed by a cast. + Type t = upgrade_type_for_arithmetic(c->type); + Type wt = upgrade_type_for_arithmetic(c->args[2].type()); + Expr e = lower_lerp(op->type, + cast(t, c->args[0]), + cast(t, c->args[1]), + cast(wt, c->args[2]), + target); + codegen(e); + return; + } + value = codegen(op->value); llvm::Type *llvm_dst = llvm_type_of(dst); @@ -2698,11 +2712,11 @@ void CodeGen_LLVM::visit(const Call *op) { // TODO: This might be surprising behavior? Type t = upgrade_type_for_arithmetic(op->type); Type wt = upgrade_type_for_arithmetic(op->args[2].type()); - Expr e = lower_lerp(cast(t, op->args[0]), + Expr e = lower_lerp(op->type, + cast(t, op->args[0]), cast(t, op->args[1]), cast(wt, op->args[2]), target); - e = cast(op->type, e); codegen(e); } else if (op->is_intrinsic(Call::popcount)) { internal_assert(op->args.size() == 1); diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index 7c18bd6f9a73..3749a9434b42 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -991,7 +991,7 @@ class OptimizePatterns : public IRMutator { // We need to lower lerps now to optimize the arithmetic // that they generate. internal_assert(op->args.size() == 3); - return mutate(lower_lerp(op->args[0], op->args[1], op->args[2], target)); + return mutate(lower_lerp(op->type, op->args[0], op->args[1], op->args[2], target)); } else if ((op->is_intrinsic(Call::div_round_to_zero) || op->is_intrinsic(Call::mod_round_to_zero)) && !op->type.is_float() && op->type.is_vector()) { diff --git a/src/Lerp.cpp b/src/Lerp.cpp index 5134b5619c3e..da7177436ae3 100644 --- a/src/Lerp.cpp +++ b/src/Lerp.cpp @@ -11,7 +11,7 @@ namespace Halide { namespace Internal { -Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &target) { +Expr lower_lerp(Type final_type, Expr zero_val, Expr one_val, const Expr &weight, const Target &target) { Expr result; @@ -153,7 +153,6 @@ Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &t } else { result = rounding_shift_right(rounding_shift_right(prod_sum, bits) + prod_sum, bits); } - result = Cast::make(UInt(bits, computation_type.lanes()), result); break; } case 64: @@ -172,6 +171,10 @@ Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &t } } + if (result.type() != final_type) { + result = Cast::make(final_type, result); + } + return simplify(common_subexpression_elimination(result)); } diff --git a/src/Lerp.h b/src/Lerp.h index 163d20c312f6..5d9e5a044879 100644 --- a/src/Lerp.h +++ b/src/Lerp.h @@ -13,9 +13,11 @@ struct Target; namespace Internal { -/** Build Halide IR that computes a lerp. Use by codegen targets that - * don't have a native lerp. */ -Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &target); +/** Build Halide IR that computes a lerp. Use by codegen targets that don't have + * a native lerp. The lerp is done in the type of the zero value. The final_type + * is a cast that should occur after the lerp. It's included because in some + * cases you can incorporate a final cast into the lerp math. */ +Expr lower_lerp(Type final_type, Expr zero_val, Expr one_val, const Expr &weight, const Target &target); } // namespace Internal } // namespace Halide From 675303c4cdf94929c638e7196235739c19826736 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 7 Dec 2021 13:32:53 -0800 Subject: [PATCH 2/4] Add test --- test/correctness/CMakeLists.txt | 1 + test/correctness/widening_lerp.cpp | 62 ++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 test/correctness/widening_lerp.cpp diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 63ab09e5eab3..60016b67d4c3 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -355,6 +355,7 @@ tests(GROUPS correctness vectorized_initialization.cpp vectorized_load_from_vectorized_allocation.cpp vectorized_reduction_bug.cpp + widening_lerp.cpp widening_reduction.cpp ) diff --git a/test/correctness/widening_lerp.cpp b/test/correctness/widening_lerp.cpp new file mode 100644 index 000000000000..690dbcfe5f2d --- /dev/null +++ b/test/correctness/widening_lerp.cpp @@ -0,0 +1,62 @@ +#include "Halide.h" + +using namespace Halide; + +std::mt19937 rng(0); + +int main(int argc, char **argv) { + + int fuzz_seed = argc > 1 ? atoi(argv[1]) : time(nullptr); + rng.seed(fuzz_seed); + printf("Lerp test seed: %d\n", fuzz_seed); + + // Lerp lowering incorporates a cast. This test checks that a widening lerp + // is equal to the widened version of the lerp. + for (Type t1 : {UInt(8), UInt(16), UInt(32), Int(8), Int(16), Int(32), Float(32)}) { + if (rng() & 1) continue; + for (Type t2 : {UInt(8), UInt(16), UInt(32), Float(32)}) { + if (rng() & 1) continue; + for (Type t3 : {UInt(8), UInt(16), UInt(32), Int(8), Int(16), Int(32), Float(32)}) { + if (rng() & 1) continue; + Func f; + Var x; + f(x) = cast(t1, random_uint((int)rng())); + + Expr weight = cast(t2, f(x + 16)); + if (t2.is_float()) { + weight /= 256.f; + weight = clamp(weight, 0.f, 1.f); + } + + Expr lerped = lerp(f(x), f(x + 8), cast(t2, f(x + 16))); + + Func cast_and_lerp, lerp_alone, cast_of_lerp; + cast_and_lerp(x) = cast(t3, lerped); + lerp_alone(x) = lerped; + cast_of_lerp(x) = cast(t3, lerp_alone(x)); + + RDom r(0, 32 * 1024); + Func check; + check() = maximum(abs(cast(cast_and_lerp(r)) - + cast(cast_of_lerp(r)))); + + f.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + lerp_alone.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + cast_and_lerp.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + cast_of_lerp.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + + double err = evaluate(check()); + + if (err > 1e-5) { + printf("Difference of lerp + cast and lerp alone is %f," + " which exceeds threshold for seed %d\n", + err, fuzz_seed); + return -1; + } + } + } + } + + printf("Success!\n"); + return 0; +} From 8251c5be1fc031f9e88bcb89cbb38facd7f61ed1 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 9 Dec 2021 05:17:47 -0800 Subject: [PATCH 3/4] Fix bug in test --- test/correctness/widening_lerp.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/correctness/widening_lerp.cpp b/test/correctness/widening_lerp.cpp index 690dbcfe5f2d..a531c16ed196 100644 --- a/test/correctness/widening_lerp.cpp +++ b/test/correctness/widening_lerp.cpp @@ -28,7 +28,9 @@ int main(int argc, char **argv) { weight = clamp(weight, 0.f, 1.f); } - Expr lerped = lerp(f(x), f(x + 8), cast(t2, f(x + 16))); + Expr zero_val = f(x); + Expr one_val = f(x + 8); + Expr lerped = lerp(zero_val, one_val, weight); Func cast_and_lerp, lerp_alone, cast_of_lerp; cast_and_lerp(x) = cast(t3, lerped); From c54f4a4e9bbff4666cb49827bad490e6e78e2236 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 10 Dec 2021 05:12:56 -0800 Subject: [PATCH 4/4] Don't produce out-of-range lerp values --- src/Lerp.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/Lerp.cpp b/src/Lerp.cpp index da7177436ae3..b36d31c4e84d 100644 --- a/src/Lerp.cpp +++ b/src/Lerp.cpp @@ -164,6 +164,13 @@ Expr lower_lerp(Type final_type, Expr zero_val, Expr one_val, const Expr &weight default: break; } + + if (weight.type().is_float()) { + // Insert an explicit cast to the computation type, even if + // we're going to widen, because out-of-range floats can produce + // out-of-range outputs. + result = Cast::make(computation_type, result); + } } if (!is_const_zero(bias_value)) {