Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -868,10 +868,10 @@ class IterMapRewriter : public ExprMutator {
IterSumExpr structured_form = expr, flattened_form = expr;
flattened_form.CopyOnWrite()->args =
Array<IterSplitExpr>(flattened_iters.rbegin(), flattened_iters.rend());
flattened_form.CopyOnWrite()->base = 0;
flattened_form.CopyOnWrite()->base = make_const(expr.dtype(), 0);
structured_form.CopyOnWrite()->args =
Array<IterSplitExpr>(grouped_iters.rbegin(), grouped_iters.rend());
structured_form.CopyOnWrite()->base = 0;
structured_form.CopyOnWrite()->base = make_const(expr.dtype(), 0);
auto it = sum_fuse_map_.find(flattened_form);
if (it != sum_fuse_map_.end()) {
// old iter
Expand Down Expand Up @@ -1831,11 +1831,20 @@ class SubspaceDivider {
IterSplitExpr GetInnerAsSplit() const { return GetAsSplit(inner, inner_extent); }

static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) {
return DivisionResult(IterSumExpr({}, 0), 1, iter, extent);
auto dtype = iter.dtype();
return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1), iter,
extent);
}

static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) {
return DivisionResult(iter, extent, IterSumExpr({}, 0), 1);
auto dtype = iter.dtype();
return DivisionResult(iter, extent, IterSumExpr({}, make_const(dtype, 0)),
make_const(dtype, 1));
}

// Special value to indicate the division is not possible
static DivisionResult Failure() {
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
}

