diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index dff6374a6185..aa75fb05a067 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1592,10 +1592,12 @@ inline Array meshgrid(const Array& inputs, const std::string& in * \param dst_layout the destination layout. * \param name output tensor name. * \param tag output tensor tag. + * \param schedule_rule name of specialized schedule rule to use. * \return A tensor with shape in \p dst_layout */ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, const std::string& dst_layout, + const std::string schedule_rule = "None", const std::string name = "T_layout_trans", const std::string tag = kInjective) { Layout src_layout_struct(src_layout); @@ -1614,6 +1616,12 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, Array dst_shape = layout_converter.ForwardShape(src->shape); + Map attrs = {{"schedule_rule", String(schedule_rule)}, + // Information about layouts needed for the schedule rule + {"src_layout", String(src_layout)}, + {"dst_layout", String(dst_layout)}, + {"input_shape", src->shape}}; + return compute( dst_shape, [&](const Array& dst_indices) { @@ -1625,7 +1633,7 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, } return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0))); }, - name, tag); + name, tag, attrs); } /*! \brief Utility function for auto_scheduler_layout_transform */ diff --git a/python/tvm/meta_schedule/schedule/cuda/__init__.py b/python/tvm/meta_schedule/schedule/cuda/__init__.py index 937a6e16a91b..ce79d15cc4b4 100644 --- a/python/tvm/meta_schedule/schedule/cuda/__init__.py +++ b/python/tvm/meta_schedule/schedule/cuda/__init__.py @@ -15,3 +15,5 @@ # specific language governing permissions and limitations # under the License. """Per-block schedule rules in MetaSchedule for target key 'cuda'""" + +from . import layout_transform diff --git a/python/tvm/meta_schedule/schedule/cuda/layout_transform.py b/python/tvm/meta_schedule/schedule/cuda/layout_transform.py new file mode 100644 index 000000000000..949ef915c9ff --- /dev/null +++ b/python/tvm/meta_schedule/schedule/cuda/layout_transform.py @@ -0,0 +1,583 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""layout_transform scheduling rule for cuda.""" + +import math +from collections import deque +from typing import List, Optional, Tuple, Union + +import tvm +from tvm import meta_schedule +from tvm.tir.schedule import BlockRV, ExprRV, LoopRV + +## Tiling layout transforms: +# Assume we have an input shape of [A, B, C, D] and want to layout transform +# ABCD --> DBAC so the output shape would be [D, B, A, C]. +# +# Consider reading from the input buffer in a cache-friendly fashion on CPU. We would +# expect a loop structure like: +# lAr, lBr, lCr, lDr = T.grid(A, B, C, D) +# +# Meanwhile consider writing to the output buffer in a cache-friendly fashion on CPU: +# lDw, lBw, lAw, lCw = T.grid(D, B, A, C) +# +# Clearly in many scenarios it is impossible to guarantee contiguous writes and reads +# within a single loop due to non-adjacent dimensions. Instead we work on transposing some +# small sub-tensor of our input writing and then reading from shared memory. We must now +# construct our submatrix so that reading and writing can both be done with some contiguous +# access in global memory. +# +# Consider the case of a 2D transpose. For example [1024, 2048] -> [2048, 1024]. +# We note that if we deal with a submatrix of shape [32, 32] which corresponds +# to the dimension of our input tensor, then rows of the submatrix are contiguous +# in the input tensor. Meanwhile, columns of our submatrix are contiguous in our +# output vector. Therefore, with this tile shape we have opportunity to read +# contiguously in our input tensor and write to shared memory, and write contiguously +# to our output tensor. +# +# The multiple dimensional case has a similar analogue. We want to allocate shared +# memory per block of [`tile_size`, `tile_size`]. We want the inner most dimension +# of our shared memory to correspond to contiguous reads from the input tensor and +# the outer dimension to correspond to contiguous writes into the output tensor. +# +# In terms of the loop structure reading from the input tensor, the inner most loops +# of our tile must correspond to the inner most dimensions of the input shape, +# while the outer dimensions correspond to the inner most dimensions of the output shape. +# To obtain an inner tile with this loop structure we factor out a contiguous `tile_size` +# chunk of our loop in the shape of interest. +# +# An example is probably best to show this idea: +# Let's say we want a layout transform of ABCD --> DCAB. With shape +# [1024_a, 2_b, 32_c, 8_d] --> [8_d, 32_c, 1024_a, 2_b] +# +# And tile size 32. +# +# Then we initially have a coalesced-read loop pattern of: +# T.grid(1024_a, 2_b, 32_c, 8_d) +# +# To obtain an inner tile of 32, we factor 4 from 32_c and 8 from 8_d: +# T.grid(1024_a, 2_b, 8_c1, 1_d1, 4_c2t, 8_d2t) +# T.grid(1024_a, 2_b, 8_cr, 1_dr, 32_dim1) +# +# To obtain an outer tile of 32, we factor from B then A to follow contiguous write +# pattern: +# +# T.grid(64_a1, 1_b1, 8_cr, 1_dr, 16_a2t, 2_b2t, 32_dim1) +# T.grid(64_ar, 1_br, 8_cr, 1_dr, 32_dim0, 32_dim1) +# +# Which allows us to read a tile with our wanted properties. +# For writing we use the existing analysis infrastructure to generate the structure for writing. + + +def tile_layout_transform( + sch: tvm.tir.Schedule, + block_read: BlockRV, + block_write: BlockRV, + src_layout: str, + dst_layout: str, + input_shape: List[int], + tile_size: ExprRV, +) -> Tuple[BlockRV, BlockRV]: + """ + High level tiling for layout transform block. Mutates sch in place. + + Parameters + ---------- + sch: + The initial schedule. We expect `block_read` and `block_write` to correspond to + the blocks which reads and writes from global memory respectively. We also expect + block_read's initial loops to follow + + block_read: + The block which reads from global memory and writes to shared memory buffer. + + block_write: + The block which writes to global memory and reads from shared memory buffer. + + src_layout : + The src_layout, each character should appear once and also appear in dst_layout. + There should be not numeric characters and refer to potentially implicit reshapes. + E.g. the transform NCHW --> NCHW4c really implies NCcHW --> NCHWc. In this case + src_layout should be NCcHW. + + dst_layout: + The dst_layout. There should not be numeric characters, e.g. NCHW4c becomes NCHWc. + + input_shape: + The input shape after applying potentially implicit reshapes. Should match the loop + extants corresponding to src_layout. + + tile_size: + The tile size of read and writes. There will be tile_size threads per block, each of which + reads up to tile_size elements. + + Returns + ------- + ret: + A tuple of the block that writes to global memory, and the block that reads from + global memory. + """ + + def pad_dimension_to_at_least_number(loop: LoopRV, requested_size: int): + """E.g. if loop has extant of 8 but we want 10, returns size 10 loop with padding.""" + left, right = sch.split(loop, [None, requested_size]) + return sch.fuse(left, right) + + def pad_dimension_to_factor_of_tile_size( + loop: LoopRV, initial_size: int, tile_size: int = tile_size + ) -> Tuple[LoopRV, int]: + """ + Pads loop of given size until it is divisible into tile_size. + If the given size of the loop is greater than tile size. Do not pad. + + examples: + - loop_size = 5 , tile_size = 32. loop_size --> 8 + - loop_size = 5 , tile_size = 36. loop_size --> 6 + - loop_size = 8 , tile_size = 32. loop_size --> 8 : since 8 already divides 32. + - loop_size = 33, tile_size = 32. loop_size --> 33 : since 33 > 32. + + Returns padded loopRV and the new size. + """ + if tile_size % initial_size == 0: + return loop, int(initial_size) + + if initial_size > tile_size or initial_size == tile_size: + return loop, int(initial_size) + + # if initial_size > tile_size return without change, factor = 1 + size = initial_size + while (tile_size % size) % tile_size > 0: + size += 1 + + return pad_dimension_to_at_least_number(loop, size), int(size) + + def spin_out_factor( + loops: List[LoopRV], loop_extants: List[int], index: int, factor_needed: int + ) -> Tuple[List[LoopRV], List[int], int]: + """ + Factor out the requested loop's dimensions to reach the requested factor and + places the requested factor as the innermost loop. + + Updates the schedule in-place. + + E.g. say we want to factors which eventually multiply to 32 (factor_needed). + + Say we have the index we chose is a loop with an extant of 8. + E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed = 32, index=3 (dim=8) + - 8 divides into 32 so we just split up the loop into two loops with extants 1 and 8. + - we then keep the 1-loop in place and move the new 8-loop to back of the list of loops + - ending loops / loop_extants = [3, 32, 6, 1, 8], remaining_factor_needed = 32 / 8 = 4 + + E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=32, index=0 (dim=3) + - 3 does not divide 32, so we pad until the extant divides 32, e.g. 4 + - we then split up the loop into extants 1 and 4, moving the 4 to the back + - ending loops / loop_extants = [1, 32, 6, 8, 4], remaining_factor_needed = 32 / 4 = 8 + + E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=5, index=3 (dim=8) + - 8 is larger than 5 so we immediately do the splitting routine. + - the 8 extant loop becomes loops with extants 2 and 5 + - ending loops / loop_extants = [1, 32, 6, 2, 5], remaining_factor_needed = 5 / 5 = 1 + + After updating loop ordering in place, returns the new list of loops, extants, and the + remaining factor needed. + """ + cur_loop = loops[index] + cur_extant = loop_extants[index] + + # Pad loops to divide evenly for factors needed, and split + new_loop, new_size = pad_dimension_to_factor_of_tile_size( + cur_loop, cur_extant, tile_size=factor_needed + ) + + split_factor = min(new_size, factor_needed) + new_loop_split, factored_loop = sch.split(new_loop, [None, split_factor]) + factor_needed = factor_needed // split_factor + + # update caching + loops[index] = new_loop_split + loops.append(factored_loop) + + loop_extants[index] = math.ceil(int(new_size) / int(split_factor)) + loop_extants.append(split_factor) + + sch.reorder(*loops) + return loops, loop_extants, factor_needed + + def factor_dim_in_order( + indices: List[int], + loops: List[LoopRV], + cur_loop_extants: List[int], + work_needed_inner_loop: int = tile_size, + ) -> Tuple[List[LoopRV], List[int]]: + """Factors out the loops in the order of indices until we reach needed work. + + Adds new loop factors to the back in reverse order of access. Returns new list + of loops and their extants. + """ + for i in indices: + loops, cur_loop_extants, work_needed_inner_loop = spin_out_factor( + loops, cur_loop_extants, i, work_needed_inner_loop + ) + if work_needed_inner_loop == 1: + break + return loops, cur_loop_extants + + def get_high_level_loop_structure( + block_read: BlockRV, input_shape: List[int], src_layout: str, dst_layout: str + ): + """Runs the factorization described above.""" + # index 0 ... rank - 1 will always correspond to original loops + # perhaps after they have been factored. + rank = len(input_shape) + loops = sch.get_loops(block_read) + cur_loop_extants = list(input_shape) + + # Factor dim0 tile size and fuse things together + loops, cur_loop_extants = factor_dim_in_order( + list(range(rank - 1, -1, -1)), + loops, + cur_loop_extants, + work_needed_inner_loop=tile_size, + ) + # The factors which multiply to tile_size are now in back of our + # list of loops. However because we added them by traversing the inner + # dimensions, they are actually reversed order to guarantee the best access + # so reorder before fusing. + loops = loops[:rank] + loops[rank:][::-1] + cur_loop_extants = cur_loop_extants[:rank] + cur_loop_extants[rank::-1] + sch.reorder(*loops) + dim0_loop_tiled = sch.fuse(*loops[rank:]) + loops = loops[:rank] + loops.append(dim0_loop_tiled) + cur_loop_extants = cur_loop_extants[:rank] + cur_loop_extants.append(tile_size) + + # Same thing with dim1 + # [:rank + 1], since we placed dim0_loop_tiled in the end which we want to keep + loops, cur_loop_extants = factor_dim_in_order( + list( + ( + src_layout.index(dst_layout[loop_index_dst]) + for loop_index_dst in range(rank - 1, -1, -1) + ) + ), + loops, + cur_loop_extants, + work_needed_inner_loop=tile_size, + ) + loops = loops[: rank + 1] + loops[rank + 1 :][::-1] + cur_loop_extants = cur_loop_extants[: rank + 1] + cur_loop_extants[rank + 1 :: -1] + sch.reorder(*loops) + dim1_loop_tiled = sch.fuse(*loops[rank + 1 :]) + loops = loops[: rank + 1] + loops.append(dim1_loop_tiled) + cur_loop_extants = cur_loop_extants[: rank + 1] + cur_loop_extants.append(tile_size) + + # After this we have loops: [loop1, loop2, loop3 ... dim0_tiled, dim1_tiled] + get_high_level_loop_structure(block_read, input_shape, src_layout, dst_layout) + + # If there are insufficient elements, than dim1_tiled or dim0_tiled might be too small + # In all likelihood you should use a smaller tile, but I don't want things to crash. + loops = sch.get_loops(block_read) + loops[-1] = pad_dimension_to_at_least_number(loops[-1], tile_size) + loops[-2] = pad_dimension_to_at_least_number(loops[-2], tile_size) + + # We want the dim0 and dim1 parent loops to be the inner most. Right now dim1 is inner-msot + # and we just need to move dim0 in (last dimension of dst). + # Recall right now structure is at least [l1 l2 ... ln, dim0_tiled, dim1_tiled] + # where n >= 2. + dim0_loop_index = src_layout.index(dst_layout[-1]) + dim0_loop = loops.pop(dim0_loop_index) + loops = loops[:-3] + [dim0_loop, loops[-3]] + loops[-2:] + sch.reorder(*loops) + + # After this loops are: [outer_loop (block binding), dim0_tiled, dim1_tiled] + outer_loop = sch.fuse(*loops[:-2]) + + # Now that we have the high level loop structure, we can use reverse_compute_at magic + # To get the proper loop structure for writing! This is also as coalesced as possible + # already. + sch.reverse_compute_at(block_write, outer_loop) + + # Fuse all inner loops for the write into 2 loops, grab inner loops for both read + # and write block which have locality (we will bind these to threadIdx) + fused_write_loop = sch.fuse(*sch.get_loops(block_write)[1:]) + _, inner_write_loop = sch.split(fused_write_loop, [None, tile_size]) + inner_read_loop = sch.get_loops(block_read)[-2] + + sch.bind(loop=outer_loop, thread_axis="blockIdx.x") + sch.bind(loop=inner_write_loop, thread_axis="threadIdx.x") + sch.bind(loop=inner_read_loop, thread_axis="threadIdx.x") + + return block_write, block_read + + +def create_cached_read( + sch: tvm.tir.Schedule, + block_write: BlockRV, + orig_input_shape: List[int], + orig_src_layout: str, + orig_dst_layout: str, +) -> Tuple[BlockRV, List[int], str, str]: + """ + Creates the cached read block with expected structure. + + Loop extants should follow the input shape closely. E.g. if the input is [2, 6, 8], we + expect our loop structure to be T.grid(2, 6, 8). Possibly reshape to handle implicit reshapes, + in which case we will match the implicit reshape shape. + + Layout transform allows semantics like NCHW --> NCHW4c. Which involves splitting the original C + axis into contiguous 4-element chunks. This axis is then moved to the end (NCHWc). This is + guaranteed by the operator to be done without additional padding. To handle this we just split + the associating axis (prev. type checking ensures C is divisible by 4)in src_layout found in + block_read. E.g. NCHW -> NCHW4c now becomes NC4cHW -> NCHW4c. + + Note: NCHW4c --> NCHW is not allowed, so the only numeric digits will be in dst. + + The returned layout strings will be santized and made compatible. E.g. NCHW --> NCHW4c becomes + NCcHW --> NCHWc. + + TODO(AndrewZhaoLuo): Investigate using proper memory alignment to avoid bank conflict. + + Parameters + ---------- + sch: + The initial schedule. We expect `block_read`. We also expect + block_read's initial loops to follow the original input shape. + + block_read: + The block which reads from global memory and writes to shared memory buffer. + + orig_input_shape: + The input shape of the input buffer to the primfunc. + + orig_src_layout: + The original src_layout string. + + orig_dst_layout: + The original dst_layout string. + + Returns + ------- + ret: + A tuple of the cached read block, new input shape of shared memory buffer, + the new src_layout, and new dst_layout string. + """ + # Figure out split dimensions, entries are (loop index in src_layout, split amount) + split_dimensions: List[Tuple[int, int]] = [] + + # This is without numeric digits, e.g. NCHW4c -> NCHWc + new_dst_layout = [] + + # Use state machine to parse NCHW4c string + split_size = 0 + for char in orig_dst_layout: + if char.isnumeric(): + split_size = split_size * 10 + int(char) + else: + if char.islower(): + # hit axis like 'c', need to find parent axis 'C' in src_layout + src_layout_index = orig_src_layout.index(char.upper()) + split_dimensions.append((src_layout_index, split_size)) + split_size = 0 + new_dst_layout.append(char) + + # If no splits were detected we are done + if len(split_dimensions) == 0: + block_read = sch.cache_read(block_write, 0, "shared") + return block_read, orig_input_shape, orig_src_layout, orig_dst_layout + + # Calculate final input shapes, each of these are a single element for unsplit dims + # and tuples for split dims associated with the two new axis + input_shape: List[Union[int, Tuple]] = list(orig_input_shape) + new_src_layout: List[Union[str, Tuple]] = list(orig_src_layout) + for src_layout_split_index, split_factor in split_dimensions: + dimension_name = orig_src_layout[src_layout_split_index] + new_src_layout[src_layout_split_index] = (dimension_name, dimension_name.lower()) + input_shape[src_layout_split_index] = ( + orig_input_shape[src_layout_split_index] // split_factor, + split_factor, + ) + + # Unpack any tuples introduced via appending + def unpack_list(target_list) -> List: + output: List = [] + for ele in target_list: + if isinstance(ele, tuple): + output.extend(ele) + else: + output.append(ele) + return output + + new_src_layout_str = "".join(unpack_list(new_src_layout)) + new_dst_layout_str = "".join(unpack_list(new_dst_layout)) + + # Write block loop extants match + dst_to_src_map = [new_dst_layout_str.index(dim) for dim in new_src_layout_str] + block_read = sch.reindex_cache_read( + block_write, + read_buffer_index=0, + index_map=tvm.tir.IndexMap.from_func( + lambda *loops: [loops[dst_to_src_map[i]] for i, _ in enumerate(loops)], + ndim=len(new_src_layout_str), + ), + storage_scope="shared", + ) + + loops_read = sch.get_loops(block_read) + sch.reorder( + *[loops_read[new_dst_layout_str.index(dst_dim_name)] for dst_dim_name in new_src_layout_str] + ) + return block_read, unpack_list(input_shape), new_src_layout_str, new_dst_layout_str + + +def auto_inline_into(sch: tvm.tir.Schedule, start_block: BlockRV) -> BlockRV: + """ + Inlines given start_block's consumers and future dependencies into start_block. + + Parameters + ---------- + sch: + The initial schedule. + + start_block: + The block to inline into, should be a block which reads and writes to global memory, doing + layout transform. + + Returns + ------- + ret: + The new block inlined into it's consumers. + """ + # Rules defined by DefaultCUDA schedule_rule set. + autoinline_rule = meta_schedule.schedule_rule.AutoInline( + into_producer=True, + into_consumer=False, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + ) + + fringe = deque(sch.get_consumers(start_block)) + visited = set() + while len(fringe) > 0: + cur_block = fringe.popleft() + if cur_block in visited: + continue + + visited.add(cur_block) + consumer_blocks = sch.get_consumers(cur_block) + fringe.extend(consumer_blocks) + + sch = autoinline_rule.apply(sch, cur_block)[0] + + +def get_max_tile_size() -> int: + """Returns the max tile size. + + This is assuming only threads in a warp can have coalesced accesses. 32 is the default if + no target information can be gotten. + """ + max_tile_size = 32 + cur_target = tvm.target.Target.current() + if cur_target is not None and hasattr(cur_target, "thread_warp_size"): + max_tile_size = int(cur_target.thread_warp_size) + return max_tile_size + + +@tvm.register_func("meta_schedule.cuda.layout_transform") +def cuda_layout_transform_schedule_rule( + sch: tvm.tir.Schedule, block: BlockRV, testing_tile_sizes: Optional[List[int]] = None +) -> List[tvm.tir.Schedule]: + """ + Applies tiling scheme to layout transform task (potentially fused with other injective funcs). + + Returned schedules will be the default schedule, as well as tiled versions with tile_size in + the range of 2,3...threads_per_warp. + + This is assuming only threads in a warp can have coalesced accesses. 32 is the default if + no target information can be gotten. + + Parameters + ---------- + sch: + The initial schedule. + + block: + The block corresponding to the layout transform. + Should be a block which reads and writes to global memory, doing layout transform. + + testing_tile_sizes: + A list of tile sizes to try, overriding normal settings. For testing. None means + ignore. Else overrides normal settings of tile sizes to try. + + Returns + ------- + ret: + A list of new schedules to try. + """ + # Info needed for tiling + src_layout = sch.get_sref(block).stmt.annotations["src_layout"] + dst_layout = sch.get_sref(block).stmt.annotations["dst_layout"] + input_shape = [int(c) for c in sch.get_sref(block).stmt.annotations["input_shape"]] + + schedules = [] + + # Always include the default schedules which will be handled via AutoBind schedule rule + # Except during testing + if not testing_tile_sizes: + schedules.append(sch) + + sch = sch.copy() + + # Inline consumers of the layout transform into the layout transform block. + # Normally default for injective schedules but must manually be called in new schedule rule + # for consumers of the layout transform. TODO(AndrewZhaoLuo): Figure out why this is the case. + auto_inline_into(sch, block) + + # Setup up basic structure of schedule of creating read into shared mem, before applying tiling + # Outer loop structure of read block matches that of src_layout + # E.g. if input_shape is [4, 6, 8]. Loops for read block will be + # for i, j, k in T.grid(4, 6, 8): + # ... + # Read block will read from global memory coalesced at the start + # Assume write to output global memory is coalesced in block_write + # + # This also handles the case where there is an implicit reshape going on. + # e.g. NCHW -> NCHW4c which is equivalent to reshaping NCHW + # to NCcHW and then applying the new layout where the extant of c is 4. + # Grab final input shape and src and dst layouts with possible implicit reshape. + block_read, input_shape, src_layout, dst_layout = create_cached_read( + sch, block, input_shape, src_layout, dst_layout + ) + + # Try tile size 2,3...threads_per_warp as tile size of 1 has no coaslescing. + if testing_tile_sizes is None: + tile_sizes = list(range(2, get_max_tile_size() + 1)) + else: + tile_sizes = testing_tile_sizes + + for tile_size in tile_sizes: + new_sch = sch.copy() + tile_layout_transform( + new_sch, block_read, block, src_layout, dst_layout, input_shape, tile_size + ) + schedules.append(new_sch) + + return schedules diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 4499d5e2266e..93df67ff6b99 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -94,7 +94,7 @@ def compute_strided_set(attrs, inputs, output_type): _reg.register_injective_schedule("strided_set") # layout_transform -_reg.register_injective_schedule("layout_transform") +_reg.register_strategy("layout_transform", strategy.layout_transform_strategy) _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) _reg.register_injective_schedule("auto_scheduler_layout_transform") _reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 416637c14905..65573321f76c 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1396,3 +1396,14 @@ def dft_strategy_cuda(attrs, inputs, out_type, target): name="dft.cuda", ) return strategy + + +@layout_transform_strategy.register(["cuda", "gpu"]) +def layout_transform_strategy_cuda(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_layout_transform(topi.layout_transform, schedule_rule="layout_transform"), + schedule_injective, + name="layout_transform.cuda", + ) + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 733b630fc4da..2883e5e1fb77 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -21,12 +21,7 @@ from tvm import _ffi, ir, te, topi from tvm.target import generic_func, override_native_generic_func -from tvm.topi.utils import ( - get_const_float, - get_const_int, - get_const_tuple, - get_float_tuple, -) +from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, get_float_tuple from .. import op as _op @@ -2060,3 +2055,32 @@ def conv2d_backward_weight_strategy(attrs, inputs, out_type, target): "conv2d_backward_weight is currently only supported with cudnn. " "Please run Legalize pass to decompose this op into supported ops." ) + + +@override_native_generic_func("layout_transform_strategy") +def layout_transform_strategy(attrs, inputs, out_type, target): + """layout transform generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_layout_transform(topi.layout_transform), + # Defined earlier in the file + schedule_injective, + name="layout_transform.generic", + ) + return strategy + + +def wrap_compute_layout_transform(topi_compute, schedule_rule="None"): + """Wrap layout transform compute""" + + def _compute_layout_transform(attrs, inputs, output_type): + return [ + topi_compute( + inputs[0], + attrs.src_layout, + attrs.dst_layout, + schedule_rule, + ) + ] + + return _compute_layout_transform diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 23334da9c25c..e4fe3c583990 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -17,13 +17,13 @@ # pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name """Injective transformation operators""" from __future__ import absolute_import as _abs + import tvm -from tvm import te -from tvm import topi +from tvm import te, topi from tvm.te import hybrid -from . import cpp -from . import tag -from .utils import within_index, make_idx, const_vector + +from . import cpp, tag +from .utils import const_vector, make_idx, within_index def expand_dims(a, axis, num_newaxis=1): @@ -636,7 +636,7 @@ def tile(a, reps): return cpp.tile(a, reps) -def layout_transform(array, src_layout, dst_layout): +def layout_transform(array, src_layout, dst_layout, schedule_rule="None"): """Transform the layout according to src_layout and dst_layout Parameters @@ -649,8 +649,11 @@ def layout_transform(array, src_layout, dst_layout): dst_layout : str the destination layout. + + schedule_rule : str + the schedule rule to apply if any """ - return cpp.layout_transform(array, src_layout, dst_layout) + return cpp.layout_transform(array, src_layout, dst_layout, schedule_rule) def shape(array, dtype="int32"): diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 0ea1392e5daf..bbefa19c2055 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -87,7 +87,7 @@ TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = layout_transform(args[0], args[1], args[2]); + *rv = layout_transform(args[0], args[1], args[2], args[3]); }); TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index f1d74348db17..ee148db94d0a 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -20,6 +20,7 @@ import numpy as np import pytest + import tvm import tvm.testing from tvm import IRModule @@ -420,6 +421,7 @@ def main( # type: ignore ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[ax0, ax1 * T.int64(3) + ax4, ax2, ax3]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) + T.block_attr({"dst_layout": "NCHW3c", "input_shape": [1, 3, 16, 16], "schedule_rule": "None", "src_layout": "NCHW"}) T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else( ax0 < T.int64(1) and ax1 * T.int64(3) + ax4 < T.int64(3) and ax2 < T.int64(16) and ax3 < T.int64(16), # type: ignore placeholder[ax0, ax1 * T.int64(3) + ax4, ax2, ax3], @@ -440,6 +442,7 @@ def main(placeholder: T.Buffer((T.int64(1), T.int64(2), T.int64(16), T.int64(16) ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(placeholder[ax0, ax1 // T.int64(4), ax2, ax3, ax1 % T.int64(4)]) # type: ignore T.writes(T_layout_trans[ax0, ax1, ax2, ax3]) + T.block_attr({"dst_layout": "NCHW", "input_shape": [1, 2, 16, 16, 4], "schedule_rule": "None", "src_layout": "NCHW4c"}) T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < T.int64(1) and ax1 < T.int64(8) and ax2 < T.int64(16) and ax3 < T.int64(16), placeholder[ax0, ax1 // T.int64(4), ax2, ax3, ax1 % T.int64(4)], T.float32(0), dtype="float32") # type: ignore @tvm.script.ir_module diff --git a/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py b/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py new file mode 100644 index 000000000000..d1ba84d836be --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py @@ -0,0 +1,466 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import itertools +import random +import tempfile +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np + +import tvm +import tvm.testing +from tvm import meta_schedule, relay +from tvm.meta_schedule.schedule.cuda.layout_transform import cuda_layout_transform_schedule_rule +from tvm.relay.op import OpPattern +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.tir.schedule import BlockRV + +# Small gpu parameters which should work for nearly every (modern-ish) gpu. +TARGET = tvm.target.Target( + "cuda -max_threads_per_block=32 -max_num_threads=128 -thread_warp_size=32 -max_shared_memory_per_block=8192 -registers_per_block=1024" +) + + +class PatchCustomLayoutTransformScheduleRule: + """Patch the custom layout transform schedule to test only specific tile sizes. + + If tile_sizes = [], then returns the default (non-tiled) schedule, otherwise + returns only the schedule with the given tiles. + """ + + FUNC_NAME = "meta_schedule.cuda.layout_transform" + + def __init__(self, tile_sizes: List[int]) -> None: + self.tile_sizes = tile_sizes + self.old_func = None + + def __enter__(self, *args, **kwargs) -> None: + self.old_func = tvm.get_global_func(self.FUNC_NAME) + + def new_layout_rule( + sch: tvm.tir.Schedule, + block: BlockRV, + tile_sizes: Optional[List[int]] = self.tile_sizes, + ) -> List[tvm.tir.Schedule]: + return cuda_layout_transform_schedule_rule(sch, block, tile_sizes) + + tvm.register_func(self.FUNC_NAME, new_layout_rule, override=True) + + def __exit__(self, *args, **kwargs) -> None: + tvm.register_func(self.FUNC_NAME, self.old_func, override=True) + + +# Create unary functions which apply ops with compatible fusion levels to layout transform +def get_random_axis(data: relay.Expr): + rank = len(relay.transform.InferTypeLocal(data).shape) + return random.randint(0, rank - 1) + + +def apply_elemwise_clip(data: relay.Expr, min=0, max=10): + assert relay.op.get("clip").get_attr("TOpPattern") == OpPattern.ELEMWISE + return relay.clip(data, min, max) + + +def apply_broadcast_add(data: relay.Expr, val_to_add=5): + assert relay.op.get("add").get_attr("TOpPattern") == OpPattern.BROADCAST + type_info = relay.transform.InferTypeLocal(data) + return relay.add(data, relay.const(val_to_add, dtype=type_info.dtype)) + + +def apply_injective_concatenate(data: relay.Expr, axis=None): + if axis is None: + axis = get_random_axis(data) + assert relay.op.get("concatenate").get_attr("TOpPattern") == OpPattern.INJECTIVE + return relay.concatenate([data, data], axis) + + +def apply_comm_reduce_max(data: relay.Expr, axis=None): + if axis is None: + axis = get_random_axis(data) + assert relay.op.get("max").get_attr("TOpPattern") == OpPattern.COMM_REDUCE + + # Do this to maintain dimensions + return relay.add(data, relay.max(data, axis, keepdims=True)) + + +pattern_level_to_op = { + OpPattern.ELEMWISE: apply_elemwise_clip, + OpPattern.BROADCAST: apply_broadcast_add, + OpPattern.INJECTIVE: apply_injective_concatenate, + OpPattern.COMM_REDUCE: apply_comm_reduce_max, +} + + +def apply_layout_transform(data: relay.Expr, src_layout: str, dst_layout: str): + assert relay.op.get("layout_transform").get_attr("TOpPattern") == OpPattern.INJECTIVE + return relay.layout_transform(data, src_layout, dst_layout) + + +def create_relay_module( + input_shape: List[int], dtype: str, ops: List[Union[OpPattern, Tuple[str, str]]] +) -> tvm.IRModule: + """Create a relay module with the given string of ops. + + ops: + Applies the associated operators in order. If an integer, refers to applying + the unary operator from `extra_pattern_level_to_op` map. If a tuple, applies + a layout transform with the given (src_layout, dst_layout) + """ + input_data = relay.var("input", shape=input_shape, dtype=dtype) + + cur_data = input_data + for op_info in ops: + # Progressively build type info + relay.transform.InferTypeLocal(cur_data) + if isinstance(op_info, tuple): + # layout transform case + src_layout, dst_layout = op_info + cur_data = apply_layout_transform(cur_data, src_layout, dst_layout) + else: + cur_data = pattern_level_to_op[op_info](cur_data) + + relay.transform.InferTypeLocal(cur_data) + return tvm.IRModule.from_expr(cur_data) + + +def extract_layout_transform_task( + mod: tvm.IRModule, target: tvm.target.Target +) -> meta_schedule.ExtractedTask: + """Given a relay IRModule, return the PrimFunc IRModule with fused layout transform task.""" + extracted_tasks = meta_schedule.relay_integration.extract_tasks( + mod, + target, + {}, + pass_config={"relay.backend.use_meta_schedule": True}, + ) + task_of_interest = None + for task in extracted_tasks: + if "layout_transform" in task.task_name: + task_of_interest = task + break + assert task_of_interest is not None + return task_of_interest + + +def run_primfunc( + primfunc_mod: tvm.IRModule, target: tvm.target.Target, input_tensors: List[tvm.nd.NDArray] +): + """Compile and run the primfunc with the given input tensors.""" + with tvm.transform.PassContext( + config={"relay.backend.use_meta_schedule": True}, + opt_level=3, + ): + lib = tvm.build(primfunc_mod, target=target) + lib(*input_tensors) + + +class TestRandomRelayE2ECorrectness: + """Tests E2E correctness of layout transform schedule. + + Randomly generates relay mod with layout transform and fusable ops. Checks the + layout transform task for correctness by comparing against its unscheduled result. + """ + + @staticmethod + def generate_test_case( + input_shape: List[int], + implicit_reshape_info: Optional[Tuple[int, int]], + dtype: str, + num_additional_ops: int, + ) -> tvm.IRModule: + """Creates a random layout transform module with up to num_additional_ops fused.""" + # Create layout transforms + rank = len(input_shape) + + # src_layout is a string like ABCDEFG... with length as rank + src_layout = "".join([chr(i + ord("A")) for i in range(rank)]) + + # dst_layout is randomly shuffled src_layout, potentially after adding split axis + dst_layout = list(src_layout) + if implicit_reshape_info: + axis_to_reshape, size_new_dim = implicit_reshape_info + cur_dim = dst_layout[axis_to_reshape] + dst_layout[axis_to_reshape] = f"{cur_dim}" + dst_layout.append(f"{size_new_dim}{cur_dim.lower()}") + + random.shuffle(dst_layout) + while "".join(dst_layout) == src_layout: + random.shuffle(dst_layout) + dst_layout = "".join(dst_layout) + + # Randomly sample a list of potentially fusable ops to layout transform + op_order = random.choices( + list(pattern_level_to_op.keys()), + k=num_additional_ops, + ) + + # Append tuple, representing layout transfomr from src --> dst layout + op_order.append((src_layout, dst_layout)) + + random.shuffle(op_order) + return create_relay_module(input_shape, dtype, op_order) + + @staticmethod + def get_primfunc(extracted_task: meta_schedule.ExtractedTask, tile_size: Optional[int]): + with PatchCustomLayoutTransformScheduleRule( + tile_sizes=[] if tile_size is None else [tile_size] + ): + with tempfile.TemporaryDirectory() as tmpdir: + ( + tune_contexts, + _, + ) = meta_schedule.relay_integration.extracted_tasks_to_tune_contexts( + [extracted_task], + tmpdir, + ) + tune_contexts[0].pre_tuning(1) + candidates = tune_contexts[0].generate_measure_candidates() + primfunc = candidates[0].sch.mod["main"] + return primfunc + + @staticmethod + def verify_layout_transform_task( + extracted_task: meta_schedule.ExtractedTask, + target: tvm.target.Target, + tile_sizes: List[int], + ): + """Given a layout transform task, tests the given tile_sizes and verifies output matches.""" + device = tvm.cuda(0) + relay_mod = extracted_task.mod + + # Create and cache inputs + func_type = relay.transform.InferTypeLocal(relay_mod[relay_mod.get_global_vars()[0]]) + input_tensors = [] + for input_type in func_type.arg_types: + orig_input_np = np.random.uniform(0, 10, size=list(map(int, input_type.shape))).astype( + input_type.dtype + ) + orig_input_np = np.arange(0, orig_input_np.size, dtype=input_type.dtype).reshape( + orig_input_np.shape + ) + input_tensors.append(tvm.nd.array(orig_input_np, device)) + ret_type = func_type.ret_type + + def get_output_tensor() -> Tuple[tvm.nd.NDArray, tvm.nd.NDArray]: + numpy_init = np.random.uniform(0, 1000, size=list(map(int, ret_type.shape))).astype( + ret_type.dtype + ) + return tvm.nd.array(numpy_init, device) + + def run_and_get_output(tile_size: Optional[int]) -> np.ndarray: + returned_primfunc = TestRandomRelayE2ECorrectness.get_primfunc( + extracted_task, tile_size + ) + output_tensor = get_output_tensor() + run_primfunc(returned_primfunc, target, [*input_tensors, output_tensor]) + return output_tensor.numpy() + + # Passing None, we basically do not apply the custom rule we have created + # and instead use the old default schedule which is the ground truth. + ground_truth_np = run_and_get_output(None) + + for tile_size in tile_sizes: + experimental_np = run_and_get_output(tile_size) + np.testing.assert_allclose(ground_truth_np, experimental_np) + + ( + input_shape, + implicit_reshape_info, + dtype, + tile_sizes, + num_additional_ops, + ) = tvm.testing.parameters( + *itertools.product( + # input_shape: Each has ~10k elements, should take single microseconds on modern gpu + [ + [12, 48, 18], + [890, 14], + [10, 12, 2, 5, 3, 3], + ], + # implicit_reshape_info: Implicit reshape conditions. + # None is do no implicit reshape, (0, 2) means divide axis 0 in half, e.g. AB --> A2aB + [None, (0, 2), (1, 2)], + # dtype: dtypes to test, should not matter that much + ["float16"], + # tile_sizes: Tile sizes to try + [[8, 7]], + # num_additional_ops: number of non-layout transform ops to include and may be fused + [5], + ) + ) + + @tvm.testing.requires_gpu + def test_all_test_case( + self, + input_shape, + implicit_reshape_info, + dtype, + tile_sizes, + num_additional_ops, + ): + """Tests the product of all conditions `repeat_per_condition` times.""" + # Generate random module of fusable ops + layout transform and extract fused layout transform task + full_mod = self.generate_test_case( + input_shape, implicit_reshape_info, dtype, num_additional_ops + ) + + # Fused layout transform task + extracted_task = extract_layout_transform_task(full_mod, TARGET) + self.verify_layout_transform_task(extracted_task, TARGET, tile_sizes) + + +@tvm.testing.requires_gpu +class TestManualCases: + def assert_extracted_equals_expected( + self, relay_mod: tvm.IRModule, expected_mod: tvm.IRModule, tile_size: int + ): + extracted_task = extract_layout_transform_task(relay_mod, TARGET) + dispatched_mod = extracted_task.dispatched[0] + sch = tvm.tir.Schedule(dispatched_mod) + block = sch.get_block("T_layout_trans") + output_sch = cuda_layout_transform_schedule_rule(sch, block, [tile_size])[0] + assert output_sch.mod.script() == expected_mod.script() + + def test_simple_tiling(self): + mod = create_relay_module([1, 32, 32, 32], "float16", [("NCHW", "NHWC")]) + + # Main things to notice: + # - two blocks each with 16, 16 extents which write/read shared mem + # - coalesced accesses in inner loop of global memory buffer for both + # fmt: off + @I.ir_module + class ExpectedModule: + @T.prim_func + def main(p0: T.Buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), "float16"), T_layout_trans: T.Buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), "float16")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + p0_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), "float16", scope="shared") + for ax0_ax2_ax1_0_ax3_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x"): + for ax3_1_fused_0_ax3_1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"): + for ax1_1_fused_0_ax1_1_fused_1_fused in range(T.int64(16)): + with T.block("p0_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), ax0_ax2_ax1_0_ax3_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + ax1_1_fused_0_ax1_1_fused_1_fused) + v2 = T.axis.spatial(T.int64(32), ax0_ax2_ax1_0_ax3_0_fused // T.int64(4)) + v3 = T.axis.spatial(T.int64(32), ax0_ax2_ax1_0_ax3_0_fused % T.int64(2) * T.int64(16) + ax3_1_fused_0_ax3_1_fused_1_fused) + T.reads(p0[v0, v1, v2, v3]) + T.writes(p0_shared[v0, v1, v2, v3]) + p0_shared[v0, v1, v2, v3] = p0[v0, v1, v2, v3] + for ax0_ax1_fused_0 in range(T.int64(16)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): + with T.block("T_layout_trans"): + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(T.int64(32), ax0_ax2_ax1_0_ax3_0_fused // T.int64(4)) + v_ax2 = T.axis.spatial(T.int64(32), ax0_ax2_ax1_0_ax3_0_fused % T.int64(2) * T.int64(16) + (ax0_ax1_fused_0 * T.int64(16) + ax0_ax1_fused_1) // T.int64(16)) + v_ax3 = T.axis.spatial(T.int64(32), ax0_ax2_ax1_0_ax3_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + (ax0_ax1_fused_0 * T.int64(16) + ax0_ax1_fused_1) % T.int64(16)) + T.reads(p0_shared[v_ax0, v_ax3, v_ax1, v_ax2]) + T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3]) + T.block_attr({"dst_layout": "NHWC", "input_shape": [1, 32, 32, 32], "schedule_rule": "layout_transform", "src_layout": "NCHW"}) + T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(v_ax0 < T.int64(1) and v_ax3 < T.int64(32) and v_ax1 < T.int64(32) and v_ax2 < T.int64(32), p0_shared[v_ax0, v_ax3, v_ax1, v_ax2], T.float16(0)) + + self.assert_extracted_equals_expected(mod, ExpectedModule, 16) + + def test_simple_implicit_reshape(self): + mod = create_relay_module([1, 32, 32, 32], "float16", [("NCHW", "NCHW4c")]) + + # Main things to notice: + # - two blocks each with 16, 16 extents which write/read shared mem + # - coalesced accesses in inner loop of global memory buffer for both + # - an implicit reshape is done (see p0_shared) + # fmt: off + @I.ir_module + class ExpectedModule: + @T.prim_func + def main(p0: T.Buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), "float16"), T_layout_trans: T.Buffer((T.int64(1), T.int64(8), T.int64(32), T.int64(32), T.int64(4)), "float16")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + p0_shared = T.alloc_buffer((T.int64(1), T.int64(8), T.int64(4), T.int64(32), T.int64(32)), "float16", scope="shared") + for ax0_ax1_ax2_0_ax4_0_ax3_0_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x"): + for ax3_1_fused_0_ax3_1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"): + for ax2_1_ax3_0_1_ax4_1_fused_0_ax2_1_ax3_0_1_ax4_1_fused_1_fused in range(T.int64(16)): + with T.block("p0_shared"): + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(T.int64(8), ax0_ax1_ax2_0_ax4_0_ax3_0_0_fused // T.int64(16)) + v_ax2 = T.axis.spatial(T.int64(32), ax0_ax1_ax2_0_ax4_0_ax3_0_0_fused % T.int64(16) * T.int64(2) + ax2_1_ax3_0_1_ax4_1_fused_0_ax2_1_ax3_0_1_ax4_1_fused_1_fused // T.int64(8)) + v_ax3 = T.axis.spatial(T.int64(32), ax2_1_ax3_0_1_ax4_1_fused_0_ax2_1_ax3_0_1_ax4_1_fused_1_fused % T.int64(8) // T.int64(4) * T.int64(16) + ax3_1_fused_0_ax3_1_fused_1_fused) + v_ax4 = T.axis.spatial(T.int64(4), ax2_1_ax3_0_1_ax4_1_fused_0_ax2_1_ax3_0_1_ax4_1_fused_1_fused % T.int64(4)) + T.reads(p0[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3]) + T.writes(p0_shared[v_ax0, v_ax1, v_ax4, v_ax2, v_ax3]) + p0_shared[v_ax0, v_ax1, v_ax4, v_ax2, v_ax3] = p0[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3] + for ax0_ax1_ax2_fused_0 in range(T.int64(16)): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): + with T.block("T_layout_trans"): + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(T.int64(8), ax0_ax1_ax2_0_ax4_0_ax3_0_0_fused // T.int64(16)) + v_ax2 = T.axis.spatial(T.int64(32), ax0_ax1_ax2_0_ax4_0_ax3_0_0_fused % T.int64(16) * T.int64(2) + (ax0_ax1_ax2_fused_0 * T.int64(16) + ax0_ax1_ax2_fused_1) // T.int64(128)) + v_ax3 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(16) + ax0_ax1_ax2_fused_1) % T.int64(128) // T.int64(4)) + v_ax4 = T.axis.spatial(T.int64(4), (ax0_ax1_ax2_fused_0 * T.int64(16) + ax0_ax1_ax2_fused_1) % T.int64(4)) + T.reads(p0_shared[v_ax0, v_ax1, v_ax4, v_ax2, v_ax3]) + T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"dst_layout": "NCHW4c", "input_shape": [1, 32, 32, 32], "schedule_rule": "layout_transform", "src_layout": "NCHW"}) + T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(v_ax0 < T.int64(1) and v_ax1 * T.int64(4) + v_ax4 < T.int64(32) and v_ax2 < T.int64(32) and v_ax3 < T.int64(32), p0_shared[v_ax0, v_ax1, v_ax4, v_ax2, v_ax3], T.float16(0)) + self.assert_extracted_equals_expected(mod, ExpectedModule, 16) + + def test_expected_fusion_post(self): + mod = create_relay_module( + [1, 32, 32, 32], "float16", [("NCHW", "NCHW4c"), OpPattern.BROADCAST] + ) + + # Main things to notice: + # - two blocks each with 16, 16 extents which write/read shared mem + # - coalesced accesses in inner loop of global memory buffer for both + # - an implicit reshape is done (see p0_shared) + # - an addition is inlined in the final block (p1 input) + # fmt: off + @I.ir_module + class ExpectedModule: + @T.prim_func + def main(p0: T.Buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), "float16"), p1: T.Buffer((), "float16"), T_add: T.Buffer((T.int64(1), T.int64(8), T.int64(32), T.int64(32), T.int64(4)), "float16")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + p0_shared = T.alloc_buffer((T.int64(1), T.int64(8), T.int64(4), T.int64(32), T.int64(32)), "float16", scope="shared") + for ax0_ax1_ax2_0_ax4_0_ax3_0_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x"): + for ax3_1_fused_0_ax3_1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"): + for ax2_1_ax3_0_1_ax4_1_fused_0_ax2_1_ax3_0_1_ax4_1_fused_1_fused in range(T.int64(16)): + with T.block("p0_shared"): + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(T.int64(8), ax0_ax1_ax2_0_ax4_0_ax3_0_0_fused // T.int64(16)) + v_ax2 = T.axis.spatial(T.int64(32), ax0_ax1_ax2_0_ax4_0_ax3_0_0_fused % T.int64(16) * T.int64(2) + ax2_1_ax3_0_1_ax4_1_fused_0_ax2_1_ax3_0_1_ax4_1_fused_1_fused // T.int64(8)) + v_ax3 = T.axis.spatial(T.int64(32), ax2_1_ax3_0_1_ax4_1_fused_0_ax2_1_ax3_0_1_ax4_1_fused_1_fused % T.int64(8) // T.int64(4) * T.int64(16) + ax3_1_fused_0_ax3_1_fused_1_fused) + v_ax4 = T.axis.spatial(T.int64(4), ax2_1_ax3_0_1_ax4_1_fused_0_ax2_1_ax3_0_1_ax4_1_fused_1_fused % T.int64(4)) + T.reads(p0[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3]) + T.writes(p0_shared[v_ax0, v_ax1, v_ax4, v_ax2, v_ax3]) + p0_shared[v_ax0, v_ax1, v_ax4, v_ax2, v_ax3] = p0[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3] + for ax0_ax1_ax2_fused_0 in range(T.int64(16)): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): + with T.block("T_layout_trans"): + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(T.int64(8), ax0_ax1_ax2_0_ax4_0_ax3_0_0_fused // T.int64(16)) + v_ax2 = T.axis.spatial(T.int64(32), ax0_ax1_ax2_0_ax4_0_ax3_0_0_fused % T.int64(16) * T.int64(2) + (ax0_ax1_ax2_fused_0 * T.int64(16) + ax0_ax1_ax2_fused_1) // T.int64(128)) + v_ax3 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(16) + ax0_ax1_ax2_fused_1) % T.int64(128) // T.int64(4)) + v_ax4 = T.axis.spatial(T.int64(4), (ax0_ax1_ax2_fused_0 * T.int64(16) + ax0_ax1_ax2_fused_1) % T.int64(4)) + T.reads(p0_shared[v_ax0, v_ax1, v_ax4, v_ax2, v_ax3], p1[()]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"dst_layout": "NCHW4c", "input_shape": [1, 32, 32, 32], "schedule_rule": "layout_transform", "src_layout": "NCHW"}) + T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(v_ax0 < T.int64(1) and v_ax1 * T.int64(4) + v_ax4 < T.int64(32) and v_ax2 < T.int64(32) and v_ax3 < T.int64(32), p0_shared[v_ax0, v_ax1, v_ax4, v_ax2, v_ax3], T.float16(0)) + p1[()] + self.assert_extracted_equals_expected(mod, ExpectedModule, 16) + + +if __name__ == "__main__": + tvm.testing.main()