Parallel Tempering-specific diagnostics

We describe how to produce some key non-reversible parallel tempering diagnostics described in Syed et al., 2021.

Global communication barrier

The global communication barrier can be used to set the number of chains. The theoretical framework of Syed et al., 2021 yields that under simplifying assumptions, it is optimal to set the number of chains (the argument n_chains in Inputs or pigeons()) to roughly 2Λ.

The global communication barrier is shown at each round and can also be accessed via global_barrier().

using DynamicPPL
using Pigeons

pt = pigeons(target = Pigeons.toy_turing_unid_target(100, 50))
Pigeons.global_barrier(pt)
1.4911469199502336

When both a fixed and variational are used, they are printed separately, labelled Λ and Λ_var for the fixed and variational global barriers respectively:

using DynamicPPL
using Pigeons

pt = pigeons(target = Pigeons.toy_turing_unid_target(100, 50),
                variational = GaussianReference(),
                n_chains_variational = 10)
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        Λ_var      time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       1.39       3.24      0.283   8.82e+06      -8.14    0.00178      0.756          1          1
        4        2.4       1.55    0.00443    4.4e+06      -5.27     0.0352      0.792          1          1
        8       1.65       1.71    0.00802   8.51e+06       -4.4      0.331      0.823          1          1
       16       1.76       1.86     0.0159   1.72e+07      -4.51      0.545       0.81          1          1
       32        1.4       1.64     0.0321   3.49e+07      -4.95      0.636       0.84          1          1
       64       1.69       1.54        0.1   6.81e+07      -4.95      0.538       0.83          1          1
      128        1.5       1.03      0.412   1.51e+08      -5.19      0.687      0.867          1          1
      256       1.54       0.97      0.294   2.83e+08      -4.91      0.786      0.868          1          1
      512       1.51       1.01      0.605   5.64e+08      -5.08      0.795      0.868          1          1
 1.02e+03       1.48       1.02       1.18   1.12e+09      -4.93      0.815      0.868          1          1
─────────────────────────────────────────────────────────────────────────────────────────────────────────────

Round trips and tempered restarts

A tempered restart happens when a sample from the reference percolates to the target. When the reference supports iid sampling, tempered restarts can enable large jumps in the state space.

A round-trip happens when we have a full cycle from reference to target and back to reference.

To count tempered restarts and round trips, add the round_trip() recorder:

pt = pigeons(target = Pigeons.toy_turing_unid_target(100, 50),
           record = [round_trip; record_default()])
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.24    0.00786   1.13e+06      -8.14    0.00178       0.64          1          1
        4          0       1.64    0.00223   2.08e+06      -5.04     0.0352      0.818          1          1
        8          0       1.17    0.00387   4.07e+06      -4.42      0.708      0.871          1          1
       16          1        1.2    0.00783   8.59e+06      -4.03      0.549      0.867          1          1
       32          6       1.11     0.0154   1.68e+07      -4.77      0.754      0.877          1          1
       64         11       1.35     0.0309   3.37e+07      -4.79      0.698       0.85          1          1
      128         25        1.6     0.0922   6.68e+07      -4.97      0.725      0.823          1          1
      256         43       1.51      0.126   1.32e+08      -4.92      0.758      0.832          1          1
      512         92       1.46      0.344   2.66e+08         -5      0.806      0.838          1          1
 1.02e+03        188       1.49      0.586   5.33e+08      -4.92      0.798      0.834          1          1
─────────────────────────────────────────────────────────────────────────────────────────────────────────────

The values can also be accessed as follows:

Pigeons.n_tempered_restarts(pt), Pigeons.n_round_trips(pt)
(188, 182)

Local communication barrier

When the global communication barrier is large, many chains may be required to obtain tempered restarts.

The local communication barrier can be used to visualize the cause of a high global communication barrier. For example, if there is a sharp peak close to a reference constructed from the prior, it may be useful to switch to a variational approximation.

The local barrier can be plotted as follows:

using Plots
plotlyjs()
myplot = plot(pt.shared.tempering.communication_barriers.localbarrier);
savefig(myplot, "local_barrier_plot.html");

Index process

The index process tracks the permutation of chains as machine exchange annealing parameters. Each row is a chain and each connected line corresponds to a replica. To enable this we use the index_process recorder:

pt = pigeons(
        target = toy_mvn_target(1),
        record = [index_process],
        n_rounds = 5)
myplot = plot(pt.reduced_recorders.index_process)
savefig(myplot, "index_process_plot.html");
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────
  scans        Λ      log(Z₁/Z₀)   min(α)     mean(α)
────────── ────────── ────────── ────────── ──────────
        2      0.417      -1.56      0.782      0.954
        4      0.564     -0.653       0.61      0.937
        8      0.597     -0.833      0.869      0.934
       16      0.686      -1.35      0.729      0.924
       32      0.818      -1.19        0.8      0.909
──────────────────────────────────────────────────────