-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
I'm working on automatic differentiation at the level of compute expressions, and I would like to share some progress and hear any comments. Currently the automatic differentiation works well enough for some operations, so that it is possible to train a simple model, here is a tutorial on how to do this. Yet, for many operations the performance is unacceptable, but I'm working on it.
My implementation mostly follows this paper. In this notebook I describe how my implementation works internally and give a list of operations which are known to work or not to work. Basically, the AD consists of two parts:
- The automatic differentiation itself which simply differentiates expressions according to the well-known rules and produces inefficient expressions. The code is here.
- A set of transformations to optimize the resulting inefficient expressions. The code is here.
All transformations work on the level of compute expressions (before scheduling). Their general goal is to eliminate summation over zeros by moving up conditional expressions of the form cond ? val : 0 and then using them to simplify iteration domains of reductions. Hopefully, these transformations may be useful for some other tasks besides AD when they are powerful enough. Currently the main problem is that they don't understand modular arithmetic (which is needed for differentiating dilated and strided convolutions and for the flattening operation).
The git branch
The squashed commit
The tutorial on training a simple model
The notebook describing some internals