Truncated Bi-Level Optimization

In 2012, I wrote a paper that I probably should have called “truncated bi-level optimization”.  I vaguely remembered telling the reviewers I would release some code, so I’m finally getting around to it.

The idea of bilevel optimization is quite simple.  Imagine that you would like to minimize some function $L(w)$.  However, $L$ itself is defined through some optimization.  More formally, suppose we would like to solve

$\min_{w,y} Q(y), \text{ s.t. } y = \arg\min_y E(y,w).$

Or, equivalently,

$\min_{w} L(w)=Q(y^*(w)),$

where $y^*$ is defined as $y^*(w) := \arg\min_y E(y,w)$.  This seems a little bit obscure at first, but actually comes up in several different ways in machine learning and related fields.

Hyper-parameter learning

The first example would be in learning hyperparameters, such as regularization constants.  Inevitably in machine learning, one fits parameters parameters to optimize some tradeoff between the quality of a fit to training data and a regularization function being small.  Traditionally, the regularization constant is selected by optimizing on a training dataset with a variety of values, and then picking the one that performs best on a held-out dataset.  However, if there are a large number of regularization parameters, a high-dimensional grid-search will not be practical.  In the notation above, suppose that $w$ is a vector of regularization constants, and that $y$ are training parameters.  Let, $E$  be the regularized empirical risk on a training dataset, and let $Q$ be how well the parameters $y^*(w)$ perform on some held-out validation dataset.

Energy-based models

Another example (and the one suggesting the notation) is an energy-based model.  Suppose that we have some “energy” function $E_x(y,w)$ which measures how well an output $y$ fits to an input $x$.  The energy is parametrized by $w$.  For a given training input/output pair $(\hat{x},\hat{y})$, we might have that $Q_{\hat{y}}(y^*(w))$ measures how how the predicted output $y^*$ compares to the true output $\hat{y}$, where $y^*(w)=\arg\max_y E_{\hat{x}}(y,w)$.

Even if we just have the modest goal of following the gradient of $L(w)$ to a local minimum, even computing the gradient is not so simple.  Clearly, even to evaluate $L(w)$ requires solving the “inner” minimization of $E$. It turns out that one can compute $\nabla_w L(w)$ through first solving the inner minimization, and then solving a linear system.

1. Input $w$
2. Solve $y^* \leftarrow \arg\min_y E(y,w)$.
3. Compute:
4.    (a) the loss $L(w)= Q(y^*)$
5.    (b) the gradient $g=\nabla_y Q(y^*)$
6.    (c) the Hessian $H=\frac{\partial^2 E(y^*,w)}{\partial w\partial w^T}$
7. Solve the linear system $z=H^{-1} g$.
8. Compute the parameter gradient $\nabla_w L(w) = - \frac{\partial^2 E(y^*,w)}{\partial w\partial y^T} z$
9. Return $L(w)$ and $\nabla_w L(w)$.

This looks a bit nasty, since we need to compute second-derivative matrices of $E$.  In fact, as long as one has a routine to compute $\nabla_y E$ and $\nabla_w E$, this can be avoided through Efficient Matrix-vector products.  This is essentially proposed by Do, Foo, and Ng in “Efficient multiple hyperparameter learning for log-linear models”.

Overall, this is a decent approach, but it can be quite slow, simply because one must solve an “inner” optimization in order to compute each gradient of the “outer” optimization.  Often, the inner-optimization needs to be solved to very high accuracy in order to estimate a gradient accurately enough to reduce $L(w)$— higher accuracy than is needed when one is simply using the predicted value $y^*$ itself.

Truncated optimization

To get around this expense, a fairly obvious idea is to re-define the problem.  The slowness of exactly computing the gradient stems from needing to exactly solve the inner optimization.  Hence, perhaps we re-define the problem such that an inexact solve of the inner problem nevertheless yields an “exact” gradient?

Re-define the problem as solving

$\min_{w} L(w)=Q(y^*(w)), \text{ } y^*(w) := \text{opt-alg}_y E(y,w)$,

where $\text{opt-alg}$ denotes an approximate solve of the inner optimization.  In order for this to work, $\text{opt-alg}$ must be defined in such a way that $y^*(w)$ is a continuous function of $w$.  With standard optimization methods such as gradient descent or BFGS, this can be achieved by assuming there are a fixed number of iterations applied, with a fixed step-size.  Since each iteration of these algorithms is continuous, this clearly defines $y^*(w)$ as a continuous function.  Thus, in principle, it could be optimized efficiently through automatic differentiation of the code that optimizes $E$.  That’s fine in principle, but often inconvenient in practice.

