Truncated Bi-Level Optimization

In 2012, I wrote a paper that I probably should have called “truncated bi-level optimization”.  I vaguely remembered telling the reviewers I would release some code, so I’m finally getting around to it.

The idea of bilevel optimization is quite simple.  Imagine that you would like to minimize some function $L(w)$.  However, $L$ itself is defined through some optimization.  More formally, suppose we would like to solve

$\min_{w,y} Q(y), \text{ s.t. } y = \arg\min_y E(y,w).$

Or, equivalently,

$\min_{w} L(w)=Q(y^*(w)),$

where $y^*$ is defined as $y^*(w) := \arg\min_y E(y,w)$.  This seems a little bit obscure at first, but actually comes up in several different ways in machine learning and related fields.

Hyper-parameter learning

The first example would be in learning hyperparameters, such as regularization constants.  Inevitably in machine learning, one fits parameters parameters to optimize some tradeoff between the quality of a fit to training data and a regularization function being small.  Traditionally, the regularization constant is selected by optimizing on a training dataset with a variety of values, and then picking the one that performs best on a held-out dataset.  However, if there are a large number of regularization parameters, a high-dimensional grid-search will not be practical.  In the notation above, suppose that $w$ is a vector of regularization constants, and that $y$ are training parameters.  Let, $E$  be the regularized empirical risk on a training dataset, and let $Q$ be how well the parameters $y^*(w)$ perform on some held-out validation dataset.

Energy-based models

Another example (and the one suggesting the notation) is an energy-based model.  Suppose that we have some “energy” function $E_x(y,w)$ which measures how well an output $y$ fits to an input $x$.  The energy is parametrized by $w$.  For a given training input/output pair $(\hat{x},\hat{y})$, we might have that $Q_{\hat{y}}(y^*(w))$ measures how how the predicted output $y^*$ compares to the true output $\hat{y}$, where $y^*(w)=\arg\max_y E_{\hat{x}}(y,w)$.

Computing the gradient exactly

Even if we just have the modest goal of following the gradient of $L(w)$ to a local minimum, even computing the gradient is not so simple.  Clearly, even to evaluate $L(w)$ requires solving the “inner” minimization of $E$. It turns out that one can compute $\nabla_w L(w)$ through first solving the inner minimization, and then solving a linear system.

1. Input $w$
2. Solve $y^* \leftarrow \arg\min_y E(y,w)$.
3. Compute:
4.    (a) the loss $L(w)= Q(y^*)$
5.    (b) the gradient $g=\nabla_y Q(y^*)$
6.    (c) the Hessian $H=\frac{\partial^2 E(y^*,w)}{\partial w\partial w^T}$
7. Solve the linear system $z=H^{-1} g$.
8. Compute the parameter gradient $\nabla_w L(w) = - \frac{\partial^2 E(y^*,w)}{\partial w\partial y^T} z$
9. Return $L(w)$ and $\nabla_w L(w)$.

This looks a bit nasty, since we need to compute second-derivative matrices of $E$.  In fact, as long as one has a routine to compute $\nabla_y E$ and $\nabla_w E$, this can be avoided through Efficient Matrix-vector products.  This is essentially proposed by Do, Foo, and Ng in “Efficient multiple hyperparameter learning for log-linear models”.

Overall, this is a decent approach, but it can be quite slow, simply because one must solve an “inner” optimization in order to compute each gradient of the “outer” optimization.  Often, the inner-optimization needs to be solved to very high accuracy in order to estimate a gradient accurately enough to reduce $L(w)$— higher accuracy than is needed when one is simply using the predicted value $y^*$ itself.

Truncated optimization

To get around this expense, a fairly obvious idea is to re-define the problem.  The slowness of exactly computing the gradient stems from needing to exactly solve the inner optimization.  Hence, perhaps we re-define the problem such that an inexact solve of the inner problem nevertheless yields an “exact” gradient?

Re-define the problem as solving

$\min_{w} L(w)=Q(y^*(w)), \text{ } y^*(w) := \text{opt-alg}_y E(y,w)$,

