Correctness checks for distributed/parallel algorithms

It is notoriously difficult to implement correct parallel/distributed algorithms. One strategy we use to address this is to guarantee that the code will output precisely the same output no matter how many threads/machines are used. We describe how this is done under the hood in the page Distributed PT.

In practice, how is this useful? Let us say you developed a new target and you would like to make sure that it works correctly in a multi-threaded environment. To do so, add a flag to indicate to "check" one of the PT rounds as follows, and enable checkpointing

using Pigeons
pigeons(target = toy_mvn_target(100), checked_round = 3, checkpoint = true)
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)
────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2        7.5   3.13e-05   1.28e+04       -122   2.66e-15      0.167
        4       5.59   3.82e-05   1.47e+04       -119    2.6e-07      0.378
        8       6.04   6.21e-05   1.86e+04       -115    0.00129      0.329
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)
────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2        7.5   6.56e-05   1.28e+04       -122   2.66e-15      0.167
        4       5.59   8.77e-05   1.47e+04       -119    2.6e-07      0.378
        8       6.04   0.000103   1.86e+04       -115    0.00129      0.329
────────────────────────────────────────────────────────────────────────────
       16       7.27   0.000132   2.42e+04       -118     0.0134      0.193
       32       6.97   0.000204   3.05e+04       -114      0.107      0.225
       64       7.03   0.000372   4.13e+04       -117     0.0531      0.219
      128       7.23   0.000712   6.08e+04       -114     0.0944      0.196
      256       7.05    0.00138   6.77e+04       -115       0.13      0.217
      512       7.14    0.00272   7.18e+04       -115      0.171      0.207
 1.02e+03       7.19    0.00534   7.91e+04       -115      0.172      0.201
────────────────────────────────────────────────────────────────────────────

The above line does the following: the PT algorithm will pause at the end of round 3, spawn a separate process with only one thread in it, run 3 rounds of PT with the same Inputs object in it, and verify that the checkpoints of the single-threaded run is identical to the one that ran in the main process. If not, an error will be raised with some information on where the discrepancy comes from. Try to pick the checked round to be small enough that it does not dominate the running time (since it runs in single-threaded, single-process mode), but big enough to achieve the same code coverage as the full algorithm. Setting it to zero (or omitting the argument), disables this functionality.

Did the code above actually use many threads? This depends on the value of Threads.nthreads(). Julia currently does not allow you to change this value at runtime, so for convenience we provide the following way to run the job in a child process with a set number of Julia threads:

pt_result = pigeons(target = toy_mvn_target(100), multithreaded = true, checked_round = 3, checkpoint = true, on = ChildProcess(n_threads = 4))
Result{PT}("/home/runner/work/Pigeons.jl/Pigeons.jl/docs/build/results/all/2024-11-14-21-21-03-wypGFatX")

Notice that we also add the flag multithreaded = true, to instruct Pigeons to use the multiple threads available to parallelize exploration across chains (in other use cases, parallelization might get used internally e.g. to parallelize likelihood evaluation).

Here the check passed successfully as expected. But what if you had a third-party target distribution that is not multi-threaded friendly? For example some code sometimes write in global variables or other non-thread safe constructs. In such situation, you can still use your thread-naive target over MPI processes. For example, if the thread-unsafety comes from the use of global variables, then each process will have its own copy of the global variables.

Failed equality check

If you are using a custom struct that is either mutable or containing mutables, it is possible that the check will fail even if your implementation is sound. This is caused by == dispatching === on your type, which is too strict for the purpose of comparing two deserialized checkpoints. See recursive_equal for instructions on how to prevent this behavior.

Correctness checks of MCMC kernels

Pigeons offers a tool, the Exact Invariance Test (EIT), to help validating correctness of MCMC kernels. It formulates an hypothesis test where the null hypothesis is that the provided explorer kernel is invariant with respect to the target distribution. For details, see Bouchard-Côté, 2022, Section 10.5 or this tutorial.

Using EIT is as simple as defining a Bayesian model with a proper prior using the Turing syntax, and then calling invariance_test():

using Pigeons
using Distributions
using DynamicPPL
using HypothesisTests

# note: observation should not be an argument of the Turing model
DynamicPPL.@model function some_generative_model(n_trials)
    p1 ~ Uniform()
    p2 ~ Uniform()
    n_successes ~ Binomial(n_trials, p1*p2)
    return n_successes
end

model = some_generative_model(100)
target = TuringLogPotential(model)
explorer = SliceSampler()
test_result = Pigeons.invariance_test(
                target,
                explorer;
                condition_on=(:n_successes,))

@assert test_result.passed

EIT checks for invariance but not irreducibility. Being able to check invariance irrespective of irreducibility is beneficial: for example if a Gibbs sampler has two moves, one can then test each in isolation each of the two moves. Moreover, in case of failure, we can determine which of the two moves would be problematic.