diff --git a/REQUIRE b/REQUIRE index 7ead8571..11f39701 100644 --- a/REQUIRE +++ b/REQUIRE @@ -3,3 +3,4 @@ DiffBase 0.0.3 Compat 0.17.0 Calculus 0.2.0 NaNMath 0.2.2 +SpecialFunctions 0.1.0 diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index b21b957b..3b65d67a 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -8,6 +8,7 @@ using DiffBase: DiffResult import Calculus import NaNMath +import SpecialFunctions ############################# # types/functions/constants # @@ -35,9 +36,17 @@ end #---------------------# const AUTO_DEFINED_UNARY_FUNCS = map(first, Calculus.symbolic_derivatives_1arg()) + const NANMATH_FUNCS = (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10, :lgamma, :log1p) +const SPECIAL_FUNCS = (:erf, :erfc, :erfinv, :erfcinv, :erfi, :erfcx, + :dawson, :digamma, :eta, :zeta, :airyai, :airyaiprime, + :airybi, :airybiprime, :airyaix, :besselj, :besselj0, + :besselj1, :besseljx, :bessely, :bessely0, :bessely1, + :besselyx, :besselh, :hankelh1, :hankelh1x, :hankelh2, + :hankelh2x, :besseli, :besselix, :besselk, :besselkx) + # chunk settings # #----------------# diff --git a/src/dual.jl b/src/dual.jl index dcbf6dc8..6e0d105d 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -310,10 +310,21 @@ for fsym in AUTO_DEFINED_UNARY_FUNCS # exp and sqrt are manually defined below if !(in(fsym, (:exp, :sqrt))) - @eval begin - @inline function Base.$(fsym)(n::Dual) - $(v) = value(n) - return Dual($(fsym)($v), $(deriv) * partials(n)) + is_special_function = in(fsym, SPECIAL_FUNCS) + if is_special_function + @eval begin + @inline function SpecialFunctions.$(fsym)(n::Dual) + $(v) = value(n) + return Dual(SpecialFunctions.$(fsym)($v), $(deriv) * partials(n)) + end + end + end + if !(is_special_function) || VERSION < v"0.6.0-dev.2767" + @eval begin + @inline function Base.$(fsym)(n::Dual) + $(v) = value(n) + return Dual(Base.$(fsym)($v), $(deriv) * partials(n)) + end end end end diff --git a/test/DualTest.jl b/test/DualTest.jl index 4b1e4255..ca9931a3 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -6,6 +6,7 @@ using ForwardDiff: Partials, Dual, value, partials import NaNMath import Calculus +import SpecialFunctions samerng() = MersenneTwister(1) @@ -405,29 +406,29 @@ for N in (0,3), M in (0,4), T in (Int, Float32) try v = :v deriv = Calculus.differentiate(:($(fsym)($v)), v) - is_domain_err_func = fsym in DOMAIN_ERR_FUNCS - is_nanmath_func = fsym in ForwardDiff.NANMATH_FUNCS - is_unsupported_nested_func = fsym in UNSUPPORTED_NESTED_FUNCS - @eval begin - fdnum = $(is_domain_err_func ? FDNUM + 1 : FDNUM) - $(v) = ForwardDiff.value(fdnum) - $(test_approx_diffnums)($(fsym)(fdnum), ForwardDiff.Dual($(fsym)($v), $(deriv) * ForwardDiff.partials(fdnum))) - if $(is_nanmath_func) - $(test_approx_diffnums)(NaNMath.$(fsym)(fdnum), ForwardDiff.Dual(NaNMath.$(fsym)($v), $(deriv) * ForwardDiff.partials(fdnum))) - end - - if $(!(is_unsupported_nested_func)) - nested_fdnum = $(is_domain_err_func ? NESTED_FDNUM + 1 : NESTED_FDNUM) - $(v) = ForwardDiff.value(nested_fdnum) - $(test_approx_diffnums)($(fsym)(nested_fdnum), ForwardDiff.Dual($(fsym)($v), $(deriv) * ForwardDiff.partials(nested_fdnum))) - if $(is_nanmath_func) - $(test_approx_diffnums)(NaNMath.$(fsym)(nested_fdnum), ForwardDiff.Dual(NaNMath.$(fsym)($v), $(deriv) * ForwardDiff.partials(nested_fdnum))) + is_nanmath_func = in(fsym, ForwardDiff.NANMATH_FUNCS) + is_special_func = in(fsym, ForwardDiff.SPECIAL_FUNCS) + is_domain_err_func = in(fsym, DOMAIN_ERR_FUNCS) + is_unsupported_nested_func = in(fsym, UNSUPPORTED_NESTED_FUNCS) + tested_funcs = Vector{Expr}(0) + is_nanmath_func && push!(tested_funcs, :(NaNMath.$(fsym))) + is_special_func && push!(tested_funcs, :(SpecialFunctions.$(fsym))) + (!(is_special_func) || VERSION < v"0.6.0-dev.2767") && push!(tested_funcs, :(Base.$(fsym))) + for func in tested_funcs + @eval begin + fdnum = $(is_domain_err_func ? FDNUM + 1 : FDNUM) + $(v) = ForwardDiff.value(fdnum) + $(test_approx_diffnums)($(func)(fdnum), ForwardDiff.Dual($(func)($v), $(deriv) * ForwardDiff.partials(fdnum))) + if $(!(is_unsupported_nested_func)) + nested_fdnum = $(is_domain_err_func ? NESTED_FDNUM + 1 : NESTED_FDNUM) + $(v) = ForwardDiff.value(nested_fdnum) + $(test_approx_diffnums)($(func)(nested_fdnum), ForwardDiff.Dual($(func)($v), $(deriv) * ForwardDiff.partials(nested_fdnum))) end end end catch err warn("Encountered error when testing $(fsym)(::Dual):") - throw(err) + rethrow(err) end end end