From 604bf784cf8b1ab0a46d3055aebacde9d39576cf Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 6 Sep 2019 12:10:18 +0100 Subject: [PATCH 1/5] make extern(::Thunk) recursive, and improve show(::Thunk) --- src/differentials.jl | 23 ++++++++++++++++++++++- test/differentials.jl | 19 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/differentials.jl b/src/differentials.jl index 28869122b..3f3e19d62 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -181,6 +181,24 @@ Base.iterate(::One, ::Any) = nothing Thunk(()->v) A thunk is a deferred computation. It wraps a zero argument closure that when invoked returns a differential. + +Calling that thunk, calls the wrapped closure. +`extern`ing thunks applies recursively, it also externs the differial that the closure returns. +If you do not want that, then simply call the thunk + +``` +julia> t = @thunk(@thunk(3)) +Thunk(var"##7#9"()) + +julia> extern(t) +3 + +julia> t() +Thunk(var"##8#10"()) + +julia> t()() +3 +``` """ struct Thunk{F} <: AbstractDifferential f::F @@ -190,7 +208,8 @@ macro thunk(body) return :(Thunk(() -> $(esc(body)))) end -@inline extern(x::Thunk) = x.f() +(x::Thunk)() = x.f() +@inline extern(x::Thunk) = extern(x()) Base.Broadcast.broadcastable(x::Thunk) = broadcastable(extern(x)) @@ -206,3 +225,5 @@ end end Base.conj(x::Thunk) = @thunk(conj(extern(x))) + +Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") diff --git a/test/differentials.jl b/test/differentials.jl index 869663037..1085f1b4e 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -49,6 +49,25 @@ @test conj(o) == o end + @testset "Thunk" begin + @test @thunk(3) isa Thunk + + @testset "show" begin + rep = repr(Thunk(rand)) + @test occursin(r"Thunk\(.*rand.*\)", rep) + end + + @testset "Externing" begin + @test extern(@thunk(3)) == 3 + @test extern(@thunk(@thunk(3))) == 3 + end + + @testset "calling thunks should call inner function" begin + @test (@thunk(3))() == 3 + @test (@thunk(@thunk(3)))() isa Thunk + end + end + @testset "No ambiguities in $f" for f in (+, *) # We don't use `Test.detect_ambiguities` as we are only interested in # the +, and * operations. We also would catch any that are unrelated From b6f3be1794dd424f57715a357d1dafb26ca9b37d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 6 Sep 2019 12:12:46 +0100 Subject: [PATCH 2/5] bump version (breaking) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6fd16dc2c..938852959 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.2.0" +version = "0.3.0" [compat] julia = "^1.0" From 85b0a9ab579f225d607c8d7d75d4feb07c0666d3 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 6 Sep 2019 13:00:45 +0100 Subject: [PATCH 3/5] make this a unreleased version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 938852959..af8d5996d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.3.0" +version = "0.3.0-DEV" [compat] julia = "^1.0" From 0f5aecf31a61f3a2e087e8f00757290ab97d706e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 6 Sep 2019 13:16:20 +0100 Subject: [PATCH 4/5] mark nonbreaking --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index af8d5996d..a81fc0d4d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.3.0-DEV" +version = "0.2.1-DEV" [compat] julia = "^1.0" From 5d48459193fceca06146cd4cba6da9e9c271b87e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 6 Sep 2019 16:45:36 +0100 Subject: [PATCH 5/5] fix tabs --- test/differentials.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test/differentials.jl b/test/differentials.jl index 1085f1b4e..557c125d8 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -49,24 +49,24 @@ @test conj(o) == o end - @testset "Thunk" begin - @test @thunk(3) isa Thunk + @testset "Thunk" begin + @test @thunk(3) isa Thunk - @testset "show" begin - rep = repr(Thunk(rand)) - @test occursin(r"Thunk\(.*rand.*\)", rep) - end + @testset "show" begin + rep = repr(Thunk(rand)) + @test occursin(r"Thunk\(.*rand.*\)", rep) + end - @testset "Externing" begin - @test extern(@thunk(3)) == 3 - @test extern(@thunk(@thunk(3))) == 3 - end + @testset "Externing" begin + @test extern(@thunk(3)) == 3 + @test extern(@thunk(@thunk(3))) == 3 + end - @testset "calling thunks should call inner function" begin - @test (@thunk(3))() == 3 - @test (@thunk(@thunk(3)))() isa Thunk - end - end + @testset "calling thunks should call inner function" begin + @test (@thunk(3))() == 3 + @test (@thunk(@thunk(3)))() isa Thunk + end + end @testset "No ambiguities in $f" for f in (+, *) # We don't use `Test.detect_ambiguities` as we are only interested in