Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);

using runtime::PackedFunc;
using runtime::TVMArgs;
Expand Down
51 changes: 34 additions & 17 deletions src/tir/transforms/inject_software_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,10 @@ class PipelineRewriter : public StmtExprMutator {
const Array<Buffer> pipeline_allocs, const For& pipeline_loop,
const PipelineInfo& pipeline_info,
const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info,
const Map<String, ObjectRef> preserved_annotations) {
const Map<String, ObjectRef> preserved_annotations, bool merge_async_commit_queue_scope) {
PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop,
pipeline_info, fragment_info, preserved_annotations);
pipeline_info, fragment_info, preserved_annotations,
merge_async_commit_queue_scope);
return rewriter.BuildPipeline();
}

Expand All @@ -321,15 +322,17 @@ class PipelineRewriter : public StmtExprMutator {
const Array<Buffer>& pipeline_allocs, const For& pipeline_loop,
const PipelineInfo& pipeline_info,
const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info,
const Map<String, ObjectRef> preserved_annotations)
const Map<String, ObjectRef> preserved_annotations,
bool merge_async_commit_queue_scope)

: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
double_buffers_(double_buffers),
pipeline_allocs_(pipeline_allocs),
pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info),
fragment_info_(fragment_info),
preserved_annotations_(preserved_annotations) {}
preserved_annotations_(preserved_annotations),
merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {}

Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions
Expand Down Expand Up @@ -762,11 +765,19 @@ class PipelineRewriter : public StmtExprMutator {
<< "Predicates in the same stage are expected to be identical";
group_bodies.push_back(new_blocks[i].block->body);
}
auto body = group_bodies.size() > 1 ? SeqStmt(group_bodies) : group_bodies[0];
auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
tir::attr::async_commit_queue_scope, stage_id, body);
auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
stmts.push_back(BlockRealize({}, predicate, new_block));

if (merge_async_commit_queue_scope_ && group_bodies.size() > 1) {
auto merged_bodies = SeqStmt(group_bodies);
group_bodies.clear();
group_bodies.push_back(merged_bodies);
}

for (auto body : group_bodies) {
auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
tir::attr::async_commit_queue_scope, stage_id, body);
auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
stmts.push_back(BlockRealize({}, predicate, new_block));
}
}
}