private:
Expand All @@ -1853,14 +1862,16 @@ class SubspaceDivider {

// Divide an IterSumExpr
DivisionResult DivideIterSumExpr(const IterSumExpr& expr, const PrimExpr& mark_extent) {
auto dtype = expr.dtype();
if (expr->args.empty()) {
// base
return DivisionResult(IterSumExpr({}, 0), 1, IterSumExpr({}, expr->base), 1);
return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1),
IterSumExpr({}, expr->base), make_const(dtype, 1));
} else if (expr->args.size() == 1) {
// arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base)
if (!is_one(expr->args[0]->scale)) {
unresolved_count_++;
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
return DivisionResult::Failure();
}
DivisionResult res = DivideIterSplitExpr(expr->args[0]);
if (!is_zero(expr->base)) res = AddBase(res, expr->base);
Expand All @@ -1869,7 +1880,7 @@ class SubspaceDivider {
// arg1 + arg2 + ... + argn + base
// then we can write it as Y*E(X)+X
// if it starts with contiguous outer splits, followed by contiguous inner splits
PrimExpr extent = 1;
PrimExpr extent = make_const(dtype, 1);
std::vector<IterSplitExpr> outer_args, inner_args;
bool inner = true, scale_is_one = false;
// we check in inverse order so we can visit from inner to outer
Expand All @@ -1881,7 +1892,7 @@ class SubspaceDivider {
if (arg_division.IsInner()) {
if (!inner) {
unresolved_count_++;
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
return DivisionResult::Failure();
}
new_arg = arg_division.GetInnerAsSplit();
inner_args.push_back(new_arg);
Expand All @@ -1892,13 +1903,13 @@ class SubspaceDivider {
inner = false;
} else {
unresolved_count_++;
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
return DivisionResult::Failure();
}
extent *= new_arg->extent;
}
if (!scale_is_one) {
unresolved_count_++;
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
return DivisionResult::Failure();
}
bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent);
const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0);
Expand All @@ -1919,7 +1930,7 @@ class SubspaceDivider {
return DivisionResult::Inner(inner_source, mark_extent);
} else {
unresolved_count_++;
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
return DivisionResult::Failure();
}
}
return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent);
Expand All @@ -1943,7 +1954,7 @@ class SubspaceDivider {
// args are sorted from inner to outer
static IterMark MarkFromArgsAndBase(const std::vector<IterSplitExpr>& args, PrimExpr base) {
std::vector<IterSplitExpr> res;
PrimExpr extent = 1;
PrimExpr extent = make_const(base.dtype(), 1);
for (const IterSplitExpr& it : args) {
IterSplitExpr arg = it;
arg.CopyOnWrite()->scale = extent;
Expand Down Expand Up @@ -2006,7 +2017,7 @@ class SubspaceDivider {
}
if (j == splits.size()) {
unresolved_count_++;
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
return DivisionResult::Failure();
}
used[j] = true;
if (!encountered_boundary) {
Expand All @@ -2020,7 +2031,7 @@ class SubspaceDivider {
}
if (!encountered_boundary) {
unresolved_count_++;
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
return DivisionResult::Failure();
}
for (const IterSplitExpr& inner_iter : inner_iters) {
IterSplitExpr new_iter = inner_iter;
Expand All @@ -2036,7 +2047,7 @@ class SubspaceDivider {
}
} else {
unresolved_count_++;
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
return DivisionResult::Failure();
}
return split_map_.at(expr);
}
Expand Down
12 changes: 7 additions & 5 deletions src/tir/schedule/ir_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) {
}

bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {
bool equal =
n.same_as(other) || ((n->type_index() == other->type_index()) && n->dtype == other->dtype &&
ExprComparator::VisitExpr(n, other));
bool equal = n.same_as(other) ||
((n->type_index() == other->type_index()) &&
n.dtype().code() == other.dtype().code() && ExprComparator::VisitExpr(n, other));
if (!equal && assert_mode_) {
std::ostringstream os;
os << "Expression mismatch: " << n << " vs " << other;
Expand Down Expand Up @@ -185,7 +185,7 @@ bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) {
const auto* rhs = other.as<VarNode>();
auto lhs = GetRef<Var>(op);
if (lhs.same_as(other)) return true;
if (op->dtype != rhs->dtype) return false;
if (op->dtype.code() != rhs->dtype.code()) return false;
auto it = equal_map_.find(lhs);
return it != equal_map_.end() && it->second.same_as(other);
}
Expand All @@ -208,7 +208,9 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) {
if (it != equal_map_.end()) return it->second.same_as(rhs);
// Otherwise remap lhs to rhs
equal_map_[lhs] = rhs;
analyzer_.Bind(lhs, rhs);
// Cast if necessary. This allows the workload and the tensor intrin to have different dtypes in
// the indices.
analyzer_.Bind(lhs, cast(lhs.dtype(), rhs));
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/blockize_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int
}
for (int i = 0; i < static_cast<int>(old_region.size()); i++) {
PrimExpr min = indices_base[i + offset];
PrimExpr extent = old_region[i]->extent;
PrimExpr extent = cast(min.dtype(), old_region[i]->extent);
new_region.push_back(Range::FromMinExtent(min, extent));
}
match_buffer_regions.push_back(MatchBufferRegion(impl, BufferRegion(cur, new_region)));
Expand Down
11 changes: 7 additions & 4 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ class TransformLayoutPlanner : private StmtExprVisitor {
ss << "v_" << var->name_hint;
Var virtual_var(ss.str(), var.dtype());
new_iter_values.push_back(var);
new_iter_vars.push_back(IterVar(Range::FromMinExtent(0, dim), virtual_var, kDataPar));
new_iter_vars.push_back(
IterVar(Range::FromMinExtent(make_zero(dim.dtype()), dim), virtual_var, kDataPar));
new_access_indices.push_back(virtual_var);
loop_var_to_virtual_var.Set(var, virtual_var);
}
Expand Down Expand Up @@ -990,7 +991,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
auto [inverse, padding_predicate] = [&]() {
Array<Range> region;
for (const auto& dim : old_buffer->shape) {
region.push_back(Range::FromMinExtent(0, dim));
region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
}
return index_map.NonSurjectiveInverse(region);
}();
Expand Down Expand Up @@ -1209,8 +1210,10 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
if (iter_type == kOpaque) {
throw OpaqueNewIterTypeError(self->mod, GetRef<Block>(block_ptr), transformed_block_iters[i]);
}
new_block_iters.push_back(IterVar(/*dom=*/Range::FromMinExtent(0, new_block_iter_range[i]),
/*var=*/std::move(new_block_var), /*iter_type=*/iter_type));
auto dtype = new_block_var.dtype();
new_block_iters.push_back(IterVar(
/*dom=*/Range::FromMinExtent(make_zero(dtype), new_block_iter_range[i]),
/*var=*/std::move(new_block_var), /*iter_type=*/iter_type));
}

// Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters
Expand Down
35 changes: 35 additions & 0 deletions tests/python/unittest/test_tir_schedule_blockize.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,5 +247,40 @@ def after_rowsum_blockize(
verify_trace_roundtrip(sch=s, mod=rowsum)


def test_blockize_outer_int64_shape():
@T.prim_func
def single_elementwise_int64(
A: T.Buffer[(T.int64(16), T.int64(128)), "float32"],
B: T.Buffer[(T.int64(16), T.int64(128)), "float32"],
) -> None:
for i0, j0, i1, j1 in T.grid(T.int64(1), T.int64(8), T.int64(16), T.int64(16)):
with T.block("B"):
vi = T.axis.S(T.int64(16), i0 * T.int64(16) + i1)
vj = T.axis.S(T.int64(128), j0 * T.int64(16) + j1)
B[vi, vj] = A[vi, vj] + 1.0

@T.prim_func
def after_single_elementwise_int64_blockize(
A: T.Buffer[(T.int64(16), T.int64(128)), "float32"],
B: T.Buffer[(T.int64(16), T.int64(128)), "float32"],
) -> None:
for i0, j0 in T.grid(T.int64(1), T.int64(8)):
with T.block("B_o"):
vi_o = T.axis.spatial(T.int64(1), T.int64(0))
vj_o = T.axis.spatial(T.int64(8), j0)
for i1, j1 in T.grid(T.int64(16), T.int64(16)):
with T.block("B"):
vi_i, vj_i = T.axis.remap("SS", [i1, j1])
B[vi_i, vj_o * T.int64(16) + vj_i] = A[
vi_i, vj_o * T.int64(16) + vj_i
] + T.float32(1)

s = tir.Schedule(single_elementwise_int64, debug_mask="all")
_, _, i1, _ = s.get_loops(s.get_block("B"))
s.blockize(i1)
tvm.ir.assert_structural_equal(s.mod["main"], after_single_elementwise_int64_blockize)
verify_trace_roundtrip(sch=s, mod=single_elementwise_int64)


if __name__ == "__main__":
tvm.testing.main()
85 changes: 85 additions & 0 deletions tests/python/unittest/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,5 +653,90 @@ def test_tensor_intrin_look_up():
tir.TensorIntrin.get(intrin_name)


def test_tensorize_matmul_mixed_dtype():
# fmt: off
@T.prim_func
def matmul_int64_shape(
A: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
B: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
C: T.Buffer[(T.int64(128), T.int64(128)), "float32"]
) -> None:
for i_0, j_0 in T.grid(T.int64(8), T.int64(8)):
for i_1_init, j_1_init in T.grid(T.int64(16), T.int64(16)):
with T.block("init"):
vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1_init)
vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1_init)
C[vi, vj] = T.float32(0)
for k_0, i_1, j_1, k_1 in T.grid(T.int64(8), T.int64(16), T.int64(16), T.int64(16)):
with T.block("update"):
vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1)
vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1)
vk = T.axis.reduce(T.int64(128), k_0 * T.int64(16) + k_1)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

