Why Parallel Tempering (PT)? An example.

Consider a Bayesian model where the likelihood is a binomial distribution with probability parameter $p$. Let us consider an over-parameterized model where we write $p = p_1 p_2$. Assume that each $p_i$ has a uniform prior on the interval $[0, 1]$. This is a toy example of an unidentifiable parameterization. In practice many popular Bayesian models are unidentifiable.

When there are many observations, the posterior of unidentifiable models concentrate on a sub-manifold, making sampling difficult, as shown in the following pair plots:

Unidentifiable example without PT

Let us look at trace plots obtained from performing single-chain MCMC on this problem. The key part of the code below is the argument n_chains = 1: we have designed our PT implementation so that setting the number of chains to one reduces to a standard MCMC algorithm.

using DynamicPPL
using Pigeons
using MCMCChains
using StatsPlots
plotlyjs()

# The model described above implemented in Turing
# note we are using a large observation size here
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100000, 50000)

pt = pigeons(
        target = an_unidentifiable_model,
        n_chains = 1, # <- corresponds to single chain MCMC
        record = [traces])

# collect the statistics and convert to MCMCChains' Chains
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "no_pt_posterior_densities_and_traces.html");
───────────────────────────────────────────
  scans    log(Z₁/Z₀)   min(αₑ)   mean(αₑ)
────────── ────────── ────────── ──────────
        2          0          1          1
        4          0          1          1
        8          0          1          1
       16          0          1          1
       32          0          1          1
       64          0          1          1
      128          0          1          1
      256          0          1          1
      512          0          1          1
 1.02e+03          0          1          1
───────────────────────────────────────────

It is quite obvious that mixing is poor, as confirmed by effective sample size (ESS) estimates:

samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):

Iterations        = 1:1:1024
Number of chains  = 1
Samples per chain = 1024
parameters        = p1, p2
internals         = log_density

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e     Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

          p1    0.7498    0.1107    0.0672     2.8759    29.9758    1.7339     ⋯
          p2    0.6799    0.0889    0.0547     2.8784    30.6664    1.7317     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          p1    0.6379    0.6679    0.6975    0.8143    0.9817
          p2    0.5090    0.6135    0.7168    0.7490    0.7841

Unidentifiable example with PT

Let us enable parallel tempering now, by setting n_chains to a value greater than one:

pt = pigeons(
        target = an_unidentifiable_model,
        n_chains = 10,
        record = [traces, round_trip])

# collect the statistics and convert to MCMCChains' Chains
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "with_pt_posterior_densities_and_traces.html");
───────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ      log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       1.04  -4.24e+03          0      0.885          1          1
        4          0       4.06      -16.3   4.63e-06      0.549          1          1
        8          0       3.49      -12.1      0.215      0.612          1          1
       16          0       2.68      -10.2      0.518      0.703          1          1
       32          0       4.29      -11.8      0.222      0.524          1          1
       64          3       3.17      -11.5      0.529      0.648          1          1
      128          8       3.56      -11.5      0.523      0.605          1          1
      256         12       3.38      -11.6      0.526      0.625          1          1
      512         37       3.48        -12      0.527      0.614          1          1
 1.02e+03         77       3.55      -11.8      0.571      0.605          1          1
───────────────────────────────────────────────────────────────────────────────────────

There is a marked difference. Thanks to round trips through the reference distribution, where we can sample iid, we are able to jump at different parts of the state space.

This is also confirmed by the PT ESS estimates:

samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):

Iterations        = 1:1:1024
Number of chains  = 1
Samples per chain = 1024
parameters        = p1, p2
internals         = log_density

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e     Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

          p1    0.7002    0.1380    0.0125   122.5666   159.2296    1.0070     ⋯
          p2    0.7419    0.1430    0.0127   122.1132   170.5302    1.0108     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          p1    0.5130    0.5702    0.6919    0.8044    0.9703
          p2    0.5160    0.6214    0.7222    0.8771    0.9751