Skip to content

Automatically provide tangents#116

Merged
oxinabox merged 15 commits intomasterfrom
ox/autotangent
Feb 5, 2021
Merged

Automatically provide tangents#116
oxinabox merged 15 commits intomasterfrom
ox/autotangent

Conversation

@oxinabox
Copy link
Member

@oxinabox oxinabox commented Feb 3, 2021

Closes #67
Does option 2D witth x ⊢ ẋ
that is \vdash + tab
We can argue about symbols though.

before:

x = rand(4, 5)
x̄ = rand(4, 5)

ȳ = rand(2, 10)

rrule_test(reshape, ȳ, (x, x̄), ((2, 10), nothing))
rrule_test(reshape, ȳ, (x, x̄), (2, nothing), (10, nothing))

After

No point declaring x in advance any-more.
Can just put the rand in-place

test_rrule(reshape, rand(4, 5), (2, 10)   nothing)
test_rrule(reshape, rand(4, 5), 2    nothing, 10   nothing)

@oxinabox oxinabox marked this pull request as draft February 3, 2021 19:48
@oxinabox oxinabox mentioned this pull request Feb 4, 2021
@oxinabox oxinabox changed the title WIP: automatically provide tangents Automatically provide tangents Feb 4, 2021
@oxinabox oxinabox marked this pull request as ready for review February 4, 2021 19:35
@oxinabox
Copy link
Member Author

oxinabox commented Feb 4, 2021

I have not added tests for this explictly, but some places we still use this because we explictly want to pass special tangents in.
So we basically have real world based tests that this works.
So i think this is good.
but review should comment if they disagree

@oxinabox
Copy link
Member Author

oxinabox commented Feb 4, 2021

I would say that we should wait til ChainRules is updated to use this, before merging.
But it is going to be a big job to update all the tests. Probably multiple PRs.
So I think we should merge this, and then patch it if we find something was missed.

This is nonbreaking as the old way is deprecated.

function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e-9, fdm=_fdm, check_inferred::Bool=true, fkwargs::NamedTuple=NamedTuple(), kwargs...)
function test_rrule(
f, inputs...;
output_tangent=Auto(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be output_cotangent?

src/testers.jl Outdated
y = f(xs...; fkwargs...)
check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct

ȳ = tangent(auto_primal_and_tangent(y ⊢ output_tangent))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit stilly to be attaching a tangent to the primal explicitly, then pealing it off, but it makes both given and Auto() work

Copy link
Member

@mzgubic mzgubic Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we just use rand_tangent(y)? I don't understand this part of your comment:

but it makes both given and Auto() work

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the alternative is:

Suggested change
= tangent(auto_primal_and_tangent(y output_tangent))
= output_tangent === Auto() ? rand_tangent(primal) : output_tangent

Which maybe is cleaner?

idk yesterday I was hating on branches.
I started to write auto_tangent(primal, output_tangent) and realized that it overlapped with auto_primal_and_tangent

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ȳ = output_tangent === Auto() ? rand_tangent(primal) : output_tangent is certainly much clearer, to my eyes

@oxinabox
Copy link
Member Author

oxinabox commented Feb 4, 2021

Ah, I forgot to push some changes.
Namely the ones that update all the tests
Will do so tomorrow

@mzgubic
Copy link
Member

mzgubic commented Feb 5, 2021

That's wonderful! Please ping me when it is done, happy to review

Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love this change, makes the API so much easier to use. Should we also update the examples in the docs to reflect the new usage?

src/testers.jl Outdated
y = f(xs...; fkwargs...)
check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct

ȳ = tangent(auto_primal_and_tangent(y ⊢ output_tangent))
Copy link
Member

@mzgubic mzgubic Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we just use rand_tangent(y)? I don't understand this part of your comment:

but it makes both given and Auto() work

Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
@oxinabox oxinabox merged commit 340e21c into master Feb 5, 2021
@CarloLucibello
Copy link

CarloLucibello commented Feb 6, 2021

I just got the deprecation message

┌ Warning: `rrule_test(f, ȳ, inputs::Tuple{Any, Any}...; kwargs...)` is deprecated, use `test_rrule(f, (x ⊢ dx for (x, dx) = inputs)...; output_tangent = ȳ, kwargs...)` instead.

and it looked a bit too exotic. Why not pairs => instead of vdash?

edit: maybe you need to differentiate with respect to a pair, I see now the problem

If `check_inferred=true`, then the inferrability of the `rrule` is checked — if `f` is
itself inferrable — along with the inferrability of the pullback it returns.
All remaining keyword arguments are passed to `isapprox`.
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ẋ should be x̄.
Maybe it should be explained also here how to produce a vdash

@oxinabox
Copy link
Member Author

oxinabox commented Feb 6, 2021

@CarloLucibello note the deprecations warning is misleading as to the correct action.
Correct is to just delete the partials and let the automatic process kick in.

@nickrobinson251
Copy link
Contributor

@CarloLucibello note the deprecations warning is misleading as to the correct action.
Correct is to just delete the partials and let the automatic process kick in.

Maybe manually define the deprecation, with a depwarn giving the better instructions, rather than rely on @deprecate

Otherwise people will do what the deprecation warning tells them to

@oxinabox
Copy link
Member Author

oxinabox commented Feb 6, 2021

Good idea, I will see if I find time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Don't make user provide tangents

5 participants