Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/TagBot.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
name: TagBot
on:
schedule:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreasnoack I forgot to remove these lines in #70, seems I copied the updated file incorrectly. schedule: is not needed anymore, so removing this trigger will reduce the amount of runs.

All other CI changes were addressed in #70.

- cron: 0 * * * *
issue_comment:
types:
- created
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.jl.cov
*.jl.*.cov
*.jl.mem
/Manifest.toml
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.3.1"

[deps]
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
LogExpFunctions = "0.3"
NaNMath = "0.3"
SpecialFunctions = "0.8, 0.9, 0.10, 1.0"
julia = "1"
Expand Down
10 changes: 9 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
using Documenter, DiffRules

DocMeta.setdocmeta!(
DiffRules,
:DocTestSetup,
:(using DiffRules);
recursive=true,
)

makedocs(modules=[DiffRules],
doctest = false,
sitename = "DiffRules",
pages = ["Documentation" => "index.md"],
format = Documenter.HTML(
prettyurls = get(ENV, "CI", nothing) == "true"
),
strict=true,
checkdocs=:exports,
)

deploydocs(; repo="github.com/JuliaDiff/DiffRules.jl", push_preview=true)
2 changes: 2 additions & 0 deletions src/DiffRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ __precompile__()

module DiffRules

import LogExpFunctions

include("api.jl")
include("rules.jl")

Expand Down
142 changes: 105 additions & 37 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ interpolated wherever they are used on the RHS.

Note that differentiation rules are purely symbolic, so no type annotations should be used.

Examples:

@define_diffrule Base.cos(x) = :(-sin(\$x))
@define_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2))
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x))
# Examples

```julia
@define_diffrule Base.cos(x) = :(-sin(\$x))
@define_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2))
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x))
```
"""
macro define_diffrule(def)
@assert isa(def, Expr) && def.head == :(=) "Diff rule expression does not have a left and right side"
Expand Down Expand Up @@ -50,19 +51,18 @@ interpolated into the returned expression.
In the `n`-ary case, an `n`-tuple of expressions will be returned where the `i`th expression
is the derivative of `f` w.r.t the `i`th argument.

Examples:

julia> DiffRules.diffrule(:Base, :sin, 1)
:(cos(1))
# Examples

julia> DiffRules.diffrule(:Base, :sin, :x)
:(cos(x))
```jldoctest
julia> DiffRules.diffrule(:Base, :sin, 1)
:(cos(1))

julia> DiffRules.diffrule(:Base, :sin, :(x * y^2))
:(cos(x * y ^ 2))
julia> DiffRules.diffrule(:Base, :sin, :x)
:(cos(x))

julia> DiffRules.diffrule(:Base, :^, :(x + 2), :c)
(:(c * (x + 2) ^ (c - 1)), :((x + 2) ^ c * log(x + 2)))
julia> DiffRules.diffrule(:Base, :sin, :(x * y^2))
:(cos(x * y ^ 2))
```
"""
diffrule(M::Union{Expr,Symbol}, f::Symbol, args...) = DEFINED_DIFFRULES[M,f,length(args)](args...)

Expand All @@ -74,41 +74,109 @@ otherwise.

Here, `arity` refers to the number of arguments accepted by `f`.

Examples:
# Examples

julia> DiffRules.hasdiffrule(:Base, :sin, 1)
true
```jldoctest
julia> DiffRules.hasdiffrule(:Base, :sin, 1)
true

julia> DiffRules.hasdiffrule(:Base, :sin, 2)
false
julia> DiffRules.hasdiffrule(:Base, :sin, 2)
false

julia> DiffRules.hasdiffrule(:Base, :-, 1)
true
julia> DiffRules.hasdiffrule(:Base, :-, 1)
true

julia> DiffRules.hasdiffrule(:Base, :-, 2)
true
julia> DiffRules.hasdiffrule(:Base, :-, 2)
true

julia> DiffRules.hasdiffrule(:Base, :-, 3)
false
julia> DiffRules.hasdiffrule(:Base, :-, 3)
false
```
"""
hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int) = haskey(DEFINED_DIFFRULES, (M, f, arity))

# show a deprecation warning if `filter_modules` in `diffrules()` is specified implicitly
# we use a custom singleton to figure out if the keyword argument was set explicitly
struct DefaultFilterModules end

function deprecated_modules(modules)
return if modules isa DefaultFilterModules
Base.depwarn(
"the implicit keyword argument " *
"`filter_modules=(:Base, :SpecialFunctions, :NaNMath)` in `diffrules()` is " *
"deprecated and will be changed to `filter_modules=nothing` in an upcoming " *
"breaking release of DiffRules (i.e., `diffrules()` will return all rules " *
"defined in DiffRules)",
:diffrules,
)
(:Base, :SpecialFunctions, :NaNMath)
else
modules
end
end

"""
diffrules()
diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath))

