A simple explanation of reverse-mode automatic differentiation

My previous rant about automatic differentiation generated several requests for an explanation of how it works. This can be confusing because there are different types of automatic differentiation (forward-mode, reverse-mode, hybrids.) This is my attempt to explain the basic idea of reverse-mode autodiff as simply as possible.

Reverse-mode automatic differentiation is most attractive when you have a function that takes $n$ inputs $x_1,x_2,...,x_n$, and produces a single output $x_N$. We want the derivatives of that function, $\displaystyle{\frac{d x_N}{d x_i}}$, for all $i$.

Point #1: Any differentiable algorithm can be translated into a sequence of assignments of basic operations.

Forward-Prop

for $i=n+1,n+2,...,N$

$x_i \leftarrow f_i({\bf x}_{\pi(i)})$

Here, each function $f_i$ is some very basic operation (e.g. addition, multiplication, a logarithm) and $\pi(i)$ denotes the set of “parents” of $x_i$. So, for example, if $\pi(7)=(2,5)$ and $f_7 = \text{add}$, then $x_7 = x_2 + x_5$.

It would be extremely tedious, of course, to actually write an algorithm in this “expression graph” form. So, autodiff tools create this representation automatically from high-level source code.

Point #2: Given an algorithm in the previous format, it is easy to compute its derivatives.

The essential point here is just the application of the chain rule.

$\displaystyle{ \frac{d x_N}{d x_i} = \sum_{j:i\in \pi(j)} \frac{d x_N}{d x_j}\frac{\partial x_j}{\partial x_i}}$

Applying this, we can compute all the derivatives in reverse order.

Back-Prop

$\displaystyle{ \frac{d x_N}{d x_N} \leftarrow 1}$

for $i=N-1,N-2,...,1$

$\displaystyle{ \frac{d x_N}{d x_i} \leftarrow \sum_{j:i\in \pi(j)} \frac{d x_N}{d x_j}\frac{\partial f_j}{\partial x_i}}$

That’s it!  Just create an expression graph representation of the algorithm and differentiate each basic operation $f_i$ in reverse order using calc 101 rules.

Other stuff:

• No, this is not the same thing as symbolic differentiation.  This should be obvious:  Most algorithms don’t even have simple symbolic representations.  And, even if yours does,  it is possible that it “explodes” upon symbolic differentiation.  As a contrived example, try computing the derivative of $\exp(\exp(\exp(\exp(\exp(\exp(x))))))$.
• The complexity of the back-prop step is the same as the forward propagation step.
• In machine learning, functions from N inputs to one output come up all the time:  The N inputs are parameters defining a model, and the 1 output is a loss, measuring how well the model fits training data.  The gradient can be fed into an optimization routine to fit the model to data.
• There are two common ways of implementing this:
1. Operator Overloading.  One can create a new variable type that has all the common operations of numeric types, which automatically creates an expression graph when the program is run.   One can then call the back-prop routine on this expression graph.  Hence,  one does not need to modify the program, just replace each numeric type with this new type.  This is fairly easy to implement, and very easy to use.  David Gay‘s RAD toolbox for C++ is a good example, which I use all the time.
The major downside of operator overloading is efficiency:  current compilers will not optimize the backprop code.  Essentially, this  step is interpreted.  Thus, one finds in practice a non-negligible overhead of, say, 2-15 times the complexity of the original algorithm using a native numeric type.  (The overhead depends on how much the original code benefits from compiler optimizations.)
2. Source code transformation. Alternatively, one could write a program that examines the source code of the original program, and transforms this into source code computing the derivatives.  This is much harder to implement, unless one is using a language like Lisp with very uniform syntax.  However, because the backprop source code produced is then optimized like normal code, it offers the potential of zero overhead compared with manually computed derivatives.
• If it isn’t convenient to use automatic differentiation, one can also use “manual automatic differentation”.  That is, to compute the derivatives, just attack each intermediate value your algorithm computes, in reverse order.
• Some of the most interesting work on autodiff comes from Pearlmutter and Siskind, who have produced a system called Stalingrad for a subset of scheme that allows for crazy things like taking derivatives of code that itself is taking derivates.  (So you can, for example, produce Hessians.)  I think they wouldn’t mind hearing from potential users.

15 thoughts on “A simple explanation of reverse-mode automatic differentiation”

1. Song Chen says:

Hi , I have a question that
how to compute the dXn/dXn-1
due to the formula above , I will get
dXn/dXn-1=partial(Xn)/partial(Xn-1)
is this coorect??

2. justindomke says:

If I understand you, you are asking if the full derivative of X_n with respect to X_{n-1} is equal to the partial derivative. I believe the answer for that is yes. The reason is t hat X_{n-1} doesn’t have the chance to influence any other variables before X_n is computed. It wouldn’t be true in general, however, for say the derivative of X_n with respect to X_{n-2}, since the full derivative would take into account the influence of X_{n-2} on X_{n-1} but the partial derivative would not.

3. Song Chen says:

So, I have a test case like below:
let
t1=2 ; ( or some arbitrary value )
t2=t1^2 ;
t3=t1^2+t2^2 ;

then , using the reverse AD mode ,
dt3/dt2 should be partial(t3)/partial(t2) , right ?
so that would be 2*t2.

Howeever, if we just use some mathematical substitution, since t2=t1^2, we can rewrite the t3 like below:

t3=t2+t2^2

so obviously the total derivative of t3 respect to t2 would be

dt3/dt2=1+2*t2

4. justindomke says:

The confusion here is that we have to remember each variable is a function of the previous variables. By substituting as you have done above, you have implicitly used two different function structures. You should really write something like this:

t1 = f1();
t2 = f2(t1)
t3 = f3(t1,t2)

for appropriately defined f1, f2, and f3. Then, calculating derivatives will make more sense. The second substitution you have made above essentially re-writes things as

t2 = f2()
t1 = f1(t2)
t3 = f3(t1,t2)

which is a totally different animal!

5. Song Chen says:

Oh , I have one more question.

From mathematical point of view, let’s say a function like this

f=F(x)

so the partial and total derivative of x respect to f
dx/df and partial(x)/partial(f)
do they have some sort of real meaning?
or they just equal to zero.

Sometimes we can just write an inverse function of f=F(x) as x=g(f), however there’s no guarantee that this inverse function uniquely exist.

6. justindomke says:

I guess, technically, you can’t differentiate a variable, you can only differentiate a function. Writing dy/dx is a bit of a shorthand for differentiating the function f(x) that produces y. So, technically the things you wrote have no meaning in general. You could (if F is invertible) define

x = G(f), where G(x)=F^-1(x)

then it would make sense to talk about dx/df = dG/df.

7. Song Chen says:

Right, I see.

Thank you ! 🙂

8. I’ve just implemented reverse differentiation in C++ using operator overloading without a dynamic stack being created and interpreted.

9. One of the clearest articles on the topic I have read – thanks for posting!

10. Thank you for the useful post, prof. Domke!

One question please: I don’t see how you came up with that chain rule. For me there seems to be a notational issue: on the left-hand side, the $d$ (for total derivative) and the $\partial$ are switched. Shouldn’t it read:
$\frac{d x_N}{d x_i} = \sum_{j: i\in\pi(j} \frac{\partial x_N}{\partial x_j} \frac{d x_j}{d x_i}$?

11. justindomke says:

I think that looks good to me as is… (though I’m happy to be corrected!) The semantics for that equation are:

(i’s total influence on the output) = (sum over all descendants j of i) (j’s total influence on the output) (i’s direct influence on j)

I’m using partials to denote “direct” influence and totals to represent total influence.