Extended output (i.e., for all chains)
So far when outputting traces (either to memory via traces
or to disk via disk
), we have been storing only the target distribution's samples. This is the most common scenario and the default. Here we show how to instead store the samples from all chains.
This can be useful in scenarios where all distributions $\pi_i$ are of interest, e.g. in certain statistical mechanics applications and for Bayesian inference under model mis-specification.
The key argument to add is extended_traces = true
, which we demonstrate for various common scenarios below.
Posterior densities and trace plots for all chains
Make sure to have the third party DynamicPPL
, MCMCChains
, and StatsPlots
packages installed via
using Pkg; Pkg.add("DynamicPPL", "MCMCChains", "StatsPlots")
Then use the following:
using DynamicPPL
using Pigeons
using MCMCChains
using StatsPlots
plotlyjs()
# example target: Binomial likelihood with parameter p = p1 * p2
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100, 50)
pt = pigeons(target = an_unidentifiable_model,
n_rounds = 12,
extended_traces = true,
# make sure to record the trace:
record = [traces; round_trip; record_default()])
# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "posterior_densities_and_traces_extended.html");
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
scans restarts Λ time(s) allc(B) log(Z₁/Z₀) min(α) mean(α) min(αₑ) mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
2 0 3.24 0.00821 1.25e+06 -8.14 0.00178 0.64 1 1
4 0 1.64 0.00249 2.32e+06 -5.04 0.0352 0.818 1 1
8 0 1.17 0.00457 4.55e+06 -4.42 0.708 0.871 1 1
16 1 1.2 0.00888 9.58e+06 -4.03 0.549 0.867 1 1
32 6 1.11 0.0174 1.87e+07 -4.77 0.754 0.877 1 1
64 11 1.35 0.0348 3.77e+07 -4.79 0.698 0.85 1 1
128 25 1.6 0.0782 7.45e+07 -4.97 0.725 0.823 1 1
256 43 1.51 0.195 1.48e+08 -4.92 0.758 0.832 1 1
512 92 1.46 0.361 2.97e+08 -5 0.806 0.838 1 1
1.02e+03 188 1.49 0.732 5.96e+08 -4.92 0.798 0.834 1 1
2.05e+03 384 1.5 1.43 1.19e+09 -4.96 0.811 0.834 1 1
4.1e+03 748 1.48 2.92 2.38e+09 -4.94 0.826 0.835 1 1
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
Here the ten different colours correspond to the 10 chains interpolating between the posterior and the prior (here a uniform distribution).
Off-memory processing for all chains
The same option, extended_traces = true
can be used in the same fashion to save to disk samples from all chains:
using Pigeons
# example target: a 1000 dimensional target
high_d_target = Pigeons.toy_mvn_target(1000)
pt = pigeons(target = high_d_target,
checkpoint = true,
extended_traces = true,
record = [disk])
first_dim_of_each = zeros(10, 1024)
process_sample(pt) do chain, scan, sample # ordered as if we had an inner loop over scans
# each sample here is a Vector{Float64} of length 1000
# in general, it will is produced by extract_sample()
first_dim_of_each[chain, scan] = sample[1]
end
──────────────────────────────────────────────────────
scans Λ log(Z₁/Z₀) min(α) mean(α)
────────── ────────── ────────── ────────── ──────────
2 9 -1.18e+03 7.04e-107 0.000359
4 8.97 -1.17e+03 1.31e-102 0.00346
8 8.38 -1.2e+03 1.13e-107 0.0692
16 8.94 -1.18e+03 1.65e-93 0.00618
32 8.96 -1.16e+03 1.45e-70 0.00487
64 8.85 -1.17e+03 2.13e-83 0.0164
128 8.88 -1.16e+03 1.23e-68 0.0129
256 8.92 -1.15e+03 4.23e-64 0.00884
512 8.92 -1.16e+03 5.16e-68 0.00893
1.02e+03 8.93 -1.16e+03 6.65e-66 0.00732
──────────────────────────────────────────────────────
Accessing the annealing parameters
To obtain the annealing parameter used to define each intermediate distribution, use:
using Pigeons
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100, 50)
pt = pigeons(target = an_unidentifiable_model)
pt.shared.tempering.schedule
Pigeons.Schedule([0.0, 0.007746630632468892, 0.01953727870095436, 0.03794177705127855, 0.06676168905907191, 0.12237384250783831, 0.2078044064064331, 0.3705603026866782, 0.6038337802460468, 1.0])