Skip to content
73 changes: 70 additions & 3 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,73 @@ class IterSumExpr : public IterMapExpr {
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);

/*! \brief A utility struct for return values from DetectPaddedIterMap
*/
struct PaddedIterMapResult {
// Any errors that occurred while converting the input indices. If
// the array is empty, the conversion was successful.
Array<String> errors;

// The detected pattern if a match exists.
Array<IterSumExpr> indices;

/* \brief Boolean expression indicating if padding was required
*
* `requires_padding` evaluates to true if the returned indices
* contain padding relative to the provided expressions, and false
* otherwise. If `input_iters` contains a variable extent, this
* expression may be in terms of those variables.
*/
PrimExpr requires_padding;

/* \brief Boolean expression indicating if a specific value w
*
* `padding_predicate` evaluates to true for a set of indices that
* are outside the bounds of the provided index iterators, but
* inside the bounds of the returned index iterators. This
* expression is in terms of the variables provided in
* `input_iters`.
*/
PrimExpr padding_predicate;
};

/*!
* \brief Detect if indices can be written as
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
*
* Here y = some-quasi-affine-iter-map(input_iters) and c are
* symbolic constants. The y_i iterators may be padded to fit this
* representation.
*
* We also requires that y_i and y_j to be independent for i != j.
*
* For returned value rv, the following is always true:
* - rv.indices[i]->args.size() <=1: only one iterator per element.
*
* \param indices The indices to detect pattern for.
*
* \param input_iters Map from variable to iterator's range.
*
* \param predicate The predicate constraints on the input iterators
*
* \param require_bijective A boolean flag that indicates whether the
* mapping should be bijective. If true, no padding may be
* introduced.
*
* \param analyzer Analyzer used to get context information.
*
* \param simplify_trivial_iterators If true, iterators with extent of
* 1 will be replaced with a constant value.
*
* \return An instance of PaddedIterMapResult.
*/
PaddedIterMapResult DetectPaddedIterMap(const Array<PrimExpr>& indices,
const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer,
bool simplify_trivial_iterators = true);

/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
Expand Down Expand Up @@ -352,11 +419,11 @@ Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
bool require_bijective, arith::Analyzer* analyzer);

/*!
* \brief Given an IterMapExpr, transform it to normal PrimExpr.
* \param expr The input IterMapExpr.
* \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr.
* \param expr The input expression, which may contain IterMapExpr.
* \return The corresponding normal PrimExpr.
*/
PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr);
PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr);

} // namespace arith
} // namespace tvm
Expand Down
20 changes: 17 additions & 3 deletions include/tvm/tir/index_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <tvm/runtime/object.h>
#include <tvm/tir/var.h>

#include <utility>

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -141,12 +143,24 @@ class IndexMap : public ObjectRef {
*
* TODO(Lunderberg): Look into allowing non-bijective
* transformations. If injective, the inverse mapping could still
* be generated with some predicate. If non-injective, could
* simplify the implementation of other optimizations (e.g. double
* buffering as a map `lambda *indices: [buffer_loop%2, *indices]`).
* be generated with some predicate (see NonSurjectiveInverse). If
* non-injective, could simplify the implementation of other
* optimizations (e.g. double buffering as a map `lambda *indices:
* [buffer_loop%2, *indices]`).
*/
IndexMap Inverse(Array<Range> initial_ranges) const;

/*! \brief Generate the inverse mapping.
*
* Determine the inverse, where the output range may contain
* addresses that do not correspond to an address in the input
* range.
*
* \return The inverted index map, along with the predicate for
* which the inverse maps to a valid range.
*/
std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range> initial_ranges) const;

TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
};

Expand Down
110 changes: 107 additions & 3 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# under the License.
"""Function data types."""

from typing import Callable, List, Mapping, Optional, Union
from typing import Callable, List, Mapping, Optional, Union, Tuple
import inspect

import tvm
import tvm._ffi
import tvm.runtime
from tvm.runtime import Object
from tvm.ir import BaseFunc
from tvm.ir import BaseFunc, Range
from .buffer import Buffer
from .expr import Var, PrimExpr
from . import _ffi_api
Expand Down Expand Up @@ -296,12 +297,42 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None):
final_indices = mapping_function(*args)
return IndexMap(args, final_indices)

def is_equivalent_to(self, other_map: "IndexMap") -> bool:
"""Return if the index maps are equivalent.

Parameters
----------
other_map: IndexMap

The IndexMap to which the comparison should be made.

Returns
-------
is_equivalent: bool

True if the two mappings represent the same
transformation, otherwise False
"""
if len(self.initial_indices) != len(other_map.initial_indices):
return False
if len(self.final_indices) != len(other_map.final_indices):
return False

analyzer = tvm.arith.Analyzer()

mapped_other_final_indices = other_map.map_indices(self.initial_indices)
for self_index, other_index in zip(self.final_indices, mapped_other_final_indices):
if not analyzer.can_prove_equal(self_index, other_index):
return False

return True

def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]:
"""Apply the index map to a set of indices

Parameters
----------
indices : List[PriExpr]
indices : List[PrimExpr]
The indices to be mapped

Returns
Expand All @@ -310,3 +341,76 @@ def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]:
The mapped indices
"""
return _ffi_api.IndexMapMapIndices(self, indices)

def map_shape(self, shape: List[PrimExpr]) -> List[PrimExpr]:
"""Apply the index map to a buffer shape

Parameters
----------
shape : List[PrimExpr]
The buffer shape to be mapped

Returns
-------
result : List[PrimExpr]
The mapped shape
"""
return _ffi_api.IndexMapMapShape(self, shape)

def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap":
"""Return the inverse of the map

Throws an error if the function is not bijective.

Parameters
----------
shape: List[Union[Range,PrimExpr]]

The region over which the inverse should be determined.
Used for validating that the mapping is bijective over
this range.

Returns
-------
inverse : IndexMap

The inverse
"""

shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape]
return _ffi_api.IndexMapInverse(self, shape)

def non_surjective_inverse(
self, shape: List[Union[Range, PrimExpr]]
) -> Tuple["IndexMap", PrimExpr]:
"""Return the inverse of the map

Can be applied to transformations that introduce padding.

Parameters
----------
shape: List[Union[Range,PrimExpr]]

The region over which the inverse should be determined.
Used for determining the predicate.

Returns
-------
result : Tuple[IndexMap, PrimExpr]

The inverse, and a predicate for which the inverse maps to
a valid index in the input range.

Examples
--------

.. code-block:: python

index_map = IndexMap.from_func(lambda i: [i//4, i%4])
inverse_map, predicate = index_map.non_surjective_inverse([14])
assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k])
print(predicate) # Prints "(axis0==3) && (axis2 >= 2)"
"""

shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape]
return _ffi_api.IndexMapNonSurjectiveInverse(self, shape)
Loading