Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.45"
version = "0.10.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
4 changes: 2 additions & 2 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ version = "0.9.44"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "0900bc19193b8e672d9cd477e6cd92d9e7c02f99"
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.29.0"
version = "3.30.0"

[[Dates]]
deps = ["Printf"]
Expand Down
10 changes: 5 additions & 5 deletions docs/src/arrays.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ function rrule(::typeof(*), A::Matrix{<:RealOrComplex}, B::Matrix{<:RealOrComple
function times_pullback(ΔΩ)
∂A = @thunk(ΔΩ * B')
∂B = @thunk(A' * ΔΩ)
return (NO_FIELDS, ∂A, ∂B)
return (NoTangent(), ∂A, ∂B)
end
return A * B, times_pullback
end
Expand Down Expand Up @@ -398,7 +398,7 @@ function rrule(::typeof(inv), A::Matrix{<:RealOrComplex})
Ω = inv(A)
function inv_pullback(ΔΩ)
∂A = -Ω' * ΔΩ * Ω'
return (NO_FIELDS, ∂A)
return (NoTangent(), ∂A)
end
return Ω, inv_pullback
end
Expand Down Expand Up @@ -497,7 +497,7 @@ function rrule(::typeof(sum), ::typeof(abs2), X::Array{<:RealOrComplex}; dims =
function sum_abs2_pullback(ΔΩ)
∂abs2 = NoTangent()
∂X = @thunk(2 .* real.(ΔΩ) .* X)
return (NO_FIELDS, ∂abs2, ∂X)
return (NoTangent(), ∂abs2, ∂X)
end
return sum(abs2, X; dims = dims), sum_abs2_pullback
end
Expand Down Expand Up @@ -702,7 +702,7 @@ function rrule(::typeof(logabsdet), A::Matrix{<:RealOrComplex})
imagf = f - real(f) # 0 for real A and Δs, im * imag(f) for complex A and/or Δs
g = real(Δl) + imagf
∂A = g * inv(F)' # == g * inv(A)'
return (NO_FIELDS, ∂A)
return (NoTangent(), ∂A)
end
return (Ω, logabsdet_pullback)
end
Expand Down Expand Up @@ -802,7 +802,7 @@ function rrule(::typeof(sylvester), A, B, C)
X = sylvester(A, B, C)
function sylvester_pullback(ΔX)
∂C = copy(sylvester(B, A, copy(ΔX'))')
return NO_FIELDS, @thunk(∂C * X'), @thunk(X' * ∂C), ∂C
return NoTangent(), @thunk(∂C * X'), @thunk(X' * ∂C), ∂C
end
return X, sylvester_pullback
end
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ For example a closure has the fields it closes over; a callable object (i.e. a f
**Thus every function is treated as having the extra implicit argument `self`, which captures those fields.**
So every `pushforward` takes in an extra argument, which is ignored unless the original function has fields.
It is common to write `function foo_pushforward(_, Δargs...)` in the case when `foo` does not have fields.
Similarly every `pullback` returns an extra `∂self`, which for things without fields is the constant `NO_FIELDS`, indicating there are no fields within the function itself.
Similarly every `pullback` returns an extra `∂self`, which for things without fields is `NoTangent()`, indicating there are no fields within the function itself.


### Pushforward / Pullback summary
Expand Down
12 changes: 6 additions & 6 deletions docs/src/writing_good_rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Use named local functions for the `pullback` in an `rrule`.
function rrule(::typeof(foo), x)
Y = foo(x)
function foo_pullback(Ȳ)
return NO_FIELDS, bar(Ȳ)
return NoTangent(), bar(Ȳ)
end
return Y, foo_pullback
end
Expand All @@ -72,7 +72,7 @@ julia> rrule(foo, 2)

# bad:
function rrule(::typeof(foo), x)
return foo(x), x̄ -> (NO_FIELDS, bar(x̄))
return foo(x), x̄ -> (NoTangent(), bar(x̄))
end
#== output:
julia> rrule(foo, 2)
Expand Down Expand Up @@ -178,7 +178,7 @@ However, upon adding the `rrule` (restart the REPL after calling `gradient`)
function ChainRules.rrule(::typeof(addone!), a)
y = addone!(a)
function addone!_pullback(ȳ)
return NO_FIELDS, ones(length(a))
return NoTangent(), ones(length(a))
end
return y, addone!_pullback
end
Expand Down Expand Up @@ -220,7 +220,7 @@ without an `rrule` defined (restart the REPL after calling `gradient`)
function ChainRulesCore.rrule(::typeof(exception), x)
y = exception(x)
function exception_pullback(ȳ)
return NO_FIELDS, 2*x
return NoTangent(), 2*x
end
return y, exception_pullback
end
Expand Down Expand Up @@ -261,7 +261,7 @@ function ChainRules.rrule(::typeof(mse), x, x̂)
function mse_pullback(ȳ)
N = length(x)
g = (2 ./ N) .* (x .- x̂) .* ȳ
return NO_FIELDS, g, -g
return NoTangent(), g, -g
end
return output, mse_pullback
end
Expand Down Expand Up @@ -295,7 +295,7 @@ function ChainRulesCore.rrule(::typeof(sum3), a)
function sum3_pullback(ȳ)
grad = zeros(length(a))
grad[1:3] .+= ȳ
return NO_FIELDS, grad
return NoTangent(), grad
end
return y, sum3_pullback
end
Expand Down
1 change: 0 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ export canonicalize, extern, unthunk # differential operations
export add!! # gradient accumulation operations
# differentials
export Tangent, NoTangent, InplaceableThunk, One, Thunk, ZeroTangent, AbstractZero, AbstractThunk
export NO_FIELDS

include("compat.jl")
include("debug_mode.jl")
Expand Down
1 change: 1 addition & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ Base.@deprecate_binding AbstractDifferential AbstractTangent
Base.@deprecate_binding Composite Tangent
Base.@deprecate_binding Zero ZeroTangent
Base.@deprecate_binding DoesNotExist NoTangent
Base.@deprecate_binding NO_FIELDS NoTangent()
2 changes: 1 addition & 1 deletion src/differentials/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ arguments.
```
function rrule(fill, x, len::Int)
y = fill(x, len)
fill_pullback(ȳ) = (NO_FIELDS, @thunk(sum(Ȳ)), NoTangent())
fill_pullback(ȳ) = (NoTangent(), @thunk(sum(Ȳ)), NoTangent())
return y, fill_pullback
end
```
Expand Down
9 changes: 0 additions & 9 deletions src/differentials/composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,3 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P}
printstyled(io, err.original; color=:yellow)
println(io)
end

"""
NO_FIELDS

Constant for the reverse-mode derivative with respect to a structure that has no fields.
The most notable use for this is for the reverse-mode derivative with respect to the
function itself, when that function is not a closure.
"""
const NO_FIELDS = ZeroTangent()
8 changes: 4 additions & 4 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ A convenience macro that generates simple scalar forward or reverse rules using
the provided partial derivatives. Specifically, generates the corresponding
methods for `frule` and `rrule`:

function ChainRulesCore.frule((NO_FIELDS, Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...)
function ChainRulesCore.frule((NoTangent(), Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...)
Ω = f(x₁, x₂, ...)
\$(statement₁, statement₂, ...)
return Ω, (
Expand All @@ -30,7 +30,7 @@ methods for `frule` and `rrule`:
Ω = f(x₁, x₂, ...)
\$(statement₁, statement₂, ...)
return Ω, ((ΔΩ₁, ΔΩ₂, ...)) -> (
NO_FIELDS,
NoTangent(),
∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
...
Expand All @@ -46,7 +46,7 @@ e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x

At present this does not support defining for closures/functors.
Thus in reverse-mode, the first returned partial,
representing the derivative with respect to the function itself, is always `NO_FIELDS`.
representing the derivative with respect to the function itself, is always `NoTangent()`.
And in forward-mode, the first input to the returned propagator is always ignored.

The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
Expand Down Expand Up @@ -196,7 +196,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
pullback = @strip_linenos quote
@inline function $(esc(propagator_name(f, :pullback)))($pullback_input)
$(__source__)
return (NO_FIELDS, $(pullback_returns...))
return (NoTangent(), $(pullback_returns...))
end
end

Expand Down
6 changes: 3 additions & 3 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Examples:
unary input, unary output scalar function:

```jldoctest frule
julia> dself = NO_FIELDS;
julia> dself = NoTangent();

julia> x = rand()
0.8236475079774124
Expand Down Expand Up @@ -83,7 +83,7 @@ julia> sinx, sin_pullback = rrule(sin, x);
julia> sinx == sin(x)
true

julia> sin_pullback(1) == (NO_FIELDS, cos(x))
julia> sin_pullback(1) == (NoTangent(), cos(x))
true
```

Expand All @@ -97,7 +97,7 @@ julia> hypotxy, hypot_pullback = rrule(hypot, x, y);
julia> hypotxy == hypot(x, y)
true

julia> hypot_pullback(1) == (NO_FIELDS, (x / hypot(x, y)), (y / hypot(x, y)))
julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y)))
true
```

Expand Down
2 changes: 1 addition & 1 deletion test/demos/forwarddiffzero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function define_dual_overload(sig)
# we use the function call overloading form as it lets us avoid namespacing issues
# as we can directly interpolate the function type into to the AST.
function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...)
ȧrgs = (NO_FIELDS, partial.(dual_args)...)
ȧrgs = (NoTangent(), partial.(dual_args)...)
args = (op, primal.(dual_args)...)
y, ẏ = frule(ȧrgs, args...; kwargs...)
return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error.
Expand Down
2 changes: 1 addition & 1 deletion test/demos/reversediffzero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ end
function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number)
function times_pullback(ΔΩ)
# we will use thunks here to show we handle them fine.
return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
return (NoTangent(), @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
end
return x * y, times_pullback
end
Expand Down
1 change: 1 addition & 0 deletions test/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ very_nice(x, y) = x + y
@test Zero === ZeroTangent
@test DoesNotExist === NoTangent
@test Composite === Tangent
@test_deprecated NO_FIELDS
@test_deprecated One()
end

Expand Down
4 changes: 0 additions & 4 deletions test/differentials/composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,4 @@ end
c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1)
@test nt + c == (; a=1, b=2.1);
end

@testset "NO_FIELDS" begin
@test NO_FIELDS === ZeroTangent()
end
end
8 changes: 4 additions & 4 deletions test/differentials/notimplemented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,29 +117,29 @@
notimplemented1(x, y) = x + y
@scalar_rule notimplemented1(x, y) (@not_implemented("notimplemented1"), 1)

y, ẏ = frule((NO_FIELDS, 1.2, 2.3), notimplemented1, 3, 2)
y, ẏ = frule((NoTangent(), 1.2, 2.3), notimplemented1, 3, 2)
@test y == 5
@test ẏ isa ChainRulesCore.NotImplemented

res, pb = rrule(notimplemented1, 3, 2)
@test res == 5
f̄, x̄1, x̄2 = pb(3.1)
@test f̄ == NO_FIELDS
@test f̄ == NoTangent()
@test x̄1 isa ChainRulesCore.NotImplemented
@test x̄2 == 3.1

notimplemented2(x, y) = (x + y, x - y)
@scalar_rule notimplemented2(x, y) (@not_implemented("notimplemented2"), 1) (1, -1)

y, (ẏ1, ẏ2) = frule((NO_FIELDS, 1.2, 2.3), notimplemented2, 3, 2)
y, (ẏ1, ẏ2) = frule((NoTangent(), 1.2, 2.3), notimplemented2, 3, 2)
@test y == (5, 1)
@test ẏ1 isa ChainRulesCore.NotImplemented
@test ẏ2 ≈ -1.1

res, pb = rrule(notimplemented2, 3, 2)
@test res == (5, 1)
f̄, x̄1, x̄2 = pb((3.1, 4.5))
@test f̄ == NO_FIELDS
@test f̄ == NoTangent()
@test x̄1 isa ChainRulesCore.NotImplemented
@test x̄2 == -1.4
end
Expand Down
6 changes: 3 additions & 3 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ end

y, simo_pb = rrule(simo, π)

@test simo_pb((10f0, 20f0)) == (NO_FIELDS, 50f0)
@test simo_pb((10f0, 20f0)) == (NoTangent(), 50f0)

y, ẏ = frule((NO_FIELDS, 50f0), simo, π)
y, ẏ = frule((NoTangent(), 50f0), simo, π)
@test y == (π, 2π)
@test ẏ == Tangent{typeof(y)}(50f0, 100f0)
# make sure type is exactly as expected:
Expand Down Expand Up @@ -315,7 +315,7 @@ module IsolatedModuleForTestingScoping
Δy = randn()
y, f_pullback = rrule(my_id, x)
@test y == x
@test f_pullback(Δy) == (ZeroTangent(), Δy)
@test f_pullback(Δy) == (NoTangent(), Δy)
end
end
end
6 changes: 3 additions & 3 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@test cool_pushforward === 1
rrx, cool_pullback = rrule(cool, 1)
self, rr1 = cool_pullback(1)
@test self === NO_FIELDS
@test self === NoTangent()
@test rrx === 2
@test rr1 === 1

frx, nice_pushforward = frule((dself, 1), nice, 1)
@test nice_pushforward === ZeroTangent()
rrx, nice_pullback = rrule(nice, 1)
@test (NO_FIELDS, ZeroTangent()) === nice_pullback(1)
@test (NoTangent(), ZeroTangent()) === nice_pullback(1)


# Test that these run. Do not care about numerical correctness.
Expand Down Expand Up @@ -133,7 +133,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
Ω_rev, back = rrule(complex_times, x)
@test Ω_rev == Ω
∂self, ∂x = back(Ω̄)
@test ∂self == NO_FIELDS
@test ∂self == NoTangent()
@test ∂x ≈ (1 - 2im) * Ω̄
end
end