diff --git a/Makefile b/Makefile index 10bbed901d53..3fe4f2acfc65 100644 --- a/Makefile +++ b/Makefile @@ -513,6 +513,7 @@ SOURCE_FILES = \ RDom.cpp \ Realization.cpp \ RealizationOrder.cpp \ + RebaseLoopsToZero.cpp \ Reduction.cpp \ RegionCosts.cpp \ RemoveDeadAllocations.cpp \ @@ -689,6 +690,7 @@ HEADER_FILES = \ Realization.h \ RDom.h \ RealizationOrder.h \ + RebaseLoopsToZero.h \ Reduction.h \ RegionCosts.h \ RemoveDeadAllocations.h \ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2ffdf85a51d4..e2400c9ae1d7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -121,6 +121,7 @@ set(HEADER_FILES RDom.h Realization.h RealizationOrder.h + RebaseLoopsToZero.h Reduction.h RegionCosts.h RemoveDeadAllocations.h @@ -277,6 +278,7 @@ set(SOURCE_FILES RDom.cpp Realization.cpp RealizationOrder.cpp + RebaseLoopsToZero.cpp Reduction.cpp RegionCosts.cpp RemoveDeadAllocations.cpp diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index daec6cbf9741..c55a7503dd21 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -2612,6 +2612,7 @@ void CodeGen_C::visit(const Allocate *op) { size_id = print_expr(make_const(size_id_type, constant_size)); if (op->memory_type == MemoryType::Stack || + op->memory_type == MemoryType::Register || (op->memory_type == MemoryType::Auto && can_allocation_fit_on_stack(stack_bytes))) { on_stack = true; diff --git a/src/Lower.cpp b/src/Lower.cpp index a7227b6ec76d..da4e77d2b430 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -47,6 +47,7 @@ #include "PurifyIndexMath.h" #include "Qualify.h" #include "RealizationOrder.h" +#include "RebaseLoopsToZero.h" #include "RemoveDeadAllocations.h" #include "RemoveExternLoops.h" #include "RemoveUndef.h" @@ -344,6 +345,11 @@ Module lower(const vector &output_funcs, s = trim_no_ops(s); log("Lowering after loop trimming:", s); + debug(1) << "Rebasing loops to zero...\n"; + s = rebase_loops_to_zero(s); + debug(2) << "Lowering after rebasing loops to zero:\n" + << s << "\n\n"; + debug(1) << "Hoisting loop invariant if statements...\n"; s = hoist_loop_invariant_if_statements(s); log("Lowering after hoisting loop invariant if statements:", s); diff --git a/src/PartitionLoops.cpp b/src/PartitionLoops.cpp index c1e0d1fb7bfb..6d38d3267efa 100644 --- a/src/PartitionLoops.cpp +++ b/src/PartitionLoops.cpp @@ -329,28 +329,40 @@ class FindSimplifications : public IRVisitor { } } - void visit(const Select *op) override { - op->condition.accept(this); + void visit_select(const Expr &condition, const Expr &old, const Expr &true_value, const Expr &false_value) { + condition.accept(this); - bool likely_t = has_uncaptured_likely_tag(op->true_value); - bool likely_f = has_uncaptured_likely_tag(op->false_value); + bool likely_t = has_uncaptured_likely_tag(true_value); + bool likely_f = has_uncaptured_likely_tag(false_value); if (!likely_t && !likely_f) { - likely_t = has_likely_tag(op->true_value); - likely_f = has_likely_tag(op->false_value); + likely_t = has_likely_tag(true_value); + likely_f = has_likely_tag(false_value); } if (!likely_t) { - op->false_value.accept(this); + false_value.accept(this); } if (!likely_f) { - op->true_value.accept(this); + true_value.accept(this); } if (likely_t && !likely_f) { - new_simplification(op->condition, op, op->true_value, op->false_value); + new_simplification(condition, old, true_value, false_value); } else if (likely_f && !likely_t) { - new_simplification(!op->condition, op, op->false_value, op->true_value); + new_simplification(!condition, old, false_value, true_value); + } + } + + void visit(const Select *op) override { + visit_select(op->condition, op, op->true_value, op->false_value); + } + + void visit(const Call *op) override { + if (op->is_intrinsic(Call::if_then_else)) { + visit_select(op->args[0], op, op->args[1], op->args[2]); + } else { + IRVisitor::visit(op); } } diff --git a/src/RebaseLoopsToZero.cpp b/src/RebaseLoopsToZero.cpp new file mode 100644 index 000000000000..4dccde1e0b9c --- /dev/null +++ b/src/RebaseLoopsToZero.cpp @@ -0,0 +1,54 @@ +#include "RebaseLoopsToZero.h" +#include "IRMutator.h" +#include "IROperator.h" + +namespace Halide { +namespace Internal { + +using std::string; + +namespace { + +bool should_rebase(ForType type) { + switch (type) { + case ForType::Extern: + case ForType::GPUBlock: + case ForType::GPUThread: + case ForType::GPULane: + return false; + default: + return true; + } +} + +class RebaseLoopsToZero : public IRMutator { + using IRMutator::visit; + + Stmt visit(const For *op) override { + if (!should_rebase(op->for_type)) { + return IRMutator::visit(op); + } + Stmt body = mutate(op->body); + string name = op->name; + if (!is_const_zero(op->min)) { + // Renaming the loop (intentionally) invalidates any .loop_min/.loop_max lets. + name = op->name + ".rebased"; + Expr loop_var = Variable::make(Int(32), name); + body = LetStmt::make(op->name, loop_var + op->min, body); + } + if (body.same_as(op->body)) { + return op; + } else { + return For::make(name, 0, op->extent, op->for_type, op->device_api, body); + } + } +}; + +} // namespace + +Stmt rebase_loops_to_zero(const Stmt &s) { + return RebaseLoopsToZero().mutate(s); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/RebaseLoopsToZero.h b/src/RebaseLoopsToZero.h new file mode 100644 index 000000000000..930d62ae3f8d --- /dev/null +++ b/src/RebaseLoopsToZero.h @@ -0,0 +1,19 @@ +#ifndef HALIDE_REBASE_LOOPS_TO_ZERO_H +#define HALIDE_REBASE_LOOPS_TO_ZERO_H + +/** \file + * Defines the lowering pass that rewrites loop mins to be 0. + */ + +#include "Expr.h" + +namespace Halide { +namespace Internal { + +/** Rewrite the mins of most loops to 0. */ +Stmt rebase_loops_to_zero(const Stmt &); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/Simplify_LT.cpp b/src/Simplify_LT.cpp index 2b15ae9877de..b8f676e88a95 100644 --- a/src/Simplify_LT.cpp +++ b/src/Simplify_LT.cpp @@ -285,9 +285,43 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) { rewrite(select(y, z, x + c0) < x + c1, y && (z < x + c1), c0 >= c1) || rewrite(select(y, z, x + c0) < x + c1, !y || (z < x + c1), c0 < c1) || + rewrite(c0 < select(x, c1, c2), select(x, fold(c0 < c1), fold(c0 < c2))) || + rewrite(select(x, c1, c2) < c0, select(x, fold(c1 < c0), fold(c2 < c0))) || + // Normalize comparison of ramps to a comparison of a ramp and a broadacst - rewrite(ramp(x, y, lanes) < ramp(z, w, lanes), ramp(x - z, y - w, lanes) < 0))) || + rewrite(ramp(x, y, lanes) < ramp(z, w, lanes), ramp(x - z, y - w, lanes) < 0) || + + // Rules of the form: + // rewrite(ramp(x, y, lanes) < broadcast(z, lanes), ramp(x - z, y, lanes) < 0) || + // where x and z cancel usefully + rewrite(ramp(x + z, y, lanes) < broadcast(x + w, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) || + rewrite(ramp(z + x, y, lanes) < broadcast(x + w, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) || + rewrite(ramp(x + z, y, lanes) < broadcast(w + x, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) || + rewrite(ramp(z + x, y, lanes) < broadcast(w + x, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) || + + // z = 0 + rewrite(ramp(x, y, lanes) < broadcast(x + w, lanes), ramp(0, y, lanes) < broadcast(w, lanes)) || + rewrite(ramp(x, y, lanes) < broadcast(w + x, lanes), ramp(0, y, lanes) < broadcast(w, lanes)) || + + // w = 0 + rewrite(ramp(x + z, y, lanes) < broadcast(x, lanes), ramp(z, y, lanes) < 0) || + rewrite(ramp(z + x, y, lanes) < broadcast(x, lanes), ramp(z, y, lanes) < 0) || + + // With the args flipped + rewrite(broadcast(x + w, lanes) < ramp(x + z, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) || + rewrite(broadcast(x + w, lanes) < ramp(z + x, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) || + rewrite(broadcast(w + x, lanes) < ramp(x + z, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) || + rewrite(broadcast(w + x, lanes) < ramp(z + x, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) || + + // z = 0 + rewrite(broadcast(x + w, lanes) < ramp(x, y, lanes), broadcast(w, lanes) < ramp(0, y, lanes)) || + rewrite(broadcast(w + x, lanes) < ramp(x, y, lanes), broadcast(w, lanes) < ramp(0, y, lanes)) || + + // w = 0 + rewrite(broadcast(x, lanes) < ramp(x + z, y, lanes), 0 < ramp(z, y, lanes)) || + rewrite(broadcast(x, lanes) < ramp(z + x, y, lanes), 0 < ramp(z, y, lanes)) || + false)) || (no_overflow_int(ty) && EVAL_IN_LAMBDA (rewrite(x * c0 < y * c1, x < y * fold(c1 / c0), c1 % c0 == 0 && c0 > 0) || rewrite(x * c0 < y * c1, x * fold(c0 / c1) < y, c0 % c1 == 0 && c1 > 0) || diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 9e1b7114eedb..7d93283dde6e 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -119,6 +119,102 @@ bool find_produce(const Stmt &s, const string &func) { return finder.found; } +// This mutator rewrites calls and provides to a particular +// func: +// - Calls and Provides are shifted to be relative to the min. +// - Provides additionally are rewritten to load values from the +// previous iteration of the loop if they were computed in the +// last iteration. +class RollFunc : public IRMutator { + const Function &func; + int dim; + const string &loop_var; + const Interval &old_bounds; + const Interval &new_bounds; + + Scope scope; + + // It helps simplify the shifted calls/provides to rebase the + // loops that are subtracted from to have a min of 0. + set loops_to_rebase; + bool in_produce = false; + + using IRMutator::visit; + + Stmt visit(const ProducerConsumer *op) override { + bool produce_func = op->name == func.name() && op->is_producer; + ScopedValue old_in_produce(in_produce, in_produce || produce_func); + return IRMutator::visit(op); + } + + Stmt visit(const Provide *op) override { + if (!(in_produce && op->name == func.name())) { + return IRMutator::visit(op); + } + vector values = op->values; + for (Expr &i : values) { + i = mutate(i); + } + vector args = op->args; + for (Expr &i : args) { + i = mutate(i); + } + bool sliding_up = old_bounds.max.same_as(new_bounds.max); + Expr is_new = sliding_up ? new_bounds.min <= args[dim] : args[dim] <= new_bounds.max; + args[dim] -= old_bounds.min; + vector old_args = args; + Expr old_arg_dim = expand_expr(old_args[dim], scope); + old_args[dim] = substitute(loop_var, Variable::make(Int(32), loop_var) - 1, old_arg_dim); + for (int i = 0; i < (int)values.size(); i++) { + Type t = values[i].type(); + Expr old_value = + Call::make(t, op->name, old_args, Call::Halide, func.get_contents(), i); + values[i] = Call::make(values[i].type(), Call::if_then_else, {is_new, values[i], likely(old_value)}, Call::PureIntrinsic); + } + if (const Variable *v = op->args[dim].as()) { + // The subtractions above simplify more easily if the loop is rebased to 0. + loops_to_rebase.insert(v->name); + } + return Provide::make(func.name(), values, args); + } + + Expr visit(const Call *op) override { + if (!(op->call_type == Call::Halide && op->name == func.name())) { + return IRMutator::visit(op); + } + vector args = op->args; + for (Expr &i : args) { + i = mutate(i); + } + args[dim] -= old_bounds.min; + return Call::make(op->type, op->name, args, Call::Halide, op->func, op->value_index, op->image, op->param); + } + + Stmt visit(const For *op) override { + Stmt result = IRMutator::visit(op); + op = result.as(); + internal_assert(op); + if (loops_to_rebase.count(op->name)) { + string new_name = op->name + ".rebased"; + Stmt body = substitute(op->name, Variable::make(Int(32), new_name) + op->min, op->body); + result = For::make(new_name, 0, op->extent, op->for_type, op->device_api, body); + loops_to_rebase.erase(op->name); + } + return result; + } + + Stmt visit(const LetStmt *op) override { + ScopedBinding bind(scope, op->name, simplify(expand_expr(op->value, scope))); + return IRMutator::visit(op); + } + +public: + RollFunc(const Function &func, int dim, const string &loop_var, + const Interval &old_bounds, const Interval &new_bounds) + : func(func), dim(dim), loop_var(loop_var), old_bounds(old_bounds), new_bounds(new_bounds) { + } +}; + // Perform sliding window optimization for a function over a // particular serial for loop class SlidingWindowOnFunctionAndLoop : public IRMutator { @@ -372,7 +468,17 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { slid_dimensions.insert(dim_idx); - // Now redefine the appropriate regions required + // If we want to slide in registers, we're done here, we just need to + // save the updated bounds for later. + if (func.schedule().memory_type() == MemoryType::Register) { + this->dim_idx = dim_idx; + old_bounds = {min_required, max_required}; + new_bounds = {new_min, new_max}; + return op; + } + + // If we aren't sliding in registers, we need to update the bounds of + // the producer to be only the bounds of the region newly computed. internal_assert(replacements.empty()); if (can_slide_up) { replacements[prefix + dim + ".min"] = new_min; @@ -493,6 +599,13 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { } Expr new_loop_min; + int dim_idx; + Interval old_bounds; + Interval new_bounds; + + Stmt translate_loop(const Stmt &s) { + return RollFunc(func, dim_idx, loop_var, old_bounds, new_bounds).mutate(s); + } }; // In Stmt s, does the production of b depend on a? @@ -692,6 +805,11 @@ class SlidingWindow : public IRMutator { SlidingWindowOnFunctionAndLoop slider(func, name, prev_loop_min, slid_dimensions[func.name()]); body = slider.mutate(body); + if (func.schedule().memory_type() == MemoryType::Register && + slider.old_bounds.has_lower_bound()) { + body = slider.translate_loop(body); + } + if (slider.new_loop_min.defined()) { Expr new_loop_min = slider.new_loop_min; if (!prev_loop_min.same_as(loop_min)) { diff --git a/src/StorageFolding.cpp b/src/StorageFolding.cpp index 548d46cb57cf..f7c1ffff44bb 100644 --- a/src/StorageFolding.cpp +++ b/src/StorageFolding.cpp @@ -542,6 +542,16 @@ class AttemptStorageFoldingOfFunction : public IRMutator { Expr min = simplify(common_subexpression_elimination(box[dim].min)); Expr max = simplify(common_subexpression_elimination(box[dim].max)); + if (is_const(min) || is_const(max)) { + debug(3) << "\nNot considering folding " << func.name() + << " over for loop over " << op->name + << " dimension " << i - 1 << "\n" + << " because the min or max are constants." + << "Min: " << min << "\n" + << "Max: " << max << "\n"; + continue; + } + Expr min_provided, max_provided, min_required, max_required; if (func.schedule().async() && !explicit_only) { if (!provided.empty()) { diff --git a/test/correctness/deferred_loop_level.cpp b/test/correctness/deferred_loop_level.cpp index 8e382b26e996..47b1f41cbf42 100644 --- a/test/correctness/deferred_loop_level.cpp +++ b/test/correctness/deferred_loop_level.cpp @@ -26,18 +26,18 @@ class CheckLoopLevels : public IRVisitor { void visit(const Call *op) override { IRVisitor::visit(op); if (op->name == "sin_f32") { - _halide_user_assert(inside_for_loop == inner_loop_level); + _halide_user_assert(starts_with(inside_for_loop, inner_loop_level)); } else if (op->name == "cos_f32") { - _halide_user_assert(inside_for_loop == outer_loop_level); + _halide_user_assert(starts_with(inside_for_loop, outer_loop_level)); } } void visit(const Store *op) override { IRVisitor::visit(op); if (op->name.substr(0, 5) == "inner") { - _halide_user_assert(inside_for_loop == inner_loop_level); + _halide_user_assert(starts_with(inside_for_loop, inner_loop_level)); } else if (op->name.substr(0, 5) == "outer") { - _halide_user_assert(inside_for_loop == outer_loop_level); + _halide_user_assert(starts_with(inside_for_loop, outer_loop_level)); } else { _halide_user_assert(0); } diff --git a/test/correctness/loop_level_generator_param.cpp b/test/correctness/loop_level_generator_param.cpp index 64505f219a47..c95f28324d78 100644 --- a/test/correctness/loop_level_generator_param.cpp +++ b/test/correctness/loop_level_generator_param.cpp @@ -50,10 +50,10 @@ class CheckLoopLevels : public IRVisitor { void visit(const Call *op) override { IRVisitor::visit(op); if (op->name == "sin_f32") { - _halide_user_assert(inside_for_loop == inner_loop_level) + _halide_user_assert(starts_with(inside_for_loop, inner_loop_level)) << "call sin_f32: expected " << inner_loop_level << ", actual: " << inside_for_loop; } else if (op->name == "cos_f32") { - _halide_user_assert(inside_for_loop == outer_loop_level) + _halide_user_assert(starts_with(inside_for_loop, outer_loop_level)) << "call cos_f32: expected " << outer_loop_level << ", actual: " << inside_for_loop; } } @@ -62,10 +62,10 @@ class CheckLoopLevels : public IRVisitor { IRVisitor::visit(op); std::string op_name = strip_uniquified_names(op->name); if (op_name == "inner") { - _halide_user_assert(inside_for_loop == inner_loop_level) + _halide_user_assert(starts_with(inside_for_loop, inner_loop_level)) << "inside_for_loop: expected " << inner_loop_level << ", actual: " << inside_for_loop; } else if (op_name == "outer") { - _halide_user_assert(inside_for_loop == outer_loop_level) + _halide_user_assert(starts_with(inside_for_loop, outer_loop_level)) << "inside_for_loop: expected " << outer_loop_level << ", actual: " << inside_for_loop; } else { _halide_user_assert(0) << "store at: " << op_name << " inside_for_loop: " << inside_for_loop; diff --git a/test/correctness/sliding_window.cpp b/test/correctness/sliding_window.cpp index 31875158600a..70fe8b7a1baa 100644 --- a/test/correctness/sliding_window.cpp +++ b/test/correctness/sliding_window.cpp @@ -32,13 +32,14 @@ int main(int argc, char **argv) { return 0; } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { + count = 0; Func f, g; f(x) = call_counter(x, 0); g(x) = f(x) + f(x - 1); - f.store_root().compute_at(g, x); + f.store_root().compute_at(g, x).store_in(store_in); // Test that sliding window works when specializing. g.specialize(g.output_buffer().dim(0).min() == 0); @@ -53,7 +54,7 @@ int main(int argc, char **argv) { } // Try two producers used by the same consumer. - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { count = 0; Func f, g, h; @@ -61,8 +62,8 @@ int main(int argc, char **argv) { g(x) = call_counter(2 * x + 1, 0); h(x) = f(x) + f(x - 1) + g(x) + g(x - 1); - f.store_root().compute_at(h, x); - g.store_root().compute_at(h, x); + f.store_root().compute_at(h, x).store_in(store_in); + g.store_root().compute_at(h, x).store_in(store_in); Buffer im = h.realize({100}); if (count != 202) { @@ -72,7 +73,7 @@ int main(int argc, char **argv) { } // Try a sequence of two sliding windows. - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { count = 0; Func f, g, h; @@ -80,25 +81,26 @@ int main(int argc, char **argv) { g(x) = f(x) + f(x - 1); h(x) = g(x) + g(x - 1); - f.store_root().compute_at(h, x); - g.store_root().compute_at(h, x); + f.store_root().compute_at(h, x).store_in(store_in); + g.store_root().compute_at(h, x).store_in(store_in); Buffer im = h.realize({100}); - if (count != 102) { - printf("f was called %d times instead of %d times\n", count, 102); + int correct = store_in == MemoryType::Register ? 103 : 102; + if (count != correct) { + printf("f was called %d times instead of %d times\n", count, correct); return -1; } } // Try again where there's a containing stage - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { count = 0; Func f, g, h; f(x) = call_counter(x, 0); g(x) = f(x) + f(x - 1); h(x) = g(x); - f.store_root().compute_at(g, x); + f.store_root().compute_at(g, x).store_in(store_in); g.compute_at(h, x); Buffer im = h.realize({100}); @@ -109,7 +111,7 @@ int main(int argc, char **argv) { } // Add an inner vectorized dimension. - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { count = 0; Func f, g, h; Var c; @@ -119,6 +121,7 @@ int main(int argc, char **argv) { f.store_root() .compute_at(h, x) + .store_in(store_in) .reorder(c, x) .reorder_storage(c, x) .bound(c, 0, 4) @@ -223,14 +226,14 @@ int main(int argc, char **argv) { } } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { // Sliding where we only need a new value every third iteration of the consumer. Func f, g; f(x) = call_counter(x, 0); g(x) = f(x / 3); - f.store_root().compute_at(g, x); + f.store_root().compute_at(g, x).store_in(store_in); count = 0; Buffer im = g.realize({100}); @@ -242,7 +245,7 @@ int main(int argc, char **argv) { } } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { // Sliding where we only need a new value every third iteration of the consumer. // This test checks that we don't ask for excessive bounds. ImageParam f(Int(32), 1); @@ -252,7 +255,7 @@ int main(int argc, char **argv) { Var xo; g.split(x, xo, x, 10); - f.in().store_at(g, xo).compute_at(g, x); + f.in().store_at(g, xo).compute_at(g, x).store_in(store_in); Buffer buf(33); f.set(buf); @@ -260,7 +263,7 @@ int main(int argc, char **argv) { Buffer im = g.realize({98}); } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { // Sliding with an unrolled producer Var x, xi; Func f, g; @@ -269,7 +272,7 @@ int main(int argc, char **argv) { g(x) = f(x) + f(x - 1); g.split(x, x, xi, 10); - f.store_root().compute_at(g, x).unroll(x); + f.store_root().compute_at(g, x).store_in(store_in).unroll(x); count = 0; Buffer im = g.realize({100}); @@ -297,6 +300,29 @@ int main(int argc, char **argv) { } } + { + // Sliding with a vectorized producer and consumer, trying to rotate + // cleanly in registers. + count = 0; + Func f, g; + f(x) = call_counter(x, 0); + g(x) = f(x + 1) + f(x - 1); + + // This currently requires a trick to get everything to be aligned + // nicely. This exploits the fact that ShiftInwards splits are + // aligned to the end of the original loop (and extending before the + // min if necessary). + Var xi("xi"); + f.store_root().compute_at(g, x).store_in(MemoryType::Register).split(x, x, xi, 8).vectorize(xi, 4).unroll(xi); + g.vectorize(x, 4, TailStrategy::RoundUp); + + Buffer im = g.realize({100}); + if (count != 102) { + printf("f was called %d times instead of %d times\n", count, 102); + return -1; + } + } + { // A sequence of stencils, all computed at the output. count = 0;