@T.prim_func
def tensorized_matmul_int64_shape(
A: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
B: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
C: T.Buffer[(T.int64(128), T.int64(128)), "float32"]
) -> None:
for i_outer, j_outer in T.grid(T.int64(8), T.int64(8)):
for i_inner_init, j_inner_init in T.grid(T.int64(16), T.int64(16)):
with T.block("init"):
vi = T.axis.spatial(T.int64(128), i_outer * T.int64(16) + i_inner_init)
vj = T.axis.spatial(T.int64(128), j_outer * T.int64(16) + j_inner_init)
C[vi, vj] = T.float32(0)
for k_outer in T.grid(T.int64(8)):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer])
T.reads(
[
C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)],
A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
B[vj * T.int64(16) : vj * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
]
)
T.writes(C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)])
A_elem_offset = T.var("int32")
B_elem_offset = T.var("int32")
C_elem_offset = T.var("int32")
A_sub = T.match_buffer(
A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
[16, 16],
elem_offset=A_elem_offset,
)
B_sub = T.match_buffer(
B[vj * T.int64(16) : vj * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
[16, 16],
elem_offset=B_elem_offset,
)
C_sub = T.match_buffer(
C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)],
[16, 16],
elem_offset=C_elem_offset,
)
T.evaluate(
T.tvm_mma_sync(
C_sub.data,
T.floordiv(C_sub.elem_offset, 256),
A_sub.data,
T.floordiv(A_sub.elem_offset, 256),
B_sub.data,
T.floordiv(B_sub.elem_offset, 256),
C_sub.data,
T.floordiv(C_sub.elem_offset, 256),
dtype="handle",
)
)
# fmt: on

s = tir.Schedule(matmul_int64_shape, debug_mask="all")
update = s.get_block("update")
ii = s.get_loops(update)[-3]
s.tensorize(ii, "test_mma_intrin")
tvm.ir.assert_structural_equal(s.mod["main"], tensorized_matmul_int64_shape)
verify_trace_roundtrip(sch=s, mod=matmul_int64_shape)


if __name__ == "__main__":
tvm.testing.main()