From da8a614ab540c40ee0e004425a8ea6884c9d56e7 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 23 Mar 2022 12:36:31 -0700 Subject: [PATCH 1/3] [TIR][Schedule] Change BufferIndexType enum in python side to string --- python/tvm/tir/schedule/__init__.py | 2 +- python/tvm/tir/schedule/schedule.py | 18 +++-------- .../primitive/layout_transformation.cc | 3 +- src/tir/schedule/schedule.cc | 4 +-- src/tir/schedule/utils.h | 32 +++++++++++++++++++ .../test_tir_schedule_transform_layout.py | 7 ++-- 6 files changed, 45 insertions(+), 21 deletions(-) diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 2314c7fb939f..5f0e169c43e3 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -19,6 +19,6 @@ from .block_scope import BlockScope, Dependency, DepKind, StmtSRef from .instruction import Instruction, InstructionKind -from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError, BufferType +from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError from .state import ScheduleDebugMask, ScheduleState from .trace import Trace diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index c54c7f74f24f..0c7e2147a619 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" -import enum from typing import Callable, Dict, List, Optional, Union from tvm._ffi import register_object as _register_object @@ -73,13 +72,6 @@ def __init__(self) -> None: } -class BufferType(enum.IntEnum): - """Type of buffer in access regions of a block""" - - READ = 0 - WRITE = 1 - - def _parse_error_render_level(error_render_level: str) -> int: if error_render_level not in _ERROR_RENDER_LEVEL: raise ValueError( @@ -2127,7 +2119,7 @@ def transform_layout( self, block: BlockRV, buffer_index: int, - buffer_type: BufferType, + buffer_index_type: str, index_map: Union[IndexMap, Callable], ) -> None: """Apply a transformation represented by IndexMap to buffer @@ -2137,8 +2129,8 @@ def transform_layout( The block that accesses the target buffer buffer_index: int The index of the buffer in block's read or write region - buffer_type : BufferType - Type of the buffer, READ or WRITE. + buffer_index_type : str + Type of the buffer index, "read" or "write" index_map : Union[IndexMap, Callable] The transformation to apply @@ -2167,7 +2159,7 @@ def before_transform_layout(a: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_storage_align) - sch.transform_layout(sch.get_block("B"), buffer_index=0, BufferType.WRITE, + sch.transform_layout(sch.get_block("B"), buffer_index=0, "write", index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16)) print(sch.mod["main"].script()) @@ -2193,7 +2185,7 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> if callable(index_map): index_map = IndexMap.from_func(index_map) _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member - self, block, buffer_index, buffer_type, index_map + self, block, buffer_index, buffer_index_type, index_map ) ########## Schedule: Misc ########## diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 56eedca1120d..cbf1e6dc7896 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -207,7 +207,8 @@ struct TransformLayoutTraits : public UnpackedInstTraits PythonAPICall py("transform_layout"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); - py.Input("buffer_index_type", buffer_index_type); + py.Input("buffer_index_type", + BufferIndexType2Str(static_cast(buffer_index_type->value))); py.Input("index_map", index_map->ToPythonString()); return py.Str(); } diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 82cd0a4a351a..fad7661aff02 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -229,9 +229,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") /******** (FFI) Layout transformation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, - int buffer_index_type, const IndexMap& index_map) { + const String& buffer_index_type, const IndexMap& index_map) { return self->TransformLayout(block_rv, buffer_index, - static_cast(buffer_index_type), index_map); + Str2BufferIndexType(buffer_index_type), index_map); }); /******** (FFI) Misc ********/ diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 2de8ef6e0c93..9564c2d172c5 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -431,6 +431,38 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl } } +/******** Helper functions for enum conversion ********/ + +/*! + * \brief Convert String to BufferIndexType + * \param str The string representation of BufferIndexType + * \return The converted BufferIndexType + */ +inline BufferIndexType Str2BufferIndexType(const String& str) { + if (str == "read") { + return BufferIndexType::kRead; + } else if (str == "write") { + return BufferIndexType::kWrite; + } else { + LOG(FATAL) << "ValueError: Unknown BufferIndexType: " << str; + throw; + } +} + +/*! + * \brief Convert BufferIndexType to String + * \param buffer_index_type The BufferIndexType value to convert + * \return The string representation of BufferIndexType + */ +inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { + if (buffer_index_type == BufferIndexType::kRead) { + return "read"; + } else { + ICHECK(buffer_index_type == BufferIndexType::kWrite); + return "write"; + } +} + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index e0a7f66bf278..9547b3c1a38f 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -21,7 +21,6 @@ import tvm from tvm import tir -from tvm.tir import BufferType from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip @@ -99,7 +98,7 @@ def test_two_elementwise_transform_intermediate_buffer(): sch = tir.Schedule(two_elementwise, debug_mask="all") block = sch.get_block("B") sch.transform_layout( - block, 0, BufferType.WRITE, lambda m, n: (m // 16, n // 16, m % 16, n % 16) + block, 0, "write", lambda m, n: (m // 16, n // 16, m % 16, n % 16) ) tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -108,7 +107,7 @@ def test_two_elementwise_transform_intermediate_buffer(): def test_two_elementwise_transform_input_buffer(): sch = tir.Schedule(two_elementwise, debug_mask="all") block = sch.get_block("B") - sch.transform_layout(block, 0, BufferType.READ, packed_index_map_func) + sch.transform_layout(block, 0, "read", packed_index_map_func) tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -116,7 +115,7 @@ def test_two_elementwise_transform_input_buffer(): def test_two_elementwise_transform_output_buffer(): sch = tir.Schedule(two_elementwise, debug_mask="all") block = sch.get_block("C") - sch.transform_layout(block, 0, BufferType.WRITE, packed_index_map_func) + sch.transform_layout(block, 0, "write", packed_index_map_func) tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) From 47d02308adc89485a4ece8828d9bc0f565035067 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 23 Mar 2022 13:52:34 -0700 Subject: [PATCH 2/3] address comments --- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/schedule/schedule.py | 4 +++- src/tir/schedule/schedule.cc | 4 ++-- src/tir/schedule/utils.h | 16 ---------------- 4 files changed, 6 insertions(+), 20 deletions(-) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 147360d7e087..17f9aa3d9c60 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -57,7 +57,7 @@ from .op import comm_reducer, min, max, sum from .op import q_multiply_shift -from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError, BufferType +from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError from . import schedule from . import ir_builder diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 0c7e2147a619..d537db28001c 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2184,8 +2184,10 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> """ if callable(index_map): index_map = IndexMap.from_func(index_map) + assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" + buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member - self, block, buffer_index, buffer_index_type, index_map + self, block, buffer_index, buffer_index_type_enum, index_map ) ########## Schedule: Misc ########## diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index fad7661aff02..82cd0a4a351a 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -229,9 +229,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") /******** (FFI) Layout transformation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, - const String& buffer_index_type, const IndexMap& index_map) { + int buffer_index_type, const IndexMap& index_map) { return self->TransformLayout(block_rv, buffer_index, - Str2BufferIndexType(buffer_index_type), index_map); + static_cast(buffer_index_type), index_map); }); /******** (FFI) Misc ********/ diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 9564c2d172c5..53cafa798b54 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -433,22 +433,6 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl /******** Helper functions for enum conversion ********/ -/*! - * \brief Convert String to BufferIndexType - * \param str The string representation of BufferIndexType - * \return The converted BufferIndexType - */ -inline BufferIndexType Str2BufferIndexType(const String& str) { - if (str == "read") { - return BufferIndexType::kRead; - } else if (str == "write") { - return BufferIndexType::kWrite; - } else { - LOG(FATAL) << "ValueError: Unknown BufferIndexType: " << str; - throw; - } -} - /*! * \brief Convert BufferIndexType to String * \param buffer_index_type The BufferIndexType value to convert From 920b2f291aaba48215cae218575ce94d8e65af3b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 24 Mar 2022 10:09:56 -0700 Subject: [PATCH 3/3] lint --- tests/python/unittest/test_tir_schedule_transform_layout.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 9547b3c1a38f..ba8e28845cfc 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -97,9 +97,7 @@ def two_elementwise_transformed_output_buffer( def test_two_elementwise_transform_intermediate_buffer(): sch = tir.Schedule(two_elementwise, debug_mask="all") block = sch.get_block("B") - sch.transform_layout( - block, 0, "write", lambda m, n: (m // 16, n // 16, m % 16, n % 16) - ) + sch.transform_layout(block, 0, "write", lambda m, n: (m // 16, n // 16, m % 16, n % 16)) tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise)