A Divergence Bound For Hybrids of MCMC and Variational Inference and …

At ICML I recently published a paper that I somehow decided to title “A Divergence Bound for Hybrids of MCMC and Variational Inference and an Application to Langevin Dynamics and SGVI”. This paper gives one framework for building “hybrid” algorithms between Markov chain Monte Carlo (MCMC) and Variational inference (VI) algorithms. Then, it gives an example for particular algorithms, namely:

  • MCMC ⇔ Stochastic Gradient Langevin Dynamics [1] [2] [3]
  • VI ⇔ Stochatic Gradient VI [4] [5] [6]

Outline:

Is there a pleasing visualization?

Here’s three different views of the algorithm for a one-dimensional problem, interpolating between VI-like algorithms and MCMC-like algorithms as β goes from 0 (VI) to 1 (MCMC).

1d_example_ver

(Admittedly, this might not make that much sense at this point.)

What is VI?

The goal of “inference” is to be able to evaluate expectations with respect to some “target” distribution p(z). Variational inference (VI) converts this problem into the minimization of the KL-divergence KL(q_w(z) \Vert p(z)) for some simple class of distributions q_w(z) parameterized by w. For example, if p is a mixture of two Gaussians (in red), and q is a single Gaussian (in blue), the VI optimization in one dimension would arrive at the solution below.

samps_lab0.00

What is MCMC?

Given that same target distribution p(z), Markov chain Monte Carlo creates a random walk over z. The random walk is carefully constructed so that if you run it a long time, the probability it will end up in any given state is proportional to p(z). You can picture it as follows.

samps1.00

Why would you want to interpolate between them?

In short, VI is only an approximate algorithm, but MCMC can be very slow. In practice, the difference can be enormous– MCMC may require many orders of magnitude more time just to equal the performance of VI. This presents a user with an awkward situation where if one chooses the best algorithm for each time horizon, there’s a “gap” between when VI finishes until and when MCMC is better. Informally, you can get performance that looks like this:

time_cartoon.png

Intuitively, it seems like something better should be achievable at those intermediate times.

But they are so different. Is it even possible to combine them?

Very roughly speaking, you can define a random walk over the space of variational distributions. Then, you trade off between how random the walk is and how random the distributions are. You arrive at something like this:

samps0.32.png

Put another way, both VI and MCMC seek “high probability” regions in z, but with different coverage strategies:

  • VI explicitly includes the entropy in its objective
  • MCMC injects randomness into its walk

It is therefore natural to define a random walk over w, where we trade off between “how random” the walk is and “how much” high entropy w are favored.

That’s fine intuition but how can you guarantee anything formally?

Yes! Or, at least, sort of. To define a bit of notation, we start with a fixed variational family q(z|w) and a target distribution p(z). Now, we want to define a distribution q(w) (so we can do a random walk) so that

q(z) = \int_w q(w) q(z|w) \approx p(z).

The natural goal would be to minimize the KL-divergence KL(q(Z) \Vert p(Z))=\int_z q(z) \log(q(z)/p(z)). That’s difficult since q(z) is defined by marginalizing q(w) out– you can’t evaluate it. What you can do is set up two upper-bounds on this quantity.

The first bound is the conditional divergence:

KL(q(Z) \Vert p(Z)) \leq D_0 := \int_w q(w) \int_z q(z|w) \frac{q(z|w)}{p(z)}

The second bound is the joint divergence. You need to augment p(z) with some distribution p(w|z) and then you have the bound

KL(q(Z) \Vert p(Z)) \leq D_1 := \int_w q(w) \int_z q(z|w) \frac{q(w)q(z|w)}{p(z)p(w|z)}

Since these are both upper-bounds, a convex combination of them will also be. Thus, the goal is to find the distribution q(w) that minimizes D_\beta = (1-\beta)D_0 + \beta D_1, for any \beta in the [0,1] interval.

What distribution optimizes that bound?

First, note that D_\beta depends on the choice of p(w|z). You get a valid upper-bound for any choice, but the tightness changes. The paper uses p(w|z) = r(w) q(z|w) / r_z where r_z = \int_w r(w) q(z|w) is a normalizing constant. Here, you can think of r(w) as something akin to a “base measure”. r_z is restricted to beconstant over z. (This isn’t a terrible restriction– it essentially means that if r(w) were a prior for q(z|w), it wouldn’t favor any point.)

Taking that choice of p(w|z) the solution turns out to be:

q^*(w) = \exp( s(w) - A)
s(w) = \log r(w) - \log r_z + E_{q(Z|w)} [\beta^{-1} \log p(Z) + (1-\beta^{-1}) \log q(Z|w)]
A = \log \int_w \exp s(w)

Furthermore, the actual value of the divergence bound at the solution turns out to be just the normalizing constant A up to a constant, i.e.

D^*_\beta = - \beta A.

How would you apply this to real algorithms?

To do anything concrete, you need to look at a specific VI algorithm and a specific MCMC algorithm. The paper uses

  • Langevin dynamics as an MCMC algorithm, which iterate the upates

    z \leftarrow z + \frac{\epsilon}{w} \nabla_z \log p(z) + \sqrt{\epsilon} \eta

    where \eta is noise from a standard Gaussian and \epsilon is a step-size.

  • Stochastic gradient VI which uses the iteration

    w \leftarrow w - \frac{\epsilon}{2} \nabla_w KL(q(Z|w) \Vert p(Z))

To get the novel algorithm in this paper, all that really needs to be done is to apply Langevin dynamics to the distribution q^*(w) derived above. Then, after a re-scaling, this becomes the new hybrid algorithm

w \leftarrow w + \frac{\epsilon}{2} \nabla_w \Big( KL(q(Z|w) \Vert p(Z)) - \beta H(w) + \beta \log r_\beta(w) \Big) + \sqrt{\beta \epsilon} \eta.

Here, H is the entropy of q_w. This clearly becomes the previous VI algorithm in the limit of \beta \rightarrow 0. It also essentially becomes Langevin with \beta \rightarrow 1. That’s because the distribution r_\beta (not yet defined!) will prefer w where q(Z|w) is highly concentrated. Thus, only the mean parameters of w matter, and sampling w becomes equivalent to just sampling z.

I feel like you’re skipping some details here.

Yes. First, the experiments below use a diagonal Gaussian for q(z|w) with w=(\mu, \nu) and \nu_i = \log_{10} \sigma_i. Second, Tthe gradient of the objective involves a KL-divergence. Exactly computing this is intractable, but can be approximated with standard tricks from stochastic VI, namely data subsampling and the “reparameterization trick”. Third, r_\beta needs to be chosen. The experiments below use the (improper) distribution r_\beta(w) \propto \prod_i \mathcal{N}(\nu_i \vert u_\beta,1) where u_\beta is a universal constant chosen for each \beta to minimize the divergence bound with p(z) is a standard Gaussian. (If you — like I– find this displeasing, see below.)

Are there some more pictures of samples?

Here’s a couple 2-D examples sampling from a “doughnut” distribution and a “three peaks” mixture distribution. Here, the samples are visualized by putting a curve at one standard deviation around the mean. Notice it smoothly becomes more “MCMC-like” as \beta increases.

three_peaks.png

donut.png

What about “real” problems? Can you show samples there?

Sure, but of course in more than 2 dimensions its hard to show samples. Here are some results sampling from a logistic regression model on the classic ionosphere dataset. As a comparison, I implemented the same model with STAN and ran it a huge amount of time to generate “presumed correct” samples. I then projected all samples to the first two principal components.

ionosphere_samples.png

(Note: technically what’s shown here is a sample z being drawn from each sampled w)

The top row shows the results after 104 iterations, the middle row after 105 and the bottom row after 106 You can roughly see that for small time horizons you are better off using a lower value of \beta but for higher time horizons you should use a larger value.

Do you get the desired speed/accuracy tradeoff?

Here, you need to compare the error each value of \beta creates at each time horizon. This is made difficult by the fact that you also need to select a step-size \epsilon and the best step-size changes depending on the time and \beta. To be as fair as possible, I ran 100 experiments with a range of step-sizes, and averaged the performance. Then, for each value of \beta and each time horizon, the results are shown with the best timestep. (Actually, this same procedure was used to generate the previous plots of samples as well.)

logreg_errors.png

The above plot shows the error (measured in MMD) on the y axis against time on the x-axis. Note that both axes are logarithmic. There are often several orders of magnitude of time horizons where an intermediate algorithm performs better than pure VI (β=0) or pure MCMC (β=1).

Is there anything that remains unsatisfying about this approach?

The most unsatisfying thing about this approach is the need to choose p(w|z). This is a bit disturbing, since this is not an object that “exists” in either the pure MCMC or pure VI worlds. On the other hand, there is a strong argument that it needs to exist here. If you carefully observe q^*(w) above, you’ll notice that it depends on the particular parameterization of w. So, e.g. if we “stretched out” part of the space of w this would change the marginal density q(z). That would be truly disturbing, but if p(w|z) is transformed in the opposite way, it would counteract that. So, p(w|z) needs to exist to reflect how we’ve parameterized w.

On the other hand, simply picking a single static distribution r_\beta(w) is pretty simplistic. (Recall, p(w|z) was defined in terms of r(w) above) It would be natural to try to adjust this distribution during inference to tighten the bound D^*_\beta. Using the fact that D^*_\beta=-\beta A you can show that it’s possible to find derivatives of D^*_\beta with respect to the parameters of r online, and thus tighten the bound while the algorithm is running. (I didn’t want to do this in this paper since neither VI nor MCMC do this, and it complicates the interpretation of the experiments.)

Finally, the main question is if this can be extended to other pairs of VI / MCMC algorithms. I actually first derived this algorithm by looking at simple discrete graphical models, e.g. Ising models. There, you can use the algorithms:

  • VI: Use a fully-factorized variational distribution and single-site coordinate ascent updates
  • MCMC: Use single-site Gibbs sampling.

You do in fact get a useful hybrid algorithm in the middle. However, the unfortunate reality is that both of the endpoints are considered pretty bad algorithms, so its hard to get too excited about the interpolation.

Finally, do note that there are other ideas out there for combining MCMC and VI. However, these usually fall into the camps of “putting MCMC inside of VI” [7] [8] [9] or “putting VI inside of MCMC” [10], rather than a straight interpolation of the two.

Advertisements

One thought on “A Divergence Bound For Hybrids of MCMC and Variational Inference and …

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s