⁂ George Ho

Cookbook — Bayesian Modelling with PyMC3

Recently I’ve started using PyMC3 for Bayesian modelling, and it’s an amazing piece of software! The API only exposes as much of heavy machinery of MCMC as you need — by which I mean, just the pm.sample() method (a.k.a., as Thomas Wiecki puts it, the Magic Inference Button™). This really frees up your mind to think about your data and model, which is really the heart and soul of data science!

That being said however, I quickly realized that the water gets very deep very fast: I explored my data set, specified a hierarchical model that made sense to me, hit the Magic Inference Button™, and… uh, what now? I blinked at the angry red warnings the sampler spat out.

So began by long, rewarding and ongoing exploration of Bayesian modelling. This is a compilation of notes, tips, tricks and recipes that I’ve collected from everywhere: papers, documentation, peppering my more experienced colleagues with questions. It’s still very much a work in progress, but hopefully somebody else finds it useful!

PyMC logo


For the Uninitiated

EDIT (1/24/2020): I published a subsequent blog post with a reading list for Bayesian inference and modelling. Check it out for reading material in addition to the ones I list below!

Bayesian modelling

Markov-chain Monte Carlo

Variational inference

Model Formulation

Hierarchical models

Model Implementation

MCMC Initialization and Sampling

MCMC Trace Diagnostics

  1. Theoretically, run the chain for as long as you have the patience or resources for. In practice, just use the PyMC3 defaults: 500 tuning iterations, 1000 sampling iterations.

  2. Check for divergences. PyMC3’s sampler will spit out a warning if there are diverging chains, but the following code snippet may make things easier:

    # Display the total number and percentage of divergent chains
    diverging = trace['diverging']
    print('Number of Divergent Chains: {}'.format(diverging.nonzero()[0].size))
    diverging_pct = diverging.nonzero()[0].size / len(trace) * 100
    print('Percentage of Divergent Chains: {:.1f}'.format(diverging_pct))
  3. Check the traceplot (pm.traceplot(trace)). You’re looking for traceplots that look like “fuzzy caterpillars”. If the trace moves into some region and stays there for a long time (a.k.a. there are some “sticky regions”), that’s cause for concern! That indicates that once the sampler moves into some region of parameter space, it gets stuck there (probably due to high curvature or other bad topological properties).

  4. In addition to the traceplot, there are a ton of other plots you can make with your trace:

    • pm.plot_posterior(trace): check if your posteriors look reasonable.
    • pm.forestplot(trace): check if your variables have reasonable credible intervals, and Gelman–Rubin scores close to 1.
    • pm.autocorrplot(trace): check if your chains are impaired by high autocorrelation. Also remember that thinning your chains is a waste of time at best, and deluding yourself at worst. See Chris Fonnesbeck’s comment on this GitHub issue and Junpeng Lao’s reply to Michael Betancourt’s tweet
    • pm.energyplot(trace): ideally the energy and marginal energy distributions should look very similar. Long tails in the distribution of energy levels indicates deteriorated sampler efficiency.
    • pm.densityplot(trace): a souped-up version of pm.plot_posterior. It doesn’t seem to be wildly useful unless you’re plotting posteriors from multiple models.
  5. PyMC3 has a nice helper function to pretty-print a summary table of the trace: pm.summary(trace) (I usually tack on a .round(2) for my sanity). Look out for:

    • the $\hat{R}$ values (a.k.a. the Gelman–Rubin statistic, a.k.a. the potential scale reduction factor, a.k.a. the PSRF): are they all close to 1? If not, something is horribly wrong. Consider respecifying or reparameterizing your model. You can also inspect these in the forest plot.
    • the sign and magnitude of the inferred values: do they make sense, or are they unexpected and unreasonable? This could indicate a poorly specified model. (E.g. parameters of the unexpected sign that have low uncertainties might indicate that your model needs interaction terms.)
  6. As a drastic debugging measure, try to pm.sample with draws=1, tune=500, and discard_tuned_samples=False, and inspect the traceplot. During the tuning phase, we don’t expect to see friendly fuzzy caterpillars, but we do expect to see good (if noisy) exploration of parameter space. So if the sampler is getting stuck during the tuning phase, that might explain why the trace looks horrible.

  7. If you get scary errors that describe mathematical problems (e.g. ValueError: Mass matrix contains zeros on the diagonal. Some derivatives might always be zero.), then you’re shit out of luck exceptionally unlucky: those kinds of errors are notoriously hard to debug. I can only point to the Folk Theorem of Statistical Computing:

    If you’re having computational problems, probably your model is wrong.

