Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e81714e
Sliding in registers
dsharletg Mar 13, 2021
89ef82a
Fix some failure cases.
dsharletg Mar 15, 2021
c8a3fb1
Handle if_then_else in loop partitioning.
dsharletg Mar 15, 2021
d05c72b
Add rebase_loops_to_zero pass.
dsharletg Mar 15, 2021
975d700
Use select instead of if_then_else.
dsharletg Mar 15, 2021
85c0ab5
Merge branch 'master' of github.com:halide/Halide into dsharletg/slid…
dsharletg Mar 15, 2021
085ba48
Add select comparison simplifications.
dsharletg Mar 15, 2021
411e0cb
Don't rewrite lets
dsharletg Mar 15, 2021
ce56515
Rebase producer loops of register slides to 0, and don't overwrite re…
dsharletg Mar 15, 2021
4c0a6c5
Add rules for ramp < broadcast
dsharletg Mar 15, 2021
f422129
Put the likely on the old value instead of the new value.
dsharletg Mar 15, 2021
8466e4d
New rules for comparing ramps and broadcasts
abadams Mar 15, 2021
0f2c9e9
Merge branch 'dsharletg/slide-registers' of https://github.com/halide…
abadams Mar 15, 2021
f7111ca
Switch back to if_then_else
dsharletg Mar 16, 2021
16ad4e0
Update comments.
dsharletg Mar 16, 2021
bc6a7c6
Don't try to fold dimensions with a constant min or max.
dsharletg Mar 16, 2021
944ab79
More comments.
dsharletg Mar 16, 2021
b94a59b
Make the vectorized register sliding window test tighter.
dsharletg Mar 16, 2021
70f9d7a
Remove debug helper.
dsharletg Mar 16, 2021
e54dd90
Fix tests broken by loop rebasing.
dsharletg Mar 16, 2021
86b4fd6
Move rebasing after loop partitioning
dsharletg Mar 16, 2021
afe379d
clang-format
dsharletg Mar 16, 2021
b68bdcf
clang-tidy
dsharletg Mar 16, 2021
c142b77
Also put MemoryType::Register on the stack.
dsharletg Mar 17, 2021
2e1f91e
Merge branch 'master' into dsharletg/slide-registers
steven-johnson Mar 19, 2021
25b3997
Merge branch 'master' of github.com:halide/Halide into dsharletg/slid…
dsharletg Mar 23, 2021
084aba3
Expand arg before substitute.
dsharletg Mar 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ SOURCE_FILES = \
RDom.cpp \
Realization.cpp \
RealizationOrder.cpp \
RebaseLoopsToZero.cpp \
Reduction.cpp \
RegionCosts.cpp \
RemoveDeadAllocations.cpp \
Expand Down Expand Up @@ -689,6 +690,7 @@ HEADER_FILES = \
Realization.h \
RDom.h \
RealizationOrder.h \
RebaseLoopsToZero.h \
Reduction.h \
RegionCosts.h \
RemoveDeadAllocations.h \
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ set(HEADER_FILES
RDom.h
Realization.h
RealizationOrder.h
RebaseLoopsToZero.h
Reduction.h
RegionCosts.h
RemoveDeadAllocations.h
Expand Down Expand Up @@ -277,6 +278,7 @@ set(SOURCE_FILES
RDom.cpp
Realization.cpp
RealizationOrder.cpp
RebaseLoopsToZero.cpp
Reduction.cpp
RegionCosts.cpp
RemoveDeadAllocations.cpp
Expand Down
1 change: 1 addition & 0 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2612,6 +2612,7 @@ void CodeGen_C::visit(const Allocate *op) {
size_id = print_expr(make_const(size_id_type, constant_size));

if (op->memory_type == MemoryType::Stack ||
op->memory_type == MemoryType::Register ||
(op->memory_type == MemoryType::Auto &&
can_allocation_fit_on_stack(stack_bytes))) {
on_stack = true;
Expand Down
6 changes: 6 additions & 0 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "PurifyIndexMath.h"
#include "Qualify.h"
#include "RealizationOrder.h"
#include "RebaseLoopsToZero.h"
#include "RemoveDeadAllocations.h"
#include "RemoveExternLoops.h"
#include "RemoveUndef.h"
Expand Down Expand Up @@ -344,6 +345,11 @@ Module lower(const vector<Function> &output_funcs,
s = trim_no_ops(s);
log("Lowering after loop trimming:", s);

debug(1) << "Rebasing loops to zero...\n";
s = rebase_loops_to_zero(s);
debug(2) << "Lowering after rebasing loops to zero:\n"
<< s << "\n\n";

debug(1) << "Hoisting loop invariant if statements...\n";
s = hoist_loop_invariant_if_statements(s);
log("Lowering after hoisting loop invariant if statements:", s);
Expand Down
32 changes: 22 additions & 10 deletions src/PartitionLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,28 +329,40 @@ class FindSimplifications : public IRVisitor {
}
}

void visit(const Select *op) override {
op->condition.accept(this);
void visit_select(const Expr &condition, const Expr &old, const Expr &true_value, const Expr &false_value) {
condition.accept(this);

bool likely_t = has_uncaptured_likely_tag(op->true_value);
bool likely_f = has_uncaptured_likely_tag(op->false_value);
bool likely_t = has_uncaptured_likely_tag(true_value);
bool likely_f = has_uncaptured_likely_tag(false_value);

if (!likely_t && !likely_f) {
likely_t = has_likely_tag(op->true_value);
likely_f = has_likely_tag(op->false_value);
likely_t = has_likely_tag(true_value);
likely_f = has_likely_tag(false_value);
}

if (!likely_t) {
op->false_value.accept(this);
false_value.accept(this);
}
if (!likely_f) {
op->true_value.accept(this);
true_value.accept(this);
}

if (likely_t && !likely_f) {
new_simplification(op->condition, op, op->true_value, op->false_value);
new_simplification(condition, old, true_value, false_value);
} else if (likely_f && !likely_t) {
new_simplification(!op->condition, op, op->false_value, op->true_value);
new_simplification(!condition, old, false_value, true_value);
}
}

void visit(const Select *op) override {
visit_select(op->condition, op, op->true_value, op->false_value);
}

void visit(const Call *op) override {
if (op->is_intrinsic(Call::if_then_else)) {
visit_select(op->args[0], op, op->args[1], op->args[2]);
} else {
IRVisitor::visit(op);
}
}

Expand Down
54 changes: 54 additions & 0 deletions src/RebaseLoopsToZero.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "RebaseLoopsToZero.h"
#include "IRMutator.h"
#include "IROperator.h"

namespace Halide {
namespace Internal {

using std::string;

namespace {

bool should_rebase(ForType type) {
switch (type) {
case ForType::Extern:
case ForType::GPUBlock:
case ForType::GPUThread:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu loops absolutely need to be rebased to zero, and there's a pass inside FuseGPUThreadLoops.cpp that does it. Is that mutator now redundant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That mutator works a little differently. It substitutes loop_var + min, rather than makes a new let, and that seems like maybe something that might matter for another pass/logic (there's a lot of stuff happening in FuseGPUThreadLoops). I would also be careful about changing when loops get rebased to 0.

Maybe this PR should make it so rebase_loops_to_zero accepts a set of ForType that get rebased, and use that in FuseGPUThreadLoops? That would at least keep the rebasing happening at the same time, and the only change would be substitute vs. let.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the new mutator is just more correct than the old one. I guess the old one happens earlier though, so we can't just rely on the new one to do the mutation. Your proposal sounds good, but I have no strong feelings one way or the other.

Copy link
Contributor Author

@dsharletg dsharletg Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the duplication here is minimal and there are non-minimal risks in messing with this, so I'll save it for a separate PR (will file an issue).

case ForType::GPULane:
return false;
default:
return true;
}
}

class RebaseLoopsToZero : public IRMutator {
using IRMutator::visit;

Stmt visit(const For *op) override {
if (!should_rebase(op->for_type)) {
return IRMutator::visit(op);
}
Stmt body = mutate(op->body);
string name = op->name;
if (!is_const_zero(op->min)) {
// Renaming the loop (intentionally) invalidates any .loop_min/.loop_max lets.
name = op->name + ".rebased";
Expr loop_var = Variable::make(Int(32), name);
body = LetStmt::make(op->name, loop_var + op->min, body);
}
if (body.same_as(op->body)) {
return op;
} else {
return For::make(name, 0, op->extent, op->for_type, op->device_api, body);
}
}
};

} // namespace

Stmt rebase_loops_to_zero(const Stmt &s) {
return RebaseLoopsToZero().mutate(s);
}

} // namespace Internal
} // namespace Halide
19 changes: 19 additions & 0 deletions src/RebaseLoopsToZero.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef HALIDE_REBASE_LOOPS_TO_ZERO_H
#define HALIDE_REBASE_LOOPS_TO_ZERO_H

/** \file
* Defines the lowering pass that rewrites loop mins to be 0.
*/

#include "Expr.h"

namespace Halide {
namespace Internal {

/** Rewrite the mins of most loops to 0. */
Stmt rebase_loops_to_zero(const Stmt &);

} // namespace Internal
} // namespace Halide

#endif
36 changes: 35 additions & 1 deletion src/Simplify_LT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,43 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) {
rewrite(select(y, z, x + c0) < x + c1, y && (z < x + c1), c0 >= c1) ||
rewrite(select(y, z, x + c0) < x + c1, !y || (z < x + c1), c0 < c1) ||

rewrite(c0 < select(x, c1, c2), select(x, fold(c0 < c1), fold(c0 < c2))) ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These rules all formally verify ok

rewrite(select(x, c1, c2) < c0, select(x, fold(c1 < c0), fold(c2 < c0))) ||

// Normalize comparison of ramps to a comparison of a ramp and a broadacst
rewrite(ramp(x, y, lanes) < ramp(z, w, lanes), ramp(x - z, y - w, lanes) < 0))) ||
rewrite(ramp(x, y, lanes) < ramp(z, w, lanes), ramp(x - z, y - w, lanes) < 0) ||

// Rules of the form:
// rewrite(ramp(x, y, lanes) < broadcast(z, lanes), ramp(x - z, y, lanes) < 0) ||
// where x and z cancel usefully
rewrite(ramp(x + z, y, lanes) < broadcast(x + w, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) ||
rewrite(ramp(z + x, y, lanes) < broadcast(x + w, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) ||
rewrite(ramp(x + z, y, lanes) < broadcast(w + x, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) ||
rewrite(ramp(z + x, y, lanes) < broadcast(w + x, lanes), ramp(z, y, lanes) < broadcast(w, lanes)) ||

// z = 0
rewrite(ramp(x, y, lanes) < broadcast(x + w, lanes), ramp(0, y, lanes) < broadcast(w, lanes)) ||
rewrite(ramp(x, y, lanes) < broadcast(w + x, lanes), ramp(0, y, lanes) < broadcast(w, lanes)) ||

// w = 0
rewrite(ramp(x + z, y, lanes) < broadcast(x, lanes), ramp(z, y, lanes) < 0) ||
rewrite(ramp(z + x, y, lanes) < broadcast(x, lanes), ramp(z, y, lanes) < 0) ||

// With the args flipped
rewrite(broadcast(x + w, lanes) < ramp(x + z, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) ||
rewrite(broadcast(x + w, lanes) < ramp(z + x, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) ||
rewrite(broadcast(w + x, lanes) < ramp(x + z, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) ||
rewrite(broadcast(w + x, lanes) < ramp(z + x, y, lanes), broadcast(w, lanes) < ramp(z, y, lanes)) ||

// z = 0
rewrite(broadcast(x + w, lanes) < ramp(x, y, lanes), broadcast(w, lanes) < ramp(0, y, lanes)) ||
rewrite(broadcast(w + x, lanes) < ramp(x, y, lanes), broadcast(w, lanes) < ramp(0, y, lanes)) ||

// w = 0
rewrite(broadcast(x, lanes) < ramp(x + z, y, lanes), 0 < ramp(z, y, lanes)) ||
rewrite(broadcast(x, lanes) < ramp(z + x, y, lanes), 0 < ramp(z, y, lanes)) ||

false)) ||
(no_overflow_int(ty) && EVAL_IN_LAMBDA
(rewrite(x * c0 < y * c1, x < y * fold(c1 / c0), c1 % c0 == 0 && c0 > 0) ||
rewrite(x * c0 < y * c1, x * fold(c0 / c1) < y, c0 % c1 == 0 && c1 > 0) ||
Expand Down
120 changes: 119 additions & 1 deletion src/SlidingWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,102 @@ bool find_produce(const Stmt &s, const string &func) {
return finder.found;
}

// This mutator rewrites calls and provides to a particular
// func:
// - Calls and Provides are shifted to be relative to the min.
// - Provides additionally are rewritten to load values from the
// previous iteration of the loop if they were computed in the
// last iteration.
class RollFunc : public IRMutator {
const Function &func;
int dim;
const string &loop_var;
const Interval &old_bounds;
const Interval &new_bounds;

Scope<Expr> scope;

// It helps simplify the shifted calls/provides to rebase the
// loops that are subtracted from to have a min of 0.
set<string> loops_to_rebase;
bool in_produce = false;

using IRMutator::visit;

Stmt visit(const ProducerConsumer *op) override {
bool produce_func = op->name == func.name() && op->is_producer;
ScopedValue<bool> old_in_produce(in_produce, in_produce || produce_func);
return IRMutator::visit(op);
}

Stmt visit(const Provide *op) override {
if (!(in_produce && op->name == func.name())) {
return IRMutator::visit(op);
}
vector<Expr> values = op->values;
for (Expr &i : values) {
i = mutate(i);
}
vector<Expr> args = op->args;
for (Expr &i : args) {
i = mutate(i);
}
bool sliding_up = old_bounds.max.same_as(new_bounds.max);
Expr is_new = sliding_up ? new_bounds.min <= args[dim] : args[dim] <= new_bounds.max;
args[dim] -= old_bounds.min;
vector<Expr> old_args = args;
Expr old_arg_dim = expand_expr(old_args[dim], scope);
old_args[dim] = substitute(loop_var, Variable::make(Int(32), loop_var) - 1, old_arg_dim);
for (int i = 0; i < (int)values.size(); i++) {
Type t = values[i].type();
Expr old_value =
Call::make(t, op->name, old_args, Call::Halide, func.get_contents(), i);
values[i] = Call::make(values[i].type(), Call::if_then_else, {is_new, values[i], likely(old_value)}, Call::PureIntrinsic);
}
if (const Variable *v = op->args[dim].as<Variable>()) {
// The subtractions above simplify more easily if the loop is rebased to 0.
loops_to_rebase.insert(v->name);
}
return Provide::make(func.name(), values, args);
}

Expr visit(const Call *op) override {
if (!(op->call_type == Call::Halide && op->name == func.name())) {
return IRMutator::visit(op);
}
vector<Expr> args = op->args;
for (Expr &i : args) {
i = mutate(i);
}
args[dim] -= old_bounds.min;
return Call::make(op->type, op->name, args, Call::Halide, op->func, op->value_index, op->image, op->param);
}

Stmt visit(const For *op) override {
Stmt result = IRMutator::visit(op);
op = result.as<For>();
internal_assert(op);
if (loops_to_rebase.count(op->name)) {
string new_name = op->name + ".rebased";
Stmt body = substitute(op->name, Variable::make(Int(32), new_name) + op->min, op->body);
result = For::make(new_name, 0, op->extent, op->for_type, op->device_api, body);
loops_to_rebase.erase(op->name);
}
return result;
}

Stmt visit(const LetStmt *op) override {
ScopedBinding<Expr> bind(scope, op->name, simplify(expand_expr(op->value, scope)));
return IRMutator::visit(op);
}

public:
RollFunc(const Function &func, int dim, const string &loop_var,
const Interval &old_bounds, const Interval &new_bounds)
: func(func), dim(dim), loop_var(loop_var), old_bounds(old_bounds), new_bounds(new_bounds) {
}
};

// Perform sliding window optimization for a function over a
// particular serial for loop
class SlidingWindowOnFunctionAndLoop : public IRMutator {
Expand Down Expand Up @@ -372,7 +468,17 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {

slid_dimensions.insert(dim_idx);

// Now redefine the appropriate regions required
// If we want to slide in registers, we're done here, we just need to
// save the updated bounds for later.
if (func.schedule().memory_type() == MemoryType::Register) {
this->dim_idx = dim_idx;
old_bounds = {min_required, max_required};
new_bounds = {new_min, new_max};
return op;
}

// If we aren't sliding in registers, we need to update the bounds of
// the producer to be only the bounds of the region newly computed.
internal_assert(replacements.empty());
if (can_slide_up) {
replacements[prefix + dim + ".min"] = new_min;
Expand Down Expand Up @@ -493,6 +599,13 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
}

Expr new_loop_min;
int dim_idx;
Interval old_bounds;
Interval new_bounds;

Stmt translate_loop(const Stmt &s) {
return RollFunc(func, dim_idx, loop_var, old_bounds, new_bounds).mutate(s);
}
};

// In Stmt s, does the production of b depend on a?
Expand Down Expand Up @@ -692,6 +805,11 @@ class SlidingWindow : public IRMutator {
SlidingWindowOnFunctionAndLoop slider(func, name, prev_loop_min, slid_dimensions[func.name()]);
body = slider.mutate(body);

if (func.schedule().memory_type() == MemoryType::Register &&
slider.old_bounds.has_lower_bound()) {
body = slider.translate_loop(body);
}

if (slider.new_loop_min.defined()) {
Expr new_loop_min = slider.new_loop_min;
if (!prev_loop_min.same_as(loop_min)) {
Expand Down
10 changes: 10 additions & 0 deletions src/StorageFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,16 @@ class AttemptStorageFoldingOfFunction : public IRMutator {
Expr min = simplify(common_subexpression_elimination(box[dim].min));
Expr max = simplify(common_subexpression_elimination(box[dim].max));

if (is_const(min) || is_const(max)) {
debug(3) << "\nNot considering folding " << func.name()
<< " over for loop over " << op->name
<< " dimension " << i - 1 << "\n"
<< " because the min or max are constants."
<< "Min: " << min << "\n"
<< "Max: " << max << "\n";
continue;
}

Expr min_provided, max_provided, min_required, max_required;
if (func.schedule().async() && !explicit_only) {
if (!provided.empty()) {
Expand Down
Loading