Say you’ve got a positive dataset and you want to calculate the variance. However, the numbers in your dataset are huge, so huge you need to represent them in the log-domain. How do you compute the log-variance without things blowing up?
I ran into the problem today. To my surprise, I couldn’t find a standard solution.
The bad solution
Suppose that your data is , which you have stored as
where
. The obvious thing to do is to just exponentiate and then compute the variance. That would be something like the following:
This of course is a terrible idea: When is large, you can’t even write down
without running into numerical problems.
The mediocre solution
The first idea I had for this problem was relatively elegant. We can of course represent the variance as
Instead of calculating and
, why not calculate the log of these quantities?
To do this, we can introduce a “log domain mean” operator, a close relative of the good-old scipy.special.logsumexp
def log_domain_mean(logx): "np.log(np.mean(np.exp(x))) but more stable" n = len(logx) damax = np.max(logx) return np.log(np.sum(np.exp(logx-damax))) \ + damax-np.log(n)
Next, introduce a “log-sub-add” operator. (A variant of np.logaddexp
)
def logsubadd(a,b): "np.log(np.exp(a)-np.exp(b)) but more stable" return a + np.log(1-np.exp(b-a))
Then, we can compute the log-variance as
def log_domain_var(logx): a = log_domain_mean(2*logx) b = log_domain_mean(logx)*2 c = logsubadd(a,b) return c
Here a
is while
b
is .
Nice, right? Well, it’s much better then the first solution. But it isn’t that good. The problem is that when the variance is small, a
and b
are close. When they are both close and large, logsubadd
runs into numerical problems. It’s not clear that there is a way to fix this problem with logsubadd
.
To solve this, abandon elegance!
The good solution
For the good solution, the math is a series of not-too-intuitive transformations. (I put them at the end.) These start with
and end with
Why this form? Well, we’ve reduced to things we can do relatively stably: Compute the log-mean, and do a (small variant of) log-sum-exp.
def log_domain_var(logx): """like np.log(np.var(np.exp(logx))) except more stable""" n = len(logx) log_xmean = log_domain_mean(logx) return np.log(np.sum( np.expm1(logx-log_xmean)**2))\ + 2*log_xmean - np.log(n)
This uses the log_domain_mean
implementation from above, and also np.expm1
to compute in a more stable wauy when
a
is close to zero.
Why is this stable? Is it really stable? Well, umm, I’m not sure. I derived transformations that “looked stable” to me, but there’s no proof that this is best. I’d be surprised if a better solution wasn’t possible. (I’d also be surprised if there isn’t a paper from 25+ years ago that describes that better solution.)
In any case, I’ve experimentally found that this function will (while working in single precision) happily compute the variance even when logx
is in the range of to
, which is about 28 orders of magnitude better than the naive solution and sufficient for my needs.
As always, failure cases are probably out there. Numerical instability always wins when it can be bothered to make an effort.