diff --git a/src/dual.jl b/src/dual.jl index d52c5736..03330c00 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -393,7 +393,7 @@ Base.float(d::Dual) = convert(float(typeof(d)), d) ################################### for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) - if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-)) + if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-), (:Base, :sin), (:Base, :cos)) continue # Skip methods which we define elsewhere. elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f)) continue # Skip rules for methods not defined in the current scope @@ -622,12 +622,19 @@ end Dual{Tz}(muladd(x, y, value(z)), partials(z)) # z_body ) -# sincos # +# sin/cos # #--------# +function Base.sin(d::Dual{T}) where T + s, c = sincos(value(d)) + return Dual{T}(s, c * partials(d)) +end -@inline sincos(x) = (sin(x), cos(x)) +function Base.cos(d::Dual{T}) where T + s, c = sincos(value(d)) + return Dual{T}(c, -s * partials(d)) +end -@inline function sincos(d::Dual{T}) where T +@inline function Base.sincos(d::Dual{T}) where T sd, cd = sincos(value(d)) return (Dual{T}(sd, cd * partials(d)), Dual{T}(cd, -sd * partials(d))) end