Why Parallel Tempering (PT)? An example.

Consider a Bayesian model where the likelihood is a binomial distribution with probability parameter $p$. Let us consider an over-parameterized model where we write $p = p_1 p_2$. Assume that each $p_i$ has a uniform prior on the interval $[0, 1]$. This is a toy example of an unidentifiable parameterization. In practice many popular Bayesian models are unidentifiable.

When there are many observations, the posterior of unidentifiable models concentrate on a sub-manifold, making sampling difficult, as shown in the following pair plots:

Unidentifiable example without PT

Let us look at trace plots obtained from performing single-chain MCMC on this problem. The key part of the code below is the argument n_chains = 1: we have designed our PT implementation so that setting the number of chains to one reduces to a standard MCMC algorithm.

using DynamicPPL
using Pigeons
using MCMCChains
using StatsPlots
plotlyjs()

# The model described above implemented in Turing
# note we are using a large observation size here
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100000, 50000)

pt = pigeons(
        target = an_unidentifiable_model,
        n_chains = 1, # <- corresponds to single chain MCMC
        record = [traces])

# collect the statistics and convert to MCMCChains' Chains
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "no_pt_posterior_densities_and_traces.html");
───────────────────────────────────────────
  scans    log(Z₁/Z₀)   min(αₑ)   mean(αₑ)
────────── ────────── ────────── ──────────
        2          0          1          1
        4          0          1          1
        8          0          1          1
       16          0          1          1
       32          0          1          1
       64          0          1          1
      128          0          1          1
      256          0          1          1
      512          0          1          1
 1.02e+03          0          1          1
───────────────────────────────────────────

It is quite obvious that mixing is poor, as confirmed by effective sample size (ESS) estimates:

samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):

Iterations        = 1:1:1024
Number of chains  = 1
Samples per chain = 1024
parameters        = p1, p2
internals         = log_density

Use `describe(chains)` for summary statistics and quantiles.

Unidentifiable example with PT

Let us enable parallel tempering now, by setting n_chains to a value greater than one:

pt = pigeons(
        target = an_unidentifiable_model,
        n_chains = 10,
        record = [traces, round_trip])

# collect the statistics and convert to MCMCChains' Chains
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "with_pt_posterior_densities_and_traces.html");
───────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ      log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0          2  -3.54e+03          0      0.778          1          1
        4          0       2.35      -52.9   5.81e-38      0.739          1          1
        8          0        3.2      -10.1      0.256      0.645          1          1
       16          0       3.93      -10.8      0.319      0.563          1          1
       32          0       3.51      -12.3      0.313       0.61          1          1
       64          4       3.22      -10.9       0.41      0.642          1          1
      128          7       3.63      -11.9      0.419      0.596          1          1
      256         16       3.44      -11.5      0.566      0.618          1          1
      512         40       3.43      -11.6      0.527      0.619          1          1
 1.02e+03         63       3.49      -11.8      0.563      0.612          1          1
───────────────────────────────────────────────────────────────────────────────────────

There is a marked difference. Thanks to round trips through the reference distribution, where we can sample iid, we are able to jump at different parts of the state space.

This is also confirmed by the PT ESS estimates:

samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):

Iterations        = 1:1:1024
Number of chains  = 1
Samples per chain = 1024
parameters        = p1, p2
internals         = log_density

Use `describe(chains)` for summary statistics and quantiles.