diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d63af560d704..759acd1fa506 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -77,6 +77,8 @@ jobs: - name: Minimal Metal Compile-and-Run shell: bash -l {0} run: >- + python -m pytest -v -s 'tests/python/codegen/test_target_codegen_metal.py' + python -m pytest -v -s 'tests/python/codegen/test_target_codegen_gpu_common.py' python -m pytest -v -s 'tests/python/codegen/test_gpu_codegen_allreduce.py::test_allreduce_sum[dims0-metal]' # - name: Test iOS RPC # shell: bash -l {0} diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 2695c43173a0..ea8ccd98b1af 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -53,8 +53,13 @@ struct Direct { std::string operator()(DataType t, std::string name) const { return name; } }; -// Call pure extern function. -template +/*! + * \brief Dispatch pure extern function. + * \param e The call expression. + * \tparam T The function to dispatch. + * \tparam dtype_from_arg Whether the dtype is from the first argument or the call node + */ +template inline PrimExpr DispatchPureExtern(const PrimExpr& e) { const CallNode* call = e.as(); ICHECK(call != nullptr); @@ -64,7 +69,14 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) { ICHECK(op != nullptr); std::string name = op->name; ICHECK_EQ(name.substr(0, 4), "tir."); - name = T()(call->dtype, name.substr(4)); + DataType dtype; + if (dtype_from_arg) { + ICHECK_EQ(call->args.size(), 1U); + dtype = call->args[0].dtype(); + } else { + dtype = call->dtype; + } + name = T()(dtype, name.substr(4)); if (name.length() != 0) { Array new_args = {StringImm(name)}; diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 95fbf7f1a513..79ea7a458ff0 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -54,6 +54,15 @@ struct CUDAMath { } } else if (t.is_bfloat16()) { return 'h' + name; + } else if (t.is_int() || t.is_uint()) { + switch (t.bits()) { + case 32: + return "__" + name; + case 64: + return "__" + name + "ll"; + default: + return ""; + } } return ""; } @@ -133,6 +142,9 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); } +TVM_REGISTER_OP("tir.clz").set_attr( + "cuda.FLowerIntrinsic", DispatchPureExtern); + TVM_REGISTER_OP("tir.floor") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index 50685f6ef269..b7561e86715e 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -52,6 +52,9 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { return Call(call->dtype, T()(call->dtype, Downcast(call->op)), metal_args); } +TVM_REGISTER_OP("tir.clz").set_attr("metal.FLowerIntrinsic", + DispatchPureExtern); + TVM_REGISTER_OP("tir.floor") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 94ab9d8b9d9c..bd9e148b187d 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -31,6 +31,9 @@ namespace codegen { namespace intrin { using tir::FLowerIntrinsic; +TVM_REGISTER_OP("tir.clz").set_attr("opencl.FLowerIntrinsic", + DispatchPureExtern); + TVM_REGISTER_OP("tir.floor") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index a613b8d4bb0c..c03e19137ef0 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -238,10 +238,12 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { } else if (op->op.same_as(Op::Get("tir.clz"))) { DataType before_dtype = before->args[0]->dtype; DataType after_dtype = op->args[0]->dtype; - CHECK(before_dtype.is_int() && (before_dtype.bits() == 32 || before_dtype.bits() == 64)) + CHECK((before_dtype.is_int() || before_dtype.is_uint()) && + (before_dtype.bits() == 32 || before_dtype.bits() == 64)) << "clz only supports 32 or 64 bit integer types, but get type before legalizing: " << before_dtype; - CHECK(after_dtype.is_int() && (after_dtype.bits() == 32 || after_dtype.bits() == 64)) + CHECK((after_dtype.is_int() || after_dtype.is_uint()) && + (after_dtype.bits() == 32 || after_dtype.bits() == 64)) << "clz only supports 32 or 64 bit integer types, but get type after legalizing: " << after_dtype; return e - after_dtype.bits() + before_dtype.bits(); diff --git a/tests/python/codegen/test_target_codegen_gpu_common.py b/tests/python/codegen/test_target_codegen_gpu_common.py new file mode 100644 index 000000000000..2941f366a43b --- /dev/null +++ b/tests/python/codegen/test_target_codegen_gpu_common.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from functools import partial + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm import te + + +@tvm.testing.requires_gpu +@tvm.testing.parametrize_targets("cuda", "metal", "vulkan -supports_int64=1", "opencl") +@pytest.mark.parametrize("dtype", ["int32", "uint32", "int64", "uint64"]) +def test_int_intrin(target, dev, dtype): + test_funcs = [ + (tvm.tir.clz, lambda x, dtype: int(dtype[-2:]) - (len(bin(x)) - 2)), + ] + + def run_test(tvm_intrin, np_func, dtype): + n = 128 + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.compute(A.shape, lambda *i: tvm_intrin(A(*i)), name="B") + func = te.create_prim_func([A, B]) + sch = tvm.tir.Schedule(func) + (x,) = sch.get_loops(sch.get_block("B")) + sch.bind(x, "threadIdx.x") + f = tvm.build(sch.mod, target=target) + a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), dev) + f(a, b) + ref = np.vectorize(partial(np_func, dtype=dtype))(a.numpy()) + tvm.testing.assert_allclose(b.numpy(), ref) + + for func in test_funcs: + run_test(*func, dtype) + + +if __name__ == "__main__": + tvm.testing.main()