Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ class ScheduleRule : public runtime::ObjectRef {
bool require_injective, //
bool require_ordered, //
Optional<Array<String>> disallow_op);

/*!
* \brief Inline blocks that produce a constant scalar. Such blocks get in the way of
* ReverseComputeInline during AutoInline, since they are also counted as a producer block
* unless they are inlined first. So it is recommended to run InlineConstantScalars before
* AutoInline.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule InlineConstantScalars();

/*!
* \brief Create a mega rule: multi-level tiling with data reuse
* \param structure The tiling structure. Recommended:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .add_rfactor import AddRFactor
from .apply_custom_rule import ApplyCustomRule
from .auto_bind import AutoBind
from .auto_inline import AutoInline
from .auto_inline import AutoInline, InlineConstantScalars
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import (
MultiLevelTiling,
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/auto_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,20 @@ def __init__(
require_ordered,
disallow_op,
)


@register_object("meta_schedule.InlineConstantScalars")
class InlineConstantScalars(ScheduleRule):
"""Inline blocks that produce a constant scalar.

Such blocks get in the way of ReverseComputeInline during AutoInline, since they are also
counted as a producer block unless they are inlined first. So it is recommended to run
InlineConstantScalars before AutoInline.
"""

def __init__(
self,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleInlineConstantScalars, # type: ignore # pylint: disable=no-member
)
2 changes: 2 additions & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,12 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::StorageRewrite());
pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
pass_list.push_back(tir::transform::LowerIntrin());
// Convert Function to IRModule
transform::PassContext pass_ctx = transform::PassContext::Current();
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required for LowerIntrin
bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
if (noalias) {
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
Expand Down
37 changes: 37 additions & 0 deletions src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,42 @@ TVM_REGISTER_NODE_TYPE(AutoInlineNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline")
.set_body_typed(ScheduleRule::AutoInline);

/*! \brief Inline blocks that produce a constant scalar. */
class InlineConstantScalarsNode : public ScheduleRuleNode {
public:
void InitializeWithTuneContext(const TuneContext& context) final {}

Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
// Look for a block of the form
// block compile_engine_const(iter_var(vi, range(min=0, ext=1))) {
// reads([])
// writes([compile_engine_const[]])
// compile_engine_const[] = 59
// }
auto block = sch->Get(block_rv);
if (block->reads.size() == 0 && block->writes.size() == 1 &&
block->writes[0]->buffer->shape.size() == 0) {
sch->ComputeInline(block_rv);
}
return {sch};
}

ScheduleRule Clone() const final {
ObjectPtr<InlineConstantScalarsNode> n = make_object<InlineConstantScalarsNode>(*this);
return ScheduleRule(n);
}

static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars";
TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode);
};

ScheduleRule ScheduleRule::InlineConstantScalars() {
ObjectPtr<InlineConstantScalarsNode> n = make_object<InlineConstantScalarsNode>();
return ScheduleRule(n);
}

TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars")
.set_body_typed(ScheduleRule::InlineConstantScalars);
} // namespace meta_schedule
} // namespace tvm
3 changes: 3 additions & 0 deletions src/meta_schedule/schedule_rule/schedule_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ ScheduleRule ScheduleRule::PyScheduleRule(
Array<ScheduleRule> ScheduleRule::DefaultLLVM() {
return {
ScheduleRule::ApplyCustomRule(),
ScheduleRule::InlineConstantScalars(),
ScheduleRule::AutoInline(
/*into_producer=*/false,
/*into_consumer=*/true,
Expand Down Expand Up @@ -100,6 +101,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDA() {
Map<String, ObjectRef>{{"req", String("must")},
{"levels", Array<Integer>{3}}, //
{"scope", String("local")}}),
ScheduleRule::InlineConstantScalars(),
ScheduleRule::AutoInline(
/*into_producer=*/true,
/*into_consumer=*/true,
Expand Down Expand Up @@ -178,6 +180,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
return {
ScheduleRule::ApplyCustomRule(),
ScheduleRule::InlineConstantScalars(),
ScheduleRule::AutoInline(
/*into_producer=*/false,
/*into_consumer=*/true,
Expand Down
13 changes: 13 additions & 0 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,19 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
}

void VisitExpr_(const CastNode* op) {
if (op->dtype.lanes() > 1) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
<< op->dtype.bytes() << ") for dtype " << op->dtype
<< " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
}
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const BufferLoadNode* op) {
if (op->dtype.lanes() > 1) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from tvm.meta_schedule import postproc, schedule_rule
from tvm.tir.schedule import BlockRV, Schedule
from tvm.tir.schedule.analysis import has_block
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN

from ..infrastructure import get_hexagon_target
Expand Down Expand Up @@ -206,9 +207,9 @@ def _schedule_packed_8x8x32_conv2d():

def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool:
if conv2d_block is None:
try:
if has_block(sch, "conv2d_NCHWc_int8"):
conv2d_block = sch.get_block("conv2d_NCHWc_int8")
except ValueError:
else:
return False

assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"]
Expand Down
115 changes: 115 additions & 0 deletions tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import pytest

import tvm
from tvm.tir import Schedule
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.space_generation import generate_design_space
from tvm.script import tir as T
Expand Down Expand Up @@ -334,6 +337,101 @@ def main(T_full: T.Buffer[(1, 12, 4096), "int64"]) -> None:
T.writes(T_full[ax0, ax1, ax2])
T_full[ax0, ax1, ax2] = T.int64(0)


@tvm.script.ir_module
class Conv2dInt8:
@T.prim_func
def main(p0: T.Buffer[(16, 14, 14, 256), "int8"], p1: T.Buffer[(1024, 1, 1, 256), "int8"], p2: T.Buffer[(1, 1, 1, 1024), "int32"], p3: T.Buffer[(1, 1, 1, 1024), "int32"], p4: T.Buffer[1024, "int32"], p5: T.Buffer[1024, "int32"], p6: T.Buffer[1024, "int32"], p7: T.Buffer[1, "int32"], p8: T.Buffer[(16, 14, 14, 1024), "int32"], compute: T.Buffer[(16, 14, 14, 1024), "int32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
compile_engine_const = T.alloc_buffer([], dtype="int32")
pad_temp = T.alloc_buffer([16, 14, 14, 256], dtype="int8")
conv2d_nhwc = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
T_subtract = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
T_add = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
compute_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
T_add_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
compute_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
T_subtract_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
compute_3 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
T_add_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
with T.block("compile_engine_const"):
vi = T.axis.spatial(1, 0)
T.reads()
T.writes(compile_engine_const[()])
compile_engine_const[()] = 59
for i0, i1, i2, i3 in T.grid(16, 14, 14, 256):
with T.block("pad_temp"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(p0[i0_1, i1_1, i2_1, i3_1])
T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1])
pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1]
for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 14, 14, 1024, 1, 1, 256):
with T.block("conv2d_nhwc"):
nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc])
T.writes(conv2d_nhwc[nn, yy, xx, ff])
with T.init():
conv2d_nhwc[nn, yy, xx, ff] = 0
conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32")
for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
with T.block("T_subtract"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3])
T.writes(T_subtract[ax0, ax1, ax2, ax3])
T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3]
for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
with T.block("T_add"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3])
T.writes(T_add[ax0, ax1, ax2, ax3])
T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3]
for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
with T.block("compute"):
i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2])
T.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32")
for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 14, 14, 1024):
with T.block("T_add_1"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
T.reads(compile_engine_const[()], compute_1[ax0, ax1, ax2, ax3])
T.writes(T_add_1[ax0, ax1, ax2, ax3])
T_add_1[ax0, ax1, ax2, ax3] = compile_engine_const[()] + compute_1[ax0, ax1, ax2, ax3]
for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 14, 14, 1024):
with T.block("compute_1"):
i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5])
T.writes(compute_2[i0_5, i1_5, i2_5, i3_5])
compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0)
for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 14, 14, 1024):
with T.block("T_subtract_1"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
T.reads(compute_2[ax0, ax1, ax2, ax3], p7[0])
T.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p7[0]
for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 14, 14, 1024):
with T.block("compute_2"):
i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8])
T.writes(compute_3[i0_8, i1_8, i2_8, i3_8])
compute_3[i0_8, i1_8, i2_8, i3_8] = T.q_multiply_shift(T_subtract_1[i0_8, i1_8, i2_8, i3_8], 1408572815, 31, 1, dtype="int32")
for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 14, 14, 1024):
with T.block("T_add_2"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9])
T.reads(compute_3[ax0, ax1, ax2, ax3], p8[ax0, ax1, ax2, ax3])
T.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = compute_3[ax0, ax1, ax2, ax3] + p8[ax0, ax1, ax2, ax3]
for i0_10, i1_10, i2_10, i3_10 in T.grid(16, 14, 14, 1024):
with T.block("compute_3"):
i0_11, i1_11, i2_11, i3_11 = T.axis.remap("SSSS", [i0_10, i1_10, i2_10, i3_10])
T.reads(T_add_2[i0_11, i1_11, i2_11, i3_11])
T.writes(compute[i0_11, i1_11, i2_11, i3_11])
compute[i0_11, i1_11, i2_11, i3_11] = T.max(T.min(T_add_2[i0_11, i1_11, i2_11, i3_11], 255), 0)


# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on

Expand Down Expand Up @@ -398,9 +496,26 @@ def test_inline_constant_tensor():
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer)


def test_conv2d_int8_inline_constant_scalars():
sch = Schedule(Conv2dInt8)

conv2d = sch.get_block("conv2d_nhwc")
sch.cache_write(conv2d, 0, "shared")

with pytest.raises(tvm.tir.ScheduleError) as e:
sch.reverse_compute_inline(sch.get_block("T_add_1"))

err_msg = "The block is only allowed to read a single buffer region, but it reads 2 region(s)"
assert err_msg in str(e)

ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_block("compile_engine_const"))
sch.reverse_compute_inline(sch.get_block("T_add_1"))


if __name__ == "__main__":
test_inline_consumer_chain()
test_inline_into_cache()
test_inline_into_multiple_consumers()
test_inline_pure_spatial()
test_inline_constant_tensor()
test_conv2d_int8_inline_constant_scalars()