diff --git a/Project.toml b/Project.toml index 6fd16dc2c..a81fc0d4d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.2.0" +version = "0.2.1-DEV" [compat] julia = "^1.0" diff --git a/src/differentials.jl b/src/differentials.jl index 28869122b..3f3e19d62 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -181,6 +181,24 @@ Base.iterate(::One, ::Any) = nothing Thunk(()->v) A thunk is a deferred computation. It wraps a zero argument closure that when invoked returns a differential. + +Calling that thunk, calls the wrapped closure. +`extern`ing thunks applies recursively, it also externs the differial that the closure returns. +If you do not want that, then simply call the thunk + +``` +julia> t = @thunk(@thunk(3)) +Thunk(var"##7#9"()) + +julia> extern(t) +3 + +julia> t() +Thunk(var"##8#10"()) + +julia> t()() +3 +``` """ struct Thunk{F} <: AbstractDifferential f::F @@ -190,7 +208,8 @@ macro thunk(body) return :(Thunk(() -> $(esc(body)))) end -@inline extern(x::Thunk) = x.f() +(x::Thunk)() = x.f() +@inline extern(x::Thunk) = extern(x()) Base.Broadcast.broadcastable(x::Thunk) = broadcastable(extern(x)) @@ -206,3 +225,5 @@ end end Base.conj(x::Thunk) = @thunk(conj(extern(x))) + +Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") diff --git a/test/differentials.jl b/test/differentials.jl index 869663037..557c125d8 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -49,6 +49,25 @@ @test conj(o) == o end + @testset "Thunk" begin + @test @thunk(3) isa Thunk + + @testset "show" begin + rep = repr(Thunk(rand)) + @test occursin(r"Thunk\(.*rand.*\)", rep) + end + + @testset "Externing" begin + @test extern(@thunk(3)) == 3 + @test extern(@thunk(@thunk(3))) == 3 + end + + @testset "calling thunks should call inner function" begin + @test (@thunk(3))() == 3 + @test (@thunk(@thunk(3)))() isa Thunk + end + end + @testset "No ambiguities in $f" for f in (+, *) # We don't use `Test.detect_ambiguities` as we are only interested in # the +, and * operations. We also would catch any that are unrelated