From e4010379b1af41511086d1a41c49d975a4922614 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 10 May 2022 15:00:05 -0500 Subject: [PATCH 01/16] [TIR][Schedule] Added Schedule.transform_layout_sugared --- python/tvm/tir/schedule/schedule.py | 84 ++++++++++++++++++- .../test_tir_schedule_transform_layout.py | 49 +++++++++-- 2 files changed, 123 insertions(+), 10 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 8bfd9063158c..2ce3df1dc43b 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -21,7 +21,7 @@ from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object, String -from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc +from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, Buffer from ..function import IndexMap from . import _ffi_api @@ -2114,6 +2114,88 @@ def after_unannotate(a: T.handle, b: T.handle) -> None: ########## Schedule: Layout transformation ########## + @type_checked + def transform_layout_sugared( + self, + index_map: Union[IndexMap, Callable], + buffer: str, + block: Union[BlockRV, str], + ) -> None: + """Apply a transformation represented by IndexMap to buffer + + This is a wrapper around `transform_layout`, intended for ease + of use. + + Parameters + ---------- + index_map : Union[IndexMap, Callable] + + The transformation to apply + + buffer: Union[Buffer,str] + + The buffer to be transformed. This buffer must exist in + the reads or writes of the block. If a string, the + reads/writes of the block must not contain more than one + buffer with that name. + + block : Union[BlockRV,str] + + The block that accesses the target buffer. If a string, + should refer to a name that uniquely identifies a block + within the schedule. + + """ + + if isinstance(block, str): + block = self.get_block(block) + + def iter_buffers(): + block_obj = self.get(block) + for i, read in enumerate(block_obj.reads): + yield i, "read", read.buffer + for i, write in enumerate(block_obj.writes): + yield i, "write", write.buffer + + possible_buffers = {} + + if isinstance(buffer, str): + # String lookup requires ensuring that the name is unique + for buffer_index, buffer_index_type, buf in iter_buffers(): + if buf.name == buffer: + possible_buffers[buf] = (buffer_index, buffer_index_type) + + block_name = self.get(block).name_hint + assert possible_buffers, f"Could not find buffer '{buffer}' in block '{block_name}'" + assert ( + len(possible_buffers) == 1 + ), f"Multiple buffers named '{buffer}' in block '{block_name}'" + buffer_obj, (buffer_index, buffer_index_type) = next(iter(possible_buffers.items())) + + elif isinstance(buffer, Buffer): + # Buffer lookup has unique id, can break out early + found = False + for buffer_index, buffer_index_type, buffer_obj in iter_buffers(): + if buffer_obj.same_as(buffer): + found = True + break + + block_name = self.get(block).name_hint + assert found, "Could not find buffer '{buffer.name}' in block '{block_name}'" + + else: + raise TypeError( + f"Argument 'buffer' should be str or tir.Buffer, " + f"but found {type(buffer)} instead." + ) + + ndim = len(buffer_obj.shape) + + if callable(index_map): + index_map = IndexMap.from_func(index_map, ndim=ndim) + + self.transform_layout(block, buffer_index, buffer_index_type, index_map) + @type_checked def transform_layout( self, diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index ba8e28845cfc..1b8444e9fb96 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -93,27 +93,58 @@ def two_elementwise_transformed_output_buffer( # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on +use_sugared_transform = tvm.testing.parameter( + by_dict={"transform_layout": False, "transform_layout_sugared": True} +) -def test_two_elementwise_transform_intermediate_buffer(): + +def test_two_elementwise_transform_intermediate_buffer(use_sugared_transform): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = sch.get_block("B") - sch.transform_layout(block, 0, "write", lambda m, n: (m // 16, n // 16, m % 16, n % 16)) + + if use_sugared_transform: + sch.transform_layout_sugared( + index_map=packed_index_map_func, + block="B", + buffer="B", + ) + else: + block = sch.get_block("B") + sch.transform_layout(block, 0, "write", packed_index_map_func) + tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) -def test_two_elementwise_transform_input_buffer(): +def test_two_elementwise_transform_input_buffer(use_sugared_transform): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = sch.get_block("B") - sch.transform_layout(block, 0, "read", packed_index_map_func) + + if use_sugared_transform: + sch.transform_layout_sugared( + index_map=packed_index_map_func, + block="B", + buffer="A", + ) + else: + block = sch.get_block("B") + sch.transform_layout(block, 0, "read", packed_index_map_func) + tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) -def test_two_elementwise_transform_output_buffer(): +def test_two_elementwise_transform_output_buffer(use_sugared_transform): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = sch.get_block("C") - sch.transform_layout(block, 0, "write", packed_index_map_func) + + if use_sugared_transform: + sch.transform_layout_sugared( + index_map=packed_index_map_func, + block="C", + buffer="C", + ) + else: + block = sch.get_block("C") + sch.transform_layout(block, 0, "write", packed_index_map_func) + tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) From 7ef231215207a822a5c1875be15f60c53ab9429b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 10 May 2022 15:53:38 -0500 Subject: [PATCH 02/16] [TE][TIR] Reduced duplication in TE/TIR layout transformations Previously, the implementations of `tir.IndexMap.from_func` and `te.Stage.transform_layout` had significant duplication to handle argument parsing. This commit extracts the shared logic into `tir.IndexMap`. --- python/tvm/te/schedule.py | 70 ++++--------------------- python/tvm/tir/function.py | 102 ++++++++++++++++++++++++++++++++++--- 2 files changed, 104 insertions(+), 68 deletions(-) diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index fdd08f9208c9..50f9a22ec205 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -25,7 +25,7 @@ from tvm.runtime import Object, convert from tvm.ir import container as _container -from tvm.tir import IterVar, Buffer, Var +from tvm.tir import IterVar, Buffer, Var, IndexMap from . import tensor as _tensor from . import _ffi_api @@ -599,65 +599,12 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr """ - args = [] - var_arg_name = None - kwargs = collections.OrderedDict() - default_index_dtype = "int32" - - # Make a dummy variable for each explicitly named input index. - # We may have some keyword-only arguments, if the function has - # *args before the last argument. - params = inspect.signature(mapping_function).parameters - for name, param in params.items(): - if param.kind in [ - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ]: - args.append(tvm.tir.Var(name, default_index_dtype)) - - elif param.kind == inspect.Parameter.VAR_POSITIONAL: - var_arg_name = name - - elif param.kind == inspect.Parameter.KEYWORD_ONLY: - kwargs[name] = tvm.tir.Var(name, default_index_dtype) - - elif param.kind in [inspect.Parameter.VAR_KEYWORD]: - raise ValueError("transform_layout mapping may not have **kwargs") - ndim = len(self.op.output(0).shape) + index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim=ndim) - # Now that all the named arguments have been collected, - # everything that remains should go to the *args, if - # specified. - if var_arg_name is not None: - num_var_args = ndim - len(args) - len(kwargs) - for i in range(num_var_args): - args.append(tvm.tir.Var(f"{var_arg_name}[{i}]", default_index_dtype)) - - initial_indices = args + list(kwargs.values()) - if len(initial_indices) != ndim: - raise ValueError( - f"transform_layout mapping accepts {len(params)} initial indices, " - f"but {self.op.name} is {len(self.op.shape)}-dimensional" - ) - - mapping = mapping_function(*args, **kwargs) - - final_indices = [] - axis_separators = [] - for val in mapping: - if isinstance(val, tvm.ir.PrimExpr): - final_indices.append(val) - elif val is AXIS_SEPARATOR: - axis_separators.append(len(final_indices)) - else: - raise TypeError( - "Expected mapping function to return list of " - "either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR. " - "Instead received {val} of type {type(val)}." - ) - - new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, final_indices) + new_iter_vars = _ffi_api.StageTransformLayout( + self, index_map.initial_indices, index_map.final_indices + ) _ffi_api.StageSetAxisSeparators(self, axis_separators) return new_iter_vars or None @@ -700,9 +647,10 @@ def __exit__(self, ptype, value, trace): # Sentinel value used to indicate which groups of pre-flattening axes -# should be used to post-flattening axes axes. See -# Stage.transform_layout for more details. -AXIS_SEPARATOR = "axis_separator" +# should be used to post-flattening axes axes. Moved from +# te.AXIS_SEPARATOR to tir.IndexMap.AXIS_SEPARATOR for general use, +# maintained here for backwards compatibility. +AXIS_SEPARATOR = IndexMap.AXIS_SEPARATOR tvm._ffi._init_api("schedule", __name__) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index d84513e072d3..cdb15b3bf3c0 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -16,8 +16,9 @@ # under the License. """Function data types.""" -from typing import Callable, List, Mapping, Optional, Union, Tuple +import collections import inspect +from typing import Callable, List, Mapping, Optional, Union, Tuple import tvm import tvm._ffi @@ -258,6 +259,11 @@ class IndexMap(Object): initial_indices: List[Var] final_indices: List[PrimExpr] + # Sentinel value used to indicate which groups of pre-flattening axes + # should be used to post-flattening axes axes. See + # Stage.transform_layout for more details. + AXIS_SEPARATOR = "axis_separator" + def __init__(self, initial_indices, final_indices): self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices) @@ -268,34 +274,116 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): Parameters ---------- mapping_function : Callable - The function to map from source indices to target indices + + The function to map from source indices to target indices. + The function should accept tir.Var parameters and return a + list. Each element of the returned list should be a + tir.PrimExpr. + + ndim: Optional[int] + + The dimensionality of the buffer to which this + transformation should be applied. If mapping_function + uses variadic argument *args, ndim must be specified. If + mapping_function does not use variadic arguments, ndim is + optional. + + Returns + ------- + index_map: IndexMap + + Returns an IndexMap representing the mapping_function. + + """ + index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim) + assert not axis_separators, ( + "The mapping_function provided to IndexMap.from_func may not return IndexMap.AXIS_SEPARATOR. " + "If required, please use IndexMap.from_func_with_separators instead." + ) + return index_map + + @staticmethod + def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None): + """Create an index map from a function + + Parameters + ---------- + mapping_function : Callable + + The function to map from source indices to target indices. + The function should accept tir.Var parameters and return a + list. Each element of the returned list should be either a + tir.PrimExpr or the object IndexMap.AXIS_SEPARATOR. + + ndim: Optional[int] + + The dimensionality of the buffer to which this + transformation should be applied. If mapping_function + uses variadic argument *args, ndim must be specified. If + mapping_function does not use variadic arguments, ndim is + optional. + + Returns + ------- + ret: Tuple[IndexMap, List[int]] + + Returns a tuple whose first element is an IndexMap + representing the mapping_function, and whose second index + is a list of indices at which IndexMap.AXIS_SEPARATOR + occurred. + """ params = inspect.signature(mapping_function).parameters - default_index_dtype = "int32" + args = [] var_arg_name = None + kwargs = collections.OrderedDict() + default_index_dtype = "int32" + for name, param in params.items(): if param.kind in [ inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ]: args.append(tvm.tir.Var(name, default_index_dtype)) + elif param.kind == inspect.Parameter.VAR_POSITIONAL: var_arg_name = name + + elif param.kind == inspect.Parameter.KEYWORD_ONLY: + kwargs[name] = tvm.tir.Var(name, default_index_dtype) + else: - raise ValueError("transform_layout mapping may not have *args or **kwargs") + raise ValueError("transform_layout mapping may not have *args") # Now that all the named arguments have been collected, # everything that remains should go to the *args, if # specified. if var_arg_name is not None: assert ndim is not None, "ndim must be specified when *args is used" - num_var_args = ndim - len(args) + num_var_args = ndim - len(args) - len(kwargs) for i in range(num_var_args): args.append(tvm.tir.Var(f"{var_arg_name}_{i}", default_index_dtype)) - final_indices = mapping_function(*args) - return IndexMap(args, final_indices) + mapping = mapping_function(*args, **kwargs) + + initial_indices = args + list(kwargs.values()) + + final_indices = [] + axis_separators = [] + for val in mapping: + if isinstance(val, tvm.ir.PrimExpr): + final_indices.append(val) + elif val is IndexMap.AXIS_SEPARATOR: + axis_separators.append(len(final_indices)) + else: + raise TypeError( + "Expected mapping function to return list of " + "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. " + "Instead received {val} of type {type(val)}." + ) + + return IndexMap(initial_indices, final_indices), axis_separators def is_equivalent_to(self, other_map: "IndexMap") -> bool: """Return if the index maps are equivalent. From bb215a19f43b5fc622b979e948a36f8d8c8194a0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 10 May 2022 16:08:50 -0500 Subject: [PATCH 03/16] Enabled *args in Schedule.transform_layout_sugared --- .../test_tir_schedule_transform_layout.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 1b8444e9fb96..9af021177972 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -149,5 +149,33 @@ def test_two_elementwise_transform_output_buffer(use_sugared_transform): verify_trace_roundtrip(sch=sch, mod=two_elementwise) +def test_var_args_sugar(): + @T.prim_func + def summation_3d( + A: T.Buffer[(1024, 1024, 32), "float32"], B: T.Buffer[(1,), "float32"] + ) -> None: + B[0] = 0 + for i, j, k in T.grid(1024, 1024, 32): + with T.block("compute"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[0] = B[0] + A[vi, vj, vk] + + @T.prim_func + def summation_3d_split( + A: T.Buffer[(1024, 1024, 8, 4), "float32"], B: T.Buffer[(1,), "float32"] + ) -> None: + B[0] = 0 + for i, j, k in T.grid(1024, 1024, 32): + with T.block("compute"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[0] = B[0] + A[vi, vj, vk // 4, vk % 4] + + sch = tir.Schedule(summation_3d, debug_mask="all") + sch.transform_layout_sugared( + index_map=lambda *indices, k: [*indices, k // 4, k % 4], block="compute", buffer="A" + ) + tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"]) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 3740d75e7c932007a683901d9c6043a3fe78fbba Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 11 May 2022 09:29:49 -0500 Subject: [PATCH 04/16] Fix lint error --- python/tvm/tir/function.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index cdb15b3bf3c0..254b8e721617 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -297,7 +297,8 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): """ index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim) assert not axis_separators, ( - "The mapping_function provided to IndexMap.from_func may not return IndexMap.AXIS_SEPARATOR. " + "The mapping_function provided to IndexMap.from_func " + "may not return IndexMap.AXIS_SEPARATOR. " "If required, please use IndexMap.from_func_with_separators instead." ) return index_map From b4195cd4064b1e6e3251c87aa195045b9979ca30 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 11 May 2022 09:48:42 -0500 Subject: [PATCH 05/16] Allow Schedule.transform_layout_sugared to set axis separators --- python/tvm/tir/schedule/schedule.py | 6 +++- .../test_tir_schedule_set_axis_separator.py | 32 +++++++++++++++---- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 2ce3df1dc43b..b67a0b861915 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2192,9 +2192,13 @@ def iter_buffers(): ndim = len(buffer_obj.shape) if callable(index_map): - index_map = IndexMap.from_func(index_map, ndim=ndim) + index_map, axis_separators = IndexMap.from_func_with_separators(index_map, ndim=ndim) + else: + axis_separators = [] self.transform_layout(block, buffer_index, buffer_index_type, index_map) + if axis_separators: + self.set_axis_separator(block, buffer_index, buffer_index_type, axis_separators) @type_checked def transform_layout( diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py index d829a3f1b76c..40fb4b1e3940 100644 --- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py +++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import tir +from tvm.tir import IndexMap from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip @@ -101,11 +102,19 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer[(128, 128), "flo # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg +use_sugared_transform = tvm.testing.parameter( + by_dict={"set_axis_separators": False, "transform_layout_sugared": True} +) -def test_set_axis_separator(): +def test_set_axis_separator(use_sugared_transform): func = element_wise s = tir.Schedule(func, debug_mask='all') - s.set_axis_separator(s.get_block("B"), 0, "write", [1]) + + if use_sugared_transform: + s.set_axis_separator(s.get_block("B"), 0, "write", [1]) + else: + s.transform_layout_sugared(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + tvm.ir.assert_structural_equal(element_wise_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -119,18 +128,29 @@ def test_set_scope_fail_on_index_out_of_bound(): s.set_axis_separator(s.get_block("B"), -1, "read",[1]) -def test_set_axis_separator_input_buffer(): +def test_set_axis_separator_input_buffer(use_sugared_transform): func = element_wise s = tir.Schedule(func, debug_mask='all') - s.set_axis_separator(s.get_block("B"), 0, "read", [1]) + + if use_sugared_transform: + s.transform_layout_sugared(block='B', buffer='A', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + else: + s.set_axis_separator(s.get_block("B"), 0, "read", [1]) + + tvm.ir.assert_structural_equal(element_wise_set_axis_separator_input_buffer, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) -def test_set_axis_separator_subregion(): +def test_set_axis_separator_subregion(use_sugared_transform): func = element_wise_subregion_match s = tir.Schedule(func, debug_mask='all') - s.set_axis_separator(s.get_block("B"), 0, "write", [1]) + + if use_sugared_transform: + s.transform_layout_sugared(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + else: + s.set_axis_separator(s.get_block("B"), 0, "write", [1]) + tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) From c048d5cf761f2a2dfef5b11fa40864e4339e8523 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 11 May 2022 12:03:45 -0500 Subject: [PATCH 06/16] Merged transform_layout_sugared functionality into transform_layout --- python/tvm/tir/schedule/schedule.py | 185 ++++++++++-------- .../test_tir_schedule_set_axis_separator.py | 20 +- .../test_tir_schedule_transform_layout.py | 16 +- 3 files changed, 125 insertions(+), 96 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b67a0b861915..d7acb98a99dc 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union, Tuple from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error @@ -2114,58 +2114,32 @@ def after_unannotate(a: T.handle, b: T.handle) -> None: ########## Schedule: Layout transformation ########## - @type_checked - def transform_layout_sugared( - self, - index_map: Union[IndexMap, Callable], - buffer: str, - block: Union[BlockRV, str], - ) -> None: - """Apply a transformation represented by IndexMap to buffer - - This is a wrapper around `transform_layout`, intended for ease - of use. - - Parameters - ---------- - index_map : Union[IndexMap, Callable] - - The transformation to apply - - buffer: Union[Buffer,str] - - The buffer to be transformed. This buffer must exist in - the reads or writes of the block. If a string, the - reads/writes of the block must not contain more than one - buffer with that name. - - block : Union[BlockRV,str] - - The block that accesses the target buffer. If a string, - should refer to a name that uniquely identifies a block - within the schedule. + def _normalize_block_arg(self, block: Union[BlockRV, str]) -> BlockRV: + if isinstance(block, str): + return self.get_block(block) + else: + return block - """ + def _normalize_buffer_arg( + self, block: BlockRV, buffer: Union[Tuple[str, int], str, Buffer] + ) -> Tuple[str, int, Buffer]: - if isinstance(block, str): - block = self.get_block(block) + block_name = self.get(block).name_hint def iter_buffers(): block_obj = self.get(block) for i, read in enumerate(block_obj.reads): - yield i, "read", read.buffer + yield "read", i, read.buffer for i, write in enumerate(block_obj.writes): - yield i, "write", write.buffer - - possible_buffers = {} + yield "write", i, write.buffer if isinstance(buffer, str): + possible_buffers = {} # String lookup requires ensuring that the name is unique for buffer_index, buffer_index_type, buf in iter_buffers(): if buf.name == buffer: - possible_buffers[buf] = (buffer_index, buffer_index_type) + possible_buffers[buf] = (buffer_index_type, buffer_index) - block_name = self.get(block).name_hint assert possible_buffers, f"Could not find buffer '{buffer}' in block '{block_name}'" assert ( len(possible_buffers) == 1 @@ -2180,45 +2154,71 @@ def iter_buffers(): found = True break - block_name = self.get(block).name_hint assert found, "Could not find buffer '{buffer.name}' in block '{block_name}'" - else: - raise TypeError( - f"Argument 'buffer' should be str or tir.Buffer, " - f"but found {type(buffer)} instead." + elif isinstance(buffer, tuple): + buffer_index_type, buffer_index = buffer + assert buffer_index_type in ["read", "write",], ( + f"Invalid buffer_index_type. " + f"Expected 'read' or 'write', " + f"but received {buffer_index_type}" ) + buffer_list = ( + self.get(block).reads if buffer_index_type == "read" else self.get(block).writes + ) + assert 0 <= buffer_index < len(buffer_list), ( + f"Invalid buffer_index {buffer_index}. " + f"Block {block_name} has only " + f"{len(buffer_list)} {buffer_index_type} buffers." + ) + buffer_obj = buffer_list[buffer_index].buffer - ndim = len(buffer_obj.shape) - - if callable(index_map): - index_map, axis_separators = IndexMap.from_func_with_separators(index_map, ndim=ndim) else: - axis_separators = [] + raise TypeError(f"Invalid type for argument 'buffer': {type(buffer)}") - self.transform_layout(block, buffer_index, buffer_index_type, index_map) - if axis_separators: - self.set_axis_separator(block, buffer_index, buffer_index_type, axis_separators) + return (buffer_index_type, buffer_index, buffer_obj) - @type_checked + # @type_checked def transform_layout( self, - block: BlockRV, - buffer_index: int, - buffer_index_type: str, + block: Union[BlockRV, str], + buffer: Union[Tuple[str, int], str, Buffer], index_map: Union[IndexMap, Callable], ) -> None: """Apply a transformation represented by IndexMap to buffer + Parameters ---------- - block : BlockRV - The block that accesses the target buffer - buffer_index: int - The index of the buffer in block's read or write region - buffer_index_type : str - Type of the buffer index, "read" or "write" + block : Union[BlockRV, str] + + The block that accesses the target buffer. If a string, + this must uniquely identify a block. + + buffer: Union[Tuple[str,int], Buffer, str] + + The buffer to be transformed, or a specification of how to + identify the buffer to be transformed. + + If `buffer` if a tuple of ``(str,int)``, the first item + should be either "read" or "write", and the second item is + an index into the block's read or write regions. + + If `buffer` is a string, it is the name of the buffer, + which must exist within the reads/writes of the block. In + addition, the reads/writes of the block may not contain + more than one buffer with this name. + + If `buffer` is a Buffer object, it must exist within the + reads/writes of the block. + index_map : Union[IndexMap, Callable] - The transformation to apply + + The transformation to apply. + + If `index_map` is a callable, and the returned list + contains IndexMap.AXIS_SEPARATOR, the SetAxisSeparators + primitive will be called in addition to the + TransformLayout primitive. Examples -------- @@ -2245,7 +2245,7 @@ def before_transform_layout(a: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_storage_align) - sch.transform_layout(sch.get_block("B"), buffer_index=0, "write", + sch.transform_layout(sch.get_block("B"), buffer=("write",0), index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16)) print(sch.mod["main"].script()) @@ -2268,20 +2268,29 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 """ + block = self._normalize_block_arg(block) + buffer_index_type, buffer_index, buffer_obj = self._normalize_buffer_arg(block, buffer) + + ndim = len(buffer_obj.shape) if callable(index_map): - index_map = IndexMap.from_func(index_map) - assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" + index_map, axis_separators = IndexMap.from_func_with_separators(index_map, ndim=ndim) + else: + axis_separators = [] + buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member self, block, buffer_index, buffer_index_type_enum, index_map ) + if axis_separators: + _ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member + self, block, buffer_index, buffer_index_type_enum, axis_separators + ) - @type_checked + # @type_checked def set_axis_separator( self, - block: BlockRV, - buffer_index: int, - buffer_index_type: str, + block: Union[BlockRV, str], + buffer: Union[Tuple[str, int], str, Buffer], axis_separators: Optional[List[int]], ) -> None: """Set the axis separator of a buffer, where the buffer is specified by a block and a read @@ -2289,13 +2298,30 @@ def set_axis_separator( Parameters ---------- - block : BlockRV - The block that accesses the target buffer - buffer_index: int - The index of the buffer in block's read or write region - buffer_index_type : str - Type of the buffer index, "read" or "write" + block : Union[BlockRV, str] + + The block that accesses the target buffer. If a string, + this must uniquely identify a block. + + buffer: Union[Tuple[str,int], Buffer, str] + + The buffer to be transformed, or a specification of how to + identify the buffer to be transformed. + + If `buffer` if a tuple of ``(str,int)``, the first item + should be either "read" or "write", and the second item is + an index into the block's read or write regions. + + If `buffer` is a string, it is the name of the buffer, + which must exist within the reads/writes of the block. In + addition, the reads/writes of the block may not contain + more than one buffer with this name. + + If `buffer` is a Buffer object, it must exist within the + reads/writes of the block. + axis_separators : Optional[List[int]] + The axis separators. Examples @@ -2349,7 +2375,10 @@ def after_set_axis_separators( C[vi, vj] = B[vi, vj] + T.float32(1) """ axis_separators = axis_separators or [] - assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" + + block = self._normalize_block_arg(block) + buffer_index_type, buffer_index, buffer_obj = self._normalize_buffer_arg(block, buffer) + buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 _ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member self, block, buffer_index, buffer_index_type_enum, axis_separators diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py index 40fb4b1e3940..3bf04aa1a1c0 100644 --- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py +++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py @@ -111,9 +111,9 @@ def test_set_axis_separator(use_sugared_transform): s = tir.Schedule(func, debug_mask='all') if use_sugared_transform: - s.set_axis_separator(s.get_block("B"), 0, "write", [1]) + s.set_axis_separator(s.get_block("B"), ("write",0), [1]) else: - s.transform_layout_sugared(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + s.transform_layout(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) tvm.ir.assert_structural_equal(element_wise_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -122,10 +122,10 @@ def test_set_axis_separator(use_sugared_transform): def test_set_scope_fail_on_index_out_of_bound(): func = element_wise s = tir.Schedule(func, debug_mask='all') - with pytest.raises(tvm.tir.ScheduleError): - s.set_axis_separator(s.get_block("B"), 1, "write",[1]) - with pytest.raises(tvm.tir.ScheduleError): - s.set_axis_separator(s.get_block("B"), -1, "read",[1]) + with pytest.raises(AssertionError): + s.set_axis_separator(s.get_block("B"), ("write",1),[1]) + with pytest.raises(AssertionError): + s.set_axis_separator(s.get_block("B"), ("read",-1),[1]) def test_set_axis_separator_input_buffer(use_sugared_transform): @@ -133,9 +133,9 @@ def test_set_axis_separator_input_buffer(use_sugared_transform): s = tir.Schedule(func, debug_mask='all') if use_sugared_transform: - s.transform_layout_sugared(block='B', buffer='A', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + s.transform_layout(block='B', buffer='A', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) else: - s.set_axis_separator(s.get_block("B"), 0, "read", [1]) + s.set_axis_separator(s.get_block("B"), ("read",0), [1]) tvm.ir.assert_structural_equal(element_wise_set_axis_separator_input_buffer, s.mod["main"]) @@ -147,9 +147,9 @@ def test_set_axis_separator_subregion(use_sugared_transform): s = tir.Schedule(func, debug_mask='all') if use_sugared_transform: - s.transform_layout_sugared(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + s.transform_layout(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) else: - s.set_axis_separator(s.get_block("B"), 0, "write", [1]) + s.set_axis_separator(s.get_block("B"), ("write",0), [1]) tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 9af021177972..7a447642baf8 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -102,14 +102,14 @@ def test_two_elementwise_transform_intermediate_buffer(use_sugared_transform): sch = tir.Schedule(two_elementwise, debug_mask="all") if use_sugared_transform: - sch.transform_layout_sugared( - index_map=packed_index_map_func, + sch.transform_layout( block="B", buffer="B", + index_map=packed_index_map_func, ) else: block = sch.get_block("B") - sch.transform_layout(block, 0, "write", packed_index_map_func) + sch.transform_layout(block, ("write", 0), packed_index_map_func) tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -119,14 +119,14 @@ def test_two_elementwise_transform_input_buffer(use_sugared_transform): sch = tir.Schedule(two_elementwise, debug_mask="all") if use_sugared_transform: - sch.transform_layout_sugared( + sch.transform_layout( index_map=packed_index_map_func, block="B", buffer="A", ) else: block = sch.get_block("B") - sch.transform_layout(block, 0, "read", packed_index_map_func) + sch.transform_layout(block, ("read", 0), packed_index_map_func) tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -136,14 +136,14 @@ def test_two_elementwise_transform_output_buffer(use_sugared_transform): sch = tir.Schedule(two_elementwise, debug_mask="all") if use_sugared_transform: - sch.transform_layout_sugared( + sch.transform_layout( index_map=packed_index_map_func, block="C", buffer="C", ) else: block = sch.get_block("C") - sch.transform_layout(block, 0, "write", packed_index_map_func) + sch.transform_layout(block, ("write", 0), packed_index_map_func) tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -171,7 +171,7 @@ def summation_3d_split( B[0] = B[0] + A[vi, vj, vk // 4, vk % 4] sch = tir.Schedule(summation_3d, debug_mask="all") - sch.transform_layout_sugared( + sch.transform_layout( index_map=lambda *indices, k: [*indices, k // 4, k % 4], block="compute", buffer="A" ) tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"]) From 7a7810296317e8fff68308d2992cb32e1eafeb7c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 11 May 2022 16:01:58 -0500 Subject: [PATCH 07/16] Fix lint errors --- python/tvm/tir/schedule/schedule.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d7acb98a99dc..e27d846260be 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2116,9 +2116,9 @@ def after_unannotate(a: T.handle, b: T.handle) -> None: def _normalize_block_arg(self, block: Union[BlockRV, str]) -> BlockRV: if isinstance(block, str): - return self.get_block(block) - else: - return block + block = self.get_block(block) + + return block def _normalize_buffer_arg( self, block: BlockRV, buffer: Union[Tuple[str, int], str, Buffer] @@ -2377,7 +2377,7 @@ def after_set_axis_separators( axis_separators = axis_separators or [] block = self._normalize_block_arg(block) - buffer_index_type, buffer_index, buffer_obj = self._normalize_buffer_arg(block, buffer) + buffer_index_type, buffer_index, _ = self._normalize_buffer_arg(block, buffer) buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 _ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member From 3229175812d9d22d681eee9fcfa40a782d60ae0c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 12 May 2022 09:26:52 -0500 Subject: [PATCH 08/16] Fix lint error --- python/tvm/tir/schedule/schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index e27d846260be..7b8a81881615 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2116,7 +2116,7 @@ def after_unannotate(a: T.handle, b: T.handle) -> None: def _normalize_block_arg(self, block: Union[BlockRV, str]) -> BlockRV: if isinstance(block, str): - block = self.get_block(block) + return self.get_block(block) return block From ec66ff166e9ab175330913582a292904b2197d27 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 16 May 2022 08:47:23 -0500 Subject: [PATCH 09/16] Fixed docstring errors --- python/tvm/tir/function.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 254b8e721617..a921c5b9fc40 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -276,15 +276,15 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): mapping_function : Callable The function to map from source indices to target indices. - The function should accept tir.Var parameters and return a - list. Each element of the returned list should be a - tir.PrimExpr. + The function should accept `tir.Var` parameters and return + a list. Each element of the returned list should be a + `tir.PrimExpr`. ndim: Optional[int] The dimensionality of the buffer to which this - transformation should be applied. If mapping_function - uses variadic argument *args, ndim must be specified. If + transformation should be applied. If mapping_function uses + variadic argument `*args`, `ndim` must be specified. If mapping_function does not use variadic arguments, ndim is optional. @@ -292,7 +292,7 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): ------- index_map: IndexMap - Returns an IndexMap representing the mapping_function. + Returns an IndexMap representing the `mapping_function`. """ index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim) @@ -314,13 +314,13 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = The function to map from source indices to target indices. The function should accept tir.Var parameters and return a list. Each element of the returned list should be either a - tir.PrimExpr or the object IndexMap.AXIS_SEPARATOR. + `tir.PrimExpr` or the object `IndexMap.AXIS_SEPARATOR`. ndim: Optional[int] The dimensionality of the buffer to which this - transformation should be applied. If mapping_function - uses variadic argument *args, ndim must be specified. If + transformation should be applied. If mapping_function uses + variadic argument `*args`, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional. @@ -329,8 +329,8 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = ret: Tuple[IndexMap, List[int]] Returns a tuple whose first element is an IndexMap - representing the mapping_function, and whose second index - is a list of indices at which IndexMap.AXIS_SEPARATOR + representing the `mapping_function`, and whose second index + is a list of indices at which `IndexMap.AXIS_SEPARATOR` occurred. """ From cefee79cad3c2b778c4a597299a1960af3ab1cd3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 17 May 2022 09:31:08 -0500 Subject: [PATCH 10/16] Updated/tested TransformatLayoutTraits::UnpackedAsPython --- python/tvm/tir/schedule/testing.py | 30 +++++++++++++++---- .../primitive/layout_transformation.cc | 11 +++---- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py index 04cbffcd4d87..0286d0e7c56e 100644 --- a/python/tvm/tir/schedule/testing.py +++ b/python/tvm/tir/schedule/testing.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. """Testing utilities for the TensorIR schedule API""" -from typing import Union +from typing import Union, Sequence +import tvm from tvm.ir import IRModule, structural_equal from tvm.tir import PrimFunc from tvm.tir.schedule import Trace, Schedule @@ -27,6 +28,7 @@ def verify_trace_roundtrip( mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", + text_format: Union[str, Sequence[str]] = ["python", "json"], ) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. @@ -44,18 +46,36 @@ def verify_trace_roundtrip( 1) "all" - Turn on all the checks 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask + text_format: Union[str, Sequence[str]] + The text format or formats whose round-trip behavior should be + validated. If a single string, validate round-trips through """ - # Step 1. Serialize the trace to JSON + if not isinstance(text_format, str): + for opt in text_format: + new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, text_format=opt) + return new_sch + trace = sch.trace assert trace is not None - json_obj = trace.as_json() - # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling + + # Step 1. Perform a round-trip through the text-format new_sch = Schedule(mod=mod, debug_mask=debug_mask) - Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) + if text_format == "json": + json_obj = trace.as_json() + Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) + elif text_format == "python": + py_trace = "\n".join(trace.as_python()) + exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) + else: + assert text_format in ("json", "python"), f"Unknown text format: {text_format}" + + # Step 2. Verify that the round-trip produced the same scheduling assert structural_equal(new_sch.mod, sch.mod) + # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr + # Step 4. Return the new schedule in case it could be useful return new_sch diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index b133f537b5ac..13304dc2e6ae 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -284,11 +284,12 @@ struct TransformLayoutTraits : public UnpackedInstTraits Integer buffer_index_type, IndexMap index_map) { PythonAPICall py("transform_layout"); py.Input("block", block_rv); - py.Input("buffer_index", buffer_index); - py.Input("buffer_index_type", '"' + - std::string(BufferIndexType2Str( - static_cast(buffer_index_type->value))) + - '"'); + + std::ostringstream os; + os << "(\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) + << "\", " << buffer_index << ")"; + py.Input("buffer", os.str()); + py.Input("index_map", index_map->ToPythonString()); return py.Str(); } From 5caf6f459d18a9225f948b848c231eaf323eaf80 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 17 May 2022 11:17:58 -0500 Subject: [PATCH 11/16] Disabled exec-used check for running trace.as_python() --- python/tvm/tir/schedule/testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py index 0286d0e7c56e..3689f756e83c 100644 --- a/python/tvm/tir/schedule/testing.py +++ b/python/tvm/tir/schedule/testing.py @@ -65,7 +65,7 @@ def verify_trace_roundtrip( Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) elif text_format == "python": py_trace = "\n".join(trace.as_python()) - exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) + exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint: disable=exec-used else: assert text_format in ("json", "python"), f"Unknown text format: {text_format}" From 2f63e96b7b3df1bd1cf336e181f6dfd9fc140d95 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 17 May 2022 15:11:24 -0500 Subject: [PATCH 12/16] Updated SetAxisSeparatorTraits::UnpackedAsPython --- src/tir/schedule/primitive/layout_transformation.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 3971ab188622..86b73196dae1 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -342,11 +342,12 @@ struct SetAxisSeparatorTraits : public UnpackedInstTraits axis_separators) { PythonAPICall py("set_axis_separator"); py.Input("block", block_rv); - py.Input("buffer_index", buffer_index); - py.Input("buffer_index_type", '"' + - std::string(BufferIndexType2Str( - static_cast(buffer_index_type->value))) + - '"'); + + std::ostringstream os; + os << "(\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) + << "\", " << buffer_index << ")"; + py.Input("buffer", os.str()); + py.Input("axis_separators", axis_separators); return py.Str(); } From 015ce32213d7bcf7ca9b17a11a1b969375a8a29b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 18 May 2022 08:38:25 -0500 Subject: [PATCH 13/16] Updated unit test that was added in merge commit --- tests/python/unittest/test_tir_schedule_transform_layout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index c1f841936ccd..5530bcc6d1e1 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -166,7 +166,7 @@ def test_simplify(): block_outer = sch.blockize(i_inner) B = sch.cache_read(block_outer, 0, "global") - sch.transform_layout(B, 0, "write", lambda i, j: (i // 16, j // 16, i % 16, j % 16)) + sch.transform_layout(B, ("write", 0), lambda i, j: (i // 16, j // 16, i % 16, j % 16)) @T.prim_func def ref(B: T.Buffer[(8, 8, 16, 16), "float32"], C: T.Buffer[(128, 128), "float32"]): From accf8ff5a847a1df74661353f5c033567c2fa1e2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 18 May 2022 08:38:47 -0500 Subject: [PATCH 14/16] Fixed the argument name for TensorizeTraits This wasn't checked before, but was the only other issue caught by the updates to verify_trace_roundtrip. --- src/tir/schedule/primitive/blockize_tensorize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 331d098347b0..7ed80a1c5b8f 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -699,7 +699,7 @@ struct TensorizeTraits : public UnpackedInstTraits { static String UnpackedAsPython(Array outputs, String block_or_loop_rv, String intrin) { PythonAPICall py("tensorize"); py.Input("block_or_loop", block_or_loop_rv); - py.Input("intrin", intrin); + py.Input("tensor_intrin", intrin); return py.Str(); } From 99fb77583d6a340ed2bf68fcfd86f404ef46738f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 20 May 2022 12:29:52 -0500 Subject: [PATCH 15/16] Re-enable type checks of transform_layout/set_axis_separator Disabled while waiting for https://github.com/apache/tvm/pull/11289, which was required for the `Tuple` argument. --- python/tvm/tir/schedule/schedule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 7b8a81881615..6895c009d3aa 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2178,7 +2178,7 @@ def iter_buffers(): return (buffer_index_type, buffer_index, buffer_obj) - # @type_checked + @type_checked def transform_layout( self, block: Union[BlockRV, str], @@ -2286,7 +2286,7 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> self, block, buffer_index, buffer_index_type_enum, axis_separators ) - # @type_checked + @type_checked def set_axis_separator( self, block: Union[BlockRV, str], From a513f18c2f89f1362282d1aa6a42452d5388f4d2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 23 May 2022 10:31:13 -0500 Subject: [PATCH 16/16] Updated a few additional transform_layout usages from main --- .../unittest/test_tir_schedule_tensorize_ldmatrix_mma.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index 67e8ae0ad836..e9ee990a2415 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -177,9 +177,9 @@ def tile_wmma_fragment(block_read, height, width): else: loop_b = tile_wmma_fragment(B_warp, k_inner, 16) - sch.transform_layout(A_warp, 0, "write", index_map_A) - sch.transform_layout(B_warp, 0, "write", index_map_B) - sch.transform_layout(C_warp, 0, "read", index_map_C) + sch.transform_layout(A_warp, ("write", 0), index_map_A) + sch.transform_layout(B_warp, ("write", 0), index_map_B) + sch.transform_layout(C_warp, ("read", 0), index_map_C) sch.tensorize(loop_a, ldmatrix_a_intrin) sch.tensorize(loop_b, ldmatrix_b_intrin)