It turns out, however, that one can derive “backpropagating” versions of algorithms like gradient descent, that take as input only a procedure to compute $E$ along with it’s first derivatives.  These algorithms can then produce the gradient of $L$ in the same time as automatic differentiation.

If the inner-optimization is gradient descent for $N$ steps with a step-size of $\lambda$, the algorithm to compute the loss is simple:

1. Input $w$
2. For $k=0,1,...,N-1$
3.     (a) $y_{k+1} \leftarrow y_k - \lambda \nabla_y E(y_k,w)$
4. Return $L(w) = Q(y_N)$

How to compute the gradient of this quantity?  The following algorithm does the trick.

1. $\overleftarrow{y_N} \leftarrow \nabla Q(y_N)$
2. $\overleftarrow{w} \leftarrow 0$
3. For $k=N-1,...,0$
4.     (a) $\overleftarrow{w} \leftarrow \overleftarrow{w} - \lambda \frac{\partial^2 E(y_k,w)}{\partial w \partial y^T} \overleftarrow{y_{k+1}}$
5.     (b) $\overleftarrow{y_k} \leftarrow \overleftarrow{y_{k+1}} - \lambda \frac{\partial^2 E(y_k,w)}{\partial y \partial y^T} \overleftarrow{y_{k+1}}$
6. Return $\nabla L = \overleftarrow{w}$.

Similar algorithms can be derived for the heavy-ball algorithm (with a little more complexity) and limited memory BFGS (with a lot more complexity).

Code

So, finally, here is the code, and I’ll give a simple example of how to use it. There are just four simple files:

I think the meanings of this are pretty straightforward, so I’ll just quickly step through the demo here.  I’ll start off by grabbing taking one of Matlab’s built-in datasets (on cities) so that we are trying to predict a measure of crime from measures of climate, housing, health, transportation, arts, recreation, and economy, as well as a constant.  There are 329 data, total, which I split into a training set of size 40, a validation set of size 160, and a test set of size 129.

load cities ratings
X = ratings;
for i=1:size(X,2)
X(:,i) = X(:,i) - mean(X(:,i));
X(:,i) = X(:,i) / std( X(:,i));
end

% predict crime from climate, housing, health, trans, edu, arts, rec, econ,
Y = X(:,4);
X = [X(:,[1:3 5:9]) 1+0*X(:,1)];

p = randperm(length(Y));
X = X(p,:);
Y = Y(p);

whotrain = 1:50;
whoval   = 51:200;
whotest  = 201:329;

Xtrain = X(whotrain,:);
Xval   = X(whoval  ,:);
Xtest  = X(whotest ,:);
Ytrain = Y(whotrain);
Yval   = Y(whoval  );
Ytest  = Y(whotest );

Next, I’ll set up some simple constants that will be used later on, and define the optimization parameters for minFunc, that I will be using for the outer optimization. In particular, here I choose the inner optimization to use 20 iterations.

opt_iters = 20;
ndims = size(Xtrain,2); w0    = zeros(ndims,1); ndata = size(Xtrain,1); ndata_val = size(Xval,1);
options = []; options.Method  = 'gd'; options.LS      = 1; options.MaxIter = 100;

Now, I’ll define the training risk function ($E$ in the notation above). The computes the risk with a regularization constant of $a$, as well as derivatives. I’ll also define the validation risk ($Q$ in the notation above).

function [R dRdw dRdloga] = training_risk(w,loga)
a         = exp(loga);
R         = sum( (Xtrain*w - Ytrain).^2 )/ndata + a*sum(w.^2);
dRdw      = 2*Xtrain'*(Xtrain*w-Ytrain)  /ndata + 2*a*w;
dRda      = sum(w.^2);
dRdloga = dRda*a;
end

function [R g] = validation_risk(w)
R = sum( (Xval*w - Yval).^2 ) / ndata_val;
g = 2*Xval'*(Xval*w-Yval)     / ndata_val;
end

Now, before going any further, let’s do a traditional sweep through regularization constants to see what that looks like.

