Why does regularization work?

When fitting statistical models, we usually need to “regularize” the model. The simplest example is probably linear regression. Take some training data, {({\hat {\bf x}}, {\hat y})}. Given a vector of weights \bf w, the total squared distance is

\sum_{ ({\hat {\bf x}}, {\hat y})} ({\bf w}^T{\hat {\bf x}} - {\hat y})^2

So to fit the model, we might find {\bf w} to minimize the above loss. Commonly, (particularly when {\bf x} has many dimensions), we find t\hat the above procedure leads to overfitting: very large weights {\bf w} t\hat fit the training data very well, but poorly predict future data. “Regularization” means modifying the optimization problem to prefer small weights. Consider

\sum_{ ({\hat {\bf x}}, {\hat y})} ({\bf w}^T{\hat {\bf x}} - {\hat y})^2 + \lambda ||{\bf w}||^2.

Here, \lambda trades off how much we care about the model fitting well versus how we care about {\bf w} being small. Of course, this is a much more general concept that linear regression, and one could pick many alternatives to the squared norm to measure how simple a model is. Let’s abstract a bit, and just write some function M to represent how well the model fits the data, and some function R to represent how simple it is.

M({\bf w}) + \lambda R({\bf w})

The obvious question is: how do we pick \lambda? This, of course, has been the subject of much theoretical study. What is commonly done in practice is this:

  1. Divide the data into a training set and test set.
  2. Fit the model for a range of different \lambdas using only the training set.
  3. For each of the models fit in step 2, check how well the resulting weights fit the test data.
  4. Output the weights that perform best on test data.

(One can also retrain on all the data using the \lambda that did best in step 2.)

Now, there are many ways to measure simplicity. (In the example above, we might consider ||{\bf w}||^2, ||{\bf w}||, ||{\bf w}||_1, ||{\bf w}||_0, …). Which one to use? If you drank too much coffee one afternoon, you might decide to include a bunch of regulizers simultaneously, each with a corresponding regularization parameter, ending up with a problem like

M({\bf w}) + \sum_i \lambda_i R_i({\bf w}).

And yet, staring at that equation, things take a sinister hue: if we include too many regularization parameters, won’t we eventually overfit the regularization parameters? And furthermore, won’t it be very hard to test all possible combinations of \lambda_i to some reasonable resolution? What should we do?

And now I will finally get to the point.

Recently, I’ve been looking at regularization in a different way, which seems very obvious in retrospect. The idea is that when searching for the optimal regularization parameters, we are fitting a model– a model that just happens to be defined in an odd way.

First, let me define some notation. Let T denote the training set, and let S denote the test set. Now, define

{\bf f}(\lambda) = \arg\min_{\bf w} M_T({\bf w}) + \lambda R({\bf w})

Where M_T is the function measuring how well the model fits the training data. So, for a given regularization parameter, {\bf f} returns the weights solving the optimization problem. Given that, the optimal \lambda will the the one minimizing this equation:

\lambda^* = \arg\min_\lambda M_S( {\bf f}(\lambda) )

This extends naturally to the setting where we have a vector of regularization parameters {\bf \lambda}.

{\bf f}({\boldsymbol \lambda}) = \arg\min_{\bf w} M_T({\bf w}) + \sum_i \lambda_i R_i({\bf w})

{\boldsymbol \lambda}^* = \arg\min_{\bf \lambda} M_S( {\bf f}({\boldsymbol \lambda}) )

From this perspective, doing an optimization over regularization parameters makes sense: it is just fitting the parameters \boldsymbol \lambda to the data S— it just so happens that the objective function is implicitly defined by a minimization over the data T.

Regularization works because it is fitting a single parameter to some data. (In general, one is safe fitting a single parameter to any reasonably sized dataset although there are exceptions.) Fitting multiple regularization parameters should work, supposing the test data is big enough to support them. (There are computational issues I’m not discussing here with doing this, stemming from the fact that {\bf f}({\boldsymbol \lambda}) isn’t known in closed form, but is implicitly defined by a minimization. So far as I know, this is an open problem, although I suspect the solution is “implicit differentiation”)

Even regularizing the regularization parameters isn’t necessarily a crazy idea, though clearly ||{\boldsymbol \lambda}||^2 isn’t the way to do it. (Simpler models correspond to larger regularization constants. Maybe 1/||{\boldsymbol \lambda}||^2 ?.) Then, you can regularize those parameters!

Well. In retrospect that wasn’t quite so simple as it seemed.

It’s a strong bet that this is a standard interpretation in some communities.

Update: The idea of fitting multiple parameters using implicit differentiation is published in the 1996 paper Adaptive regularization in neural network modeling by Jan Larsen, Claus Svarer, Lars Nonboe Andersen and Lars Kai Hansen.

4 thoughts on “Why does regularization work?

  1. Hello,

    I just happen to see you website where you explain regularization. Very well explained just a small detail that I think is missing how to apply lamda back as input.

    M({\bf w}) + \lambda R({\bf w})

    You see from the equation that you given Lamda is just multiplier you can choose any abrupt number and output may not change or I am missing something?

  2. Hello,

    I’m not sure I understand your question, but in general it is not obvious exactly how to choose the value of lambda. I suppose that the standard “best practice” would be to try a range of different lambdas and see which leads to the best performance on a validation set (or, better, to do this using cross-validation).

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s