Return a list of keys that can be used to access all defined differentiation rules.
Return a list of keys that can be used to access all defined differentiation rules for
modules in `filter_modules`.

Each key is of the form `(M::Symbol, f::Symbol, arity::Int)`.

Here, `arity` refers to the number of arguments accepted by `f`.

Examples:

julia> first(DiffRules.diffrules())
(:Base, :asind, 1)

Here, `arity` refers to the number of arguments accepted by `f` and `M` is one of the
modules in `filter_modules`.

To include all rules, specify `filter_modules = nothing`.

!!! note
Calling `diffrules()` with the implicit default keyword argument `filter_modules`
does *not* return all rules defined by this package but rather only rules for the
packages for which DiffRules 1.0 provided rules. This is done in order to not to
break downstream packages that assumed this list would never change.
It is planned to change `diffrules()` to return all rules, i.e., to use the
default keyword argument `filter_modules=nothing`, in an upcoming breaking release
of DiffRules.

# Examples

```jldoctest
julia> first(DiffRules.diffrules())
(:Base, :log2, 1)
```

If you call `diffrules()`, only rules for Base, SpecialFunctions, and
NaNMath are returned but no rules for LogExpFunctions:
```jldoctest
julia> any(M === :LogExpFunctions for (M, _, _) in DiffRules.diffrules())
false
```

If you set `filter_modules=nothing`, all rules defined in DiffRules are
returned and in particular also rules for LogExpFunctions:
```jldoctest
julia> any(
M === :LogExpFunctions
for (M, _, _) in DiffRules.diffrules(; filter_modules=nothing)
)
true
```

