From 0e3cb159d2d744dbe394ff427adca89f2cd036c9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 9 Nov 2022 18:24:46 +0900 Subject: [PATCH 1/4] [MetaSchedule] Add a new schedule rule to inline all scalar constants --- include/tvm/meta_schedule/schedule_rule.h | 7 + .../meta_schedule/schedule_rule/__init__.py | 1 + .../schedule_rule/inline_const_scalars.py | 34 +++++ src/meta_schedule/postproc/verify_gpu_code.cc | 2 + .../schedule_rule/inline_const_scalar.cc | 56 ++++++++ .../schedule_rule/schedule_rule.cc | 3 + src/tir/analysis/verify_gpu_code.cc | 13 ++ .../metaschedule_e2e/test_resnet50_int8.py | 5 +- ...e_schedule_rule_inline_constant_scalars.py | 123 ++++++++++++++++++ 9 files changed, 242 insertions(+), 2 deletions(-) create mode 100644 python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py create mode 100644 src/meta_schedule/schedule_rule/inline_const_scalar.cc create mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_inline_constant_scalars.py diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index da8f1faa8e1d..fe6b8fda895c 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -125,6 +125,13 @@ class ScheduleRule : public runtime::ObjectRef { bool require_injective, // bool require_ordered, // Optional> disallow_op); + + /*! + * \brief TODO + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule InlineConstantScalars(); + /*! * \brief Create a mega rule: multi-level tiling with data reuse * \param structure The tiling structure. Recommended: diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index 5971ad53c48c..d3bc5fe9a191 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -34,3 +34,4 @@ from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation from .schedule_rule import PyScheduleRule, ScheduleRule +from .inline_const_scalars import InlineConstantScalars diff --git a/python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py b/python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py new file mode 100644 index 000000000000..6072f2bdefbf --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TODO""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.InlineConstantScalars") +class InlineConstantScalars(ScheduleRule): + """ + """ + + def __init__( + self, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleInlineConstantScalars, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 0828ee538427..f003e886b82b 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -175,10 +175,12 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + pass_list.push_back(tir::transform::LowerIntrin()); // Convert Function to IRModule transform::PassContext pass_ctx = transform::PassContext::Current(); tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", runtime::String(g_var->name_hint)); + f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required for LowerIntrin bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); if (noalias) { f = WithAttr(std::move(f), "tir.noalias", Bool(true)); diff --git a/src/meta_schedule/schedule_rule/inline_const_scalar.cc b/src/meta_schedule/schedule_rule/inline_const_scalar.cc new file mode 100644 index 000000000000..b84ac66080c2 --- /dev/null +++ b/src/meta_schedule/schedule_rule/inline_const_scalar.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */ +class InlineConstantScalarsNode : public ScheduleRuleNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + const std::string block_name = sch->Get(block_rv)->name_hint; + if (block_name.find("compile_engine_const") != std::string::npos) { + sch->ComputeInline(block_rv); + } + return {sch}; + } + + ScheduleRule Clone() const final { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); + } + + static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars"; + TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::InlineConstantScalars() { + ObjectPtr n = make_object(); + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") + .set_body_typed(ScheduleRule::InlineConstantScalars); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 141b93be5e34..b1e8c3695d3e 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -54,6 +54,7 @@ ScheduleRule ScheduleRule::PyScheduleRule( Array ScheduleRule::DefaultLLVM() { return { ScheduleRule::ApplyCustomRule(), + ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/false, /*into_consumer=*/true, @@ -100,6 +101,7 @@ Array ScheduleRule::DefaultCUDA() { Map{{"req", String("must")}, {"levels", Array{3}}, // {"scope", String("local")}}), + ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/true, /*into_consumer=*/true, @@ -178,6 +180,7 @@ Array ScheduleRule::DefaultCUDATensorCore() { Array ScheduleRule::DefaultHexagon() { return { ScheduleRule::ApplyCustomRule(), + ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/false, /*into_consumer=*/true, diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index f0672f39217a..3377515a9589 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -209,6 +209,19 @@ class GPUCodeVerifier : public StmtExprVisitor { } } + void VisitExpr_(const CastNode* op) { + if (op->dtype.lanes() > 1) { + if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { + std::stringstream s; + s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" + << op->dtype.bytes() << ") for dtype " << op->dtype + << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; + errors_.push_back(s.str()); + } + } + ExprVisitor::VisitExpr_(op); + } + void VisitExpr_(const BufferLoadNode* op) { if (op->dtype.lanes() > 1) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index b703c79c5d3a..7633d73b1b2f 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -33,6 +33,7 @@ ) from tvm.meta_schedule import postproc, schedule_rule from tvm.tir.schedule import BlockRV, Schedule +from tvm.tir.schedule.analysis import has_block from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN from ..infrastructure import get_hexagon_target @@ -206,9 +207,9 @@ def _schedule_packed_8x8x32_conv2d(): def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool: if conv2d_block is None: - try: + if has_block("conv2d_NCHWc_int8"): conv2d_block = sch.get_block("conv2d_NCHWc_int8") - except ValueError: + else: return False assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"] diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_inline_constant_scalars.py b/tests/python/unittest/test_meta_schedule_schedule_rule_inline_constant_scalars.py new file mode 100644 index 000000000000..df26cbe459e0 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_inline_constant_scalars.py @@ -0,0 +1,123 @@ +import pytest +import tvm + +from tvm.script import tir as T +from tvm import meta_schedule as ms +from tvm.tir import Schedule + + +# fmt: off +@tvm.script.ir_module +class Conv2dInt8: + @T.prim_func + def main(p0: T.Buffer[(16, 14, 14, 256), "int8"], p1: T.Buffer[(1024, 1, 1, 256), "int8"], p2: T.Buffer[(1, 1, 1, 1024), "int32"], p3: T.Buffer[(1, 1, 1, 1024), "int32"], p4: T.Buffer[1024, "int32"], p5: T.Buffer[1024, "int32"], p6: T.Buffer[1024, "int32"], p7: T.Buffer[1, "int32"], p8: T.Buffer[(16, 14, 14, 1024), "int32"], compute: T.Buffer[(16, 14, 14, 1024), "int32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + compile_engine_const = T.alloc_buffer([], dtype="int32") + pad_temp = T.alloc_buffer([16, 14, 14, 256], dtype="int8") + conv2d_nhwc = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_subtract = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_add = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + compute_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_add_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + compute_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_subtract_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + compute_3 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_add_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + with T.block("compile_engine_const"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const[()]) + compile_engine_const[()] = 59 + for i0, i1, i2, i3 in T.grid(16, 14, 14, 256): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 14, 14, 1024, 1, 1, 256): + with T.block("conv2d_nhwc"): + nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) + T.writes(conv2d_nhwc[nn, yy, xx, ff]) + with T.init(): + conv2d_nhwc[nn, yy, xx, ff] = 0 + conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") + for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): + with T.block("compute"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2]) + T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) + compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32") + for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 14, 14, 1024): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) + T.reads(compile_engine_const[()], compute_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = compile_engine_const[()] + compute_1[ax0, ax1, ax2, ax3] + for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 14, 14, 1024): + with T.block("compute_1"): + i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) + T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5]) + T.writes(compute_2[i0_5, i1_5, i2_5, i3_5]) + compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0) + for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 14, 14, 1024): + with T.block("T_subtract_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) + T.reads(compute_2[ax0, ax1, ax2, ax3], p7[0]) + T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) + T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p7[0] + for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 14, 14, 1024): + with T.block("compute_2"): + i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) + T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8]) + T.writes(compute_3[i0_8, i1_8, i2_8, i3_8]) + compute_3[i0_8, i1_8, i2_8, i3_8] = T.q_multiply_shift(T_subtract_1[i0_8, i1_8, i2_8, i3_8], 1408572815, 31, 1, dtype="int32") + for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 14, 14, 1024): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9]) + T.reads(compute_3[ax0, ax1, ax2, ax3], p8[ax0, ax1, ax2, ax3]) + T.writes(T_add_2[ax0, ax1, ax2, ax3]) + T_add_2[ax0, ax1, ax2, ax3] = compute_3[ax0, ax1, ax2, ax3] + p8[ax0, ax1, ax2, ax3] + for i0_10, i1_10, i2_10, i3_10 in T.grid(16, 14, 14, 1024): + with T.block("compute_3"): + i0_11, i1_11, i2_11, i3_11 = T.axis.remap("SSSS", [i0_10, i1_10, i2_10, i3_10]) + T.reads(T_add_2[i0_11, i1_11, i2_11, i3_11]) + T.writes(compute[i0_11, i1_11, i2_11, i3_11]) + compute[i0_11, i1_11, i2_11, i3_11] = T.max(T.min(T_add_2[i0_11, i1_11, i2_11, i3_11], 255), 0) + +# fmt: on + + +def test_conv2d_int8(): + sch = Schedule(Conv2dInt8) + + conv2d = sch.get_block("conv2d_nhwc") + sch.cache_write(conv2d, 0, "shared") + + with pytest.raises(tvm.tir.ScheduleError) as e: + sch.reverse_compute_inline(sch.get_block("T_add_1")) + + err_msg = "The block is only allowed to read a single buffer region, but it reads 2 region(s)" + assert err_msg in str(e) + + ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_block("compile_engine_const")) + sch.reverse_compute_inline(sch.get_block("T_add_1")) + + +if __name__ == "__main__": + tvm.testing.main() From 11f1f9635406616096689c2856d3985d3335ae71 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 10 Nov 2022 05:11:49 +0900 Subject: [PATCH 2/4] add doc --- include/tvm/meta_schedule/schedule_rule.h | 5 ++++- .../tvm/meta_schedule/schedule_rule/inline_const_scalars.py | 6 +++++- src/meta_schedule/schedule_rule/inline_const_scalar.cc | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index fe6b8fda895c..70dec47e60bd 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -127,7 +127,10 @@ class ScheduleRule : public runtime::ObjectRef { Optional> disallow_op); /*! - * \brief TODO + * \brief Inline blocks that produce a constant scalar. Such blocks get in the way of + * ReverseComputeInline during AutoInline, since they are also counted as a producer block + * unless they are inlined first. So it is recommended to run InlineConstantScalars before + * AutoInline. * \return The schedule rule created */ TVM_DLL static ScheduleRule InlineConstantScalars(); diff --git a/python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py b/python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py index 6072f2bdefbf..20b59457ea13 100644 --- a/python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py +++ b/python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py @@ -23,7 +23,11 @@ @register_object("meta_schedule.InlineConstantScalars") class InlineConstantScalars(ScheduleRule): - """ + """Inline blocks that produce a constant scalar. + + Such blocks get in the way of ReverseComputeInline during AutoInline, since they are also + counted as a producer block unless they are inlined first. So it is recommended to run + InlineConstantScalars before AutoInline. """ def __init__( diff --git a/src/meta_schedule/schedule_rule/inline_const_scalar.cc b/src/meta_schedule/schedule_rule/inline_const_scalar.cc index b84ac66080c2..8bb9a917366c 100644 --- a/src/meta_schedule/schedule_rule/inline_const_scalar.cc +++ b/src/meta_schedule/schedule_rule/inline_const_scalar.cc @@ -21,7 +21,7 @@ namespace tvm { namespace meta_schedule { -/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */ +/*! \brief Inline blocks that produce a constant scalar. */ class InlineConstantScalarsNode : public ScheduleRuleNode { public: void InitializeWithTuneContext(const TuneContext& context) final {} From 4b83416369d84a8ef896c6da67c694fd9d05c7b9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 10 Nov 2022 05:17:12 +0900 Subject: [PATCH 3/4] reorg --- .../meta_schedule/schedule_rule/__init__.py | 3 +- .../schedule_rule/auto_inline.py | 17 +++ .../schedule_rule/inline_const_scalars.py | 38 ------ src/meta_schedule/postproc/verify_gpu_code.cc | 2 +- .../schedule_rule/auto_inline.cc | 30 +++++ .../schedule_rule/inline_const_scalar.cc | 56 -------- .../metaschedule_e2e/test_resnet50_int8.py | 2 +- ...meta_schedule_schedule_rule_auto_inline.py | 115 ++++++++++++++++ ...e_schedule_rule_inline_constant_scalars.py | 123 ------------------ 9 files changed, 165 insertions(+), 221 deletions(-) delete mode 100644 python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py delete mode 100644 src/meta_schedule/schedule_rule/inline_const_scalar.cc delete mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_inline_constant_scalars.py diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index d3bc5fe9a191..d330fc713991 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -22,7 +22,7 @@ from .add_rfactor import AddRFactor from .apply_custom_rule import ApplyCustomRule from .auto_bind import AutoBind -from .auto_inline import AutoInline +from .auto_inline import AutoInline, InlineConstantScalars from .cross_thread_reduction import CrossThreadReduction from .multi_level_tiling import ( MultiLevelTiling, @@ -34,4 +34,3 @@ from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation from .schedule_rule import PyScheduleRule, ScheduleRule -from .inline_const_scalars import InlineConstantScalars diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py index 22206f3fcc24..c84dbaf89b97 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_inline.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -65,3 +65,20 @@ def __init__( require_ordered, disallow_op, ) + + +@register_object("meta_schedule.InlineConstantScalars") +class InlineConstantScalars(ScheduleRule): + """Inline blocks that produce a constant scalar. + + Such blocks get in the way of ReverseComputeInline during AutoInline, since they are also + counted as a producer block unless they are inlined first. So it is recommended to run + InlineConstantScalars before AutoInline. + """ + + def __init__( + self, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleInlineConstantScalars, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py b/python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py deleted file mode 100644 index 20b59457ea13..000000000000 --- a/python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py +++ /dev/null @@ -1,38 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TODO""" -from tvm._ffi import register_object - -from .. import _ffi_api -from .schedule_rule import ScheduleRule - - -@register_object("meta_schedule.InlineConstantScalars") -class InlineConstantScalars(ScheduleRule): - """Inline blocks that produce a constant scalar. - - Such blocks get in the way of ReverseComputeInline during AutoInline, since they are also - counted as a producer block unless they are inlined first. So it is recommended to run - InlineConstantScalars before AutoInline. - """ - - def __init__( - self, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.ScheduleRuleInlineConstantScalars, # type: ignore # pylint: disable=no-member - ) diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index f003e886b82b..ae6f3474bbd6 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -180,7 +180,7 @@ class VerifyGPUCodeNode : public PostprocNode { transform::PassContext pass_ctx = transform::PassContext::Current(); tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", runtime::String(g_var->name_hint)); - f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required for LowerIntrin + f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required for LowerIntrin bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); if (noalias) { f = WithAttr(std::move(f), "tir.noalias", Bool(true)); diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index dcdc83f95cb1..63cc24aca123 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -189,5 +189,35 @@ TVM_REGISTER_NODE_TYPE(AutoInlineNode); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") .set_body_typed(ScheduleRule::AutoInline); +/*! \brief Inline blocks that produce a constant scalar. */ +class InlineConstantScalarsNode : public ScheduleRuleNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + const std::string block_name = sch->Get(block_rv)->name_hint; + if (block_name.find("compile_engine_const") != std::string::npos) { + sch->ComputeInline(block_rv); + } + return {sch}; + } + + ScheduleRule Clone() const final { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); + } + + static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars"; + TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::InlineConstantScalars() { + ObjectPtr n = make_object(); + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") + .set_body_typed(ScheduleRule::InlineConstantScalars); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/inline_const_scalar.cc b/src/meta_schedule/schedule_rule/inline_const_scalar.cc deleted file mode 100644 index 8bb9a917366c..000000000000 --- a/src/meta_schedule/schedule_rule/inline_const_scalar.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include "../utils.h" - -namespace tvm { -namespace meta_schedule { - -/*! \brief Inline blocks that produce a constant scalar. */ -class InlineConstantScalarsNode : public ScheduleRuleNode { - public: - void InitializeWithTuneContext(const TuneContext& context) final {} - - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { - const std::string block_name = sch->Get(block_rv)->name_hint; - if (block_name.find("compile_engine_const") != std::string::npos) { - sch->ComputeInline(block_rv); - } - return {sch}; - } - - ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); - return ScheduleRule(n); - } - - static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars"; - TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode); -}; - -ScheduleRule ScheduleRule::InlineConstantScalars() { - ObjectPtr n = make_object(); - return ScheduleRule(n); -} - -TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") - .set_body_typed(ScheduleRule::InlineConstantScalars); - -} // namespace meta_schedule -} // namespace tvm diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 7633d73b1b2f..9edf5877fd5e 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -207,7 +207,7 @@ def _schedule_packed_8x8x32_conv2d(): def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool: if conv2d_block is None: - if has_block("conv2d_NCHWc_int8"): + if has_block(sch, "conv2d_NCHWc_int8"): conv2d_block = sch.get_block("conv2d_NCHWc_int8") else: return False diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index c17209e2cb77..1baa13793f38 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -15,7 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import pytest + import tvm +from tvm.tir import Schedule from tvm import meta_schedule as ms from tvm.meta_schedule.testing.space_generation import generate_design_space from tvm.script import tir as T @@ -334,6 +337,101 @@ def main(T_full: T.Buffer[(1, 12, 4096), "int64"]) -> None: T.writes(T_full[ax0, ax1, ax2]) T_full[ax0, ax1, ax2] = T.int64(0) + +@tvm.script.ir_module +class Conv2dInt8: + @T.prim_func + def main(p0: T.Buffer[(16, 14, 14, 256), "int8"], p1: T.Buffer[(1024, 1, 1, 256), "int8"], p2: T.Buffer[(1, 1, 1, 1024), "int32"], p3: T.Buffer[(1, 1, 1, 1024), "int32"], p4: T.Buffer[1024, "int32"], p5: T.Buffer[1024, "int32"], p6: T.Buffer[1024, "int32"], p7: T.Buffer[1, "int32"], p8: T.Buffer[(16, 14, 14, 1024), "int32"], compute: T.Buffer[(16, 14, 14, 1024), "int32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + compile_engine_const = T.alloc_buffer([], dtype="int32") + pad_temp = T.alloc_buffer([16, 14, 14, 256], dtype="int8") + conv2d_nhwc = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_subtract = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_add = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + compute_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_add_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + compute_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_subtract_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + compute_3 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + T_add_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") + with T.block("compile_engine_const"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const[()]) + compile_engine_const[()] = 59 + for i0, i1, i2, i3 in T.grid(16, 14, 14, 256): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 14, 14, 1024, 1, 1, 256): + with T.block("conv2d_nhwc"): + nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) + T.writes(conv2d_nhwc[nn, yy, xx, ff]) + with T.init(): + conv2d_nhwc[nn, yy, xx, ff] = 0 + conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") + for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): + with T.block("compute"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2]) + T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) + compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32") + for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 14, 14, 1024): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) + T.reads(compile_engine_const[()], compute_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = compile_engine_const[()] + compute_1[ax0, ax1, ax2, ax3] + for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 14, 14, 1024): + with T.block("compute_1"): + i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) + T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5]) + T.writes(compute_2[i0_5, i1_5, i2_5, i3_5]) + compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0) + for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 14, 14, 1024): + with T.block("T_subtract_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) + T.reads(compute_2[ax0, ax1, ax2, ax3], p7[0]) + T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) + T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p7[0] + for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 14, 14, 1024): + with T.block("compute_2"): + i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) + T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8]) + T.writes(compute_3[i0_8, i1_8, i2_8, i3_8]) + compute_3[i0_8, i1_8, i2_8, i3_8] = T.q_multiply_shift(T_subtract_1[i0_8, i1_8, i2_8, i3_8], 1408572815, 31, 1, dtype="int32") + for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 14, 14, 1024): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9]) + T.reads(compute_3[ax0, ax1, ax2, ax3], p8[ax0, ax1, ax2, ax3]) + T.writes(T_add_2[ax0, ax1, ax2, ax3]) + T_add_2[ax0, ax1, ax2, ax3] = compute_3[ax0, ax1, ax2, ax3] + p8[ax0, ax1, ax2, ax3] + for i0_10, i1_10, i2_10, i3_10 in T.grid(16, 14, 14, 1024): + with T.block("compute_3"): + i0_11, i1_11, i2_11, i3_11 = T.axis.remap("SSSS", [i0_10, i1_10, i2_10, i3_10]) + T.reads(T_add_2[i0_11, i1_11, i2_11, i3_11]) + T.writes(compute[i0_11, i1_11, i2_11, i3_11]) + compute[i0_11, i1_11, i2_11, i3_11] = T.max(T.min(T_add_2[i0_11, i1_11, i2_11, i3_11], 255), 0) + + # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks # fmt: on @@ -398,9 +496,26 @@ def test_inline_constant_tensor(): tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer) +def test_conv2d_int8_inline_constant_scalars(): + sch = Schedule(Conv2dInt8) + + conv2d = sch.get_block("conv2d_nhwc") + sch.cache_write(conv2d, 0, "shared") + + with pytest.raises(tvm.tir.ScheduleError) as e: + sch.reverse_compute_inline(sch.get_block("T_add_1")) + + err_msg = "The block is only allowed to read a single buffer region, but it reads 2 region(s)" + assert err_msg in str(e) + + ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_block("compile_engine_const")) + sch.reverse_compute_inline(sch.get_block("T_add_1")) + + if __name__ == "__main__": test_inline_consumer_chain() test_inline_into_cache() test_inline_into_multiple_consumers() test_inline_pure_spatial() test_inline_constant_tensor() + test_conv2d_int8_inline_constant_scalars() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_inline_constant_scalars.py b/tests/python/unittest/test_meta_schedule_schedule_rule_inline_constant_scalars.py deleted file mode 100644 index df26cbe459e0..000000000000 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_inline_constant_scalars.py +++ /dev/null @@ -1,123 +0,0 @@ -import pytest -import tvm - -from tvm.script import tir as T -from tvm import meta_schedule as ms -from tvm.tir import Schedule - - -# fmt: off -@tvm.script.ir_module -class Conv2dInt8: - @T.prim_func - def main(p0: T.Buffer[(16, 14, 14, 256), "int8"], p1: T.Buffer[(1024, 1, 1, 256), "int8"], p2: T.Buffer[(1, 1, 1, 1024), "int32"], p3: T.Buffer[(1, 1, 1, 1024), "int32"], p4: T.Buffer[1024, "int32"], p5: T.Buffer[1024, "int32"], p6: T.Buffer[1024, "int32"], p7: T.Buffer[1, "int32"], p8: T.Buffer[(16, 14, 14, 1024), "int32"], compute: T.Buffer[(16, 14, 14, 1024), "int32"]) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - compile_engine_const = T.alloc_buffer([], dtype="int32") - pad_temp = T.alloc_buffer([16, 14, 14, 256], dtype="int8") - conv2d_nhwc = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") - T_subtract = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") - T_add = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") - compute_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") - T_add_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") - compute_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") - T_subtract_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") - compute_3 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") - T_add_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") - with T.block("compile_engine_const"): - vi = T.axis.spatial(1, 0) - T.reads() - T.writes(compile_engine_const[()]) - compile_engine_const[()] = 59 - for i0, i1, i2, i3 in T.grid(16, 14, 14, 256): - with T.block("pad_temp"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(p0[i0_1, i1_1, i2_1, i3_1]) - T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) - pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] - for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 14, 14, 1024, 1, 1, 256): - with T.block("conv2d_nhwc"): - nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) - T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) - T.writes(conv2d_nhwc[nn, yy, xx, ff]) - with T.init(): - conv2d_nhwc[nn, yy, xx, ff] = 0 - conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") - for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): - with T.block("T_subtract"): - ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) - T.writes(T_subtract[ax0, ax1, ax2, ax3]) - T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] - for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): - with T.block("T_add"): - ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) - T.writes(T_add[ax0, ax1, ax2, ax3]) - T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] - for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): - with T.block("compute"): - i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2]) - T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) - compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32") - for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 14, 14, 1024): - with T.block("T_add_1"): - ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) - T.reads(compile_engine_const[()], compute_1[ax0, ax1, ax2, ax3]) - T.writes(T_add_1[ax0, ax1, ax2, ax3]) - T_add_1[ax0, ax1, ax2, ax3] = compile_engine_const[()] + compute_1[ax0, ax1, ax2, ax3] - for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 14, 14, 1024): - with T.block("compute_1"): - i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) - T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5]) - T.writes(compute_2[i0_5, i1_5, i2_5, i3_5]) - compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0) - for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 14, 14, 1024): - with T.block("T_subtract_1"): - ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) - T.reads(compute_2[ax0, ax1, ax2, ax3], p7[0]) - T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) - T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p7[0] - for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 14, 14, 1024): - with T.block("compute_2"): - i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) - T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8]) - T.writes(compute_3[i0_8, i1_8, i2_8, i3_8]) - compute_3[i0_8, i1_8, i2_8, i3_8] = T.q_multiply_shift(T_subtract_1[i0_8, i1_8, i2_8, i3_8], 1408572815, 31, 1, dtype="int32") - for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 14, 14, 1024): - with T.block("T_add_2"): - ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9]) - T.reads(compute_3[ax0, ax1, ax2, ax3], p8[ax0, ax1, ax2, ax3]) - T.writes(T_add_2[ax0, ax1, ax2, ax3]) - T_add_2[ax0, ax1, ax2, ax3] = compute_3[ax0, ax1, ax2, ax3] + p8[ax0, ax1, ax2, ax3] - for i0_10, i1_10, i2_10, i3_10 in T.grid(16, 14, 14, 1024): - with T.block("compute_3"): - i0_11, i1_11, i2_11, i3_11 = T.axis.remap("SSSS", [i0_10, i1_10, i2_10, i3_10]) - T.reads(T_add_2[i0_11, i1_11, i2_11, i3_11]) - T.writes(compute[i0_11, i1_11, i2_11, i3_11]) - compute[i0_11, i1_11, i2_11, i3_11] = T.max(T.min(T_add_2[i0_11, i1_11, i2_11, i3_11], 255), 0) - -# fmt: on - - -def test_conv2d_int8(): - sch = Schedule(Conv2dInt8) - - conv2d = sch.get_block("conv2d_nhwc") - sch.cache_write(conv2d, 0, "shared") - - with pytest.raises(tvm.tir.ScheduleError) as e: - sch.reverse_compute_inline(sch.get_block("T_add_1")) - - err_msg = "The block is only allowed to read a single buffer region, but it reads 2 region(s)" - assert err_msg in str(e) - - ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_block("compile_engine_const")) - sch.reverse_compute_inline(sch.get_block("T_add_1")) - - -if __name__ == "__main__": - tvm.testing.main() From f398453b035fea9b03e40627030603111582b396 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 10 Nov 2022 15:46:42 +0900 Subject: [PATCH 4/4] identify constant block by its structure, not by name --- src/meta_schedule/schedule_rule/auto_inline.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 63cc24aca123..d2d48b9008ce 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -195,8 +195,15 @@ class InlineConstantScalarsNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final {} Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { - const std::string block_name = sch->Get(block_rv)->name_hint; - if (block_name.find("compile_engine_const") != std::string::npos) { + // Look for a block of the form + // block compile_engine_const(iter_var(vi, range(min=0, ext=1))) { + // reads([]) + // writes([compile_engine_const[]]) + // compile_engine_const[] = 59 + // } + auto block = sch->Get(block_rv); + if (block->reads.size() == 0 && block->writes.size() == 1 && + block->writes[0]->buffer->shape.size() == 0) { sch->ComputeInline(block_rv); } return {sch};