Output for custom types
Much of the discussion on sample post-processing so far has focussed on cases where state
's are real or integer vectors.
We discuss here how to post-process samples when the states are not real or integer vectors ("custom types").
Example of custom state
As a concrete example, we consider an implementation of an Ising model where a state contains a matrix of binary variables as well as some other caches. The full example can be found here, the only snippet needed from this file needed to understand the following is:
mutable struct IsingState
matrix::BitMatrix
[some cache variable, etc...]
end
Flattening into a vector
The sample_array
function is convenient but it assumes that the variables are real or integer vectors (the latter coerced into the former).
Sometimes, custom types can be "flattened" into a real vector. For example, a 2D Ising grid can be reshaped into a vector using vec()
.
To perform flattening, add a dispatch to Pigeons' extract_sample
. Here is how this would be done for the same Ising example as above:
include("../../examples/ising.jl")
Pigeons.extract_sample(state::IsingState, log_potential) = copy(vec(state.matrix))
pt = pigeons(target = IsingLogPotential(1.0, 2), record = [traces])
using MCMCChains
using StatsPlots
samples = Chains(sample_array(pt))
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "posterior_densities_and_traces_ising.html");
──────────────────────────────────────────────────────
scans Λ log(Z₁/Z₀) min(α) mean(α)
────────── ────────── ────────── ────────── ──────────
2 0 6.67 1 1
4 0.883 5.33 0.706 0.902
8 0.561 6.21 0.805 0.938
16 1.15 5.64 0.378 0.872
32 0.786 5.8 0.752 0.913
64 0.805 6.04 0.806 0.911
128 0.944 5.95 0.817 0.895
256 0.935 6.08 0.861 0.896
512 0.933 5.79 0.822 0.896
1.02e+03 0.943 5.9 0.873 0.895
──────────────────────────────────────────────────────
This plots the 4 components of a two-by-two Ising model:
Trace processing without flattening
It is also possible to process in-memory traces without flattening. To do so, the function extract_sample
should still be extended to perform a copy of the relevant parts of the state. Then to access the trace, use get_sample
:
include("../../examples/ising.jl")
Pigeons.extract_sample(state::IsingState, log_potential) = copy(state.matrix)
pt = pigeons(target = IsingLogPotential(1.0, 2), record = [traces])
vector = get_sample(pt)
# a vector of 2^10 samples, each extracted into a BitMatrix:
length(vector), eltype(vector)
(1024, BitMatrix)