From 512ccfcb2ce26ca578aac5adc7e81ed83326e63b Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Wed, 14 Dec 2022 09:56:53 -0800 Subject: [PATCH 1/4] Add check for non-contiguous memory access when lowering to async dma copies. --- src/tir/transforms/lower_async_dma.cc | 27 +++ .../test_hexagon/test_async_dma_pipeline.py | 206 ++++++++++++++++++ 2 files changed, 233 insertions(+) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index 9a950c10c776..d03975d09ae1 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -22,6 +22,7 @@ */ #include +#include #include #include @@ -34,6 +35,12 @@ class AsyncDMALowerer : public StmtExprMutator { public: explicit AsyncDMALowerer(bool dma_bypass_cache) : dma_bypass_cache_(dma_bypass_cache) {} + // Create member statement to track a mapping from iter var to iter range + Stmt VisitStmt_(const ForNode* op) final { + input_iters.Set(op->loop_var, Range(op->min, op->extent)); + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const AttrStmtNode* op) final { // Convert this, for example: // attr [0] "async_wait_queue_scope" = 0; @@ -146,6 +153,16 @@ class AsyncDMALowerer : public StmtExprMutator { // map loop variable to zero for the store index & simplify Array store_index = bufferstorenode->indices; + + // Use DetectIterMap to detect whether store index is non-contiguous. + arith::Analyzer analyzer; + auto store_iter_map = + DetectIterMap(store_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false); + if (!store_iter_map->errors.empty()) { + LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with index: " + << store_index; + } + store_index.MutateByApply([&](PrimExpr expr) { arith::Analyzer analyzer; return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); @@ -153,6 +170,15 @@ class AsyncDMALowerer : public StmtExprMutator { // map loop variable to zero for the load index & simplify Array load_index = bufferloadnode->indices; + + // Use DetectIterMap to detect whether load index is non-contiguous. + auto load_iter_map = + DetectIterMap(load_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false); + if (!load_iter_map->errors.empty()) { + LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with index: " + << load_index; + } + load_index.MutateByApply([&](PrimExpr expr) { arith::Analyzer analyzer; return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); @@ -176,6 +202,7 @@ class AsyncDMALowerer : public StmtExprMutator { private: std::set queue_ids_; bool dma_bypass_cache_; + Map input_iters = Map(); }; namespace transform { diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 51427f18f6f4..8f0970f8d61d 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -26,6 +26,194 @@ VRMPY_SIZE_INT32 = 32 +@T.prim_func +def conv2d_async_non_contig( + p0: T.Buffer[(T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)), "uint8"], + fused_constant_1: T.Buffer[ + (T.int64(1), T.int64(1), T.int64(3), T.int64(3), T.int64(1), T.int64(32), T.int64(4)), + "uint8", + ], + conv2d_NCHWc_int8: T.Buffer[ + (T.int64(1), T.int64(1), T.int64(54), T.int64(54), T.int64(32)), "int32" + ], +): + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + p0_global_vtcm = T.alloc_buffer( + [T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)], + dtype="uint8", + scope="global.vtcm", + ) + fused_constant_global_vtcm = T.alloc_buffer( + [T.int64(1), T.int64(1), T.int64(3), T.int64(3), T.int64(1), T.int64(32), T.int64(4)], + dtype="uint8", + scope="global.vtcm", + ) + for oh_0 in T.serial(T.int64(3)): + for ow_0 in T.serial( + T.int64(3), + annotations={ + "software_pipeline_async_stages": [0], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_stage": [0, 0, 1], + }, + ): + for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(1600)): + with T.block("p0_global.vtcm"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial( + T.int64(56), oh_0 * T.int64(18) + ax0_ax1_ax2_ax3_ax4_fused // T.int64(80) + ) + v3 = T.axis.spatial( + T.int64(56), + ow_0 * T.int64(18) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(80) // T.int64(4), + ) + v4 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_fused % T.int64(4)) + T.reads(p0[v0, v1, v2, v3, v4]) + T.writes(p0_global_vtcm[v0, v1, v2, v3, v4]) + p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4] + for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(1152)): + with T.block("fused_constant_global.vtcm"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial( + T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // T.int64(384) + ) + v3 = T.axis.spatial( + T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(384) // T.int64(128) + ) + v4 = T.axis.spatial(T.int64(1), T.int64(0)) + v5 = T.axis.spatial( + T.int64(32), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(128) // T.int64(4) + ) + v6 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4)) + T.reads(fused_constant_1[v0, v1, v2, v3, v4, v5, v6]) + T.writes(fused_constant_global_vtcm[v0, v1, v2, v3, v4, v5, v6]) + fused_constant_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = fused_constant_1[ + v0, v1, v2, v3, v4, v5, v6 + ] + for oh_1, ow_1 in T.grid(T.int64(3), T.int64(6)): + for oh_2_init, ow_2_init in T.grid(T.int64(6), T.int64(3)): + with T.block("conv2d_NCHWc_int8_o_init"): + v_n = T.axis.spatial(T.int64(1), T.int64(0)) + v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0)) + v_oh = T.axis.spatial( + T.int64(54), oh_0 * T.int64(18) + oh_1 * T.int64(6) + oh_2_init + ) + v_ow = T.axis.spatial( + T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2_init + ) + v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads() + T.writes( + conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)] + ) + for oc_block_1 in T.vectorized(T.int64(32)): + with T.block("conv2d_NCHWc_int8_init"): + v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1) + T.reads() + T.writes( + conv2d_NCHWc_int8[ + v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init + ] + ) + conv2d_NCHWc_int8[ + v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init + ] = 0 + for kh_1, kw_1, oh_2, ow_2 in T.grid( + T.int64(3), T.int64(3), T.int64(6), T.int64(3) + ): + with T.block("conv2d_NCHWc_int8_o_update"): + v_n = T.axis.spatial(T.int64(1), T.int64(0)) + v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0)) + v_oh = T.axis.spatial( + T.int64(54), oh_0 * T.int64(18) + oh_1 * T.int64(6) + oh_2 + ) + v_ow = T.axis.spatial( + T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2 + ) + v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0)) + v_kh, v_kw = T.axis.remap("RR", [kh_1, kw_1]) + v_ic_outer = T.axis.reduce(T.int64(1), T.int64(0)) + v_ic_f_inner = T.axis.reduce(T.int64(1), T.int64(0)) + v_ic_s_inner_o = T.axis.reduce(T.int64(1), T.int64(0)) + T.reads( + conv2d_NCHWc_int8[ + v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32) + ], + p0_global_vtcm[ + v_n, + v_ic_outer, + v_oh + v_kh, + v_ow + v_kw, + v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4), + ], + fused_constant_global_vtcm[ + v_oc_chunk, + v_ic_outer, + v_kh, + v_kw, + v_ic_f_inner, + T.int64(0) : T.int64(32), + T.int64(0) : T.int64(4), + ], + ) + T.writes( + conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)] + ) + A = T.match_buffer( + p0_global_vtcm[ + v_n, + v_ic_outer, + v_oh + v_kh, + v_ow + v_kw, + v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4), + ], + [T.int64(4)], + dtype="uint8", + scope="global.vtcm", + offset_factor=1, + ) + B = T.match_buffer( + fused_constant_global_vtcm[ + v_oc_chunk, + v_ic_outer, + v_kh, + v_kw, + v_ic_f_inner, + T.int64(0) : T.int64(32), + T.int64(0) : T.int64(4), + ], + [T.int64(32), T.int64(4)], + dtype="uint8", + scope="global.vtcm", + offset_factor=1, + ) + C = T.match_buffer( + conv2d_NCHWc_int8[ + v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32) + ], + [T.int64(32)], + dtype="int32", + offset_factor=1, + ) + A_u8x4: T.uint8x4 = A[T.int64(0) : T.int64(4)] + A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") + B_i8x128 = B[T.int64(0), T.int64(0) : T.int64(128)] + B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, dtype="int32x32") + C[0:32] = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"), + T.uint32(3), + C[0:32], + B_i32x32, + A_i32, + dtype="int32x32", + ) + + def conv_approximation(size_a, size_w): """Conv approximation.""" a_shape = (size_a, VRMPY_SIZE_B) @@ -695,5 +883,23 @@ def test_meta(hexagon_session): ) +@tvm.testing.requires_hexagon +def test_non_contiguous(hexagon_session): + sch = tvm.tir.Schedule(conv2d_async_non_contig) + target_hexagon = tvm.target.hexagon("v68", link_params=True) + err_rgx = r"Unable to lower async dma for non contiguous memory access with index: " + # Currently we do not support non contiguous memory access being lowered to async dma so we throw an error. + with pytest.raises(tvm.TVMError, match=err_rgx): + with tvm.transform.PassContext( + config={ + "tir.use_async_copy": 1, + "tir.merge_async_commit_queue_scope": 0, + } + ): + tvm.build( + sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon) + ) + + if __name__ == "__main__": tvm.testing.main() From 43af40aac7c4cf22af687ea7fe78057d4cf514bb Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Wed, 14 Dec 2022 12:47:38 -0800 Subject: [PATCH 2/4] lint --- src/tir/transforms/lower_async_dma.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index d03975d09ae1..45a1f4d64080 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -153,11 +153,11 @@ class AsyncDMALowerer : public StmtExprMutator { // map loop variable to zero for the store index & simplify Array store_index = bufferstorenode->indices; - + // Use DetectIterMap to detect whether store index is non-contiguous. arith::Analyzer analyzer; - auto store_iter_map = - DetectIterMap(store_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false); + auto store_iter_map = DetectIterMap(store_index, input_iters, 1, arith::IterMapLevel::NoCheck, + &analyzer, false); if (!store_iter_map->errors.empty()) { LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with index: " << store_index; From d77987ecdfa632bf7b5d49b44cacf2469314e73f Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Thu, 15 Dec 2022 13:28:05 -0800 Subject: [PATCH 3/4] lint and nits --- src/tir/transforms/lower_async_dma.cc | 4 ++-- .../test_hexagon/test_async_dma_pipeline.py | 17 ++++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index 45a1f4d64080..ecdc792cbb7e 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -159,7 +159,7 @@ class AsyncDMALowerer : public StmtExprMutator { auto store_iter_map = DetectIterMap(store_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false); if (!store_iter_map->errors.empty()) { - LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with index: " + LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with store index: " << store_index; } @@ -175,7 +175,7 @@ class AsyncDMALowerer : public StmtExprMutator { auto load_iter_map = DetectIterMap(load_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false); if (!load_iter_map->errors.empty()) { - LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with index: " + LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with load index: " << load_index; } diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 8f0970f8d61d..5cc09a098328 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -25,7 +25,7 @@ VRMPY_SIZE_B = 128 VRMPY_SIZE_INT32 = 32 - +# pylint: disable=invalid-name @T.prim_func def conv2d_async_non_contig( p0: T.Buffer[(T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)), "uint8"], @@ -37,6 +37,8 @@ def conv2d_async_non_contig( (T.int64(1), T.int64(1), T.int64(54), T.int64(54), T.int64(32)), "int32" ], ): + """Non contiguous memory access is used in this conv2d taken from MS.""" + # pylint: disable=no-self-argument # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body @@ -106,7 +108,6 @@ def conv2d_async_non_contig( v_ow = T.axis.spatial( T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2_init ) - v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0)) T.reads() T.writes( conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)] @@ -135,11 +136,9 @@ def conv2d_async_non_contig( v_ow = T.axis.spatial( T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2 ) - v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0)) v_kh, v_kw = T.axis.remap("RR", [kh_1, kw_1]) v_ic_outer = T.axis.reduce(T.int64(1), T.int64(0)) v_ic_f_inner = T.axis.reduce(T.int64(1), T.int64(0)) - v_ic_s_inner_o = T.axis.reduce(T.int64(1), T.int64(0)) T.reads( conv2d_NCHWc_int8[ v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32) @@ -882,13 +881,13 @@ def test_meta(hexagon_session): }, ) - -@tvm.testing.requires_hexagon -def test_non_contiguous(hexagon_session): +def test_non_contiguous(): + """Test Non Contiguous memory lowering.""" sch = tvm.tir.Schedule(conv2d_async_non_contig) target_hexagon = tvm.target.hexagon("v68", link_params=True) - err_rgx = r"Unable to lower async dma for non contiguous memory access with index: " - # Currently we do not support non contiguous memory access being lowered to async dma so we throw an error. + err_rgx = r"Unable to lower async dma for non contiguous memory access with load index: " + # Currently we do not support non contiguous memory access being lowered to + # async dma so we throw an error. with pytest.raises(tvm.TVMError, match=err_rgx): with tvm.transform.PassContext( config={ From 223593fb420049424c825a90fbda020361738519 Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Thu, 15 Dec 2022 13:46:42 -0800 Subject: [PATCH 4/4] lint --- src/tir/transforms/lower_async_dma.cc | 5 +++-- tests/python/contrib/test_hexagon/test_async_dma_pipeline.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index ecdc792cbb7e..94769dae0899 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -159,8 +159,9 @@ class AsyncDMALowerer : public StmtExprMutator { auto store_iter_map = DetectIterMap(store_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false); if (!store_iter_map->errors.empty()) { - LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with store index: " - << store_index; + LOG(FATAL) + << "Unable to lower async dma for non contiguous memory access with store index: " + << store_index; } store_index.MutateByApply([&](PrimExpr expr) { diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 5cc09a098328..914a26c51180 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -881,6 +881,7 @@ def test_meta(hexagon_session): }, ) + def test_non_contiguous(): """Test Non Contiguous memory lowering.""" sch = tvm.tir.Schedule(conv2d_async_non_contig)