Skip to content

[RFC][WIP] Tensor Expression level automatic differentiation #1996

@sgrechanik-h

Description

@sgrechanik-h

I'm working on automatic differentiation at the level of compute expressions, and I would like to share some progress and hear any comments. Currently the automatic differentiation works well enough for some operations, so that it is possible to train a simple model, here is a tutorial on how to do this. Yet, for many operations the performance is unacceptable, but I'm working on it.

My implementation mostly follows this paper. In this notebook I describe how my implementation works internally and give a list of operations which are known to work or not to work. Basically, the AD consists of two parts:

  • The automatic differentiation itself which simply differentiates expressions according to the well-known rules and produces inefficient expressions. The code is here.
  • A set of transformations to optimize the resulting inefficient expressions. The code is here.

All transformations work on the level of compute expressions (before scheduling). Their general goal is to eliminate summation over zeros by moving up conditional expressions of the form cond ? val : 0 and then using them to simplify iteration domains of reductions. Hopefully, these transformations may be useful for some other tasks besides AD when they are powerful enough. Currently the main problem is that they don't understand modular arithmetic (which is needed for differentiating dilated and strided convolutions and for the flattening operation).

The git branch
The squashed commit
The tutorial on training a simple model
The notebook describing some internals

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions