Variational PT

We describe here the implementation of Variational PT, Surjanovic et al., 2022 included in Pigeons. Both the basic variational PT and stabilized variants introduced in Surjanovic et al., 2022 are available.

Basic variational PT

Enable variational PT by supplier the variational option to pigeons(...):

using DynamicPPL
using Pigeons

pigeons(
    target = Pigeons.toy_turing_unid_target(100, 50),
    variational = GaussianReference(first_tuning_round = 5))
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       3.24    0.00178    7.2e+05      -8.14    0.00178       0.64          1          1
        4       1.64    0.00266    1.4e+06      -5.04     0.0352      0.818          1          1
        8       1.17    0.00467   2.72e+06      -4.42      0.708      0.871          1          1
       16        1.2    0.00891   5.69e+06      -4.03      0.549      0.867          1          1
       32       1.11     0.0174   1.11e+07      -4.77      0.754      0.877          1          1
       64       1.08     0.0419   3.98e+07      -4.84      0.754      0.879          1          1
      128       1.06     0.0858   7.93e+07      -4.97      0.804      0.882          1          1
      256       1.05      0.215   1.58e+08      -4.96      0.858      0.884          1          1
      512       1.04      0.387   3.16e+08      -5.01      0.868      0.885          1          1
 1.02e+03      0.994      0.766   6.31e+08      -4.98      0.876       0.89          1          1
──────────────────────────────────────────────────────────────────────────────────────────────────

Note variational fitting only starts at first_tuning_round. The fixed reference is used before that point.

Stabilized variational PT

Surjanovic et al., 2022 describes situations where the variational fitting can cause catastrophic forgetting of modes. This is remediated by using both a fixed and a variational reference each linked to two copies of the target, which are also swapped according to a non-reversible swapping scheme.

Enable stabilized variational PT by adding the n_chains_variational option to pigeons(...):

pigeons(
    target = Pigeons.toy_turing_unid_target(100, 50),
    variational = GaussianReference(first_tuning_round = 5),
    n_chains_variational = 10)
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        Λ_var      time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       1.39       3.24    0.00256   1.47e+06      -8.14    0.00178      0.756          1          1
        4        2.4       1.55    0.00475   2.91e+06      -5.27     0.0352      0.792          1          1
        8       1.65       1.71     0.0089   5.63e+06       -4.4      0.331      0.823          1          1
       16       1.76       1.86     0.0178   1.14e+07      -4.51      0.545       0.81          1          1
       32        1.4       1.64     0.0354    2.3e+07      -4.95      0.636       0.84          1          1
       64       1.58       0.96     0.0774   6.21e+07      -4.83      0.711      0.866          1          1
      128       1.55      0.919      0.209   1.24e+08      -4.86      0.776       0.87          1          1
      256        1.4       1.04      0.302   2.48e+08      -4.97      0.812      0.872          1          1
      512       1.48       1.02       0.72   4.95e+08         -5      0.809      0.868          1          1
 1.02e+03        1.5       1.06       1.36   9.92e+08      -4.99        0.8      0.865          1          1
─────────────────────────────────────────────────────────────────────────────────────────────────────────────