Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 141 additions & 57 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ using std::set;
using std::string;
using std::vector;

namespace {

bool is_folding_semaphore(const string &sema_name, const vector<string> &funcs) {
for (const auto &f : funcs) {
if (starts_with(sema_name, f + ".folding_semaphore.")) {
return true;
}
}
return false;
}

} // namespace

/** A mutator which eagerly folds no-op stmts */
class NoOpCollapsingMutator : public IRMutator {
protected:
Expand Down Expand Up @@ -105,23 +118,15 @@ class NoOpCollapsingMutator : public IRMutator {
};

class GenerateProducerBody : public NoOpCollapsingMutator {
const string &func;
vector<Expr> sema;
const vector<string> funcs;

using NoOpCollapsingMutator::visit;

// Preserve produce nodes and add synchronization
Stmt visit(const ProducerConsumer *op) override {
if (op->name == func && op->is_producer) {
// Add post-synchronization
internal_assert(!sema.empty()) << "Duplicate produce node: " << op->name << "\n";
Stmt body = op->body;
while (!sema.empty()) {
Expr release = Call::make(Int(32), "halide_semaphore_release", {sema.back(), 1}, Call::Extern);
body = Block::make(body, Evaluate::make(release));
sema.pop_back();
}
return ProducerConsumer::make_produce(op->name, body);
auto it = std::find(funcs.begin(), funcs.end(), op->name);
if ((it != funcs.end()) && op->is_producer) {
return op;
} else {
Stmt body = mutate(op->body);
if (is_no_op(body) || op->is_producer) {
Expand All @@ -142,7 +147,7 @@ class GenerateProducerBody : public NoOpCollapsingMutator {
}

Stmt visit(const Store *op) override {
if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) {
if (is_folding_semaphore(op->name, funcs) && ends_with(op->name, ".head")) {
// This is a counter associated with the producer side of a storage-folding semaphore. Keep it.
return op;
} else {
Expand All @@ -164,7 +169,7 @@ class GenerateProducerBody : public NoOpCollapsingMutator {
internal_assert(var);
if (is_no_op(body)) {
return body;
} else if (starts_with(var->name, func + ".folding_semaphore.")) {
} else if (is_folding_semaphore(var->name, funcs)) {
// This is a storage-folding semaphore for the func we're producing. Keep it.
return Acquire::make(op->semaphore, op->count, body);
} else {
Expand Down Expand Up @@ -194,27 +199,24 @@ class GenerateProducerBody : public NoOpCollapsingMutator {
set<string> inner_semaphores;

public:
GenerateProducerBody(const string &f, const vector<Expr> &s, map<string, string> &a)
: func(f), sema(s), cloned_acquires(a) {
GenerateProducerBody(const vector<string> &fg, map<string, string> &a)
: funcs(fg), cloned_acquires(a) {
}
};

class GenerateConsumerBody : public NoOpCollapsingMutator {
const string &func;
vector<Expr> sema;
const vector<string> &funcs;

using NoOpCollapsingMutator::visit;

Stmt visit(const ProducerConsumer *op) override {
if (op->name == func) {
auto it = std::find(funcs.begin(), funcs.end(), op->name);
if (it != funcs.end()) {
if (op->is_producer) {
// Remove the work entirely
return Evaluate::make(0);
} else {
// Synchronize on the work done by the producer before beginning consumption
Expr acquire_sema = sema.back();
sema.pop_back();
return Acquire::make(acquire_sema, 1, op);
return op;
}
} else {
return NoOpCollapsingMutator::visit(op);
Expand All @@ -223,15 +225,15 @@ class GenerateConsumerBody : public NoOpCollapsingMutator {

Stmt visit(const Allocate *op) override {
// Don't want to keep the producer's storage-folding tracker - it's dead code on the consumer side
if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) {
if (is_folding_semaphore(op->name, funcs) && ends_with(op->name, ".head")) {
return mutate(op->body);
} else {
return NoOpCollapsingMutator::visit(op);
}
}

Stmt visit(const Store *op) override {
if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) {
if (is_folding_semaphore(op->name, funcs) && ends_with(op->name, ".head")) {
return Evaluate::make(0);
} else {
return NoOpCollapsingMutator::visit(op);
Expand All @@ -243,16 +245,16 @@ class GenerateConsumerBody : public NoOpCollapsingMutator {
// Ones from folding should go to the producer side.
const Variable *var = op->semaphore.as<Variable>();
internal_assert(var);
if (starts_with(var->name, func + ".folding_semaphore.")) {
if (is_folding_semaphore(var->name, funcs)) {
return mutate(op->body);
} else {
return NoOpCollapsingMutator::visit(op);
}
}

public:
GenerateConsumerBody(const string &f, const vector<Expr> &s)
: func(f), sema(s) {
GenerateConsumerBody(const vector<string> &fg)
: funcs(fg) {
}
};

Expand Down Expand Up @@ -304,38 +306,114 @@ class CountConsumeNodes : public IRVisitor {
int count = 0;
};

class InsertProducerSemaphores : public IRMutator {
const string &func;
vector<Expr> sema;

using IRMutator::visit;

Stmt visit(const ProducerConsumer *op) override {
if (op->name == func && op->is_producer) {
// Add post-synchronization
Stmt body = op->body;
while (!sema.empty()) {
Expr release = Call::make(Int(32), "halide_semaphore_release", {sema.back(), 1}, Call::Extern);
body = Block::make(body, Evaluate::make(release));
sema.pop_back();
}
return ProducerConsumer::make_produce(op->name, body);
}
return IRMutator::visit(op);
}

public:
InsertProducerSemaphores(const string &f, const vector<Expr> &s)
: func(f), sema(s) {
}
};

class InsertConsumerSemaphores : public IRMutator {
const string &func;
vector<Expr> sema;

using IRMutator::visit;

Stmt visit(const ProducerConsumer *op) override {
if (op->name == func && !op->is_producer) {
// Synchronize on the work done by the producer before beginning consumption
Expr acquire_sema = sema.back();
sema.pop_back();
return Acquire::make(acquire_sema, 1, op);
}
return IRMutator::visit(op);
}

public:
InsertConsumerSemaphores(const string &f, const vector<Expr> &s)
: func(f), sema(s) {
}
};

class ForkAsyncProducers : public IRMutator {
using IRMutator::visit;

const map<string, Function> &env;
const std::vector<std::vector<std::string>> &fused_groups;
std::vector<size_t> fused_group_func_counter;

map<string, string> cloned_acquires;

Stmt visit(const Realize *op) override {
auto it = env.find(op->name);
internal_assert(it != env.end());
Function f = it->second;
if (f.schedule().async()) {

size_t fused_group_index;
for (fused_group_index = 0; fused_group_index < fused_groups.size(); fused_group_index++) {
const auto &fused_group = fused_groups[fused_group_index];
auto func_it = std::find(fused_group.begin(), fused_group.end(), op->name);
if (func_it != fused_group.end()) {
break;
}
}
internal_assert(fused_group_index < fused_groups.size());
// We want to make sure that for fused function group transformation is
// applied to the inner Realize node.
fused_group_func_counter[fused_group_index]++;
const auto &current_group = fused_groups[fused_group_index];

Stmt mutated;
if (f.schedule().async() &&
fused_group_func_counter[fused_group_index] == current_group.size()) {
Stmt body = op->body;

// Make two copies of the body, one which only does the
// producer, and one which only does the consumer. Inject
// synchronization to preserve dependencies. Put them in a
// task-parallel block.

vector<vector<string>> sema_names(current_group.size());
vector<vector<Expr>> sema_vars(current_group.size());

// Make a semaphore per consume node
CountConsumeNodes consumes(op->name);
body.accept(&consumes);

vector<string> sema_names;
vector<Expr> sema_vars;
for (int i = 0; i < consumes.count; i++) {
sema_names.push_back(op->name + ".semaphore_" + std::to_string(i));
sema_vars.push_back(Variable::make(Handle(), sema_names.back()));
for (size_t ix = 0; ix < current_group.size(); ix++) {
CountConsumeNodes consumes(current_group[ix]);
body.accept(&consumes);

for (int i = 0; i < consumes.count; i++) {
sema_names[ix].push_back(current_group[ix] + ".semaphore_" + std::to_string(i));
sema_vars[ix].push_back(Variable::make(Handle(), sema_names[ix].back()));
}
}

Stmt producer = GenerateProducerBody(op->name, sema_vars, cloned_acquires).mutate(body);
Stmt consumer = GenerateConsumerBody(op->name, sema_vars).mutate(body);
Stmt producer = GenerateProducerBody(current_group, cloned_acquires).mutate(body);
Stmt consumer = GenerateConsumerBody(current_group).mutate(body);

// Insert producer/consumer semaphores.
for (size_t ix = 0; ix < current_group.size(); ix++) {
producer = InsertProducerSemaphores(current_group[ix], sema_vars[ix]).mutate(producer);
consumer = InsertConsumerSemaphores(current_group[ix], sema_vars[ix]).mutate(consumer);
}

// Recurse on both sides
producer = mutate(producer);
Expand All @@ -344,33 +422,38 @@ class ForkAsyncProducers : public IRMutator {
// Run them concurrently
body = Fork::make(producer, consumer);

for (const string &sema_name : sema_names) {
for (const auto &names : sema_names) {
// Make a semaphore on the stack
Expr sema_space = Call::make(type_of<halide_semaphore_t *>(), "halide_make_semaphore",
{0}, Call::Extern);

// If there's a nested async producer, we may have
// recursively cloned this semaphore inside the mutation
// of the producer and consumer.
auto it = cloned_acquires.find(sema_name);
if (it != cloned_acquires.end()) {
body = CloneAcquire(sema_name, it->second).mutate(body);
body = LetStmt::make(it->second, sema_space, body);
for (const string &sema_name : names) {

// If there's a nested async producer, we may have
// recursively cloned this semaphore inside the mutation
// of the producer and consumer.
auto it = cloned_acquires.find(sema_name);
if (it != cloned_acquires.end()) {
body = CloneAcquire(sema_name, it->second).mutate(body);
body = LetStmt::make(it->second, sema_space, body);
}

body = LetStmt::make(sema_name, sema_space, body);
}

body = LetStmt::make(sema_name, sema_space, body);
}

return Realize::make(op->name, op->types, op->memory_type,
op->bounds, op->condition, body);
mutated = Realize::make(op->name, op->types, op->memory_type,
op->bounds, op->condition, body);
} else {
return IRMutator::visit(op);
mutated = IRMutator::visit(op);
}

fused_group_func_counter[fused_group_index]--;
return mutated;
}

public:
ForkAsyncProducers(const map<string, Function> &e)
: env(e) {
ForkAsyncProducers(const map<string, Function> &e, const std::vector<std::vector<std::string>> &fg)
: env(e), fused_groups(fg), fused_group_func_counter(fg.size()) {
}
};

Expand Down Expand Up @@ -651,9 +734,10 @@ class TightenForkNodes : public IRMutator {

// TODO: merge semaphores?

Stmt fork_async_producers(Stmt s, const map<string, Function> &env) {
Stmt fork_async_producers(Stmt s, const map<string, Function> &env,
const std::vector<std::vector<std::string>> &fused_groups) {
s = TightenProducerConsumerNodes(env).mutate(s);
s = ForkAsyncProducers(env).mutate(s);
s = ForkAsyncProducers(env, fused_groups).mutate(s);
s = ExpandAcquireNodes().mutate(s);
s = TightenForkNodes().mutate(s);
s = InitializeSemaphores().mutate(s);
Expand Down
3 changes: 2 additions & 1 deletion src/AsyncProducers.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace Internal {

class Function;

Stmt fork_async_producers(Stmt s, const std::map<std::string, Function> &env);
Stmt fork_async_producers(Stmt s, const std::map<std::string, Function> &env,
const std::vector<std::vector<std::string>> &fused_groups);

} // namespace Internal
} // namespace Halide
Expand Down
2 changes: 1 addition & 1 deletion src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ Module lower(const vector<Function> &output_funcs,
<< s << "\n\n";

debug(1) << "Forking asynchronous producers...\n";
s = fork_async_producers(s, env);
s = fork_async_producers(s, env, fused_groups);
debug(2) << "Lowering after forking asynchronous producers:\n"
<< s << "\n";

Expand Down
5 changes: 5 additions & 0 deletions src/ScheduleFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2179,6 +2179,11 @@ void validate_fused_group_schedule_helper(const string &fn,
<< ") and " << p.func_2 << ".s" << p.stage_2 << " ("
<< func_2.schedule().compute_level().to_string() << ") do not match.\n";

// Verify that they have matching async flags.
user_assert(func_1.schedule().async() == func_2.schedule().async())
<< "Invalid compute_with: functions " << func_1.name()
<< " and " << func_2.name() << " have different async flags.\n";

const vector<Dim> &dims_1 = def_1.schedule().dims();
const vector<Dim> &dims_2 = def_2.schedule().dims();

Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ tests(GROUPS correctness
argmax.cpp
assertion_failure_in_parallel_for.cpp
async.cpp
async_compute_with.cpp
async_copy_chain.cpp
async_device_copy.cpp
atomic_tuples.cpp
Expand Down
Loading