Skip to content

sum(f, xs) dispatches on non functions #522

@mzgubic

Description

@mzgubic

The rule in

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray; dims=:
)
fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs)
y = sum(first, fx_and_pullbacks; dims=dims)
pullbacks = last.(fx_and_pullbacks)
project = ProjectTo(xs)
function sum_pullback(ȳ)
call(f, x) = f(x) # we need to broadcast this to handle dims kwarg
f̄_and_x̄s = call.(pullbacks, ȳ)
# no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
= if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
NoTangent()
else
sum(first, f̄_and_x̄s)
end
x̄s = map(unthunk last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
return NoTangent(), f̄, project(x̄s)
end
return y, sum_pullback
end

which is meant to be used where f is a function/functor also dispatches on

julia> using Zygote

julia> using ChainRulesCore

julia> using StatsBase

julia> rrule(Zygote.ZygoteRuleConfig(), sum, rand(3), AnalyticWeights([1.0, 2.0, 3.0]))
ERROR: MethodError: objects of type Vector{Float64} are not callable
Use square brackets [] for indexing an Array.
Stacktrace:
  [1] macro expansion
    @ ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context, f::Vector{Float64}, args::Float64)
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:9
  [3] rrule_via_ad(::Zygote.ZygoteRuleConfig{Zygote.Context}, ::Vector{Float64}, ::Vararg{Any, N} where N; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/chainrules.jl:185
  [4] rrule_via_ad(::Zygote.ZygoteRuleConfig{Zygote.Context}, ::Vector{Float64}, ::Vararg{Any, N} where N)
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/chainrules.jl:180
  [5] (::ChainRules.var"#1388#1389"{Zygote.ZygoteRuleConfig{Zygote.Context}, Vector{Float64}})(x::Float64)
    @ ChainRules ~/JuliaEnvs/PortfolioNets.jl/dev/ChainRules/src/rulesets/Base/mapreduce.jl:71
  [6] iterate
    @ ./generator.jl:47 [inlined]
  [7] _collect
    @ ./array.jl:691 [inlined]
  [8] collect_similar
    @ ./array.jl:606 [inlined]
  [9] map
    @ ./abstractarray.jl:2294 [inlined]
 [10] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context}, ::typeof(sum), f::Vector{Float64}, xs::AnalyticWeights{Float64, Float64, Vector{Float64}}; dims::Function)
    @ ChainRules ~/JuliaEnvs/PortfolioNets.jl/dev/ChainRules/src/rulesets/Base/mapreduce.jl:71
 [11] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context}, ::typeof(sum), f::Vector{Float64}, xs::AnalyticWeights{Float64, Float64, Vector{Float64}})
    @ ChainRules ~/JuliaEnvs/PortfolioNets.jl/dev/ChainRules/src/rulesets/Base/mapreduce.jl:70
 [12] top-level scope
    @ REPL[40]:1

Is there anything we can do about this? There are several things that come to mind but none are amazing:

  • define a rule locally (i've done that, but other might run into the exact same problem)
  • add a more informative error message checking if f is callable (so that users know they need to add a rule)
  • add a rule to StatsBase, which does not depend on ChainRulesCore (they probably won't accept the PR?)
  • tighten f to f::Function, but that means the rule does dispatch on functors

I am leaning towards just throwing a better error. Maybe also open in issue in StatsBase?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions