Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,65 @@ class TightenForkNodes : public IRMutator {
bool in_fork = false;
};

// If in realize node A there is a producer node B and producer A is async and is
// consumer of B then B must be scheduled as async.
class CheckAsyncOrder : public IRVisitor {
const map<string, Function> &env;

Scope<> in_realizes;
Scope<> in_consumers;
map<string, set<string>> realizes_around_producers;

using IRVisitor::visit;

void visit(const ProducerConsumer *op) override {
if (op->is_producer) {
for (auto r = in_realizes.cbegin(); r != in_realizes.cend(); ++r) {
realizes_around_producers[op->name].insert(r.name());
}
auto it = env.find(op->name);
internal_assert(it != env.end());
Function f = it->second;
if (f.schedule().async()) {
for (auto r = in_realizes.cbegin(); r != in_realizes.cend(); ++r) {
if (in_consumers.contains(r.name()) && realizes_around_producers[r.name()].count(op->name)) {
auto other_it = env.find(r.name());
internal_assert(other_it != env.end());
Function other_f = other_it->second;
user_assert(other_f.schedule().async()) << "Invalid async: producer " << op->name
<< " is a consumer of " << r.name() << ", but " << r.name() << " is inside of the "
<< op->name << " Realize node, "
<< "which is scheduled as async(), so " << r.name()
<< " must be scheduled as async() too.";
}
}
}
} else {
in_consumers.push(op->name);
}
IRVisitor::visit(op);
if (op->is_producer) {
} else {
in_consumers.pop(op->name);
}
}

void visit(const Realize *op) override {
ScopedBinding<> bind(in_realizes, op->name);
IRVisitor::visit(op);
}

public:
CheckAsyncOrder(const map<string, Function> &e)
: env(e) {
}
};

// TODO: merge semaphores?

Stmt fork_async_producers(Stmt s, const map<string, Function> &env) {
CheckAsyncOrder check(env);
s.accept(&check);
s = TightenProducerConsumerNodes(env).mutate(s);
s = ForkAsyncProducers(env).mutate(s);
s = ExpandAcquireNodes().mutate(s);
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tests(GROUPS correctness
async.cpp
async_copy_chain.cpp
async_device_copy.cpp
async_order.cpp
atomic_tuples.cpp
atomics.cpp
autodiff.cpp
Expand Down
87 changes: 87 additions & 0 deletions test/correctness/async_order.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
{
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer1(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.compute_at(consumer, y);
producer2.compute_at(consumer, y).async();

consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize(16, 16);

out.for_each_element([&](int x, int y) {
int correct = 2 * (x + y);
if (out(x, y) != correct) {
printf("out(%d, %d) = %d instead of %d\n",
x, y, out(x, y), correct);
exit(-1);
}
});
}
{
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer1(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.compute_root();
producer2.store_root().compute_at(consumer, y).async();
consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize(16, 16);

out.for_each_element([&](int x, int y) {
int correct = 2 * (x + y);
if (out(x, y) != correct) {
printf("out(%d, %d) = %d instead of %d\n",
x, y, out(x, y), correct);
exit(-1);
}
});
}

{
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer1(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.store_root().compute_at(consumer, y).async();
producer2.store_root().compute_at(consumer, y).async();
consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize(16, 16);

out.for_each_element([&](int x, int y) {
int correct = 2 * (x + y);
if (out(x, y) != correct) {
printf("out(%d, %d) = %d instead of %d\n",
x, y, out(x, y), correct);
exit(-1);
}
});
}

printf("Success!\n");
return 0;
}
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ tests(GROUPS error
EXPECT_FAILURE
SOURCES
ambiguous_inline_reductions.cpp
async_order.cpp
async_require_fail.cpp
atomics_gpu_8_bit.cpp
atomics_gpu_mutex.cpp
Expand Down
24 changes: 24 additions & 0 deletions test/error/async_order.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer1(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.store_root().compute_at(consumer, y);
producer2.store_root().compute_at(consumer, y).async();

consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize(16, 16);

return 0;
}