Julia code as input to pigeons

In typical Bayesian statistics applications, it is easiest to specify the model in a modelling language, such as Turing, but sometimes to get more flexibility or speed it is useful to implement the density evaluation manually as a "black-box" Julia function.

Here we show how this is done using our familiar unidentifiable toy example.

We first create a custom type, MyLogPotential to control dispatch on the interface target.

using Pigeons
using Random
using Distributions

struct MyLogPotential
    n_trials::Int
    n_successes::Int
end

Next, we make MyLogPotential a function-like object, so that we can write expressions of the form my_log_potential([0.5, 0.5]) and hence MyLogPotential satisfies the log_potential interface:

function (log_potential::MyLogPotential)(x)
    p1, p2 = x
    if !(0 < p1 < 1) || !(0 < p2 < 1)
        return -Inf64
    end
    p = p1 * p2
    return logpdf(Binomial(log_potential.n_trials, p), log_potential.n_successes)
end

# e.g.:
my_log_potential = MyLogPotential(100, 50)
my_log_potential([0.5, 0.5])
-16.91498002656617

Next, we need to specify how to create fresh state objects:

Pigeons.initialization(::MyLogPotential, ::AbstractRNG, ::Int) = [0.5, 0.5]

We can now run the sampler:

pt = pigeons(
        target = MyLogPotential(100, 50),
        reference = MyLogPotential(0, 0)
    )
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
┌ Warning: It looks like sample_iid!() is not implemented for a
reference_log_potential of type Main.MyLogPotential.
Instead, using step!().
@ Pigeons ~/work/Pigeons.jl/Pigeons.jl/src/targets/target.jl:51
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       2.03   0.000371   1.84e+05      -8.65   6.03e-05      0.775          1          1
        4      0.994   0.000227   3.27e+04      -4.82      0.578       0.89          1          1
        8       2.25   0.000346   5.92e+04      -5.81      0.201       0.75          1          1
       16       1.62   0.000657   7.25e+04         -5      0.609       0.82          1          1
       32       1.25    0.00127   1.39e+05      -4.69      0.706      0.861          1          1
       64       1.52    0.00241   1.26e+05      -5.05      0.684      0.831          1          1
      128       1.67    0.00475   1.22e+05      -5.16      0.661      0.815          1          1
      256       1.56    0.00978   1.22e+05      -5.09      0.762      0.827          1          1
      512       1.54      0.019   1.22e+05      -5.07      0.786      0.829          1          1
 1.02e+03       1.51      0.038   1.22e+05      -4.91      0.807      0.832          1          1
──────────────────────────────────────────────────────────────────────────────────────────────────

Notice that we have specified a reference distribution, in this case the same model but with no observations (hence the prior). Indeed, in contrast to targets specified using Turing.jl, it is not possible to construct a reference automatically from Julia "black-box" targets.

The default_explorer() is the SliceSampler.

Sampling from the reference distribution

Ability to sample from the reference distribution can be beneficial, e.g. to jump modes in multi-modal distribution. For black-box Julia function targets, this is done as follows:

function Pigeons.sample_iid!(::MyLogPotential, replica, shared)
    state = replica.state
    rng = replica.rng
    rand!(rng, state)
end

pt = pigeons(
        target = MyLogPotential(100, 50),
        reference = MyLogPotential(0, 0)
    )
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       2.03   0.000122   2.37e+04      -12.1   5.75e-08      0.775          1          1
        4       1.13   0.000207   3.08e+04      -5.26     0.0854      0.874          1          1
        8      0.989   0.000294   5.56e+04      -4.09      0.509       0.89          1          1
       16        1.5   0.000604   9.34e+04      -5.16      0.675      0.834          1          1
       32       1.58     0.0012   1.17e+05      -5.29       0.73      0.825          1          1
       64       1.34    0.00235   1.25e+05      -4.79      0.758      0.852          1          1
      128       1.65    0.00453   1.18e+05      -4.91      0.739      0.817          1          1
      256       1.46    0.00897   1.18e+05      -5.01      0.732      0.838          1          1
      512       1.52     0.0183   1.18e+05       -5.1      0.759      0.831          1          1
 1.02e+03       1.55     0.0361   1.18e+05         -5      0.796      0.828          1          1
──────────────────────────────────────────────────────────────────────────────────────────────────

Changing the explorer

Here is an example using AutoMALA—a gradient-based sampler—instead of the default SliceSampler. We'll use the Enzyme backend, a state-of-the-art AD system that supports targets written in plain Julia. Enzyme is considerably faster than the default ForwardDiff, whose main advantage is compatibility with a broader range of targets. Many other AD backends are supported by the LogDensityProblemsAD.jl interface (:Enzyme, :ForwardDiff, :Zygote, :ReverseDiff, etc).

To proceed, we only need to add methods to make our custom type MyLogPotential conform to the LogDensityProblems interface:

using Enzyme
using LogDensityProblems

