Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.3.0"
version = "0.4.0-DEV"

[compat]
julia = "^1.0"
Expand Down
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export frule, rrule
export wirtinger_conjugate, wirtinger_primal, refine_differential
export @scalar_rule, @thunk
export extern, cast, store!
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk
export Wirtinger, Zero, One, DNE, Thunk, InplaceableThunk
export NO_FIELDS

include("differentials.jl")
Expand Down
15 changes: 2 additions & 13 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered.
Thus we can avoid any ambiguities.

Notice:
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
The precidence goes: (:Wirtinger, :Zero, :DNE, :One, :AbstractThunk, :Any)
Thus each of the @eval loops creating definitions of + and *
defines the combination this type with all types of lower precidence.
This means each eval loops is 1 item smaller than the previous.
Expand Down Expand Up @@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger)
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
end

for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
for T in (:Zero, :DNE, :One, :AbstractThunk, :Any)
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b

Expand All @@ -45,17 +45,6 @@ for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
end


Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))
for T in (:Zero, :DNE, :One, :AbstractThunk, :Any)
@eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b))
@eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value))

@eval Base.:*(a::Casted, b::$T) = Casted(broadcasted(*, a.value, b))
@eval Base.:*(a::$T, b::Casted) = Casted(broadcasted(*, a, b.value))
end


Base.:+(::Zero, b::Zero) = Zero()
Base.:*(::Zero, ::Zero) = Zero()
for T in (:DNE, :One, :AbstractThunk, :Any)
Expand Down
27 changes: 0 additions & 27 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,33 +86,6 @@ Base.iterate(::Wirtinger, ::Any) = nothing
# TODO: define `conj` for` `Wirtinger`
Base.conj(x::Wirtinger) = throw(MethodError(conj, x))


#####
##### `Casted`
#####

"""
Casted(v)

This differential wraps another differential (including a number-like type)
to indicate that it should be lazily broadcast.
"""
struct Casted{V} <: AbstractDifferential
value::V
end

cast(x) = Casted(x)
cast(f, args...) = Casted(broadcasted(f, args...))

extern(x::Casted) = materialize(broadcasted(extern, x.value))

Base.Broadcast.broadcastable(x::Casted) = x.value

Base.iterate(x::Casted) = iterate(x.value)
Base.iterate(x::Casted, state) = iterate(x.value, state)

Base.conj(x::Casted) = cast(conj, x.value)

#####
##### `Zero`
#####
Expand Down
5 changes: 1 addition & 4 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@ so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case.
This function is overloadable by using a [`InplaceThunk`](@ref).
See also: [`accumulate`](@ref), [`store!`](@ref).
"""
function accumulate!(Δ, ∂)
return materialize!(Δ, broadcastable(cast(Δ) + ∂))
end
accumulate!(Δ, ∂) = store!(Δ, accumulate(Δ, ∂))

accumulate!(Δ::Number, ∂) = accumulate(Δ, ∂)



"""
store!(Δ, ∂)

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ChainRulesCore
using LinearAlgebra: Diagonal
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate,
Zero, One, Casted, cast, DNE, Thunk
Zero, One, DNE, Thunk
using Base.Broadcast: broadcastable

@testset "ChainRulesCore" begin
Expand Down