diff --git a/Project.toml b/Project.toml index c222e0394..8c17914cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.45" +version = "0.10.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index d0f016972..67ca344dd 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -17,9 +17,9 @@ version = "0.9.44" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "0900bc19193b8e672d9cd477e6cd92d9e7c02f99" +git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.29.0" +version = "3.30.0" [[Dates]] deps = ["Printf"] diff --git a/docs/src/arrays.md b/docs/src/arrays.md index 3ba0d5a45..4a387803f 100644 --- a/docs/src/arrays.md +++ b/docs/src/arrays.md @@ -355,7 +355,7 @@ function rrule(::typeof(*), A::Matrix{<:RealOrComplex}, B::Matrix{<:RealOrComple function times_pullback(ΔΩ) ∂A = @thunk(ΔΩ * B') ∂B = @thunk(A' * ΔΩ) - return (NO_FIELDS, ∂A, ∂B) + return (NoTangent(), ∂A, ∂B) end return A * B, times_pullback end @@ -398,7 +398,7 @@ function rrule(::typeof(inv), A::Matrix{<:RealOrComplex}) Ω = inv(A) function inv_pullback(ΔΩ) ∂A = -Ω' * ΔΩ * Ω' - return (NO_FIELDS, ∂A) + return (NoTangent(), ∂A) end return Ω, inv_pullback end @@ -497,7 +497,7 @@ function rrule(::typeof(sum), ::typeof(abs2), X::Array{<:RealOrComplex}; dims = function sum_abs2_pullback(ΔΩ) ∂abs2 = NoTangent() ∂X = @thunk(2 .* real.(ΔΩ) .* X) - return (NO_FIELDS, ∂abs2, ∂X) + return (NoTangent(), ∂abs2, ∂X) end return sum(abs2, X; dims = dims), sum_abs2_pullback end @@ -702,7 +702,7 @@ function rrule(::typeof(logabsdet), A::Matrix{<:RealOrComplex}) imagf = f - real(f) # 0 for real A and Δs, im * imag(f) for complex A and/or Δs g = real(Δl) + imagf ∂A = g * inv(F)' # == g * inv(A)' - return (NO_FIELDS, ∂A) + return (NoTangent(), ∂A) end return (Ω, logabsdet_pullback) end @@ -802,7 +802,7 @@ function rrule(::typeof(sylvester), A, B, C) X = sylvester(A, B, C) function sylvester_pullback(ΔX) ∂C = copy(sylvester(B, A, copy(ΔX'))') - return NO_FIELDS, @thunk(∂C * X'), @thunk(X' * ∂C), ∂C + return NoTangent(), @thunk(∂C * X'), @thunk(X' * ∂C), ∂C end return X, sylvester_pullback end diff --git a/docs/src/index.md b/docs/src/index.md index 62c681039..301fcc91b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -255,7 +255,7 @@ For example a closure has the fields it closes over; a callable object (i.e. a f **Thus every function is treated as having the extra implicit argument `self`, which captures those fields.** So every `pushforward` takes in an extra argument, which is ignored unless the original function has fields. It is common to write `function foo_pushforward(_, Δargs...)` in the case when `foo` does not have fields. -Similarly every `pullback` returns an extra `∂self`, which for things without fields is the constant `NO_FIELDS`, indicating there are no fields within the function itself. +Similarly every `pullback` returns an extra `∂self`, which for things without fields is `NoTangent()`, indicating there are no fields within the function itself. ### Pushforward / Pullback summary diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index afce46e37..9a64d14f4 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -61,7 +61,7 @@ Use named local functions for the `pullback` in an `rrule`. function rrule(::typeof(foo), x) Y = foo(x) function foo_pullback(Ȳ) - return NO_FIELDS, bar(Ȳ) + return NoTangent(), bar(Ȳ) end return Y, foo_pullback end @@ -72,7 +72,7 @@ julia> rrule(foo, 2) # bad: function rrule(::typeof(foo), x) - return foo(x), x̄ -> (NO_FIELDS, bar(x̄)) + return foo(x), x̄ -> (NoTangent(), bar(x̄)) end #== output: julia> rrule(foo, 2) @@ -178,7 +178,7 @@ However, upon adding the `rrule` (restart the REPL after calling `gradient`) function ChainRules.rrule(::typeof(addone!), a) y = addone!(a) function addone!_pullback(ȳ) - return NO_FIELDS, ones(length(a)) + return NoTangent(), ones(length(a)) end return y, addone!_pullback end @@ -220,7 +220,7 @@ without an `rrule` defined (restart the REPL after calling `gradient`) function ChainRulesCore.rrule(::typeof(exception), x) y = exception(x) function exception_pullback(ȳ) - return NO_FIELDS, 2*x + return NoTangent(), 2*x end return y, exception_pullback end @@ -261,7 +261,7 @@ function ChainRules.rrule(::typeof(mse), x, x̂) function mse_pullback(ȳ) N = length(x) g = (2 ./ N) .* (x .- x̂) .* ȳ - return NO_FIELDS, g, -g + return NoTangent(), g, -g end return output, mse_pullback end @@ -295,7 +295,7 @@ function ChainRulesCore.rrule(::typeof(sum3), a) function sum3_pullback(ȳ) grad = zeros(length(a)) grad[1:3] .+= ȳ - return NO_FIELDS, grad + return NoTangent(), grad end return y, sum3_pullback end diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index d6e786d02..04f9ad4d4 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -11,7 +11,6 @@ export canonicalize, extern, unthunk # differential operations export add!! # gradient accumulation operations # differentials export Tangent, NoTangent, InplaceableThunk, One, Thunk, ZeroTangent, AbstractZero, AbstractThunk -export NO_FIELDS include("compat.jl") include("debug_mode.jl") diff --git a/src/deprecated.jl b/src/deprecated.jl index b4b4404ac..a9152f5a1 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -2,3 +2,4 @@ Base.@deprecate_binding AbstractDifferential AbstractTangent Base.@deprecate_binding Composite Tangent Base.@deprecate_binding Zero ZeroTangent Base.@deprecate_binding DoesNotExist NoTangent +Base.@deprecate_binding NO_FIELDS NoTangent() diff --git a/src/differentials/abstract_zero.jl b/src/differentials/abstract_zero.jl index 73585e896..46be3b0ad 100644 --- a/src/differentials/abstract_zero.jl +++ b/src/differentials/abstract_zero.jl @@ -66,7 +66,7 @@ arguments. ``` function rrule(fill, x, len::Int) y = fill(x, len) - fill_pullback(ȳ) = (NO_FIELDS, @thunk(sum(Ȳ)), NoTangent()) + fill_pullback(ȳ) = (NoTangent(), @thunk(sum(Ȳ)), NoTangent()) return y, fill_pullback end ``` diff --git a/src/differentials/composite.jl b/src/differentials/composite.jl index 28b5eceaa..9fadf3e76 100644 --- a/src/differentials/composite.jl +++ b/src/differentials/composite.jl @@ -301,12 +301,3 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} printstyled(io, err.original; color=:yellow) println(io) end - -""" - NO_FIELDS - -Constant for the reverse-mode derivative with respect to a structure that has no fields. -The most notable use for this is for the reverse-mode derivative with respect to the -function itself, when that function is not a closure. -""" -const NO_FIELDS = ZeroTangent() diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 95f4f338a..57d3c47ca 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -16,7 +16,7 @@ A convenience macro that generates simple scalar forward or reverse rules using the provided partial derivatives. Specifically, generates the corresponding methods for `frule` and `rrule`: - function ChainRulesCore.frule((NO_FIELDS, Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...) + function ChainRulesCore.frule((NoTangent(), Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) return Ω, ( @@ -30,7 +30,7 @@ methods for `frule` and `rrule`: Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) return Ω, ((ΔΩ₁, ΔΩ₂, ...)) -> ( - NO_FIELDS, + NoTangent(), ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...), ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...), ... @@ -46,7 +46,7 @@ e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x At present this does not support defining for closures/functors. Thus in reverse-mode, the first returned partial, -representing the derivative with respect to the function itself, is always `NO_FIELDS`. +representing the derivative with respect to the function itself, is always `NoTangent()`. And in forward-mode, the first input to the returned propagator is always ignored. The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This @@ -196,7 +196,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) pullback = @strip_linenos quote @inline function $(esc(propagator_name(f, :pullback)))($pullback_input) $(__source__) - return (NO_FIELDS, $(pullback_returns...)) + return (NoTangent(), $(pullback_returns...)) end end diff --git a/src/rules.jl b/src/rules.jl index fafe4a2e8..391a2debe 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -14,7 +14,7 @@ Examples: unary input, unary output scalar function: ```jldoctest frule -julia> dself = NO_FIELDS; +julia> dself = NoTangent(); julia> x = rand() 0.8236475079774124 @@ -83,7 +83,7 @@ julia> sinx, sin_pullback = rrule(sin, x); julia> sinx == sin(x) true -julia> sin_pullback(1) == (NO_FIELDS, cos(x)) +julia> sin_pullback(1) == (NoTangent(), cos(x)) true ``` @@ -97,7 +97,7 @@ julia> hypotxy, hypot_pullback = rrule(hypot, x, y); julia> hypotxy == hypot(x, y) true -julia> hypot_pullback(1) == (NO_FIELDS, (x / hypot(x, y)), (y / hypot(x, y))) +julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y))) true ``` diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 13a1b1971..6ff22da25 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -39,7 +39,7 @@ function define_dual_overload(sig) # we use the function call overloading form as it lets us avoid namespacing issues # as we can directly interpolate the function type into to the AST. function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...) - ȧrgs = (NO_FIELDS, partial.(dual_args)...) + ȧrgs = (NoTangent(), partial.(dual_args)...) args = (op, primal.(dual_args)...) y, ẏ = frule(ȧrgs, args...; kwargs...) return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error. diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 540b8ea20..17f33cf5a 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -111,7 +111,7 @@ end function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number) function times_pullback(ΔΩ) # we will use thunks here to show we handle them fine. - return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) + return (NoTangent(), @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) end return x * y, times_pullback end diff --git a/test/deprecated.jl b/test/deprecated.jl index 4035a8989..da8b17748 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -10,6 +10,7 @@ very_nice(x, y) = x + y @test Zero === ZeroTangent @test DoesNotExist === NoTangent @test Composite === Tangent + @test_deprecated NO_FIELDS @test_deprecated One() end diff --git a/test/differentials/composite.jl b/test/differentials/composite.jl index 9a64e7f1c..9ae2b15c5 100644 --- a/test/differentials/composite.jl +++ b/test/differentials/composite.jl @@ -358,8 +358,4 @@ end c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) @test nt + c == (; a=1, b=2.1); end - - @testset "NO_FIELDS" begin - @test NO_FIELDS === ZeroTangent() - end end diff --git a/test/differentials/notimplemented.jl b/test/differentials/notimplemented.jl index e80a1461e..7ed33350a 100644 --- a/test/differentials/notimplemented.jl +++ b/test/differentials/notimplemented.jl @@ -117,21 +117,21 @@ notimplemented1(x, y) = x + y @scalar_rule notimplemented1(x, y) (@not_implemented("notimplemented1"), 1) - y, ẏ = frule((NO_FIELDS, 1.2, 2.3), notimplemented1, 3, 2) + y, ẏ = frule((NoTangent(), 1.2, 2.3), notimplemented1, 3, 2) @test y == 5 @test ẏ isa ChainRulesCore.NotImplemented res, pb = rrule(notimplemented1, 3, 2) @test res == 5 f̄, x̄1, x̄2 = pb(3.1) - @test f̄ == NO_FIELDS + @test f̄ == NoTangent() @test x̄1 isa ChainRulesCore.NotImplemented @test x̄2 == 3.1 notimplemented2(x, y) = (x + y, x - y) @scalar_rule notimplemented2(x, y) (@not_implemented("notimplemented2"), 1) (1, -1) - y, (ẏ1, ẏ2) = frule((NO_FIELDS, 1.2, 2.3), notimplemented2, 3, 2) + y, (ẏ1, ẏ2) = frule((NoTangent(), 1.2, 2.3), notimplemented2, 3, 2) @test y == (5, 1) @test ẏ1 isa ChainRulesCore.NotImplemented @test ẏ2 ≈ -1.1 @@ -139,7 +139,7 @@ res, pb = rrule(notimplemented2, 3, 2) @test res == (5, 1) f̄, x̄1, x̄2 = pb((3.1, 4.5)) - @test f̄ == NO_FIELDS + @test f̄ == NoTangent() @test x̄1 isa ChainRulesCore.NotImplemented @test x̄2 == -1.4 end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index ebea77e35..28e878bd5 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -228,9 +228,9 @@ end y, simo_pb = rrule(simo, π) - @test simo_pb((10f0, 20f0)) == (NO_FIELDS, 50f0) + @test simo_pb((10f0, 20f0)) == (NoTangent(), 50f0) - y, ẏ = frule((NO_FIELDS, 50f0), simo, π) + y, ẏ = frule((NoTangent(), 50f0), simo, π) @test y == (π, 2π) @test ẏ == Tangent{typeof(y)}(50f0, 100f0) # make sure type is exactly as expected: @@ -315,7 +315,7 @@ module IsolatedModuleForTestingScoping Δy = randn() y, f_pullback = rrule(my_id, x) @test y == x - @test f_pullback(Δy) == (ZeroTangent(), Δy) + @test f_pullback(Δy) == (NoTangent(), Δy) end end end diff --git a/test/rules.jl b/test/rules.jl index 78334e229..a38c2e993 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -85,14 +85,14 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test cool_pushforward === 1 rrx, cool_pullback = rrule(cool, 1) self, rr1 = cool_pullback(1) - @test self === NO_FIELDS + @test self === NoTangent() @test rrx === 2 @test rr1 === 1 frx, nice_pushforward = frule((dself, 1), nice, 1) @test nice_pushforward === ZeroTangent() rrx, nice_pullback = rrule(nice, 1) - @test (NO_FIELDS, ZeroTangent()) === nice_pullback(1) + @test (NoTangent(), ZeroTangent()) === nice_pullback(1) # Test that these run. Do not care about numerical correctness. @@ -133,7 +133,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) Ω_rev, back = rrule(complex_times, x) @test Ω_rev == Ω ∂self, ∂x = back(Ω̄) - @test ∂self == NO_FIELDS + @test ∂self == NoTangent() @test ∂x ≈ (1 - 2im) * Ω̄ end end