If you set `filter_modules=(:Base,)` only rules for functions in Base are
returned:
```jldoctest
julia> all(M === :Base for (M, _, _) in DiffRules.diffrules(; filter_modules=(:Base,)))
true
```
"""
diffrules() = keys(DEFINED_DIFFRULES)
function diffrules(; filter_modules=DefaultFilterModules())
modules = deprecated_modules(filter_modules)
return if modules === nothing
keys(DEFINED_DIFFRULES)
else
Iterators.filter(keys(DEFINED_DIFFRULES)) do (M, _, _)
return M in modules
end
end
end

# For v0.6 and v0.7 compatibility, need to support having the diff rule function enter as a
# `Expr(:quote...)` and a `QuoteNode`. When v0.6 support is dropped, the function will
Expand Down
27 changes: 27 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,30 @@ end
:(ifelse(($y > $x) | (signbit($y) < signbit($x)), ifelse(isnan($y), zero($y), one($y)), ifelse(isnan($x), one($y), zero($y))))
@define_diffrule NaNMath.min(x, y) = :(ifelse(($y < $x) | (signbit($y) > signbit($x)), ifelse(isnan($y), one($x), zero($x)), ifelse(isnan($x), zero($x), one($x)))),
:(ifelse(($y < $x) | (signbit($y) > signbit($x)), ifelse(isnan($y), zero($y), one($y)), ifelse(isnan($x), one($x), zero($x))))

###################
# LogExpFunctions #
###################

# unary
@define_diffrule LogExpFunctions.xlogx(x) = :(1 + log($x))
@define_diffrule LogExpFunctions.logistic(x) = :(z = LogExpFunctions.logistic($x); z * (1 - z))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I missed it but it seems there is no way to reuse the result of the primal computation in DiffRules?

@define_diffrule LogExpFunctions.logit(x) = :(inv($x * (1 - $x)))
@define_diffrule LogExpFunctions.log1psq(x) = :(2 * $x / (1 + $x^2))
@define_diffrule LogExpFunctions.log1pexp(x) = :(LogExpFunctions.logistic($x))
@define_diffrule LogExpFunctions.log1mexp(x) = :(-exp($x - LogExpFunctions.log1mexp($x)))
@define_diffrule LogExpFunctions.log2mexp(x) = :(-exp($x - LogExpFunctions.log2mexp($x)))
@define_diffrule LogExpFunctions.logexpm1(x) = :(exp($x - LogExpFunctions.logexpm1($x)))

# binary
@define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :($x / $y)
@define_diffrule LogExpFunctions.logaddexp(x, y) =
:(exp($x - LogExpFunctions.logaddexp($x, $y))), :(exp($y - LogExpFunctions.logaddexp($x, $y)))
@define_diffrule LogExpFunctions.logsubexp(x, y) =
:(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)),
:(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z))

# only defined in LogExpFunctions >= 0.3.2
if isdefined(LogExpFunctions, :xlog1py)
@define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), :($x / (1 + $y))
end
82 changes: 53 additions & 29 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,49 +1,58 @@
if VERSION < v"0.7-"
using Base.Test
srand(1)
else
using Test
import Random
Random.seed!(1)
end
import SpecialFunctions, NaNMath
using DiffRules
using Test

import SpecialFunctions, NaNMath, LogExpFunctions
import Random
Random.seed!(1)

function finitediff(f, x)
ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x))
return (f(x + ϵ) - f(x - ϵ)) / (ϵ + ϵ)
end

@testset "DiffRules" begin
@testset "check rules" begin

non_numeric_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ifelse, 3)]

for (M, f, arity) in DiffRules.diffrules()
for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
(M, f, arity) ∈ non_numeric_arg_functions && continue
if arity == 1
@test DiffRules.hasdiffrule(M, f, 1)
deriv = DiffRules.diffrule(M, f, :goo)
modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? 1 : 0
modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)
1.0
elseif f === :log1mexp
-1.0
elseif f === :log2mexp
-0.5
else
0.0
end
@eval begin
goo = rand() + $modifier
@test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05)
# test for 2pi functions
if "mod2pi" == string($M.$f)
goo = 4pi + $modifier
@test NaN === $deriv
let
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the let blocks to fix some local/global scope warnings.

goo = rand() + $modifier
@test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05)
# test for 2pi functions
if "mod2pi" == string($M.$f)
goo = 4pi + $modifier
@test NaN === $deriv
end
end
end
elseif arity == 2
@test DiffRules.hasdiffrule(M, f, 2)
derivs = DiffRules.diffrule(M, f, :foo, :bar)
@eval begin
foo, bar = rand(1:10), rand()
dx, dy = $(derivs[1]), $(derivs[2])
if !(isnan(dx))
@test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05)
end
if !(isnan(dy))
@test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05)
let
foo, bar = rand(1:10), rand()
dx, dy = $(derivs[1]), $(derivs[2])
if !(isnan(dx))
@test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05)
end
if !(isnan(dy))
@test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05)
end
end
end
elseif arity == 3
Expand Down Expand Up @@ -72,14 +81,29 @@ derivs = DiffRules.diffrule(:Base, :rem2pi, :x, :y)
for xtype in [:Float64, :BigFloat, :Int64]
for mode in [:RoundUp, :RoundDown, :RoundToZero, :RoundNearest]
@eval begin
x = $xtype(rand(1 : 10))
y = $mode
dx, dy = $(derivs[1]), $(derivs[2])
@test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05)
@test isnan(dy)
let
x = $xtype(rand(1 : 10))
y = $mode
dx, dy = $(derivs[1]), $(derivs[2])
@test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05)
@test isnan(dy)
end
end
end
end
end

@testset "diffrules" begin
rules = @test_deprecated(DiffRules.diffrules())
@test Set(M for (M, _, _) in rules) == Set((:Base, :SpecialFunctions, :NaNMath))

rules = DiffRules.diffrules(; filter_modules=nothing)
@test Set(M for (M, _, _) in rules) == Set((:Base, :SpecialFunctions, :NaNMath, :LogExpFunctions))

rules = DiffRules.diffrules(; filter_modules=(:Base, :LogExpFunctions))
@test Set(M for (M, _, _) in rules) == Set((:Base, :LogExpFunctions))
end
end

# Test ifelse separately as first argument is boolean
#=
Expand Down