A simple explanation of reverse-mode automatic differentiation

My previous rant about automatic differentiation generated several requests for an explanation of how it works. This can be confusing because there are different types of automatic differentiation (forward-mode, reverse-mode, hybrids.) This is my attempt to explain the basic idea of reverse-mode autodiff as simply as possible.

Reverse-mode automatic differentiation is most attractive when you have a function that takes n inputs x_1,x_2,...,x_n, and produces a single output x_N. We want the derivatives of that function, \displaystyle{\frac{d x_N}{d x_i}}, for all i.

Point #1: Any differentiable algorithm can be translated into a sequence of assignments of basic operations.


for i=n+1,n+2,...,N

x_i \leftarrow f_i({\bf x}_{\pi(i)})

Here, each function f_i is some very basic operation (e.g. addition, multiplication, a logarithm) and \pi(i) denotes the set of “parents” of x_i. So, for example, if \pi(7)=(2,5) and f_7 = \text{add}, then x_7 = x_2 + x_5.

It would be extremely tedious, of course, to actually write an algorithm in this “expression graph” form. So, autodiff tools create this representation automatically from high-level source code.

Point #2: Given an algorithm in the previous format, it is easy to compute its derivatives.

The essential point here is just the application of the chain rule.

\displaystyle{ \frac{d x_N}{d x_i} = \sum_{j:i\in \pi(j)} \frac{d x_N}{d x_j}\frac{\partial x_j}{\partial x_i}}

Applying this, we can compute all the derivatives in reverse order.


\displaystyle{ \frac{d x_N}{d x_N} \leftarrow 1}

for i=N-1,N-2,...,1

\displaystyle{ \frac{d x_N}{d x_i} \leftarrow \sum_{j:i\in \pi(j)} \frac{d x_N}{d x_j}\frac{\partial f_j}{\partial x_i}}

That’s it!  Just create an expression graph representation of the algorithm and differentiate each basic operation f_i in reverse order using calc 101 rules.

Other stuff:

  • No, this is not the same thing as symbolic differentiation.  This should be obvious:  Most algorithms don’t even have simple symbolic representations.  And, even if yours does,  it is possible that it “explodes” upon symbolic differentiation.  As a contrived example, try computing the derivative of \exp(\exp(\exp(\exp(\exp(\exp(x)))))).
  • The complexity of the back-prop step is the same as the forward propagation step.
  • In machine learning, functions from N inputs to one output come up all the time:  The N inputs are parameters defining a model, and the 1 output is a loss, measuring how well the model fits training data.  The gradient can be fed into an optimization routine to fit the model to data.
  • There are two common ways of implementing this:
  1. Operator Overloading.  One can create a new variable type that has all the common operations of numeric types, which automatically creates an expression graph when the program is run.   One can then call the back-prop routine on this expression graph.  Hence,  one does not need to modify the program, just replace each numeric type with this new type.  This is fairly easy to implement, and very easy to use.  David Gay‘s RAD toolbox for C++ is a good example, which I use all the time.
    The major downside of operator overloading is efficiency:  current compilers will not optimize the backprop code.  Essentially, this  step is interpreted.  Thus, one finds in practice a non-negligible overhead of, say, 2-15 times the complexity of the original algorithm using a native numeric type.  (The overhead depends on how much the original code benefits from compiler optimizations.)
  2. Source code transformation. Alternatively, one could write a program that examines the source code of the original program, and transforms this into source code computing the derivatives.  This is much harder to implement, unless one is using a language like Lisp with very uniform syntax.  However, because the backprop source code produced is then optimized like normal code, it offers the potential of zero overhead compared with manually computed derivatives.
  • If it isn’t convenient to use automatic differentiation, one can also use “manual automatic differentation”.  That is, to compute the derivatives, just attack each intermediate value your algorithm computes, in reverse order.
  • Some of the most interesting work on autodiff comes from Pearlmutter and Siskind, who have produced a system called Stalingrad for a subset of scheme that allows for crazy things like taking derivatives of code that itself is taking derivates.  (So you can, for example, produce Hessians.)  I think they wouldn’t mind hearing from potential users.