-
Notifications
You must be signed in to change notification settings - Fork 65
Closed
Labels
designRequires some desgin before changes are madeRequires some desgin before changes are madeenhancementNew feature or requestNew feature or requestrule definition helperrelating to helpers for declaring rulesrelating to helpers for declaring rules
Description
julia> @benchmark cos(x) setup=(x=0.5)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 4.409 ns (0.00% GC)
median time: 4.421 ns (0.00% GC)
mean time: 4.530 ns (0.00% GC)
maximum time: 101.587 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 1000
julia> @benchmark rrule(sin, x)[2](1.0) setup=(x=0.5)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 7.840 ns (0.00% GC)
median time: 8.102 ns (0.00% GC)
mean time: 8.157 ns (0.00% GC)
maximum time: 35.264 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 999Because the macro @scalar_rule generates the following code.
julia> MacroTools.prettify(@macroexpand @scalar_rule sinc(x) cosc(x))
quote
if !(sinc isa ChainRulesCore.Type) && ChainRulesCore.fieldcount(ChainRulesCore.typeof(sinc)) > 0
ChainRulesCore.throw(ChainRulesCore.ArgumentError("@scalar_rule cannot be used on closures/functors (such as $(sinc))"))
end
function (ChainRulesCore.ChainRulesCore).frule((ChainRulesCore._, Δ1), ::ChainRulesCore.typeof(sinc), x::Number)
Ω = sinc(x)
nothing
return (Ω, cosc(x) * Δ1)
end
function (ChainRulesCore.ChainRulesCore).rrule(::ChainRulesCore.typeof(sinc), x::Number)
Ω = sinc(x)
nothing
return (Ω, function sinc_pullback(gull)
return (ChainRulesCore.NO_FIELDS, ChainRulesCore.conj(cosc(x)) * gull)
end)
end
endIn order to make the code more friendly to packages that want to make use of these scalar rules. The following generated code might be better?
julia> MacroTools.prettify(@macroexpand @scalar_rule sinc(x) cosc(x))
quote
if !(sinc isa ChainRulesCore.Type) && ChainRulesCore.fieldcount(ChainRulesCore.typeof(sinc)) > 0
ChainRulesCore.throw(ChainRulesCore.ArgumentError("@scalar_rule cannot be used on closures/functors (such as $(sinc))"))
end
function (ChainRulesCore.ChainRulesCore).frule((ChainRulesCore._, Δ1), ::ChainRulesCore.typeof(sinc), x::Number)
Ω = sinc(x)
nothing
return (Ω, cosc(x) * Δ1)
end
function scalar_pullback(::ChainRulesCore.typeof(sinc), x::Number)
function sinc_pullback(gull)
return (ChainRulesCore.NO_FIELDS, ChainRulesCore.conj(cosc(x)) * gull)
end
end
function (ChainRulesCore.ChainRulesCore).rrule(f::ChainRulesCore.typeof(sinc), x::Number)
Ω = sinc(x)
nothing
return (Ω, scalar_pullback(f, x))
end
endNote: there are limited number of functions like + that do not need to know the value of x. We can define them separately.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
designRequires some desgin before changes are madeRequires some desgin before changes are madeenhancementNew feature or requestNew feature or requestrule definition helperrelating to helpers for declaring rulesrelating to helpers for declaring rules