Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ export frule, rrule
export wirtinger_conjugate, wirtinger_primal, refine_differential
export @scalar_rule, @thunk
export extern, cast, store!
export Wirtinger, Zero, One, DNE, Thunk, InplaceableThunk
export Wirtinger, Zero, One, DoesNotExist, Thunk, InplaceableThunk
export NO_FIELDS

include("differentials.jl")
include("differential_arithmetic.jl")
include("operations.jl")
include("rules.jl")
include("rule_definition_tools.jl")

Base.@deprecate_binding DNE DoesNotExist

end # module
18 changes: 9 additions & 9 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered.
Thus we can avoid any ambiguities.

Notice:
The precidence goes: (:Wirtinger, :Zero, :DNE, :One, :AbstractThunk, :Any)
The precidence goes: (:Wirtinger, :Zero, :DoesNotExist, :One, :AbstractThunk, :Any)
Thus each of the @eval loops creating definitions of + and *
defines the combination this type with all types of lower precidence.
This means each eval loops is 1 item smaller than the previous.
Expand Down Expand Up @@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger)
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
end

for T in (:Zero, :DNE, :One, :AbstractThunk, :Any)
for T in (:Zero, :DoesNotExist, :One, :AbstractThunk, :Any)
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b

Expand All @@ -47,7 +47,7 @@ end

Base.:+(::Zero, b::Zero) = Zero()
Base.:*(::Zero, ::Zero) = Zero()
for T in (:DNE, :One, :AbstractThunk, :Any)
for T in (:DoesNotExist, :One, :AbstractThunk, :Any)
@eval Base.:+(::Zero, b::$T) = b
@eval Base.:+(a::$T, ::Zero) = a

Expand All @@ -56,14 +56,14 @@ for T in (:DNE, :One, :AbstractThunk, :Any)
end


Base.:+(::DNE, ::DNE) = DNE()
Base.:*(::DNE, ::DNE) = DNE()
Base.:+(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
Base.:*(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
for T in (:One, :AbstractThunk, :Any)
@eval Base.:+(::DNE, b::$T) = b
@eval Base.:+(a::$T, ::DNE) = a
@eval Base.:+(::DoesNotExist, b::$T) = b
@eval Base.:+(a::$T, ::DoesNotExist) = a

@eval Base.:*(::DNE, ::$T) = DNE()
@eval Base.:*(::$T, ::DNE) = DNE()
@eval Base.:*(::DoesNotExist, ::$T) = DoesNotExist()
@eval Base.:*(::$T, ::DoesNotExist) = DoesNotExist()
end


Expand Down
16 changes: 8 additions & 8 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,26 @@ Base.iterate(::Zero, ::Any) = nothing


#####
##### `DNE`
##### `DoesNotExist`
#####

"""
DNE()
DoesNotExist()

This differential indicates that the derivative Does Not Exist (D.N.E).
This is not the cast that it is not implemented, but rather that it mathematically
is not defined.
"""
struct DNE <: AbstractDifferential end
struct DoesNotExist <: AbstractDifferential end

function extern(x::DNE)
function extern(x::DoesNotExist)
throw(ArgumentError("Derivative does not exit. Cannot be converted to an external type."))
end

Base.Broadcast.broadcastable(::DNE) = Ref(DNE())
Base.Broadcast.broadcastable(::DoesNotExist) = Ref(DoesNotExist())

Base.iterate(x::DNE) = (x, nothing)
Base.iterate(::DNE, ::Any) = nothing
Base.iterate(x::DoesNotExist) = (x, nothing)
Base.iterate(::DoesNotExist, ::Any) = nothing

#####
##### `One`
Expand Down Expand Up @@ -270,7 +270,7 @@ Constant for the reverse-mode derivative with respect to a structure that has no
The most notable use for this is for the reverse-mode derivative with respect to the
function itself, when that function is not a closure.
"""
const NO_FIELDS = DNE()
const NO_FIELDS = DoesNotExist()

"""
refine_differential(𝒟::Type, der)
Expand Down
2 changes: 1 addition & 1 deletion test/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
@test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4

# For most differentials, in most domains, this does nothing
for der in (DNE(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0)
for der in (DoesNotExist(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0)
for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2]))
@test refine_differential(𝒟, der) === der
end
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ChainRulesCore
using LinearAlgebra: Diagonal
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate,
Zero, One, DNE, Thunk
Zero, One, DoesNotExist, Thunk
using Base.Broadcast: broadcastable

@testset "ChainRulesCore" begin
Expand Down