diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index c96245e1399c..7c4b38c1d702 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -182,10 +182,18 @@ TVM_REGISTER_OP("tir.sigmoid") useqhl = tstring.find("+hvx-qfloat") != std::string::npos; } + PrimExpr MinBound = tir::make_const(x.dtype(), -8); + PrimExpr MaxBound = tir::make_const(x.dtype(), 8); + const PrimExpr v1 = tir::Max(x, MinBound); + const PrimExpr v2 = tir::Min(v1, MaxBound); + + Array new_args = {v2}; + const tir::Call new_call = tir::Call(call->dtype, call->op, new_args); + // Enable QHL library for FP16 data type if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf"); - return TVMExternCall(call, tvm_wrapper); + return TVMExternCall(new_call.get(), tvm_wrapper); } #endif PrimExpr one = tir::make_const(x.dtype(), 1); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 509badbebb92..3f7f05fe8e64 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -866,7 +866,7 @@ TIR_REGISTER_PURE_UNARY_OP("tir.erf"); TIR_REGISTER_PURE_UNARY_OP("tir.tanh").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.sigmoid"); +TIR_REGISTER_PURE_UNARY_OP("tir.sigmoid").set_attr("TVectorizable", true); TIR_REGISTER_PURE_UNARY_OP("tir.sqrt").set_attr("TVectorizable", true); diff --git a/tests/python/contrib/test_hexagon/test_sigmoid.py b/tests/python/contrib/test_hexagon/test_sigmoid.py new file mode 100644 index 000000000000..9aad35ee76c1 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_sigmoid.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm import te +from tvm import tir +from tvm import topi +from tvm.contrib.hexagon.build import HexagonLauncher + +from .infrastructure import allocate_hexagon_array, transform_numpy + + +def sigmoid_compute(Input): + return topi.sigmoid(Input) + + +def sigmoid_stir_schedule(Input, Output): + sigmoid_func = te.create_prim_func([Input, Output]) + sch = tir.Schedule(sigmoid_func, debug_mask="all") + block = sch.get_block("compute") + + (n,) = sch.get_loops(block) + sch.vectorize(n) + return sch + + +@tvm.testing.fixture +def input_np(in_shape, dtype, min_val, max_val): + return np.random.uniform(low=min_val, high=max_val, size=in_shape).astype(dtype) + + +@tvm.testing.fixture +def ref_output_np(input_np): + output_np = 1 / (1 + np.exp(-input_np)) + return output_np + + +class BaseSigmoid: + (in_shape, dtype, min_val, max_val,) = tvm.testing.parameters( + ((64,), "float16", -8.0, 8.0), + ((64,), "float16", -6.0, 7.0), + ((64,), "float16", -10.0, 15.0), + ((64,), "float16", -10.0, 0.0), + ((64,), "float16", 0.0, 10.0), + ) + + +class TestSigmoid(BaseSigmoid): + @tvm.testing.requires_hexagon + def test_sigmoid( + self, + in_shape, + dtype, + input_np, + ref_output_np, + target, + hexagon_session, + ): + InputTensor = te.placeholder(in_shape, name="InputTensor", dtype=dtype) + + OutputTensor = sigmoid_compute(InputTensor) + + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + tir_s = sigmoid_stir_schedule(InputTensor, OutputTensor) + + input_data = allocate_hexagon_array( + hexagon_session.device, + data=input_np, + ) + output_data = allocate_hexagon_array( + hexagon_session.device, + tensor_shape=ref_output_np.shape, + dtype=ref_output_np.dtype, + ) + + func_name = "sigmoid" + with tvm.transform.PassContext(opt_level=3): + runtime_module = tvm.build(tir_s.mod, target=target, name=func_name) + + assert "hvx_sigmoid" in runtime_module.get_source("asm") + assert "vmin" in runtime_module.get_source("asm") + assert "vmax" in runtime_module.get_source("asm") + mod = hexagon_session.load_module(runtime_module) + + mod(input_data, output_data) + output_np = output_data.numpy() + + tvm.testing.assert_allclose( + output_np, + ref_output_np, + 1e-3, + 1e-3, + ) + + +if __name__ == "__main__": + tvm.testing.main()