Sometimes faster sum(f,x) rule#529
Conversation
|
Zygote test failure |
src/rulesets/Base/base.jl
Outdated
| return (x, identity_pullback) | ||
| end | ||
|
|
||
| derivatives_given_output(Ω, ::typeof(identity), x) = tuple(tuple(true)) |
There was a problem hiding this comment.
I am not sure how I feel about overloading this directly.
Rather than leaving it for @scalar_rule to do.
How should I feel?
There was a problem hiding this comment.
Yea I don't know. It's a bit weird to have ever more functions you have to know to define. But some functions don't use the macro.
For some functions, it might in fact be ideal to provide several methods of this, if you can equally well compute the derivative using input or output. Using input only is what this PR exploits, but using output only might be useful for fusing some broadcast things. Something like relu can be equally efficient either way. The macro can't really figure that out.
There was a problem hiding this comment.
we will need to keep thinking about that, but lets do this, so we have an example.
We can always back it out later, as it is an implementation detail
src/rulesets/Base/mapreduce.jl
Outdated
| # Then we can compute the forward pass as usual, save nothing but `xs`: | ||
| y = sum(f, xs; dims=dims) | ||
| function sum_pullback_easy(dy) | ||
| dxs = unthunk(dy) .* conj.(only.(only.(derivatives_given_output.(nothing, f, xs)))) |
There was a problem hiding this comment.
should we have something like _siso_derivatives_given_output(f, x) = only(only(nothing, f, x)
so this can be
| dxs = unthunk(dy) .* conj.(only.(only.(derivatives_given_output.(nothing, f, xs)))) | |
| dxs = unthunk(dy) .* conj.(_siso_derivatives_given_output.(f, xs)))) |
There was a problem hiding this comment.
Or only_derivative_given_output. But not sure it's worth the extra complication of one more function to track down & know about.
There was a problem hiding this comment.
I made this a do-block broadcast like the other cases, now, as I think that's more readable.
src/rulesets/Base/mapreduce.jl
Outdated
|
|
||
| function sum_pullback_f(dy) | ||
| # For arrays of arrays, we ought to protect the element against broadcasting: | ||
| dys = dims isa Colon ? Ref(unthunk(dy)) : unthunk(dy) |
There was a problem hiding this comment.
= performant to use 1-tuple rather than
Ref
| dys = dims isa Colon ? Ref(unthunk(dy)) : unthunk(dy) | |
| dys = dims isa Colon ? (unthunk(dy),) : unthunk(dy) |
There was a problem hiding this comment.
Isn't Ref the standard thing, and what's used internally? I think this makes the intention a little clearer.
julia> Broadcast.broadcastable(:x)
Base.RefValue{Symbol}(:x)
julia> Broadcast.broadcastable(sin)
Base.RefValue{typeof(sin)}(sin)
src/rulesets/Base/mapreduce.jl
Outdated
| fx_and_pullbacks = map(x -> rrule_via_ad(config, f, x), xs) | ||
| y = sum(first, fx_and_pullbacks; dims=dims) | ||
|
|
||
| function sum_pullback_f(dy) |
There was a problem hiding this comment.
can we bring back the unicode?
I feel it actually add worthwild clarity here
There was a problem hiding this comment.
See what you think of the current level of unicode-ness.
src/rulesets/Base/mapreduce.jl
Outdated
| # Then at least `f` has no gradient. Note that broadcasting here | ||
| # gets the shape right with or without `dims` keyword. | ||
| dxs = broadcast(fx_and_pullbacks, dys) do (_, back), dy1 | ||
| unthunk(last(back(dy1))) |
There was a problem hiding this comment.
I don't think the unthunk is required anymore, as didn't we fix project?
| unthunk(last(back(dy1))) | |
| last(back(dy1)) |
There was a problem hiding this comment.
I think my test case here was sum(sum, ...) where you get an InplaceableThunk from back here.
There was a problem hiding this comment.
And ProjectTo does not like those:
julia> proj = ProjectTo([1,2,3]);
julia> ith = rrule(sum, [1,2,3])[2](1)[2]; ith isa InplaceableThunk
true
julia> proj(ith)
ERROR: MethodError: no method matching (::ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(::InplaceableThunk{Thunk{ChainRules.var"#1409#1412"{Int64, Colon, Vector{Int64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}, ChainRules.var"#1408#1411"{Int64, Colon}})
Although I think at some point I wanted to make it un-wrap, insert itself into the @thunk part, and return that.
In which case you'd get an array of thunks back, not an array of arrays. Not sure what we think about that.
There was a problem hiding this comment.
The first test which fails without this unthunk is:
julia> test_rrule(sum, sum, [[1,2], [3,4], [5,6]]; check_inferred=false)
test_rrule: sum on typeof(sum),Vector{Vector{Int64}}: Test Failed at /Users/me/.julia/packages/ChainRulesTestUtils/XI7i2/src/testers.jl:307
Expression: ad_cotangent isa NoTangent
Evaluated: Thunk{ComposedFunction{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), ...
julia> CFG = ChainRulesTestUtils.ADviaRuleConfig();
julia> rrule(CFG, sum, sum, [[1,2], [3,4], [5,6]])[2](1.0)
(NoTangent(), NoTangent(), Thunk{ComposedFunction{ProjectTo{AbstractArray, NamedTuple{(:element, :axes) This works on the tagged version, something unthunks:
julia> test_rrule(sum, sum, [[1,2], [3,4], [5,6]]; check_inferred=false)
Test Summary: | Pass Total Time
test_rrule: sum on typeof(sum),Vector{Vector{Int64}} | 7 7 0.8s
Test.DefaultTestSet("test_rrule: sum on typeof(sum),Vector{Vector{Int64}}", Any[], 7, false, false, true, 1.645230607920931e9, 1.645230608709222e9)
julia> CFG = ChainRulesTestUtils.ADviaRuleConfig();
julia> rrule(CFG, sum, sum, [[1,2], [3,4], [5,6]])[2](1.0)
(NoTangent(), NoTangent(), [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]])
(@v1.8) pkg> st ChainRules
Status `~/.julia/environments/v1.8/Project.toml`
⌃ [082447d4] ChainRules v1.26.0
src/rulesets/Base/mapreduce.jl
Outdated
| ∇prod_dims!(dx, vald, x, dy, y) | ||
| return dx | ||
| end | ||
| ∇prod_dims(::Val{dims}, x, dy::AbstractZero, y=prod(x; dims=dims)) where {dims} = dy |
There was a problem hiding this comment.
Is this meant to be part of this PR?
There was a problem hiding this comment.
It's here because I found out it was missing when writing a test for this PR. Could be done separately I guess.
src/rulesets/Base/mapreduce.jl
Outdated
| ∇prod!(dx, x, dy, y) | ||
| return dx | ||
| end | ||
| ∇prod(x, dy::AbstractZero, y::Number=prod(x)) = dy |
There was a problem hiding this comment.
Is this meant to be part of this PR?
src/rulesets/Base/mapreduce.jl
Outdated
| return dx | ||
| end | ||
| end | ||
| ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy::AbstractZero, y=cumprod(x; dims=dim)) where {dim} = dy |
There was a problem hiding this comment.
Is this meant to be part of this PR?
src/rulesets/Base/mapreduce.jl
Outdated
| ∇cumprod!(dx, x, dy, y) | ||
| return dx | ||
| end | ||
| ∇cumprod(x::AbstractVector, dy::AbstractZero, y=cumprod(x)) = dy |
There was a problem hiding this comment.
Is this meant to be part of this PR?
|
Is this waiting for me to respond to something? |
|
No, it's waiting for me to circle back, sorry. |
|
I think this is ready to go. Would be nice to fix #85 on top of it, there are various ways but probably better to explore in another PR. |
|
Sorry for responding late, I was away. Would it be better for @oxinabox to review this instead, since she has the context? |
|
Bump? I think this is fine, and faster. But if we wait long enough eventually it will rot. |
mzgubic
left a comment
There was a problem hiding this comment.
Thanks for being persistent, it's a great addition and would be a shame if it got stale. Looks a nice improvement to me overall.
oxinabox
left a comment
There was a problem hiding this comment.
Yep, ok a few last comments to address, then merge when happy.
Sorry about dropping the ball on this one.
I am really hoping i can find time to catch up on my gitlab backlog soon
src/rulesets/Base/base.jl
Outdated
| return (x, identity_pullback) | ||
| end | ||
|
|
||
| derivatives_given_output(Ω, ::typeof(identity), x) = tuple(tuple(true)) |
There was a problem hiding this comment.
we will need to keep thinking about that, but lets do this, so we have an example.
We can always back it out later, as it is an implementation detail
Co-authored-by: Frames Catherine White <oxinabox@ucc.asn.au>
This does two things to the
sum(f,x)rule.First, it is a bit more efficient in how many temporary arrays it creates. It closes over the array of
y,backtuples instead of making a new array for just the pullbacks. And when broadcasting the pullbacks, it avoids making a tuple in cases wherefdoesn't have a gradient anyway.Second, it
uses Addchecksderivatives_given_inputChainRulesCore.jl#456derivatives_given_outputto see if the gradient can be computed from just the input. If so, it can avoid storing the pullbacks at all.Best, good, and worst case times: