Rules for getindex(::Tuple) and sum(::Tuple)#643
Conversation
src/rulesets/Base/indexing.jl
Outdated
| function rrule(::typeof(getindex), x::NTuple{N,T}, i::Integer) where {N, T<:Number} | ||
| proj = ProjectTo(x) | ||
| len = Val(N) | ||
| function getindex_back_2(dy_raw) | ||
| dy = unthunk(dy_raw) | ||
| dx = ntuple(j -> j == i ? dy : zero(dy), len) |
There was a problem hiding this comment.
Not certain it's worthwhile, but the idea of this was to allow uniform tuples to have type-stable behaviour.
Zygote does not do this, always uses nothing for others. Has a special method for ranges of indices though.
There was a problem hiding this comment.
my instinct is that this is not worth doing
but maybe it is is hit in something important.
StaticArrays.jl?
Either way need comment
| function frule((_, ẋ), ::typeof(getindex), x::Tuple, i) | ||
| y = x[i] | ||
| return y, Tangent{typeof(y)}(ẋ[i]...) | ||
| end |
There was a problem hiding this comment.
Do we need this rule?
I guess it avoids going through a bunch of indexing machinery
There was a problem hiding this comment.
Don't know. No test case motivated any of these forward rules, they're just in the name of completeness.
| function rrule(::typeof(getindex), x::T, inds) where {T<:Tuple} # e.g. ranges, not type-stable | ||
| function getindex_back_3(dy_raw) | ||
| dy = unthunk(dy_raw) | ||
| dx = ntuple(Returns(NoTangent()), _tuple_N(T)) | ||
| for (dyi, i) in zip(dy, inds) | ||
| dx = Base.setindex(dx, dyi + dx[i], i) | ||
| end | ||
| return (NoTangent(), Tangent{T}(dx...), NoTangent()) |
There was a problem hiding this comment.
FWIW, Zygote has two rules here, one for inds:: AbstractUnitRange, and a generic one: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/lib.jl#L125-L142 The generic one accumulates within each tuple element, rather than adding the tuples.
I have not compared carefully whether that performs better in some cases, etc.
Co-authored-by: Frames Catherine White <oxinabox@ucc.asn.au>
No description provided.