Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError, BufferType
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError

from . import schedule
from . import ir_builder
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@

from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
from .instruction import Instruction, InstructionKind
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError, BufferType
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
from .state import ScheduleDebugMask, ScheduleState
from .trace import Trace
20 changes: 7 additions & 13 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""The TensorIR schedule class"""
import enum
from typing import Callable, Dict, List, Optional, Union

from tvm._ffi import register_object as _register_object
Expand Down Expand Up @@ -73,13 +72,6 @@ def __init__(self) -> None:
}


class BufferType(enum.IntEnum):
"""Type of buffer in access regions of a block"""

READ = 0
WRITE = 1


def _parse_error_render_level(error_render_level: str) -> int:
if error_render_level not in _ERROR_RENDER_LEVEL:
raise ValueError(
Expand Down Expand Up @@ -2127,7 +2119,7 @@ def transform_layout(
self,
block: BlockRV,
buffer_index: int,
buffer_type: BufferType,
buffer_index_type: str,
index_map: Union[IndexMap, Callable],
) -> None:
"""Apply a transformation represented by IndexMap to buffer
Expand All @@ -2137,8 +2129,8 @@ def transform_layout(
The block that accesses the target buffer
buffer_index: int
The index of the buffer in block's read or write region
buffer_type : BufferType
Type of the buffer, READ or WRITE.
buffer_index_type : str
Type of the buffer index, "read" or "write"
index_map : Union[IndexMap, Callable]
The transformation to apply

Expand Down Expand Up @@ -2167,7 +2159,7 @@ def before_transform_layout(a: T.handle, c: T.handle) -> None:
.. code-block:: python

sch = tir.Schedule(before_storage_align)
sch.transform_layout(sch.get_block("B"), buffer_index=0, BufferType.WRITE,
sch.transform_layout(sch.get_block("B"), buffer_index=0, "write",
index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16))
print(sch.mod["main"].script())

Expand All @@ -2192,8 +2184,10 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
"""
if callable(index_map):
index_map = IndexMap.from_func(index_map)
assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type"
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member
self, block, buffer_index, buffer_type, index_map
self, block, buffer_index, buffer_index_type_enum, index_map
)

########## Schedule: Misc ##########
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
PythonAPICall py("transform_layout");
py.Input("block", block_rv);
py.Input("buffer_index", buffer_index);
py.Input("buffer_index_type", buffer_index_type);
py.Input("buffer_index_type",
BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value)));
py.Input("index_map", index_map->ToPythonString());
return py.Str();
}
Expand Down
16 changes: 16 additions & 0 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,22 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl
}
}

/******** Helper functions for enum conversion ********/

/*!
* \brief Convert BufferIndexType to String
* \param buffer_index_type The BufferIndexType value to convert
* \return The string representation of BufferIndexType
*/
inline String BufferIndexType2Str(BufferIndexType buffer_index_type) {
if (buffer_index_type == BufferIndexType::kRead) {
return "read";
} else {
ICHECK(buffer_index_type == BufferIndexType::kWrite);
return "write";
}
}

} // namespace tir
} // namespace tvm

Expand Down
9 changes: 3 additions & 6 deletions tests/python/unittest/test_tir_schedule_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import tvm
from tvm import tir
from tvm.tir import BufferType
from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip

Expand Down Expand Up @@ -98,25 +97,23 @@ def two_elementwise_transformed_output_buffer(
def test_two_elementwise_transform_intermediate_buffer():
sch = tir.Schedule(two_elementwise, debug_mask="all")
block = sch.get_block("B")
sch.transform_layout(
block, 0, BufferType.WRITE, lambda m, n: (m // 16, n // 16, m % 16, n % 16)
)
sch.transform_layout(block, 0, "write", lambda m, n: (m // 16, n // 16, m % 16, n % 16))
tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=two_elementwise)


def test_two_elementwise_transform_input_buffer():
sch = tir.Schedule(two_elementwise, debug_mask="all")
block = sch.get_block("B")
sch.transform_layout(block, 0, BufferType.READ, packed_index_map_func)
sch.transform_layout(block, 0, "read", packed_index_map_func)
tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=two_elementwise)


def test_two_elementwise_transform_output_buffer():
sch = tir.Schedule(two_elementwise, debug_mask="all")
block = sch.get_block("C")
sch.transform_layout(block, 0, BufferType.WRITE, packed_index_map_func)
sch.transform_layout(block, 0, "write", packed_index_map_func)
tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=two_elementwise)

Expand Down