diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index 6f26d07dc8a5..17aedbcff308 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -251,6 +251,11 @@ class Stage : public ObjectRef { * \return reference to self. */ TVM_DLL Stage& double_buffer(); // NOLINT(*) + /*! + * \brief Compute current stage with rolling buffering. + * \return reference to self. + */ + TVM_DLL Stage& rolling_buffer(); // NOLINT(*) /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. @@ -493,6 +498,8 @@ class StageNode : public Object { bool is_output{false}; /*! \brief Whether apply double buffer optimization to this stage */ bool double_buffer{false}; + /*! \brief Whether apply rolling buffer optimization to this stage */ + bool rolling_buffer{false}; /*! * \brief The parent group of the current stage. * The stage cannot be assigned to stages outside the group. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index cc10c218c8ff..6da879812e55 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1270,6 +1270,8 @@ constexpr const char* double_buffer_scope = "double_buffer_scope"; * \brief Marks region used by double buffer write */ constexpr const char* double_buffer_write = "double_buffer_write"; +/*! \brief Mark realization for rolling buffer optimization */ +constexpr const char* rolling_buffer_scope = "rolling_buffer_scope"; /*! \brief Mark of scan update scope */ constexpr const char* scan_update_scope = "scan_update_scope"; /*! \brief Mark of scan init scope */ diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a3d0bb656736..8896f23a886c 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -221,6 +221,7 @@ def lower( pass_list += [ tvm.tir.transform.VectorizeLoop(not disable_vectorize), tvm.tir.transform.InjectVirtualThread(), + tvm.tir.transform.InjectRollingBuffer(), tvm.tir.transform.InjectDoubleBuffer(), tvm.tir.transform.StorageRewrite(), tvm.tir.transform.UnrollLoop(), diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index 7bd7dceb03e5..55d07a57e3e4 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -511,6 +511,14 @@ def double_buffer(self): """ _ffi_api.StageDoubleBuffer(self) + def rolling_buffer(self): + """Compute the current stage via rolling buffering. + + This can only be applied to intermediate stage. + This will change the storage cost of the current stage. + """ + _ffi_api.StageRollingBuffer(self) + @tvm._ffi.register_object class SpecializedCondition(Object): diff --git a/python/tvm/tir/transform/inject_rolling_buffer.py b/python/tvm/tir/transform/inject_rolling_buffer.py new file mode 100644 index 000000000000..f531f88356f0 --- /dev/null +++ b/python/tvm/tir/transform/inject_rolling_buffer.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Inject rolling buffers through a TIR transformation.""" +# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements +from collections import defaultdict, namedtuple +import math + +import tvm +from tvm import arith + + +def InjectRollingBuffer(): + """Inject rolling buffer statements. + + Rolling buffers are buffers where one of the dimensions has been made into + a circular buffer. Two optimizations are implemented in order to accomplish + this: sliding window and storage folding. In particular, the sliding window + optimization is applied to the entire buffer (to avoid recomputing elements) + and storage folding is then applied to just the rolling dimension. + + Rolling buffers must be inside a loop with only part of the buffer used per + iteration. The outermost axis will be rolled over. + + For more information, see the RFC: + https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836 + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + buffer_to_attrs = defaultdict(list) + rolling_buffers = set() + rolling_buffer_to_info = dict() + for_loops = list() + hoist_buffer_to_for = defaultdict(list) + + RollingBufferInfo = namedtuple( + "RollingBufferInfo", ["rolling_axis", "rolling_extent", "axis_overlaps", "axis_iter_vars"] + ) + + def _pre_visit(stmt): + if isinstance(stmt, tvm.tir.For): + # Manage the stack of iter_vars + for_loops.append(stmt) + + elif isinstance(stmt, tvm.tir.AttrStmt): + if isinstance(stmt.node, tvm.tir.Buffer): + if stmt.attr_key == "rolling_buffer_scope" and stmt.value.value: + # If the attribute is indicating that a buffer should be a rolling + # buffer, then update the rolling_buffers set to include the bufffer + rolling_buffers.add(stmt.node) + # Keep a dictionary associating attribute statements with the buffers + # they reference. We'll need this if the buffer gets hoisted and we + # need to hoist all of its attributes at the same time. + buffer_to_attrs[stmt.node].append(stmt) + + elif isinstance(stmt, tvm.tir.BufferRealize): + if stmt.buffer in rolling_buffers: + # If a BufferRealize has been identified as needing to be made into + # a rolling buffer, begin the analysis... + bound_iter_vars = [] + bound_overlaps = [] + # We use the bound information of the BufferRealize to calculate + # how we can legally roll + for bound in stmt.bounds: + divisor = 1 + # Handle the case of fractional strides + # They take this form: floordiv(hh.outer, 2) + # Strip the floordiv and keep track of the divisor + if isinstance(bound.min, tvm.tir.FloorDiv): + divisor = bound.min.b.value + bound.min = bound.min.a + # If the bound is an int, we can't roll over it + if isinstance(bound.min, tvm.tir.IntImm): + iter_var = None + stride = 0 + # If the bound is just a Var, that implies the stride is 1 + elif isinstance(bound.min, tvm.tir.Var): + iter_var = bound.min + stride = 1 + # Otherwise, it's the iter var multiplied by the stride + # If not we're in unknown behaviour, so assert + else: + assert isinstance( + bound.min, tvm.tir.Mul + ), "Rolling buffer injection failed: the buffer striding is unsupported" + assert isinstance( + bound.min.a, tvm.tir.Var + ), "Rolling buffer injection failed: the buffer striding is unsupported" + assert isinstance( + bound.min.b, tvm.tir.IntImm + ), "Rolling buffer injection failed: the buffer striding is unsupported" + iter_var = bound.min.a + stride = bound.min.b.value + stride = math.ceil(stride / divisor) + bound_iter_vars.append(iter_var) + if iter_var is not None: + bound_overlaps.append(bound.extent.value - stride) + else: + bound_overlaps.append(0) + + # Pick the outermost iter_var that's mentioned in the bounds + # to be the rolling axis + roll_iter_var = None + roll_axis = -1 + for loop in for_loops: + iter_var = loop.loop_var + if iter_var in bound_iter_vars: + roll_iter_var = iter_var + roll_axis = bound_iter_vars.index(iter_var) + break + + # We must have found an axis to roll over + assert ( + roll_iter_var is not None + ), "Rolling buffer injection failed: no rolling axis found" + assert roll_axis != -1, "Rolling buffer injection failed: no rolling axis found" + rolling_buffer_info = RollingBufferInfo( + roll_axis, stmt.bounds[roll_axis].extent.value, bound_overlaps, bound_iter_vars + ) + rolling_buffer_to_info[stmt.buffer] = rolling_buffer_info + new_bounds = [] + for i, extent in enumerate(stmt.buffer.shape): + if i == rolling_buffer_info.rolling_axis: + new_bounds.append(tvm.ir.Range(rolling_buffer_info.rolling_extent)) + else: + new_bounds.append(tvm.ir.Range(extent)) + new_realize = tvm.tir.BufferRealize( + stmt.buffer, new_bounds, stmt.condition, stmt.body, stmt.span + ) + hoist_buffer_to_for[iter_var].append(new_realize) + + def _post_visit(stmt): + if isinstance(stmt, tvm.tir.For): + # Manage the stack of iter_vars + for_loops.pop() + # If the loop corresponds to an iter_var that needs a BufferRealize + # hoisting to its scope, perform the hoisting + if stmt.loop_var in hoist_buffer_to_for: + body = stmt + for realize in hoist_buffer_to_for[stmt.loop_var]: + attrs = buffer_to_attrs[realize.buffer] + new_realize = tvm.tir.BufferRealize( + realize.buffer, realize.bounds, realize.condition, body, realize.span + ) + # The attributes attached to the BufferRealize need hoisting too + for attr in attrs: + if attr.attr_key == "rolling_buffer_scope": + continue + new_realize = tvm.tir.AttrStmt( + attr.node, attr.attr_key, attr.value, new_realize, attr.span + ) + body = new_realize + return body + elif isinstance(stmt, tvm.tir.AttrStmt): + if stmt.node in rolling_buffers: + # Remove the attribute statements attached to rolling buffers + # because they will have been hoisted to the relevant rolling + # scope + return stmt.body + elif isinstance(stmt, tvm.tir.BufferRealize): + if stmt.buffer in rolling_buffers: + # Remove the original BufferRealize for rolling buffers + # because they will have been hoisted to the relevant rolling + # scope + return stmt.body + elif isinstance(stmt, tvm.tir.BufferStore): + if stmt.buffer in rolling_buffer_to_info: + rolling_buffer_info = rolling_buffer_to_info[stmt.buffer] + indices = [] + # First modify the access indices to use modulo arithmetic + # for the rolling axis + for i, index in enumerate(stmt.indices): + if i == rolling_buffer_info.rolling_axis: + indices.append(tvm.tir.FloorMod(index, rolling_buffer_info.rolling_extent)) + else: + indices.append(index) + buffer_store = tvm.tir.BufferStore(stmt.buffer, stmt.value, indices, stmt.span) + # Then wrap the BufferStores in some Ifs to avoid recomputing elements + for i, iter_var in enumerate(rolling_buffer_info.axis_iter_vars): + if iter_var is not None and rolling_buffer_info.axis_overlaps[i] > 0: + dmap = {iter_var: arith.IntervalSet(0, 0)} + term_2 = arith.Analyzer().int_set(stmt.indices[i], dmap).min_value + buffer_store = tvm.tir.IfThenElse( + tvm.tir.Or( + iter_var < 1, term_2 >= rolling_buffer_info.axis_overlaps[i] + ), + buffer_store, + None, + ) + return buffer_store + elif isinstance(stmt, tvm.tir.BufferLoad): + if stmt.buffer in rolling_buffer_to_info: + rolling_buffer_info = rolling_buffer_to_info[stmt.buffer] + indices = [] + # Modify the access indices to use modulo arithmetic + # for the rolling axis + for i, index in enumerate(stmt.indices): + if i == rolling_buffer_info.rolling_axis: + indices.append(tvm.tir.FloorMod(index, rolling_buffer_info.rolling_extent)) + else: + indices.append(index) + return tvm.tir.BufferLoad(stmt.buffer, indices, stmt.span) + + def _ftransform(f, mod, ctx): + return f.with_body( + tvm.tir.stmt_functor.ir_transform( + f.body, + _pre_visit, + _post_visit, + [ + "tir.AttrStmt", + "tir.BufferRealize", + "tir.For", + "tir.BufferStore", + "tir.BufferLoad", + ], + ) + ) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.InjectRollingBuffer" + ) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 26b22f99c215..ec93bdacdc74 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. """Wrapping existing transformations.""" -# pylint: disable=invalid-name +# pylint: disable=invalid-name,unused-import from . import _ffi_api from . import function_pass as _fpass +from .inject_rolling_buffer import InjectRollingBuffer def Apply(ftransform): diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 9a4eadb35619..809da2b1590c 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -484,7 +484,8 @@ ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Sta } ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap), debug_keep_trivial_loop); - ret.init_predicates = MakeBoundCheck(stage, dom_map, ret.init_vmap, true, skip_iter); + ret.init_predicates = + MakeBoundCheck(stage, dom_map, ret.init_vmap, !stage->rolling_buffer, skip_iter); for (auto& e : ret.init_predicates) { e = likely(e); } diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 8964c1013a53..9960dbd8201a 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -423,6 +423,13 @@ Stage& Stage::double_buffer() { return *this; } +Stage& Stage::rolling_buffer() { + StageNode* self = operator->(); + ICHECK(!self->is_output) << "Cannot apply rolling buffer on output"; + self->rolling_buffer = true; + return *this; +} + Stage CopyStage(const Stage& s) { ObjectPtr n = make_object(*s.operator->()); return Stage(n); @@ -879,6 +886,8 @@ TVM_REGISTER_GLOBAL("te.StageStorageAlign").set_body_method(&Stage::storage_alig TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffer); +TVM_REGISTER_GLOBAL("te.StageRollingBuffer").set_body_method(&Stage::rolling_buffer); + TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize); TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group); diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 355e3c39494b..ae341ffc2fc3 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -54,6 +54,9 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ pipeline = s->op->BuildRealize(s, dom_map, pipeline); // use attribute to mark scope of the operation. pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); + if (s->rolling_buffer) { + pipeline = AttrStmt(s->op, tir::attr::rolling_buffer_scope, Bool(true), pipeline); + } return pipeline; } diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 5c59961fe011..3e0eaa68be78 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -70,7 +70,8 @@ class TensorToBufferMapper : public StmtExprMutator { // TODO(tvm-team): remove realize_scope, turn the info into // Buffer's scope field in this pass. if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + op->attr_key == tir::attr::double_buffer_scope || + op->attr_key == tir::attr::rolling_buffer_scope) { Stmt body = op->body; Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { diff --git a/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py new file mode 100644 index 000000000000..b3be7f636985 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py @@ -0,0 +1,263 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.script +from tvm.script import tir, ty +from tvm import te +from tvm import topi +from tvm.driver.build_module import get_binds +import numpy as np + +import pytest + + +def _tile_nd(s, tensor, tile): + outer_indices = [] + inner_indices = [] + for i, size in enumerate(tile): + outer, inner = s[tensor].split(tensor.op.axis[i], size) + outer_indices.append(outer) + inner_indices.append(inner) + + s[tensor].reorder(*outer_indices, *inner_indices) + return outer_indices, inner_indices + + +def _lower_schedule(sch, args): + sch = sch.normalize() + bounds = tvm.te.schedule.InferBound(sch) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds) + + compact = tvm.te.schedule.VerifyCompactBuffer(stmt) + binds, arg_list = get_binds(args, compact, None) + func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) + + func = func.with_attr("global_symbol", "main") + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({"main": func}) + return mod + + +def _verify_schedule(sch, inputs, output): + mod = _lower_schedule(sch, inputs + [output]) + mods = [] + mods.append(mod) + mod = tvm.tir.transform.InjectRollingBuffer()(mod) + + def _check(stmt): + if isinstance(stmt, tvm.tir.AttrStmt): + assert stmt.attr_key != "rolling_buffer_scope", "Failed to lower rolling buffers" + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _check) + mods.append(mod) + + outputs = [] + ctx = tvm.cpu(0) + input_data = [] + for tensor in inputs: + shape = [i.value for i in tensor.shape] + input_data.append( + tvm.nd.array(np.random.randint(low=-100, high=100, size=shape).astype("int8"), ctx) + ) + shape = [i.value for i in output.shape] + out = tvm.nd.array(np.zeros(shape, dtype="int8"), ctx) + for mod in mods: + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.NarrowDataType(32)(mod) + mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + # Build for CPU execution + f = tvm.build(mod) + f(*input_data, out) + outputs.append(out.asnumpy()) + + np.testing.assert_equal(outputs[0], outputs[1]) + + +@pytest.mark.parametrize("tile_shape", [(1, 4, 8, 16), (1, 8, 7, 11), (1, 8, 3, 8), (1, 7, 5, 3)]) +def test_tile_shapes(tile_shape): + A = te.placeholder((1, 12, 14, 16), name="A", dtype="int8") + pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_b = topi.nn.pool2d(pool_a, (3, 5), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + + sch = tvm.te.create_schedule([pool_b.op]) + oi, ii = _tile_nd(sch, pool_b, tile_shape) + sch[pool_a].compute_at(sch[pool_b], oi[-1]) + sch[pool_a].rolling_buffer() + + _verify_schedule(sch, [A], pool_b) + + +def test_implied_split(): + A = te.placeholder((1, 12, 12, 16), name="A", dtype="int8") + pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_b = topi.nn.pool2d(pool_a, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + + sch = tvm.te.create_schedule([pool_b.op]) + n, h, w, c = pool_b.op.axis + oi, ii = sch[pool_b].split(w, 4) + sch[pool_a].compute_at(sch[pool_b], oi) + sch[pool_a].rolling_buffer() + + _verify_schedule(sch, [A], pool_b) + + +def test_upscale(): + A = te.placeholder((1, 12, 12, 16), name="A", dtype="int8") + pool = topi.nn.pool2d(A, (1, 1), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + upscale = te.compute((1, 24, 24, 16), lambda nn, hh, ww, cc: pool[nn, hh // 2, ww // 2, cc]) + + sch = tvm.te.create_schedule([upscale.op]) + oi, ii = _tile_nd(sch, upscale, (1, 5, 5, 16)) + sch[pool].compute_at(sch[upscale], oi[-1]) + sch[pool].rolling_buffer() + + _verify_schedule(sch, [A], upscale) + + +@pytest.mark.parametrize("tile_shape", [(1, 4, 8, 16), (1, 8, 7, 11), (1, 8, 3, 8), (1, 7, 5, 3)]) +def test_3_tiled_poolings(tile_shape): + A = te.placeholder((1, 14, 14, 16), name="A", dtype="int8") + pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_b = topi.nn.pool2d(pool_a, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_c = topi.nn.pool2d(pool_b, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + + sch = tvm.te.create_schedule([pool_c.op]) + oi, ii = _tile_nd(sch, pool_c, tile_shape) + sch[pool_b].compute_at(sch[pool_c], oi[-1]) + sch[pool_b].rolling_buffer() + sch[pool_a].compute_at(sch[pool_c], oi[-1]) + sch[pool_a].rolling_buffer() + + _verify_schedule(sch, [A], pool_c) + + +@pytest.mark.parametrize("tile_shape", [(1, 4, 8, 16), (1, 8, 7, 11), (1, 8, 3, 8), (1, 7, 5, 3)]) +def test_tiled_added_poolings(tile_shape): + A = te.placeholder((1, 12, 12, 16), name="A", dtype="int8") + B = te.placeholder((1, 14, 14, 16), name="A", dtype="int8") + pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_b = topi.nn.pool2d(B, (5, 5), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + add = topi.add(pool_a, pool_b) + pool_c = topi.nn.pool2d(add, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + + sch = tvm.te.create_schedule([pool_c.op]) + oi, ii = _tile_nd(sch, pool_c, tile_shape) + sch[add].compute_at(sch[pool_c], oi[-1]) + sch[add].rolling_buffer() + sch[pool_b].compute_at(sch[pool_c], oi[-1]) + sch[pool_b].rolling_buffer() + sch[pool_a].compute_at(sch[pool_c], oi[-1]) + sch[pool_a].rolling_buffer() + + _verify_schedule(sch, [A, B], pool_c) + + +@pytest.mark.parametrize("make_rolling", [(0, 0), (1, 0), (0, 1), (1, 1)]) +def test_mixed_buffers(make_rolling): + A = te.placeholder((1, 14, 14, 16), name="A", dtype="int8") + pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_b = topi.nn.pool2d(pool_a, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + pool_c = topi.nn.pool2d(pool_b, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") + + sch = tvm.te.create_schedule([pool_c.op]) + oi, ii = _tile_nd(sch, pool_c, (1, 4, 8, 16)) + sch[pool_b].compute_at(sch[pool_c], oi[-1]) + if make_rolling[0]: + sch[pool_b].rolling_buffer() + sch[pool_a].compute_at(sch[pool_c], oi[-1]) + if make_rolling[1]: + sch[pool_a].rolling_buffer() + + _verify_schedule(sch, [A], pool_c) + + +# fmt: off +@tvm.script.tir +class PreRollingBuffer: + def main(A: ty.handle, tensor: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + # buffer definition + tensor_2 = tir.buffer_decl([1, 10, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + A_1 = tir.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + tensor_1 = tir.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + tir.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "") + for ax1_outer in tir.serial(0, 2): + tir.attr(tensor_2, "rolling_buffer_scope", True) + tir.realize(tensor_2[0:1, (ax1_outer*4):((ax1_outer*4) + 6), 0:12, 0:16], "") + for ax1 in tir.serial(0, 6): + for ax2 in tir.serial(0, 12): + for ax3 in tir.serial(0, 16): + tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = tir.int8(0) + for dh in tir.serial(0, 3): + for dw in tir.serial(0, 3): + tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = tir.max(tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3]) + for ax1_inner in tir.serial(0, 4): + for ax2_inner in tir.serial(0, 8): + for ax3_inner in tir.serial(0, 16): + tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = tir.int8(0) + for dh_1 in tir.serial(0, 3): + for dw_1 in tir.serial(0, 5): + tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = tir.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, ((ax1_inner + (ax1_outer*4)) + dh_1), (ax2_inner + dw_1), ax3_inner]) + __tvm_meta__ = None + + +@tvm.script.tir +class PostRollingBuffer: + def main(A: ty.handle, tensor: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + # buffer definition + tensor_2 = tir.buffer_decl([1, 10, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + A_1 = tir.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + tensor_1 = tir.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + tir.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "") + tir.realize(tensor_2[0:1, 0:6, 0:12, 0:16], "") + for ax1_outer in tir.serial(0, 2): + for ax1 in tir.serial(0, 6): + for ax2 in tir.serial(0, 12): + for ax3 in tir.serial(0, 16): + if ((ax1_outer < 1) or (ax1 >= 2)): + tensor_2[0, tir.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = tir.int8(0) + for dh in tir.serial(0, 3): + for dw in tir.serial(0, 3): + if ((ax1_outer < 1) or (ax1 >= 2)): + tensor_2[0, tir.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = tir.max(tensor_2[0, tir.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3]) + for ax1_inner in tir.serial(0, 4): + for ax2_inner in tir.serial(0, 8): + for ax3_inner in tir.serial(0, 16): + tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = tir.int8(0) + for dh_1 in tir.serial(0, 3): + for dw_1 in tir.serial(0, 5): + tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = tir.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, tir.floormod(((ax1_inner + (ax1_outer*4)) + dh_1), 6), (ax2_inner + dw_1), ax3_inner]) + __tvm_meta__ = None +# fmt: on + + +def test_rolling_buffer_ir_transform(): + mod = PreRollingBuffer() + mod = tvm.tir.transform.InjectRollingBuffer()(mod) + script = tvm.script.asscript(mod, True) + mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(mod["main"], PostRollingBuffer()["main"], True) + + +if __name__ == "__main__": + pytest.main([__file__])