In 2012, I wrote a paper that I probably should have called “truncated bi-level optimization”. I vaguely remembered telling the reviewers I would release some code, so I’m finally getting around to it.

The idea of bilevel optimization is quite simple. Imagine that you would like to minimize some function . However, itself is defined through some optimization. More formally, suppose we would like to solve

Or, equivalently,

where is defined as . This seems a little bit obscure at first, but actually comes up in several different ways in machine learning and related fields.

## Hyper-parameter learning

The first example would be in learning hyperparameters, such as regularization constants. Inevitably in machine learning, one fits parameters parameters to optimize some tradeoff between the quality of a fit to training data and a regularization function being small. Traditionally, the regularization constant is selected by optimizing on a training dataset with a variety of values, and then picking the one that performs best on a held-out dataset. However, if there are a large number of regularization parameters, a high-dimensional grid-search will not be practical. In the notation above, suppose that is a vector of regularization constants, and that are training parameters. Let, be the regularized empirical risk on a training dataset, and let be how well the parameters perform on some held-out validation dataset.

## Energy-based models

Another example (and the one suggesting the notation) is an energy-based model. Suppose that we have some “energy” function which measures how well an output fits to an input . The energy is parametrized by . For a given training input/output pair , we might have that measures how how the predicted output compares to the true output , where .

## Computing the gradient exactly

Even if we just have the modest goal of following the gradient of to a local minimum, even computing the gradient is not so simple. Clearly, even to evaluate requires solving the “inner” minimization of . It turns out that one can compute through first solving the inner minimization, and then solving a linear system.

- Input
- Solve .
- Compute:
- (a) the loss
- (b) the gradient
- (c) the Hessian
- Solve the linear system .
- Compute the parameter gradient
- Return and .

This looks a bit nasty, since we need to compute second-derivative matrices of . In fact, as long as one has a routine to compute and , this can be avoided through Efficient Matrix-vector products. This is essentially proposed by Do, Foo, and Ng in “Efficient multiple hyperparameter learning for log-linear models”.

Overall, this is a decent approach, but it can be quite slow, simply because one must solve an “inner” optimization in order to compute each gradient of the “outer” optimization. Often, the inner-optimization needs to be solved to very high accuracy in order to estimate a gradient accurately enough to reduce — higher accuracy than is needed when one is simply using the predicted value itself.

## Truncated optimization

To get around this expense, a fairly obvious idea is to re-define the problem. The slowness of exactly computing the gradient stems from needing to exactly solve the inner optimization. Hence, perhaps we re-define the problem such that an inexact solve of the inner problem nevertheless yields an “exact” gradient?

Re-define the problem as solving

,

where denotes an approximate solve of the inner optimization. In order for this to work, must be defined in such a way that is a continuous function of . With standard optimization methods such as gradient descent or BFGS, this can be achieved by assuming there are a fixed number of iterations applied, with a fixed step-size. Since each iteration of these algorithms is continuous, this clearly defines as a continuous function. Thus, in principle, it could be optimized efficiently through automatic differentiation of the code that optimizes . That’s fine in principle, but often inconvenient in practice.

It turns out, however, that one can derive “backpropagating” versions of algorithms like gradient descent, that take as input only a procedure to compute along with it’s first derivatives. These algorithms can then produce the gradient of in the same time as automatic differentiation.

## Back Gradient-Descent

If the inner-optimization is gradient descent for steps with a step-size of , the algorithm to compute the loss is simple:

- Input
- For
- (a)
- Return

How to compute the gradient of this quantity? The following algorithm does the trick.

- For
- (a)
- (b)
- Return .

Similar algorithms can be derived for the heavy-ball algorithm (with a little more complexity) and limited memory BFGS (with a lot more complexity).

## Code

So, finally, here is the code, and I’ll give a simple example of how to use it. There are just four simple files:

I think the meanings of this are pretty straightforward, so I’ll just quickly step through the demo here. I’ll start off by grabbing taking one of Matlab’s built-in datasets (on cities) so that we are trying to predict a measure of crime from measures of climate, housing, health, transportation, arts, recreation, and economy, as well as a constant. There are 329 data, total, which I split into a training set of size 40, a validation set of size 160, and a test set of size 129.

load cities ratings X = ratings; for i=1:size(X,2) X(:,i) = X(:,i) - mean(X(:,i)); X(:,i) = X(:,i) / std( X(:,i)); end % predict crime from climate, housing, health, trans, edu, arts, rec, econ, Y = X(:,4); X = [X(:,[1:3 5:9]) 1+0*X(:,1)]; p = randperm(length(Y)); X = X(p,:); Y = Y(p); whotrain = 1:50; whoval = 51:200; whotest = 201:329; Xtrain = X(whotrain,:); Xval = X(whoval ,:); Xtest = X(whotest ,:); Ytrain = Y(whotrain); Yval = Y(whoval ); Ytest = Y(whotest );

Next, I’ll set up some simple constants that will be used later on, and define the optimization parameters for minFunc, that I will be using for the outer optimization. In particular, here I choose the inner optimization to use 20 iterations.

