Turing.jl model as input to pigeons

To target the posterior distribution specified by a Turing.jl model first load Turing or DynamicPPL and use TuringLogPotential:

using DynamicPPL, Pigeons, Distributions

DynamicPPL.@model function my_turing_model(n_trials, n_successes)
    p1 ~ Uniform(0, 1)
    p2 ~ Uniform(0, 1)
    n_successes ~ Binomial(n_trials, p1*p2)
    return n_successes
end

my_turing_target = TuringLogPotential(my_turing_model(100, 50))
pt = pigeons(target = my_turing_target);
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       3.24    0.00886   1.14e+06      -8.14    0.00178       0.64          1          1
        4       1.64    0.00223   2.09e+06      -5.04     0.0352      0.818          1          1
        8       1.17    0.00404   4.09e+06      -4.42      0.708      0.871          1          1
       16        1.2    0.00785   8.63e+06      -4.03      0.549      0.867          1          1
       32       1.11     0.0153   1.69e+07      -4.77      0.754      0.877          1          1
       64       1.35     0.0305   3.39e+07      -4.79      0.698       0.85          1          1
      128        1.6     0.0697   6.71e+07      -4.97      0.725      0.823          1          1
      256       1.51      0.173   1.33e+08      -4.92      0.758      0.832          1          1
      512       1.48       0.32   2.69e+08      -4.94      0.806      0.836          1          1
 1.02e+03       1.53      0.524   5.36e+08      -5.08      0.808       0.83          1          1
──────────────────────────────────────────────────────────────────────────────────────────────────

At the moment, only Turing models with fixed dimensionality are supported. Both real and integer-valued random variables are supported. For a TuringLogPotential, the default_explorer() is the SliceSampler and the default_reference() is the prior distribution encoded in the Turing model.

Gradient-based sampling with AutoMALA

For Turing models with fully continuous state-spaces—as is the case for my_turing_model defined above—AutoMALA can be an effective alternative to SliceSampler—especially for high-dimensional problems. Because Turing targets conform to the LogDensityProblemsAD.jl interface, Automatic Differentiation (AD) backends can be used to obtain the gradients needed by AutoMALA.

The default AD backend for AutoMALA is ForwardDiff. However, Turing supports other backends that may exhibit improved performance. One such is Mooncake, which we can use in Pigeons via

using ADTypes, Mooncake
pt = pigeons(
    target = my_turing_target,
    explorer = AutoMALA(default_autodiff_backend = AutoMooncake(nothing))
);
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       1.26       61.4   4.35e+09      -6.65    0.00185       0.86          0      0.517
        4      0.485      0.129   2.67e+07      -4.07      0.742      0.946          0      0.522
        8      0.105       0.12   4.23e+07      -3.63      0.933      0.988      0.406      0.575
       16      0.708        0.1   8.74e+07      -3.33      0.613      0.921      0.226      0.565
       32       1.15      0.275   1.85e+08      -3.88      0.576      0.872      0.269      0.481
       64      0.757      0.423   3.58e+08      -3.65      0.853      0.916      0.242      0.551
      128      0.857      0.833   7.06e+08       -3.7      0.798      0.905      0.357      0.539
      256      0.932       1.72   1.46e+09      -3.93      0.865      0.896      0.313      0.564
      512      0.934       3.31    2.8e+09      -3.86      0.862      0.896      0.266      0.499
 1.02e+03       1.02       6.68   5.64e+09      -4.01      0.851      0.887       0.24      0.473
──────────────────────────────────────────────────────────────────────────────────────────────────

Alternatively, in the special case when the Turing model does not involve branching decisions (if, while, etc...) depending on latent variables, ReverseDiff with compiled tape may provide accelerated performance. Since my_turing_target satisfies this criterion, we can use AutoMALA with the ReverseDiff AD backend via

