diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index b166b16b7721..f33432645cc3 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -632,6 +632,25 @@ TVM_DLL const Op& ptx_mma_sp(); */ TVM_DLL const Op& ptx_ldmatrix(); +/*! + * \brief tvm intrinsics for ptx async copy from global to shared memory + * + * void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t + * bytes); + * + */ +TVM_DLL const Op& ptx_cp_async(); + +/*! + * \brief tvm intrinsics for ptx async copy commit and wait. + * + * void ptx_commit_group(); + * void ptx_wait_group(int num); + * + */ +TVM_DLL const Op& ptx_commit_group(); +TVM_DLL const Op& ptx_wait_group(); + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index d4ec536fb001..7459d4c250ba 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -821,6 +821,18 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string smem_elem_offset = this->PrintExpr(op->args[6]); this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, smem_ptr, smem_elem_offset); + } else if (op->op.same_as(builtin::ptx_cp_async())) { + std::string dst = this->PrintExpr(op->args[0]); + std::string dst_offset = this->PrintExpr(op->args[1]); + std::string src = this->PrintExpr(op->args[2]); + std::string src_offset = this->PrintExpr(op->args[3]); + std::string size = this->PrintExpr(op->args[4]); + this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); + } else if (op->op.same_as(builtin::ptx_commit_group())) { + this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; + } else if (op->op.same_as(builtin::ptx_wait_group())) { + std::string N = this->PrintExpr(op->args[0]); + this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + ";\");\n\n"; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 02a98ffbbabd..71c68baed6dc 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -638,5 +638,31 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type return asm_code; } +std::string PrintCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, const std::string& bytes) { + std::string asm_code = R"( + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)({smem_addr})) + ); + __asm__ __volatile__( + "cp.async.cg.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) + ); + } +)"; + Replacer replacer; + replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h index c4255d737ad0..c811a1b9c1d6 100644 --- a/src/target/source/ptx.h +++ b/src/target/source/ptx.h @@ -79,6 +79,19 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type const std::string& smem_ptr, const std::string& smem_elem_offset); +/*! + * \brief Print ptx cp.async assembly string given parameters. + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + */ +std::string PrintCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, const std::string& bytes); + } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 4e8d83dd32df..0415d1bbec9e 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -247,6 +247,15 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp) TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py new file mode 100644 index 000000000000..17b60885509f --- /dev/null +++ b/tests/python/unittest/test_tir_ptx_cp_async.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm.script import tir as T +import numpy as np +import tvm.testing + + +@T.prim_func +def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"]) -> None: + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + bx = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(bx, 1) + T.launch_thread(tx, 32) + with T.block(): + A_shared = T.alloc_buffer([32, 128], "float16", scope="shared") + T.reads(A[0:32, 0:128]) + T.writes(B[0:32, 0:128]) + + for i in range(16): + T.evaluate( + T.ptx_cp_async( + A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16" + ) + ) + + # TODO(masahi): Remove dtype requirement from TVMScript parser + T.evaluate(T.ptx_commit_group(dtype="float16")) + T.evaluate(T.ptx_wait_group(0, dtype="float16")) + + for i in range(128): + B[tx, i] = A_shared[tx, i] + + +@tvm.testing.requires_cuda +def test_ptx_cp_async(): + f = ptx_cp_async + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + + mod = tvm.build(f, target="cuda") + A_np = np.random.rand(32, 128).astype("float16") + B_np = np.zeros((32, 128)).astype("float16") + dev = tvm.cuda(0) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +if __name__ == "__main__": + test_ptx_cp_async()