Why Parallel Tempering (PT)? An example.
Consider a Bayesian model where the likelihood is a binomial distribution with probability parameter $p$. Let us consider an over-parameterized model where we write $p = p_1 p_2$. Assume that each $p_i$ has a uniform prior on the interval $[0, 1]$. This is a toy example of an unidentifiable parameterization. In practice many popular Bayesian models are unidentifiable.
When there are many observations, the posterior of unidentifiable models concentrate on a sub-manifold, making sampling difficult, as shown in the following pair plots:
Unidentifiable example without PT
Let us look at trace plots obtained from performing single-chain MCMC on this problem. The key part of the code below is the argument n_chains = 1
: we have designed our PT implementation so that setting the number of chains to one reduces to a standard MCMC algorithm.
using DynamicPPL
using Pigeons
using MCMCChains
using StatsPlots
plotlyjs()
# The model described above implemented in Turing
# note we are using a large observation size here
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100000, 50000)
pt = pigeons(
target = an_unidentifiable_model,
n_chains = 1, # <- corresponds to single chain MCMC
record = [traces])
# collect the statistics and convert to MCMCChains' Chains
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "no_pt_posterior_densities_and_traces.html");
───────────────────────────────────────────
scans log(Z₁/Z₀) min(αₑ) mean(αₑ)
────────── ────────── ────────── ──────────
2 0 1 1
4 0 1 1
8 0 1 1
16 0 1 1
32 0 1 1
64 0 1 1
128 0 1 1
256 0 1 1
512 0 1 1
1.02e+03 0 1 1
───────────────────────────────────────────
It is quite obvious that mixing is poor, as confirmed by effective sample size (ESS) estimates:
samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):
Iterations = 1:1:1024
Number of chains = 1
Samples per chain = 1024
parameters = p1, p2
internals = log_density
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat e ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
p1 0.7498 0.1107 0.0672 2.8759 29.9758 1.7339 ⋯
p2 0.6799 0.0889 0.0547 2.8784 30.6664 1.7317 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
p1 0.6379 0.6679 0.6975 0.8143 0.9817
p2 0.5090 0.6135 0.7168 0.7490 0.7841
Unidentifiable example with PT
Let us enable parallel tempering now, by setting n_chains
to a value greater than one:
pt = pigeons(
target = an_unidentifiable_model,
n_chains = 10,
record = [traces, round_trip])
# collect the statistics and convert to MCMCChains' Chains
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "with_pt_posterior_densities_and_traces.html");
───────────────────────────────────────────────────────────────────────────────────────
scans restarts Λ log(Z₁/Z₀) min(α) mean(α) min(αₑ) mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
2 0 1.04 -4.24e+03 0 0.885 1 1
4 0 4.06 -16.3 4.63e-06 0.549 1 1
8 0 3.49 -12.1 0.215 0.612 1 1
16 0 2.68 -10.2 0.518 0.703 1 1
32 0 4.29 -11.8 0.222 0.524 1 1
64 3 3.17 -11.5 0.529 0.648 1 1
128 8 3.56 -11.5 0.523 0.605 1 1
256 12 3.38 -11.6 0.526 0.625 1 1
512 37 3.48 -12 0.527 0.614 1 1
1.02e+03 77 3.55 -11.8 0.571 0.605 1 1
───────────────────────────────────────────────────────────────────────────────────────
There is a marked difference. Thanks to round trips through the reference distribution, where we can sample iid, we are able to jump at different parts of the state space.
This is also confirmed by the PT ESS estimates:
samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):
Iterations = 1:1:1024
Number of chains = 1
Samples per chain = 1024
parameters = p1, p2
internals = log_density
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat e ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
p1 0.7002 0.1380 0.0125 122.5666 159.2296 1.0070 ⋯
p2 0.7419 0.1430 0.0127 122.1132 170.5302 1.0108 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
p1 0.5130 0.5702 0.6919 0.8044 0.9703
p2 0.5160 0.6214 0.7222 0.8771 0.9751