diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index 50a25f6748f4..58bacde63cb5 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -2314,7 +2314,7 @@ void CodeGen_C::visit(const Call *op) { } } else if (op->is_intrinsic(Call::lerp)) { internal_assert(op->args.size() == 3); - Expr e = lower_lerp(op->args[0], op->args[1], op->args[2], target); + Expr e = lower_lerp(op->type, op->args[0], op->args[1], op->args[2], target); rhs << print_expr(e); } else if (op->is_intrinsic(Call::absd)) { internal_assert(op->args.size() == 2); diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 0d241ea02ea1..0726c9881830 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1392,6 +1392,20 @@ void CodeGen_LLVM::visit(const Cast *op) { return; } + if (const Call *c = Call::as_intrinsic(op->value, {Call::lerp})) { + // We want to codegen a cast of a lerp as a single thing, because it can + // be done more intelligently than a lerp followed by a cast. + Type t = upgrade_type_for_arithmetic(c->type); + Type wt = upgrade_type_for_arithmetic(c->args[2].type()); + Expr e = lower_lerp(op->type, + cast(t, c->args[0]), + cast(t, c->args[1]), + cast(wt, c->args[2]), + target); + codegen(e); + return; + } + value = codegen(op->value); llvm::Type *llvm_dst = llvm_type_of(dst); @@ -2698,11 +2712,11 @@ void CodeGen_LLVM::visit(const Call *op) { // TODO: This might be surprising behavior? Type t = upgrade_type_for_arithmetic(op->type); Type wt = upgrade_type_for_arithmetic(op->args[2].type()); - Expr e = lower_lerp(cast(t, op->args[0]), + Expr e = lower_lerp(op->type, + cast(t, op->args[0]), cast(t, op->args[1]), cast(wt, op->args[2]), target); - e = cast(op->type, e); codegen(e); } else if (op->is_intrinsic(Call::popcount)) { internal_assert(op->args.size() == 1); diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index 7c18bd6f9a73..3749a9434b42 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -991,7 +991,7 @@ class OptimizePatterns : public IRMutator { // We need to lower lerps now to optimize the arithmetic // that they generate. internal_assert(op->args.size() == 3); - return mutate(lower_lerp(op->args[0], op->args[1], op->args[2], target)); + return mutate(lower_lerp(op->type, op->args[0], op->args[1], op->args[2], target)); } else if ((op->is_intrinsic(Call::div_round_to_zero) || op->is_intrinsic(Call::mod_round_to_zero)) && !op->type.is_float() && op->type.is_vector()) { diff --git a/src/Lerp.cpp b/src/Lerp.cpp index 5134b5619c3e..b36d31c4e84d 100644 --- a/src/Lerp.cpp +++ b/src/Lerp.cpp @@ -11,7 +11,7 @@ namespace Halide { namespace Internal { -Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &target) { +Expr lower_lerp(Type final_type, Expr zero_val, Expr one_val, const Expr &weight, const Target &target) { Expr result; @@ -153,7 +153,6 @@ Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &t } else { result = rounding_shift_right(rounding_shift_right(prod_sum, bits) + prod_sum, bits); } - result = Cast::make(UInt(bits, computation_type.lanes()), result); break; } case 64: @@ -165,6 +164,13 @@ Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &t default: break; } + + if (weight.type().is_float()) { + // Insert an explicit cast to the computation type, even if + // we're going to widen, because out-of-range floats can produce + // out-of-range outputs. + result = Cast::make(computation_type, result); + } } if (!is_const_zero(bias_value)) { @@ -172,6 +178,10 @@ Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &t } } + if (result.type() != final_type) { + result = Cast::make(final_type, result); + } + return simplify(common_subexpression_elimination(result)); } diff --git a/src/Lerp.h b/src/Lerp.h index 163d20c312f6..5d9e5a044879 100644 --- a/src/Lerp.h +++ b/src/Lerp.h @@ -13,9 +13,11 @@ struct Target; namespace Internal { -/** Build Halide IR that computes a lerp. Use by codegen targets that - * don't have a native lerp. */ -Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &target); +/** Build Halide IR that computes a lerp. Use by codegen targets that don't have + * a native lerp. The lerp is done in the type of the zero value. The final_type + * is a cast that should occur after the lerp. It's included because in some + * cases you can incorporate a final cast into the lerp math. */ +Expr lower_lerp(Type final_type, Expr zero_val, Expr one_val, const Expr &weight, const Target &target); } // namespace Internal } // namespace Halide diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 63ab09e5eab3..60016b67d4c3 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -355,6 +355,7 @@ tests(GROUPS correctness vectorized_initialization.cpp vectorized_load_from_vectorized_allocation.cpp vectorized_reduction_bug.cpp + widening_lerp.cpp widening_reduction.cpp ) diff --git a/test/correctness/widening_lerp.cpp b/test/correctness/widening_lerp.cpp new file mode 100644 index 000000000000..a531c16ed196 --- /dev/null +++ b/test/correctness/widening_lerp.cpp @@ -0,0 +1,64 @@ +#include "Halide.h" + +using namespace Halide; + +std::mt19937 rng(0); + +int main(int argc, char **argv) { + + int fuzz_seed = argc > 1 ? atoi(argv[1]) : time(nullptr); + rng.seed(fuzz_seed); + printf("Lerp test seed: %d\n", fuzz_seed); + + // Lerp lowering incorporates a cast. This test checks that a widening lerp + // is equal to the widened version of the lerp. + for (Type t1 : {UInt(8), UInt(16), UInt(32), Int(8), Int(16), Int(32), Float(32)}) { + if (rng() & 1) continue; + for (Type t2 : {UInt(8), UInt(16), UInt(32), Float(32)}) { + if (rng() & 1) continue; + for (Type t3 : {UInt(8), UInt(16), UInt(32), Int(8), Int(16), Int(32), Float(32)}) { + if (rng() & 1) continue; + Func f; + Var x; + f(x) = cast(t1, random_uint((int)rng())); + + Expr weight = cast(t2, f(x + 16)); + if (t2.is_float()) { + weight /= 256.f; + weight = clamp(weight, 0.f, 1.f); + } + + Expr zero_val = f(x); + Expr one_val = f(x + 8); + Expr lerped = lerp(zero_val, one_val, weight); + + Func cast_and_lerp, lerp_alone, cast_of_lerp; + cast_and_lerp(x) = cast(t3, lerped); + lerp_alone(x) = lerped; + cast_of_lerp(x) = cast(t3, lerp_alone(x)); + + RDom r(0, 32 * 1024); + Func check; + check() = maximum(abs(cast(cast_and_lerp(r)) - + cast(cast_of_lerp(r)))); + + f.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + lerp_alone.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + cast_and_lerp.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + cast_of_lerp.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + + double err = evaluate(check()); + + if (err > 1e-5) { + printf("Difference of lerp + cast and lerp alone is %f," + " which exceeds threshold for seed %d\n", + err, fuzz_seed); + return -1; + } + } + } + } + + printf("Success!\n"); + return 0; +}