Expand Down Expand Up @@ -842,7 +853,8 @@ class PipelineRewriter : public StmtExprMutator {
auto& local_state = async_states_local[stage];

int commit_group_id = -1;
if (local_state.commit_groups.empty() || local_state.consumed) {
if (local_state.commit_groups.empty() || local_state.consumed ||
!merge_async_commit_queue_scope_) {
// consumed == true means there is already a consumer stage waiting for an
// eariler async operation of this stage. In such cases, we make multiple commit_queue
// for this stage.
Expand Down Expand Up @@ -942,6 +954,7 @@ class PipelineRewriter : public StmtExprMutator {
Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states;
Map<String, ObjectRef> preserved_annotations_;
bool merge_async_commit_queue_scope_ = true;
};

/*!
Expand Down Expand Up @@ -980,8 +993,8 @@ void BuildDependencyGraph(

class PipelineInjector : private StmtExprMutator {
public:
static Stmt Inject(const PrimFunc& func) {
PipelineInjector injector;
static Stmt Inject(const PrimFunc& func, bool merge_async_commit_queue_scope) {
PipelineInjector injector(merge_async_commit_queue_scope);
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
Expand All @@ -991,7 +1004,8 @@ class PipelineInjector : private StmtExprMutator {
}

private:
PipelineInjector() = default;
explicit PipelineInjector(bool merge_async_commit_queue_scope)
: merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {}

/*!
* \brief Check the pipeline satisfies the following conditions:
Expand Down Expand Up @@ -1126,9 +1140,9 @@ class PipelineInjector : private StmtExprMutator {
ValidatePipelineBody(pipeline_info, original_order);

// Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers,
pipeline_allocs, GetRef<For>(op), pipeline_info,
fragment_info_, preserved_annotations);
Stmt pipeline = PipelineRewriter::Rewrite(
buffer_data_to_buffer_, double_buffers, pipeline_allocs, GetRef<For>(op), pipeline_info,
fragment_info_, preserved_annotations, merge_async_commit_queue_scope_);

if (const auto* realize = op->body.as<BlockRealizeNode>()) {
const auto& block = realize->block;
Expand Down Expand Up @@ -1197,6 +1211,7 @@ class PipelineInjector : private StmtExprMutator {
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<const VarNode*, FragmentInfo> fragment_info_;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> double_buffers;
bool merge_async_commit_queue_scope_ = true;
};

} // namespace software_pipeline
Expand All @@ -1210,7 +1225,9 @@ namespace transform {
Pass InjectSoftwarePipeline() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* fptr = f.CopyOnWrite();
fptr->body = software_pipeline::PipelineInjector::Inject(f);
bool merge_async_commit_queue_scope =
ctx->GetConfig<Bool>("tir.merge_async_commit_queue_scope", Bool(true)).value();
fptr->body = software_pipeline::PipelineInjector::Inject(f, merge_async_commit_queue_scope);
fptr->body = ConvertSSA(std::move(fptr->body));
return f;
};
Expand Down
7 changes: 4 additions & 3 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ class AsyncDMALowerer : public StmtExprMutator {
ICHECK(queue_id_node);
int queue_id = queue_id_node->value;

// save queue ID for inspection in `wait` transform
queue_ids.insert(queue_id);

// walk the graph to verify this is a mem copy ...
// 1) async_commit_queue_scope contains async_scope
auto async_scope = op->body.as<AttrStmtNode>();
Expand Down Expand Up @@ -161,6 +158,10 @@ class AsyncDMALowerer : public StmtExprMutator {
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
});

// now that we are about to perform the `copy` transform
// save queue ID for inspection in `wait` transform
queue_ids.insert(queue_id);

return Evaluate(Call(DataType::Int(32), builtin::dma_copy(),
{queue_id,
Call(DataType::Handle(), builtin::address_of(),
Expand Down
180 changes: 134 additions & 46 deletions tests/python/contrib/test_hexagon/test_software_pipeline_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,81 +30,169 @@
inner = tvm.testing.parameter(64, 128)
dtype = tvm.testing.parameter("uint8", "float16")
scope = tvm.testing.parameter("global", "global.vtcm")
# TODO(Joseph) Turn on "multi_input_diffQ" compute type once we have upstreamed
# changes in the InjectSoftwarePipeline pass to alleviate this restriction:
# 'A dependency on multiple async stages is not supported'
comp_type = tvm.testing.parameter("single_input", "multi_input_sameQ")
# TODO(Straw) Add back "cache_write" schedule type once we have upstreamed
# buffer dependency analysis in InjectSoftwarePipeline pass
# to insert approprite TIR "wait" attributes for this schedule
sched = tvm.testing.parameter("cache_read", "cache_read_write")
sched_type = tvm.testing.parameter("cache_read", "cache_read_write")


@tvm.testing.fixture
def compute(outer, inner, dtype):
@T.prim_func
def plus_one_primfunc(A: T.Buffer[(outer, inner), dtype], B: T.Buffer[(outer, inner), dtype]):
for i in T.serial(outer):
for j in T.serial(inner):
with T.block("compute"):
with T.block():
B[i, j] = A[i, j] + T.cast(1, dtype)
def data(comp_type, outer, inner, dtype):
out_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype)
a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype)
if comp_type == "single_input":
return out_np, a_np
else:
b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype)
return out_np, a_np, b_np


@tvm.testing.fixture
def compute(comp_type, outer, inner, dtype):
if comp_type == "single_input":

@T.prim_func
def a_plus_1_primfunc(
A: T.Buffer[(outer, inner), dtype], OUT: T.Buffer[(outer, inner), dtype]
):
for i in T.serial(outer):
for j in T.serial(inner):
with T.block("compute"):
with T.block():
OUT[i, j] = A[i, j] + T.cast(1, dtype)

return a_plus_1_primfunc
else:

@T.prim_func
def a_plus_b_plus_1_primfunc(
A: T.Buffer[(outer, inner), dtype],
B: T.Buffer[(outer, inner), dtype],
OUT: T.Buffer[(outer, inner), dtype],
):
for i in T.serial(outer):
for j in T.serial(inner):
with T.block("compute"):
with T.block():
OUT[i, j] = A[i, j] + B[i, j] + T.cast(1, dtype)

return a_plus_b_plus_1_primfunc


@tvm.testing.fixture
def reference(comp_type):
if comp_type == "single_input":

def a_plus_1_ref(a):
return a + 1

return a_plus_1_ref
else:

def plus_one_ref(a):
return a + 1
def a_plus_b_plus_1_ref(a, b):
return a + b + 1

return plus_one_primfunc, plus_one_ref
return a_plus_b_plus_1_ref


@tvm.testing.fixture
def schedule(compute, sched, scope):
sch = tir.Schedule(compute[0])
def schedule(comp_type, compute, sched_type, scope):
sch = tir.Schedule(compute)

compute_block = sch.get_block("compute")
i, _ = sch.get_loops(compute_block)

if sched == "cache_read":
cache_read_block = sch.cache_read(compute_block, 0, scope)
sch.compute_at(cache_read_block, i)
sch.annotate(i, "software_pipeline_stage", [0, 1])
sch.annotate(i, "software_pipeline_order", [0, 1])
sch.annotate(i, "software_pipeline_async_stages", [0])
elif sched == "cache_write":
cache_write_block = sch.cache_write(compute_block, 0, scope)
sch.reverse_compute_at(cache_write_block, i)
if "read" in sched_type:
cache_read_a = sch.cache_read(compute_block, 0, scope)
sch.compute_at(cache_read_a, i)

if "multi_input" in comp_type:
cache_read_b = sch.cache_read(compute_block, 1, scope)
sch.compute_at(cache_read_b, i)

if "write" in sched_type:
cache_write_out = sch.cache_write(compute_block, 0, scope)
sch.reverse_compute_at(cache_write_out, i)

if "read" in sched_type and "write" in sched_type:
if comp_type == "single_input":
sch.annotate(i, "software_pipeline_stage", [0, 1, 2])
sch.annotate(i, "software_pipeline_order", [0, 1, 2])
sch.annotate(i, "software_pipeline_async_stages", [0, 2])
elif comp_type == "multi_input_sameQ":
sch.annotate(i, "software_pipeline_stage", [0, 0, 1, 2])
sch.annotate(i, "software_pipeline_order", [0, 1, 2, 3])
sch.annotate(i, "software_pipeline_async_stages", [0, 2])
elif comp_type == "multi_input_diffQ":
sch.annotate(i, "software_pipeline_stage", [0, 1, 2, 3])
sch.annotate(i, "software_pipeline_order", [0, 1, 2, 3])
sch.annotate(i, "software_pipeline_async_stages", [0, 1, 2])

elif "read" in sched_type:
if comp_type == "single_input":
sch.annotate(i, "software_pipeline_stage", [0, 1])
sch.annotate(i, "software_pipeline_order", [0, 1])
sch.annotate(i, "software_pipeline_async_stages", [0])
elif comp_type == "multi_input_sameQ":
sch.annotate(i, "software_pipeline_stage", [0, 0, 1])
sch.annotate(i, "software_pipeline_order", [0, 1, 2])
sch.annotate(i, "software_pipeline_async_stages", [0])
elif comp_type == "multi_input_diffQ":
sch.annotate(i, "software_pipeline_stage", [0, 1, 2])
sch.annotate(i, "software_pipeline_order", [0, 1, 2])
sch.annotate(i, "software_pipeline_async_stages", [0, 1])

elif "write" in sched_type:
sch.annotate(i, "software_pipeline_stage", [0, 1])
sch.annotate(i, "software_pipeline_order", [0, 1])
sch.annotate(i, "software_pipeline_async_stages", [1])
elif sched == "cache_read_write":
cache_read_block = sch.cache_read(compute_block, 0, scope)
sch.compute_at(cache_read_block, i)
cache_write_block = sch.cache_write(compute_block, 0, scope)
sch.reverse_compute_at(cache_write_block, i)
sch.annotate(i, "software_pipeline_stage", [0, 1, 2])
sch.annotate(i, "software_pipeline_order", [0, 1, 2])
sch.annotate(i, "software_pipeline_async_stages", [0, 2])

return sch


@tvm.testing.requires_hexagon
def test_async_software_pipeline(hexagon_launcher, compute, schedule, outer, inner, dtype, scope):
sch = schedule
@tvm.testing.fixture
def verify(dtype):
def check(out, ref):
if "int" in dtype:
np.testing.assert_equal(out.numpy(), ref)
else:
np.testing.assert_allclose(out.numpy(), ref, rtol=1e-3, atol=1e-3)

return check

a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype)
b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype)
ref = compute[1](a_np)

with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
func = tvm.build(sch.mod["main"], target=get_hexagon_target("v68"))
@tvm.testing.requires_hexagon
def test_async_software_pipeline(hexagon_launcher, comp_type, data, reference, schedule, verify):
out_np = data[0]
a_np = data[1]
if comp_type == "single_input":
ref = reference(a_np)
else:
b_np = data[2]
ref = reference(a_np, b_np)

with tvm.transform.PassContext(
config={"tir.use_async_copy": 1, "tir.merge_async_commit_queue_scope": False}
):
# tvm.lower(schedule.mod["main"]).show()
func = tvm.build(schedule.mod["main"], target=get_hexagon_target("v68"))

with hexagon_launcher.start_session() as hexagon_session:
dev = hexagon_session.device
a = tvm.nd.array(a_np, device=dev)
b = tvm.nd.array(b_np, device=dev)
mod = hexagon_session.load_module(func)
mod(a, b)

if "int" in dtype:
np.testing.assert_equal(b.numpy(), ref)
out = tvm.nd.array(out_np, device=dev)
a = tvm.nd.array(a_np, device=dev)
if comp_type == "single_input":
mod(a, out)
else:
np.testing.assert_allclose(b.numpy(), ref, rtol=1e-3, atol=1e-3)
b = tvm.nd.array(b_np, device=dev)
mod(a, b, out)

verify(out, ref)


if __name__ == "__main__":
Expand Down