diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index a6685fe87b48..5d23e854be02 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -79,7 +80,7 @@ class PTXAsyncCopyInjector : public StmtMutator { if (indices_lanes == 1) { auto src_offset = load->indices[0]; auto dst_offset = store->indices[0]; - Array args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + Array args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)}; // use arguments size to indicate whether or not to use predicated cp.async if (predicated) { @@ -114,7 +115,7 @@ class PTXAsyncCopyInjector : public StmtMutator { }(); if (src_offset.defined() && dst_offset.defined()) { return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)})); } } else { @@ -144,7 +145,7 @@ class PTXAsyncCopyInjector : public StmtMutator { if (src_offset.defined() && dst_offset.defined()) { return Evaluate( Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes), predicate_value})); } } diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 99ed4376590e..25c10dd6828d 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -417,7 +417,8 @@ Array GetBufferAllocationShape(const Buffer& buffer) { if (buffer->strides.size()) { ICHECK_EQ(buffer->shape.size(), buffer->strides.size()); for (size_t i = buffer->strides.size() - 1; i > 0; --i) { - ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i]))); + ICHECK( + arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0)); alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]); } } diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py index 444e36bfbb7a..95df26e66fd8 100644 --- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py +++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py @@ -250,6 +250,34 @@ def transformed_strided_buffer_func( C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2) +@T.prim_func +def compacted_symbolic_strided_buffer_func(a: T.handle) -> None: + n = T.int64() + A = T.match_buffer(a, (1, n, 10240), "float32") + for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): + with T.block(""): + T.reads(A[0, i * 128 + j * 32:i * 128 + j * 32 + 96, k * 64:k * 64 + 64]) + A_pad_shared_dyn = T.alloc_buffer((1, T.min((n + 63) // 64 * 64, 96), 64), "float32", strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), scope="shared.dyn") + for ax0, ax1 in T.grid(96, 64): + with T.block("A_pad_shared.dyn"): + T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64) + T.reads(A[0, i * 128 + j * 32 + ax0, k * 64 + ax1]) + T.writes(A_pad_shared_dyn[0, ax0, ax1]) + A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(i * 128 + j * 32 + ax0 < n, A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], T.float16(0)) + + +@T.prim_func +def transformed_symbolic_strided_buffer_func(a: T.handle): + n = T.int64() + A = T.match_buffer(a, (1, n, 10240)) + for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), 2, 160): + A_pad_shared_dyn = T.allocate([1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72], "float32", "shared.dyn") + A_pad_shared_dyn_1 = T.decl_buffer((1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 64), data=A_pad_shared_dyn, strides=(T.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72, 1), scope="shared.dyn") + for ax0, ax1 in T.grid(96, 64): + if i * T.int64(128) + T.Cast("int64", j) * T.int64(32) + T.Cast("int64", ax0) < (n + T.int64(63)) // T.int64(64) * T.int64(64): + A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else(i * T.int64(128) + T.Cast("int64", j) * T.int64(32) + T.Cast("int64", ax0) < n, A[0, i * T.int64(128) + T.Cast("int64", j) * T.int64(32) + T.Cast("int64", ax0), k * 64 + ax1], T.float32(0)) + + @T.prim_func def annotated_loops(a: T.handle) -> None: A = T.match_buffer(a, (16,), "float32") @@ -301,6 +329,10 @@ def test_strided_buffer(): _check(compacted_strided_buffer_func, transformed_strided_buffer_func) +def test_symbolic_strided_buffer(): + _check(compacted_symbolic_strided_buffer_func, transformed_symbolic_strided_buffer_func) + + def test_lower_te(): x = te.placeholder((1,)) y = te.compute((1,), lambda i: x[i] + 2)