Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -588,12 +588,47 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
}
}

bool all_singlepoints_outside = true;

// Check all partitions to see if they are single points and outside `for_interval`
for (const auto& partition : finder.partitions) {
const auto& intset = partition.second;
// Only proceed if the interval set is a single point
if (intset.IsSinglePoint()) {
auto single_point = intset.PointValue();
// Check if the single point is outside the `for_interval`
bool is_inside = analyzer_.CanProve(single_point >= for_interval.min()) &&
analyzer_.CanProve(single_point <= for_interval.max());
if (is_inside) {
// If any single point is inside, this is an error condition
LOG(ERROR) << "unexpected case happened.";
all_singlepoints_outside = false;
break;
}
} else {
// If there is any intset that is not a single point, follow default logic
// For now, we set all_singlepoints_outside to false to indicate default logic was used
all_singlepoints_outside = false;
break;
}
}

if (all_singlepoints_outside) {
// If all single points are outside `for_interval`, return a nothing interval and false
return {IntSet::Nothing(), ExpressionSet(), false};
}

// we couldn't find an interval in which the conditions are
// provably true or false. Therefore, we can't partition the loop
// based on those conds
return {{}, {}, std::nullopt};
}();

if (middle_interval.IsNothing() && opt_cond_value == false) {
// Return loop directly as it can be simplified.
return stmt;
}

if (!opt_cond_value.has_value()) {
if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ &&
analyzer_.CanProve(max - min > 0)) {
Expand Down
228 changes: 228 additions & 0 deletions tests/python/tir-transform/test_tir_transform_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm import te
Expand Down Expand Up @@ -834,5 +835,232 @@ def after(
assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main"))


@T.prim_func
def concat_func_single_point(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
) -> None:
for i0 in range(28):
for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
if i1 > 63:
T_concat[i0, i1] = placeholder[i0, i1 - 64]
elif i1 == 63:
T_concat[i0, i1] = placeholder_1[i0, i1 - 63]
else:
T_concat[i0, i1] = placeholder_2[i0, i1]


@T.prim_func
def expected_partitioned_concat_single_point(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
):
for i0 in range(28):
T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
for i1 in range(63):
placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data)
T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1]
placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
T_concat_1[i0 * 128 + 63] = placeholder_1_1[i0]
for i1 in range(64):
placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]


@T.prim_func
def concat_func_start_point_equality(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
) -> None:
for i0 in range(28):
for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}):
if i1 == 0:
# Special case for i1 == 0
T_concat[i0, i1] = placeholder_1[i0, 0]
elif i1 < 64:
# Normal case for i1 in [1, 63]
T_concat[i0, i1] = placeholder_2[i0, i1]
else:
# Case for i1 in [64, 127]
T_concat[i0, i1] = placeholder[i0, i1 - 64]


@T.prim_func
def concat_func_start_point_equality_expected(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
):
for i0 in range(28):
T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
T_concat_1[i0 * 128] = placeholder_1_1[i0]
for i1 in range(63):
placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data)
T_concat_1[i0 * 128 + i1 + 1] = placeholder_2_1[i0 * 63 + i1 + 1]
for i1 in range(64):
placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]


@T.prim_func
def concat_func_end_point_equality(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
) -> None:
for i0 in range(28):
for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}):
if i1 == 127:
# Explicit equality check for the end point i1 == 127
T_concat[i0, i1] = placeholder_1[i0, 0]
elif i1 >= 64:
# Case for i1 in [64, 126]
T_concat[i0, i1] = placeholder[i0, i1 - 64]
else:
# Case for i1 in [0, 63]
T_concat[i0, i1] = placeholder_2[i0, i1]


@T.prim_func
def concat_func_end_point_equality_expected(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
):
for i0 in range(28):
T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
for i1 in range(64):
placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data)
T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1]
for i1 in range(63):
placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]
placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
T_concat_1[i0 * 128 + 127] = placeholder_1_1[i0]


@T.prim_func
def concat_func_edge_equalities(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 1), "int8"),
T_concat: T.Buffer((28, 66), "int8"),
) -> None:
for i0 in range(28):
for i1 in range(
66, annotations={"pragma_loop_partition_hint": 1}
): # Loop from 0 to 65 inclusive
if i1 == 0:
# Handle equality at the start of the range: i1 == 0
T_concat[i0, i1] = placeholder_2[i0, 0]
elif i1 == 65:
# Handle equality at the end of the range: i1 == 65
T_concat[i0, i1] = placeholder_1[i0, 0]
else:
# Copying from placeholder (from 0 to 63)
T_concat[i0, i1] = placeholder[i0, i1 - 1]


