Julia code as input to pigeons
In typical Bayesian statistics applications, it is easiest to specify the model in a modelling language, such as Turing, but sometimes to get more flexibility or speed it is useful to implement the density evaluation manually as a "black-box" Julia function.
Here we show how this is done using our familiar unidentifiable toy example.
We first create a custom type, MyLogPotential
to control dispatch on the interface target
.
using Pigeons
using Random
using Distributions
struct MyLogPotential
n_trials::Int
n_successes::Int
end
Next, we make MyLogPotential
a function-like object, so that we can write expressions of the form my_log_potential([0.5, 0.5])
and hence MyLogPotential
satisfies the log_potential
interface:
function (log_potential::MyLogPotential)(x)
p1, p2 = x
if !(0 < p1 < 1) || !(0 < p2 < 1)
return -Inf64
end
p = p1 * p2
return logpdf(Binomial(log_potential.n_trials, p), log_potential.n_successes)
end
# e.g.:
my_log_potential = MyLogPotential(100, 50)
my_log_potential([0.5, 0.5])
-16.91498002656617
Next, we need to specify how to create fresh state
objects:
Pigeons.initialization(::MyLogPotential, ::AbstractRNG, ::Int) = [0.5, 0.5]
We can now run the sampler:
pt = pigeons(
target = MyLogPotential(100, 50),
reference = MyLogPotential(0, 0)
)
┌ Info: Neither traces, disk, nor online recorders included.
│ You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└ To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
┌ Warning: It looks like sample_iid!() is not implemented for a
│ reference_log_potential of type Main.MyLogPotential.
│ Instead, using step!().
└ @ Pigeons ~/work/Pigeons.jl/Pigeons.jl/src/targets/target.jl:51
──────────────────────────────────────────────────────────────────────────────────────────────────
scans Λ time(s) allc(B) log(Z₁/Z₀) min(α) mean(α) min(αₑ) mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
2 2.03 0.000371 1.84e+05 -8.65 6.03e-05 0.775 1 1
4 0.994 0.000227 3.27e+04 -4.82 0.578 0.89 1 1
8 2.25 0.000346 5.92e+04 -5.81 0.201 0.75 1 1
16 1.62 0.000657 7.25e+04 -5 0.609 0.82 1 1
32 1.25 0.00127 1.39e+05 -4.69 0.706 0.861 1 1
64 1.52 0.00241 1.26e+05 -5.05 0.684 0.831 1 1
128 1.67 0.00475 1.22e+05 -5.16 0.661 0.815 1 1
256 1.56 0.00978 1.22e+05 -5.09 0.762 0.827 1 1
512 1.54 0.019 1.22e+05 -5.07 0.786 0.829 1 1
1.02e+03 1.51 0.038 1.22e+05 -4.91 0.807 0.832 1 1
──────────────────────────────────────────────────────────────────────────────────────────────────
Notice that we have specified a reference distribution, in this case the same model but with no observations (hence the prior). Indeed, in contrast to targets specified using Turing.jl, it is not possible to construct a reference automatically from Julia "black-box" targets.
The default_explorer()
is the SliceSampler
.
Sampling from the reference distribution
Ability to sample from the reference distribution can be beneficial, e.g. to jump modes in multi-modal distribution. For black-box Julia function targets, this is done as follows:
function Pigeons.sample_iid!(::MyLogPotential, replica, shared)
state = replica.state
rng = replica.rng
rand!(rng, state)
end
pt = pigeons(
target = MyLogPotential(100, 50),
reference = MyLogPotential(0, 0)
)
┌ Info: Neither traces, disk, nor online recorders included.
│ You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└ To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
scans Λ time(s) allc(B) log(Z₁/Z₀) min(α) mean(α) min(αₑ) mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
2 2.03 0.000122 2.37e+04 -12.1 5.75e-08 0.775 1 1
4 1.13 0.000207 3.08e+04 -5.26 0.0854 0.874 1 1
8 0.989 0.000294 5.56e+04 -4.09 0.509 0.89 1 1
16 1.5 0.000604 9.34e+04 -5.16 0.675 0.834 1 1
32 1.58 0.0012 1.17e+05 -5.29 0.73 0.825 1 1
64 1.34 0.00235 1.25e+05 -4.79 0.758 0.852 1 1
128 1.65 0.00453 1.18e+05 -4.91 0.739 0.817 1 1
256 1.46 0.00897 1.18e+05 -5.01 0.732 0.838 1 1
512 1.52 0.0183 1.18e+05 -5.1 0.759 0.831 1 1
1.02e+03 1.55 0.0361 1.18e+05 -5 0.796 0.828 1 1
──────────────────────────────────────────────────────────────────────────────────────────────────
Changing the explorer
Here is an example using AutoMALA
—a gradient-based sampler—instead of the default SliceSampler
. We'll use the Enzyme backend, a state-of-the-art AD system that supports targets written in plain Julia. Enzyme is considerably faster than the default ForwardDiff, whose main advantage is compatibility with a broader range of targets. Many other AD backends are supported by the LogDensityProblemsAD.jl interface (:Enzyme
, :ForwardDiff
, :Zygote
, :ReverseDiff
, etc).
To proceed, we only need to add methods to make our custom type MyLogPotential
conform to the LogDensityProblems interface:
using Enzyme
using LogDensityProblems
LogDensityProblems.dimension(lp::MyLogPotential) = 2
LogDensityProblems.logdensity(lp::MyLogPotential, x) = lp(x)
pt = pigeons(
target = MyLogPotential(100, 50),
reference = MyLogPotential(0, 0),
explorer = AutoMALA(default_autodiff_backend = :Enzyme)
)
┌ Info: Neither traces, disk, nor online recorders included.
│ You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└ To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
scans Λ time(s) allc(B) log(Z₁/Z₀) min(α) mean(α) min(αₑ) mean(αₑ)
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
2 2.07 13.2 8.81e+08 -12 4.85e-08 0.771 0.266 0.508
4 1.24 0.105 5.09e+06 -5.87 0.282 0.862 0.369 0.549
8 1.28 0.00209 3.83e+05 -5.71 0.554 0.858 0.413 0.511
16 1.18 0.00427 7.76e+05 -4.92 0.63 0.869 0.426 0.519
32 1.61 0.0088 1.53e+06 -4.57 0.391 0.821 0.41 0.525
64 1.18 0.0166 2.81e+06 -4.91 0.753 0.869 0.455 0.531
128 1.42 0.0341 5.54e+06 -4.84 0.767 0.842 0.475 0.544
256 1.45 0.0666 1.08e+07 -4.83 0.763 0.839 0.501 0.55
512 1.52 0.135 2.17e+07 -5.04 0.795 0.832 0.488 0.547
1.02e+03 1.5 0.308 4.31e+07 -4.95 0.815 0.833 0.489 0.548
──────────────────────────────────────────────────────────────────────────────────────────────────
Pigeons have several built-in explorer
kernels such as AutoMALA
and a SliceSampler
. However when the state space is neither the reals nor the integers, or for performance reasons, it may be necessary to create custom exploration MCMC kernels. This is described on the custom explorers page.
Custom gradients
In some situations it may be helpful to compute gradients explicitly (performance, unsupported primitives, etc). One method to do so is to use autodiff-specific machinery, see for example the Enzyme documentation. In addition, Pigeons also has an AD framework-agnostic method to provide explicit gradients, supporting replica-specific, in-place buffers (this functionality was developed to support efficient interfacing with Stan). Using this is demonstrated below:
using Pigeons
using Random
using LogDensityProblems
using LogDensityProblemsAD
struct CustomGradientLogPotential
precision::Float64
dim::Int
end
function (log_potential::CustomGradientLogPotential)(x)
-0.5 * log_potential.precision * sum(abs2, x)
end
Pigeons.initialization(lp::CustomGradientLogPotential, ::AbstractRNG, ::Int) = zeros(lp.dim)
LogDensityProblems.dimension(lp::CustomGradientLogPotential) = lp.dim
LogDensityProblems.logdensity(lp::CustomGradientLogPotential, x) = lp(x)
LogDensityProblemsAD.ADgradient(::Val, log_potential::CustomGradientLogPotential, replica::Pigeons.Replica) =
Pigeons.BufferedAD(log_potential, replica.recorders.buffers)
const check_custom_grad_called = Ref(false)
function LogDensityProblems.logdensity_and_gradient(log_potential::Pigeons.BufferedAD{CustomGradientLogPotential}, x)
logdens = log_potential.enclosed(x)
global check_custom_grad_called[] = true
log_potential.buffer .= -log_potential.enclosed.precision .* x
return logdens, log_potential.buffer
end
pigeons(
target = CustomGradientLogPotential(2.1, 4),
reference = CustomGradientLogPotential(1.1, 4),
n_chains = 1,
n_rounds = 5,
explorer = AutoMALA())
@assert check_custom_grad_called[]
┌ Info: Neither traces, disk, nor online recorders included.
│ You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└ To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
─────────────────────────────────────────────────────────────────
scans time(s) allc(B) log(Z₁/Z₀) min(αₑ) mean(αₑ)
────────── ────────── ────────── ────────── ────────── ──────────
2 0.351 1.61e+07 0 0.488 0.488
4 0.0615 4.68e+06 0 0.771 0.771
8 9.95e-05 6.11e+03 0 0.56 0.56
16 0.000132 9.57e+03 0 0.703 0.703
32 0.000228 1.65e+04 0 0.652 0.652
─────────────────────────────────────────────────────────────────
Manipulating the output
Some common post-processing are shown below, see the section on output processing for more information.
using MCMCChains
using StatsPlots
plotlyjs()
pt = pigeons(
target = MyLogPotential(100, 50),
reference = MyLogPotential(0, 0),
explorer = AutoMALA(default_autodiff_backend = :Enzyme),
record = [traces])
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "julia_posterior_densities_and_traces.html");
samples
Chains MCMC chain (1024×3×1 Array{Float64, 3}):
Iterations = 1:1:1024
Number of chains = 1
Samples per chain = 1024
parameters = param_1, param_2
internals = log_density
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat e ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
param_1 0.7139 0.1519 0.0072 476.2129 813.4181 1.0009 ⋯
param_2 0.7196 0.1521 0.0068 531.4729 717.6050 0.9994 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
param_1 0.4607 0.5920 0.7058 0.8373 0.9825
param_2 0.4781 0.5885 0.7118 0.8463 0.9859