-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Fast exponent #4790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast exponent #4790
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -360,5 +360,85 @@ inline Tensor full_like(const Tensor& x, | |
| }, name, tag); | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Fast exponential function implementation | ||
| * | ||
| * \param _x The input tensor | ||
| * \param name The name of the operation | ||
| * \param tag The tag to mark the operation | ||
| * | ||
| * \return A Tensor whose op member is exponent operation | ||
| * | ||
| * \note Function computes: | ||
| * log2(e^x) = x * log2(e) * log2(2) => | ||
| * log2(e^x) = log2(2^(x*log2(e))) => | ||
| * e^x = 2^(x*log2(e)) | ||
| * Splitting power x*log2(e) into integer and fractional parts: | ||
| * e^(n+f) = e^n * e^f | ||
| * n = floor(x*log2(e) + 1/2) | ||
| * f = x - n * ln(2) | ||
| * exp(x) = 2^n * exp(y) | ||
| * Approximation for fractional part: | ||
| * y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2)) | ||
| */ | ||
| inline Tensor fast_exp_float32(const Tensor& _x, | ||
| std::string name, | ||
| std::string tag) { | ||
| auto x_hi = make_const(DataType::Float(32), 88.3762626647950f); | ||
| auto x_lo = make_const(DataType::Float(32), -88.3762626647949f); | ||
| auto log2e = make_const(DataType::Float(32), 1.44269504088896341f); | ||
| auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f); | ||
| PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f), | ||
| make_const(DataType::Float(32), 1.3981999507E-3f), | ||
| make_const(DataType::Float(32), 8.3334519073E-3f), | ||
| make_const(DataType::Float(32), 4.1665795894E-2f), | ||
| make_const(DataType::Float(32), 1.6666665459E-1f), | ||
| make_const(DataType::Float(32), 5.0000001201E-1f)}; | ||
| auto one = make_const(DataType::Float(32), 1.0f); | ||
| auto one_half = make_const(DataType::Float(32), 0.5f); | ||
| auto b = make_const(DataType::Float(32), 127.0f); | ||
|
|
||
| return compute(_x->shape, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A high-level design question - Can we do this at Relay level? Relay can then fuse things accordingly.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somewhat relevant PR was l2_normalize where unfolding at Relay leads to much better performance - #4795 |
||
| [&](const Array<Var>& i) { | ||
| // clamp x | ||
| auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo); | ||
| // integer part | ||
| auto n = ::tvm::floor(x * log2e + one_half); | ||
| // fractional part | ||
| auto f = x - n * ln2; | ||
| auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3])* f+ p[4]) * f | ||
| + p[5]) * f * f + f + one; | ||
| // Return 2^m * exp(r). | ||
| auto ef = tvm::reinterpret(DataType::Float(32), | ||
| ::tvm::cast(DataType::Int(32), n + b) << 23); | ||
| return ::tvm::max(ef * y, _x(i)); // NOLINT(*) | ||
| }, | ||
| name, tag); | ||
| } | ||
|
|
||
|
|
||
| /*! | ||
| * \brief Fast exponential function implementation | ||
| * | ||
| * \param x The input tensor | ||
| * \param name The name of the operation | ||
| * \param tag The tag to mark the operation | ||
| * | ||
| * \return A Tensor whose op member is exponent operation | ||
| * | ||
| */ | ||
| inline Tensor fast_exp(const Tensor& x, | ||
| std::string name = "T_fast_exp", | ||
| std::string tag = kElementWise) { | ||
| if (x->dtype == DataType::Float(32)) { | ||
| auto ret = fast_exp_float32(x, name, tag); | ||
| return ret; | ||
| } else { | ||
| return compute(x->shape, [&](const Array<Var>& i) { | ||
| return ::tvm::exp(x(i)); | ||
| }, name, tag); | ||
| } | ||
| } | ||
|
|
||
| } // namespace topi | ||
| #endif // TOPI_ELEMWISE_H_ | ||
Uh oh!
There was an error while loading. Please reload this page.