diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index c3cc0ef60152..2797ee44735c 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -144,13 +144,40 @@ void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedA } } +int CalculateNumRewritableLoops(const Array& loop_srefs, + const std::vector& loop_types) { + int rw_loops_num = 0; + ICHECK_EQ(loop_srefs.size(), loop_types.size()); + for (size_t i = 0; i < loop_srefs.size(); ++i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + if (HasAnnOrBinding(loop)) { + continue; + } + // Cannot vectorize reduce axis + if (loop_types[i] != IterVarType::kDataPar) { + continue; + } + // Cannot fuse with a loop with multiple children + if (!IsSingleStmt(loop->body)) { + continue; + } + // Check if the loop extent is valid + if (GetLoopIntExtent(loop_sref) == nullptr) { + continue; + } + ++rw_loops_num; + } + return rw_loops_num; +} + void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, const Array& loop_rvs, ParsedAnnotation* parsed) { StmtSRef block_sref = sch->GetSRef(block_rv); if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { return; } - int n_loops = loop_rvs.size(); + const int n_loops = loop_rvs.size(); if (n_loops == 0) { parsed->max_parallel_extent = -1; parsed->max_vectorize_extent = -1; @@ -226,6 +253,10 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, } max_fusible = std::min(max_fusible, fusible); } + + // Calculate how many loops are rewritable, i.e. valid for vectorization and parallelization. + int max_rw_loops = CalculateNumRewritableLoops(loop_srefs, loop_types); + // Calculate the parallelize extent if (parsed->max_parallel_extent != -1) { int max_extent = parsed->max_parallel_extent; @@ -290,10 +321,17 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, num_fusible = -1; } } - // Prefer num_vectorize to num_parallel + if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) { - parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, // - n_loops - parsed->num_vectorize_loops); + if (max_rw_loops == n_loops && max_fusible == n_loops) { + // All loops can be fused, parallelized and vectorized + parsed->num_parallel_loops = n_loops; + parsed->num_vectorize_loops = n_loops; + } else { + // Prefer num_vectorize to num_parallel + parsed->num_parallel_loops = + std::min(parsed->num_parallel_loops, n_loops - parsed->num_vectorize_loops); + } } } @@ -317,6 +355,21 @@ bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, Block return false; } +void RewriteFuseSplitParallelVectorize(const Schedule& sch, Array* loop_rvs, int vec_len) { + size_t n_loops = loop_rvs->size(); + LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()}); + Array split = sch->Split(fused, {NullOpt, Integer(vec_len)}); + ICHECK_EQ(split.size(), 2); + const LoopRV& outer = split[0]; + const LoopRV& inner = split[1]; + sch->Parallel(outer); + sch->Vectorize(inner); + for (size_t i = 0; i < n_loops - 1; ++i) { + loop_rvs->Set(i, outer); + } + loop_rvs->Set(n_loops - 1, inner); +} + void RewriteParallel(const Schedule& sch, size_t n, Array* loop_rvs) { ICHECK_LE(n, loop_rvs->size()); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); @@ -364,13 +417,19 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { } tir::ParsedAnnotation parsed = parsed_root; tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); - // Parallel - if (parsed.num_parallel_loops > 0) { - tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); - } - // Vectorize - if (parsed.num_vectorize_loops > 0) { - tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + const int loops_num = loop_rvs.size(); + if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) { + // Fuse, split, vectorize and parallelize + tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent); + } else { + // Parallel + if (parsed.num_parallel_loops > 0) { + tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + } + // Vectorize + if (parsed.num_vectorize_loops > 0) { + tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + } } // AutoUnroll if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index a3b1cc5e0139..0b6f891cca7d 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -142,6 +142,42 @@ def after_matmul_vectorize( T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1] +@T.prim_func +def before_postproc_add( + lhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"), +) -> None: + with T.block("root"): + T.block_attr({"meta_schedule.parallel":64, "meta_schedule.vectorize":128}) + for n, c0, h, w, c1 in T.grid(1, 8, 56, 56, 32): + with T.block("add_compute"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [n, c0, h, w, c1]) + T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4]) + T.writes(add_compute[v0, v1, v2, v3, v4]) + add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4] + + +@T.prim_func +def after_postproc_add( + lhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"), +) -> None: + with T.block("root"): + for n_c0_h_w_c1_fused_0 in T.parallel(0, 6272): + for n_c0_h_w_c1_fused_1 in T.vectorized(0, 128): + with T.block("add_compute"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(8, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) // 100352) + v2 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 100352 // 1792) + v3 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 1792 // 32) + v4 = T.axis.spatial(32, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 32) + T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4]) + T.writes(add_compute[v0, v1, v2, v3, v4]) + add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4] + + # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable @@ -161,6 +197,14 @@ def test_vectorize_inner_loop(): tvm.ir.assert_structural_equal(sch.mod["main"], after_matmul_vectorize) +def test_parallel_vectorize_add(): + sch = Schedule(before_postproc_add) + rule = RewriteParallelVectorizeUnroll() + assert rule.apply(sch) + tvm.ir.assert_structural_equal(sch.mod["main"], after_postproc_add) + + if __name__ == "__main__": test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize() test_vectorize_inner_loop() + test_parallel_vectorize_add()