where $\text{opt-alg}$ denotes an approximate solve of the inner optimization.  In order for this to work, $\text{opt-alg}$ must be defined in such a way that $y^*(w)$ is a continuous function of $w$.  With standard optimization methods such as gradient descent or BFGS, this can be achieved by assuming there are a fixed number of iterations applied, with a fixed step-size.  Since each iteration of these algorithms is continuous, this clearly defines $y^*(w)$ as a continuous function.  Thus, in principle, it could be optimized efficiently through automatic differentiation of the code that optimizes $E$.  That’s fine in principle, but often inconvenient in practice.

It turns out, however, that one can derive “backpropagating” versions of algorithms like gradient descent, that take as input only a procedure to compute $E$ along with it’s first derivatives.  These algorithms can then produce the gradient of $L$ in the same time as automatic differentiation.

If the inner-optimization is gradient descent for $N$ steps with a step-size of $\lambda$, the algorithm to compute the loss is simple:

1. Input $w$
2. For $k=0,1,...,N-1$
3.     (a) $y_{k+1} \leftarrow y_k - \lambda \nabla_y E(y_k,w)$
4. Return $L(w) = Q(y_N)$

How to compute the gradient of this quantity?  The following algorithm does the trick.

1. $\overleftarrow{y_N} \leftarrow \nabla Q(y_N)$
2. $\overleftarrow{w} \leftarrow 0$
3. For $k=N-1,...,0$
4.     (a) $\overleftarrow{w} \leftarrow \overleftarrow{w} - \lambda \frac{\partial^2 E(y_k,w)}{\partial w \partial y^T} \overleftarrow{y_{k+1}}$
5.     (b) $\overleftarrow{y_k} \leftarrow \overleftarrow{y_{k+1}} - \lambda \frac{\partial^2 E(y_k,w)}{\partial y \partial y^T} \overleftarrow{y_{k+1}}$
6. Return $\nabla L = \overleftarrow{w}$.

Similar algorithms can be derived for the heavy-ball algorithm (with a little more complexity) and limited memory BFGS (with a lot more complexity).

Code

So, finally, here is the code, and I’ll give a simple example of how to use it. There are just four simple files:

I think the meanings of this are pretty straightforward, so I’ll just quickly step through the demo here.  I’ll start off by grabbing taking one of Matlab’s built-in datasets (on cities) so that we are trying to predict a measure of crime from measures of climate, housing, health, transportation, arts, recreation, and economy, as well as a constant.  There are 329 data, total, which I split into a training set of size 40, a validation set of size 160, and a test set of size 129.

load cities ratings
X = ratings;
for i=1:size(X,2)
X(:,i) = X(:,i) - mean(X(:,i));
X(:,i) = X(:,i) / std( X(:,i));
end

% predict crime from climate, housing, health, trans, edu, arts, rec, econ,
Y = X(:,4);
X = [X(:,[1:3 5:9]) 1+0*X(:,1)];

p = randperm(length(Y));
X = X(p,:);
Y = Y(p);

whotrain = 1:50;
whoval   = 51:200;
whotest  = 201:329;

Xtrain = X(whotrain,:);
Xval   = X(whoval  ,:);
Xtest  = X(whotest ,:);
Ytrain = Y(whotrain);
Yval   = Y(whoval  );
Ytest  = Y(whotest );

Next, I’ll set up some simple constants that will be used later on, and define the optimization parameters for minFunc, that I will be using for the outer optimization. In particular, here I choose the inner optimization to use 20 iterations.

opt_iters = 20;
ndims = size(Xtrain,2); w0    = zeros(ndims,1); ndata = size(Xtrain,1); ndata_val = size(Xval,1);
options = []; options.Method  = 'gd'; options.LS      = 1; options.MaxIter = 100;

Now, I’ll define the training risk function ($E$ in the notation above). The computes the risk with a regularization constant of $a$, as well as derivatives. I’ll also define the validation risk ($Q$ in the notation above).

function [R dRdw dRdloga] = training_risk(w,loga)
a         = exp(loga);
R         = sum( (Xtrain*w - Ytrain).^2 )/ndata + a*sum(w.^2);
dRdw      = 2*Xtrain'*(Xtrain*w-Ytrain)  /ndata + 2*a*w;
dRda      = sum(w.^2);
dRdloga = dRda*a;
end

function [R g] = validation_risk(w)
R = sum( (Xval*w - Yval).^2 ) / ndata_val;
g = 2*Xval'*(Xval*w-Yval)     / ndata_val;
end

