Why Parallel Tempering (PT)? An example.
Consider a Bayesian model where the likelihood is a binomial distribution with probability parameter $p$. Let us consider an over-parameterized model where we write $p = p_1 p_2$. Assume that each $p_i$ has a uniform prior on the interval $[0, 1]$. This is a toy example of an unidentifiable parameterization. In practice many popular Bayesian models are unidentifiable.
When there are many observations, the posterior of unidentifiable models concentrate on a sub-manifold, making sampling difficult, as shown in the following pair plots:
Unidentifiable example without PT
Let us look at trace plots obtained from performing single-chain MCMC on this problem. The key part of the code below is the argument n_chains = 1
: we have designed our PT implementation so that setting the number of chains to one reduces to a standard MCMC algorithm.
using DynamicPPL
using Pigeons
using MCMCChains
using StatsPlots
plotlyjs()
# The model described above implemented in Turing
# note we are using a large observation size here
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100000, 50000)
pt = pigeons(
target = an_unidentifiable_model,
n_chains = 1, # <- corresponds to single chain MCMC
record = [traces])
# collect the statistics and convert to MCMCChains' Chains
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "no_pt_posterior_densities_and_traces.html");
───────────────────────────────────────────
scans log(Z₁/Z₀) min(αₑ) mean(αₑ)
────────── ────────── ────────── ──────────
2 0 1 1
4 0 1 1
8 0 1 1
16 0 1 1
32 0 1 1
64 0 1 1
128 0 1 1
256 0 1 1
512 0 1 1
1.02e+03 0 1 1
───────────────────────────────────────────
It is quite obvious that mixing is poor, as confirmed by effective sample size (ESS) estimates:
samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):
Iterations = 1:1:1024
Number of chains = 1
Samples per chain = 1024
parameters = p1, p2
internals = log_density
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat e ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
p1 0.9646 0.0244 0.0047 29.1938 76.8349 1.0231 ⋯
p2 0.5187 0.0135 0.0026 29.2742 72.9775 1.0226 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
p1 0.9122 0.9491 0.9683 0.9846 0.9983
p2 0.5004 0.5080 0.5166 0.5269 0.5488
Unidentifiable example with PT
Let us enable parallel tempering now, by setting n_chains
to a value greater than one:
pt = pigeons(
target = an_unidentifiable_model,
n_chains = 10,
record = [traces, round_trip])
# collect the statistics and convert to MCMCChains' Chains
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "with_pt_posterior_densities_and_traces.html");
───────────────────────────────────────────────────────────────────────────────────────
scans restarts Λ log(Z₁/Z₀) min(α) mean(α) min(αₑ) mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
2 0 2 -3.54e+03 0 0.778 1 1
4 0 2.35 -52.9 5.81e-38 0.739 1 1
8 0 3.2 -10.1 0.256 0.645 1 1
16 0 3.93 -10.8 0.319 0.563 1 1
32 0 3.51 -12.3 0.313 0.61 1 1
64 4 3.22 -10.9 0.41 0.642 1 1
128 7 3.63 -11.9 0.419 0.596 1 1
256 16 3.44 -11.5 0.566 0.618 1 1
512 40 3.43 -11.6 0.527 0.619 1 1
1.02e+03 63 3.49 -11.8 0.563 0.612 1 1
───────────────────────────────────────────────────────────────────────────────────────
There is a marked difference. Thanks to round trips through the reference distribution, where we can sample iid, we are able to jump at different parts of the state space.
This is also confirmed by the PT ESS estimates:
samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):
Iterations = 1:1:1024
Number of chains = 1
Samples per chain = 1024
parameters = p1, p2
internals = log_density
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat e ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
p1 0.7375 0.1556 0.0207 87.7657 128.1847 1.0502 ⋯
p2 0.7103 0.1552 0.0153 85.8256 107.0261 1.0252 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
p1 0.5069 0.5914 0.7321 0.8824 0.9814
p2 0.5088 0.5668 0.6832 0.8460 0.9877