Conversation
061197f to
e11c001
Compare
|
All the accumulation stuff needs to be rewritten still. |
e238a8d to
e061404
Compare
|
I have not delted the AbstractRules yet, as I am yet to workout the story for I guess that will block that PR, but I will work that through as I finish JuliaDiff/ChainRules.jl#91 Good the review this now though, |
|
Cool. I will add that to this PR tomorrow. |
simeonschaub
left a comment
There was a problem hiding this comment.
Just some small typos in the docstrings, that caught my eye
nickrobinson251
left a comment
There was a problem hiding this comment.
Nice job! a bunch of pretty small comments:
look forward to reviewing once the other changes are in!
| macro scalar_rule(call, maybe_setup, partials...) | ||
| ############################################################################ | ||
| # Setup: normalizing input form etc | ||
|
|
There was a problem hiding this comment.
Can this be broken up into functions? I'd love for this to not be 100 lines long...
There was a problem hiding this comment.
i still feel the same
But if it is not easy / does not make sense to you, @oxinabox , then that's fine by me too
src/differentials.jl
Outdated
|
|
||
| #### | ||
| """ | ||
| differential(𝒟::Type, der) |
There was a problem hiding this comment.
Maybe this should take primal and conjugate as arguments, and depending on 𝒟 return either Wirtinger or their sum? I think that would make it more clear to rule authors, that when you create a Wirtinger, you usually also want this fall-through behavior.
There was a problem hiding this comment.
It could also be useful to check, whether conjugate isa Zero here, and unwrap Wirtinger if that's the case.
There was a problem hiding this comment.
Maybe this should take primal and conjugate as arguments, and depending on 𝒟 return either Wirtinger or their sum?
I'm not sure I understand the advantag of that?
With this we have
differential(𝒟, Wirtinger(primal, conjugatge))
which seems fine.
What is the advantage of
differential(𝒟, primal, conjugatge)
?
It could also be useful to check, whether conjugate isa Zero here, and unwrap Wirtinger if that's the case.
Maybe.
Maybe even iszero(conjugate) to to get constants like 0
Maybe even for inputs that are scalar:
iszero(der) && return Zero()
and similar for One()
We should discuss that kind thing in an issue and make a follow up PR for it
There was a problem hiding this comment.
What is the advantage of
differential(𝒟, primal, conjugatge)
?
I just don't know if differential is the best name for this function, since it takes a differential and returns a, sometimes different and more efficient, differential again. I would expect a function called differential to work more like a constructor. Or do we maybe want to call this wirtinger and have it take 𝒟, primal, and conjugate, since this probably corresponds better to what it does right now? But I also wouldn't feel too strongly just leaving this for now, since I'm also struggling to find a better name for this function.
Maybe.
Maybe eveniszero(conjugate)to to get constants like0Maybe even for inputs that are scalar:
iszero(der) && return Zero()
and similar forOne()
I'm not quite convinced we benefit from introducing dispatch based on value here, wouldn't this also cause problem on GPUs? But this is definitely an issue for another day.
There was a problem hiding this comment.
How about we rename it to refine_differential ?
I think this actually also interacts with #8 since we will want to apply something recursively.
I have a concern about this decision. Does it mean that, for example, the reverse rule of function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
return A * B, Rule(Ȳ -> (Ȳ * B', A' * Ȳ))
end(which is like what Zygote.jl is doing at the moment) instead of the current definition function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
return A * B, (Rule(Ȳ -> Ȳ * B'), Rule(Ȳ -> A' * Ȳ))
endwhere you can compute the derivative w.r.t different argument separately? Wouldn't it be a huge performance loss when large constant arrays are participated in the computation of the intermediate variables that depend on the variables ("trainable parameters") with which the derivatives are taken? For example, in the Generative Adversarial Network (GAN) setting, I think it would be a big issue when taking derivative I'm by no means an AD or ML specialist so I may be missing something. It would be great if you can clarify that my concern is invalid. |
|
Short answer: don't worry we solve this with Full answer:@tkf a very reasonable concern. And one I used too have
becomes: In partner PR, this is one of the ones I've already updated
That it basically is is a differential that defers computation until it is used. then since |
|
@oxinabox Thanks a lot! I appreciate the full explanation. I should have checked the partner PR. |
It is a really important question |
make real scalar rules work. correct @scalarrule forward rule return Wirtinger scalar working work WirtingerRule test as a test of @scalar_rule Fix spelling Co-Authored-By: simeonschaub <simeondavidschaub99@gmail.com> Oxford Comma Co-Authored-By: simeonschaub <simeondavidschaub99@gmail.com> spelling Co-Authored-By: Nick Robinson <npr251@gmail.com> docstring for propagator_name spelling Co-Authored-By: Nick Robinson <npr251@gmail.com>
error ratehr than Assert cleanup Update src/rule_definition_tools.jl Co-Authored-By: Nick Robinson <npr251@gmail.com> Add more complex Wirtinger Scalar Rule Test
update accumulate to work on differentials
Co-Authored-By: Curtis Vogt <curtis.vogt@gmail.com>
spelling is hard
zero the storage inplace
This reverts commit 85b5bf9.
5995297 to
3656389
Compare
|
Rebased, and squashed some of them. Normally I am hesitant to squash during PR review but this has had a lot of review so far, |
|
All tests (except inegration tests) are passing. Shuffle rebasing is hard, not sure if worth it |
nickrobinson251
left a comment
There was a problem hiding this comment.
This LGTM
Good work!
I have felt handful of tiny comment :)
| - The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)` | ||
| - The expression being a constant | ||
| - The expression being itself a `thunk` | ||
| - The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already) |
There was a problem hiding this comment.
This entire section is great
Shall we move it to a page in the docs?
| (Otherwise one can just use a normal `Thunk`). | ||
|
|
||
| Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`; | ||
| and destroy its inplacability. |
There was a problem hiding this comment.
| and destroy its inplacability. | |
| and destroy its ability to work inplace. |
| Base.conj(x::Thunk) = @thunk(conj(extern(x))) | ||
| # The real reason we have this: | ||
| accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ) | ||
| store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. |
There was a problem hiding this comment.
| store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. | |
| store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ .*= false)) # zero it, then add to it. |
| Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place, | ||
| storing the result in `Δ`. | ||
|
|
||
| Note: this function may not actually store the result in `Δ` if `Δ` is immutable, |
There was a problem hiding this comment.
| Note: this function may not actually store the result in `Δ` if `Δ` is immutable, | |
| !!! note | |
| this function may not actually store the result in `Δ` if `Δ` is immutable, |
| """ | ||
| store!(Δ, ∂) | ||
|
|
||
| Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before. |
There was a problem hiding this comment.
| Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before. | |
| Stores `∂` in `Δ` overwriting whatever was in `Δ` before. |
src/rule_definition_tools.jl
Outdated
| Returns the expression for the propagation of | ||
| the input gradient `Δs` though the partials `∂s`. | ||
|
|
||
| 𝒟 is an expression that when evaluated returns the type-of the input domain. |
There was a problem hiding this comment.
| 𝒟 is an expression that when evaluated returns the type-of the input domain. | |
| 𝒟 is an expression that when evaluated returns the type of the input domain. |
src/rule_definition_tools.jl
Outdated
| function standard_propagation_expr(Δs, ∂s) | ||
| # This is basically Δs ⋅ ∂s | ||
|
|
||
| # Notice: the thunking of `∂s[i] (potentially) saves us some computation |
There was a problem hiding this comment.
| # Notice: the thunking of `∂s[i] (potentially) saves us some computation | |
| # Notice: the thunking of `∂s[i]` (potentially) saves us some computation |
src/rule_definition_tools.jl
Outdated
| # Notice: the thunking of `∂s[i] (potentially) saves us some computation | ||
| # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon | ||
| # as the pullback is evaluated | ||
| ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] |
There was a problem hiding this comment.
Yes :) Is it worth opening an issue (to stare hard at this and figure out if all is well)?
| macro scalar_rule(call, maybe_setup, partials...) | ||
| ############################################################################ | ||
| # Setup: normalizing input form etc | ||
|
|
There was a problem hiding this comment.
i still feel the same
But if it is not easy / does not make sense to you, @oxinabox , then that's fine by me too
Co-Authored-By: Nick Robinson <npr251@gmail.com>
2bd9b6c to
e51ff80
Compare
This is a very mighty PR.
DNE()as not functors) return value for allrrules, to represent the derviative w.r.t internals of closures/functors, and similar demands an extra input argument at the start of a call tofrule(ignored for all current cases as not functors) Differentiating with respect to a function #22frule/rrulenow return a 1 propagator (pushforward/pullback) that returms a tuple of partials, rather than 1 propagator per partial 1 AbstractRule Per Partial, vs 1 AbstractRule returning a tuple of Differentials (one per partial) #38AbstractRulesubtypes are no longer used anywhere Remove Rule (or maybe all AbstractRules) and treat functions as Rules #39@scalar_ruleautomatically names pullbacks/pushforwards. Below is what that looks like in Julia Master (with new improved display for gensymed names)Does not look quiet as nice for 1.0 but still useful
It has a corresponding PR to ChainRules.jl
JuliaDiff/ChainRules.jl#91
This is the main blocker for FluxML/Zygote.jl#291