Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1739,7 +1739,9 @@ class IterMapToExprNormalizer : public ExprMutator {
bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) {
const auto* clhs = lhs.as<IntImmNode>();
const auto* crhs = rhs.as<IntImmNode>();
if (clhs && crhs) {
if (crhs && crhs->value == 0) {
return false;
} else if (clhs && crhs) {
return clhs->value % crhs->value == 0;
}

Expand Down
33 changes: 20 additions & 13 deletions src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,21 @@ IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(A
return IndexMap(initial_indices, func(initial_indices), std::move(inverse_index_map));
}

std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initial_ranges) const {
if ((*this)->inverse_index_map.defined()) {
std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const IndexMap& self,
const Array<Range>& initial_ranges,
arith::IterMapLevel check_level) {
if (self->inverse_index_map.defined()) {
// return the pre-defined inverse index map if exists. In this
// case, the user-defined inverse is assumed to be correct and
// bijective.
PrimExpr padding_predicate = Bool(false);
return {Downcast<IndexMap>((*this)->inverse_index_map.value()), padding_predicate};
return {Downcast<IndexMap>(self->inverse_index_map.value()), padding_predicate};
}

// Dummy variables to represent the inverse's inputs.
Array<Var> output_vars;
for (size_t i = 0; i < (*this)->final_indices.size(); i++) {
PrimExpr index = (*this)->final_indices[i];
for (size_t i = 0; i < self->final_indices.size(); i++) {
PrimExpr index = self->final_indices[i];
// TODO(Lunderberg): Better names for these variables. A variable
// that is passed through unmodified (`index` is an element of
// `initial_indices`) should use that input index's name. A pair
Expand All @@ -79,16 +81,16 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia

// Dummy ranges for the extent of each input.
Map<Var, Range> input_iters;
ICHECK_EQ((*this)->initial_indices.size(), initial_ranges.size());
ICHECK_EQ(self->initial_indices.size(), initial_ranges.size());
for (size_t i = 0; i < initial_ranges.size(); i++) {
input_iters.Set((*this)->initial_indices[i], initial_ranges[i]);
input_iters.Set(self->initial_indices[i], initial_ranges[i]);
}

// Unpack the output indices into linear combinations of the initial
// indices.
arith::Analyzer analyzer;
auto padded_iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1,
/*check_level=*/arith::IterMapLevel::NoCheck, &analyzer,
auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /* predicate = */ 1,
/*check_level=*/check_level, &analyzer,
/*simplify_trivial_iterators=*/false);
CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. "
<< "Error: " << padded_iter_map->errors[0];
Expand All @@ -100,8 +102,8 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia

// Unpack the map to an array, maintaining the same parameter order.
Array<PrimExpr> inverse_exprs;
for (int i = 0, n = (*this)->initial_indices.size(); i < n; ++i) {
Var index = (*this)->initial_indices[i];
for (int i = 0, n = self->initial_indices.size(); i < n; ++i) {
Var index = self->initial_indices[i];
PrimExpr expr;
if (is_one(initial_ranges[i]->extent) && !inverse_exprs_map.count(index)) {
expr = initial_ranges[i]->min;
Expand All @@ -116,7 +118,7 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia
padding_predicate = Substitute(padding_predicate, inverse_exprs_map);

{
auto output_ranges = (*this)->MapRanges(initial_ranges);
auto output_ranges = self->MapRanges(initial_ranges);
ICHECK_EQ(output_ranges.size(), output_vars.size());

arith::Analyzer analyzer;
Expand All @@ -131,8 +133,13 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia
return {IndexMap(output_vars, inverse_exprs), padding_predicate};
}

std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initial_ranges) const {
return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck);
}

IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
auto [inverse, padding_predicate] = NonSurjectiveInverse(std::move(initial_ranges));
auto [inverse, padding_predicate] =
IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective);
arith::Analyzer analyzer;
CHECK(analyzer.CanProve(!padding_predicate))
<< "Bijective inverse should not contain padding, but inverse of " << *this << " over range "
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,5 +575,18 @@ def test_size_one_buffer(shape, transform):
s[B].transform_layout(transform)


def test_non_divisible_transform_raises_error():
A = te.placeholder([1, 3, 8, 8])
B = te.compute(A.shape, lambda *indices: A[indices])
s = te.create_schedule(B.op)

transform = lambda n, c, h, w: [n, c // 4, h, w, c % 4]
# Error occurs here, because the transformation would introduce
# padding. Padded transforms are supported in TIR-based
# schedules.
with pytest.raises(tvm.TVMError):
s[B].transform_layout(transform)


Comment on lines +578 to +590
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Probably we can joint this test case with test_size_one_buffer? Just extend its testing parameters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could, but I think it would make the test less readable as an example use case, specifically what behavior is being tested, because the desired behavior differs in each case. It would look something like below, but there's nothing to call attention to the fact that is_valid changes the expected behavior, and isn't just a parameter being used in the setup.

shape, transform, is_valid = tvm.testing.parameters(
    ([1, 8], lambda n, i: [i, n], True),
    ([1, 1, 8], lambda i, j, k: [j, te.AXIS_SEPARATOR, i, k], True),
    ([1, 1, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k], True),
    ([1, 3, 8, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k], False),
)


def test_transform_validity(shape, transform, is_valid):
    dtype = "int8"
    A = te.placeholder(shape, dtype, name="A")
    B = te.compute(
        shape=A.shape,
        fcompute=lambda *indices: A[indices].astype(dtype),
        name="B",
    )
    s = te.create_schedule(B.op)

    if is_valid:
        s[B].transform_layout(transform)
    else:
        with pytest.raises(tvm.TVMError):
            s[B].transform_layout(transform)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thank you. I agree.

if __name__ == "__main__":
tvm.testing.main()