diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 147360d7e087..17f9aa3d9c60 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -57,7 +57,7 @@ from .op import comm_reducer, min, max, sum from .op import q_multiply_shift -from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError, BufferType +from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError from . import schedule from . import ir_builder diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 2314c7fb939f..5f0e169c43e3 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -19,6 +19,6 @@ from .block_scope import BlockScope, Dependency, DepKind, StmtSRef from .instruction import Instruction, InstructionKind -from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError, BufferType +from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError from .state import ScheduleDebugMask, ScheduleState from .trace import Trace diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index c54c7f74f24f..d537db28001c 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" -import enum from typing import Callable, Dict, List, Optional, Union from tvm._ffi import register_object as _register_object @@ -73,13 +72,6 @@ def __init__(self) -> None: } -class BufferType(enum.IntEnum): - """Type of buffer in access regions of a block""" - - READ = 0 - WRITE = 1 - - def _parse_error_render_level(error_render_level: str) -> int: if error_render_level not in _ERROR_RENDER_LEVEL: raise ValueError( @@ -2127,7 +2119,7 @@ def transform_layout( self, block: BlockRV, buffer_index: int, - buffer_type: BufferType, + buffer_index_type: str, index_map: Union[IndexMap, Callable], ) -> None: """Apply a transformation represented by IndexMap to buffer @@ -2137,8 +2129,8 @@ def transform_layout( The block that accesses the target buffer buffer_index: int The index of the buffer in block's read or write region - buffer_type : BufferType - Type of the buffer, READ or WRITE. + buffer_index_type : str + Type of the buffer index, "read" or "write" index_map : Union[IndexMap, Callable] The transformation to apply @@ -2167,7 +2159,7 @@ def before_transform_layout(a: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_storage_align) - sch.transform_layout(sch.get_block("B"), buffer_index=0, BufferType.WRITE, + sch.transform_layout(sch.get_block("B"), buffer_index=0, "write", index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16)) print(sch.mod["main"].script()) @@ -2192,8 +2184,10 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> """ if callable(index_map): index_map = IndexMap.from_func(index_map) + assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" + buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member - self, block, buffer_index, buffer_type, index_map + self, block, buffer_index, buffer_index_type_enum, index_map ) ########## Schedule: Misc ########## diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 56eedca1120d..cbf1e6dc7896 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -207,7 +207,8 @@ struct TransformLayoutTraits : public UnpackedInstTraits PythonAPICall py("transform_layout"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); - py.Input("buffer_index_type", buffer_index_type); + py.Input("buffer_index_type", + BufferIndexType2Str(static_cast(buffer_index_type->value))); py.Input("index_map", index_map->ToPythonString()); return py.Str(); } diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 2de8ef6e0c93..53cafa798b54 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -431,6 +431,22 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl } } +/******** Helper functions for enum conversion ********/ + +/*! + * \brief Convert BufferIndexType to String + * \param buffer_index_type The BufferIndexType value to convert + * \return The string representation of BufferIndexType + */ +inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { + if (buffer_index_type == BufferIndexType::kRead) { + return "read"; + } else { + ICHECK(buffer_index_type == BufferIndexType::kWrite); + return "write"; + } +} + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index e0a7f66bf278..ba8e28845cfc 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -21,7 +21,6 @@ import tvm from tvm import tir -from tvm.tir import BufferType from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip @@ -98,9 +97,7 @@ def two_elementwise_transformed_output_buffer( def test_two_elementwise_transform_intermediate_buffer(): sch = tir.Schedule(two_elementwise, debug_mask="all") block = sch.get_block("B") - sch.transform_layout( - block, 0, BufferType.WRITE, lambda m, n: (m // 16, n // 16, m % 16, n % 16) - ) + sch.transform_layout(block, 0, "write", lambda m, n: (m // 16, n // 16, m % 16, n % 16)) tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -108,7 +105,7 @@ def test_two_elementwise_transform_intermediate_buffer(): def test_two_elementwise_transform_input_buffer(): sch = tir.Schedule(two_elementwise, debug_mask="all") block = sch.get_block("B") - sch.transform_layout(block, 0, BufferType.READ, packed_index_map_func) + sch.transform_layout(block, 0, "read", packed_index_map_func) tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -116,7 +113,7 @@ def test_two_elementwise_transform_input_buffer(): def test_two_elementwise_transform_output_buffer(): sch = tir.Schedule(two_elementwise, debug_mask="all") block = sch.get_block("C") - sch.transform_layout(block, 0, BufferType.WRITE, packed_index_map_func) + sch.transform_layout(block, 0, "write", packed_index_map_func) tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise)