At ICML I recently published a paper that I somehow decided to title “A Divergence Bound for Hybrids of MCMC and Variational Inference and an Application to Langevin Dynamics and SGVI”. This paper gives one framework for building “hybrid” algorithms between Markov chain Monte Carlo (MCMC) and Variational inference (VI) algorithms. Then, it gives an example for particular algorithms, namely:
- Is there a pleasing visualization?
- What is VI?
- What is MCMC
- Why would you want to interpolate between them?
- But they are so different. Is it even possible to combine them?
- That’s fine intuition but how can you guarantee anything formally?
- What distribution optimizes that bound?
- How would you apply this to real algorithms?
- I feel like you’re skipping some details here.
- Are there some more pictures of samples?
- What about “real” problems? Can you show samples there?
- Do you get the desired speed/accuracy tradeoff?
- Is there anything that remains unsatisfying about this approach?
Here’s three different views of the algorithm for a one-dimensional problem, interpolating between VI-like algorithms and MCMC-like algorithms as β goes from 0 (VI) to 1 (MCMC).
(Admittedly, this might not make that much sense at this point.)
The goal of “inference” is to be able to evaluate expectations with respect to some “target” distribution . Variational inference (VI) converts this problem into the minimization of the KL-divergence for some simple class of distributions parameterized by . For example, if is a mixture of two Gaussians (in red), and is a single Gaussian (in blue), the VI optimization in one dimension would arrive at the solution below.
Given that same target distribution , Markov chain Monte Carlo creates a random walk over . The random walk is carefully constructed so that if you run it a long time, the probability it will end up in any given state is proportional to . You can picture it as follows.
In short, VI is only an approximate algorithm, but MCMC can be very slow. In practice, the difference can be enormous– MCMC may require many orders of magnitude more time just to equal the performance of VI. This presents a user with an awkward situation where if one chooses the best algorithm for each time horizon, there’s a “gap” between when VI finishes until and when MCMC is better. Informally, you can get performance that looks like this:
Intuitively, it seems like something better should be achievable at those intermediate times.
Very roughly speaking, you can define a random walk over the space of variational distributions. Then, you trade off between how random the walk is and how random the distributions are. You arrive at something like this:
Put another way, both VI and MCMC seek “high probability” regions in , but with different coverage strategies:
- VI explicitly includes the entropy in its objective
- MCMC injects randomness into its walk
It is therefore natural to define a random walk over , where we trade off between “how random” the walk is and “how much” high entropy are favored.
Yes! Or, at least, sort of. To define a bit of notation, we start with a fixed variational family and a target distribution . Now, we want to define a distribution (so we can do a random walk) so that
The natural goal would be to minimize the KL-divergence . That’s difficult since is defined by marginalizing out– you can’t evaluate it. What you can do is set up two upper-bounds on this quantity.
The first bound is the conditional divergence:
The second bound is the joint divergence. You need to augment with some distribution and then you have the bound
Since these are both upper-bounds, a convex combination of them will also be. Thus, the goal is to find the distribution that minimizes for any in the [0,1] interval.
First, note that depends on the choice of . You get a valid upper-bound for any choice, but the tightness changes. The paper uses where is a normalizing constant. Here, you can think of as something akin to a “base measure”. is restricted to beconstant over . (This isn’t a terrible restriction– it essentially means that if were a prior for , it wouldn’t favor any point.)
Taking that choice of the solution turns out to be:
Furthermore, the actual value of the divergence bound at the solution turns out to be just the normalizing constant up to a constant, i.e.
To do anything concrete, you need to look at a specific VI algorithm and a specific MCMC algorithm. The paper uses
- Langevin dynamics as an MCMC algorithm, which iterate the upates
where is noise from a standard Gaussian and is a step-size.
- Stochastic gradient VI which uses the iteration
To get the novel algorithm in this paper, all that really needs to be done is to apply Langevin dynamics to the distribution derived above. Then, after a re-scaling, this becomes the new hybrid algorithm
Here, is the entropy of . This clearly becomes the previous VI algorithm in the limit of . It also essentially becomes Langevin with . That’s because the distribution (not yet defined!) will prefer where is highly concentrated. Thus, only the mean parameters of matter, and sampling becomes equivalent to just sampling .
Yes. First, the experiments below use a diagonal Gaussian for with and . Second, Tthe gradient of the objective involves a KL-divergence. Exactly computing this is intractable, but can be approximated with standard tricks from stochastic VI, namely data subsampling and the “reparameterization trick”. Third, needs to be chosen. The experiments below use the (improper) distribution where is a universal constant chosen for each to minimize the divergence bound with is a standard Gaussian. (If you — like I– find this displeasing, see below.)
Here’s a couple 2-D examples sampling from a “doughnut” distribution and a “three peaks” mixture distribution. Here, the samples are visualized by putting a curve at one standard deviation around the mean. Notice it smoothly becomes more “MCMC-like” as increases.
Sure, but of course in more than 2 dimensions its hard to show samples. Here are some results sampling from a logistic regression model on the classic ionosphere dataset. As a comparison, I implemented the same model with STAN and ran it a huge amount of time to generate “presumed correct” samples. I then projected all samples to the first two principal components.
(Note: technically what’s shown here is a sample being drawn from each sampled )
The top row shows the results after 104 iterations, the middle row after 105 and the bottom row after 106 You can roughly see that for small time horizons you are better off using a lower value of but for higher time horizons you should use a larger value.
Here, you need to compare the error each value of creates at each time horizon. This is made difficult by the fact that you also need to select a step-size and the best step-size changes depending on the time and . To be as fair as possible, I ran 100 experiments with a range of step-sizes, and averaged the performance. Then, for each value of and each time horizon, the results are shown with the best timestep. (Actually, this same procedure was used to generate the previous plots of samples as well.)
The above plot shows the error (measured in MMD) on the y axis against time on the x-axis. Note that both axes are logarithmic. There are often several orders of magnitude of time horizons where an intermediate algorithm performs better than pure VI (β=0) or pure MCMC (β=1).
The most unsatisfying thing about this approach is the need to choose . This is a bit disturbing, since this is not an object that “exists” in either the pure MCMC or pure VI worlds. On the other hand, there is a strong argument that it needs to exist here. If you carefully observe above, you’ll notice that it depends on the particular parameterization of . So, e.g. if we “stretched out” part of the space of this would change the marginal density . That would be truly disturbing, but if is transformed in the opposite way, it would counteract that. So, needs to exist to reflect how we’ve parameterized .
On the other hand, simply picking a single static distribution is pretty simplistic. (Recall, was defined in terms of above) It would be natural to try to adjust this distribution during inference to tighten the bound . Using the fact that you can show that it’s possible to find derivatives of with respect to the parameters of online, and thus tighten the bound while the algorithm is running. (I didn’t want to do this in this paper since neither VI nor MCMC do this, and it complicates the interpretation of the experiments.)
Finally, the main question is if this can be extended to other pairs of VI / MCMC algorithms. I actually first derived this algorithm by looking at simple discrete graphical models, e.g. Ising models. There, you can use the algorithms:
- VI: Use a fully-factorized variational distribution and single-site coordinate ascent updates
- MCMC: Use single-site Gibbs sampling.
You do in fact get a useful hybrid algorithm in the middle. However, the unfortunate reality is that both of the endpoints are considered pretty bad algorithms, so its hard to get too excited about the interpolation.
Finally, do note that there are other ideas out there for combining MCMC and VI. However, these usually fall into the camps of “putting MCMC inside of VI”    or “putting VI inside of MCMC” , rather than a straight interpolation of the two.