diff --git a/Project.toml b/Project.toml index e0727aa64..f5273df1d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.1" +version = "0.10.2" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/docs/src/api.md b/docs/src/api.md index 1fffafd8f..fc6ac50ab 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -39,3 +39,8 @@ ChainRulesCore.is_inplaceable_destination ChainRulesCore.AbstractTangent ChainRulesCore.debug_mode ``` + +## Deprecated +```@docs +ChainRulesCore.extern +``` diff --git a/src/deprecated.jl b/src/deprecated.jl index ea7df0674..8e60d4073 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1 +1,53 @@ Base.@deprecate_binding NO_FIELDS NoTangent() + +const EXTERN_DEPRECATION = "`extern` is deprecated, use `unthunk` or `backing` instead, " * + "depending on the use case." + +""" + extern(x) + +Makes a best effort attempt to convert a differential into a primal value. +This is not always a well-defined operation. +For two reasons: + - It may not be possible to determine the primal type for a given differential. + For example, `Zero` is a valid differential for any primal. + - The primal type might not be a vector space, thus might not be a valid differential type. + For example, if the primal type is `DateTime`, it's not a valid differential type as two + `DateTime` can not be added (fun fact: `Milisecond` is a differential for `DateTime`). + +Where it is defined the operation of `extern` for a primal type `P` should be +`extern(x) = zero(P) + x`. + +!!! note + Because of its limitations, `extern` should only really be used for testing. + It can be useful, if you know what you are getting out, as it recursively removes + thunks, and otherwise makes outputs more consistent with finite differencing. + + The more useful action in general is to call `+`, or in the case of a [`Thunk`](@ref) + to call [`unthunk`](@ref). + +!!! warning + `extern` may return an alias (not necessarily a copy) to data + wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. +""" +@inline function extern(x) + Base.depwarn(EXTERN_DEPRECATION, :extern) + return x +end + +extern(x::ZeroTangent) = (Base.depwarn(EXTERN_DEPRECATION, :extern); return false) # false is a strong 0. E.g. `false * NaN = 0.0` + +function extern(x::NoTangent) + Base.depwarn(EXTERN_DEPRECATION, :extern) + throw(ArgumentError("Derivative does not exit. Cannot be converted to an external type.")) +end + +extern(comp::Tangent) = (Base.depwarn(EXTERN_DEPRECATION, :extern); return backing(map(extern, comp))) # gives a NamedTuple or Tuple + +extern(x::NotImplemented) = (Base.depwarn(EXTERN_DEPRECATION, :extern); throw(NotImplementedException(x))) + +@inline extern(x::AbstractThunk) = (Base.depwarn(EXTERN_DEPRECATION, :extern); return extern(unthunk(x))) + + + + diff --git a/src/differentials/abstract_differential.jl b/src/differentials/abstract_differential.jl index 9169ce45e..029d03af9 100644 --- a/src/differentials/abstract_differential.jl +++ b/src/differentials/abstract_differential.jl @@ -37,33 +37,4 @@ abstract type AbstractTangent end Base.:+(x::AbstractTangent) = x -""" - extern(x) - -Makes a best effort attempt to convert a differential into a primal value. -This is not always a well-defined operation. -For two reasons: - - It may not be possible to determine the primal type for a given differential. - For example, `Zero` is a valid differential for any primal. - - The primal type might not be a vector space, thus might not be a valid differential type. - For example, if the primal type is `DateTime`, it's not a valid differential type as two - `DateTime` can not be added (fun fact: `Milisecond` is a differential for `DateTime`). - -Where it is defined the operation of `extern` for a primal type `P` should be -`extern(x) = zero(P) + x`. - -!!! note - Because of its limitations, `extern` should only really be used for testing. - It can be useful, if you know what you are getting out, as it recursively removes - thunks, and otherwise makes outputs more consistent with finite differencing. - - The more useful action in general is to call `+`, or in the case of a [`Thunk`](@ref) - to call [`unthunk`](@ref). - -!!! warning - `extern` may return an alias (not necessarily a copy) to data - wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. -""" -@inline extern(x) = x - @inline Base.conj(x::AbstractTangent) = x diff --git a/src/differentials/abstract_zero.jl b/src/differentials/abstract_zero.jl index 46be3b0ad..52ef38adb 100644 --- a/src/differentials/abstract_zero.jl +++ b/src/differentials/abstract_zero.jl @@ -35,8 +35,6 @@ A derivative of `ZeroTangent()` does not propagate through the primal function. """ struct ZeroTangent <: AbstractZero end -extern(x::ZeroTangent) = false # false is a strong 0. E.g. `false * NaN = 0.0` - Base.eltype(::Type{ZeroTangent}) = ZeroTangent Base.zero(::AbstractTangent) = ZeroTangent() @@ -72,7 +70,3 @@ arguments. ``` """ struct NoTangent <: AbstractZero end - -function extern(x::NoTangent) - throw(ArgumentError("Derivative does not exit. Cannot be converted to an external type.")) -end diff --git a/src/differentials/composite.jl b/src/differentials/composite.jl index 9fadf3e76..af1a05d1f 100644 --- a/src/differentials/composite.jl +++ b/src/differentials/composite.jl @@ -114,9 +114,6 @@ end Base.conj(comp::Tangent) = map(conj, comp) -extern(comp::Tangent) = backing(map(extern, comp)) # gives a NamedTuple or Tuple - - """ backing(x) diff --git a/src/differentials/notimplemented.jl b/src/differentials/notimplemented.jl index 6a5b8961f..a2044fbe1 100644 --- a/src/differentials/notimplemented.jl +++ b/src/differentials/notimplemented.jl @@ -39,8 +39,6 @@ Base.Broadcast.broadcastable(x::NotImplemented) = Ref(x) # throw error with debugging information for other standard information # (`+`, `-`, `*`, and `dot` are defined in differential_arithmetic.jl) -extern(x::NotImplemented) = throw(NotImplementedException(x)) - Base.:/(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) diff --git a/src/differentials/thunks.jl b/src/differentials/thunks.jl index 545fb4835..efaebb0c7 100644 --- a/src/differentials/thunks.jl +++ b/src/differentials/thunks.jl @@ -30,13 +30,9 @@ end On `AbstractThunk`s this removes 1 layer of thunking. On any other type, it is the identity operation. - -In contrast to [`extern`](@ref) this is nonrecursive. """ @inline unthunk(x) = x -@inline extern(x::AbstractThunk) = extern(unthunk(x)) - Base.conj(x::AbstractThunk) = @thunk(conj(unthunk(x))) Base.adjoint(x::AbstractThunk) = @thunk(adjoint(unthunk(x))) Base.transpose(x::AbstractThunk) = @thunk(transpose(unthunk(x))) @@ -54,16 +50,11 @@ It wraps a zero argument closure that when invoked returns a differential. Calling a thunk, calls the wrapped closure. If you are unsure if you have a `Thunk`, call [`unthunk`](@ref) which is a no-op when the argument is not a `Thunk`. -If you need to unthunk recursively, call [`extern`](@ref), which also externs the differial -that the closure returns. ```jldoctest julia> t = @thunk(@thunk(3)) Thunk(var"#4#6"()) -julia> extern(t) -3 - julia> t() Thunk(var"#5#7"()) diff --git a/test/deprecated.jl b/test/deprecated.jl index 42555a3c9..7f7eb6c61 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -1,3 +1,24 @@ @testset "NO_FIELDS" begin @test (@test_deprecated NO_FIELDS) isa NoTangent end + +@testset "extern" begin + @test extern(@thunk(3)) == 3 + @test extern(@thunk(@thunk(3))) == 3 + + @test extern(Tangent{Foo}(x=2.0)) == (;x=2.0) + @test extern(Tangent{Tuple{Float64,}}(2.0)) == (2.0,) + @test extern(Tangent{Dict}(Dict(4 => 3))) == Dict(4 => 3) + + # with differentials on the inside + @test extern(Tangent{Foo}(x=@thunk(0+2.0))) == (;x=2.0) + @test extern(Tangent{Tuple{Float64,}}(@thunk(0+2.0))) == (2.0,) + @test extern(Tangent{Dict}(Dict(4 => @thunk(3)))) == Dict(4 => 3) + + z = ZeroTangent() + @test extern(z) === false + dne = NoTangent() + @test_throws Exception extern(dne) + E = ChainRulesCore.NotImplementedException + @test_throws E extern(ni) +end diff --git a/test/differentials/abstract_zero.jl b/test/differentials/abstract_zero.jl index 7a1adb577..9eb037651 100644 --- a/test/differentials/abstract_zero.jl +++ b/test/differentials/abstract_zero.jl @@ -6,7 +6,6 @@ @testset "ZeroTangent" begin z = ZeroTangent() - @test extern(z) === false @test z + z === z @test z + 1 === 1 @test 1 + z === 1 @@ -64,7 +63,6 @@ @testset "NoTangent" begin dne = NoTangent() - @test_throws Exception extern(dne) @test dne + dne == dne @test dne + 1 == 1 @test 1 + dne == 1 diff --git a/test/differentials/composite.jl b/test/differentials/composite.jl index 9ae2b15c5..37da31d47 100644 --- a/test/differentials/composite.jl +++ b/test/differentials/composite.jl @@ -123,17 +123,6 @@ end ) end - @testset "extern" begin - @test extern(Tangent{Foo}(x=2.0)) == (;x=2.0) - @test extern(Tangent{Tuple{Float64,}}(2.0)) == (2.0,) - @test extern(Tangent{Dict}(Dict(4 => 3))) == Dict(4 => 3) - - # with differentials on the inside - @test extern(Tangent{Foo}(x=@thunk(0+2.0))) == (;x=2.0) - @test extern(Tangent{Tuple{Float64,}}(@thunk(0+2.0))) == (2.0,) - @test extern(Tangent{Dict}(Dict(4 => @thunk(3)))) == Dict(4 => 3) - end - @testset "canonicalize" begin # Testing iterate via collect @test ==( diff --git a/test/differentials/notimplemented.jl b/test/differentials/notimplemented.jl index 7ed33350a..fe3668518 100644 --- a/test/differentials/notimplemented.jl +++ b/test/differentials/notimplemented.jl @@ -49,7 +49,6 @@ # unsupported operations E = ChainRulesCore.NotImplementedException - @test_throws E extern(ni) @test_throws E +ni @test_throws E -ni @test_throws E ni - rand() diff --git a/test/differentials/thunks.jl b/test/differentials/thunks.jl index 1bc79290d..2fb532fb0 100644 --- a/test/differentials/thunks.jl +++ b/test/differentials/thunks.jl @@ -6,11 +6,6 @@ @test occursin(r"Thunk\(.*rand.*\)", rep) end - @testset "Externing" begin - @test extern(@thunk(3)) == 3 - @test extern(@thunk(@thunk(3))) == 3 - end - @testset "unthunk" begin @test unthunk(@thunk(3)) == 3 @test unthunk(@thunk(@thunk(3))) isa Thunk @@ -25,7 +20,7 @@ expected_line = (@__LINE__) + 2 # for testing it is at right palce try x = @thunk(error()) - extern(x) + unthunk(x) catch err err isa ErrorException || rethrow() st = stacktrace(catch_backtrace())