Skip to content

Add methods for thunks#371

Merged
mzgubic merged 14 commits intomasterfrom
mz/thunks
Jun 18, 2021
Merged

Add methods for thunks#371
mzgubic merged 14 commits intomasterfrom
mz/thunks

Conversation

@mzgubic
Copy link
Member

@mzgubic mzgubic commented Jun 17, 2021

@mzgubic mzgubic changed the title Add methods for thunks WIP: Add methods for thunks Jun 17, 2021
@codecov-commenter
Copy link

codecov-commenter commented Jun 17, 2021

Codecov Report

Merging #371 (2a78a05) into master (bf01ddf) will increase coverage by 0.50%.
The diff coverage is 92.95%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #371      +/-   ##
==========================================
+ Coverage   88.86%   89.36%   +0.50%     
==========================================
  Files          14       14              
  Lines         485      555      +70     
==========================================
+ Hits          431      496      +65     
- Misses         54       59       +5     
Impacted Files Coverage Δ
src/ChainRulesCore.jl 100.00% <ø> (ø)
src/differentials/abstract_zero.jl 86.66% <0.00%> (-6.20%) ⬇️
src/differentials/thunks.jl 93.75% <94.28%> (+1.15%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bf01ddf...2a78a05. Read the comment docs.

@mzgubic mzgubic changed the title WIP: Add methods for thunks Add methods for thunks Jun 17, 2021
include("differentials/abstract_zero.jl")
include("differentials/thunks.jl")
include("differentials/composite.jl")
include("differentials/combinations.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the wrong location for this file.
It should be in the src directly, like differential_arthmetic.jl
Everything in differentials/*.jl focuses on one type and operations just on that type

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One significant concern re: mutation, i think we shouldn't allow it as it doesn't work.
And we should just make the user unthunk.

Similar to that but less sever is: size, axes, and eltype.
There operations are basically never the only operation you are going to do to a thunk.
Often it is something like:
for i in 1:size(x, 1); for j in size(x, 2); net += x[i, j]; end; end
that is going to be really slow compare to just unthunking at the start, and then doing that.

@@ -19,6 +19,74 @@ Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b)
Base.:(==)(a::AbstractThunk, b) = unthunk(a) == b
Base.:(==)(a, b::AbstractThunk) = a == unthunk(b)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will some of these, esp the two arg ones move into differential_arthmetic.jl once we define them on everything?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess just the two arg ones?

Base.:(-)(a::AbstractThunk, b) = unthunk(a) - b
Base.:(-)(a, b::AbstractThunk) = a - unthunk(b)
Base.:(/)(a::AbstractThunk, b) = unthunk(a) / b
Base.:(/)(a, b::AbstractThunk) = a / unthunk(b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't have to worry about both being thunks because that would be a nonlinear operation between two tangent types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally I only defined the functions I needed to get the ChainRules tests to pass, and didn't worry too much about all possible combinations. Not sure how less lightweight ChainRulesCore is going to get once we define all these functions? Or is that not a concern at all?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how less lightweight ChainRulesCore is going to get once we define all these functions? Or is that not a concern at all?

I don't think it is too much of a concern. Not compared to correctness and ease of use.
Except for if the time increase is huge.

E.g. for me current release of ChainRulesCore is

julia> @time using ChainRulesCore
  0.251466 seconds (452.55 k allocations: 26.666 MiB)

and this branch which adds a ton already gets it up to

julia> @time using ChainRulesCore
  0.273390 seconds (486.60 k allocations: 28.348 MiB)

I suspect we can get it down more with some precompilation.
If we have to we can move them out to ChainRules.jl (and then require people load ChainRules.jl for testing etc).
But for now lets not worry too much

LinearAlgebra.Hermitian(a::AbstractThunk, uplo=:U) = Hermitian(unthunk(a), uplo)

LinearAlgebra.diagm(kv::Pair{<:Integer, <:AbstractThunk}...) = diagm((k => unthunk(v) for (k, v) in kv)...)
LinearAlgebra.diagm(m, n, kv::Pair{<:Integer, <:AbstractThunk}...) = diagm(m, n, (k => unthunk(v) for (k, v) in kv)...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines are over-length.
Run JuliaFormatter to fix?

@oxinabox
Copy link
Member

Did you havethoughts on

... less sever is: size, axes, and eltype.
There operations are basically never the only operation you are going to do to a thunk.
Often it is something like:
for i in 1:size(x, 1); for j in size(x, 2); net += x[i, j]; end; end
that is going to be really slow compare to just unthunking at the start, and then doing that.

@mzgubic
Copy link
Member Author

mzgubic commented Jun 17, 2021

I did (let's make it work, and not worry about performance).

But I will investigate it case by case, and then make up my mind.

@oxinabox
Copy link
Member

oxinabox commented Jun 17, 2021

we could put in a warning, if in debug mode?
Maybe out of scope for this PR?

@mzgubic
Copy link
Member Author

mzgubic commented Jun 18, 2021

Documentation is failing in a weird way:

MethodError: convert(::Type{Any}, ::Thunk{Main.##doctest-#328.var"#4#6"}) is ambiguous

julia> include("make.jl")
[ Info: Precompiling ChainRulesCore [d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4]
[ Info: SetupBuildDirectory: setting up build directory.
[ Info: Doctest: running doctests.
ERROR: LoadError: MethodError: convert(::Type{Any}, ::Thunk{Main.##doctest-#328.var"#4#6"}) is ambiguous. Candidates:
  convert(T, a::AbstractThunk) in ChainRulesCore at /Users/mzgubic/JuliaEnvs/ChainRules.jl/dev/ChainRulesCore/src/differentials/thunks.jl:42
  convert(::Type{Any}, x) in Base at essentials.jl:204
  convert(::Type{T}, x::T) where T>:Union{Missing, Nothing} in Base at missing.jl:68
  convert(::Type{T}, x::T) where T>:Nothing in Base at some.jl:35
  convert(::Type{T}, x::T) where T>:Missing in Base at missing.jl:67
  convert(::Type{T}, x::T) where T in Base at essentials.jl:205
  convert(::Type{T}, x) where T>:Union{Missing, Nothing} in Base at missing.jl:70
  convert(::Type{T}, x) where T>:Nothing in Base at some.jl:36
  convert(::Type{T}, x) where T>:Missing in Base at missing.jl:69
Possible fix, define
  convert(::Type{Any}, ::AbstractThunk)
Stacktrace:
  [1] setproperty!(x::Documenter.DocTests.Result, f::Symbol, v::Thunk{Main.##doctest-#328.var"#4#6"})
    @ Base ./Base.jl:34
  [2] eval_repl(block::Documenter.DocTests.MutableMD2CodeBlock, sandbox::Module, meta::Dict{Symbol, Any}, doc::Documenter.Documents.Document, page::String)
    @ Documenter.DocTests ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:224
  [3] doctest(ctx::Documenter.DocTests.DocTestContext, block_immutable::Documenter.Utilities.Markdown2.CodeBlock)
    @ Documenter.DocTests ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:170
  [4] (::Documenter.DocTests.var"#1#2"{Documenter.DocTests.DocTestContext})(node::Documenter.Utilities.Markdown2.CodeBlock)
    @ Documenter.DocTests ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:112
  [5] walk(f::Documenter.DocTests.var"#1#2"{Documenter.DocTests.DocTestContext}, node::Documenter.Utilities.Markdown2.CodeBlock)
    @ Documenter.Utilities.Markdown2 ~/.julia/packages/Documenter/6vUwN/src/Utilities/Markdown2.jl:297
  [6] walk(f::Function, nodes::Vector{Documenter.Utilities.Markdown2.MarkdownBlockNode})
    @ Documenter.Utilities.Markdown2 ~/.julia/packages/Documenter/6vUwN/src/Utilities/Markdown2.jl:306
  [7] walk(f::Documenter.DocTests.var"#1#2"{Documenter.DocTests.DocTestContext}, node::Documenter.Utilities.Markdown2.MD)
    @ Documenter.Utilities.Markdown2 ~/.julia/packages/Documenter/6vUwN/src/Utilities/Markdown2.jl:299
  [8] doctest
    @ ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:109 [inlined]
  [9] doctest(docstr::Base.Docs.DocStr, mod::Module, doc::Documenter.Documents.Document)
    @ Documenter.DocTests ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:89
 [10] doctest(blueprint::Documenter.Documents.DocumentBlueprint, doc::Documenter.Documents.Document)
    @ Documenter.DocTests ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:57
 [11] runner(#unused#::Type{Documenter.Builder.Doctest}, doc::Documenter.Documents.Document)
    @ Documenter.Builder ~/.julia/packages/Documenter/6vUwN/src/Builder.jl:214
 [12] dispatch(#unused#::Type{Documenter.Builder.DocumentPipeline}, x::Documenter.Documents.Document)
    @ Documenter.Utilities.Selectors ~/.julia/packages/Documenter/6vUwN/src/Utilities/Selectors.jl:170
 [13] #2
    @ ~/.julia/packages/Documenter/6vUwN/src/Documenter.jl:247 [inlined]
 [14] cd(f::Documenter.var"#2#3"{Documenter.Documents.Document}, dir::String)
    @ Base.Filesystem ./file.jl:106
 [15] #makedocs#1
    @ ~/.julia/packages/Documenter/6vUwN/src/Documenter.jl:246 [inlined]
 [16] top-level scope
    @ ~/JuliaEnvs/ChainRules.jl/dev/ChainRulesCore/docs/make.jl:24
 [17] include(fname::String)
    @ Base.MainInclude ./client.jl:444
 [18] top-level scope
    @ REPL[2]:1
in expression starting at /Users/mzgubic/JuliaEnvs/ChainRules.jl/dev/ChainRulesCore/docs/make.jl:24

And if I define

Base.convert(::Type{Any}, a::AbstractThunk) = convert(Type{Any}, unthunk(a))

Then it fails with

ERROR: LoadError: MethodError: no method matching (::Main.##doctest-#410.var"#4#6")() The applicable method may be too new: running in world age 29758, while current world is 29870

julia> include("make.jl")
┌ Warning: DocTestSetup already set for module ChainRulesCore. Overwriting.
└ @ Documenter.DocMeta ~/.julia/packages/Documenter/6vUwN/src/DocMeta.jl:81
[ Info: SetupBuildDirectory: setting up build directory.
[ Info: Doctest: running doctests.
ERROR: LoadError: MethodError: no method matching (::Main.##doctest-#410.var"#4#6")()
The applicable method may be too new: running in world age 29758, while current world is 29870.
Closest candidates are:
  (::Main.##doctest-#410.var"#4#6")() at none:1 (method too new to be called from this world context.)
Stacktrace:
  [1] (::Thunk{Main.##doctest-#410.var"#4#6"})()
    @ ChainRulesCore ~/JuliaEnvs/ChainRules.jl/dev/ChainRulesCore/src/differentials/thunks.jl:173
  [2] unthunk
    @ ~/JuliaEnvs/ChainRules.jl/dev/ChainRulesCore/src/differentials/thunks.jl:173 [inlined]
  [3] convert(#unused#::Type{Any}, a::Thunk{Main.##doctest-#410.var"#4#6"})
    @ ChainRulesCore ~/JuliaEnvs/ChainRules.jl/dev/ChainRulesCore/src/differentials/thunks.jl:44
  [4] setproperty!(x::Documenter.DocTests.Result, f::Symbol, v::Thunk{Main.##doctest-#410.var"#4#6"})
    @ Base ./Base.jl:34
  [5] eval_repl(block::Documenter.DocTests.MutableMD2CodeBlock, sandbox::Module, meta::Dict{Symbol, Any}, doc::Documenter.Documents.Document, page::String)
    @ Documenter.DocTests ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:224
  [6] doctest(ctx::Documenter.DocTests.DocTestContext, block_immutable::Documenter.Utilities.Markdown2.CodeBlock)
    @ Documenter.DocTests ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:170
  [7] (::Documenter.DocTests.var"#1#2"{Documenter.DocTests.DocTestContext})(node::Documenter.Utilities.Markdown2.CodeBlock)
    @ Documenter.DocTests ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:112
  [8] walk(f::Documenter.DocTests.var"#1#2"{Documenter.DocTests.DocTestContext}, node::Documenter.Utilities.Markdown2.CodeBlock)
    @ Documenter.Utilities.Markdown2 ~/.julia/packages/Documenter/6vUwN/src/Utilities/Markdown2.jl:297
  [9] walk(f::Function, nodes::Vector{Documenter.Utilities.Markdown2.MarkdownBlockNode})
    @ Documenter.Utilities.Markdown2 ~/.julia/packages/Documenter/6vUwN/src/Utilities/Markdown2.jl:306
 [10] walk(f::Documenter.DocTests.var"#1#2"{Documenter.DocTests.DocTestContext}, node::Documenter.Utilities.Markdown2.MD)
    @ Documenter.Utilities.Markdown2 ~/.julia/packages/Documenter/6vUwN/src/Utilities/Markdown2.jl:299
 [11] doctest
    @ ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:109 [inlined]
 [12] doctest(docstr::Base.Docs.DocStr, mod::Module, doc::Documenter.Documents.Document)
    @ Documenter.DocTests ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:89
 [13] doctest(blueprint::Documenter.Documents.DocumentBlueprint, doc::Documenter.Documents.Document)
    @ Documenter.DocTests ~/.julia/packages/Documenter/6vUwN/src/DocTests.jl:57
 [14] runner(#unused#::Type{Documenter.Builder.Doctest}, doc::Documenter.Documents.Document)
    @ Documenter.Builder ~/.julia/packages/Documenter/6vUwN/src/Builder.jl:214
 [15] dispatch(#unused#::Type{Documenter.Builder.DocumentPipeline}, x::Documenter.Documents.Document)
    @ Documenter.Utilities.Selectors ~/.julia/packages/Documenter/6vUwN/src/Utilities/Selectors.jl:170
 [16] #2
    @ ~/.julia/packages/Documenter/6vUwN/src/Documenter.jl:247 [inlined]
 [17] cd(f::Documenter.var"#2#3"{Documenter.Documents.Document}, dir::String)
    @ Base.Filesystem ./file.jl:106
 [18] #makedocs#1
    @ ~/.julia/packages/Documenter/6vUwN/src/Documenter.jl:246 [inlined]
 [19] top-level scope
    @ ~/JuliaEnvs/ChainRules.jl/dev/ChainRulesCore/docs/make.jl:24
 [20] include(fname::String)
    @ Base.MainInclude ./client.jl:444
 [21] top-level scope
    @ REPL[2]:1
in expression starting at /Users/mzgubic/JuliaEnvs/ChainRules.jl/dev/ChainRulesCore/docs/make.jl:24

😬

@mzgubic
Copy link
Member Author

mzgubic commented Jun 18, 2021

... less sever is: size, axes, and eltype.

Turns out we don't need them (after a more careful pass through the rules). I removed them so that if they show up they can serve as a reminder to unthunk.

@oxinabox
Copy link
Member

Are the docs build still failing?
Did you try restarting julia after adding that method ambiguity fix?

@mzgubic
Copy link
Member Author

mzgubic commented Jun 18, 2021

I removed those methods (they were no longer necessary after unthunking some more inside the rules).

I will review again but I think this is ready to merge

@mzgubic mzgubic merged commit 6073a47 into master Jun 18, 2021
@mzgubic mzgubic deleted the mz/thunks branch June 18, 2021 15:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants