Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/

#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

Expand All @@ -34,6 +35,12 @@ class AsyncDMALowerer : public StmtExprMutator {
public:
explicit AsyncDMALowerer(bool dma_bypass_cache) : dma_bypass_cache_(dma_bypass_cache) {}

// Create member statement to track a mapping from iter var to iter range
Stmt VisitStmt_(const ForNode* op) final {
input_iters.Set(op->loop_var, Range(op->min, op->extent));
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const AttrStmtNode* op) final {
// Convert this, for example:
// attr [0] "async_wait_queue_scope" = 0;
Expand Down Expand Up @@ -146,13 +153,33 @@ class AsyncDMALowerer : public StmtExprMutator {

// map loop variable to zero for the store index & simplify
Array<PrimExpr> store_index = bufferstorenode->indices;

// Use DetectIterMap to detect whether store index is non-contiguous.
arith::Analyzer analyzer;
auto store_iter_map = DetectIterMap(store_index, input_iters, 1, arith::IterMapLevel::NoCheck,
&analyzer, false);
if (!store_iter_map->errors.empty()) {
LOG(FATAL)
<< "Unable to lower async dma for non contiguous memory access with store index: "
<< store_index;
}

store_index.MutateByApply([&](PrimExpr expr) {
arith::Analyzer analyzer;
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
});

// map loop variable to zero for the load index & simplify
Array<PrimExpr> load_index = bufferloadnode->indices;

// Use DetectIterMap to detect whether load index is non-contiguous.
auto load_iter_map =
DetectIterMap(load_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false);
if (!load_iter_map->errors.empty()) {
LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with load index: "
<< load_index;
}

load_index.MutateByApply([&](PrimExpr expr) {
arith::Analyzer analyzer;
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
Expand All @@ -176,6 +203,7 @@ class AsyncDMALowerer : public StmtExprMutator {
private:
std::set<int> queue_ids_;
bool dma_bypass_cache_;
Map<Var, Range> input_iters = Map<Var, Range>();
};

namespace transform {
Expand Down
206 changes: 206 additions & 0 deletions tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,193 @@
VRMPY_SIZE_B = 128
VRMPY_SIZE_INT32 = 32

# pylint: disable=invalid-name
@T.prim_func
def conv2d_async_non_contig(
p0: T.Buffer[(T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)), "uint8"],
fused_constant_1: T.Buffer[
(T.int64(1), T.int64(1), T.int64(3), T.int64(3), T.int64(1), T.int64(32), T.int64(4)),
"uint8",
],
conv2d_NCHWc_int8: T.Buffer[
(T.int64(1), T.int64(1), T.int64(54), T.int64(54), T.int64(32)), "int32"
],
):
"""Non contiguous memory access is used in this conv2d taken from MS."""
# pylint: disable=no-self-argument
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
p0_global_vtcm = T.alloc_buffer(
[T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)],
dtype="uint8",
scope="global.vtcm",
)
fused_constant_global_vtcm = T.alloc_buffer(
[T.int64(1), T.int64(1), T.int64(3), T.int64(3), T.int64(1), T.int64(32), T.int64(4)],
dtype="uint8",
scope="global.vtcm",
)
for oh_0 in T.serial(T.int64(3)):
for ow_0 in T.serial(
T.int64(3),
annotations={
"software_pipeline_async_stages": [0],
"software_pipeline_order": [0, 1, 2],
"software_pipeline_stage": [0, 0, 1],
},
):
for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(1600)):
with T.block("p0_global.vtcm"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(1), T.int64(0))
v2 = T.axis.spatial(
T.int64(56), oh_0 * T.int64(18) + ax0_ax1_ax2_ax3_ax4_fused // T.int64(80)
)
v3 = T.axis.spatial(
T.int64(56),
ow_0 * T.int64(18) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(80) // T.int64(4),
)
v4 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_fused % T.int64(4))
T.reads(p0[v0, v1, v2, v3, v4])
T.writes(p0_global_vtcm[v0, v1, v2, v3, v4])
p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4]
for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(1152)):
with T.block("fused_constant_global.vtcm"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(1), T.int64(0))
v2 = T.axis.spatial(
T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // T.int64(384)
)
v3 = T.axis.spatial(
T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(384) // T.int64(128)
)
v4 = T.axis.spatial(T.int64(1), T.int64(0))
v5 = T.axis.spatial(
T.int64(32), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(128) // T.int64(4)
)
v6 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4))
T.reads(fused_constant_1[v0, v1, v2, v3, v4, v5, v6])
T.writes(fused_constant_global_vtcm[v0, v1, v2, v3, v4, v5, v6])
fused_constant_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = fused_constant_1[
v0, v1, v2, v3, v4, v5, v6
]
for oh_1, ow_1 in T.grid(T.int64(3), T.int64(6)):
for oh_2_init, ow_2_init in T.grid(T.int64(6), T.int64(3)):
with T.block("conv2d_NCHWc_int8_o_init"):
v_n = T.axis.spatial(T.int64(1), T.int64(0))
v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0))
v_oh = T.axis.spatial(
T.int64(54), oh_0 * T.int64(18) + oh_1 * T.int64(6) + oh_2_init
)
v_ow = T.axis.spatial(
T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2_init
)
T.reads()
T.writes(
conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)]
)
for oc_block_1 in T.vectorized(T.int64(32)):
with T.block("conv2d_NCHWc_int8_init"):
v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1)
T.reads()
T.writes(
conv2d_NCHWc_int8[
v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init
]
)
conv2d_NCHWc_int8[
v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init
] = 0
for kh_1, kw_1, oh_2, ow_2 in T.grid(
T.int64(3), T.int64(3), T.int64(6), T.int64(3)
):
with T.block("conv2d_NCHWc_int8_o_update"):
v_n = T.axis.spatial(T.int64(1), T.int64(0))
v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0))
v_oh = T.axis.spatial(
T.int64(54), oh_0 * T.int64(18) + oh_1 * T.int64(6) + oh_2
)
v_ow = T.axis.spatial(
T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2
)
v_kh, v_kw = T.axis.remap("RR", [kh_1, kw_1])
v_ic_outer = T.axis.reduce(T.int64(1), T.int64(0))
v_ic_f_inner = T.axis.reduce(T.int64(1), T.int64(0))
T.reads(
conv2d_NCHWc_int8[
v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)
],
p0_global_vtcm[
v_n,
v_ic_outer,
v_oh + v_kh,
v_ow + v_kw,
v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4),
],
fused_constant_global_vtcm[
v_oc_chunk,
v_ic_outer,
v_kh,
v_kw,
v_ic_f_inner,
T.int64(0) : T.int64(32),
T.int64(0) : T.int64(4),
],
)
T.writes(
conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)]
)
A = T.match_buffer(
p0_global_vtcm[
v_n,
v_ic_outer,
v_oh + v_kh,
v_ow + v_kw,
v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4),
],
[T.int64(4)],
dtype="uint8",
scope="global.vtcm",
offset_factor=1,
)
B = T.match_buffer(
fused_constant_global_vtcm[
v_oc_chunk,
v_ic_outer,
v_kh,
v_kw,
v_ic_f_inner,
T.int64(0) : T.int64(32),
T.int64(0) : T.int64(4),
],
[T.int64(32), T.int64(4)],
dtype="uint8",
scope="global.vtcm",
offset_factor=1,
)
C = T.match_buffer(
conv2d_NCHWc_int8[
v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)
],
[T.int64(32)],
dtype="int32",
offset_factor=1,
)
A_u8x4: T.uint8x4 = A[T.int64(0) : T.int64(4)]
A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
B_i8x128 = B[T.int64(0), T.int64(0) : T.int64(128)]
B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
C[0:32] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"),
T.uint32(3),
C[0:32],
B_i32x32,
A_i32,
dtype="int32x32",
)


def conv_approximation(size_a, size_w):
"""Conv approximation."""
Expand Down Expand Up @@ -695,5 +882,24 @@ def test_meta(hexagon_session):
)


def test_non_contiguous():
"""Test Non Contiguous memory lowering."""
sch = tvm.tir.Schedule(conv2d_async_non_contig)
target_hexagon = tvm.target.hexagon("v68", link_params=True)
err_rgx = r"Unable to lower async dma for non contiguous memory access with load index: "
# Currently we do not support non contiguous memory access being lowered to
# async dma so we throw an error.
with pytest.raises(tvm.TVMError, match=err_rgx):
with tvm.transform.PassContext(
config={
"tir.use_async_copy": 1,
"tir.merge_async_commit_queue_scope": 0,
}
):
tvm.build(
sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)
)


if __name__ == "__main__":
tvm.testing.main()