Fixing divergences

There were N divergences after tuning. Increase 'target_accept' or reparameterize.

— The Magic Inference Button™

Other common warnings

Model reparameterization

Model Diagnostics

  1. Run the following snippet of code to inspect the pairplot of your variables one at a time (if you have a plate of variables, it’s fine to pick a couple at random). It’ll tell you if the two random variables are correlated, and help identify any troublesome neighborhoods in the parameter space (divergent samples will be colored differently, and will cluster near such neighborhoods).

                sub_varnames=[variable_1, variable_2],
                kwargs_divergence={'color': 'C2'})
  2. Look at your posteriors (either from the traceplot, density plots or posterior plots). Do they even make sense? E.g. are there outliers or long tails that you weren’t expecting? Do their uncertainties look reasonable to you? If you had a plate of variables, are their posteriors different? Did you expect them to be that way? If not, what about the data made the posteriors different? You’re the only one who knows your problem/use case, so the posteriors better look good to you!

  3. Broadly speaking, there are four kinds of bad geometries that your posterior can suffer from:

    • highly correlated posteriors: this will probably cause divergences or traces that don’t look like “fuzzy caterpillars”. Either look at the jointplots of each pair of variables, or look at the correlation matrix of all variables. Try using a centered parameterization, or reparameterize in some other way, to remove these correlations.
    • posteriors that form “funnels”: this will probably cause divergences. Try using a noncentered parameterization.
    • heavy tailed posteriors: this will probably raise warnings about max_treedepth being exceeded. If your data has long tails, you should model that with a long-tailed distribution. If your data doesn’t have long tails, then your model is ill-specified: perhaps a more informative prior would help.
    • multimodal posteriors: right now this is pretty much a death blow. At the time of writing, all samplers have a hard time with multimodality, and there’s not much you can do about that. Try reparameterizing to get a unimodal posterior. If that’s not possible (perhaps you’re modelling multimodality using a mixture model), you’re out of luck: just let NUTS sample for a day or so, and hopefully you’ll get a good trace.
  4. Pick a small subset of your raw data, and see what exactly your model does with that data (i.e. run the model on a specific subset of your data). I find that a lot of problems with your model can be found this way.

  5. Run posterior predictive checks (a.k.a. PPCs): sample from your posterior, plug it back in to your model, and “generate new data sets”. PyMC3 even has a nice function to do all this for you: pm.sample_ppc. But what do you do with these new data sets? That’s a question only you can answer! The point of a PPC is to see if the generated data sets reproduce patterns you care about in the observed real data set, and only you know what patterns you care about. E.g. how close are the PPC means to the observed sample mean? What about the variance?

    • For example, suppose you were modelling the levels of radon gas in different counties in a country (through a hierarchical model). Then you could sample radon gas levels from the posterior for each county, and take the maximum within each county. You’d then have a distribution of maximum radon gas levels across counties. You could then check if the actual maximum radon gas level (in your observed data set) is acceptably within that distribution. If it’s much larger than the maxima, then you would know that the actual likelihood has longer tails than you assumed (e.g. perhaps you should use a Student’s T instead of a normal?)
    • Remember that how well the posterior predictive distribution fits the data is of little consequence (e.g. the expectation that 90% of the data should fall within the 90% credible interval of the posterior). The posterior predictive distribution tells you what values for data you would expect if we were to remeasure, given that you’ve already observed the data you did. As such, it’s informed by your prior as well as your data, and it’s not its job to adequately fit your data!

#bayes #pymc #open-source