diff --git a/src/AsyncProducers.cpp b/src/AsyncProducers.cpp index ff7356409c58..9d7251865bb1 100644 --- a/src/AsyncProducers.cpp +++ b/src/AsyncProducers.cpp @@ -14,6 +14,19 @@ using std::set; using std::string; using std::vector; +namespace { + +bool is_folding_semaphore(const string &sema_name, const vector &funcs) { + for (const auto &f : funcs) { + if (starts_with(sema_name, f + ".folding_semaphore.")) { + return true; + } + } + return false; +} + +} // namespace + /** A mutator which eagerly folds no-op stmts */ class NoOpCollapsingMutator : public IRMutator { protected: @@ -105,23 +118,15 @@ class NoOpCollapsingMutator : public IRMutator { }; class GenerateProducerBody : public NoOpCollapsingMutator { - const string &func; - vector sema; + const vector funcs; using NoOpCollapsingMutator::visit; // Preserve produce nodes and add synchronization Stmt visit(const ProducerConsumer *op) override { - if (op->name == func && op->is_producer) { - // Add post-synchronization - internal_assert(!sema.empty()) << "Duplicate produce node: " << op->name << "\n"; - Stmt body = op->body; - while (!sema.empty()) { - Expr release = Call::make(Int(32), "halide_semaphore_release", {sema.back(), 1}, Call::Extern); - body = Block::make(body, Evaluate::make(release)); - sema.pop_back(); - } - return ProducerConsumer::make_produce(op->name, body); + auto it = std::find(funcs.begin(), funcs.end(), op->name); + if ((it != funcs.end()) && op->is_producer) { + return op; } else { Stmt body = mutate(op->body); if (is_no_op(body) || op->is_producer) { @@ -142,7 +147,7 @@ class GenerateProducerBody : public NoOpCollapsingMutator { } Stmt visit(const Store *op) override { - if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) { + if (is_folding_semaphore(op->name, funcs) && ends_with(op->name, ".head")) { // This is a counter associated with the producer side of a storage-folding semaphore. Keep it. return op; } else { @@ -164,7 +169,7 @@ class GenerateProducerBody : public NoOpCollapsingMutator { internal_assert(var); if (is_no_op(body)) { return body; - } else if (starts_with(var->name, func + ".folding_semaphore.")) { + } else if (is_folding_semaphore(var->name, funcs)) { // This is a storage-folding semaphore for the func we're producing. Keep it. return Acquire::make(op->semaphore, op->count, body); } else { @@ -194,27 +199,24 @@ class GenerateProducerBody : public NoOpCollapsingMutator { set inner_semaphores; public: - GenerateProducerBody(const string &f, const vector &s, map &a) - : func(f), sema(s), cloned_acquires(a) { + GenerateProducerBody(const vector &fg, map &a) + : funcs(fg), cloned_acquires(a) { } }; class GenerateConsumerBody : public NoOpCollapsingMutator { - const string &func; - vector sema; + const vector &funcs; using NoOpCollapsingMutator::visit; Stmt visit(const ProducerConsumer *op) override { - if (op->name == func) { + auto it = std::find(funcs.begin(), funcs.end(), op->name); + if (it != funcs.end()) { if (op->is_producer) { // Remove the work entirely return Evaluate::make(0); } else { - // Synchronize on the work done by the producer before beginning consumption - Expr acquire_sema = sema.back(); - sema.pop_back(); - return Acquire::make(acquire_sema, 1, op); + return op; } } else { return NoOpCollapsingMutator::visit(op); @@ -223,7 +225,7 @@ class GenerateConsumerBody : public NoOpCollapsingMutator { Stmt visit(const Allocate *op) override { // Don't want to keep the producer's storage-folding tracker - it's dead code on the consumer side - if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) { + if (is_folding_semaphore(op->name, funcs) && ends_with(op->name, ".head")) { return mutate(op->body); } else { return NoOpCollapsingMutator::visit(op); @@ -231,7 +233,7 @@ class GenerateConsumerBody : public NoOpCollapsingMutator { } Stmt visit(const Store *op) override { - if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) { + if (is_folding_semaphore(op->name, funcs) && ends_with(op->name, ".head")) { return Evaluate::make(0); } else { return NoOpCollapsingMutator::visit(op); @@ -243,7 +245,7 @@ class GenerateConsumerBody : public NoOpCollapsingMutator { // Ones from folding should go to the producer side. const Variable *var = op->semaphore.as(); internal_assert(var); - if (starts_with(var->name, func + ".folding_semaphore.")) { + if (is_folding_semaphore(var->name, funcs)) { return mutate(op->body); } else { return NoOpCollapsingMutator::visit(op); @@ -251,8 +253,8 @@ class GenerateConsumerBody : public NoOpCollapsingMutator { } public: - GenerateConsumerBody(const string &f, const vector &s) - : func(f), sema(s) { + GenerateConsumerBody(const vector &fg) + : funcs(fg) { } }; @@ -304,10 +306,60 @@ class CountConsumeNodes : public IRVisitor { int count = 0; }; +class InsertProducerSemaphores : public IRMutator { + const string &func; + vector sema; + + using IRMutator::visit; + + Stmt visit(const ProducerConsumer *op) override { + if (op->name == func && op->is_producer) { + // Add post-synchronization + Stmt body = op->body; + while (!sema.empty()) { + Expr release = Call::make(Int(32), "halide_semaphore_release", {sema.back(), 1}, Call::Extern); + body = Block::make(body, Evaluate::make(release)); + sema.pop_back(); + } + return ProducerConsumer::make_produce(op->name, body); + } + return IRMutator::visit(op); + } + +public: + InsertProducerSemaphores(const string &f, const vector &s) + : func(f), sema(s) { + } +}; + +class InsertConsumerSemaphores : public IRMutator { + const string &func; + vector sema; + + using IRMutator::visit; + + Stmt visit(const ProducerConsumer *op) override { + if (op->name == func && !op->is_producer) { + // Synchronize on the work done by the producer before beginning consumption + Expr acquire_sema = sema.back(); + sema.pop_back(); + return Acquire::make(acquire_sema, 1, op); + } + return IRMutator::visit(op); + } + +public: + InsertConsumerSemaphores(const string &f, const vector &s) + : func(f), sema(s) { + } +}; + class ForkAsyncProducers : public IRMutator { using IRMutator::visit; const map &env; + const std::vector> &fused_groups; + std::vector fused_group_func_counter; map cloned_acquires; @@ -315,7 +367,24 @@ class ForkAsyncProducers : public IRMutator { auto it = env.find(op->name); internal_assert(it != env.end()); Function f = it->second; - if (f.schedule().async()) { + + size_t fused_group_index; + for (fused_group_index = 0; fused_group_index < fused_groups.size(); fused_group_index++) { + const auto &fused_group = fused_groups[fused_group_index]; + auto func_it = std::find(fused_group.begin(), fused_group.end(), op->name); + if (func_it != fused_group.end()) { + break; + } + } + internal_assert(fused_group_index < fused_groups.size()); + // We want to make sure that for fused function group transformation is + // applied to the inner Realize node. + fused_group_func_counter[fused_group_index]++; + const auto ¤t_group = fused_groups[fused_group_index]; + + Stmt mutated; + if (f.schedule().async() && + fused_group_func_counter[fused_group_index] == current_group.size()) { Stmt body = op->body; // Make two copies of the body, one which only does the @@ -323,19 +392,28 @@ class ForkAsyncProducers : public IRMutator { // synchronization to preserve dependencies. Put them in a // task-parallel block. + vector> sema_names(current_group.size()); + vector> sema_vars(current_group.size()); + // Make a semaphore per consume node - CountConsumeNodes consumes(op->name); - body.accept(&consumes); - - vector sema_names; - vector sema_vars; - for (int i = 0; i < consumes.count; i++) { - sema_names.push_back(op->name + ".semaphore_" + std::to_string(i)); - sema_vars.push_back(Variable::make(Handle(), sema_names.back())); + for (size_t ix = 0; ix < current_group.size(); ix++) { + CountConsumeNodes consumes(current_group[ix]); + body.accept(&consumes); + + for (int i = 0; i < consumes.count; i++) { + sema_names[ix].push_back(current_group[ix] + ".semaphore_" + std::to_string(i)); + sema_vars[ix].push_back(Variable::make(Handle(), sema_names[ix].back())); + } } - Stmt producer = GenerateProducerBody(op->name, sema_vars, cloned_acquires).mutate(body); - Stmt consumer = GenerateConsumerBody(op->name, sema_vars).mutate(body); + Stmt producer = GenerateProducerBody(current_group, cloned_acquires).mutate(body); + Stmt consumer = GenerateConsumerBody(current_group).mutate(body); + + // Insert producer/consumer semaphores. + for (size_t ix = 0; ix < current_group.size(); ix++) { + producer = InsertProducerSemaphores(current_group[ix], sema_vars[ix]).mutate(producer); + consumer = InsertConsumerSemaphores(current_group[ix], sema_vars[ix]).mutate(consumer); + } // Recurse on both sides producer = mutate(producer); @@ -344,33 +422,38 @@ class ForkAsyncProducers : public IRMutator { // Run them concurrently body = Fork::make(producer, consumer); - for (const string &sema_name : sema_names) { + for (const auto &names : sema_names) { // Make a semaphore on the stack Expr sema_space = Call::make(type_of(), "halide_make_semaphore", {0}, Call::Extern); - - // If there's a nested async producer, we may have - // recursively cloned this semaphore inside the mutation - // of the producer and consumer. - auto it = cloned_acquires.find(sema_name); - if (it != cloned_acquires.end()) { - body = CloneAcquire(sema_name, it->second).mutate(body); - body = LetStmt::make(it->second, sema_space, body); + for (const string &sema_name : names) { + + // If there's a nested async producer, we may have + // recursively cloned this semaphore inside the mutation + // of the producer and consumer. + auto it = cloned_acquires.find(sema_name); + if (it != cloned_acquires.end()) { + body = CloneAcquire(sema_name, it->second).mutate(body); + body = LetStmt::make(it->second, sema_space, body); + } + + body = LetStmt::make(sema_name, sema_space, body); } - - body = LetStmt::make(sema_name, sema_space, body); } - return Realize::make(op->name, op->types, op->memory_type, - op->bounds, op->condition, body); + mutated = Realize::make(op->name, op->types, op->memory_type, + op->bounds, op->condition, body); } else { - return IRMutator::visit(op); + mutated = IRMutator::visit(op); } + + fused_group_func_counter[fused_group_index]--; + return mutated; } public: - ForkAsyncProducers(const map &e) - : env(e) { + ForkAsyncProducers(const map &e, const std::vector> &fg) + : env(e), fused_groups(fg), fused_group_func_counter(fg.size()) { } }; @@ -651,9 +734,10 @@ class TightenForkNodes : public IRMutator { // TODO: merge semaphores? -Stmt fork_async_producers(Stmt s, const map &env) { +Stmt fork_async_producers(Stmt s, const map &env, + const std::vector> &fused_groups) { s = TightenProducerConsumerNodes(env).mutate(s); - s = ForkAsyncProducers(env).mutate(s); + s = ForkAsyncProducers(env, fused_groups).mutate(s); s = ExpandAcquireNodes().mutate(s); s = TightenForkNodes().mutate(s); s = InitializeSemaphores().mutate(s); diff --git a/src/AsyncProducers.h b/src/AsyncProducers.h index dffc331c60ae..016cbf146fa1 100644 --- a/src/AsyncProducers.h +++ b/src/AsyncProducers.h @@ -14,7 +14,8 @@ namespace Internal { class Function; -Stmt fork_async_producers(Stmt s, const std::map &env); +Stmt fork_async_producers(Stmt s, const std::map &env, + const std::vector> &fused_groups); } // namespace Internal } // namespace Halide diff --git a/src/Lower.cpp b/src/Lower.cpp index dc6d7f4174de..35700db2f037 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -245,7 +245,7 @@ Module lower(const vector &output_funcs, << s << "\n\n"; debug(1) << "Forking asynchronous producers...\n"; - s = fork_async_producers(s, env); + s = fork_async_producers(s, env, fused_groups); debug(2) << "Lowering after forking asynchronous producers:\n" << s << "\n"; diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index e67dadbd8272..f91e967c5570 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -2179,6 +2179,11 @@ void validate_fused_group_schedule_helper(const string &fn, << ") and " << p.func_2 << ".s" << p.stage_2 << " (" << func_2.schedule().compute_level().to_string() << ") do not match.\n"; + // Verify that they have matching async flags. + user_assert(func_1.schedule().async() == func_2.schedule().async()) + << "Invalid compute_with: functions " << func_1.name() + << " and " << func_2.name() << " have different async flags.\n"; + const vector &dims_1 = def_1.schedule().dims(); const vector &dims_2 = def_2.schedule().dims(); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 17a4d4401aa0..bea0e22bd5bb 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -9,6 +9,7 @@ tests(GROUPS correctness argmax.cpp assertion_failure_in_parallel_for.cpp async.cpp + async_compute_with.cpp async_copy_chain.cpp async_device_copy.cpp atomic_tuples.cpp diff --git a/test/correctness/async_compute_with.cpp b/test/correctness/async_compute_with.cpp new file mode 100644 index 000000000000..997f06e928a3 --- /dev/null +++ b/test/correctness/async_compute_with.cpp @@ -0,0 +1,152 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + // Two producers scheduled as async and two separate consumers. + { + Func producer1, producer2, consumer, consumer1, consumer2; + Var x, y; + + producer1(x, y) = x + y; + producer2(x, y) = 3 * x + 2 * y; + consumer1(x, y) = producer2(x, y); + consumer2(x, y) = producer1(x, y); + consumer(x, y) = consumer1(x, y) + consumer2(x, y); + + consumer.compute_root(); + consumer1.compute_root(); + consumer2.compute_root(); + + producer1.compute_root().async(); + producer2.compute_root().compute_with(producer1, Var::outermost()).async(); + + consumer.bound(x, 0, 16).bound(y, 0, 16); + + Buffer out = consumer.realize(16, 16); + + out.for_each_element([&](int x, int y) { + int correct = 4 * x + 3 * y; + if (out(x, y) != correct) { + printf("out(%d, %d) = %d instead of %d\n", + x, y, out(x, y), correct); + exit(-1); + } + }); + } + + // Two producers scheduled as async and one consumers. + { + Func producer1, producer2, producer3, consumer, consumer1, consumer2; + Var x, y; + + producer1(x, y) = x + y; + producer2(x, y) = 3 * x + 2 * y; + consumer(x, y) = producer1(x, y - 3) + producer1(x, y + 3) + producer2(x, y - 1) + producer2(x, y + 1); + consumer.compute_root(); + producer1.compute_at(consumer, y).store_root().async(); + producer2.compute_at(consumer, y).store_root().compute_with(producer1, y).async(); + + consumer.bound(x, 0, 16).bound(y, 0, 16); + + Buffer out = consumer.realize(16, 16); + + out.for_each_element([&](int x, int y) { + int correct = 8 * x + 6 * y; + if (out(x, y) != correct) { + printf("out(%d, %d) = %d instead of %d\n", + x, y, out(x, y), correct); + exit(-1); + } + }); + } + // Two fused producers + one producer scheduled as async and one consumers. + { + Func producer1, producer2, producer3, consumer; + Var x, y; + + producer1(x, y) = x + y; + producer2(x, y) = 3 * x + 2 * y; + producer3(x, y) = x + y; + consumer(x, y) = producer1(x, y - 1) + producer1(x, y + 1) + producer2(x, y - 1) + producer2(x, y + 1) + producer3(x, y); + consumer.compute_root(); + producer1.compute_at(consumer, y).store_root().async(); + producer2.compute_at(consumer, y).store_root().compute_with(producer1, y).async(); + producer3.compute_at(consumer, y).store_root().async(); + + consumer.bound(x, 0, 16).bound(y, 0, 16); + + Buffer out = consumer.realize(16, 16); + + out.for_each_element([&](int x, int y) { + int correct = 9 * x + 7 * y; + if (out(x, y) != correct) { + printf("out(%d, %d) = %d instead of %d\n", + x, y, out(x, y), correct); + exit(-1); + } + }); + } + + // Two producers scheduled as async + one producer and one consumer. + { + Func producer1, producer2, producer3, consumer; + Var x, y; + + producer1(x, y) = x + y; + producer2(x, y) = 3 * x + 2 * y; + producer3(x, y) = x + y; + consumer(x, y) = producer1(x, y - 1) + producer1(x, y + 1) + producer2(x, y - 1) + producer2(x, y + 1) + producer3(x, y); + consumer.compute_root(); + producer1.compute_at(consumer, y).store_root().async(); + producer2.compute_at(consumer, y).store_root().compute_with(producer1, y).async(); + // producer3 is not async. + producer3.compute_at(consumer, y).store_root(); + + consumer.bound(x, 0, 16).bound(y, 0, 16); + + Buffer out = consumer.realize(16, 16); + + out.for_each_element([&](int x, int y) { + int correct = 9 * x + 7 * y; + if (out(x, y) != correct) { + printf("out(%d, %d) = %d instead of %d\n", + x, y, out(x, y), correct); + exit(-1); + } + }); + } + + // Two producers scheduled as async and two separate consumers. + { + Func producer1, producer2, producer3, consumer, consumer1, consumer2; + Var x, y; + + producer1(x, y) = x + y; + producer2(x, y) = 3 * x + 2 * y; + consumer1(x, y) = 2 * producer1(x, y) + producer2(x, y); + consumer2(x, y) = producer1(x, y) + 2 * producer2(x, y); + consumer(x, y) = consumer1(x, y) + consumer2(x, y); + consumer.compute_root(); + consumer1.compute_root(); + consumer2.compute_root(); + producer1.compute_root().async(); + producer2.compute_root().compute_with(producer1, Var::outermost()).async(); + + consumer.bound(x, 0, 16).bound(y, 0, 16); + + Buffer out = consumer.realize(16, 16); + + out.for_each_element([&](int x, int y) { + int correct = 12 * x + 9 * y; + if (out(x, y) != correct) { + printf("out(%d, %d) = %d instead of %d\n", + x, y, out(x, y), correct); + exit(-1); + } + }); + } + + printf("Success!\n"); + return 0; +} diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index f92466bc34d4..4987207f0ba2 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -2,6 +2,7 @@ tests(GROUPS error EXPECT_FAILURE SOURCES ambiguous_inline_reductions.cpp + async_compute_with.cpp async_require_fail.cpp atomics_gpu_8_bit.cpp atomics_gpu_mutex.cpp diff --git a/test/error/async_compute_with.cpp b/test/error/async_compute_with.cpp new file mode 100644 index 000000000000..9b61793732fb --- /dev/null +++ b/test/error/async_compute_with.cpp @@ -0,0 +1,24 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func producer1, producer2, consumer; + Var x, y; + + producer1(x, y) = x + y; + producer2(x, y) = 3 * x + 2 * y; + consumer(x, y) = producer1(x, y - 1) + producer1(x, y + 1) + producer2(x, y - 1) + producer2(x, y + 1); + consumer.compute_root(); + // Both functions should have been scheduled as async. + producer1.compute_at(consumer, y).store_root(); + producer2.compute_at(consumer, y).store_root().compute_with(producer1, y).async(); + + consumer.bound(x, 0, 16).bound(y, 0, 16); + + Buffer out = consumer.realize(16, 16); + + printf("Success!\n"); + return 0; +}