LAMBDA = -5:.25:2;
VAL_RISK = 0*LAMBDA;
for i=1:length(LAMBDA)
VAL_RISK(i) = back_lbfgs(@training_risk,@validation_risk,w0,LAMBDA(i),opt_iters);
end

Plotting, we get the following:

This is a reasonable looking curve. Instead, let’s ask the algorithm to find the constant by gradient descent.

eval = @(loga) back_lbfgs(@training_risk,@validation_risk,w0,loga,opt_iters);
loga = 0;
[loga fval] = minFunc(eval,loga,options);

Running the optimization, we see:

Iteration   FunEvals     Step Length    Function Val        Opt Cond
1          2     1.00000e+00     8.74176e-01     3.71997e-03
2          3     1.00000e+00     8.73910e-01     1.86453e-04
3          4     1.00000e+00     8.73909e-01     1.06619e-05
4          5     1.00000e+00     8.73909e-01     3.60499e-08

This leads to a regularizer of the form:

$0.860543 ||w||_2^2$.

We can plot this on the graph, and see it matches the result of cross-validation.

If we actually compute the test-set error, this is 0.708217.

OK, let’s be a little bit more adventurous, and use a third-order regularizer. This is done like so:

    function [R dRdw dRdloga] = training_risk2(w,loga)
a         = exp(loga(1));
b         = exp(loga(2));
R         = sum( (Xtrain*w - Ytrain).^2 ) / ndata + a*sum(abs(w).^2)     + b*sum(abs(w).^3);
dRdw      = 2*Xtrain'*(Xtrain*w-Ytrain)   / ndata + a*abs(w).^1.*sign(w) + b*abs(w).^2.*sign(w);
dRda      = sum(abs(w).^2);
dRdb      = sum(abs(w).^3);
dRdloga = [dRda*a; dRdb*b];
end

Running the optimization, we see:

Iteration   FunEvals     Step Length    Function Val        Opt Cond
1          2     1.00000e+00     8.74445e-01     1.17262e-02
2          3     1.00000e+00     8.73685e-01     3.21956e-03
3          4     1.00000e+00     8.73608e-01     1.41744e-03
4          5     1.00000e+00     8.73598e-01     8.20040e-04
5          6     1.00000e+00     8.73567e-01     1.39830e-03
6          7     1.00000e+00     8.73513e-01     2.52994e-03
7          8     1.00000e+00     8.73471e-01     1.77157e-03
...
23         28     1.00000e+00     8.70741e-01     8.42628e-06

With a final regularizer of the form

$0.000160 ||w|_2^2 + 5.117057 * ||w||_3^3$

and a hardly-improved test error of 0.679155.

Finally, let’s fit a fourth-order polynomial.

    function [R dRdw dRdloga] = training_risk3(w,loga)
a         = exp(loga);
R         = sum( (Xtrain*w - Ytrain).^2 ) / ndata;
dRdw      = 2*Xtrain'*(Xtrain*w-Ytrain)   / ndata;
for ii=1:length(a)
b=ii+1;
R    = R    + a(ii)*sum(abs(w).^b);
dRdw = dRdw + a(ii)*abs(w).^(b-1).*sign(w);
dRda(ii,1) = sum(abs(w).^b);
end
dRdloga = dRda.*a;
end

eval = @(loga) back_lbfgs(@training_risk3,@validation_risk,w0,loga,opt_iters);
loga = [0; 0; 0; 0];
loga = minFunc(eval,loga,options);

Running the optimization, we see

Iteration   FunEvals     Step Length    Function Val        Opt Cond
1          2     1.00000e+00     8.73982e-01     1.02060e-02
2          3     1.00000e+00     8.73524e-01     3.30445e-03
3          4     1.00000e+00     8.73447e-01     1.74779e-03
4          5     1.00000e+00     8.73435e-01     1.43655e-03
5          6     1.00000e+00     8.73345e-01     4.98340e-03
6          7     1.00000e+00     8.73295e-01     1.85535e-03
7          8     1.00000e+00     8.73231e-01     1.81136e-03
...
38         41     1.00000e+00     8.69758e-01     3.44848e-06

Yielding the (strange) regularizer

$0.000000 ||w||_2^2 + 0.000010 ||w||_3^3 + 24.310631 ||w||_4^4 + 0.325565 ||w||_5^5$

and a final test-error of 0.67187.