non-differentiability of ops on AbstractArray{Bool}#310
non-differentiability of ops on AbstractArray{Bool}#310mzgubic merged 13 commits intoJuliaDiff:masterfrom
Conversation
|
Some of these are dublicated of ones that we have for Also can't do types until JuliaDiff/ChainRulesCore.jl#213 is solved. |
|
What do you think about defining these on |
src/rulesets/Base/nondiff.jl
Outdated
| @non_differentiable cumprod!(::Any, ::BitArray) | ||
| @non_differentiable cumsum(::BitArray) | ||
| @non_differentiable cumsum!(::Any, ::BitArray) | ||
| @non_differentiable DenseMatrix(::BitArray) |
There was a problem hiding this comment.
A general Julia query, probably not too much to do with this PR specifically. If I keep this line of code, I get a warning on doing using ChainRules:
┌ Info: Precompiling ChainRules [082447d4-558c-5d27-93f4-14fc19e9eca2]
└ @ Base loading.jl:1260
WARNING: Method definition frule(Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition frule##kw(Any, typeof(ChainRulesCore.frule), Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition rrule(UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition rrule##kw(Any, typeof(ChainRulesCore.rrule), UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
** incremental compilation may be fatally broken for this module **A warning, which I believe I understand that somewhere I made a re-definition of a function call, though don't know at which line was the re-definition made.
And when I remove this line of code (remove re-definition of non-differentiability of DenseMatrix(::BitArray), and keep the below line of code about non-differentiability of Matrix(::BitArray) I get the same exact warning:
┌ Info: Precompiling ChainRules [082447d4-558c-5d27-93f4-14fc19e9eca2]
└ @ Base loading.jl:1260
WARNING: Method definition frule(Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition frule##kw(Any, typeof(ChainRulesCore.frule), Any, UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:313.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition rrule(UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
** incremental compilation may be fatally broken for this module **
WARNING: Method definition rrule##kw(Any, typeof(ChainRulesCore.rrule), UnionAll, Base.BitArray{N} where N) in module ChainRules at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328 overwritten at /Users/gaurav/.julia/dev/ChainRulesCore/src/rule_definition_tools.jl:328.
** incremental compilation may be fatally broken for this module **, it didn't even point me to the line that was causing the error (was it DenseMatrix line or the Matrix line) which often leads me to think over about the non-exactness of the error and warning details in Julia (probably atleast compared to Python), is it just me or am I wrong in thinking like that?
There was a problem hiding this comment.
That is because of the bug that is fixed by JuliaDiff/ChainRulesCore.jl#243
so if you bump the requirement for ChainRulesCore up to 0.9.19
it should be fixed (0.9.19 is currently registering)
|
Let's consider a method like
because first one will cause an error from the primal's side, instead of returning Is my understanding correct? |
Would it be ok, if we write the non-differentiability rules for them, but don't merge this until JuliaDiff/ChainRulesCore.jl#213 is solved? |
It causes #310 (comment) but now that it is fixed we don't have to worry. |
I would do the first, because even if someone defined something like |
|
While I don't understand the innards of this package, is there a possibility that many of these might be built-in at a higher level? Perhaps, before applying any rule, ChainRules can check whether any inputs are |
ChainRules can't do this. So yes, maybe at some point in the future we will clear house and remove a bunch of rules. |
src/rulesets/Base/nondiff.jl
Outdated
| @non_differentiable strides(::AbstractArray{Bool}) | ||
| @non_differentiable vcat(::AbstractArray{Bool}) |
There was a problem hiding this comment.
Should this be vcat(::AbstractArray{Bool}...) if that's accepted?
Also, many of these like strudes & isperm surely aren't differentiable with any input, but isperm([true]) is an error anyway.
similar also accepts further argumens.
There was a problem hiding this comment.
Yes, I think that should be the case. We'll wait on the implementation of support of Vararg's as its implemented in PR JuliaDiff/ChainRulesCore.jl#254 ?
|
This is probably ready for a review now.. |
Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
Fixes #293
I think definitely there are more rules to be added.