LogDensityProblems.dimension(lp::MyLogPotential) = 2
LogDensityProblems.logdensity(lp::MyLogPotential, x) = lp(x)

pt = pigeons(
        target = MyLogPotential(100, 50),
        reference = MyLogPotential(0, 0),
        explorer = AutoMALA(default_autodiff_backend = :Enzyme)
    )
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       2.07       13.2   8.81e+08        -12   4.85e-08      0.771      0.266      0.508
        4       1.24      0.105   5.09e+06      -5.87      0.282      0.862      0.369      0.549
        8       1.28    0.00209   3.83e+05      -5.71      0.554      0.858      0.413      0.511
       16       1.18    0.00427   7.76e+05      -4.92       0.63      0.869      0.426      0.519
       32       1.61     0.0088   1.53e+06      -4.57      0.391      0.821       0.41      0.525
       64       1.18     0.0166   2.81e+06      -4.91      0.753      0.869      0.455      0.531
      128       1.42     0.0341   5.54e+06      -4.84      0.767      0.842      0.475      0.544
      256       1.45     0.0666   1.08e+07      -4.83      0.763      0.839      0.501       0.55
      512       1.52      0.135   2.17e+07      -5.04      0.795      0.832      0.488      0.547
 1.02e+03        1.5      0.308   4.31e+07      -4.95      0.815      0.833      0.489      0.548
──────────────────────────────────────────────────────────────────────────────────────────────────

Pigeons have several built-in explorer kernels such as AutoMALA and a SliceSampler. However when the state space is neither the reals nor the integers, or for performance reasons, it may be necessary to create custom exploration MCMC kernels. This is described on the custom explorers page.

Custom gradients

In some situations it may be helpful to compute gradients explicitly (performance, unsupported primitives, etc). One method to do so is to use autodiff-specific machinery, see for example the Enzyme documentation. In addition, Pigeons also has an AD framework-agnostic method to provide explicit gradients, supporting replica-specific, in-place buffers (this functionality was developed to support efficient interfacing with Stan). Using this is demonstrated below:

using Pigeons
using Random
using LogDensityProblems
using LogDensityProblemsAD

struct CustomGradientLogPotential
    precision::Float64
    dim::Int
end
function (log_potential::CustomGradientLogPotential)(x)
    -0.5 * log_potential.precision * sum(abs2, x)
end

Pigeons.initialization(lp::CustomGradientLogPotential, ::AbstractRNG, ::Int) = zeros(lp.dim)

LogDensityProblems.dimension(lp::CustomGradientLogPotential) = lp.dim
LogDensityProblems.logdensity(lp::CustomGradientLogPotential, x) = lp(x)

LogDensityProblemsAD.ADgradient(::Val, log_potential::CustomGradientLogPotential, replica::Pigeons.Replica) =
     Pigeons.BufferedAD(log_potential, replica.recorders.buffers)

const check_custom_grad_called = Ref(false)

function LogDensityProblems.logdensity_and_gradient(log_potential::Pigeons.BufferedAD{CustomGradientLogPotential}, x)
    logdens = log_potential.enclosed(x)
    global check_custom_grad_called[] = true
    log_potential.buffer .= -log_potential.enclosed.precision .* x
    return logdens, log_potential.buffer
end

pigeons(
    target = CustomGradientLogPotential(2.1, 4),
    reference = CustomGradientLogPotential(1.1, 4),
    n_chains = 1,
    n_rounds = 5,
    explorer = AutoMALA())

@assert check_custom_grad_called[]
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
─────────────────────────────────────────────────────────────────
  scans      time(s)    allc(B)  log(Z₁/Z₀)   min(αₑ)   mean(αₑ)
────────── ────────── ────────── ────────── ────────── ──────────
        2      0.351   1.61e+07          0      0.488      0.488
        4     0.0615   4.68e+06          0      0.771      0.771
        8   9.95e-05   6.11e+03          0       0.56       0.56
       16   0.000132   9.57e+03          0      0.703      0.703
       32   0.000228   1.65e+04          0      0.652      0.652
─────────────────────────────────────────────────────────────────

Manipulating the output

Some common post-processing are shown below, see the section on output processing for more information.

using MCMCChains
using StatsPlots
plotlyjs()

pt = pigeons(
        target = MyLogPotential(100, 50),
        reference = MyLogPotential(0, 0),
        explorer = AutoMALA(default_autodiff_backend = :Enzyme),
        record = [traces])
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "julia_posterior_densities_and_traces.html");

samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):

Iterations        = 1:1:1024
Number of chains  = 1
Samples per chain = 1024
parameters        = param_1, param_2
internals         = log_density

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e     Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

     param_1    0.7139    0.1519    0.0072   476.2129   813.4181    1.0009     ⋯
     param_2    0.7196    0.1521    0.0068   531.4729   717.6050    0.9994     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

     param_1    0.4607    0.5920    0.7058    0.8373    0.9825
     param_2    0.4781    0.5885    0.7118    0.8463    0.9859