Stan model as input to pigeons
We use the package BridgeStan.jl
as a package extension which will attempt to automatically install Stan. For BridgeStan.jl
to work, a C++ compiler and make
are needed, see the BridgeStan requirements.
To target the posterior distribution specified by a Stan model, use a StanLogPotential
.
Here we show how this is done using our familiar unidentifiable toy example ported to the Stan language.
using BridgeStan
using Pigeons
using Random
# We will use this type to make sure our iid sampler (next section) will
# be used only for this model
struct StanUnidentifiableExample end
function stan_unid(n_trials, n_successes)
# path to a .stan file (compiled files will be cached in the same directory)
stan_file = dirname(dirname(pathof(Pigeons))) * "/examples/stan/unid.stan"
# data can be specified either using...
# - a path to a json file with suffix .json containing the data to condition on
# - the JSON string itself (here via the utility Pigeons.json())
stan_data = Pigeons.json(; n_trials, n_successes)
return StanLogPotential(stan_file, stan_data, StanUnidentifiableExample())
end
pt = pigeons(target = stan_unid(100, 50), reference = stan_unid(0, 0))
BridgeStan not found at location specified by $BRIDGESTAN environment variable, downloading version 2.5.0 to /home/runner/.bridgestan/bridgestan-2.5.0
Done!
┌ Warning: Loading a shared object '/home/runner/work/Pigeons.jl/Pigeons.jl/examples/stan/unid_model.so' which is already loaded.
│ If the file has changed since the last time it was loaded, this load may not update the library!
└ @ BridgeStan ~/.julia/packages/BridgeStan/ccAxe/src/model.jl:59
┌ Info: Neither traces, disk, nor online recorders included.
│ You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└ To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
┌ Warning: It looks like sample_iid!() is not implemented for a
│ reference_log_potential of type StanLogPotential{...}.
│ Instead, using step!().
└ @ Pigeons ~/work/Pigeons.jl/Pigeons.jl/src/targets/target.jl:51
──────────────────────────────────────────────────────────────────────────────────────────────────
scans Λ time(s) allc(B) log(Z₁/Z₀) min(α) mean(α) min(αₑ) mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
2 1.3 0.672 1.21e+07 -5.89 0.0027 0.856 0.351 0.569
4 1.85 0.116 4.15e+06 -4.38 0.592 0.794 0.369 0.585
8 1.36 0.00828 1.12e+05 -4.83 0.282 0.849 0.519 0.618
16 1.56 0.0171 2.13e+05 -4.88 0.595 0.827 0.503 0.586
32 1.49 0.0303 3.67e+05 -4.55 0.67 0.835 0.568 0.654
64 1.4 0.0612 5.88e+05 -4.84 0.695 0.844 0.57 0.628
128 1.53 0.124 1.04e+06 -4.95 0.754 0.83 0.552 0.618
256 1.56 0.245 1.94e+06 -5.04 0.737 0.827 0.556 0.632
512 1.54 0.497 3.75e+06 -5.01 0.798 0.829 0.523 0.619
1.02e+03 1.5 0.991 7.37e+06 -4.94 0.788 0.833 0.544 0.622
──────────────────────────────────────────────────────────────────────────────────────────────────
Notice that we have specified a reference distribution, in this case the same model but with no observations (hence the prior). This needs to be done with Stan targets because it is not possible to automatically extract a prior from a .stan file.
For a StanLogPotential
, the default_explorer()
is AutoMALA
[1].
Sampling from the reference distribution
Ability to sample from the reference distribution can be beneficial, e.g. to jump modes in multi-modal distribution. For stan targets, this is done as follows:
using BridgeStan
function Pigeons.sample_iid!(
log_potential::StanLogPotential{M, S, D, StanUnidentifiableExample}, replica, shared) where {M, S, D}
# sample in constrained space
constrained = rand(replica.rng, 2)
# transform to unconstrained space
replica.state.unconstrained_parameters .= BridgeStan.param_unconstrain(log_potential.model, constrained)
end
pt = pigeons(target = stan_unid(100, 50), reference = stan_unid(0, 0))
┌ Warning: Loading a shared object '/home/runner/work/Pigeons.jl/Pigeons.jl/examples/stan/unid_model.so' which is already loaded.
│ If the file has changed since the last time it was loaded, this load may not update the library!
└ @ BridgeStan ~/.julia/packages/BridgeStan/ccAxe/src/model.jl:59
┌ Warning: Loading a shared object '/home/runner/work/Pigeons.jl/Pigeons.jl/examples/stan/unid_model.so' which is already loaded.
│ If the file has changed since the last time it was loaded, this load may not update the library!
└ @ BridgeStan ~/.julia/packages/BridgeStan/ccAxe/src/model.jl:59
┌ Info: Neither traces, disk, nor online recorders included.
│ You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└ To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
scans Λ time(s) allc(B) log(Z₁/Z₀) min(α) mean(α) min(αₑ) mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
2 1.24 0.00151 4.24e+04 -4.29 0.0658 0.863 0.383 0.593
4 1.73 0.00402 6.08e+04 -4.2 0.429 0.808 0.476 0.607
8 1.2 0.00688 1.15e+05 -4.52 0.706 0.867 0.481 0.59
16 1.12 0.0138 2.01e+05 -4.64 0.702 0.875 0.598 0.65
32 1.77 0.0293 3.25e+05 -5 0.593 0.803 0.506 0.622
64 1.45 0.0549 5.67e+05 -4.84 0.735 0.839 0.534 0.63
128 1.5 0.111 9.82e+05 -4.81 0.733 0.834 0.555 0.63
256 1.39 0.22 1.83e+06 -5.03 0.787 0.846 0.603 0.635
512 1.51 0.442 3.54e+06 -4.87 0.802 0.832 0.57 0.633
1.02e+03 1.51 0.889 6.94e+06 -4.96 0.799 0.832 0.536 0.623
──────────────────────────────────────────────────────────────────────────────────────────────────
Manipulating the output
Internally, Stan target's states are stored in an unconstrained parameterization provided by Stan (for example, bounded support variables are mapped to the full real line). However, sample post-processing functions such as sample_array()
and process_sample()
convert back to the original ("constrained") parameterization via extract_sample()
.
As a result parameterization issues can be essentially ignored when post-processing, for example some common post-processing are shown below, see the section on output processing for more information.
using MCMCChains
using StatsPlots
plotlyjs()
pt = pigeons(
target = stan_unid(100, 50),
reference = stan_unid(0, 0),
record = [traces])
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "stan_posterior_densities_and_traces.html");
samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):
Iterations = 1:1:1024
Number of chains = 1
Samples per chain = 1024
parameters = p1, p2
internals = log_density
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat e ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
p1 0.7039 0.1417 0.0063 491.0540 494.2846 1.0008 ⋯
p2 0.7241 0.1416 0.0065 461.4249 519.3181 1.0024 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
p1 0.4667 0.5914 0.6945 0.8046 0.9719
p2 0.4853 0.6109 0.7140 0.8322 0.9795
- 1Biron-Lattes, M., Surjanovic, N., Syed, S., Campbell, T., & Bouchard-Côté, A.. (2024). autoMALA: Locally adaptive Metropolis-adjusted Langevin algorithm. Proceedings of The 27th International Conference on Artificial Intelligence and Statistics, in Proceedings of Machine Learning Research 238:4600-4608.