From c20db0a1cd3115b76df24a3ad37cae5e9f7dcbc4 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 13 Oct 2022 15:04:29 -0700 Subject: [PATCH 1/3] [TIR] Fix handling of int64 extent in blockize --- src/arith/iter_affine_map.cc | 41 ++++++++++++------- .../unittest/test_tir_schedule_blockize.py | 35 ++++++++++++++++ 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 182eada24d96..93abf59aa7ed 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -868,10 +868,10 @@ class IterMapRewriter : public ExprMutator { IterSumExpr structured_form = expr, flattened_form = expr; flattened_form.CopyOnWrite()->args = Array(flattened_iters.rbegin(), flattened_iters.rend()); - flattened_form.CopyOnWrite()->base = 0; + flattened_form.CopyOnWrite()->base = make_const(expr.dtype(), 0); structured_form.CopyOnWrite()->args = Array(grouped_iters.rbegin(), grouped_iters.rend()); - structured_form.CopyOnWrite()->base = 0; + structured_form.CopyOnWrite()->base = make_const(expr.dtype(), 0); auto it = sum_fuse_map_.find(flattened_form); if (it != sum_fuse_map_.end()) { // old iter @@ -1829,11 +1829,20 @@ class SubspaceDivider { IterSplitExpr GetInnerAsSplit() const { return GetAsSplit(inner, inner_extent); } static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) { - return DivisionResult(IterSumExpr({}, 0), 1, iter, extent); + auto dtype = iter.dtype(); + return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1), iter, + extent); } static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) { - return DivisionResult(iter, extent, IterSumExpr({}, 0), 1); + auto dtype = iter.dtype(); + return DivisionResult(iter, extent, IterSumExpr({}, make_const(dtype, 0)), + make_const(dtype, 1)); + } + + // Special value to indicate the division is not possible + static DivisionResult Failure() { + return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); } private: @@ -1851,14 +1860,16 @@ class SubspaceDivider { // Divide an IterSumExpr DivisionResult DivideIterSumExpr(const IterSumExpr& expr, const PrimExpr& mark_extent) { + auto dtype = expr.dtype(); if (expr->args.empty()) { // base - return DivisionResult(IterSumExpr({}, 0), 1, IterSumExpr({}, expr->base), 1); + return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1), + IterSumExpr({}, expr->base), make_const(dtype, 1)); } else if (expr->args.size() == 1) { // arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base) if (!is_one(expr->args[0]->scale)) { unresolved_count_++; - return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + return DivisionResult::Failure(); } DivisionResult res = DivideIterSplitExpr(expr->args[0]); if (!is_zero(expr->base)) res = AddBase(res, expr->base); @@ -1867,7 +1878,7 @@ class SubspaceDivider { // arg1 + arg2 + ... + argn + base // then we can write it as Y*E(X)+X // if it starts with contiguous outer splits, followed by contiguous inner splits - PrimExpr extent = 1; + PrimExpr extent = make_const(dtype, 1); std::vector outer_args, inner_args; bool inner = true, scale_is_one = false; // we check in inverse order so we can visit from inner to outer @@ -1879,7 +1890,7 @@ class SubspaceDivider { if (arg_division.IsInner()) { if (!inner) { unresolved_count_++; - return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + return DivisionResult::Failure(); } new_arg = arg_division.GetInnerAsSplit(); inner_args.push_back(new_arg); @@ -1890,13 +1901,13 @@ class SubspaceDivider { inner = false; } else { unresolved_count_++; - return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + return DivisionResult::Failure(); } extent *= new_arg->extent; } if (!scale_is_one) { unresolved_count_++; - return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + return DivisionResult::Failure(); } bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent); const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0); @@ -1917,7 +1928,7 @@ class SubspaceDivider { return DivisionResult::Inner(inner_source, mark_extent); } else { unresolved_count_++; - return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + return DivisionResult::Failure(); } } return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent); @@ -1941,7 +1952,7 @@ class SubspaceDivider { // args are sorted from inner to outer static IterMark MarkFromArgsAndBase(const std::vector& args, PrimExpr base) { std::vector res; - PrimExpr extent = 1; + PrimExpr extent = make_const(base.dtype(), 1); for (const IterSplitExpr& it : args) { IterSplitExpr arg = it; arg.CopyOnWrite()->scale = extent; @@ -2004,7 +2015,7 @@ class SubspaceDivider { } if (j == splits.size()) { unresolved_count_++; - return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + return DivisionResult::Failure(); } used[j] = true; if (!encountered_boundary) { @@ -2018,7 +2029,7 @@ class SubspaceDivider { } if (!encountered_boundary) { unresolved_count_++; - return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + return DivisionResult::Failure(); } for (const IterSplitExpr& inner_iter : inner_iters) { IterSplitExpr new_iter = inner_iter; @@ -2034,7 +2045,7 @@ class SubspaceDivider { } } else { unresolved_count_++; - return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + return DivisionResult::Failure(); } return split_map_.at(expr); } diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py index 6d13281320c0..12836cdb9e68 100644 --- a/tests/python/unittest/test_tir_schedule_blockize.py +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -247,5 +247,40 @@ def after_rowsum_blockize( verify_trace_roundtrip(sch=s, mod=rowsum) +def test_blockize_outer_int64_shape(): + @T.prim_func + def single_elementwise_int64( + A: T.Buffer[(T.int64(16), T.int64(128)), "float32"], + B: T.Buffer[(T.int64(16), T.int64(128)), "float32"], + ) -> None: + for i0, j0, i1, j1 in T.grid(T.int64(1), T.int64(8), T.int64(16), T.int64(16)): + with T.block("B"): + vi = T.axis.S(T.int64(16), i0 * T.int64(16) + i1) + vj = T.axis.S(T.int64(128), j0 * T.int64(16) + j1) + B[vi, vj] = A[vi, vj] + 1.0 + + @T.prim_func + def after_single_elementwise_int64_blockize( + A: T.Buffer[(T.int64(16), T.int64(128)), "float32"], + B: T.Buffer[(T.int64(16), T.int64(128)), "float32"], + ) -> None: + for i0, j0 in T.grid(T.int64(1), T.int64(8)): + with T.block("B_o"): + vi_o = T.axis.spatial(T.int64(1), T.int64(0)) + vj_o = T.axis.spatial(T.int64(8), j0) + for i1, j1 in T.grid(T.int64(16), T.int64(16)): + with T.block("B"): + vi_i, vj_i = T.axis.remap("SS", [i1, j1]) + B[vi_i, vj_o * T.int64(16) + vj_i] = A[ + vi_i, vj_o * T.int64(16) + vj_i + ] + T.float32(1) + + s = tir.Schedule(single_elementwise_int64, debug_mask="all") + _, _, i1, _ = s.get_loops(s.get_block("B")) + s.blockize(i1) + tvm.ir.assert_structural_equal(s.mod["main"], after_single_elementwise_int64_blockize) + verify_trace_roundtrip(sch=s, mod=single_elementwise_int64) + + if __name__ == "__main__": tvm.testing.main() From 6e1c5e0bb20ba27aafc1138028d5bd32577543af Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 13 Oct 2022 15:55:22 -0700 Subject: [PATCH 2/3] Fix handling of int64 extent in tensorize --- src/tir/schedule/ir_comparator.cc | 12 +-- .../schedule/primitive/blockize_tensorize.cc | 2 +- .../unittest/test_tir_schedule_tensorize.py | 85 +++++++++++++++++++ 3 files changed, 93 insertions(+), 6 deletions(-) diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 93cb488eaf56..ea0ac0bc733d 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -72,9 +72,9 @@ bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { } bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { - bool equal = - n.same_as(other) || ((n->type_index() == other->type_index()) && n->dtype == other->dtype && - ExprComparator::VisitExpr(n, other)); + bool equal = n.same_as(other) || + ((n->type_index() == other->type_index()) && + n.dtype().code() == other.dtype().code() && ExprComparator::VisitExpr(n, other)); if (!equal && assert_mode_) { std::ostringstream os; os << "Expression mismatch: " << n << " vs " << other; @@ -185,7 +185,7 @@ bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { const auto* rhs = other.as(); auto lhs = GetRef(op); if (lhs.same_as(other)) return true; - if (op->dtype != rhs->dtype) return false; + if (op->dtype.code() != rhs->dtype.code()) return false; auto it = equal_map_.find(lhs); return it != equal_map_.end() && it->second.same_as(other); } @@ -208,7 +208,9 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { if (it != equal_map_.end()) return it->second.same_as(rhs); // Otherwise remap lhs to rhs equal_map_[lhs] = rhs; - analyzer_.Bind(lhs, rhs); + // Cast if necessary. This allows the workload and the tensor intrin to have different dtypes in + // the indices. + analyzer_.Bind(lhs, cast(lhs.dtype(), rhs)); return true; } diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 7481a7c92494..98e30117e172 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -572,7 +572,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } for (int i = 0; i < static_cast(old_region.size()); i++) { PrimExpr min = indices_base[i + offset]; - PrimExpr extent = old_region[i]->extent; + PrimExpr extent = cast(min.dtype(), old_region[i]->extent); new_region.push_back(Range::FromMinExtent(min, extent)); } match_buffer_regions.push_back(MatchBufferRegion(impl, BufferRegion(cur, new_region))); diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index f04de8e0051f..ec984ee7cff3 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -653,5 +653,90 @@ def test_tensor_intrin_look_up(): tir.TensorIntrin.get(intrin_name) +def test_tensorize_matmul_mixed_dtype(): + # fmt: off + @T.prim_func + def matmul_int64_shape( + A: T.Buffer[(T.int64(128), T.int64(128)), "float32"], + B: T.Buffer[(T.int64(128), T.int64(128)), "float32"], + C: T.Buffer[(T.int64(128), T.int64(128)), "float32"] + ) -> None: + for i_0, j_0 in T.grid(T.int64(8), T.int64(8)): + for i_1_init, j_1_init in T.grid(T.int64(16), T.int64(16)): + with T.block("init"): + vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1_init) + vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1_init) + C[vi, vj] = T.float32(0) + for k_0, i_1, j_1, k_1 in T.grid(T.int64(8), T.int64(16), T.int64(16), T.int64(16)): + with T.block("update"): + vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1) + vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1) + vk = T.axis.reduce(T.int64(128), k_0 * T.int64(16) + k_1) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + @T.prim_func + def tensorized_matmul_int64_shape( + A: T.Buffer[(T.int64(128), T.int64(128)), "float32"], + B: T.Buffer[(T.int64(128), T.int64(128)), "float32"], + C: T.Buffer[(T.int64(128), T.int64(128)), "float32"] + ) -> None: + for i_outer, j_outer in T.grid(T.int64(8), T.int64(8)): + for i_inner_init, j_inner_init in T.grid(T.int64(16), T.int64(16)): + with T.block("init"): + vi = T.axis.spatial(T.int64(128), i_outer * T.int64(16) + i_inner_init) + vj = T.axis.spatial(T.int64(128), j_outer * T.int64(16) + j_inner_init) + C[vi, vj] = T.float32(0) + for k_outer in T.grid(T.int64(8)): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) + T.reads( + [ + C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)], + A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)], + B[vj * T.int64(16) : vj * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)], + ] + ) + T.writes(C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)]) + A_elem_offset = T.var("int32") + B_elem_offset = T.var("int32") + C_elem_offset = T.var("int32") + A_sub = T.match_buffer( + A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)], + [16, 16], + elem_offset=A_elem_offset, + ) + B_sub = T.match_buffer( + B[vj * T.int64(16) : vj * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)], + [16, 16], + elem_offset=B_elem_offset, + ) + C_sub = T.match_buffer( + C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)], + [16, 16], + elem_offset=C_elem_offset, + ) + T.evaluate( + T.tvm_mma_sync( + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + A_sub.data, + T.floordiv(A_sub.elem_offset, 256), + B_sub.data, + T.floordiv(B_sub.elem_offset, 256), + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + dtype="handle", + ) + ) + # fmt: on + + s = tir.Schedule(matmul_int64_shape, debug_mask="all") + update = s.get_block("update") + ii = s.get_loops(update)[-3] + s.tensorize(ii, "test_mma_intrin") + tvm.ir.assert_structural_equal(s.mod["main"], tensorized_matmul_int64_shape) + verify_trace_roundtrip(sch=s, mod=func) + + if __name__ == "__main__": tvm.testing.main() From d85a366ceca451899cdb0bef737991f58c614923 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 17 Oct 2022 10:16:44 -0700 Subject: [PATCH 3/3] Update layout_transformation.cc --- src/tir/schedule/primitive/layout_transformation.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 9b60c2240f84..e4c91dac582c 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1210,8 +1210,9 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, if (iter_type == kOpaque) { throw OpaqueNewIterTypeError(self->mod, GetRef(block_ptr), transformed_block_iters[i]); } + auto dtype = new_block_var.dtype(); new_block_iters.push_back(IterVar( - /*dom=*/Range::FromMinExtent(make_zero(new_block_var.dtype()), new_block_iter_range[i]), + /*dom=*/Range::FromMinExtent(make_zero(dtype), new_block_iter_range[i]), /*var=*/std::move(new_block_var), /*iter_type=*/iter_type)); }