-
Notifications
You must be signed in to change notification settings - Fork 96
Closed
Description
The rule in
ChainRules.jl/src/rulesets/Base/mapreduce.jl
Lines 66 to 89 in a130b8f
| function rrule( | |
| config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray; dims=: | |
| ) | |
| fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs) | |
| y = sum(first, fx_and_pullbacks; dims=dims) | |
| pullbacks = last.(fx_and_pullbacks) | |
| project = ProjectTo(xs) | |
| function sum_pullback(ȳ) | |
| call(f, x) = f(x) # we need to broadcast this to handle dims kwarg | |
| f̄_and_x̄s = call.(pullbacks, ȳ) | |
| # no point thunking as most of work is in f̄_and_x̄s which we need to compute for both | |
| f̄ = if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f | |
| NoTangent() | |
| else | |
| sum(first, f̄_and_x̄s) | |
| end | |
| x̄s = map(unthunk ∘ last, f̄_and_x̄s) # project does not support receiving InplaceableThunks | |
| return NoTangent(), f̄, project(x̄s) | |
| end | |
| return y, sum_pullback | |
| end |
which is meant to be used where
f is a function/functor also dispatches on
julia> using Zygote
julia> using ChainRulesCore
julia> using StatsBase
julia> rrule(Zygote.ZygoteRuleConfig(), sum, rand(3), AnalyticWeights([1.0, 2.0, 3.0]))
ERROR: MethodError: objects of type Vector{Float64} are not callable
Use square brackets [] for indexing an Array.
Stacktrace:
[1] macro expansion
@ ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(ctx::Zygote.Context, f::Vector{Float64}, args::Float64)
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:9
[3] rrule_via_ad(::Zygote.ZygoteRuleConfig{Zygote.Context}, ::Vector{Float64}, ::Vararg{Any, N} where N; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/chainrules.jl:185
[4] rrule_via_ad(::Zygote.ZygoteRuleConfig{Zygote.Context}, ::Vector{Float64}, ::Vararg{Any, N} where N)
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/chainrules.jl:180
[5] (::ChainRules.var"#1388#1389"{Zygote.ZygoteRuleConfig{Zygote.Context}, Vector{Float64}})(x::Float64)
@ ChainRules ~/JuliaEnvs/PortfolioNets.jl/dev/ChainRules/src/rulesets/Base/mapreduce.jl:71
[6] iterate
@ ./generator.jl:47 [inlined]
[7] _collect
@ ./array.jl:691 [inlined]
[8] collect_similar
@ ./array.jl:606 [inlined]
[9] map
@ ./abstractarray.jl:2294 [inlined]
[10] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context}, ::typeof(sum), f::Vector{Float64}, xs::AnalyticWeights{Float64, Float64, Vector{Float64}}; dims::Function)
@ ChainRules ~/JuliaEnvs/PortfolioNets.jl/dev/ChainRules/src/rulesets/Base/mapreduce.jl:71
[11] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context}, ::typeof(sum), f::Vector{Float64}, xs::AnalyticWeights{Float64, Float64, Vector{Float64}})
@ ChainRules ~/JuliaEnvs/PortfolioNets.jl/dev/ChainRules/src/rulesets/Base/mapreduce.jl:70
[12] top-level scope
@ REPL[40]:1Is there anything we can do about this? There are several things that come to mind but none are amazing:
- define a rule locally (i've done that, but other might run into the exact same problem)
- add a more informative error message checking if
fis callable (so that users know they need to add a rule) - add a rule to StatsBase, which does not depend on ChainRulesCore (they probably won't accept the PR?)
- tighten
ftof::Function, but that means the rule does dispatch on functors
I am leaning towards just throwing a better error. Maybe also open in issue in StatsBase?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels