From 032e505dacfa42502e8cfb609de4f3c1ccc38ffc Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Wed, 22 Mar 2023 18:22:22 +0300 Subject: [PATCH] [MetaSchedule][Hexagon] Improve vectorization for standalone elementwise ops Motivation: It was found that for standalone elementwise operations (add, sub, etc.) MetaScheduler generates code with poor performance due to lack of vector code on some input tensor shapes. Current implementation is not able to vectorize if innermost loops extent is not multiple of the vector length. What was done: Core changes: it checks current loops nest, if all loops are "simple", i.e. loops without annotations, bindings, reduce axis, then it does the following: 1) Fuse all loops into single one. 2) Split this new loop into 2 parts: inner and outer. Herewith split factor for the inner loop is equal to 'max_vectorize_extent' MetaScheduler parameter. 3) Parallelize outer loop and vectorize inner loop. Performance measurement: Measurement was done on Qualcomm Snapdragon 888. As it was expected, 1 and 2 got significant performance boost, 3 and 4 - without changes. N | op | Dtype | Shape | Before fix, ms | After fix, ms | speedup | --|---------|-------|------------------|----------------|---------------|---------| 1 | add | uint8 | 1, 8, 56, 56, 32 | 1.264 | 0.167 | 7.5x | 2 | qnn.add | uint8 | 1, 8, 56, 56, 32 | 2.213 | 0.336 | 6.6x | 3 | add | int32 | 1, 8, 56, 56, 32 | 0.161 | 0.150 | 1.07x | 4 | seq* | uint8 | 1, 64, 56, 56 | 2.634 | 2.679 | 0.98x | ----------------------------------------------------------------------------------| seq* - test of the ops sequence: qnn.conv2d + bias_add + qnn.requantize, weights shape = [256, 64, 1, 1] --- .../rewrite_parallel_vectorize_unroll.cc | 81 ++++++++++++++++--- ...tproc_rewrite_parallel_vectorize_unroll.py | 44 ++++++++++ 2 files changed, 114 insertions(+), 11 deletions(-) diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index c3cc0ef60152..2797ee44735c 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -144,13 +144,40 @@ void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedA } } +int CalculateNumRewritableLoops(const Array& loop_srefs, + const std::vector& loop_types) { + int rw_loops_num = 0; + ICHECK_EQ(loop_srefs.size(), loop_types.size()); + for (size_t i = 0; i < loop_srefs.size(); ++i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + if (HasAnnOrBinding(loop)) { + continue; + } + // Cannot vectorize reduce axis + if (loop_types[i] != IterVarType::kDataPar) { + continue; + } + // Cannot fuse with a loop with multiple children + if (!IsSingleStmt(loop->body)) { + continue; + } + // Check if the loop extent is valid + if (GetLoopIntExtent(loop_sref) == nullptr) { + continue; + } + ++rw_loops_num; + } + return rw_loops_num; +} + void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, const Array& loop_rvs, ParsedAnnotation* parsed) { StmtSRef block_sref = sch->GetSRef(block_rv); if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { return; } - int n_loops = loop_rvs.size(); + const int n_loops = loop_rvs.size(); if (n_loops == 0) { parsed->max_parallel_extent = -1; parsed->max_vectorize_extent = -1; @@ -226,6 +253,10 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, } max_fusible = std::min(max_fusible, fusible); } + + // Calculate how many loops are rewritable, i.e. valid for vectorization and parallelization. + int max_rw_loops = CalculateNumRewritableLoops(loop_srefs, loop_types); + // Calculate the parallelize extent if (parsed->max_parallel_extent != -1) { int max_extent = parsed->max_parallel_extent; @@ -290,10 +321,17 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, num_fusible = -1; } } - // Prefer num_vectorize to num_parallel + if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) { - parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, // - n_loops - parsed->num_vectorize_loops); + if (max_rw_loops == n_loops && max_fusible == n_loops) { + // All loops can be fused, parallelized and vectorized + parsed->num_parallel_loops = n_loops; + parsed->num_vectorize_loops = n_loops; + } else { + // Prefer num_vectorize to num_parallel + parsed->num_parallel_loops = + std::min(parsed->num_parallel_loops, n_loops - parsed->num_vectorize_loops); + } } } @@ -317,6 +355,21 @@ bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, Block return false; } +void RewriteFuseSplitParallelVectorize(const Schedule& sch, Array* loop_rvs, int vec_len) { + size_t n_loops = loop_rvs->size(); + LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()}); + Array split = sch->Split(fused, {NullOpt, Integer(vec_len)}); + ICHECK_EQ(split.size(), 2); + const LoopRV& outer = split[0]; + const LoopRV& inner = split[1]; + sch->Parallel(outer); + sch->Vectorize(inner); + for (size_t i = 0; i < n_loops - 1; ++i) { + loop_rvs->Set(i, outer); + } + loop_rvs->Set(n_loops - 1, inner); +} + void RewriteParallel(const Schedule& sch, size_t n, Array* loop_rvs) { ICHECK_LE(n, loop_rvs->size()); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); @@ -364,13 +417,19 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { } tir::ParsedAnnotation parsed = parsed_root; tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); - // Parallel - if (parsed.num_parallel_loops > 0) { - tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); - } - // Vectorize - if (parsed.num_vectorize_loops > 0) { - tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + const int loops_num = loop_rvs.size(); + if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) { + // Fuse, split, vectorize and parallelize + tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent); + } else { + // Parallel + if (parsed.num_parallel_loops > 0) { + tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + } + // Vectorize + if (parsed.num_vectorize_loops > 0) { + tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + } } // AutoUnroll if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index a3b1cc5e0139..0b6f891cca7d 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -142,6 +142,42 @@ def after_matmul_vectorize( T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1] +@T.prim_func +def before_postproc_add( + lhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"), +) -> None: + with T.block("root"): + T.block_attr({"meta_schedule.parallel":64, "meta_schedule.vectorize":128}) + for n, c0, h, w, c1 in T.grid(1, 8, 56, 56, 32): + with T.block("add_compute"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [n, c0, h, w, c1]) + T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4]) + T.writes(add_compute[v0, v1, v2, v3, v4]) + add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4] + + +@T.prim_func +def after_postproc_add( + lhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"), +) -> None: + with T.block("root"): + for n_c0_h_w_c1_fused_0 in T.parallel(0, 6272): + for n_c0_h_w_c1_fused_1 in T.vectorized(0, 128): + with T.block("add_compute"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(8, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) // 100352) + v2 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 100352 // 1792) + v3 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 1792 // 32) + v4 = T.axis.spatial(32, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 32) + T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4]) + T.writes(add_compute[v0, v1, v2, v3, v4]) + add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4] + + # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable @@ -161,6 +197,14 @@ def test_vectorize_inner_loop(): tvm.ir.assert_structural_equal(sch.mod["main"], after_matmul_vectorize) +def test_parallel_vectorize_add(): + sch = Schedule(before_postproc_add) + rule = RewriteParallelVectorizeUnroll() + assert rule.apply(sch) + tvm.ir.assert_structural_equal(sch.mod["main"], after_postproc_add) + + if __name__ == "__main__": test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize() test_vectorize_inner_loop() + test_parallel_vectorize_add()