Online (constant memory) statistics

When the dimensionality of a model is large and/or the number of MCMC samples is large, the samples may not fit in memory. The most flexible way to deal with this situation is to write samples to disk and process them one at the time, as described in the off-memory processing documentation. However, certain statistics can be computed using fixed dimensional sufficient statistics yielding more efficient algorithms. We describe this alternative here.

Built-in online statistics: mean and variance

Simply include the online() recorder to get access to constant memory computation of the mean and variance.

using DynamicPPL
using Pigeons

# example target: Binomial likelihood with parameter p = p1 * p2
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100, 50)

pt = pigeons(
        target = an_unidentifiable_model,
        record = [online]
    )

using Statistics
mean(pt)
3-element Vector{Float64}:
  0.7238646296733922
  0.7086193530401038
 -6.815596173012225

To be more precise, the online statistics are computed on the result of calling extract_sample(). Use sample_names() to obtain the description of each coordinate:

sample_names(pt)
3-element Vector{Symbol}:
 :p1
 :p2
 :log_density

Including other online statistics

The computation of online statistics makes use of OnlineStats.jl.

The functions mean and var are implemented via the types Mean and Variance from the OnlineStats library. Many other constant-memory statistic accumulators are available in the OnlineStats library. To add additional constant-memory statistic accumulators, register them via register_online_type(). Here is an example to add computation of extrema:

using OnlineStats

# register a type <: OnlineStat to be included
Pigeons.register_online_type(Extrema)

pt = pigeons(
        target = an_unidentifiable_model,
        record = [online]
    )

Pigeons.get_statistic(pt, :singleton_variable, Extrema)
3-element Vector{NamedTuple{(:min, :max, :nmin, :nmax), Tuple{Float64, Float64, Int64, Int64}}}:
 (min = 0.3995428316184216, max = 0.9990893798129334, nmin = 1, nmax = 1)
 (min = 0.3896229038986596, max = 0.9997791042805813, nmin = 1, nmax = 1)
 (min = -13.926399214236575, max = -5.671906123840204, nmin = 1, nmax = 1)