From e81714ef1519f81c7d38bf2d83f7f71802b1470f Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Sat, 13 Mar 2021 00:36:14 -0700 Subject: [PATCH 01/23] Sliding in registers --- src/SlidingWindow.cpp | 98 ++++++++++++++++++++++++++++- test/correctness/sliding_window.cpp | 47 +++++++------- 2 files changed, 122 insertions(+), 23 deletions(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 9e1b7114eedb..32938b822ea7 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -119,6 +119,74 @@ bool find_produce(const Stmt &s, const string &func) { return finder.found; } +class RollFunc : public IRMutator { +public: + const Function &func; + int dim; + const string &loop_var; + const Interval &old_bounds; + const Interval &new_bounds; + + using IRMutator::visit; + + Stmt visit(const Provide *op) override { + if (op->name == func.name()) { + vector values(op->values); + for (Expr &i : values) { + i = mutate(i); + } + vector args = op->args; + for (Expr &i : args) { + i = mutate(i); + } + Expr is_new = old_bounds.min.same_as(new_bounds.min) ? args[dim] <= new_bounds.max : new_bounds.min <= args[dim]; + args[dim] -= old_bounds.min; + vector old_args = args; + old_args[dim] = substitute(loop_var, Variable::make(Int(32), loop_var) - 1, old_args[dim]); + for (int i = 0; i < (int)values.size(); i++) { + Type t = values[i].type(); + Expr old_value = + Call::make(t, op->name, old_args, Call::Halide, func.get_contents(), i); + values[i] = Call::make(t, Call::if_then_else, {is_new, values[i], old_value}, Call::PureIntrinsic); + } + return Provide::make(func.name(), values, args); + } else { + return IRMutator::visit(op); + } + } + + Expr visit(const Call *op) override { + if (op->call_type == Call::Halide && op->name == func.name()) { + vector args = op->args; + for (Expr &i : args) { + i = mutate(i); + } + args[dim] -= old_bounds.min; + return Call::make(op->type, op->name, args, Call::Halide, op->func, op->value_index, op->image, op->param); + } else { + return IRMutator::visit(op); + } + } + + Stmt visit(const LetStmt *op) override { + string prefix = func.name() + ".s0." + func.args()[dim]; + if (starts_with(op->name, prefix + ".min") || starts_with(op->name, prefix + ".max")) { + Expr value = mutate(op->value); + Stmt body = substitute(op->name, value, op->body); + body = mutate(body); + return LetStmt::make(op->name, value - old_bounds.min, body); + } else { + return IRMutator::visit(op); + } + } + +public: + RollFunc(const Function &func, int dim, const string &loop_var, + const Interval &old_bounds, const Interval &new_bounds) + : func(func), dim(dim), loop_var(loop_var), old_bounds(old_bounds), new_bounds(new_bounds) { + } +}; + // Perform sliding window optimization for a function over a // particular serial for loop class SlidingWindowOnFunctionAndLoop : public IRMutator { @@ -372,6 +440,13 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { slid_dimensions.insert(dim_idx); + if (func.schedule().memory_type() == MemoryType::Register) { + this->dim_idx = dim_idx; + old_bounds = {min_required, max_required}; + new_bounds = {new_min, new_max}; + return op; + } + // Now redefine the appropriate regions required internal_assert(replacements.empty()); if (can_slide_up) { @@ -493,6 +568,13 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { } Expr new_loop_min; + int dim_idx; + Interval old_bounds; + Interval new_bounds; + + Stmt translate_loop(const Stmt &s) { + return RollFunc(func, dim_idx, loop_var, old_bounds, new_bounds).mutate(s); + } }; // In Stmt s, does the production of b depend on a? @@ -619,6 +701,9 @@ class SlidingWindow : public IRMutator { // outermost. list sliding; + // Keep track of updated bounds for realizations. + map> new_bounds; + using IRMutator::visit; Stmt visit(const Realize *op) override { @@ -653,8 +738,13 @@ class SlidingWindow : public IRMutator { if (new_body.same_as(op->body)) { return op; } else { + vector bounds = op->bounds; + auto i = new_bounds.find(op->name); + if (i != new_bounds.end()) { + bounds[i->second.first] = {0, i->second.second.max - i->second.second.min + 1}; + } return Realize::make(op->name, op->types, op->memory_type, - op->bounds, op->condition, new_body); + bounds, op->condition, new_body); } } @@ -692,6 +782,12 @@ class SlidingWindow : public IRMutator { SlidingWindowOnFunctionAndLoop slider(func, name, prev_loop_min, slid_dimensions[func.name()]); body = slider.mutate(body); + if (func.schedule().memory_type() == MemoryType::Register && + slider.old_bounds.has_lower_bound()) { + body = slider.translate_loop(body); + new_bounds[func.name()] = {slider.dim_idx, slider.old_bounds}; + } + if (slider.new_loop_min.defined()) { Expr new_loop_min = slider.new_loop_min; if (!prev_loop_min.same_as(loop_min)) { diff --git a/test/correctness/sliding_window.cpp b/test/correctness/sliding_window.cpp index 31875158600a..f84e6b77f2cb 100644 --- a/test/correctness/sliding_window.cpp +++ b/test/correctness/sliding_window.cpp @@ -32,16 +32,18 @@ int main(int argc, char **argv) { return 0; } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { + count = 0; Func f, g; f(x) = call_counter(x, 0); g(x) = f(x) + f(x - 1); - f.store_root().compute_at(g, x); + f.store_root().compute_at(g, x).store_in(store_in); // Test that sliding window works when specializing. - g.specialize(g.output_buffer().dim(0).min() == 0); + //g.specialize(g.output_buffer().dim(0).min() == 0); + g.output_buffer().dim(0).set_min(0).set_extent(100); Buffer im = g.realize({100}); @@ -53,7 +55,7 @@ int main(int argc, char **argv) { } // Try two producers used by the same consumer. - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { count = 0; Func f, g, h; @@ -61,8 +63,8 @@ int main(int argc, char **argv) { g(x) = call_counter(2 * x + 1, 0); h(x) = f(x) + f(x - 1) + g(x) + g(x - 1); - f.store_root().compute_at(h, x); - g.store_root().compute_at(h, x); + f.store_root().compute_at(h, x).store_in(store_in); + g.store_root().compute_at(h, x).store_in(store_in); Buffer im = h.realize({100}); if (count != 202) { @@ -72,7 +74,7 @@ int main(int argc, char **argv) { } // Try a sequence of two sliding windows. - { + for (auto store_in : {MemoryType::Heap}) { count = 0; Func f, g, h; @@ -80,8 +82,8 @@ int main(int argc, char **argv) { g(x) = f(x) + f(x - 1); h(x) = g(x) + g(x - 1); - f.store_root().compute_at(h, x); - g.store_root().compute_at(h, x); + f.store_root().compute_at(h, x).store_in(store_in); + g.store_root().compute_at(h, x).store_in(store_in); Buffer im = h.realize({100}); if (count != 102) { @@ -91,14 +93,14 @@ int main(int argc, char **argv) { } // Try again where there's a containing stage - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { count = 0; Func f, g, h; f(x) = call_counter(x, 0); g(x) = f(x) + f(x - 1); h(x) = g(x); - f.store_root().compute_at(g, x); + f.store_root().compute_at(g, x).store_in(store_in); g.compute_at(h, x); Buffer im = h.realize({100}); @@ -109,7 +111,7 @@ int main(int argc, char **argv) { } // Add an inner vectorized dimension. - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { count = 0; Func f, g, h; Var c; @@ -119,6 +121,7 @@ int main(int argc, char **argv) { f.store_root() .compute_at(h, x) + .store_in(store_in) .reorder(c, x) .reorder_storage(c, x) .bound(c, 0, 4) @@ -136,14 +139,14 @@ int main(int argc, char **argv) { } // Now try with a reduction - { + for (auto store_in : {MemoryType::Heap/*, MemoryType::Register*/}) { count = 0; RDom r(0, 100); Func f, g; f(x, y) = 0; f(r, y) = call_counter(r, y); - f.store_root().compute_at(g, y); + f.store_root().compute_at(g, y).store_in(store_in); g(x, y) = f(x, y) + f(x, y - 1); @@ -223,14 +226,14 @@ int main(int argc, char **argv) { } } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { // Sliding where we only need a new value every third iteration of the consumer. Func f, g; f(x) = call_counter(x, 0); g(x) = f(x / 3); - f.store_root().compute_at(g, x); + f.store_root().compute_at(g, x).store_in(store_in); count = 0; Buffer im = g.realize({100}); @@ -242,7 +245,7 @@ int main(int argc, char **argv) { } } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { // Sliding where we only need a new value every third iteration of the consumer. // This test checks that we don't ask for excessive bounds. ImageParam f(Int(32), 1); @@ -252,7 +255,7 @@ int main(int argc, char **argv) { Var xo; g.split(x, xo, x, 10); - f.in().store_at(g, xo).compute_at(g, x); + f.in().store_at(g, xo).compute_at(g, x).store_in(store_in); Buffer buf(33); f.set(buf); @@ -260,7 +263,7 @@ int main(int argc, char **argv) { Buffer im = g.realize({98}); } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { // Sliding with an unrolled producer Var x, xi; Func f, g; @@ -269,7 +272,7 @@ int main(int argc, char **argv) { g(x) = f(x) + f(x - 1); g.split(x, x, xi, 10); - f.store_root().compute_at(g, x).unroll(x); + f.store_root().compute_at(g, x).store_in(store_in).unroll(x); count = 0; Buffer im = g.realize({100}); @@ -280,14 +283,14 @@ int main(int argc, char **argv) { } } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { // Sliding with a vectorized producer and consumer. count = 0; Func f, g; f(x) = call_counter(x, 0); g(x) = f(x + 1) + f(x - 1); - f.store_root().compute_at(g, x).vectorize(x, 4); + f.store_root().compute_at(g, x).store_in(store_in).vectorize(x, 4); g.vectorize(x, 4); Buffer im = g.realize({100}); From 89ef82a01ae3f0fb330018fccd8b3efbce2f8a56 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 11:54:56 -0600 Subject: [PATCH 02/23] Fix some failure cases. --- src/SlidingWindow.cpp | 106 +++++++++++++++++++--------- test/correctness/sliding_window.cpp | 23 +++--- 2 files changed, 84 insertions(+), 45 deletions(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 32938b822ea7..d0dcc809cd77 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -127,45 +127,71 @@ class RollFunc : public IRMutator { const Interval &old_bounds; const Interval &new_bounds; + set loops_to_rebase; + bool in_produce = false; + using IRMutator::visit; + Stmt visit(const ProducerConsumer *op) override { + bool produce_func = op->name == func.name() && op->is_producer; + ScopedValue old_in_produce(in_produce, in_produce || produce_func); + return IRMutator::visit(op); + } + Stmt visit(const Provide *op) override { - if (op->name == func.name()) { - vector values(op->values); - for (Expr &i : values) { - i = mutate(i); - } - vector args = op->args; - for (Expr &i : args) { - i = mutate(i); - } - Expr is_new = old_bounds.min.same_as(new_bounds.min) ? args[dim] <= new_bounds.max : new_bounds.min <= args[dim]; - args[dim] -= old_bounds.min; - vector old_args = args; - old_args[dim] = substitute(loop_var, Variable::make(Int(32), loop_var) - 1, old_args[dim]); - for (int i = 0; i < (int)values.size(); i++) { - Type t = values[i].type(); - Expr old_value = - Call::make(t, op->name, old_args, Call::Halide, func.get_contents(), i); - values[i] = Call::make(t, Call::if_then_else, {is_new, values[i], old_value}, Call::PureIntrinsic); - } - return Provide::make(func.name(), values, args); - } else { + if (!(in_produce && op->name == func.name())) { return IRMutator::visit(op); } + vector values = op->values; + for (Expr &i : values) { + i = mutate(i); + } + vector args = op->args; + for (Expr &i : args) { + i = mutate(i); + } + bool sliding_up = old_bounds.max.same_as(new_bounds.max); + Expr is_new = sliding_up ? new_bounds.min <= args[dim] : args[dim] <= new_bounds.max; + args[dim] -= old_bounds.min; + vector old_args = args; + old_args[dim] = substitute(loop_var, Variable::make(Int(32), loop_var) - 1, old_args[dim]); + for (int i = 0; i < (int)values.size(); i++) { + Type t = values[i].type(); + Expr old_value = + Call::make(t, op->name, old_args, Call::Halide, func.get_contents(), i); + values[i] = likely(values[i]); + values[i] = Call::make(t, Call::if_then_else, {is_new, values[i], old_value}, Call::PureIntrinsic); + } + if (const Variable *v = op->args[dim].as()) { + // The subtractions above simplify more easily if the loop is rebased to 0. + loops_to_rebase.insert(v->name); + } + return Provide::make(func.name(), values, args); } Expr visit(const Call *op) override { - if (op->call_type == Call::Halide && op->name == func.name()) { - vector args = op->args; - for (Expr &i : args) { - i = mutate(i); - } - args[dim] -= old_bounds.min; - return Call::make(op->type, op->name, args, Call::Halide, op->func, op->value_index, op->image, op->param); - } else { + if (!(op->call_type == Call::Halide && op->name == func.name())) { return IRMutator::visit(op); } + vector args = op->args; + for (Expr &i : args) { + i = mutate(i); + } + args[dim] -= old_bounds.min; + return Call::make(op->type, op->name, args, Call::Halide, op->func, op->value_index, op->image, op->param); + } + + Stmt visit(const For *op) override { + Stmt result = IRMutator::visit(op); + op = result.as(); + internal_assert(op); + if (loops_to_rebase.count(op->name)) { + string new_name = op->name + ".rebased"; + Stmt body = substitute(op->name, Variable::make(Int(32), new_name) + op->min, op->body); + result = For::make(new_name, 0, op->extent, op->for_type, op->device_api, body); + loops_to_rebase.erase(op->name); + } + return result; } Stmt visit(const LetStmt *op) override { @@ -441,6 +467,8 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { slid_dimensions.insert(dim_idx); if (func.schedule().memory_type() == MemoryType::Register) { + // If we're going to slide in registers, save the bounds + // for doing that later. this->dim_idx = dim_idx; old_bounds = {min_required, max_required}; new_bounds = {new_min, new_max}; @@ -702,7 +730,9 @@ class SlidingWindow : public IRMutator { list sliding; // Keep track of updated bounds for realizations. - map> new_bounds; + map> new_extents; + + Scope scope; using IRMutator::visit; @@ -739,9 +769,9 @@ class SlidingWindow : public IRMutator { return op; } else { vector bounds = op->bounds; - auto i = new_bounds.find(op->name); - if (i != new_bounds.end()) { - bounds[i->second.first] = {0, i->second.second.max - i->second.second.min + 1}; + auto i = new_extents.find(op->name); + if (i != new_extents.end()) { + bounds[i->second.first] = {0, i->second.second}; } return Realize::make(op->name, op->types, op->memory_type, bounds, op->condition, new_body); @@ -784,8 +814,11 @@ class SlidingWindow : public IRMutator { if (func.schedule().memory_type() == MemoryType::Register && slider.old_bounds.has_lower_bound()) { + // If we're sliding in registers, we need to rewrite the bounds + // of the realization, like storage folding would do. body = slider.translate_loop(body); - new_bounds[func.name()] = {slider.dim_idx, slider.old_bounds}; + Expr new_extent = slider.old_bounds.max - slider.old_bounds.min + 1; + new_extents[func.name()] = {slider.dim_idx, expand_expr(new_extent, scope)}; } if (slider.new_loop_min.defined()) { @@ -853,6 +886,11 @@ class SlidingWindow : public IRMutator { } } + Stmt visit(const LetStmt *op) override { + ScopedBinding bind(scope, op->name, simplify(expand_expr(op->value, scope))); + return IRMutator::visit(op); + } + public: SlidingWindow(const map &e) : env(e) { diff --git a/test/correctness/sliding_window.cpp b/test/correctness/sliding_window.cpp index f84e6b77f2cb..84b141742913 100644 --- a/test/correctness/sliding_window.cpp +++ b/test/correctness/sliding_window.cpp @@ -42,8 +42,7 @@ int main(int argc, char **argv) { f.store_root().compute_at(g, x).store_in(store_in); // Test that sliding window works when specializing. - //g.specialize(g.output_buffer().dim(0).min() == 0); - g.output_buffer().dim(0).set_min(0).set_extent(100); + g.specialize(g.output_buffer().dim(0).min() == 0); Buffer im = g.realize({100}); @@ -139,14 +138,14 @@ int main(int argc, char **argv) { } // Now try with a reduction - for (auto store_in : {MemoryType::Heap/*, MemoryType::Register*/}) { + { count = 0; RDom r(0, 100); Func f, g; f(x, y) = 0; f(r, y) = call_counter(r, y); - f.store_root().compute_at(g, y).store_in(store_in); + f.store_root().compute_at(g, y); g(x, y) = f(x, y) + f(x, y - 1); @@ -294,8 +293,10 @@ int main(int argc, char **argv) { g.vectorize(x, 4); Buffer im = g.realize({100}); - if (count != 104) { - printf("f was called %d times instead of %d times\n", count, 104); + // TODO: We shouldn't need the extra calls for registers. + int correct = store_in == MemoryType::Register ? 152 : 104; + if (count != correct) { + printf("f was called %d times instead of %d times\n", count, correct); return -1; } } @@ -344,7 +345,7 @@ int main(int argc, char **argv) { } } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { // Sliding a func that has a boundary condition before the beginning // of the loop. This needs an explicit warmup before we start sliding. count = 0; @@ -352,7 +353,7 @@ int main(int argc, char **argv) { f(x) = call_counter(x, 0); g(x) = f(max(x, 3)); - f.store_root().compute_at(g, x); + f.store_root().compute_at(g, x).store_in(store_in); g.realize({10}); if (count != 7) { @@ -361,7 +362,7 @@ int main(int argc, char **argv) { } } - { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { // Sliding a func that has a boundary condition on both sides. count = 0; Func f, g, h; @@ -369,8 +370,8 @@ int main(int argc, char **argv) { g(x) = f(clamp(x, 0, 9)); h(x) = g(x - 1) + g(x + 1); - f.store_root().compute_at(h, x); - g.store_root().compute_at(h, x); + f.store_root().compute_at(h, x).store_in(store_in); + g.store_root().compute_at(h, x).store_in(store_in); h.realize({10}); if (count != 10) { From c8a3fb1fdf477836ac1b703394c74adb18a5d1da Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 11:55:06 -0600 Subject: [PATCH 03/23] Handle if_then_else in loop partitioning. --- src/PartitionLoops.cpp | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/PartitionLoops.cpp b/src/PartitionLoops.cpp index c1e0d1fb7bfb..6d38d3267efa 100644 --- a/src/PartitionLoops.cpp +++ b/src/PartitionLoops.cpp @@ -329,28 +329,40 @@ class FindSimplifications : public IRVisitor { } } - void visit(const Select *op) override { - op->condition.accept(this); + void visit_select(const Expr &condition, const Expr &old, const Expr &true_value, const Expr &false_value) { + condition.accept(this); - bool likely_t = has_uncaptured_likely_tag(op->true_value); - bool likely_f = has_uncaptured_likely_tag(op->false_value); + bool likely_t = has_uncaptured_likely_tag(true_value); + bool likely_f = has_uncaptured_likely_tag(false_value); if (!likely_t && !likely_f) { - likely_t = has_likely_tag(op->true_value); - likely_f = has_likely_tag(op->false_value); + likely_t = has_likely_tag(true_value); + likely_f = has_likely_tag(false_value); } if (!likely_t) { - op->false_value.accept(this); + false_value.accept(this); } if (!likely_f) { - op->true_value.accept(this); + true_value.accept(this); } if (likely_t && !likely_f) { - new_simplification(op->condition, op, op->true_value, op->false_value); + new_simplification(condition, old, true_value, false_value); } else if (likely_f && !likely_t) { - new_simplification(!op->condition, op, op->false_value, op->true_value); + new_simplification(!condition, old, false_value, true_value); + } + } + + void visit(const Select *op) override { + visit_select(op->condition, op, op->true_value, op->false_value); + } + + void visit(const Call *op) override { + if (op->is_intrinsic(Call::if_then_else)) { + visit_select(op->args[0], op, op->args[1], op->args[2]); + } else { + IRVisitor::visit(op); } } From d05c72b8b6f0c33bde1b53af1178fc9e2778c7bc Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 14:54:35 -0600 Subject: [PATCH 04/23] Add rebase_loops_to_zero pass. --- Makefile | 4 ++- src/CMakeLists.txt | 2 ++ src/Lower.cpp | 6 +++++ src/RebaseLoopsToZero.cpp | 54 +++++++++++++++++++++++++++++++++++++++ src/RebaseLoopsToZero.h | 19 ++++++++++++++ 5 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 src/RebaseLoopsToZero.cpp create mode 100644 src/RebaseLoopsToZero.h diff --git a/Makefile b/Makefile index 1e80eba0a4a5..89f1f23eeee7 100644 --- a/Makefile +++ b/Makefile @@ -259,7 +259,7 @@ TUTORIAL_CXX_FLAGS ?= -std=c++11 -g -fno-omit-frame-pointer $(RTTI_CXX_FLAGS) -I # Also allow tests, via conditional compilation, to use the entire # capability of the CPU being compiled on via -march=native. This # presumes tests are run on the same machine they are compiled on. -TEST_CXX_FLAGS ?= $(TUTORIAL_CXX_FLAGS) $(CXX_WARNING_FLAGS) +TEST_CXX_FLAGS ?= $(TUTORIAL_CXX_FLAGS) $(CXX_WARNING_FLAGS) TEST_LD_FLAGS = -L$(BIN_DIR) -lHalide $(COMMON_LD_FLAGS) # In the tests, some of our expectations change depending on the llvm version @@ -513,6 +513,7 @@ SOURCE_FILES = \ RDom.cpp \ Realization.cpp \ RealizationOrder.cpp \ + RebaseLoopsToZero.cpp \ Reduction.cpp \ RegionCosts.cpp \ RemoveDeadAllocations.cpp \ @@ -689,6 +690,7 @@ HEADER_FILES = \ Realization.h \ RDom.h \ RealizationOrder.h \ + RebaseLoopsToZero.h \ Reduction.h \ RegionCosts.h \ RemoveDeadAllocations.h \ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8bc68eeeb26c..6c4b9371545b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -121,6 +121,7 @@ set(HEADER_FILES RDom.h Realization.h RealizationOrder.h + RebaseLoopsToZero.h Reduction.h RegionCosts.h RemoveDeadAllocations.h @@ -277,6 +278,7 @@ set(SOURCE_FILES RDom.cpp Realization.cpp RealizationOrder.cpp + RebaseLoopsToZero.cpp Reduction.cpp RegionCosts.cpp RemoveDeadAllocations.cpp diff --git a/src/Lower.cpp b/src/Lower.cpp index 07cb56f3556c..3810a0885940 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -47,6 +47,7 @@ #include "PurifyIndexMath.h" #include "Qualify.h" #include "RealizationOrder.h" +#include "RebaseLoopsToZero.h" #include "RemoveDeadAllocations.h" #include "RemoveExternLoops.h" #include "RemoveUndef.h" @@ -269,6 +270,11 @@ Module lower(const vector &output_funcs, debug(2) << "Lowering after bounding small realizations:\n" << s << "\n\n"; + debug(1) << "Rebasing loops to zero...\n"; + s = rebase_loops_to_zero(s); + debug(2) << "Lowering after rebasing loops to zero:\n" + << s << "\n\n"; + debug(1) << "Performing storage flattening...\n"; s = storage_flattening(s, outputs, env, t); debug(2) << "Lowering after storage flattening:\n" diff --git a/src/RebaseLoopsToZero.cpp b/src/RebaseLoopsToZero.cpp new file mode 100644 index 000000000000..4dccde1e0b9c --- /dev/null +++ b/src/RebaseLoopsToZero.cpp @@ -0,0 +1,54 @@ +#include "RebaseLoopsToZero.h" +#include "IRMutator.h" +#include "IROperator.h" + +namespace Halide { +namespace Internal { + +using std::string; + +namespace { + +bool should_rebase(ForType type) { + switch (type) { + case ForType::Extern: + case ForType::GPUBlock: + case ForType::GPUThread: + case ForType::GPULane: + return false; + default: + return true; + } +} + +class RebaseLoopsToZero : public IRMutator { + using IRMutator::visit; + + Stmt visit(const For *op) override { + if (!should_rebase(op->for_type)) { + return IRMutator::visit(op); + } + Stmt body = mutate(op->body); + string name = op->name; + if (!is_const_zero(op->min)) { + // Renaming the loop (intentionally) invalidates any .loop_min/.loop_max lets. + name = op->name + ".rebased"; + Expr loop_var = Variable::make(Int(32), name); + body = LetStmt::make(op->name, loop_var + op->min, body); + } + if (body.same_as(op->body)) { + return op; + } else { + return For::make(name, 0, op->extent, op->for_type, op->device_api, body); + } + } +}; + +} // namespace + +Stmt rebase_loops_to_zero(const Stmt &s) { + return RebaseLoopsToZero().mutate(s); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/RebaseLoopsToZero.h b/src/RebaseLoopsToZero.h new file mode 100644 index 000000000000..930d62ae3f8d --- /dev/null +++ b/src/RebaseLoopsToZero.h @@ -0,0 +1,19 @@ +#ifndef HALIDE_REBASE_LOOPS_TO_ZERO_H +#define HALIDE_REBASE_LOOPS_TO_ZERO_H + +/** \file + * Defines the lowering pass that rewrites loop mins to be 0. + */ + +#include "Expr.h" + +namespace Halide { +namespace Internal { + +/** Rewrite the mins of most loops to 0. */ +Stmt rebase_loops_to_zero(const Stmt &); + +} // namespace Internal +} // namespace Halide + +#endif From 975d7005db33704c52f9d7a9eda05a5b2ddd0fa7 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 14:55:55 -0600 Subject: [PATCH 05/23] Use select instead of if_then_else. --- src/SlidingWindow.cpp | 21 +-------------------- test/correctness/sliding_window.cpp | 29 ++++++++++++++++------------- 2 files changed, 17 insertions(+), 33 deletions(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index d0dcc809cd77..7b8104462c37 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -127,7 +127,6 @@ class RollFunc : public IRMutator { const Interval &old_bounds; const Interval &new_bounds; - set loops_to_rebase; bool in_produce = false; using IRMutator::visit; @@ -159,12 +158,7 @@ class RollFunc : public IRMutator { Type t = values[i].type(); Expr old_value = Call::make(t, op->name, old_args, Call::Halide, func.get_contents(), i); - values[i] = likely(values[i]); - values[i] = Call::make(t, Call::if_then_else, {is_new, values[i], old_value}, Call::PureIntrinsic); - } - if (const Variable *v = op->args[dim].as()) { - // The subtractions above simplify more easily if the loop is rebased to 0. - loops_to_rebase.insert(v->name); + values[i] = select(is_new, likely(values[i]), old_value); } return Provide::make(func.name(), values, args); } @@ -181,19 +175,6 @@ class RollFunc : public IRMutator { return Call::make(op->type, op->name, args, Call::Halide, op->func, op->value_index, op->image, op->param); } - Stmt visit(const For *op) override { - Stmt result = IRMutator::visit(op); - op = result.as(); - internal_assert(op); - if (loops_to_rebase.count(op->name)) { - string new_name = op->name + ".rebased"; - Stmt body = substitute(op->name, Variable::make(Int(32), new_name) + op->min, op->body); - result = For::make(new_name, 0, op->extent, op->for_type, op->device_api, body); - loops_to_rebase.erase(op->name); - } - return result; - } - Stmt visit(const LetStmt *op) override { string prefix = func.name() + ".s0." + func.args()[dim]; if (starts_with(op->name, prefix + ".min") || starts_with(op->name, prefix + ".max")) { diff --git a/test/correctness/sliding_window.cpp b/test/correctness/sliding_window.cpp index 84b141742913..e03f00bd5145 100644 --- a/test/correctness/sliding_window.cpp +++ b/test/correctness/sliding_window.cpp @@ -73,7 +73,7 @@ int main(int argc, char **argv) { } // Try a sequence of two sliding windows. - for (auto store_in : {MemoryType::Heap}) { + for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { count = 0; Func f, g, h; @@ -85,8 +85,9 @@ int main(int argc, char **argv) { g.store_root().compute_at(h, x).store_in(store_in); Buffer im = h.realize({100}); - if (count != 102) { - printf("f was called %d times instead of %d times\n", count, 102); + int correct = store_in == MemoryType::Register ? 103 : 102; + if (count != correct) { + printf("f was called %d times instead of %d times\n", count, correct); return -1; } } @@ -238,8 +239,9 @@ int main(int argc, char **argv) { Buffer im = g.realize({100}); // f should be able to tell that it only needs to compute each value once - if (count != 34) { - printf("f was called %d times instead of %d times\n", count, 34); + int correct = store_in == MemoryType::Register ? 100 : 34; + if (count != correct) { + printf("f was called %d times instead of %d times\n", count, correct); return -1; } } @@ -276,8 +278,9 @@ int main(int argc, char **argv) { count = 0; Buffer im = g.realize({100}); - if (count != 101) { - printf("f was called %d times instead of %d times\n", count, 101); + int correct = store_in == MemoryType::Register ? 110 : 101; + if (count != correct) { + printf("f was called %d times instead of %d times\n", count, correct); return -1; } } @@ -294,7 +297,7 @@ int main(int argc, char **argv) { Buffer im = g.realize({100}); // TODO: We shouldn't need the extra calls for registers. - int correct = store_in == MemoryType::Register ? 152 : 104; + int correct = store_in == MemoryType::Register ? 200 : 104; if (count != correct) { printf("f was called %d times instead of %d times\n", count, correct); return -1; @@ -345,7 +348,7 @@ int main(int argc, char **argv) { } } - for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { + { // Sliding a func that has a boundary condition before the beginning // of the loop. This needs an explicit warmup before we start sliding. count = 0; @@ -353,7 +356,7 @@ int main(int argc, char **argv) { f(x) = call_counter(x, 0); g(x) = f(max(x, 3)); - f.store_root().compute_at(g, x).store_in(store_in); + f.store_root().compute_at(g, x); g.realize({10}); if (count != 7) { @@ -362,7 +365,7 @@ int main(int argc, char **argv) { } } - for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { + { // Sliding a func that has a boundary condition on both sides. count = 0; Func f, g, h; @@ -370,8 +373,8 @@ int main(int argc, char **argv) { g(x) = f(clamp(x, 0, 9)); h(x) = g(x - 1) + g(x + 1); - f.store_root().compute_at(h, x).store_in(store_in); - g.store_root().compute_at(h, x).store_in(store_in); + f.store_root().compute_at(h, x); + g.store_root().compute_at(h, x); h.realize({10}); if (count != 10) { From 085ba4801d116016104b817e4ef8d26c8b12f9db Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 17:04:30 -0600 Subject: [PATCH 06/23] Add select comparison simplifications. --- src/Simplify_LT.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Simplify_LT.cpp b/src/Simplify_LT.cpp index 2b15ae9877de..34f2bc9ea8d1 100644 --- a/src/Simplify_LT.cpp +++ b/src/Simplify_LT.cpp @@ -285,6 +285,9 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) { rewrite(select(y, z, x + c0) < x + c1, y && (z < x + c1), c0 >= c1) || rewrite(select(y, z, x + c0) < x + c1, !y || (z < x + c1), c0 < c1) || + rewrite(c0 < select(x, c1, c2), select(x, fold(c0 < c1), fold(c0 < c2))) || + rewrite(select(x, c1, c2) < c0, select(x, fold(c1 < c0), fold(c2 < c0))) || + // Normalize comparison of ramps to a comparison of a ramp and a broadacst rewrite(ramp(x, y, lanes) < ramp(z, w, lanes), ramp(x - z, y - w, lanes) < 0))) || From 411e0cbbbeacb9ee6e420c8d92941e9da4326605 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 17:08:23 -0600 Subject: [PATCH 07/23] Don't rewrite lets --- src/SlidingWindow.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 7b8104462c37..f58ad7a97172 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -175,18 +175,6 @@ class RollFunc : public IRMutator { return Call::make(op->type, op->name, args, Call::Halide, op->func, op->value_index, op->image, op->param); } - Stmt visit(const LetStmt *op) override { - string prefix = func.name() + ".s0." + func.args()[dim]; - if (starts_with(op->name, prefix + ".min") || starts_with(op->name, prefix + ".max")) { - Expr value = mutate(op->value); - Stmt body = substitute(op->name, value, op->body); - body = mutate(body); - return LetStmt::make(op->name, value - old_bounds.min, body); - } else { - return IRMutator::visit(op); - } - } - public: RollFunc(const Function &func, int dim, const string &loop_var, const Interval &old_bounds, const Interval &new_bounds) From ce56515d1473a217f67347e33304803e10998244 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 17:19:43 -0600 Subject: [PATCH 08/23] Rebase producer loops of register slides to 0, and don't overwrite realization bounds. --- src/SlidingWindow.cpp | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index f58ad7a97172..b20885f607b2 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -127,6 +127,7 @@ class RollFunc : public IRMutator { const Interval &old_bounds; const Interval &new_bounds; + set loops_to_rebase; bool in_produce = false; using IRMutator::visit; @@ -160,6 +161,10 @@ class RollFunc : public IRMutator { Call::make(t, op->name, old_args, Call::Halide, func.get_contents(), i); values[i] = select(is_new, likely(values[i]), old_value); } + if (const Variable *v = op->args[dim].as()) { + // The subtractions above simplify more easily if the loop is rebased to 0. + loops_to_rebase.insert(v->name); + } return Provide::make(func.name(), values, args); } @@ -175,6 +180,19 @@ class RollFunc : public IRMutator { return Call::make(op->type, op->name, args, Call::Halide, op->func, op->value_index, op->image, op->param); } + Stmt visit(const For *op) override { + Stmt result = IRMutator::visit(op); + op = result.as(); + internal_assert(op); + if (loops_to_rebase.count(op->name)) { + string new_name = op->name + ".rebased"; + Stmt body = substitute(op->name, Variable::make(Int(32), new_name) + op->min, op->body); + result = For::make(new_name, 0, op->extent, op->for_type, op->device_api, body); + loops_to_rebase.erase(op->name); + } + return result; + } + public: RollFunc(const Function &func, int dim, const string &loop_var, const Interval &old_bounds, const Interval &new_bounds) @@ -698,11 +716,6 @@ class SlidingWindow : public IRMutator { // outermost. list sliding; - // Keep track of updated bounds for realizations. - map> new_extents; - - Scope scope; - using IRMutator::visit; Stmt visit(const Realize *op) override { @@ -737,13 +750,8 @@ class SlidingWindow : public IRMutator { if (new_body.same_as(op->body)) { return op; } else { - vector bounds = op->bounds; - auto i = new_extents.find(op->name); - if (i != new_extents.end()) { - bounds[i->second.first] = {0, i->second.second}; - } return Realize::make(op->name, op->types, op->memory_type, - bounds, op->condition, new_body); + op->bounds, op->condition, new_body); } } @@ -786,8 +794,6 @@ class SlidingWindow : public IRMutator { // If we're sliding in registers, we need to rewrite the bounds // of the realization, like storage folding would do. body = slider.translate_loop(body); - Expr new_extent = slider.old_bounds.max - slider.old_bounds.min + 1; - new_extents[func.name()] = {slider.dim_idx, expand_expr(new_extent, scope)}; } if (slider.new_loop_min.defined()) { @@ -855,11 +861,6 @@ class SlidingWindow : public IRMutator { } } - Stmt visit(const LetStmt *op) override { - ScopedBinding bind(scope, op->name, simplify(expand_expr(op->value, scope))); - return IRMutator::visit(op); - } - public: SlidingWindow(const map &e) : env(e) { From 4c0a6c5dedad1fc833f57ab3ae0acb480410aa52 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 17:50:26 -0600 Subject: [PATCH 09/23] Add rules for ramp < broadcast --- src/Simplify_LT.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Simplify_LT.cpp b/src/Simplify_LT.cpp index 34f2bc9ea8d1..6420f795ec15 100644 --- a/src/Simplify_LT.cpp +++ b/src/Simplify_LT.cpp @@ -291,6 +291,9 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) { // Normalize comparison of ramps to a comparison of a ramp and a broadacst rewrite(ramp(x, y, lanes) < ramp(z, w, lanes), ramp(x - z, y - w, lanes) < 0))) || + rewrite(ramp(x + c0, c1, c2) < broadcast(x + c3, c2), ramp(c0, c1, c2) < broadcast(c3, c2)) || + rewrite(broadcast(x + c3, c2) < ramp(x + c0, c1, c2), broadcast(c3, c2) < ramp(c0, c1, c2)) || + (no_overflow_int(ty) && EVAL_IN_LAMBDA (rewrite(x * c0 < y * c1, x < y * fold(c1 / c0), c1 % c0 == 0 && c0 > 0) || rewrite(x * c0 < y * c1, x * fold(c0 / c1) < y, c0 % c1 == 0 && c1 > 0) || From f42212989a74060a74446f686c94a914863ca713 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 17:50:45 -0600 Subject: [PATCH 10/23] Put the likely on the old value instead of the new value. --- src/SlidingWindow.cpp | 2 +- test/correctness/sliding_window.cpp | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index b20885f607b2..73a5d9f5d696 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -159,7 +159,7 @@ class RollFunc : public IRMutator { Type t = values[i].type(); Expr old_value = Call::make(t, op->name, old_args, Call::Halide, func.get_contents(), i); - values[i] = select(is_new, likely(values[i]), old_value); + values[i] = select(is_new, values[i], likely(old_value)); } if (const Variable *v = op->args[dim].as()) { // The subtractions above simplify more easily if the loop is rebased to 0. diff --git a/test/correctness/sliding_window.cpp b/test/correctness/sliding_window.cpp index e03f00bd5145..a0f95b6d6515 100644 --- a/test/correctness/sliding_window.cpp +++ b/test/correctness/sliding_window.cpp @@ -278,9 +278,8 @@ int main(int argc, char **argv) { count = 0; Buffer im = g.realize({100}); - int correct = store_in == MemoryType::Register ? 110 : 101; - if (count != correct) { - printf("f was called %d times instead of %d times\n", count, correct); + if (count != 101) { + printf("f was called %d times instead of %d times\n", count, 101); return -1; } } From 8466e4dfae9c1d8265b2479241c12e9f136202ed Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 15 Mar 2021 16:58:24 -0700 Subject: [PATCH 11/23] New rules for comparing ramps and broadcasts --- src/Simplify_LT.cpp | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/Simplify_LT.cpp b/src/Simplify_LT.cpp index 34f2bc9ea8d1..b8f676e88a95 100644 --- a/src/Simplify_LT.cpp +++ b/src/Simplify_LT.cpp @@ -289,8 +289,39 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) { rewrite(select(x, c1, c2) < c0, select(x, fold(c1 < c0), fold(c2 < c0))) || // Normalize comparison of ramps to a comparison of a ramp and a broadacst - rewrite(ramp(x, y, lanes) < ramp(z, w, lanes), ramp(x - z, y - w, lanes) < 0))) || + rewrite(ramp(x, y, lanes) < ramp(z, w, lanes), ramp(x - z, y - w, lanes) < 0) || + // Rules of the form: + // rewrite(ramp(x, y, lanes) < broadcast(z, lanes), ramp(x - z, y, lanes) < 0) || + // where x and z cancel usefully + rewrite(ramp(x + z, y, lanes) < broadcast(x + w, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) || + rewrite(ramp(z + x, y, lanes) < broadcast(x + w, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) || + rewrite(ramp(x + z, y, lanes) < broadcast(w + x, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) || + rewrite(ramp(z + x, y, lanes) < broadcast(w + x, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) || + + // z = 0 + rewrite(ramp(x, y, lanes) < broadcast(x + w, lanes), ramp(0, y, lanes) < broadcast(w, lanes)) || + rewrite(ramp(x, y, lanes) < broadcast(w + x, lanes), ramp(0, y, lanes) < broadcast(w, lanes)) || + + // w = 0 + rewrite(ramp(x + z, y, lanes) < broadcast(x, lanes), ramp(z, y, lanes) < 0) || + rewrite(ramp(z + x, y, lanes) < broadcast(x, lanes), ramp(z, y, lanes) < 0) || + + // With the args flipped + rewrite(broadcast(x + w, lanes) < ramp(x + z, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) || + rewrite(broadcast(x + w, lanes) < ramp(z + x, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) || + rewrite(broadcast(w + x, lanes) < ramp(x + z, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) || + rewrite(broadcast(w + x, lanes) < ramp(z + x, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) || + + // z = 0 + rewrite(broadcast(x + w, lanes) < ramp(x, y, lanes), broadcast(w, lanes) < ramp(0, y, lanes)) || + rewrite(broadcast(w + x, lanes) < ramp(x, y, lanes), broadcast(w, lanes) < ramp(0, y, lanes)) || + + // w = 0 + rewrite(broadcast(x, lanes) < ramp(x + z, y, lanes), 0 < ramp(z, y, lanes)) || + rewrite(broadcast(x, lanes) < ramp(z + x, y, lanes), 0 < ramp(z, y, lanes)) || + + false)) || (no_overflow_int(ty) && EVAL_IN_LAMBDA (rewrite(x * c0 < y * c1, x < y * fold(c1 / c0), c1 % c0 == 0 && c0 > 0) || rewrite(x * c0 < y * c1, x * fold(c0 / c1) < y, c0 % c1 == 0 && c1 > 0) || From f7111cad56291b1580f7d18ff59b4290ac55fac4 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 18:06:25 -0600 Subject: [PATCH 12/23] Switch back to if_then_else --- src/SlidingWindow.cpp | 2 +- test/correctness/sliding_window.cpp | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 73a5d9f5d696..e52860a6b069 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -159,7 +159,7 @@ class RollFunc : public IRMutator { Type t = values[i].type(); Expr old_value = Call::make(t, op->name, old_args, Call::Halide, func.get_contents(), i); - values[i] = select(is_new, values[i], likely(old_value)); + values[i] = Call::make(values[i].type(), Call::if_then_else, {is_new, values[i], likely(old_value)}, Call::PureIntrinsic); } if (const Variable *v = op->args[dim].as()) { // The subtractions above simplify more easily if the loop is rebased to 0. diff --git a/test/correctness/sliding_window.cpp b/test/correctness/sliding_window.cpp index a0f95b6d6515..8fd91fd36b0f 100644 --- a/test/correctness/sliding_window.cpp +++ b/test/correctness/sliding_window.cpp @@ -239,9 +239,8 @@ int main(int argc, char **argv) { Buffer im = g.realize({100}); // f should be able to tell that it only needs to compute each value once - int correct = store_in == MemoryType::Register ? 100 : 34; - if (count != correct) { - printf("f was called %d times instead of %d times\n", count, correct); + if (count != 34) { + printf("f was called %d times instead of %d times\n", count, 34); return -1; } } @@ -296,7 +295,7 @@ int main(int argc, char **argv) { Buffer im = g.realize({100}); // TODO: We shouldn't need the extra calls for registers. - int correct = store_in == MemoryType::Register ? 200 : 104; + int correct = store_in == MemoryType::Register ? 152 : 104; if (count != correct) { printf("f was called %d times instead of %d times\n", count, correct); return -1; From 16ad4e089ffc1788b70f46a68599454a67533249 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 18:32:11 -0600 Subject: [PATCH 13/23] Update comments. --- src/SlidingWindow.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index e52860a6b069..d4c78de53f57 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -119,6 +119,12 @@ bool find_produce(const Stmt &s, const string &func) { return finder.found; } +// This mutator rewrites calls and provides to a particular +// func: +// - Calls and Provides are shifted to be relative to the min. +// - Provides additionally are rewritten to load values from the +// previous iteration of the loop if they were computed in the +// last iteration. class RollFunc : public IRMutator { public: const Function &func; @@ -127,6 +133,8 @@ class RollFunc : public IRMutator { const Interval &old_bounds; const Interval &new_bounds; + // It helps simplify the shifted calls/provides to rebase the + // loops that are subtracted from to have a min of 0. set loops_to_rebase; bool in_produce = false; @@ -791,8 +799,6 @@ class SlidingWindow : public IRMutator { if (func.schedule().memory_type() == MemoryType::Register && slider.old_bounds.has_lower_bound()) { - // If we're sliding in registers, we need to rewrite the bounds - // of the realization, like storage folding would do. body = slider.translate_loop(body); } From bc6a7c6b11e032721b85d172e8b266af6939372e Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 19:40:34 -0600 Subject: [PATCH 14/23] Don't try to fold dimensions with a constant min or max. --- src/StorageFolding.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/StorageFolding.cpp b/src/StorageFolding.cpp index 548d46cb57cf..f7c1ffff44bb 100644 --- a/src/StorageFolding.cpp +++ b/src/StorageFolding.cpp @@ -542,6 +542,16 @@ class AttemptStorageFoldingOfFunction : public IRMutator { Expr min = simplify(common_subexpression_elimination(box[dim].min)); Expr max = simplify(common_subexpression_elimination(box[dim].max)); + if (is_const(min) || is_const(max)) { + debug(3) << "\nNot considering folding " << func.name() + << " over for loop over " << op->name + << " dimension " << i - 1 << "\n" + << " because the min or max are constants." + << "Min: " << min << "\n" + << "Max: " << max << "\n"; + continue; + } + Expr min_provided, max_provided, min_required, max_required; if (func.schedule().async() && !explicit_only) { if (!provided.empty()) { From 944ab792f25e6930da6f60b2d33745296382e28f Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 20:09:26 -0600 Subject: [PATCH 15/23] More comments. --- src/SlidingWindow.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index d4c78de53f57..f496f06a74ea 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -461,16 +461,17 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { slid_dimensions.insert(dim_idx); + // If we want to slide in registers, we're done here, we just need to + // save the updated bounds for later. if (func.schedule().memory_type() == MemoryType::Register) { - // If we're going to slide in registers, save the bounds - // for doing that later. this->dim_idx = dim_idx; old_bounds = {min_required, max_required}; new_bounds = {new_min, new_max}; return op; } - // Now redefine the appropriate regions required + // If we aren't sliding in registers, we need to update the bounds of + // the producer to be only the bounds of the region newly computed. internal_assert(replacements.empty()); if (can_slide_up) { replacements[prefix + dim + ".min"] = new_min; From b94a59b5b856386cecf1c1d9bc68a71e6585083e Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 20:32:13 -0600 Subject: [PATCH 16/23] Make the vectorized register sliding window test tighter. --- test/correctness/sliding_window.cpp | 35 ++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/test/correctness/sliding_window.cpp b/test/correctness/sliding_window.cpp index 8fd91fd36b0f..5ed809ff6522 100644 --- a/test/correctness/sliding_window.cpp +++ b/test/correctness/sliding_window.cpp @@ -283,25 +283,48 @@ int main(int argc, char **argv) { } } - for (auto store_in : {MemoryType::Heap, MemoryType::Register}) { + { // Sliding with a vectorized producer and consumer. count = 0; Func f, g; f(x) = call_counter(x, 0); g(x) = f(x + 1) + f(x - 1); - f.store_root().compute_at(g, x).store_in(store_in).vectorize(x, 4); + f.store_root().compute_at(g, x).vectorize(x, 4); g.vectorize(x, 4); Buffer im = g.realize({100}); - // TODO: We shouldn't need the extra calls for registers. - int correct = store_in == MemoryType::Register ? 152 : 104; - if (count != correct) { - printf("f was called %d times instead of %d times\n", count, correct); + if (count != 104) { + printf("f was called %d times instead of %d times\n", count, 104); return -1; } } + { + // Sliding with a vectorized producer and consumer, trying to rotate + // cleanly in registers. + count = 0; + Func f, g; + f(x) = call_counter(x, 0); + g(x) = f(x + 1) + f(x - 1); + + // This currently requires a trick to get everything to be aligned + // nicely. This exploits the fact that ShiftInwards splits are + // aligned to the end of the original loop (and extending before the + // min if necessary). + Var xi("xi"); + f.store_root().compute_at(g, x).store_in(MemoryType::Register) + .split(x, x, xi, 8).vectorize(xi, 4).unroll(xi); + g.vectorize(x, 4, TailStrategy::RoundUp); + + Buffer im = g.realize({100}); + if (count != 102) { + printf("f was called %d times instead of %d times\n", count, 102); + return -1; + } + } + return 0; + { // A sequence of stencils, all computed at the output. count = 0; From 70f9d7abef6e3c5efeeeec06e45c2a634a5e450f Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 20:37:37 -0600 Subject: [PATCH 17/23] Remove debug helper. --- test/correctness/sliding_window.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/test/correctness/sliding_window.cpp b/test/correctness/sliding_window.cpp index 5ed809ff6522..b6e888ed8088 100644 --- a/test/correctness/sliding_window.cpp +++ b/test/correctness/sliding_window.cpp @@ -323,7 +323,6 @@ int main(int argc, char **argv) { return -1; } } - return 0; { // A sequence of stencils, all computed at the output. From e54dd908a6f648cd13a0787c97cf8adb0620f34c Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Mon, 15 Mar 2021 20:46:10 -0600 Subject: [PATCH 18/23] Fix tests broken by loop rebasing. --- test/correctness/deferred_loop_level.cpp | 8 ++++---- test/correctness/loop_level_generator_param.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/correctness/deferred_loop_level.cpp b/test/correctness/deferred_loop_level.cpp index 8e382b26e996..47b1f41cbf42 100644 --- a/test/correctness/deferred_loop_level.cpp +++ b/test/correctness/deferred_loop_level.cpp @@ -26,18 +26,18 @@ class CheckLoopLevels : public IRVisitor { void visit(const Call *op) override { IRVisitor::visit(op); if (op->name == "sin_f32") { - _halide_user_assert(inside_for_loop == inner_loop_level); + _halide_user_assert(starts_with(inside_for_loop, inner_loop_level)); } else if (op->name == "cos_f32") { - _halide_user_assert(inside_for_loop == outer_loop_level); + _halide_user_assert(starts_with(inside_for_loop, outer_loop_level)); } } void visit(const Store *op) override { IRVisitor::visit(op); if (op->name.substr(0, 5) == "inner") { - _halide_user_assert(inside_for_loop == inner_loop_level); + _halide_user_assert(starts_with(inside_for_loop, inner_loop_level)); } else if (op->name.substr(0, 5) == "outer") { - _halide_user_assert(inside_for_loop == outer_loop_level); + _halide_user_assert(starts_with(inside_for_loop, outer_loop_level)); } else { _halide_user_assert(0); } diff --git a/test/correctness/loop_level_generator_param.cpp b/test/correctness/loop_level_generator_param.cpp index 64505f219a47..c95f28324d78 100644 --- a/test/correctness/loop_level_generator_param.cpp +++ b/test/correctness/loop_level_generator_param.cpp @@ -50,10 +50,10 @@ class CheckLoopLevels : public IRVisitor { void visit(const Call *op) override { IRVisitor::visit(op); if (op->name == "sin_f32") { - _halide_user_assert(inside_for_loop == inner_loop_level) + _halide_user_assert(starts_with(inside_for_loop, inner_loop_level)) << "call sin_f32: expected " << inner_loop_level << ", actual: " << inside_for_loop; } else if (op->name == "cos_f32") { - _halide_user_assert(inside_for_loop == outer_loop_level) + _halide_user_assert(starts_with(inside_for_loop, outer_loop_level)) << "call cos_f32: expected " << outer_loop_level << ", actual: " << inside_for_loop; } } @@ -62,10 +62,10 @@ class CheckLoopLevels : public IRVisitor { IRVisitor::visit(op); std::string op_name = strip_uniquified_names(op->name); if (op_name == "inner") { - _halide_user_assert(inside_for_loop == inner_loop_level) + _halide_user_assert(starts_with(inside_for_loop, inner_loop_level)) << "inside_for_loop: expected " << inner_loop_level << ", actual: " << inside_for_loop; } else if (op_name == "outer") { - _halide_user_assert(inside_for_loop == outer_loop_level) + _halide_user_assert(starts_with(inside_for_loop, outer_loop_level)) << "inside_for_loop: expected " << outer_loop_level << ", actual: " << inside_for_loop; } else { _halide_user_assert(0) << "store at: " << op_name << " inside_for_loop: " << inside_for_loop; From 86b4fd6e1a23ca39844d78af3d6e865326af0685 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Tue, 16 Mar 2021 12:38:44 -0600 Subject: [PATCH 19/23] Move rebasing after loop partitioning --- src/Lower.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Lower.cpp b/src/Lower.cpp index 099892351c9c..da4e77d2b430 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -268,11 +268,6 @@ Module lower(const vector &output_funcs, s = bound_small_allocations(s); log("Lowering after bounding small realizations:", s); - debug(1) << "Rebasing loops to zero...\n"; - s = rebase_loops_to_zero(s); - debug(2) << "Lowering after rebasing loops to zero:\n" - << s << "\n\n"; - debug(1) << "Performing storage flattening...\n"; s = storage_flattening(s, outputs, env, t); log("Lowering after storage flattening:", s); @@ -350,6 +345,11 @@ Module lower(const vector &output_funcs, s = trim_no_ops(s); log("Lowering after loop trimming:", s); + debug(1) << "Rebasing loops to zero...\n"; + s = rebase_loops_to_zero(s); + debug(2) << "Lowering after rebasing loops to zero:\n" + << s << "\n\n"; + debug(1) << "Hoisting loop invariant if statements...\n"; s = hoist_loop_invariant_if_statements(s); log("Lowering after hoisting loop invariant if statements:", s); From afe379d1571cc53c755a74c5729f15405b8e9a56 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Tue, 16 Mar 2021 12:47:32 -0600 Subject: [PATCH 20/23] clang-format --- test/correctness/sliding_window.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/correctness/sliding_window.cpp b/test/correctness/sliding_window.cpp index b6e888ed8088..70fe8b7a1baa 100644 --- a/test/correctness/sliding_window.cpp +++ b/test/correctness/sliding_window.cpp @@ -313,8 +313,7 @@ int main(int argc, char **argv) { // aligned to the end of the original loop (and extending before the // min if necessary). Var xi("xi"); - f.store_root().compute_at(g, x).store_in(MemoryType::Register) - .split(x, x, xi, 8).vectorize(xi, 4).unroll(xi); + f.store_root().compute_at(g, x).store_in(MemoryType::Register).split(x, x, xi, 8).vectorize(xi, 4).unroll(xi); g.vectorize(x, 4, TailStrategy::RoundUp); Buffer im = g.realize({100}); From b68bdcf481db4c589881535b9621addd7d71e3db Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Tue, 16 Mar 2021 13:02:28 -0600 Subject: [PATCH 21/23] clang-tidy --- src/SlidingWindow.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index f496f06a74ea..9855f3bf30fa 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -126,7 +126,6 @@ bool find_produce(const Stmt &s, const string &func) { // previous iteration of the loop if they were computed in the // last iteration. class RollFunc : public IRMutator { -public: const Function &func; int dim; const string &loop_var; From c142b77fb5f25bf93f13e7739179546021b8eae4 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Tue, 16 Mar 2021 23:31:41 -0600 Subject: [PATCH 22/23] Also put MemoryType::Register on the stack. --- src/CodeGen_C.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index 51b380a94fe7..e92271ceb436 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -2612,6 +2612,7 @@ void CodeGen_C::visit(const Allocate *op) { size_id = print_expr(make_const(size_id_type, constant_size)); if (op->memory_type == MemoryType::Stack || + op->memory_type == MemoryType::Register || (op->memory_type == MemoryType::Auto && can_allocation_fit_on_stack(stack_bytes))) { on_stack = true; From 084aba3b76012b8127aca31c683b012721a41ce0 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Tue, 23 Mar 2021 13:56:29 -0600 Subject: [PATCH 23/23] Expand arg before substitute. --- src/SlidingWindow.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 9855f3bf30fa..7d93283dde6e 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -132,6 +132,8 @@ class RollFunc : public IRMutator { const Interval &old_bounds; const Interval &new_bounds; + Scope scope; + // It helps simplify the shifted calls/provides to rebase the // loops that are subtracted from to have a min of 0. set loops_to_rebase; @@ -161,7 +163,8 @@ class RollFunc : public IRMutator { Expr is_new = sliding_up ? new_bounds.min <= args[dim] : args[dim] <= new_bounds.max; args[dim] -= old_bounds.min; vector old_args = args; - old_args[dim] = substitute(loop_var, Variable::make(Int(32), loop_var) - 1, old_args[dim]); + Expr old_arg_dim = expand_expr(old_args[dim], scope); + old_args[dim] = substitute(loop_var, Variable::make(Int(32), loop_var) - 1, old_arg_dim); for (int i = 0; i < (int)values.size(); i++) { Type t = values[i].type(); Expr old_value = @@ -200,6 +203,11 @@ class RollFunc : public IRMutator { return result; } + Stmt visit(const LetStmt *op) override { + ScopedBinding bind(scope, op->name, simplify(expand_expr(op->value, scope))); + return IRMutator::visit(op); + } + public: RollFunc(const Function &func, int dim, const string &loop_var, const Interval &old_bounds, const Interval &new_bounds)