diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index da8f1faa8e1d..70dec47e60bd 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -125,6 +125,16 @@ class ScheduleRule : public runtime::ObjectRef { bool require_injective, // bool require_ordered, // Optional> disallow_op); + + /*! + * \brief Inline blocks that produce a constant scalar. Such blocks get in the way of + * ReverseComputeInline during AutoInline, since they are also counted as a producer block + * unless they are inlined first. So it is recommended to run InlineConstantScalars before + * AutoInline. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule InlineConstantScalars(); + /*! * \brief Create a mega rule: multi-level tiling with data reuse * \param structure The tiling structure. Recommended: diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index 5971ad53c48c..d330fc713991 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -22,7 +22,7 @@ from .add_rfactor import AddRFactor from .apply_custom_rule import ApplyCustomRule from .auto_bind import AutoBind -from .auto_inline import AutoInline +from .auto_inline import AutoInline, InlineConstantScalars from .cross_thread_reduction import CrossThreadReduction from .multi_level_tiling import ( MultiLevelTiling, diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py index 22206f3fcc24..c84dbaf89b97 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_inline.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -65,3 +65,20 @@ def __init__( require_ordered, disallow_op, ) + + +@register_object("meta_schedule.InlineConstantScalars") +class InlineConstantScalars(ScheduleRule): + """Inline blocks that produce a constant scalar. + + Such blocks get in the way of ReverseComputeInline during AutoInline, since they are also + counted as a producer block unless they are inlined first. So it is recommended to run + InlineConstantScalars before AutoInline. + """ + + def __init__( + self, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleInlineConstantScalars, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 0828ee538427..ae6f3474bbd6 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -175,10 +175,12 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + pass_list.push_back(tir::transform::LowerIntrin()); // Convert Function to IRModule transform::PassContext pass_ctx = transform::PassContext::Current(); tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", runtime::String(g_var->name_hint)); + f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required for LowerIntrin bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); if (noalias) { f = WithAttr(std::move(f), "tir.noalias", Bool(true)); diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index dcdc83f95cb1..d2d48b9008ce 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -189,5 +189,42 @@ TVM_REGISTER_NODE_TYPE(AutoInlineNode); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") .set_body_typed(ScheduleRule::AutoInline); +/*! \brief Inline blocks that produce a constant scalar. */ +class InlineConstantScalarsNode : public ScheduleRuleNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + // Look for a block of the form + // block compile_engine_const(iter_var(vi, range(min=0, ext=1))) { + // reads([]) + // writes([compile_engine_const[]]) + // compile_engine_const[] = 59 + // } + auto block = sch->Get(block_rv); + if (block->reads.size() == 0 && block->writes.size() == 1 && + block->writes[0]->buffer->shape.size() == 0) { + sch->ComputeInline(block_rv); + } + return {sch}; + } + + ScheduleRule Clone() const final { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); + } + + static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars"; + TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::InlineConstantScalars() { + ObjectPtr n = make_object(); + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") + .set_body_typed(ScheduleRule::InlineConstantScalars); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 141b93be5e34..b1e8c3695d3e 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -54,6 +54,7 @@ ScheduleRule ScheduleRule::PyScheduleRule( Array ScheduleRule::DefaultLLVM() { return { ScheduleRule::ApplyCustomRule(), + ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/false, /*into_consumer=*/true, @@ -100,6 +101,7 @@ Array ScheduleRule::DefaultCUDA() { Map{{"req", String("must")}, {"levels", Array{3}}, // {"scope", String("local")}}), + ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/true, /*into_consumer=*/true, @@ -178,6 +180,7 @@ Array ScheduleRule::DefaultCUDATensorCore() { Array ScheduleRule::DefaultHexagon() { return { ScheduleRule::ApplyCustomRule(), + ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/false, /*into_consumer=*/true, diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index f0672f39217a..3377515a9589 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -209,6 +209,19 @@ class GPUCodeVerifier : public StmtExprVisitor { } } + void VisitExpr_(const CastNode* op) { + if (op->dtype.lanes() > 1) { + if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { + std::stringstream s; + s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" + << op->dtype.bytes() << ") for dtype " << op->dtype + << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; + errors_.push_back(s.str()); + } + } + ExprVisitor::VisitExpr_(op); + } + void VisitExpr_(const BufferLoadNode* op) { if (op->dtype.lanes() > 1) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index b703c79c5d3a..9edf5877fd5e 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -33,6 +33,7 @@ ) from tvm.meta_schedule import postproc, schedule_rule from tvm.tir.schedule import BlockRV, Schedule +from tvm.tir.schedule.analysis import has_block from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN from ..infrastructure import get_hexagon_target @@ -206,9 +207,9 @@ def _schedule_packed_8x8x32_conv2d(): def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool: if conv2d_block is None: - try: + if has_block(sch, "conv2d_NCHWc_int8"): conv2d_block = sch.get_block("conv2d_NCHWc_int8") - except ValueError: + else: return False assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"] diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index c17209e2cb77..1baa13793f38 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -15,7 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import pytest + import tvm +from tvm.tir import Schedule from tvm import meta_schedule as ms from tvm.meta_schedule.testing.space_generation import generate_design_space from tvm.script import tir as T @@ -334,6 +337,101 @@ def main(T_full: T.Buffer[(1, 12, 4096), "int64"]) -> None: T.writes(T_full[ax0, ax1, ax2]) T_full[ax0, ax1, ax2] = T.int64(0) + +@tvm.script.ir_module +class Conv2dInt8: + @T.prim_func + def main(p0: T.Buffer[(16, 14, 14, 256), "int8"], p1: T.Buffer[(1024, 1, 1, 256), "int8"], p2: T.Buffer[(1, 1, 1, 1024), "int32"], p3: T.Buffer[(1, 1, 1, 1024), "int32"], p4: T.Buffer[1024, "int32"], p5: T.Buffer[1024, "int32"], p6: T.Buffer[1024, "int32"], p7: T.Buffer[1, "int32"], p8: T.Buffer[(16, 14, 14, 1024), "int32"], compute: T.Buffer[(16, 14, 14, 1024), "int32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + compile_engine_const = T.alloc_buffer([], dtype="int32") + pad_temp = T.alloc_buffer([16, 14, 14, 256], dtype="int8") + conv2d_nhwc = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_subtract = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_add = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + compute_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_add_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + compute_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_subtract_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + compute_3 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_add_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + with T.block("compile_engine_const"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const[()]) + compile_engine_const[()] = 59 + for i0, i1, i2, i3 in T.grid(16, 14, 14, 256): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 14, 14, 1024, 1, 1, 256): + with T.block("conv2d_nhwc"): + nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) + T.writes(conv2d_nhwc[nn, yy, xx, ff]) + with T.init(): + conv2d_nhwc[nn, yy, xx, ff] = 0 + conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") + for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): + with T.block("compute"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2]) + T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) + compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32") + for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 14, 14, 1024): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) + T.reads(compile_engine_const[()], compute_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = compile_engine_const[()] + compute_1[ax0, ax1, ax2, ax3] + for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 14, 14, 1024): + with T.block("compute_1"): + i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) + T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5]) + T.writes(compute_2[i0_5, i1_5, i2_5, i3_5]) + compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0) + for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 14, 14, 1024): + with T.block("T_subtract_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) + T.reads(compute_2[ax0, ax1, ax2, ax3], p7[0]) + T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) + T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p7[0] + for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 14, 14, 1024): + with T.block("compute_2"): + i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) + T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8]) + T.writes(compute_3[i0_8, i1_8, i2_8, i3_8]) + compute_3[i0_8, i1_8, i2_8, i3_8] = T.q_multiply_shift(T_subtract_1[i0_8, i1_8, i2_8, i3_8], 1408572815, 31, 1, dtype="int32") + for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 14, 14, 1024): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9]) + T.reads(compute_3[ax0, ax1, ax2, ax3], p8[ax0, ax1, ax2, ax3]) + T.writes(T_add_2[ax0, ax1, ax2, ax3]) + T_add_2[ax0, ax1, ax2, ax3] = compute_3[ax0, ax1, ax2, ax3] + p8[ax0, ax1, ax2, ax3] + for i0_10, i1_10, i2_10, i3_10 in T.grid(16, 14, 14, 1024): + with T.block("compute_3"): + i0_11, i1_11, i2_11, i3_11 = T.axis.remap("SSSS", [i0_10, i1_10, i2_10, i3_10]) + T.reads(T_add_2[i0_11, i1_11, i2_11, i3_11]) + T.writes(compute[i0_11, i1_11, i2_11, i3_11]) + compute[i0_11, i1_11, i2_11, i3_11] = T.max(T.min(T_add_2[i0_11, i1_11, i2_11, i3_11], 255), 0) + + # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks # fmt: on @@ -398,9 +496,26 @@ def test_inline_constant_tensor(): tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer) +def test_conv2d_int8_inline_constant_scalars(): + sch = Schedule(Conv2dInt8) + + conv2d = sch.get_block("conv2d_nhwc") + sch.cache_write(conv2d, 0, "shared") + + with pytest.raises(tvm.tir.ScheduleError) as e: + sch.reverse_compute_inline(sch.get_block("T_add_1")) + + err_msg = "The block is only allowed to read a single buffer region, but it reads 2 region(s)" + assert err_msg in str(e) + + ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_block("compile_engine_const")) + sch.reverse_compute_inline(sch.get_block("T_add_1")) + + if __name__ == "__main__": test_inline_consumer_chain() test_inline_into_cache() test_inline_into_multiple_consumers() test_inline_pure_spatial() test_inline_constant_tensor() + test_conv2d_int8_inline_constant_scalars()