Stan model as input to pigeons

Note

We use the package BridgeStan.jl as a package extension which will attempt to automatically install Stan. For BridgeStan.jl to work, a C++ compiler and make are needed, see the BridgeStan requirements.

To target the posterior distribution specified by a Stan model, use a StanLogPotential.

Here we show how this is done using our familiar unidentifiable toy example ported to the Stan language.

using BridgeStan
using Pigeons
using Random

# We will use this type to make sure our iid sampler (next section) will
# be used only for this model
struct StanUnidentifiableExample end

function stan_unid(n_trials, n_successes)
    # path to a .stan file (compiled files will be cached in the same directory)
    stan_file = dirname(dirname(pathof(Pigeons))) * "/examples/stan/unid.stan"

    # data can be specified either using...
    #   - a path to a json file with suffix .json containing the data to condition on
    #   - the JSON string itself (here via the utility Pigeons.json())
    stan_data = Pigeons.json(; n_trials, n_successes)

    return StanLogPotential(stan_file, stan_data, StanUnidentifiableExample())
end

pt = pigeons(target = stan_unid(100, 50), reference = stan_unid(0, 0))
BridgeStan not found at location specified by $BRIDGESTAN environment variable, downloading version 2.4.1 to /home/runner/.bridgestan/bridgestan-2.4.1
Done!
┌ Warning: Loading a shared object '/home/runner/work/Pigeons.jl/Pigeons.jl/examples/stan/unid_model.so' which is already loaded.
If the file has changed since the last time it was loaded, this load may not update the library!
@ BridgeStan ~/.julia/packages/BridgeStan/aKsXp/src/model.jl:58
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
┌ Warning: It looks like sample_iid!() is not implemented for a
reference_log_potential of type StanLogPotential{...}.
Instead, using step!().
@ Pigeons ~/work/Pigeons.jl/Pigeons.jl/src/targets/target.jl:51
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2        1.3      0.038   7.47e+05      -5.89     0.0027      0.856      0.351      0.569
        4       1.85     0.0042   4.52e+04      -4.38      0.592      0.794      0.369      0.585
        8       1.36    0.00805   7.63e+04      -4.83      0.282      0.849      0.519      0.618
       16       1.56     0.0164   1.36e+05      -4.88      0.595      0.827      0.503      0.586
       32       1.49     0.0291   2.07e+05      -4.55       0.67      0.835      0.568      0.654
       64        1.4     0.0583    2.6e+05      -4.84      0.695      0.844       0.57      0.628
      128       1.53      0.118   3.72e+05      -4.95      0.754       0.83      0.552      0.618
      256       1.56      0.234   6.01e+05      -5.04      0.737      0.827      0.556      0.632
      512       1.54      0.472   1.06e+06      -5.01      0.798      0.829      0.523      0.619
 1.02e+03        1.5      0.943   1.98e+06      -4.94      0.788      0.833      0.544      0.622
──────────────────────────────────────────────────────────────────────────────────────────────────

Notice that we have specified a reference distribution, in this case the same model but with no observations (hence the prior). This needs to be done with Stan targets because it is not possible to automatically extract a prior from a .stan file.

For a StanLogPotential, the default_explorer() is AutoMALA[1].

Sampling from the reference distribution

Ability to sample from the reference distribution can be beneficial, e.g. to jump modes in multi-modal distribution. For stan targets, this is done as follows:

using BridgeStan

function Pigeons.sample_iid!(
        log_potential::StanLogPotential{M, S, D, StanUnidentifiableExample}, replica, shared) where {M, S, D}
    # sample in constrained space
    constrained = rand(replica.rng, 2)
    # transform to unconstrained space
    replica.state.unconstrained_parameters .= BridgeStan.param_unconstrain(log_potential.model, constrained)
end

pt = pigeons(target = stan_unid(100, 50), reference = stan_unid(0, 0))
┌ Warning: Loading a shared object '/home/runner/work/Pigeons.jl/Pigeons.jl/examples/stan/unid_model.so' which is already loaded.
If the file has changed since the last time it was loaded, this load may not update the library!
@ BridgeStan ~/.julia/packages/BridgeStan/aKsXp/src/model.jl:58
┌ Warning: Loading a shared object '/home/runner/work/Pigeons.jl/Pigeons.jl/examples/stan/unid_model.so' which is already loaded.
If the file has changed since the last time it was loaded, this load may not update the library!
@ BridgeStan ~/.julia/packages/BridgeStan/aKsXp/src/model.jl:58
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       1.24    0.00147   3.62e+04      -4.29     0.0658      0.863      0.383      0.593
        4       1.73    0.00392   4.75e+04       -4.2      0.429      0.808      0.476      0.607
        8        1.2    0.00664   8.36e+04      -4.52      0.706      0.867      0.481       0.59
       16       1.12     0.0133   1.31e+05      -4.64      0.702      0.875      0.598       0.65
       32       1.77     0.0262   1.82e+05         -5      0.593      0.803      0.506      0.622
       64       1.45     0.0522   2.72e+05      -4.84      0.735      0.839      0.534       0.63
      128        1.5      0.105   3.84e+05      -4.81      0.733      0.834      0.555       0.63
      256       1.39      0.209   6.26e+05      -5.03      0.787      0.846      0.603      0.635
      512       1.51      0.419   1.11e+06      -4.87      0.802      0.832       0.57      0.633
 1.02e+03       1.51      0.843   2.09e+06      -4.96      0.799      0.832      0.536      0.623
──────────────────────────────────────────────────────────────────────────────────────────────────

Manipulating the output

Internally, Stan target's states are stored in an unconstrained parameterization provided by Stan (for example, bounded support variables are mapped to the full real line). However, sample post-processing functions such as sample_array() and process_sample() convert back to the original ("constrained") parameterization via extract_sample().

As a result parameterization issues can be essentially ignored when post-processing, for example some common post-processing are shown below, see the section on output processing for more information.

using MCMCChains
using StatsPlots
plotlyjs()

pt = pigeons(
        target = stan_unid(100, 50),
        reference = stan_unid(0, 0),
        record = [traces])
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "stan_posterior_densities_and_traces.html");

samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):

Iterations        = 1:1:1024
Number of chains  = 1
Samples per chain = 1024
parameters        = p1, p2
internals         = log_density

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e     Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

          p1    0.7039    0.1417    0.0063   491.0540   494.2846    1.0008     ⋯
          p2    0.7241    0.1416    0.0065   461.4249   519.3181    1.0024     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          p1    0.4667    0.5914    0.6945    0.8046    0.9719
          p2    0.4853    0.6109    0.7140    0.8322    0.9795