In 2012, I wrote a paper that I probably should have called “truncated bi-level optimization”. I vaguely remembered telling the reviewers I would release some code, so I’m finally getting around to it.
The idea of bilevel optimization is quite simple. Imagine that you would like to minimize some function . However, itself is defined through some optimization. More formally, suppose we would like to solve
where is defined as . This seems a little bit obscure at first, but actually comes up in several different ways in machine learning and related fields.
The first example would be in learning hyperparameters, such as regularization constants. Inevitably in machine learning, one fits parameters parameters to optimize some tradeoff between the quality of a fit to training data and a regularization function being small. Traditionally, the regularization constant is selected by optimizing on a training dataset with a variety of values, and then picking the one that performs best on a held-out dataset. However, if there are a large number of regularization parameters, a high-dimensional grid-search will not be practical. In the notation above, suppose that is a vector of regularization constants, and that are training parameters. Let, be the regularized empirical risk on a training dataset, and let be how well the parameters perform on some held-out validation dataset.
Another example (and the one suggesting the notation) is an energy-based model. Suppose that we have some “energy” function which measures how well an output fits to an input . The energy is parametrized by . For a given training input/output pair , we might have that measures how how the predicted output compares to the true output , where .
Computing the gradient exactly
Even if we just have the modest goal of following the gradient of to a local minimum, even computing the gradient is not so simple. Clearly, even to evaluate requires solving the “inner” minimization of . It turns out that one can compute through first solving the inner minimization, and then solving a linear system.
- Solve .
- (a) the loss
- (b) the gradient
- (c) the Hessian
- Solve the linear system .
- Compute the parameter gradient
- Return and .
This looks a bit nasty, since we need to compute second-derivative matrices of . In fact, as long as one has a routine to compute and , this can be avoided through Efficient Matrix-vector products. This is essentially proposed by Do, Foo, and Ng in “Efficient multiple hyperparameter learning for log-linear models”.
Overall, this is a decent approach, but it can be quite slow, simply because one must solve an “inner” optimization in order to compute each gradient of the “outer” optimization. Often, the inner-optimization needs to be solved to very high accuracy in order to estimate a gradient accurately enough to reduce — higher accuracy than is needed when one is simply using the predicted value itself.
To get around this expense, a fairly obvious idea is to re-define the problem. The slowness of exactly computing the gradient stems from needing to exactly solve the inner optimization. Hence, perhaps we re-define the problem such that an inexact solve of the inner problem nevertheless yields an “exact” gradient?
Re-define the problem as solving
where denotes an approximate solve of the inner optimization. In order for this to work, must be defined in such a way that is a continuous function of . With standard optimization methods such as gradient descent or BFGS, this can be achieved by assuming there are a fixed number of iterations applied, with a fixed step-size. Since each iteration of these algorithms is continuous, this clearly defines as a continuous function. Thus, in principle, it could be optimized efficiently through automatic differentiation of the code that optimizes . That’s fine in principle, but often inconvenient in practice.
It turns out, however, that one can derive “backpropagating” versions of algorithms like gradient descent, that take as input only a procedure to compute along with it’s first derivatives. These algorithms can then produce the gradient of in the same time as automatic differentiation.
If the inner-optimization is gradient descent for steps with a step-size of , the algorithm to compute the loss is simple:
How to compute the gradient of this quantity? The following algorithm does the trick.
- Return .
Similar algorithms can be derived for the heavy-ball algorithm (with a little more complexity) and limited memory BFGS (with a lot more complexity).
So, finally, here is the code, and I’ll give a simple example of how to use it. There are just four simple files:
I think the meanings of this are pretty straightforward, so I’ll just quickly step through the demo here. I’ll start off by grabbing taking one of Matlab’s built-in datasets (on cities) so that we are trying to predict a measure of crime from measures of climate, housing, health, transportation, arts, recreation, and economy, as well as a constant. There are 329 data, total, which I split into a training set of size 40, a validation set of size 160, and a test set of size 129.
load cities ratings X = ratings; for i=1:size(X,2) X(:,i) = X(:,i) - mean(X(:,i)); X(:,i) = X(:,i) / std( X(:,i)); end % predict crime from climate, housing, health, trans, edu, arts, rec, econ, Y = X(:,4); X = [X(:,[1:3 5:9]) 1+0*X(:,1)]; p = randperm(length(Y)); X = X(p,:); Y = Y(p); whotrain = 1:50; whoval = 51:200; whotest = 201:329; Xtrain = X(whotrain,:); Xval = X(whoval ,:); Xtest = X(whotest ,:); Ytrain = Y(whotrain); Yval = Y(whoval ); Ytest = Y(whotest );
Next, I’ll set up some simple constants that will be used later on, and define the optimization parameters for minFunc, that I will be using for the outer optimization. In particular, here I choose the inner optimization to use 20 iterations.
opt_iters = 20; ndims = size(Xtrain,2); w0 = zeros(ndims,1); ndata = size(Xtrain,1); ndata_val = size(Xval,1); options = ; options.Method = 'gd'; options.LS = 1; options.MaxIter = 100;
Now, I’ll define the training risk function ( in the notation above). The computes the risk with a regularization constant of , as well as derivatives. I’ll also define the validation risk ( in the notation above).
function [R dRdw dRdloga] = training_risk(w,loga) a = exp(loga); R = sum( (Xtrain*w - Ytrain).^2 )/ndata + a*sum(w.^2); dRdw = 2*Xtrain'*(Xtrain*w-Ytrain) /ndata + 2*a*w; dRda = sum(w.^2); dRdloga = dRda*a; end function [R g] = validation_risk(w) R = sum( (Xval*w - Yval).^2 ) / ndata_val; g = 2*Xval'*(Xval*w-Yval) / ndata_val; end
Now, before going any further, let’s do a traditional sweep through regularization constants to see what that looks like.
LAMBDA = -5:.25:2; VAL_RISK = 0*LAMBDA; for i=1:length(LAMBDA) VAL_RISK(i) = back_lbfgs(@training_risk,@validation_risk,w0,LAMBDA(i),opt_iters); end
Plotting, we get the following:
This is a reasonable looking curve. Instead, let’s ask the algorithm to find the constant by gradient descent.
eval = @(loga) back_lbfgs(@training_risk,@validation_risk,w0,loga,opt_iters); loga = 0; [loga fval] = minFunc(eval,loga,options);
Running the optimization, we see:
Iteration FunEvals Step Length Function Val Opt Cond 1 2 1.00000e+00 8.74176e-01 3.71997e-03 2 3 1.00000e+00 8.73910e-01 1.86453e-04 3 4 1.00000e+00 8.73909e-01 1.06619e-05 4 5 1.00000e+00 8.73909e-01 3.60499e-08
This leads to a regularizer of the form:
We can plot this on the graph, and see it matches the result of cross-validation.
If we actually compute the test-set error, this is 0.708217.
OK, let’s be a little bit more adventurous, and use a third-order regularizer. This is done like so:
function [R dRdw dRdloga] = training_risk2(w,loga) a = exp(loga(1)); b = exp(loga(2)); R = sum( (Xtrain*w - Ytrain).^2 ) / ndata + a*sum(abs(w).^2) + b*sum(abs(w).^3); dRdw = 2*Xtrain'*(Xtrain*w-Ytrain) / ndata + a*abs(w).^1.*sign(w) + b*abs(w).^2.*sign(w); dRda = sum(abs(w).^2); dRdb = sum(abs(w).^3); dRdloga = [dRda*a; dRdb*b]; end
Running the optimization, we see:
Iteration FunEvals Step Length Function Val Opt Cond 1 2 1.00000e+00 8.74445e-01 1.17262e-02 2 3 1.00000e+00 8.73685e-01 3.21956e-03 3 4 1.00000e+00 8.73608e-01 1.41744e-03 4 5 1.00000e+00 8.73598e-01 8.20040e-04 5 6 1.00000e+00 8.73567e-01 1.39830e-03 6 7 1.00000e+00 8.73513e-01 2.52994e-03 7 8 1.00000e+00 8.73471e-01 1.77157e-03 ... 23 28 1.00000e+00 8.70741e-01 8.42628e-06
With a final regularizer of the form
and a hardly-improved test error of 0.679155.
Finally, let’s fit a fourth-order polynomial.
function [R dRdw dRdloga] = training_risk3(w,loga) a = exp(loga); R = sum( (Xtrain*w - Ytrain).^2 ) / ndata; dRdw = 2*Xtrain'*(Xtrain*w-Ytrain) / ndata; for ii=1:length(a) b=ii+1; R = R + a(ii)*sum(abs(w).^b); dRdw = dRdw + a(ii)*abs(w).^(b-1).*sign(w); dRda(ii,1) = sum(abs(w).^b); end dRdloga = dRda.*a; end eval = @(loga) back_lbfgs(@training_risk3,@validation_risk,w0,loga,opt_iters); loga = [0; 0; 0; 0]; loga = minFunc(eval,loga,options);
Running the optimization, we see
Iteration FunEvals Step Length Function Val Opt Cond 1 2 1.00000e+00 8.73982e-01 1.02060e-02 2 3 1.00000e+00 8.73524e-01 3.30445e-03 3 4 1.00000e+00 8.73447e-01 1.74779e-03 4 5 1.00000e+00 8.73435e-01 1.43655e-03 5 6 1.00000e+00 8.73345e-01 4.98340e-03 6 7 1.00000e+00 8.73295e-01 1.85535e-03 7 8 1.00000e+00 8.73231e-01 1.81136e-03 ... 38 41 1.00000e+00 8.69758e-01 3.44848e-06
Yielding the (strange) regularizer
and a final test-error of 0.67187.