Truncated Bi-Level Optimization

In 2012, I wrote a paper that I probably should have called “truncated bi-level optimization”.  I vaguely remembered telling the reviewers I would release some code, so I’m finally getting around to it.

The idea of bilevel optimization is quite simple.  Imagine that you would like to minimize some function L(w).  However, L itself is defined through some optimization.  More formally, suppose we would like to solve

\min_{w,y} Q(y), \text{ s.t. } y = \arg\min_y E(y,w).

Or, equivalently,

\min_{w} L(w)=Q(y^*(w)),

where y^* is defined as y^*(w) := \arg\min_y E(y,w).  This seems a little bit obscure at first, but actually comes up in several different ways in machine learning and related fields.

Hyper-parameter learning

The first example would be in learning hyperparameters, such as regularization constants.  Inevitably in machine learning, one fits parameters parameters to optimize some tradeoff between the quality of a fit to training data and a regularization function being small.  Traditionally, the regularization constant is selected by optimizing on a training dataset with a variety of values, and then picking the one that performs best on a held-out dataset.  However, if there are a large number of regularization parameters, a high-dimensional grid-search will not be practical.  In the notation above, suppose that w is a vector of regularization constants, and that y are training parameters.  Let, E  be the regularized empirical risk on a training dataset, and let Q be how well the parameters y^*(w) perform on some held-out validation dataset.

Energy-based models

Another example (and the one suggesting the notation) is an energy-based model.  Suppose that we have some “energy” function E_x(y,w) which measures how well an output y fits to an input x.  The energy is parametrized by w.  For a given training input/output pair (\hat{x},\hat{y}), we might have that Q_{\hat{y}}(y^*(w)) measures how how the predicted output y^* compares to the true output \hat{y}, where y^*(w)=\arg\max_y E_{\hat{x}}(y,w).

Computing the gradient exactly

Even if we just have the modest goal of following the gradient of L(w) to a local minimum, even computing the gradient is not so simple.  Clearly, even to evaluate L(w) requires solving the “inner” minimization of E. It turns out that one can compute \nabla_w L(w) through first solving the inner minimization, and then solving a linear system.

  1. Input w
  2. Solve y^* \leftarrow \arg\min_y E(y,w).
  3. Compute:
  4.    (a) the loss L(w)= Q(y^*)
  5.    (b) the gradient g=\nabla_y Q(y^*)
  6.    (c) the Hessian H=\frac{\partial^2 E(y^*,w)}{\partial w\partial w^T}
  7. Solve the linear system z=H^{-1} g.
  8. Compute the parameter gradient \nabla_w L(w) = - \frac{\partial^2 E(y^*,w)}{\partial w\partial y^T} z
  9. Return L(w) and \nabla_w L(w).

This looks a bit nasty, since we need to compute second-derivative matrices of E.  In fact, as long as one has a routine to compute \nabla_y E and \nabla_w E, this can be avoided through Efficient Matrix-vector products.  This is essentially proposed by Do, Foo, and Ng in “Efficient multiple hyperparameter learning for log-linear models”.

Overall, this is a decent approach, but it can be quite slow, simply because one must solve an “inner” optimization in order to compute each gradient of the “outer” optimization.  Often, the inner-optimization needs to be solved to very high accuracy in order to estimate a gradient accurately enough to reduce L(w)– higher accuracy than is needed when one is simply using the predicted value y^* itself.

Truncated optimization

To get around this expense, a fairly obvious idea is to re-define the problem.  The slowness of exactly computing the gradient stems from needing to exactly solve the inner optimization.  Hence, perhaps we re-define the problem such that an inexact solve of the inner problem nevertheless yields an “exact” gradient?

Re-define the problem as solving

\min_{w} L(w)=Q(y^*(w)), \text{ } y^*(w) := \text{opt-alg}_y E(y,w),

where \text{opt-alg} denotes an approximate solve of the inner optimization.  In order for this to work, \text{opt-alg} must be defined in such a way that y^*(w) is a continuous function of w.  With standard optimization methods such as gradient descent or BFGS, this can be achieved by assuming there are a fixed number of iterations applied, with a fixed step-size.  Since each iteration of these algorithms is continuous, this clearly defines y^*(w) as a continuous function.  Thus, in principle, it could be optimized efficiently through automatic differentiation of the code that optimizes E.  That’s fine in principle, but often inconvenient in practice.

