Correctness checks for distributed/parallel algorithms

It is notoriously difficult to implement correct parallel/distributed algorithms. One strategy we use to address this is to guarantee that the code will output precisely the same output no matter how many threads/machines are used. We describe how this is done under the hood in the page Distributed PT.

In practice, how is this useful? Let us say you developed a new target and you would like to make sure that it works correctly in a multi-threaded environment. To do so, add a flag to indicate to "check" one of the PT rounds as follows, and enable checkpointing

using Pigeons
pigeons(target = toy_mvn_target(100), checked_round = 3, checkpoint = true)
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)
────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2        7.5    4.4e-05   1.28e+04       -122   2.66e-15      0.167
        4       5.59   6.18e-05   1.47e+04       -119    2.6e-07      0.378
        8       6.04   9.89e-05   1.86e+04       -115    0.00129      0.329
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
┌ Info: Neither traces, disk, nor online recorders included.
   You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
   To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)
────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2        7.5    6.6e-05   1.28e+04       -122   2.66e-15      0.167
        4       5.59    8.7e-05   1.47e+04       -119    2.6e-07      0.378
        8       6.04   0.000105   1.86e+04       -115    0.00129      0.329
────────────────────────────────────────────────────────────────────────────
       16       7.27    0.00017   2.42e+04       -118     0.0134      0.193
       32       6.97   0.000256   3.05e+04       -114      0.107      0.225
       64       7.03   0.000446   4.13e+04       -117     0.0531      0.219
      128       7.23   0.000836   6.08e+04       -114     0.0944      0.196
      256       7.05    0.00157   6.77e+04       -115       0.13      0.217
      512       7.14    0.00292   7.18e+04       -115      0.171      0.207
 1.02e+03       7.19    0.00564   7.91e+04       -115      0.172      0.201
────────────────────────────────────────────────────────────────────────────

The above line does the following: the PT algorithm will pause at the end of round 3, spawn a separate process with only one thread in it, run 3 rounds of PT with the same Inputs object in it, and verify that the checkpoints of the single-threaded run is identical to the one that ran in the main process. If not, an error will be raised with some information on where the discrepancy comes from. Try to pick the checked round to be small enough that it does not dominate the running time (since it runs in single-threaded, single-process mode), but big enough to achieve the same code coverage as the full algorithm. Setting it to zero (or omitting the argument), disables this functionality.

Did the code above actually use many threads? This depends on the value of Threads.nthreads(). Julia currently does not allow you to change this value at runtime, so for convenience we provide the following way to run the job in a child process with a set number of Julia threads:

pt_result = pigeons(target = toy_mvn_target(100), multithreaded = true, checked_round = 3, checkpoint = true, on = ChildProcess(n_threads = 4))
Result{PT}("/home/runner/work/Pigeons.jl/Pigeons.jl/docs/build/results/all/2024-05-13-17-50-11-Yd2l9iJD")

Notice that we also add the flag multithreaded = true, to instruct Pigeons to use the multiple threads available to parallelize exploration across chains (in other use cases, parallelization might get used internally e.g. to parallelize likelihood evaluation).

Here the check passed successfully as expected. But what if you had a third-party target distribution that is not multi-threaded friendly? For example some code sometimes write in global variables or other non-thread safe constructs. In such situation, you can still use your thread-naive target over MPI processes. For example, if the thread-unsafety comes from the use of global variables, then each process will have its own copy of the global variables.

Failed equality check

If you are using a custom struct that is either mutable or containing mutables, it is possible that the check will fail even if your implementation is sound. This is caused by == dispatching === on your type, which is too strict for the purpose of comparing two deserialized checkpoints. See recursive_equal for instructions on how to prevent this behavior.