Skip to content

Rules for getindex(::Tuple) and sum(::Tuple)#643

Merged
mcabbott merged 10 commits intoJuliaDiff:mainfrom
mcabbott:tuples
Jul 14, 2022
Merged

Rules for getindex(::Tuple) and sum(::Tuple)#643
mcabbott merged 10 commits intoJuliaDiff:mainfrom
mcabbott:tuples

Conversation

@mcabbott
Copy link
Member

@mcabbott mcabbott commented Jul 7, 2022

No description provided.

Comment on lines +21 to +26
function rrule(::typeof(getindex), x::NTuple{N,T}, i::Integer) where {N, T<:Number}
proj = ProjectTo(x)
len = Val(N)
function getindex_back_2(dy_raw)
dy = unthunk(dy_raw)
dx = ntuple(j -> j == i ? dy : zero(dy), len)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not certain it's worthwhile, but the idea of this was to allow uniform tuples to have type-stable behaviour.

Zygote does not do this, always uses nothing for others. Has a special method for ranges of indices though.

Copy link
Member

@oxinabox oxinabox Jul 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my instinct is that this is not worth doing
but maybe it is is hit in something important.
StaticArrays.jl?

Either way need comment

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically LGTM

Comment on lines +9 to +12
function frule((_, ẋ), ::typeof(getindex), x::Tuple, i)
y = x[i]
return y, Tangent{typeof(y)}(ẋ[i]...)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this rule?
I guess it avoids going through a bunch of indexing machinery

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know. No test case motivated any of these forward rules, they're just in the name of completeness.

Comment on lines +35 to +45
function rrule(::typeof(getindex), x::T, inds) where {T<:Tuple} # e.g. ranges, not type-stable
function getindex_back_3(dy_raw)
dy = unthunk(dy_raw)
dx = ntuple(Returns(NoTangent()), _tuple_N(T))
for (dyi, i) in zip(dy, inds)
dx = Base.setindex(dx, dyi + dx[i], i)
end
return (NoTangent(), Tangent{T}(dx...), NoTangent())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, Zygote has two rules here, one for inds:: AbstractUnitRange, and a generic one: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/lib.jl#L125-L142 The generic one accumulates within each tuple element, rather than adding the tuples.

I have not compared carefully whether that performs better in some cases, etc.

@mcabbott mcabbott merged commit 8073c7c into JuliaDiff:main Jul 14, 2022
@mcabbott mcabbott deleted the tuples branch July 14, 2022 18:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants