-
Notifications
You must be signed in to change notification settings - Fork 36
Add rules for LogExpFunctions #69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cba9df4
1514bf4
06f4025
34c47b1
d34683a
b4438ff
95f1a61
be9d649
ebc9987
da0a416
e1c8cfd
1449b08
69cad0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,5 @@ | ||
| name: TagBot | ||
| on: | ||
| schedule: | ||
| - cron: 0 * * * * | ||
| issue_comment: | ||
| types: | ||
| - created | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| *.jl.cov | ||
| *.jl.*.cov | ||
| *.jl.mem | ||
| /Manifest.toml |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,20 @@ | ||
| using Documenter, DiffRules | ||
|
|
||
| DocMeta.setdocmeta!( | ||
| DiffRules, | ||
| :DocTestSetup, | ||
| :(using DiffRules); | ||
| recursive=true, | ||
| ) | ||
|
|
||
| makedocs(modules=[DiffRules], | ||
| doctest = false, | ||
| sitename = "DiffRules", | ||
| pages = ["Documentation" => "index.md"], | ||
| format = Documenter.HTML( | ||
| prettyurls = get(ENV, "CI", nothing) == "true" | ||
| ), | ||
| strict=true, | ||
| checkdocs=:exports, | ||
| ) | ||
|
|
||
| deploydocs(; repo="github.com/JuliaDiff/DiffRules.jl", push_preview=true) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,8 @@ __precompile__() | |
|
|
||
| module DiffRules | ||
|
|
||
| import LogExpFunctions | ||
|
|
||
| include("api.jl") | ||
| include("rules.jl") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -232,3 +232,30 @@ end | |
| :(ifelse(($y > $x) | (signbit($y) < signbit($x)), ifelse(isnan($y), zero($y), one($y)), ifelse(isnan($x), one($y), zero($y)))) | ||
| @define_diffrule NaNMath.min(x, y) = :(ifelse(($y < $x) | (signbit($y) > signbit($x)), ifelse(isnan($y), one($x), zero($x)), ifelse(isnan($x), zero($x), one($x)))), | ||
| :(ifelse(($y < $x) | (signbit($y) > signbit($x)), ifelse(isnan($y), zero($y), one($y)), ifelse(isnan($x), one($x), zero($x)))) | ||
|
|
||
| ################### | ||
| # LogExpFunctions # | ||
| ################### | ||
|
|
||
| # unary | ||
| @define_diffrule LogExpFunctions.xlogx(x) = :(1 + log($x)) | ||
| @define_diffrule LogExpFunctions.logistic(x) = :(z = LogExpFunctions.logistic($x); z * (1 - z)) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I missed it but it seems there is no way to reuse the result of the primal computation in DiffRules? |
||
| @define_diffrule LogExpFunctions.logit(x) = :(inv($x * (1 - $x))) | ||
| @define_diffrule LogExpFunctions.log1psq(x) = :(2 * $x / (1 + $x^2)) | ||
| @define_diffrule LogExpFunctions.log1pexp(x) = :(LogExpFunctions.logistic($x)) | ||
| @define_diffrule LogExpFunctions.log1mexp(x) = :(-exp($x - LogExpFunctions.log1mexp($x))) | ||
| @define_diffrule LogExpFunctions.log2mexp(x) = :(-exp($x - LogExpFunctions.log2mexp($x))) | ||
| @define_diffrule LogExpFunctions.logexpm1(x) = :(exp($x - LogExpFunctions.logexpm1($x))) | ||
|
|
||
| # binary | ||
| @define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :($x / $y) | ||
| @define_diffrule LogExpFunctions.logaddexp(x, y) = | ||
| :(exp($x - LogExpFunctions.logaddexp($x, $y))), :(exp($y - LogExpFunctions.logaddexp($x, $y))) | ||
| @define_diffrule LogExpFunctions.logsubexp(x, y) = | ||
| :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)), | ||
| :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z)) | ||
|
|
||
| # only defined in LogExpFunctions >= 0.3.2 | ||
| if isdefined(LogExpFunctions, :xlog1py) | ||
| @define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), :($x / (1 + $y)) | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,49 +1,58 @@ | ||
| if VERSION < v"0.7-" | ||
| using Base.Test | ||
| srand(1) | ||
| else | ||
| using Test | ||
| import Random | ||
| Random.seed!(1) | ||
| end | ||
| import SpecialFunctions, NaNMath | ||
| using DiffRules | ||
| using Test | ||
|
|
||
| import SpecialFunctions, NaNMath, LogExpFunctions | ||
| import Random | ||
| Random.seed!(1) | ||
|
|
||
| function finitediff(f, x) | ||
| ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x)) | ||
| return (f(x + ϵ) - f(x - ϵ)) / (ϵ + ϵ) | ||
| end | ||
|
|
||
| @testset "DiffRules" begin | ||
| @testset "check rules" begin | ||
|
|
||
| non_numeric_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ifelse, 3)] | ||
|
|
||
| for (M, f, arity) in DiffRules.diffrules() | ||
| for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) | ||
| (M, f, arity) ∈ non_numeric_arg_functions && continue | ||
| if arity == 1 | ||
| @test DiffRules.hasdiffrule(M, f, 1) | ||
| deriv = DiffRules.diffrule(M, f, :goo) | ||
| modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? 1 : 0 | ||
| modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth) | ||
| 1.0 | ||
| elseif f === :log1mexp | ||
| -1.0 | ||
| elseif f === :log2mexp | ||
| -0.5 | ||
| else | ||
| 0.0 | ||
| end | ||
| @eval begin | ||
| goo = rand() + $modifier | ||
| @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) | ||
| # test for 2pi functions | ||
| if "mod2pi" == string($M.$f) | ||
| goo = 4pi + $modifier | ||
| @test NaN === $deriv | ||
| let | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added the |
||
| goo = rand() + $modifier | ||
| @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) | ||
| # test for 2pi functions | ||
| if "mod2pi" == string($M.$f) | ||
| goo = 4pi + $modifier | ||
| @test NaN === $deriv | ||
| end | ||
| end | ||
| end | ||
| elseif arity == 2 | ||
| @test DiffRules.hasdiffrule(M, f, 2) | ||
| derivs = DiffRules.diffrule(M, f, :foo, :bar) | ||
| @eval begin | ||
| foo, bar = rand(1:10), rand() | ||
| dx, dy = $(derivs[1]), $(derivs[2]) | ||
| if !(isnan(dx)) | ||
| @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) | ||
| end | ||
| if !(isnan(dy)) | ||
| @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) | ||
| let | ||
| foo, bar = rand(1:10), rand() | ||
| dx, dy = $(derivs[1]), $(derivs[2]) | ||
| if !(isnan(dx)) | ||
| @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) | ||
| end | ||
| if !(isnan(dy)) | ||
| @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) | ||
| end | ||
| end | ||
| end | ||
| elseif arity == 3 | ||
|
|
@@ -72,14 +81,29 @@ derivs = DiffRules.diffrule(:Base, :rem2pi, :x, :y) | |
| for xtype in [:Float64, :BigFloat, :Int64] | ||
| for mode in [:RoundUp, :RoundDown, :RoundToZero, :RoundNearest] | ||
| @eval begin | ||
| x = $xtype(rand(1 : 10)) | ||
| y = $mode | ||
| dx, dy = $(derivs[1]), $(derivs[2]) | ||
| @test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05) | ||
| @test isnan(dy) | ||
| let | ||
| x = $xtype(rand(1 : 10)) | ||
| y = $mode | ||
| dx, dy = $(derivs[1]), $(derivs[2]) | ||
| @test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05) | ||
| @test isnan(dy) | ||
| end | ||
| end | ||
| end | ||
| end | ||
| end | ||
|
|
||
| @testset "diffrules" begin | ||
| rules = @test_deprecated(DiffRules.diffrules()) | ||
| @test Set(M for (M, _, _) in rules) == Set((:Base, :SpecialFunctions, :NaNMath)) | ||
|
|
||
| rules = DiffRules.diffrules(; filter_modules=nothing) | ||
| @test Set(M for (M, _, _) in rules) == Set((:Base, :SpecialFunctions, :NaNMath, :LogExpFunctions)) | ||
|
|
||
| rules = DiffRules.diffrules(; filter_modules=(:Base, :LogExpFunctions)) | ||
| @test Set(M for (M, _, _) in rules) == Set((:Base, :LogExpFunctions)) | ||
| end | ||
| end | ||
|
|
||
| # Test ifelse separately as first argument is boolean | ||
| #= | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andreasnoack I forgot to remove these lines in #70, seems I copied the updated file incorrectly.
schedule:is not needed anymore, so removing this trigger will reduce the amount of runs.All other CI changes were addressed in #70.