Now, before going any further, let’s do a traditional sweep through regularization constants to see what that looks like.

LAMBDA = -5:.25:2;
VAL_RISK = 0*LAMBDA;
for i=1:length(LAMBDA)
VAL_RISK(i) = back_lbfgs(@training_risk,@validation_risk,w0,LAMBDA(i),opt_iters);
end

Plotting, we get the following:

This is a reasonable looking curve. Instead, let’s ask the algorithm to find the constant by gradient descent.

eval = @(loga) back_lbfgs(@training_risk,@validation_risk,w0,loga,opt_iters);
loga = 0;
[loga fval] = minFunc(eval,loga,options);

Running the optimization, we see:

Iteration   FunEvals     Step Length    Function Val        Opt Cond
1          2     1.00000e+00     8.74176e-01     3.71997e-03
2          3     1.00000e+00     8.73910e-01     1.86453e-04
3          4     1.00000e+00     8.73909e-01     1.06619e-05
4          5     1.00000e+00     8.73909e-01     3.60499e-08

This leads to a regularizer of the form:

$0.860543 ||w||_2^2$.

We can plot this on the graph, and see it matches the result of cross-validation.

If we actually compute the test-set error, this is 0.708217.

OK, let’s be a little bit more adventurous, and use a third-order regularizer. This is done like so:

    function [R dRdw dRdloga] = training_risk2(w,loga)
a         = exp(loga(1));
b         = exp(loga(2));
R         = sum( (Xtrain*w - Ytrain).^2 ) / ndata + a*sum(abs(w).^2)     + b*sum(abs(w).^3);
dRdw      = 2*Xtrain'*(Xtrain*w-Ytrain)   / ndata + a*abs(w).^1.*sign(w) + b*abs(w).^2.*sign(w);
dRda      = sum(abs(w).^2);
dRdb      = sum(abs(w).^3);
dRdloga = [dRda*a; dRdb*b];
end

Running the optimization, we see:

Iteration   FunEvals     Step Length    Function Val        Opt Cond
1          2     1.00000e+00     8.74445e-01     1.17262e-02
2          3     1.00000e+00     8.73685e-01     3.21956e-03
3          4     1.00000e+00     8.73608e-01     1.41744e-03
4          5     1.00000e+00     8.73598e-01     8.20040e-04
5          6     1.00000e+00     8.73567e-01     1.39830e-03
6          7     1.00000e+00     8.73513e-01     2.52994e-03
7          8     1.00000e+00     8.73471e-01     1.77157e-03
...
23         28     1.00000e+00     8.70741e-01     8.42628e-06

With a final regularizer of the form

$0.000160 ||w|_2^2 + 5.117057 * ||w||_3^3$

and a hardly-improved test error of 0.679155.

Finally, let’s fit a fourth-order polynomial.

    function [R dRdw dRdloga] = training_risk3(w,loga)
a         = exp(loga);
R         = sum( (Xtrain*w - Ytrain).^2 ) / ndata;
dRdw      = 2*Xtrain'*(Xtrain*w-Ytrain)   / ndata;
for ii=1:length(a)
b=ii+1;
R    = R    + a(ii)*sum(abs(w).^b);
dRdw = dRdw + a(ii)*abs(w).^(b-1).*sign(w);
dRda(ii,1) = sum(abs(w).^b);
end
dRdloga = dRda.*a;
end

eval = @(loga) back_lbfgs(@training_risk3,@validation_risk,w0,loga,opt_iters);
loga = [0; 0; 0; 0];
loga = minFunc(eval,loga,options);

Running the optimization, we see

Iteration   FunEvals     Step Length    Function Val        Opt Cond
1          2     1.00000e+00     8.73982e-01     1.02060e-02
2          3     1.00000e+00     8.73524e-01     3.30445e-03
3          4     1.00000e+00     8.73447e-01     1.74779e-03
4          5     1.00000e+00     8.73435e-01     1.43655e-03
5          6     1.00000e+00     8.73345e-01     4.98340e-03
6          7     1.00000e+00     8.73295e-01     1.85535e-03
7          8     1.00000e+00     8.73231e-01     1.81136e-03
...
38         41     1.00000e+00     8.69758e-01     3.44848e-06