@T.prim_func
def concat_func_edge_equalities_expected(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 1), "int8"),
T_concat: T.Buffer((28, 66), "int8"),
):
for i0 in range(28):
T_concat_1 = T.Buffer((1848,), "int8", data=T_concat.data)
placeholder_2_1 = T.Buffer((28,), "int8", data=placeholder_2.data)
T_concat_1[i0 * 66] = placeholder_2_1[i0]
for i1 in range(64):
placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
T_concat_1[i0 * 66 + i1 + 1] = placeholder_3[i0 * 64 + i1]
placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
T_concat_1[i0 * 66 + 65] = placeholder_1_1[i0]


@T.prim_func
def concat_five_buffers_with_equalities(
buffer_a: T.Buffer((28, 1), "int8"), # Used for i1 == 0
buffer_b: T.Buffer((28, 63), "int8"), # Fills i1 from 1 to 63
buffer_c: T.Buffer((28, 1), "int8"), # Used for i1 == 64
buffer_d: T.Buffer((28, 63), "int8"), # Fills i1 from 65 to 128
buffer_e: T.Buffer((28, 1), "int8"), # Used for i1 == 129
T_concat: T.Buffer((28, 129), "int8"),
) -> None:
for i0 in range(28):
for i1 in range(130, annotations={"pragma_loop_partition_hint": 1}):
if i1 == 0:
T_concat[i0, i1] = buffer_a[i0, 0]
elif i1 == 64:
T_concat[i0, i1] = buffer_c[i0, 0]
elif i1 == 129:
T_concat[i0, i1] = buffer_e[i0, 0]
elif i1 < 64:
T_concat[i0, i1] = buffer_b[i0, i1 - 1]
else: # i1 > 64 and i1 < 128
T_concat[i0, i1] = buffer_d[i0, i1 - 65]


@T.prim_func
def concat_five_buffers_with_equalities_expected(
buffer_a: T.Buffer((28, 1), "int8"), # Used for i1 == 0
buffer_b: T.Buffer((28, 63), "int8"), # Fills i1 from 1 to 63
buffer_c: T.Buffer((28, 1), "int8"), # Used for i1 == 64
buffer_d: T.Buffer((28, 63), "int8"), # Fills i1 from 65 to 128
buffer_e: T.Buffer((28, 1), "int8"), # Used for i1 == 129
T_concat: T.Buffer((28, 129), "int8"),
):
for i0 in range(28):
T_concat_1 = T.Buffer((3612,), "int8", data=T_concat.data)
buffer_a_1 = T.Buffer((28,), "int8", data=buffer_a.data)
T_concat_1[i0 * 129] = buffer_a_1[i0]
for i1 in range(63):
buffer_b_1 = T.Buffer((1764,), "int8", data=buffer_b.data)
T_concat_1[i0 * 129 + i1 + 1] = buffer_b_1[i0 * 63 + i1]
buffer_c_1 = T.Buffer((28,), "int8", data=buffer_c.data)
T_concat_1[i0 * 129 + 64] = buffer_c_1[i0]
for i1 in range(64):
buffer_d_1 = T.Buffer((1764,), "int8", data=buffer_d.data)
T_concat_1[i0 * 129 + i1 + 65] = buffer_d_1[i0 * 63 + i1]
buffer_e_1 = T.Buffer((28,), "int8", data=buffer_e.data)
T_concat_1[i0 * 129 + 129] = buffer_e_1[i0]


@pytest.mark.parametrize(
"origin,expected",
[
(concat_func_single_point, expected_partitioned_concat_single_point),
(concat_func_start_point_equality, concat_func_start_point_equality_expected),
(concat_func_end_point_equality, concat_func_end_point_equality_expected),
(concat_func_edge_equalities, concat_func_edge_equalities_expected),
(concat_five_buffers_with_equalities, concat_five_buffers_with_equalities_expected),
],
)
def test_single_point_partition(origin, expected):
origin = origin.with_attr({"global_symbol": "main"})
expected = expected.with_attr({"global_symbol": "main"})
mod = partition_from_scheduled_tir(
origin,
{
"tir.LoopPartition": {
"partition_const_loop": True,
"unroll_loop_with_partition_hint_no_interval": True,
}
},
)
assert tvm.ir.structural_equal(mod["main"], expected)


if __name__ == "__main__":
tvm.testing.main()