From 6847f2d9b81be63d3e686e4cc1f70cb211fdb8e8 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 17 Oct 2019 11:49:33 +0100 Subject: [PATCH 1/3] Remove Casted --- src/ChainRulesCore.jl | 2 +- src/differential_arithmetic.jl | 15 ++------------- src/differentials.jl | 27 --------------------------- src/operations.jl | 5 +---- test/runtests.jl | 2 +- 5 files changed, 5 insertions(+), 46 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 118e7f841..c5f1622e3 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -5,7 +5,7 @@ export frule, rrule export wirtinger_conjugate, wirtinger_primal, refine_differential export @scalar_rule, @thunk export extern, cast, store! -export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk +export Wirtinger, Zero, One, DNE, Thunk, InplaceableThunk export NO_FIELDS include("differentials.jl") diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index e65748d34..0dd636b18 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered. Thus we can avoid any ambiguities. Notice: - The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) + The precidence goes: (:Wirtinger, :Zero, :DNE, :One, :AbstractThunk, :Any) Thus each of the @eval loops creating definitions of + and * defines the combination this type with all types of lower precidence. This means each eval loops is 1 item smaller than the previous. @@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger) return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate) end -for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) +for T in (:Zero, :DNE, :One, :AbstractThunk, :Any) @eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero()) @eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b @@ -45,17 +45,6 @@ for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) end -Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value)) -Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value)) -for T in (:Zero, :DNE, :One, :AbstractThunk, :Any) - @eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b)) - @eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value)) - - @eval Base.:*(a::Casted, b::$T) = Casted(broadcasted(*, a.value, b)) - @eval Base.:*(a::$T, b::Casted) = Casted(broadcasted(*, a, b.value)) -end - - Base.:+(::Zero, b::Zero) = Zero() Base.:*(::Zero, ::Zero) = Zero() for T in (:DNE, :One, :AbstractThunk, :Any) diff --git a/src/differentials.jl b/src/differentials.jl index 5ad2f8818..1f048f6d6 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -86,33 +86,6 @@ Base.iterate(::Wirtinger, ::Any) = nothing # TODO: define `conj` for` `Wirtinger` Base.conj(x::Wirtinger) = throw(MethodError(conj, x)) - -##### -##### `Casted` -##### - -""" - Casted(v) - -This differential wraps another differential (including a number-like type) -to indicate that it should be lazily broadcast. -""" -struct Casted{V} <: AbstractDifferential - value::V -end - -cast(x) = Casted(x) -cast(f, args...) = Casted(broadcasted(f, args...)) - -extern(x::Casted) = materialize(broadcasted(extern, x.value)) - -Base.Broadcast.broadcastable(x::Casted) = x.value - -Base.iterate(x::Casted) = iterate(x.value) -Base.iterate(x::Casted, state) = iterate(x.value, state) - -Base.conj(x::Casted) = cast(conj, x.value) - ##### ##### `Zero` ##### diff --git a/src/operations.jl b/src/operations.jl index c60134e9f..da0bfc879 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -22,14 +22,11 @@ so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case. This function is overloadable by using a [`InplaceThunk`](@ref). See also: [`accumulate`](@ref), [`store!`](@ref). """ -function accumulate!(Δ, ∂) - return materialize!(Δ, broadcastable(cast(Δ) + ∂)) -end +accumulate!(Δ, ∂) = store!(Δ, accumulate(Δ, ∂)) accumulate!(Δ::Number, ∂) = accumulate(Δ, ∂) - """ store!(Δ, ∂) diff --git a/test/runtests.jl b/test/runtests.jl index 5f8552935..e0282a66b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using ChainRulesCore using LinearAlgebra: Diagonal using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, - Zero, One, Casted, cast, DNE, Thunk + Zero, One, DNE, Thunk using Base.Broadcast: broadcastable @testset "ChainRulesCore" begin From 05af76b3b9339deec7f16d548f88d246a868721e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 17 Oct 2019 12:57:05 +0100 Subject: [PATCH 2/3] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 938852959..ae6c77687 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.3.0" +version = "0.4.0" [compat] julia = "^1.0" From aececb0ea2adb69b28f2dbd54bde2f4eeb4b379c Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 17 Oct 2019 14:14:58 +0100 Subject: [PATCH 3/3] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ae6c77687..5943665ff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.4.0" +version = "0.4.0-DEV" [compat] julia = "^1.0"