Why Parallel Tempering (PT)? An example.

Consider a Bayesian model where the likelihood is a binomial distribution with probability parameter $p$. Let us consider an over-parameterized model where we write $p = p_1 p_2$. Assume that each $p_i$ has a uniform prior on the interval $[0, 1]$. This is a toy example of an unidentifiable parameterization. In practice many popular Bayesian models are unidentifiable.

When there are many observations, the posterior of unidentifiable models concentrate on a sub-manifold, making sampling difficult, as shown in the following pair plots:

Unidentifiable example without PT

Let us look at trace plots obtained from performing single-chain MCMC on this problem. The key part of the code below is the argument n_chains = 1: we have designed our PT implementation so that setting the number of chains to one reduces to a standard MCMC algorithm.

using DynamicPPL
using Pigeons
using MCMCChains
using StatsPlots
plotlyjs()

# The model described above implemented in Turing
# note we are using a large observation size here
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100000, 50000)

pt = pigeons(
        target = an_unidentifiable_model,
        n_chains = 1, # <- corresponds to single chain MCMC
        record = [traces])

# collect the statistics and convert to MCMCChains' Chains
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "no_pt_posterior_densities_and_traces.html");
───────────────────────────────────────────
  scans    log(Z₁/Z₀)   min(αₑ)   mean(αₑ)
────────── ────────── ────────── ──────────
        2          0          1          1
        4          0          1          1
        8          0          1          1
       16          0          1          1
       32          0          1          1
       64          0          1          1
      128          0          1          1
      256          0          1          1
      512          0          1          1
 1.02e+03          0          1          1
───────────────────────────────────────────

It is quite obvious that mixing is poor, as confirmed by effective sample size (ESS) estimates:

samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):

Iterations        = 1:1:1024
Number of chains  = 1
Samples per chain = 1024
parameters        = p1, p2
internals         = log_density

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e     Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

          p1    0.9646    0.0244    0.0047    29.1938    76.8349    1.0231     ⋯
          p2    0.5187    0.0135    0.0026    29.2742    72.9775    1.0226     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          p1    0.9122    0.9491    0.9683    0.9846    0.9983
          p2    0.5004    0.5080    0.5166    0.5269    0.5488

Unidentifiable example with PT

Let us enable parallel tempering now, by setting n_chains to a value greater than one:

pt = pigeons(
        target = an_unidentifiable_model,
        n_chains = 10,
        record = [traces, round_trip])

# collect the statistics and convert to MCMCChains' Chains
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "with_pt_posterior_densities_and_traces.html");
───────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ      log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0          2  -3.54e+03          0      0.778          1          1
        4          0       2.35      -52.9   5.81e-38      0.739          1          1
        8          0        3.2      -10.1      0.256      0.645          1          1
       16          0       3.93      -10.8      0.319      0.563          1          1
       32          0       3.51      -12.3      0.313       0.61          1          1
       64          4       3.22      -10.9       0.41      0.642          1          1
      128          7       3.63      -11.9      0.419      0.596          1          1
      256         16       3.44      -11.5      0.566      0.618          1          1
      512         40       3.43      -11.6      0.527      0.619          1          1
 1.02e+03         63       3.49      -11.8      0.563      0.612          1          1
───────────────────────────────────────────────────────────────────────────────────────

There is a marked difference. Thanks to round trips through the reference distribution, where we can sample iid, we are able to jump at different parts of the state space.

This is also confirmed by the PT ESS estimates:

samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):

Iterations        = 1:1:1024
Number of chains  = 1
Samples per chain = 1024
parameters        = p1, p2
internals         = log_density

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e     Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

          p1    0.7375    0.1556    0.0207    87.7657   128.1847    1.0502     ⋯
          p2    0.7103    0.1552    0.0153    85.8256   107.0261    1.0252     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          p1    0.5069    0.5914    0.7321    0.8824    0.9814
          p2    0.5088    0.5668    0.6832    0.8460    0.9877