Skip to content

rules for specific ADs #270

@willtebbutt

Description

@willtebbutt

Lets assume that Diffractor is going to be better at some things than Zygote is and, as a consequence, there exist rules that we don't want to Diffractor to hit (since it generates perfectly good code anyway) but we do want Zygote to hit. Presently, we don't have a way to specify this.

It would be simple to achieve this via dispatch, and adding an additional argument to ones rrules, making them have the following signature:

abstract type AbstractAD end
struct ZygoteAD <: AbstractAD end
struct DiffractorAD <: AbstractAD end

rrule(::AbstractAD, ::typeof(f), args...) # applicable to all ADs
rrule(::ZygoteAD, ::typeof(f), args...) # only applicable to Zygote
rrule(::DiffractorAD, ::typeof(f), args...) # only applicable to Diffractor

This would also alleviate some of our existing headaches surrounding rrules for "very abstractly typed" arguments, since we could implement generic versions of things that we're sure ought to work (e.g. *(::Matrix{Float64}, Matrix{Float64})) in ChainRules, without requiring package authors to compromise on existing choices that they've made -- for example Zygote uses very abstract types for lots of things and, while I don't like it, it would be really very breaking to make changes to it at this point in time.

Recall that there are essentially 3 reasons (I think?) to implement a rule:

  1. Mathematical insight leads to a completely different algorithm than would be derived automatically by any (existing) AD tool. Anything that uses the Implicit Function Theorem to derive efficient rrules that avoid storing intermediate state are good examples of this, e.g. rrules for optimisation and (nice) ODEs.
  2. For some reason it's more efficient to manually write out the algorithm than to have a particular AD derive it.
  3. An AD doesn't know how to differentiate a particular function, but you do, so you write a rule.

Rules of type 1 are those for which you would consider writing a very generic rrule, so you would probably write them to accept any AbstractAD.

Rules of type 2 are somewhat borderline and would probably need to be done on a case-by-case basis. For example, you might write a custom adjoint for a function involving a for-loop if using Zygote, but might not need the rule at all if using Diffractor. While Zygote can usually differentiate through for-loops, it tends to be slow.

Rules of type 3 are prime candidates for AD-specific rules, since different ADs are able to differentiate through different language features.

This is related to #68 in that we're talking about including some kind of additional information about what ADs to use, but the underlying problem that it addresses is somewhat different.

Metadata

Metadata

Assignees

No one assigned

    Labels

    designRequires some desgin before changes are madeenhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions