From fd5fc7321769dff40f826c843e782834bc8248fa Mon Sep 17 00:00:00 2001 From: Sergey Mironov Date: Tue, 4 Dec 2018 10:06:13 +0300 Subject: [PATCH] Fix missing sigmoid intrinsic in C++ --- python/tvm/intrin.py | 3 --- src/codegen/intrin_rule.cc | 10 ++++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 3207b6112b1d..cd9a108c546a 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -492,6 +492,3 @@ def _rule_float_direct(op): register_intrin_rule("opencl", "exp", _rule_float_direct, override=True) # default pattern for exp register_intrin_rule("default", "exp", _rule_float_suffix, override=True) - -# default pattern for sigmoid -register_intrin_rule("default", "sigmoid", lambda op: 1.0 / (1.0 + exp(-op.args[0]))) diff --git a/src/codegen/intrin_rule.cc b/src/codegen/intrin_rule.cc index 822d515fb8a5..f326fceb6ee8 100644 --- a/src/codegen/intrin_rule.cc +++ b/src/codegen/intrin_rule.cc @@ -24,6 +24,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") +.set_body([](const TVMArgs& args, TVMRetValue* rv){ + Expr e = args[0]; + const Call* call = e.as(); + CHECK(call != nullptr); + + auto one = make_const(call->args[0].type(), 1); + *rv = one / (one + exp(-call->args[0])); + }); + } // namespace intrin } // namespace codegen } // namespace tvm