Yielding the (strange) regularizer

$0.000000 ||w||_2^2 + 0.000010 ||w||_3^3 + 24.310631 ||w||_4^4 + 0.325565 ||w||_5^5$

and a final test-error of 0.67187.

Fitting an inference algorithm instead of a model

One recent trend seems to be the realization that one can get better performance by tuning a CRF (Conditional Random Field) to a particular inference algorithm. Basically, forget about the distribution that the CRF represents, and instead only care how accurate are the results that pop out of inference. An extreme example of this is the recent paper Learning Real-Time MRF Inference for Image Denoising by Adrian Barbu.

The basic idea is to fit a FoE (Field of Experts) image prior such that when one takes a very few gradient descent steps on a denoising posterior, the results are accurate. From the abstract:

We argue that through appropriate training, a MRF/CRF model can be trained to perform very well on a suboptimal inference algorithm. The model is trained together with a fast inference algorithm through an optimization of a loss function […] We apply the proposed method to an image denoising application […] with a 1-4 iteration gradient descent inference algorithm. […] the proposed training approach obtains an improved benchmark performance as well as a 1000-3000 times speedup compared to the Fields of Experts MRF.

The implausible-sounding 1000-fold speedup comes simply from using only 4 iterations of gradient descent rather than several thousand. (Incidentally, L-BFGS requires far fewer iterations for this problem.) The results are a bit better than the generative FoE model– that takes much more work for training and inference.

I have every confidence that this does work well, and similar strategies could probably be used to create fast inference models/algorithms for many different problems. My thesis was largely an attempt to do the same thing for marginal, rather than MAP inference.

The disturbing/embarrassing question, for me, is does this really have anything to do with probabilistic modeling any more? Formally speaking, a probability density is being fit, but I doubt it would transfer to, say, inpainting, or that samples from the density would look like natural images. The best interpretation of what is happening might be that one is simply fitting a big, nonlinear, black box function approximation.

It seems that the more effort we expend to wring the best performance out of a probabilistic model, the less “probabilistic” it is.

What Gauss-Seidel is Really Doing

I’ve been reading Alan Sokal’s lecture notes “Monte Carlo Methods in Statistical Mechanics: Foundations and New Algorithms” today. Once I learned to take the word “Hamiltonian” and mentally substitute “function to be minimized”, they are very clearly written.

Anyway, the notes give an explanation of the Gauss-Seidel iterative method for solving linear systems that is so clear, I feel a little cheated that it was never explained to me before. Let’s start with the typical (confusing!) explanation of Gauss-Seidel (as in, say, Wikipedia). You want to solve some linear system

$A{\bf x}={\bf b}$,

for $\bf x$. What you do is decompose $A$ into a diagonal matrix $D$, a strictly lower triangular matrix $L$, and a strictly upper triangular matrix $U$, like

$A=L+D+U$.

To solve this, you iterate like so:

${\bf x} \leftarrow (D+L)^{-1}({\bf b}-U{\bf x})$.

This works when $A$ is symmetric positive definite.

Now, what the hell is going on here? The first observation is that instead of inverting $A$, you can think of minimizing the function

$\frac{1}{2} {\bf x}^T A {\bf x} - {\bf x}^T{\bf b}$.

(It is easy to see that the global minimum of this is found when $A{\bf x}={\bf b}$.) Now, how to minimize this function? One natural way to approach things would be to just iterate through the coordinates of ${\bf x}$, minimizing the function over each one. Say we want to minimize with respect to $x_i$. So, we are going to make the change $x_i \leftarrow x_i + d$. Thus, we want to minimize (with respect to d), this thing:

$\frac{1}{2} ({\bf x}+d {\bf e_i})^T A ({\bf x}+d {\bf e_i}) - ({\bf x}+d {\bf e_i})^T {\bf b}$

Expanding this, taking the derivative w.r.t. $d$, and solving gives us (where ${\bf e_i}$ is the unit vector in the $i$th direction)

$d = \frac{{\bf e_i}^T({\bf b}-A{\bf x})}{{\bf e_i}^TA{\bf e_i}}$,

or, equivalently,

$d = \frac{1}{A_{ii}}(b_i-A_{i*}{\bf x})$.

