diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index a6685fe87b48..5d23e854be02 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -79,7 +80,7 @@ class PTXAsyncCopyInjector : public StmtMutator { if (indices_lanes == 1) { auto src_offset = load->indices[0]; auto dst_offset = store->indices[0]; - Array args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + Array args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)}; // use arguments size to indicate whether or not to use predicated cp.async if (predicated) { @@ -114,7 +115,7 @@ class PTXAsyncCopyInjector : public StmtMutator { }(); if (src_offset.defined() && dst_offset.defined()) { return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)})); } } else { @@ -144,7 +145,7 @@ class PTXAsyncCopyInjector : public StmtMutator { if (src_offset.defined() && dst_offset.defined()) { return Evaluate( Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes), predicate_value})); } } diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 61f0892a9cf3..bf68200944d4 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -938,5 +938,39 @@ def complex_compute( assert "setp.ne.b32" in generated_code +class TestMultiplicationNodesAreInligned(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.InjectPTXAsyncCopy() + + def before(A: T.Buffer((32, 128), "float16")): + tx = T.launch_thread("threadIdx.x", T.int64(32)) + A_flattened = T.Buffer((4096,), "float16", data=A.data) + A_shared = T.decl_buffer([4096], "float16", scope="shared") + + T.attr("default", "async_scope", 1) + for i in range(16): + cse_var_1: T.int64 = T.Cast("int64", i) + A_shared[ + T.Ramp(tx * T.int64(128) + cse_var_1 * T.int64(8), T.int64(1), 8) + ] = A_flattened[T.Ramp(tx * T.int64(128) + cse_var_1 * T.int64(8), T.int64(1), 8)] + T.ptx_commit_group() + T.ptx_wait_group(0) + + def expected(A: T.Buffer((32, 128), "float16")): + tx = T.launch_thread("threadIdx.x", T.int64(32)) + A_shared = T.decl_buffer((4096,), "float16", scope="shared") + for i in range(16): + cse_var_1: T.int64 = T.Cast("int64", i) + T.ptx_cp_async( + "float16", + A_shared.data, + T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8), + A.data, + T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8), + 16, + ) + T.ptx_commit_group() + T.ptx_wait_group(0) + + if __name__ == "__main__": tvm.testing.main()