using ADTypes, ReverseDiff
pt = pigeons(
    target = my_turing_target,
    explorer = AutoMALA(default_autodiff_backend = AutoReverseDiff(compile=true))
);
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       1.26          2   1.26e+08      -6.65    0.00185       0.86          0      0.517
        4      0.485      0.223    1.1e+07      -4.07      0.742      0.946          0      0.522
        8      0.105     0.0195   8.93e+06      -3.63      0.933      0.988      0.406      0.575
       16      0.708     0.0392   1.84e+07      -3.33      0.613      0.921      0.226      0.565
       32       1.15     0.0822   3.81e+07      -3.88      0.576      0.872      0.269      0.481
       64      0.757      0.162   7.46e+07      -3.65      0.853      0.916      0.242      0.551
      128      0.857      0.389   1.48e+08       -3.7      0.798      0.905      0.357      0.539
      256      0.932      0.704   3.03e+08      -3.93      0.865      0.896      0.313      0.564
      512      0.934       1.33   5.86e+08      -3.86      0.862      0.896      0.266      0.499
 1.02e+03       1.02        2.6   1.17e+09      -4.01      0.851      0.887       0.24      0.473
──────────────────────────────────────────────────────────────────────────────────────────────────

Using DynamicPPL.@addlogprob!

The macro DynamicPPL.@addlogprob! is sometimes used when additional flexibility is needed while incrementing the log probability. To do so with Pigeons.jl, you will need to enclose the call to DynamicPPL.@addlogprob! within an if statement as shown below. Failing to do so will lead to invalid results.

DynamicPPL.@model function my_turing_model(my_data)
    # code here
    if DynamicPPL.leafcontext(__context__) !== DynamicPPL.PriorContext() 
        DynamicPPL.@addlogprob! logpdf(MyDist(parms), my_data)
    end
end

Manipulating the output

Internally, Turing target's states (of type DynamicPPL.TypedVarInfo) are stored in an unconstrained parameterization provided by Turing (for example, bounded support variables are mapped to the full real line). However, sample post-processing functions such as sample_array() and process_sample() convert back to the original ("constrained") parameterization via extract_sample().

As a result parameterization issues can be essentially ignored when post-processing, for example some common post-processing are shown below, see the section on output processing for more information.

using MCMCChains
using StatsPlots
plotlyjs()

pt = pigeons(target = my_turing_target, record = [traces])
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "turing_posterior_densities_and_traces.html");

samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):

Iterations        = 1:1:1024
Number of chains  = 1
Samples per chain = 1024
parameters        = p1, p2
internals         = log_density

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e     Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

          p1    0.7152    0.1518    0.0068   491.5895   649.6713    1.0026     ⋯
          p2    0.7185    0.1507    0.0067   525.7289   783.7176    1.0040     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          p1    0.4533    0.5884    0.7068    0.8417    0.9804
          p2    0.4729    0.5932    0.7008    0.8467    0.9856

Custom initialization

It is sometimes useful to provide a custom initialization, for example to start in a feasible region. This can be done as follows:

using DynamicPPL, Pigeons, Distributions, Random

DynamicPPL.@model function toy_beta_binom_model(n_trials, n_successes)
    p ~ Uniform(0, 1)
    n_successes ~ Binomial(n_trials, p)
    return n_successes
end

function toy_beta_binom_target(n_trials = 10, n_successes = 2)
    return Pigeons.TuringLogPotential(toy_beta_binom_model(n_trials, n_successes))
end

const ToyBetaBinomType = typeof(toy_beta_binom_target())

function Pigeons.initialization(target::ToyBetaBinomType, rng::AbstractRNG, ::Int64)
    result = DynamicPPL.VarInfo(rng, target.model, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext())
    result = DynamicPPL.link(result, target.model)

    # custom init goes here: for example here setting the variable p to 0.5
    Pigeons.update_state!(result, :p, 1, 0.5)

    return result
end

pt = pigeons(target = toy_beta_binom_target(), n_rounds = 0)
@assert Pigeons.variable(pt.replicas[1].state, :p) == [0.5]
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])