It turns out, however, that one can derive “backpropagating” versions of algorithms like gradient descent, that take as input only a procedure to compute E along with it’s first derivatives.  These algorithms can then produce the gradient of L in the same time as automatic differentiation.

Back Gradient-Descent

If the inner-optimization is gradient descent for N steps with a step-size of \lambda, the algorithm to compute the loss is simple:

  1. Input w
  2. For k=0,1,...,N-1
  3.     (a) y_{k+1} \leftarrow y_k - \lambda \nabla_y E(y_k,w)
  4. Return L(w) = Q(y_N)

How to compute the gradient of this quantity?  The following algorithm does the trick.

  1. \overleftarrow{y_N} \leftarrow \nabla Q(y_N)
  2. \overleftarrow{w} \leftarrow 0
  3. For k=N-1,...,0
  4.     (a) \overleftarrow{w} \leftarrow \overleftarrow{w} - \lambda \frac{\partial^2 E(y_k,w)}{\partial w \partial y^T} \overleftarrow{y_{k+1}}
  5.     (b) \overleftarrow{y_k} \leftarrow \overleftarrow{y_{k+1}} - \lambda \frac{\partial^2 E(y_k,w)}{\partial y \partial y^T} \overleftarrow{y_{k+1}}
  6. Return \nabla L = \overleftarrow{w}.

Similar algorithms can be derived for the heavy-ball algorithm (with a little more complexity) and limited memory BFGS (with a lot more complexity).

Code

So, finally, here is the code, and I’ll give a simple example of how to use it. There are just four simple files:

  1. back_gd.m
  2. back_hb.m
  3. back_lbfgs.m
  4. demo.m

I think the meanings of this are pretty straightforward, so I’ll just quickly step through the demo here.  I’ll start off by grabbing taking one of Matlab’s built-in datasets (on cities) so that we are trying to predict a measure of crime from measures of climate, housing, health, transportation, arts, recreation, and economy, as well as a constant.  There are 329 data, total, which I split into a training set of size 40, a validation set of size 160, and a test set of size 129.

load cities ratings
X = ratings;
for i=1:size(X,2)
    X(:,i) = X(:,i) - mean(X(:,i));
    X(:,i) = X(:,i) / std( X(:,i));
end

% predict crime from climate, housing, health, trans, edu, arts, rec, econ,
Y = X(:,4);
X = [X(:,[1:3 5:9]) 1+0*X(:,1)];

p = randperm(length(Y));
X = X(p,:);
Y = Y(p);

whotrain = 1:50;
whoval   = 51:200;
whotest  = 201:329;

Xtrain = X(whotrain,:);
Xval   = X(whoval  ,:);
Xtest  = X(whotest ,:);
Ytrain = Y(whotrain);
Yval   = Y(whoval  );
Ytest  = Y(whotest );

Next, I’ll set up some simple constants that will be used later on, and define the optimization parameters for minFunc, that I will be using for the outer optimization. In particular, here I choose the inner optimization to use 20 iterations.

opt_iters = 20;
ndims = size(Xtrain,2); w0    = zeros(ndims,1); ndata = size(Xtrain,1); ndata_val = size(Xval,1);
options = []; options.Method  = 'gd'; options.LS      = 1; options.MaxIter = 100;

Now, I’ll define the training risk function (E in the notation above). The computes the risk with a regularization constant of a, as well as derivatives. I’ll also define the validation risk (Q in the notation above).

function [R dRdw dRdloga] = training_risk(w,loga)
        a         = exp(loga);
        R         = sum( (Xtrain*w - Ytrain).^2 )/ndata + a*sum(w.^2);
        dRdw      = 2*Xtrain'*(Xtrain*w-Ytrain)  /ndata + 2*a*w;
        dRda      = sum(w.^2);
        dRdloga = dRda*a;
    end

    function [R g] = validation_risk(w)
        R = sum( (Xval*w - Yval).^2 ) / ndata_val;
        g = 2*Xval'*(Xval*w-Yval)     / ndata_val;
    end

Now, before going any further, let’s do a traditional sweep through regularization constants to see what that looks like.

LAMBDA = -5:.25:2;
    VAL_RISK = 0*LAMBDA;
    for i=1:length(LAMBDA)
        VAL_RISK(i) = back_lbfgs(@training_risk,@validation_risk,w0,LAMBDA(i),opt_iters);
    end

Plotting, we get the following:

cross_valid

This is a reasonable looking curve. Instead, let’s ask the algorithm to find the constant by gradient descent.