So, taking the solution at one timestep, ${\bf x}^t$, and solving for the solution ${\bf x}^{t+1}$ at the next timestep, we will have, for the first index,

$x^{t+1}_1 = x^t_1 + \frac{1}{A_{11}}(b_1-A_{1*}x^t)$.

Meanwhile, for the second index,

$x^{t+1}_2 = x^t_2 + \frac{1}{A_{22}}(b_2-A_{2*}(x^t-{\bf e_1} x^t_1 +{\bf e_1} x^{t+1}_1 ))$.

And, more generally

$x^{t+1}_n = x^t_n + \frac{1}{A_{nn}}(b_n-A_{n*}(\sum_{i.

Now, all is a matter of cranking through the equations.

$x^{t+1}_n = x^t_n + \frac{1}{A_{nn}}(b_n-\sum_{i

$D {\bf x}^{t+1} = D {\bf x}^t + {\bf b}- L x^{t+1} - (D+U) {\bf x}^t$

$(D+L) {\bf x}^{t+1} = {\bf b}-U {\bf x}^t$

Which is exactly the iteration we started with.

The fact that this would work when $A$ is symmetric positive definite is now obvious– that is when the objective function we started with is bounded below.

The understanding that Gauss-Seidel is just a convenient way to implement coordinate descent also helps to get intuition for the numerical properties of Gauss-Seidel (e.g. quickly removing "high frequency error", and very slowly removing "low frequency error").

Hessian-Vector products

You have some function $f({\bf x})$.  You have figured out how to compute it’s gradient, $g({\bf x})=\frac{\partial}{\partial \bf x}f({\bf x})$.  Now, however, you find that you are implementing some algorithm (like, say, Stochastic Meta Descent), and you need to compute the product of the Hessian $H({\bf x})=\frac{\partial^2}{\partial {\bf x}\partial{\bf x}^T}f({\bf x})$ with certain vectors.  You become very upset, because either A) you don’t feel like deriving the Hessian (probable), or B) the Hessian has $N^2$ elements, where $N$ is the length of $\bf x$ and that is too big to deal with (more probable).  What to do?  Behold:

${\bf g}({\bf x}+{\bf \Delta x}) \approx {\bf g}({\bf x}) + H({\bf x}){\bf \Delta x}$

Consider the Hessian-Vector product we want to compute, $H({\bf x}){\bf v}$.  For small $r$,

${\bf g}({\bf x}+r{\bf v}) \approx {\bf g}({\bf x}) + r H({\bf x}){\bf v}$

And so,

$\boxed{H({\bf x}){\bf v}\approx\frac{{\bf g}({\bf x}+r{\bf v}) - {\bf g}({\bf x})}{r}}$

This trick above has been around apparently forever.  The approximation becomes exact in the limit $r \rightarrow 0$.  Of course, for small $r$, numerical problems will also start to kill you.  Pearlmutter’s algorithm is a way to compute $H {\bf v}$ with the same complexity, with out suffering rounding errors.  Unfortunately, Pearlmutter’s algorithm is kind of complex, while the above is absolutely trivial.

Update: Perlmutter himself comments below that if we want to use the finite difference trick, we would do better to use the approximation:

$\boxed{H({\bf x}){\bf v}\approx\frac{{\bf g}({\bf x}+r{\bf v}) - {\bf g}({\bf x}-r{\bf v})}{2r}}.$

This expression will be closer to the true value for larger $r$, meaning that we are less likely to get hurt by rounding error. This is nicely illustrated on page 6 of these notes.

Why does regularization work?

When fitting statistical models, we usually need to “regularize” the model. The simplest example is probably linear regression. Take some training data, ${({\hat {\bf x}}, {\hat y})}$. Given a vector of weights $\bf w$, the total squared distance is

$\sum_{ ({\hat {\bf x}}, {\hat y})} ({\bf w}^T{\hat {\bf x}} - {\hat y})^2$

So to fit the model, we might find ${\bf w}$ to minimize the above loss. Commonly, (particularly when ${\bf x}$ has many dimensions), we find t\hat the above procedure leads to overfitting: very large weights ${\bf w}$ t\hat fit the training data very well, but poorly predict future data. “Regularization” means modifying the optimization problem to prefer small weights. Consider

$\sum_{ ({\hat {\bf x}}, {\hat y})} ({\bf w}^T{\hat {\bf x}} - {\hat y})^2 + \lambda ||{\bf w}||^2.$

Here, $\lambda$ trades off how much we care about the model fitting well versus how we care about ${\bf w}$ being small. Of course, this is a much more general concept that linear regression, and one could pick many alternatives to the squared norm to measure how simple a model is. Let’s abstract a bit, and just write some function $M$ to represent how well the model fits the data, and some function $R$ to represent how simple it is.

$M({\bf w}) + \lambda R({\bf w})$

The obvious question is: how do we pick $\lambda$? This, of course, has been the subject of much theoretical study. What is commonly done in practice is this:

1. Divide the data into a training set and test set.
2. Fit the model for a range of different $\lambda$s using only the training set.
3. For each of the models fit in step 2, check how well the resulting weights fit the test data.
4. Output the weights that perform best on test data.

(One can also retrain on all the data using the $\lambda$ that did best in step 2.)

Now, there are many ways to measure simplicity. (In the example above, we might consider $||{\bf w}||^2$, $||{\bf w}||$, $||{\bf w}||_1$, $||{\bf w}||_0$, …). Which one to use? If you drank too much coffee one afternoon, you might decide to include a bunch of regulizers simultaneously, each with a corresponding regularization parameter, ending up with a problem like

$M({\bf w}) + \sum_i \lambda_i R_i({\bf w})$.

And yet, staring at that equation, things take a sinister hue: if we include too many regularization parameters, won’t we eventually overfit the regularization parameters? And furthermore, won’t it be very hard to test all possible combinations of $\lambda_i$ to some reasonable resolution? What should we do?

And now I will finally get to the point.

Recently, I’ve been looking at regularization in a different way, which seems very obvious in retrospect. The idea is that when searching for the optimal regularization parameters, we are fitting a model– a model that just happens to be defined in an odd way.

First, let me define some notation. Let $T$ denote the training set, and let $S$ denote the test set. Now, define

${\bf f}(\lambda) = \arg\min_{\bf w} M_T({\bf w}) + \lambda R({\bf w})$

Where $M_T$ is the function measuring how well the model fits the training data. So, for a given regularization parameter, ${\bf f}$ returns the weights solving the optimization problem. Given that, the optimal \lambda will the the one minimizing this equation:

$\lambda^* = \arg\min_\lambda M_S( {\bf f}(\lambda) )$

This extends naturally to the setting where we have a vector of regularization parameters ${\bf \lambda}$.

${\bf f}({\boldsymbol \lambda}) = \arg\min_{\bf w} M_T({\bf w}) + \sum_i \lambda_i R_i({\bf w})$

${\boldsymbol \lambda}^* = \arg\min_{\bf \lambda} M_S( {\bf f}({\boldsymbol \lambda}) )$

From this perspective, doing an optimization over regularization parameters makes sense: it is just fitting the parameters $\boldsymbol \lambda$ to the data $S$— it just so happens that the objective function is implicitly defined by a minimization over the data $T$.

Regularization works because it is fitting a single parameter to some data. (In general, one is safe fitting a single parameter to any reasonably sized dataset although there are exceptions.) Fitting multiple regularization parameters should work, supposing the test data is big enough to support them. (There are computational issues I’m not discussing here with doing this, stemming from the fact that ${\bf f}({\boldsymbol \lambda})$ isn’t known in closed form, but is implicitly defined by a minimization. So far as I know, this is an open problem, although I suspect the solution is “implicit differentiation”)

Even regularizing the regularization parameters isn’t necessarily a crazy idea, though clearly $||{\boldsymbol \lambda}||^2$ isn’t the way to do it. (Simpler models correspond to larger regularization constants. Maybe $1/||{\boldsymbol \lambda}||^2$ ?.) Then, you can regularize those parameters!

Well. In retrospect that wasn’t quite so simple as it seemed.

It’s a strong bet that this is a standard interpretation in some communities.

Update: The idea of fitting multiple parameters using implicit differentiation is published in the 1996 paper Adaptive regularization in neural network modeling by Jan Larsen, Claus Svarer, Lars Nonboe Andersen and Lars Kai Hansen.