diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 57d3c47ca..a1d67490d 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -342,14 +342,16 @@ end function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs - # `::Any` instead of `_`: https://github.com/JuliaLang/julia/issues/32727 return @strip_linenos quote - function ChainRulesCore.frule( - @nospecialize(::Any), $(map(esc, primal_sig_parts)...); $(esc(kwargs))... - ) + # Manually defined kw version to save compiler work. See explanation in rules.jl + function (::Core.kwftype(typeof(ChainRulesCore.frule)))(@nospecialize($kwargs::Any), + frule::typeof(ChainRulesCore.frule), @nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) + end + function ChainRulesCore.frule(@nospecialize(::Any), $(map(esc, primal_sig_parts)...)) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() - return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) + return ($(esc(primal_invoke)), NoTangent()) end end end diff --git a/src/rules.jl b/src/rules.jl index 232b1c558..0abc81205 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -58,10 +58,25 @@ will be hit as a fallback. This is the case for most rules. See also: [`rrule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref) """ -frule(::Any, ::Any, ::Vararg{Any}; kwargs...) = nothing +frule(ȧrgs, f, ::Vararg{Any}) = nothing # if no config is present then fallback to config-less rules -frule(::RuleConfig, ȧrgs, f, args...; kwargs...) = frule(ȧrgs, f, args...; kwargs...) +frule(::RuleConfig, args...) = frule(args...) + +# Manual fallback for keyword arguments. Usually this would be generated by +# +# frule(::Any, ::Vararg{Any}; kwargs...) = nothing +# +# However - the fallback method is so hot that we want to avoid any extra code +# that would be required to have the automatically generated method package up +# the keyword arguments (which the optimizer will throw away, but the compiler +# still has to manually analyze). Manually declare this method with an +# explicitly empty body to save the compiler that work. +const frule_kwfunc = Core.kwftype(typeof(frule)).instance +(::typeof(frule_kwfunc))(::Any, ::typeof(frule), ȧrgs, f, ::Vararg{Any}) = nothing +function (::typeof(frule_kwfunc))(kws::Any, ::typeof(frule), ::RuleConfig, args...) + return frule_kwfunc(kws, frule, args...) +end """ rrule([::RuleConfig,] f, x...) @@ -116,18 +131,11 @@ See also: [`frule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref) rrule(::Any, ::Vararg{Any}) = nothing # if no config is present then fallback to config-less rules -rrule(::RuleConfig, f, args...; kwargs...) = rrule(f, args...; kwargs...) -# TODO do we need to do something for kwargs special here for performance? -# See: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/368 - -# Manual fallback for keyword arguments. Usually this would be generated by -# -# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing -# -# However - the fallback method is so hot that we want to avoid any extra code -# that would be required to have the automatically generated method package up -# the keyword arguments (which the optimizer will throw away, but the compiler -# still has to manually analyze). Manually declare this method with an -# explicitly empty body to save the compiler that work. - -(::Core.kwftype(typeof(rrule)))(::Any, ::Any, ::Vararg{Any}) = nothing +rrule(::RuleConfig, args...) = rrule(args...) + +# Manual fallback for keyword arguments. See above +const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance +(::typeof(rrule_kwfunc))(::Any, ::typeof(rrule), ::Any, ::Vararg{Any}) = nothing +function (::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...) + return rrule_kwfunc(kws, rrule, args...) +end diff --git a/test/config.jl b/test/config.jl index c55c0b6c8..466baed9a 100644 --- a/test/config.jl +++ b/test/config.jl @@ -160,4 +160,30 @@ end @testset "RuleConfig broadcasts like a scaler" begin @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} end + + @testset "fallbacks" begin + no_rule(x; kw="bye") = error() + @test frule((1.0,), no_rule, 2.0) === nothing + @test frule((1.0,), no_rule, 2.0; kw="hello") === nothing + @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0) === nothing + @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw="hello") === nothing + @test rrule(no_rule, 2.0) === nothing + @test rrule(no_rule, 2.0; kw="hello") === nothing + @test rrule(MostBoringConfig(), no_rule, 2.0) === nothing + @test rrule(MostBoringConfig(), no_rule, 2.0; kw="hello") === nothing + + # Test that incorrect use of the fallback rules correctly throws MethodError + @test_throws MethodError frule() + @test_throws MethodError frule(;kw="hello") + @test_throws MethodError frule(sin) + @test_throws MethodError frule(sin;kw="hello") + @test_throws MethodError frule(MostBoringConfig()) + @test_throws MethodError frule(MostBoringConfig(); kw="hello") + @test_throws MethodError frule(MostBoringConfig(), sin) + @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") + @test_throws MethodError rrule() + @test_throws MethodError rrule(;kw="hello") + @test_throws MethodError rrule(MostBoringConfig()) + @test_throws MethodError rrule(MostBoringConfig();kw="hello") + end end