From 7c40e726c91492417f1074232620fbb02667d31c Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 17 Jun 2021 21:49:39 -0400 Subject: [PATCH 1/4] Update kwargs fallback rules after RuleConfig rewrite Fixes #368 --- src/rule_definition_tools.jl | 12 +++++++----- src/rules.jl | 38 +++++++++++++++++++++--------------- test/config.jl | 16 +++++++++++++++ 3 files changed, 45 insertions(+), 21 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 57d3c47ca..a1d67490d 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -342,14 +342,16 @@ end function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs - # `::Any` instead of `_`: https://github.com/JuliaLang/julia/issues/32727 return @strip_linenos quote - function ChainRulesCore.frule( - @nospecialize(::Any), $(map(esc, primal_sig_parts)...); $(esc(kwargs))... - ) + # Manually defined kw version to save compiler work. See explanation in rules.jl + function (::Core.kwftype(typeof(ChainRulesCore.frule)))(@nospecialize($kwargs::Any), + frule::typeof(ChainRulesCore.frule), @nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) + end + function ChainRulesCore.frule(@nospecialize(::Any), $(map(esc, primal_sig_parts)...)) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() - return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) + return ($(esc(primal_invoke)), NoTangent()) end end end diff --git a/src/rules.jl b/src/rules.jl index 232b1c558..b0b830fb7 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -58,10 +58,24 @@ will be hit as a fallback. This is the case for most rules. See also: [`rrule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref) """ -frule(::Any, ::Any, ::Vararg{Any}; kwargs...) = nothing +frule(ȧrgs, f, ::Vararg{Any}) = nothing # if no config is present then fallback to config-less rules -frule(::RuleConfig, ȧrgs, f, args...; kwargs...) = frule(ȧrgs, f, args...; kwargs...) +frule(::RuleConfig, args...) = frule(args...) + +# Manual fallback for keyword arguments. Usually this would be generated by +# +# frule(::Any, ::Vararg{Any}; kwargs...) = nothing +# +# However - the fallback method is so hot that we want to avoid any extra code +# that would be required to have the automatically generated method package up +# the keyword arguments (which the optimizer will throw away, but the compiler +# still has to manually analyze). Manually declare this method with an +# explicitly empty body to save the compiler that work. +const frule_kwfunc = Core.kwftype(typeof(frule)).instance +(::typeof(frule_kwfunc))(::Any, ::typeof(frule), ȧrgs, f, ::Vararg{Any}) = nothing +(::typeof(frule_kwfunc))(kws::Any, ::typeof(frule), ::RuleConfig, args...) = + (frule_kwfunc)(kws, frule, args...) """ rrule([::RuleConfig,] f, x...) @@ -116,18 +130,10 @@ See also: [`frule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref) rrule(::Any, ::Vararg{Any}) = nothing # if no config is present then fallback to config-less rules -rrule(::RuleConfig, f, args...; kwargs...) = rrule(f, args...; kwargs...) -# TODO do we need to do something for kwargs special here for performance? -# See: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/368 - -# Manual fallback for keyword arguments. Usually this would be generated by -# -# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing -# -# However - the fallback method is so hot that we want to avoid any extra code -# that would be required to have the automatically generated method package up -# the keyword arguments (which the optimizer will throw away, but the compiler -# still has to manually analyze). Manually declare this method with an -# explicitly empty body to save the compiler that work. +rrule(::RuleConfig, args...) = rrule(args...) -(::Core.kwftype(typeof(rrule)))(::Any, ::Any, ::Vararg{Any}) = nothing +# Manual fallback for keyword arguments. See above +const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance +(::typeof(rrule_kwfunc))(::Any, ::typeof(rrule), ::Any, ::Vararg{Any}) = nothing +(::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...) = + (rrule_kwfunc)(kws, rrule, args...) diff --git a/test/config.jl b/test/config.jl index c55c0b6c8..d8465db97 100644 --- a/test/config.jl +++ b/test/config.jl @@ -161,3 +161,19 @@ end @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} end end + +@testset "fallbacks" begin + # Test that incorrect use of the fallback rules correctly throws MethodError + @test_throws MethodError frule() + @test_throws MethodError frule(;kw="hello") + @test_throws MethodError frule(sin) + @test_throws MethodError frule(sin;kw="hello") + @test_throws MethodError frule(MostBoringConfig()) + @test_throws MethodError frule(MostBoringConfig(); kw="hello") + @test_throws MethodError frule(MostBoringConfig(), sin) + @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") + @test_throws MethodError rrule() + @test_throws MethodError rrule(;kw="hello") + @test_throws MethodError rrule(MostBoringConfig()) + @test_throws MethodError rrule(MostBoringConfig();kw="hello") +end From bf6c662177906d828640c970cb9e62b74a5af5b9 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 18 Jun 2021 22:59:11 +0100 Subject: [PATCH 2/4] fix styleguide re-multiline functions --- src/rules.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index b0b830fb7..4baed3318 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -74,8 +74,9 @@ frule(::RuleConfig, args...) = frule(args...) # explicitly empty body to save the compiler that work. const frule_kwfunc = Core.kwftype(typeof(frule)).instance (::typeof(frule_kwfunc))(::Any, ::typeof(frule), ȧrgs, f, ::Vararg{Any}) = nothing -(::typeof(frule_kwfunc))(kws::Any, ::typeof(frule), ::RuleConfig, args...) = - (frule_kwfunc)(kws, frule, args...) +function (::typeof(frule_kwfunc))(kws::Any, ::typeof(frule), ::RuleConfig, args...) + return frule_kwfunc(kws, frule, args...) +end """ rrule([::RuleConfig,] f, x...) @@ -135,5 +136,6 @@ rrule(::RuleConfig, args...) = rrule(args...) # Manual fallback for keyword arguments. See above const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance (::typeof(rrule_kwfunc))(::Any, ::typeof(rrule), ::Any, ::Vararg{Any}) = nothing -(::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...) = - (rrule_kwfunc)(kws, rrule, args...) +function (::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...) + return (rule_kwfunc(kws, rrule, args...) +end From 54db71d9fa5b59ec7c184c79b1e736d372236c71 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 18 Jun 2021 23:09:07 +0100 Subject: [PATCH 3/4] fix typo in prev restyling --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 4baed3318..0abc81205 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -137,5 +137,5 @@ rrule(::RuleConfig, args...) = rrule(args...) const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance (::typeof(rrule_kwfunc))(::Any, ::typeof(rrule), ::Any, ::Vararg{Any}) = nothing function (::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...) - return (rule_kwfunc(kws, rrule, args...) + return rrule_kwfunc(kws, rrule, args...) end From a2e266bf0870f3d0996fc159b0e5ad5042ba623b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 18 Jun 2021 23:15:03 +0100 Subject: [PATCH 4/4] add happy path tests of fallback and move tests of fallback inside main testset --- test/config.jl | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/test/config.jl b/test/config.jl index d8465db97..466baed9a 100644 --- a/test/config.jl +++ b/test/config.jl @@ -160,20 +160,30 @@ end @testset "RuleConfig broadcasts like a scaler" begin @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} end -end -@testset "fallbacks" begin - # Test that incorrect use of the fallback rules correctly throws MethodError - @test_throws MethodError frule() - @test_throws MethodError frule(;kw="hello") - @test_throws MethodError frule(sin) - @test_throws MethodError frule(sin;kw="hello") - @test_throws MethodError frule(MostBoringConfig()) - @test_throws MethodError frule(MostBoringConfig(); kw="hello") - @test_throws MethodError frule(MostBoringConfig(), sin) - @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") - @test_throws MethodError rrule() - @test_throws MethodError rrule(;kw="hello") - @test_throws MethodError rrule(MostBoringConfig()) - @test_throws MethodError rrule(MostBoringConfig();kw="hello") + @testset "fallbacks" begin + no_rule(x; kw="bye") = error() + @test frule((1.0,), no_rule, 2.0) === nothing + @test frule((1.0,), no_rule, 2.0; kw="hello") === nothing + @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0) === nothing + @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw="hello") === nothing + @test rrule(no_rule, 2.0) === nothing + @test rrule(no_rule, 2.0; kw="hello") === nothing + @test rrule(MostBoringConfig(), no_rule, 2.0) === nothing + @test rrule(MostBoringConfig(), no_rule, 2.0; kw="hello") === nothing + + # Test that incorrect use of the fallback rules correctly throws MethodError + @test_throws MethodError frule() + @test_throws MethodError frule(;kw="hello") + @test_throws MethodError frule(sin) + @test_throws MethodError frule(sin;kw="hello") + @test_throws MethodError frule(MostBoringConfig()) + @test_throws MethodError frule(MostBoringConfig(); kw="hello") + @test_throws MethodError frule(MostBoringConfig(), sin) + @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") + @test_throws MethodError rrule() + @test_throws MethodError rrule(;kw="hello") + @test_throws MethodError rrule(MostBoringConfig()) + @test_throws MethodError rrule(MostBoringConfig();kw="hello") + end end