diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d225280b655f..d29495c43007 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -373,14 +373,14 @@ def sample_perfect_tile( @type_checked def sample_compute_location( self, - block: BlockRV, + block: Union[BlockRV, str], decision: Optional[int] = None, ) -> LoopRV: """Sample a compute-at location of the given block Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block whose compute-at location is to be sampled decision : Optional[int] The sampling decision @@ -390,6 +390,8 @@ def sample_compute_location( result : LoopRV The sampled loop where the input block is to be computed at """ + block = self._normalize_block_arg(block) + return _ffi_api.ScheduleSampleComputeLocation( # type: ignore # pylint: disable=no-member self, block, @@ -425,12 +427,12 @@ def get_block( ) @type_checked - def get_loops(self, block: BlockRV) -> List[LoopRV]: + def get_loops(self, block: Union[BlockRV, str]) -> List[LoopRV]: """Get the parent loops of the block in its scope, from outer to inner Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The query block Returns @@ -438,6 +440,7 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: loops : List[LoopRV] A list of loops above the given block in its scope, from outer to inner """ + block = self._normalize_block_arg(block) return list(_ffi_api.ScheduleGetLoops(self, block)) # type: ignore # pylint: disable=no-member @type_checked @@ -457,12 +460,12 @@ def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockR return list(_ffi_api.ScheduleGetChildBlocks(self, block_or_loop)) # type: ignore # pylint: disable=no-member @type_checked - def get_producers(self, block: BlockRV) -> List[BlockRV]: + def get_producers(self, block: Union[BlockRV, str]) -> List[BlockRV]: """Get the producers of a specific block Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block in the query Returns @@ -470,15 +473,16 @@ def get_producers(self, block: BlockRV) -> List[BlockRV]: producers : List[BlockRV] A list of producers of the given block """ + block = self._normalize_block_arg(block) return list(_ffi_api.ScheduleGetProducers(self, block)) # type: ignore # pylint: disable=no-member @type_checked - def get_consumers(self, block: BlockRV) -> List[BlockRV]: + def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]: """Get the consumers of a specific block Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block in the query Returns @@ -486,6 +490,7 @@ def get_consumers(self, block: BlockRV) -> List[BlockRV]: consumers : List[BlockRV] A list of consumers of the given block """ + block = self._normalize_block_arg(block) return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member ########## Schedule: Transform loops ########## @@ -970,7 +975,9 @@ def after_unroll(a: T.handle, b: T.handle) -> None: ########## Schedule: Insert cache stages ########## @type_checked - def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV: + def cache_read( + self, block: Union[BlockRV, str], read_buffer_index: int, storage_scope: str + ) -> BlockRV: """Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one block who write the buffer in the scope. @@ -979,7 +986,7 @@ def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The consumer block of the target buffer. read_buffer_index: int @@ -1036,12 +1043,15 @@ def after_cache_read(a: T.handle, b: T.handle) -> None: B[vi, vj] = A_local[vi, vj] * 2.0 """ + block = self._normalize_block_arg(block) return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member self, block, read_buffer_index, storage_scope ) @type_checked - def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV: + def cache_write( + self, block: Union[BlockRV, str], write_buffer_index: int, storage_scope: str + ) -> BlockRV: """Create a block that reads a buffer region into a write cache. It requires: 1) There is only one block who write the buffer in the scope. @@ -1050,7 +1060,7 @@ def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: st Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The producer block of the target buffer. write_buffer_index: int @@ -1108,12 +1118,17 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: B[vi, vj] = B_local[vi, vj] """ + block = self._normalize_block_arg(block) return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member self, block, write_buffer_index, storage_scope ) @type_checked - def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) -> BlockRV: + def reindex( + self, + block: Union[BlockRV, str], + buffer: Union[Tuple[str, int], str, Buffer], + ) -> BlockRV: """Create a block that read/write a buffer region into a read/write cache with reindexing. The layout of the cache will be the same as by the iterators of the block that reads/writes the buffer. It requires: @@ -1122,12 +1137,27 @@ def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) -> Parameters ---------- - block: BlockRV - The block that accesses the target buffer - buffer_index: int - The index of the buffer in block's read or write region - buffer_index_type : str - Type of the buffer index, "read" or "write" + block : Union[BlockRV, str] + + The block that accesses the target buffer. If a string, + this must uniquely identify a block. + + buffer: Union[Tuple[str,int], Buffer, str] + + The buffer to be transformed, or a specification of how to + identify the buffer to be transformed. + + If `buffer` if a tuple of ``(str,int)``, the first item + should be either "read" or "write", and the second item is + an index into the block's read or write regions. + + If `buffer` is a string, it is the name of the buffer, + which must exist within the reads/writes of the block. In + addition, the reads/writes of the block may not contain + more than one buffer with this name. + + If `buffer` is a Buffer object, it must exist within the + reads/writes of the block. Returns ------- @@ -1157,7 +1187,7 @@ def before_reindex( sch = tir.Schedule(before_reindex) block = sch.get_block("B") - sch.reindex(block, 0, "read) + sch.reindex(block, ("read", 0)) After applying reindex, the IR becomes: @@ -1179,6 +1209,8 @@ def after_reindex( B[vi, vj] = A_reindex[vi, vj] * 2.0 """ + block = self._normalize_block_arg(block) + buffer_index_type, buffer_index, _ = self._normalize_buffer_arg(block, buffer) assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 return _ffi_api.ScheduleReIndex( # type: ignore # pylint: disable=no-member @@ -1190,7 +1222,7 @@ def after_reindex( @type_checked def compute_at( self, - block: BlockRV, + block: Union[BlockRV, str], loop: LoopRV, preserve_unit_loops: bool = False, ) -> None: @@ -1213,7 +1245,7 @@ def compute_at( Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block to be moved loop: LoopRV @@ -1273,6 +1305,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 1.0 """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleComputeAt( # type: ignore # pylint: disable=no-member self, block, @@ -1283,7 +1316,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: @type_checked def reverse_compute_at( self, - block: BlockRV, + block: Union[BlockRV, str], loop: LoopRV, preserve_unit_loops: bool = False, ) -> None: @@ -1303,7 +1336,7 @@ def reverse_compute_at( Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block to be moved loop: LoopRV @@ -1363,6 +1396,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 1.0 """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleReverseComputeAt( # type: ignore # pylint: disable=no-member self, block, @@ -1371,7 +1405,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: ) @type_checked - def compute_inline(self, block: BlockRV) -> None: + def compute_inline(self, block: Union[BlockRV, str]) -> None: """Inline a block into its consumer(s). It requires: 1) The block is a complete non-root block, which only produces one buffer @@ -1386,7 +1420,7 @@ def compute_inline(self, block: BlockRV) -> None: Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block to be inlined to its consumer(s) Examples @@ -1432,10 +1466,11 @@ def after_inline(a: T.handle, c: T.handle) -> None: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member @type_checked - def reverse_compute_inline(self, block: BlockRV) -> None: + def reverse_compute_inline(self, block: Union[BlockRV, str]) -> None: """Inline a block into its only producer. It requires: 1) The block is a complete non-root block, which only produces and consumes one buffer @@ -1453,7 +1488,7 @@ def reverse_compute_inline(self, block: BlockRV) -> None: Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block to be inlined to its producer Examples @@ -1499,12 +1534,13 @@ def after_inline(a: T.handle, c: T.handle) -> None: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member ########## Schedule: Reduction ########## @type_checked - def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV: + def decompose_reduction(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV: """Decompose a reduction block into two separate blocks. a) The init block, which is translated from the init statement of the reduction block; @@ -1523,7 +1559,7 @@ def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV: Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The reduction block to be decomposed loop : LoopRV The loop above which the init block is inserted before. @@ -1578,6 +1614,7 @@ def after_decompose(a: ty.handle, c: ty.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] """ + block = self._normalize_block_arg(block) return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore # pylint: disable=no-member @type_checked @@ -1734,7 +1771,7 @@ def after_rfactor(a: T.handle, b: T.handle) -> None: @type_checked def storage_align( # pylint: disable=too-many-arguments self, - block: BlockRV, + block: Union[BlockRV, str], buffer_index: int, axis: int, factor: int, @@ -1747,7 +1784,7 @@ def storage_align( # pylint: disable=too-many-arguments Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The producer block of the buffer. buffer_index : int The index of the buffer in block's write region. @@ -1812,18 +1849,19 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: ---- Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`. """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleStorageAlign( # type: ignore # pylint: disable=no-member self, block, buffer_index, axis, factor, offset ) @type_checked - def set_scope(self, block: BlockRV, buffer_index: int, storage_scope: str) -> None: + def set_scope(self, block: Union[BlockRV, str], buffer_index: int, storage_scope: str) -> None: """Set the storage scope of a buffer, where the buffer is specified by the a block and a write-index Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The producer block of the buffer buffer_index : int The index of the buffer in block's write region @@ -1883,6 +1921,7 @@ def after_set_scope( ---- Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`. """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleSetScope( # type: ignore # pylint: disable=no-member self, block, buffer_index, storage_scope ) @@ -2418,14 +2457,14 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> @type_checked def transform_block_layout( self, - block: BlockRV, + block: Union[BlockRV, str], index_map: Union[IndexMap, Callable], ) -> None: """Apply a transformation represented by IndexMap to block Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block to be transformed index_map : Union[IndexMap, Callable] @@ -2470,6 +2509,7 @@ def after_transform_block_layout( vi, = T.axis.remap("S", [i]) B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0 """ + block = self._normalize_block_arg(block) if callable(index_map): index_map = IndexMap.from_func(index_map) _ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index c96f88e1f633..5a8d452f14b8 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -1241,11 +1241,10 @@ struct ReIndexTraits : public UnpackedInstTraits { Integer buffer_index_type) { PythonAPICall py("reindex"); py.Input("block", block); - py.Input("buffer_index", buffer_index); - py.Input("buffer_index_type", '"' + - std::string(BufferIndexType2Str( - static_cast(buffer_index_type->value))) + - '"'); + std::ostringstream os; + os << "(\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) + << "\", " << buffer_index << ")"; + py.Input("buffer", os.str()); py.SingleOutput(outputs); return py.Str(); } diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index ef306b2c4929..5cd39c7ddaeb 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -741,13 +741,15 @@ def block_predicate_cache_write_output_buf() -> None: ########## Testcases for cache_read ########## +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_cache_read_elementwise(): + +def test_cache_read_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") - cached_a = sch.cache_read(block_b, 0, "global") - cached_b = sch.cache_read(block_c, 0, "local") + cached_a = sch.cache_read("B" if use_block_name else block_b, 0, "global") + cached_b = sch.cache_read("C" if use_block_name else block_c, 0, "local") assert sch.get(cached_a) == sch.get(sch.get_block("A_global")) assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) assert sch.get(block_b) == sch.get(sch.get_block("B")) @@ -756,74 +758,74 @@ def test_cache_read_elementwise(): verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_cache_read_under_scope(): +def test_cache_read_under_scope(use_block_name): sch = tir.Schedule(access_under_scope, debug_mask="all") - block_b = sch.get_block("B") - block_c = sch.get_block("C") + block_b = "B" if use_block_name else sch.get_block("B") + block_c = "C" if use_block_name else sch.get_block("C") sch.cache_read(block_b, 0, "local") sch.cache_read(block_c, 0, "global") tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=access_under_scope) -def test_cache_read_opaque_access(): +def test_cache_read_opaque_access(use_block_name): sch = tir.Schedule(opaque_access, debug_mask="all") - block = sch.get_block("load_store") + block = "load_store" if use_block_name else sch.get_block("load_store") sch.cache_read(block, 0, "global") tvm.ir.assert_structural_equal(cache_read_opaque_access, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=opaque_access) -def test_cache_read_location(): +def test_cache_read_location(use_block_name): sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") sch.cache_read(block_b, 0, "global") tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) -def test_continuous_cache_read(): +def test_continuous_cache_read(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.cache_read(block_c, 0, "shared") sch.cache_read(block_c, 0, "local") tvm.ir.assert_structural_equal(continuous_cache_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_cache_read_with_block_predicate(): +def test_cache_read_with_block_predicate(use_block_name): sch = tir.Schedule(func_with_block_predicate, debug_mask="all") - block = sch.get_block("consumer") + block = "consumer" if use_block_name else sch.get_block("consumer") sch.cache_read(block, 0, "shared") tvm.ir.assert_structural_equal(block_predicate_cache_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) -def test_cache_read_non_int32_shape(): +def test_cache_read_non_int32_shape(use_block_name): sch = tir.Schedule(elementwise_shape_int64, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") sch.cache_read(block_b, 0, "global") tvm.ir.assert_structural_equal(cache_read_shape_int64, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64) -def test_cache_read_fail_multi_producer(): +def test_cache_read_fail_multi_producer(use_block_name): sch = tir.Schedule(func_multi_producer, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_read(block_b, 0, "global") -def test_cache_read_fail_index_out_of_bound(): +def test_cache_read_fail_index_out_of_bound(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_read(block_b, 1, "global") -def test_cache_read_fail_invalid_storage_scope(): +def test_cache_read_fail_invalid_storage_scope(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_read(block_b, 0, "test_scope") @@ -831,12 +833,12 @@ def test_cache_read_fail_invalid_storage_scope(): ########## Testcases for cache_write ########## -def test_cache_write_elementwise(): +def test_cache_write_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") - cached_b = sch.cache_write(block_b, 0, "local") - cached_c = sch.cache_write(block_c, 0, "global") + cached_b = sch.cache_write("B" if use_block_name else block_b, 0, "local") + cached_c = sch.cache_write("C" if use_block_name else block_c, 0, "global") assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) assert sch.get(cached_c) == sch.get(sch.get_block("C_global")) assert sch.get(block_b) == sch.get(sch.get_block("B")) @@ -845,10 +847,10 @@ def test_cache_write_elementwise(): verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_cache_write_under_scope(): +def test_cache_write_under_scope(use_block_name): sch = tir.Schedule(access_under_scope, debug_mask="all") - block_a = sch.get_block("A") - block_b = sch.get_block("B") + block_a = "A" if use_block_name else sch.get_block("A") + block_b = "B" if use_block_name else sch.get_block("B") block_scope = sch.get_block("scope") sch.cache_write(block_a, 0, "local") sch.cache_write(block_b, 0, "global") @@ -857,11 +859,11 @@ def test_cache_write_under_scope(): verify_trace_roundtrip(sch=sch, mod=access_under_scope) -def test_cache_write_opaque_access(): +def test_cache_write_opaque_access(use_block_name): sch = tir.Schedule(opaque_access, debug_mask="all") - block_store = sch.get_block("load_store") - block_opaque = sch.get_block("opaque") - block_match_buffer = sch.get_block("match_buffer") + block_store = "load_store" if use_block_name else sch.get_block("load_store") + block_opaque = "opaque" if use_block_name else sch.get_block("opaque") + block_match_buffer = "match_buffer" if use_block_name else sch.get_block("match_buffer") sch.cache_write(block_store, 0, "global") sch.cache_write(block_opaque, 0, "global") sch.cache_write(block_match_buffer, 0, "global") @@ -869,58 +871,58 @@ def test_cache_write_opaque_access(): verify_trace_roundtrip(sch=sch, mod=opaque_access) -def test_cache_write_location(): +def test_cache_write_location(use_block_name): sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_a = sch.get_block("A") + block_a = "A" if use_block_name else sch.get_block("A") sch.cache_write(block_a, 0, "global") tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) -def test_continuous_cache_write(): +def test_continuous_cache_write(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") sch.cache_write(block_b, 0, "shared") sch.cache_write(block_b, 0, "local") tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_cache_write_with_block_predicate(): +def test_cache_write_with_block_predicate(use_block_name): # cache write for intermediate buffer sch = tir.Schedule(func_with_block_predicate, debug_mask="all") - block = sch.get_block("producer") + block = "producer" if use_block_name else sch.get_block("producer") sch.cache_write(block, 0, "shared") tvm.ir.assert_structural_equal(block_predicate_cache_write_intermediate_buf, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) # cache write for external buffer sch = tir.Schedule(func_with_block_predicate, debug_mask="all") - block = sch.get_block("consumer") + block = "consumer" if use_block_name else sch.get_block("consumer") sch.cache_write(block, 0, "shared") tvm.ir.assert_structural_equal(block_predicate_cache_write_output_buf, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) -def test_cache_write_fail_multi_producer(): +def test_cache_write_fail_multi_producer(use_block_name): sch = tir.Schedule(func_multi_producer, debug_mask="all") - block_a0 = sch.get_block("A0") - block_a1 = sch.get_block("A1") + block_a0 = "A0" if use_block_name else sch.get_block("A0") + block_a1 = "A1" if use_block_name else sch.get_block("A1") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_a0, 0, "global") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_a1, 0, "global") -def test_cache_write_fail_index_out_of_bound(): +def test_cache_write_fail_index_out_of_bound(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_b, 1, "global") -def test_cache_write_fail_invalid_storage_scope(): +def test_cache_write_fail_invalid_storage_scope(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_b, 0, "test_scope") diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 3772d9a4e0fe..0c20a4783ca0 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1052,17 +1052,19 @@ def static_bound_after_compute_at(A: T.Buffer[(32, 1), "float32"], C: T.Buffer[( # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_compute_at_two_elementwise(): + +def test_compute_at_two_elementwise(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = sch.get_block("B") - loop, _ = sch.get_loops(sch.get_block("C")) + block = "B" if use_block_name else sch.get_block("B") + loop, _ = sch.get_loops("C" if use_block_name else sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) tvm.ir.assert_structural_equal(two_elementwise_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) -def test_compute_at_blockized_1(): +def test_compute_at_blockized_1(use_block_name): sch = tir.Schedule(blockized_1, debug_mask="all") block = sch.get_block("B") _, loop = sch.get_loops(sch.get_block("C_outer")) @@ -1071,7 +1073,7 @@ def test_compute_at_blockized_1(): verify_trace_roundtrip(sch=sch, mod=blockized_1) -def test_compute_at_blockized_2(): +def test_compute_at_blockized_2(use_block_name): sch = tir.Schedule(blockized_2, debug_mask="all") block = sch.get_block("B_outer") _, loop, _, _ = sch.get_loops(sch.get_block("C")) @@ -1080,7 +1082,7 @@ def test_compute_at_blockized_2(): verify_trace_roundtrip(sch=sch, mod=blockized_2) -def test_compute_at_cuda_matmul_0(): +def test_compute_at_cuda_matmul_0(use_block_name): sch = tir.Schedule(cuda_matmul_0, debug_mask="all") block = sch.get_block("C") _, _, _, _, _, loop, _, _ = sch.get_loops(sch.get_block("C_local")) @@ -1089,7 +1091,7 @@ def test_compute_at_cuda_matmul_0(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_0) -def test_compute_at_cuda_matmul_1(): +def test_compute_at_cuda_matmul_1(use_block_name): sch = tir.Schedule(cuda_matmul_1, debug_mask="all") block = sch.get_block("A_shared_local") _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_block("C")) @@ -1098,7 +1100,7 @@ def test_compute_at_cuda_matmul_1(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_1) -def test_compute_at_cuda_matmul_2(): +def test_compute_at_cuda_matmul_2(use_block_name): sch = tir.Schedule(cuda_matmul_2, debug_mask="all") block = sch.get_block("B_shared_local") _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_block("C")) @@ -1107,7 +1109,7 @@ def test_compute_at_cuda_matmul_2(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_2) -def test_compute_at_cuda_matmul_3(): +def test_compute_at_cuda_matmul_3(use_block_name): sch = tir.Schedule(cuda_matmul_3, debug_mask="all") block = sch.get_block("A_shared") _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) @@ -1116,7 +1118,7 @@ def test_compute_at_cuda_matmul_3(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_3) -def test_compute_at_cuda_matmul_4(): +def test_compute_at_cuda_matmul_4(use_block_name): sch = tir.Schedule(cuda_matmul_4, debug_mask="all") block = sch.get_block("B_shared") _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) @@ -1125,7 +1127,7 @@ def test_compute_at_cuda_matmul_4(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_4) -def test_compute_at_reduction_block(): +def test_compute_at_reduction_block(use_block_name): sch = tir.Schedule(multi_reduction, debug_mask="all") block = sch.get_block("B") (loop,) = sch.get_loops(sch.get_block("C")) @@ -1134,7 +1136,7 @@ def test_compute_at_reduction_block(): verify_trace_roundtrip(sch=sch, mod=multi_reduction) -def test_compute_at_tiled_pooling_read_cache(): +def test_compute_at_tiled_pooling_read_cache(use_block_name): sch = tir.Schedule(tiled_pooling_read_cache, debug_mask="all") compute = sch.get_block("compute") _, w_o, _, _, _, _ = sch.get_loops(compute) @@ -1144,7 +1146,7 @@ def test_compute_at_tiled_pooling_read_cache(): verify_trace_roundtrip(sch=sch, mod=tiled_pooling_read_cache) -def test_compute_at_non_uniform_tiled_conv(): +def test_compute_at_non_uniform_tiled_conv(use_block_name): sch = tir.Schedule(non_uniform_tiled_conv, debug_mask="all") compute = sch.get_block("compute") sch.compute_at(sch.get_block("cache"), sch.get_loops(compute)[1]) @@ -1152,7 +1154,7 @@ def test_compute_at_non_uniform_tiled_conv(): verify_trace_roundtrip(sch=sch, mod=non_uniform_tiled_conv) -def test_compute_at_concat(): +def test_compute_at_concat(use_block_name): sch = tir.Schedule(concat_two_elemwise, debug_mask="all") concat = sch.get_block("T_concat") add1 = sch.get_block("T_add_1") @@ -1164,7 +1166,7 @@ def test_compute_at_concat(): verify_trace_roundtrip(sch=sch, mod=concat_two_elemwise) -def test_compute_at_tiled_repeat_op(): +def test_compute_at_tiled_repeat_op(use_block_name): sch = tir.Schedule(tiled_repeat_op, debug_mask="all") outer_ax, _ = sch.get_loops(sch.get_block("T_repeat")) sch.compute_at(sch.get_block("T_add"), outer_ax) @@ -1172,7 +1174,7 @@ def test_compute_at_tiled_repeat_op(): verify_trace_roundtrip(sch=sch, mod=tiled_repeat_op) -def test_reverse_compute_at_tiled(): +def test_reverse_compute_at_tiled(use_block_name): sch = tir.Schedule(tiled, debug_mask="all") block = sch.get_block("C") _, _, loop, _ = sch.get_loops(sch.get_block("B")) @@ -1181,7 +1183,7 @@ def test_reverse_compute_at_tiled(): verify_trace_roundtrip(sch=sch, mod=tiled) -def test_reverse_compute_at_tiled_trivial_binding(): +def test_reverse_compute_at_tiled_trivial_binding(use_block_name): sch = tir.Schedule(tiled_trivial_binding, debug_mask="all") block = sch.get_block("C") _, _, loop, _ = sch.get_loops(sch.get_block("B")) @@ -1190,7 +1192,7 @@ def test_reverse_compute_at_tiled_trivial_binding(): verify_trace_roundtrip(sch=sch, mod=tiled_trivial_binding) -def test_reverse_compute_at_blockized_2(): +def test_reverse_compute_at_blockized_2(use_block_name): sch = tir.Schedule(blockized_2, debug_mask="all") block = sch.get_block("C") _, loop = sch.get_loops(sch.get_block("B_outer")) @@ -1199,7 +1201,7 @@ def test_reverse_compute_at_blockized_2(): verify_trace_roundtrip(sch=sch, mod=blockized_2) -def test_reverse_compute_at_factorized(): +def test_reverse_compute_at_factorized(use_block_name): sch = tir.Schedule(factorized, debug_mask="all") block = sch.get_block("B") _, loop, _, _ = sch.get_loops(sch.get_block("B_rf")) @@ -1208,7 +1210,7 @@ def test_reverse_compute_at_factorized(): verify_trace_roundtrip(sch=sch, mod=factorized) -def test_reverse_compute_at_floordiv_and_floormod_indices(): +def test_reverse_compute_at_floordiv_and_floormod_indices(use_block_name): sch = tir.Schedule(floordiv_and_floormod_indices, debug_mask="all") A = sch.get_block("A") B = sch.get_block("B") @@ -1219,7 +1221,7 @@ def test_reverse_compute_at_floordiv_and_floormod_indices(): verify_trace_roundtrip(sch=sch, mod=floordiv_and_floormod_indices) -def test_read_out_of_bound(): +def test_read_out_of_bound(use_block_name): sch = tir.Schedule(read_out_of_bound, debug_mask="all") block = sch.get_block("B") (loop,) = sch.get_loops(sch.get_block("C")) @@ -1228,7 +1230,7 @@ def test_read_out_of_bound(): verify_trace_roundtrip(sch=sch, mod=read_out_of_bound) -def test_compact_dataflow(): +def test_compact_dataflow(use_block_name): sch = tir.Schedule(not_all_compact_data_flow, debug_mask="all") block = sch.get_block("B") _, loop = sch.get_loops(sch.get_block("C_1")) @@ -1237,7 +1239,7 @@ def test_compact_dataflow(): verify_trace_roundtrip(sch=sch, mod=not_all_compact_data_flow) -def test_compute_at_simplify_static_bound(): +def test_compute_at_simplify_static_bound(use_block_name): sch = tir.Schedule(static_bound, debug_mask="all") block = sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("C")) @@ -1246,7 +1248,7 @@ def test_compute_at_simplify_static_bound(): verify_trace_roundtrip(sch=sch, mod=static_bound) -def test_compute_at_non_perfect_channel_group(): +def test_compute_at_non_perfect_channel_group(use_block_name): @T.prim_func def grouped_channel_bias( X: T.Buffer[(720, 8, 8), "float32"], Y: T.Buffer[(720, 8, 8), "float32"] @@ -1284,7 +1286,7 @@ def grouped_channel_bias_non_perfect_tiled( tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled) -def test_fail_subtree_complete_block(): +def test_fail_subtree_complete_block(use_block_name): sch = tir.Schedule(fail_subtree_compact_dataflow, debug_mask="all") block = sch.get_block("B_0") loop, _ = sch.get_loops(sch.get_block("C")) @@ -1292,47 +1294,47 @@ def test_fail_subtree_complete_block(): sch.compute_at(block, loop) -def test_fail_not_in_same_scope(): +def test_fail_not_in_same_scope(use_block_name): sch = tir.Schedule(blockized_1, debug_mask="all") - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("C_inner")) with pytest.raises(tvm.tir.ScheduleError, match="same block scope"): sch.compute_at(block, loop) -def test_fail_loop_is_ancestor_of_block(): +def test_fail_loop_is_ancestor_of_block(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("B")) with pytest.raises(tvm.tir.ScheduleError, match="ancestor of block"): sch.compute_at(block, loop) -def test_fail_output_block(): +def test_fail_output_block(use_block_name): sch = tir.Schedule(tiled, debug_mask="all") - block = sch.get_block("C") + block = "C" if use_block_name else sch.get_block("C") loop, _, _, _ = sch.get_loops(sch.get_block("B")) with pytest.raises(tvm.tir.ScheduleError, match="output block"): sch.compute_at(block, loop) -def test_fail_all_consumers_under_loop(): +def test_fail_all_consumers_under_loop(use_block_name): sch = tir.Schedule(fail_all_consumers_under_loop, debug_mask="all") - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("C")) with pytest.raises(tvm.tir.ScheduleError, match="requires all the consumer"): sch.compute_at(block, loop) -def test_fail_all_producers_under_loop(): +def test_fail_all_producers_under_loop(use_block_name): sch = tir.Schedule(fail_all_producers_under_loop, debug_mask="all") - block = sch.get_block("D") + block = "D" if use_block_name else sch.get_block("D") loop, _ = sch.get_loops(sch.get_block("C")) with pytest.raises(tvm.tir.ScheduleError, match="requires all the producer"): sch.reverse_compute_at(block, loop) -def test_compute_at_int64_loop(): +def test_compute_at_int64_loop(use_block_name): def _create_prim_func(): n = te.var("n", dtype="int64") m = te.var("m", dtype="int64") @@ -1344,8 +1346,8 @@ def _create_prim_func(): mod = _create_prim_func() sch = tir.Schedule(mod, debug_mask="all") - block_c = sch.get_block("C") - block_d = sch.get_block("D") + block_c = "C" if use_block_name else sch.get_block("C") + block_d = "D" if use_block_name else sch.get_block("D") i, _ = sch.get_loops(block_d) sch.compute_at(block_c, i) verify_trace_roundtrip(sch=sch, mod=mod) diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 84fb88218997..617e13db27f6 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -587,10 +587,12 @@ def exp_exp_opaque_access_with_tvm_access_ptr_inlined( # pylint: enable=no-member,invalid-name,unused-variable +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_compute_inline_elementwise(): + +def test_compute_inline_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) @@ -598,9 +600,9 @@ def test_compute_inline_elementwise(): verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_compute_inline_under_loop(): +def test_compute_inline_under_loop(use_block_name): sch = tir.Schedule(elementwise_under_loop, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) @@ -608,9 +610,9 @@ def test_compute_inline_under_loop(): verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) -def test_compute_inline_as_dce(): +def test_compute_inline_as_dce(use_block_name): sch = tir.Schedule(elementwise_standalone, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_standalone_dce, sch.mod["main"]) @@ -618,9 +620,9 @@ def test_compute_inline_as_dce(): verify_trace_roundtrip(sch=sch, mod=elementwise_standalone) -def test_compute_inline_multi_consumer(): +def test_compute_inline_multi_consumer(use_block_name): sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") block_d = sch.get_block("D") sch.compute_inline(block_b) @@ -630,81 +632,81 @@ def test_compute_inline_multi_consumer(): verify_trace_roundtrip(sch=sch, mod=elementwise_multi_producer_consumer) -def test_compute_inline_fail_multi_writer(): +def test_compute_inline_fail_multi_writer(use_block_name): sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) -def test_reverse_compute_inline_elementwise(): +def test_reverse_compute_inline_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_reverse_compute_inline_under_loop(): +def test_reverse_compute_inline_under_loop(use_block_name): sch = tir.Schedule(elementwise_under_loop, debug_mask="all") block_b = sch.get_block("B") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) -def test_reverse_compute_inline_fail_as_dce(): +def test_reverse_compute_inline_fail_as_dce(use_block_name): sch = tir.Schedule(elementwise_standalone, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_b) -def test_reverse_compute_inline_fail_multi_producer(): +def test_reverse_compute_inline_fail_multi_producer(use_block_name): sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") - block_d = sch.get_block("D") + block_d = "D" if use_block_name else sch.get_block("D") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_d) -def test_reverse_compute_inline_fail_multi_reader(): +def test_reverse_compute_inline_fail_multi_reader(use_block_name): sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) -def test_reverse_compute_multi_reverse_loads(): +def test_reverse_compute_multi_reverse_loads(use_block_name): sch = tir.Schedule(elementwise_multi_reverse_loads, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_multi_reverse_loads_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_loads) -def test_reverse_compute_inline_affine_load(): +def test_reverse_compute_inline_affine_load(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_reverse_affine_load_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load) -def test_reverse_compute_inline_multi_affine_load(): +def test_reverse_compute_inline_multi_affine_load(use_block_name): sch = tir.Schedule(elementwise_multi_reverse_affine_load, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_multi_reverse_affine_load_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_affine_load) -def test_reverse_compute_inline_affine_load_unit_iter(): +def test_reverse_compute_inline_affine_load_unit_iter(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal( elementwise_reverse_affine_load_unit_iter_inlined, sch.mod["main"] @@ -712,9 +714,9 @@ def test_reverse_compute_inline_affine_load_unit_iter(): verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter) -def test_reverse_compute_inline_affine_load_unit_iter_simplified(): +def test_reverse_compute_inline_affine_load_unit_iter_simplified(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter_simplified, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal( elementwise_reverse_affine_load_unit_iter_simplified_inlined, sch.mod["main"] @@ -723,10 +725,10 @@ def test_reverse_compute_inline_affine_load_unit_iter_simplified(): @pytest.mark.parametrize("reverse_order", [True, False]) -def test_reverse_compute_inline_affine_chain(reverse_order): +def test_reverse_compute_inline_affine_chain(use_block_name, reverse_order): sch = tir.Schedule(elementwise_reverse_affine_chain, debug_mask="all") - block_c = sch.get_block("C") - block_d = sch.get_block("D") + block_c = "C" if use_block_name else sch.get_block("C") + block_d = "D" if use_block_name else sch.get_block("D") if reverse_order: sch.reverse_compute_inline(block_d) sch.reverse_compute_inline(block_c) @@ -737,68 +739,68 @@ def test_reverse_compute_inline_affine_chain(reverse_order): verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_chain) -def test_reverse_compute_fail_non_affine_load(): +def test_reverse_compute_fail_non_affine_load(use_block_name): sch = tir.Schedule(elementwise_reverse_non_affine_load, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) -def test_reverse_compute_fail_multi_reverse_loads(): +def test_reverse_compute_fail_multi_reverse_loads(use_block_name): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) -def test_opaque_access_load(): +def test_opaque_access_load(use_block_name): sch = tir.Schedule(opaque_access_load, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) -def test_opaque_access_store(): +def test_opaque_access_store(use_block_name): sch = tir.Schedule(opaque_access_store, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) -def test_buffer_matched(): +def test_buffer_matched(use_block_name): sch = tir.Schedule(buffer_matched, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) -def test_output_block(): +def test_output_block(use_block_name): sch = tir.Schedule(matmul_relu, debug_mask="all") block = sch.get_block("compute") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block) -def test_compute_inline_predicate(): +def test_compute_inline_predicate(use_block_name): sch = tir.Schedule(elementwise_predicate, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_predicate_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) -def test_compute_inline_multi_loads(): +def test_compute_inline_multi_loads(use_block_name): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_multi_loads_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads) -def test_compute_inline_with_opaque_access(): +def test_compute_inline_with_opaque_access(use_block_name): """Test not rewrite opaque reads/writes after irrelavant compute inline""" sch = tir.Schedule(access_opaque_ptr_then_elemwise, debug_mask="all") - BB = sch.get_block("BB") + BB = "BB" if use_block_name else sch.get_block("BB") sch.compute_inline(BB) tvm.ir.assert_structural_equal(access_opaque_ptr_then_elemwise_inline, sch.mod["main"]) @@ -810,10 +812,10 @@ def test_inline_block_with_init(): sch.compute_inline(block=block) -def test_compute_inline_opaque_access_with_tvm_access_ptr(): +def test_compute_inline_opaque_access_with_tvm_access_ptr(use_block_name): """Test opaque access with tvm_access_ptr after compute inline""" sch = tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all") - compute = sch.get_block("compute") + compute = "compute" if use_block_name else sch.get_block("compute") sch.compute_inline(compute) tvm.ir.assert_structural_equal( exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"] diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index a8348afb457d..f3503460e50a 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -215,19 +215,21 @@ def colsum_decompose_with_vectorization(a: T.handle, b: T.handle) -> None: # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_reduction_decompose0(): + +def test_reduction_decompose0(use_block_name): s = tir.Schedule(matmul, debug_mask="all") - C = s.get_block("update") + C = "update" if use_block_name else s.get_block("update") i, j, k = s.get_loops(C) s.decompose_reduction(C, i) tvm.ir.assert_structural_equal(matmul_decompose0, s.mod["main"]) verify_trace_roundtrip(s, mod=matmul) -def test_reduction_decompose1(): +def test_reduction_decompose1(use_block_name): s = tir.Schedule(rowsum_blockized, debug_mask="all") - blockized_B = s.get_block("blockized_B") + blockized_B = "blockized_B" if use_block_name else s.get_block("blockized_B") io, ko = s.get_loops(blockized_B) s.decompose_reduction(blockized_B, io) tvm.ir.assert_structural_equal(matmul_decompose1, s.mod["main"]) diff --git a/tests/python/unittest/test_tir_schedule_reindex.py b/tests/python/unittest/test_tir_schedule_reindex.py index 9b2e37a19813..c6776b0c8a3e 100644 --- a/tests/python/unittest/test_tir_schedule_reindex.py +++ b/tests/python/unittest/test_tir_schedule_reindex.py @@ -168,35 +168,43 @@ def multiple_read(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "f B[vi, vj] = A[vj, vi] + A[vi, vj] -def test_reindex_read_basic(): +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) +use_buffer_name = tvm.testing.parameter(by_dict={"buffer_index": False, "buffer_name": True}) + + +def test_reindex_read_basic(use_block_name, use_buffer_name): sch = tir.Schedule(transpose_elementwise) - block = sch.get_block("B") - sch.reindex(block, 0, "read") + block = "B" if use_block_name else sch.get_block("B") + buf = "A" if use_buffer_name else ("read", 0) + sch.reindex(block, buf) tvm.ir.assert_structural_equal(transpose_elementwise_reindex_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=transpose_elementwise) -def test_conv2d_reindex_read(): +def test_conv2d_reindex_read(use_block_name, use_buffer_name): sch = tir.Schedule(conv2d_nhwc) - block = sch.get_block("conv2d_nhwc") - sch.reindex(block, 1, "read") + block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") + buf = "Weight" if use_buffer_name else ("read", 1) + sch.reindex(block, buf) tvm.ir.assert_structural_equal(conv2d_nhwc_reindex_weight, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) -def test_matmul_reindex_write(): +def test_matmul_reindex_write(use_block_name, use_buffer_name): sch = tir.Schedule(matmul) - block = sch.get_block("matmul") - sch.reindex(block, 0, "write") + block = "matmul" if use_block_name else sch.get_block("matmul") + buf = "C" if use_buffer_name else ("write", 0) + sch.reindex(block, buf) tvm.ir.assert_structural_equal(matmul_reindex_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=matmul) -def test_reindex_fail_multiple_read(): +def test_reindex_fail_multiple_read(use_block_name, use_buffer_name): sch = tir.Schedule(multiple_read) - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") + buf = "A" if use_buffer_name else ("read", 0) with pytest.raises(ScheduleError): - sch.reindex(block, 0, "read") + sch.reindex(block, buf) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index 17f35ea8f72f..0c2a3d27ffdb 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -179,10 +179,16 @@ def test_sample_perfect_tile_composite(): verify_trace_roundtrip(sch, mod=elementwise) -def test_sample_compute_location(): +use_sugared_block = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) + + +def test_sample_compute_location(use_sugared_block): n = 100 sch = tir.Schedule(tiled_conv2d_with_padding, seed=42, debug_mask="all") - pad_input = sch.get_block("PadInput") + if use_sugared_block: + pad_input = "PadInput" + else: + pad_input = sch.get_block("PadInput") decision_dict = dict() for _ in range(n): _ = sch.sample_compute_location(pad_input) # pylint: disable=invalid-name diff --git a/tests/python/unittest/test_tir_schedule_set_scope.py b/tests/python/unittest/test_tir_schedule_set_scope.py index 29c4880f7762..b2e8479462eb 100644 --- a/tests/python/unittest/test_tir_schedule_set_scope.py +++ b/tests/python/unittest/test_tir_schedule_set_scope.py @@ -86,20 +86,21 @@ def element_wise_subregion_match_set_scope(A: T.Buffer[(128, 128), "float32"], C # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_set_scope(): +def test_set_scope(use_block_name): func = element_wise s = tir.Schedule(func, debug_mask='all') - s.set_scope(s.get_block("B"), 0, "shared") + s.set_scope('B' if use_block_name else s.get_block("B"), 0, "shared") tvm.ir.assert_structural_equal(element_wise_set_scope, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) -def test_set_scope_fail_on_output_buffer(): +def test_set_scope_fail_on_output_buffer(use_block_name): func = element_wise s = tir.Schedule(func, debug_mask='all') with pytest.raises(tvm.tir.ScheduleError): - s.set_scope(s.get_block("C"), 0, "shared") + s.set_scope('C' if use_block_name else s.get_block("C"), 0, "shared") def test_set_scope_fail_on_index_out_of_bound(): diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py index 3b699fd8f1b2..072640c8f3af 100644 --- a/tests/python/unittest/test_tir_schedule_storage_align.py +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -98,10 +98,12 @@ def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) -def test_storage_align(): +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) + +def test_storage_align(use_block_name): func = element_wise s = tir.Schedule(func, debug_mask='all') - B = s.get_block("B") + B = 'B' if use_block_name else s.get_block("B") s.storage_align(B, 0, axis=0, factor=128, offset=127) tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index e184bc3f627c..205bd5091268 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -171,15 +171,13 @@ def conv2d_nhwc_transformed( # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on -use_sugared_transform = tvm.testing.parameter( - by_dict={"transform_layout": False, "transform_layout_sugared": True} -) +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_two_elementwise_transform_intermediate_buffer(use_sugared_transform): +def test_two_elementwise_transform_intermediate_buffer(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - if use_sugared_transform: + if use_block_name: sch.transform_layout( block="B", buffer="B", @@ -193,10 +191,10 @@ def test_two_elementwise_transform_intermediate_buffer(use_sugared_transform): verify_trace_roundtrip(sch=sch, mod=two_elementwise) -def test_two_elementwise_transform_input_buffer(use_sugared_transform): +def test_two_elementwise_transform_input_buffer(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - if use_sugared_transform: + if use_block_name: sch.transform_layout( index_map=packed_index_map_func, block="B", @@ -210,10 +208,10 @@ def test_two_elementwise_transform_input_buffer(use_sugared_transform): verify_trace_roundtrip(sch=sch, mod=two_elementwise) -def test_two_elementwise_transform_output_buffer(use_sugared_transform): +def test_two_elementwise_transform_output_buffer(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - if use_sugared_transform: + if use_block_name: sch.transform_layout( index_map=packed_index_map_func, block="C", @@ -295,17 +293,17 @@ def summation_3d_split( tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"]) -def test_transform_block_layout_basic(): +def test_transform_block_layout_basic(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") sch.transform_block_layout(block, lambda i, j: (i * 128 + j,)) tvm.ir.assert_structural_equal(elementwise_transformed, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_transform_block_layout_conv2d_nhwc(): +def test_transform_block_layout_conv2d_nhwc(use_block_name): sch = tir.Schedule(conv2d_nhwc, debug_mask="all") - block = sch.get_block("conv2d_nhwc") + block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") sch.transform_block_layout( block, lambda n, h, w, co, rh, rw, rc: (n * 112 * 112 + h * 112 + w, co, rh * 7 * 3 + rw * 3 + rc), @@ -314,16 +312,16 @@ def test_transform_block_layout_conv2d_nhwc(): verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) -def test_transform_block_layout_fail_non_affine(): +def test_transform_block_layout_fail_non_affine(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") with pytest.raises(tir.ScheduleError): sch.transform_block_layout(block, lambda i, j: (i + j,)) -def test_transform_block_layout_fail_mixed_iter_type(): +def test_transform_block_layout_fail_mixed_iter_type(use_block_name): sch = tir.Schedule(conv2d_nhwc, debug_mask="all") - block = sch.get_block("conv2d_nhwc") + block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") with pytest.raises(tir.ScheduleError): sch.transform_block_layout( block, diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 0d23d3f95211..b7517aab7cd3 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -104,6 +104,8 @@ def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: enable=no-member,invalid-name,unused-variable +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) + def test_tir_schedule_creation(): # Tests: @@ -131,24 +133,24 @@ def test_tir_schedule_get_block(): assert block.same_as(matmul.body.block.body.body.body[1].body.block) -def test_tir_schedule_get_loops(): +def test_tir_schedule_get_loops(use_block_name): # Tests: # - Schedule.get_loops # - Schedule.get sch = tir.Schedule(matmul, debug_mask="all") - block_rv = sch.get_block(name="update") - i, j, k = sch.get_loops(block_rv) + block = "update" if use_block_name else sch.get_block(name="update") + i, j, k = sch.get_loops(block) assert sch.get(i).loop_var.name == "i" assert sch.get(j).loop_var.name == "j" assert sch.get(k).loop_var.name == "k" -def test_tir_schedule_copy_1(): +def test_tir_schedule_copy_1(use_block_name): # Tests: # - Schedule.copy sch_1 = tir.Schedule(matmul, debug_mask="all") block_rv = sch_1.get_block(name="update") - i, j, k = sch_1.get_loops(block_rv) + i, j, k = sch_1.get_loops(block="update" if use_block_name else block_rv) assert sch_1.get(i).loop_var.name == "i" assert sch_1.get(j).loop_var.name == "j" assert sch_1.get(k).loop_var.name == "k" @@ -218,9 +220,9 @@ def test_get_child_blocks(): assert s.get(update) == s.get(blocks[1]) -def test_get_producers(): +def test_get_producers(use_block_name): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") - block = sch.get_block("relu") + block = "relu" if use_block_name else sch.get_block("relu") (producer,) = sch.get_producers(block) assert tvm.ir.structural_equal( sch.get_sref(producer).stmt, @@ -229,9 +231,9 @@ def test_get_producers(): verify_trace_roundtrip(sch, mod=matmul_relu) -def test_get_consumers(): +def test_get_consumers(use_block_name): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") - block = sch.get_block("matmul") + block = "matmul" if use_block_name else sch.get_block("matmul") (consumer,) = sch.get_consumers(block) assert tvm.ir.structural_equal( sch.get_sref(consumer).stmt,