Online (constant memory) statistics
When the dimensionality of a model is large and/or the number of MCMC samples is large, the samples may not fit in memory. The most flexible way to deal with this situation is to write samples to disk and process them one at the time, as described in the off-memory processing documentation. However, certain statistics can be computed using fixed dimensional sufficient statistics yielding more efficient algorithms. We describe this alternative here.
Built-in online statistics: mean and variance
Simply include the online()
recorder to get access to constant memory computation of the mean and variance.
using DynamicPPL
using Pigeons
# example target: Binomial likelihood with parameter p = p1 * p2
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100, 50)
pt = pigeons(
target = an_unidentifiable_model,
record = [online]
)
using Statistics
mean(pt)
3-element Vector{Float64}:
0.7238646296733922
0.7086193530401038
-6.815596173012225
To be more precise, the online statistics are computed on the result of calling extract_sample()
. Use sample_names()
to obtain the description of each coordinate:
sample_names(pt)
3-element Vector{Symbol}:
:p1
:p2
:log_density
Including other online statistics
The computation of online statistics makes use of OnlineStats.jl.
The functions mean
and var
are implemented via the types Mean
and Variance
from the OnlineStats library. Many other constant-memory statistic accumulators are available in the OnlineStats library. To add additional constant-memory statistic accumulators, register them via register_online_type()
. Here is an example to add computation of extrema:
using OnlineStats
# register a type <: OnlineStat to be included
Pigeons.register_online_type(Extrema)
pt = pigeons(
target = an_unidentifiable_model,
record = [online]
)
Pigeons.get_statistic(pt, :singleton_variable, Extrema)
3-element Vector{NamedTuple{(:min, :max, :nmin, :nmax), Tuple{Float64, Float64, Int64, Int64}}}:
(min = 0.3995428316184216, max = 0.9990893798129334, nmin = 1, nmax = 1)
(min = 0.3896229038986596, max = 0.9997791042805813, nmin = 1, nmax = 1)
(min = -13.926399214236575, max = -5.671906123840204, nmin = 1, nmax = 1)