eval = @(loga) back_lbfgs(@training_risk,@validation_risk,w0,loga,opt_iters);        
loga = 0;
[loga fval] = minFunc(eval,loga,options);

Running the optimization, we see:

Iteration   FunEvals     Step Length    Function Val        Opt Cond
         1          2     1.00000e+00     8.74176e-01     3.71997e-03
         2          3     1.00000e+00     8.73910e-01     1.86453e-04
         3          4     1.00000e+00     8.73909e-01     1.06619e-05
         4          5     1.00000e+00     8.73909e-01     3.60499e-08

This leads to a regularizer of the form:

0.860543 ||w||_2^2.

We can plot this on the graph, and see it matches the result of cross-validation.

cross_valid_point

If we actually compute the test-set error, this is 0.708217.

OK, let’s be a little bit more adventurous, and use a third-order regularizer. This is done like so:

    function [R dRdw dRdloga] = training_risk2(w,loga)
        a         = exp(loga(1));
        b         = exp(loga(2));
        R         = sum( (Xtrain*w - Ytrain).^2 ) / ndata + a*sum(abs(w).^2)     + b*sum(abs(w).^3);
        dRdw      = 2*Xtrain'*(Xtrain*w-Ytrain)   / ndata + a*abs(w).^1.*sign(w) + b*abs(w).^2.*sign(w);
        dRda      = sum(abs(w).^2);
        dRdb      = sum(abs(w).^3);
        dRdloga = [dRda*a; dRdb*b];
    end

Running the optimization, we see:

Iteration   FunEvals     Step Length    Function Val        Opt Cond
         1          2     1.00000e+00     8.74445e-01     1.17262e-02
         2          3     1.00000e+00     8.73685e-01     3.21956e-03
         3          4     1.00000e+00     8.73608e-01     1.41744e-03
         4          5     1.00000e+00     8.73598e-01     8.20040e-04
         5          6     1.00000e+00     8.73567e-01     1.39830e-03
         6          7     1.00000e+00     8.73513e-01     2.52994e-03
         7          8     1.00000e+00     8.73471e-01     1.77157e-03
...
        23         28     1.00000e+00     8.70741e-01     8.42628e-06

With a final regularizer of the form

0.000160 ||w|_2^2 + 5.117057 * ||w||_3^3

and a hardly-improved test error of 0.679155.

Finally, let’s fit a fourth-order polynomial.

    function [R dRdw dRdloga] = training_risk3(w,loga)
        a         = exp(loga);
        R         = sum( (Xtrain*w - Ytrain).^2 ) / ndata;
        dRdw      = 2*Xtrain'*(Xtrain*w-Ytrain)   / ndata;
        for ii=1:length(a)
            b=ii+1;
            R    = R    + a(ii)*sum(abs(w).^b);
            dRdw = dRdw + a(ii)*abs(w).^(b-1).*sign(w);
            dRda(ii,1) = sum(abs(w).^b);
        end
        dRdloga = dRda.*a;
    end

    eval = @(loga) back_lbfgs(@training_risk3,@validation_risk,w0,loga,opt_iters);
    loga = [0; 0; 0; 0];
    loga = minFunc(eval,loga,options);

Running the optimization, we see

Iteration   FunEvals     Step Length    Function Val        Opt Cond
         1          2     1.00000e+00     8.73982e-01     1.02060e-02
         2          3     1.00000e+00     8.73524e-01     3.30445e-03
         3          4     1.00000e+00     8.73447e-01     1.74779e-03
         4          5     1.00000e+00     8.73435e-01     1.43655e-03
         5          6     1.00000e+00     8.73345e-01     4.98340e-03
         6          7     1.00000e+00     8.73295e-01     1.85535e-03
         7          8     1.00000e+00     8.73231e-01     1.81136e-03
...
        38         41     1.00000e+00     8.69758e-01     3.44848e-06

Yielding the (strange) regularizer

0.000000 ||w||_2^2 + 0.000010 ||w||_3^3 + 24.310631 ||w||_4^4 + 0.325565 ||w||_5^5

and a final test-error of 0.67187.

About these ads
This entry was posted in Uncategorized and tagged , , , , . Bookmark the permalink.

2 Responses to Truncated Bi-Level Optimization

  1. Tim Vieira says:

    Awesome! Thanks for sharing the code. I’ve already ported back_gd to python!

    Btw, there is a copy-paste error in the top-level documentation of your LBFGS code.

  2. justindomke says:

    That’s great! Thanks for the error, which should be fixed now.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s