opt_iters = 20; ndims = size(Xtrain,2); w0 = zeros(ndims,1); ndata = size(Xtrain,1); ndata_val = size(Xval,1); options = []; options.Method = 'gd'; options.LS = 1; options.MaxIter = 100;

Now, I’ll define the training risk function ( in the notation above). The computes the risk with a regularization constant of , as well as derivatives. I’ll also define the validation risk ( in the notation above).

function [R dRdw dRdloga] = training_risk(w,loga) a = exp(loga); R = sum( (Xtrain*w - Ytrain).^2 )/ndata + a*sum(w.^2); dRdw = 2*Xtrain'*(Xtrain*w-Ytrain) /ndata + 2*a*w; dRda = sum(w.^2); dRdloga = dRda*a; end function [R g] = validation_risk(w) R = sum( (Xval*w - Yval).^2 ) / ndata_val; g = 2*Xval'*(Xval*w-Yval) / ndata_val; end

Now, before going any further, let’s do a traditional sweep through regularization constants to see what that looks like.

LAMBDA = -5:.25:2; VAL_RISK = 0*LAMBDA; for i=1:length(LAMBDA) VAL_RISK(i) = back_lbfgs(@training_risk,@validation_risk,w0,LAMBDA(i),opt_iters); end

Plotting, we get the following:

This is a reasonable looking curve. Instead, let’s ask the algorithm to find the constant by gradient descent.

eval = @(loga) back_lbfgs(@training_risk,@validation_risk,w0,loga,opt_iters); loga = 0; [loga fval] = minFunc(eval,loga,options);

Running the optimization, we see:

Iteration FunEvals Step Length Function Val Opt Cond 1 2 1.00000e+00 8.74176e-01 3.71997e-03 2 3 1.00000e+00 8.73910e-01 1.86453e-04 3 4 1.00000e+00 8.73909e-01 1.06619e-05 4 5 1.00000e+00 8.73909e-01 3.60499e-08

This leads to a regularizer of the form:

.

We can plot this on the graph, and see it matches the result of cross-validation.

If we actually compute the test-set error, this is 0.708217.

OK, let’s be a little bit more adventurous, and use a third-order regularizer. This is done like so:

function [R dRdw dRdloga] = training_risk2(w,loga) a = exp(loga(1)); b = exp(loga(2)); R = sum( (Xtrain*w - Ytrain).^2 ) / ndata + a*sum(abs(w).^2) + b*sum(abs(w).^3); dRdw = 2*Xtrain'*(Xtrain*w-Ytrain) / ndata + a*abs(w).^1.*sign(w) + b*abs(w).^2.*sign(w); dRda = sum(abs(w).^2); dRdb = sum(abs(w).^3); dRdloga = [dRda*a; dRdb*b]; end

Running the optimization, we see:

Iteration FunEvals Step Length Function Val Opt Cond 1 2 1.00000e+00 8.74445e-01 1.17262e-02 2 3 1.00000e+00 8.73685e-01 3.21956e-03 3 4 1.00000e+00 8.73608e-01 1.41744e-03 4 5 1.00000e+00 8.73598e-01 8.20040e-04 5 6 1.00000e+00 8.73567e-01 1.39830e-03 6 7 1.00000e+00 8.73513e-01 2.52994e-03 7 8 1.00000e+00 8.73471e-01 1.77157e-03 ... 23 28 1.00000e+00 8.70741e-01 8.42628e-06

With a final regularizer of the form

and a hardly-improved test error of 0.679155.

Finally, let’s fit a fourth-order polynomial.

function [R dRdw dRdloga] = training_risk3(w,loga) a = exp(loga); R = sum( (Xtrain*w - Ytrain).^2 ) / ndata; dRdw = 2*Xtrain'*(Xtrain*w-Ytrain) / ndata; for ii=1:length(a) b=ii+1; R = R + a(ii)*sum(abs(w).^b); dRdw = dRdw + a(ii)*abs(w).^(b-1).*sign(w); dRda(ii,1) = sum(abs(w).^b); end dRdloga = dRda.*a; end eval = @(loga) back_lbfgs(@training_risk3,@validation_risk,w0,loga,opt_iters); loga = [0; 0; 0; 0]; loga = minFunc(eval,loga,options);

Running the optimization, we see

Iteration FunEvals Step Length Function Val Opt Cond 1 2 1.00000e+00 8.73982e-01 1.02060e-02 2 3 1.00000e+00 8.73524e-01 3.30445e-03 3 4 1.00000e+00 8.73447e-01 1.74779e-03 4 5 1.00000e+00 8.73435e-01 1.43655e-03 5 6 1.00000e+00 8.73345e-01 4.98340e-03 6 7 1.00000e+00 8.73295e-01 1.85535e-03 7 8 1.00000e+00 8.73231e-01 1.81136e-03 ... 38 41 1.00000e+00 8.69758e-01 3.44848e-06

Yielding the (strange) regularizer

and a final test-error of 0.67187.