The arviz
library for plotting the results of MCMC runs relies on special InferenceData
objects. Unfortunately, it's not immediately obvious how to translate the results of MCMC runs from Numpyro
into these special objects. In this post, we'll build such a translation function.
import os
os.environ["JAX_PLATFORMS"] = "cpu"
import numpy as np
import jax.numpy as jnp
import jax
import numpyro
numpyro.set_host_device_count(4)
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from numpyro import handlers
from typing import Optional
import pandas as pd
import pyro_util
import arviz as az
I'll motivate our discussion with some synthetic data. The glm
and fit_nuts
functions were defined in a previous post.
year = [0, 0, 1, 1, 2, 2]
x = np.random.randn(6)
bases = 2 * np.random.randn(3)
df = pd.DataFrame({'year': pd.CategoricalIndex(year), 'x': x, 'y': x + bases[year] + np.random.randn(6)})
def model(**kwargs):
pyro_util.glm( 'y ~ year + x', df, **kwargs)
mcmc = pyro_util.fit_nuts(model)
0%| | 0/1500 [00:00<?, ?it/s]
0%| | 0/1500 [00:00<?, ?it/s]
0%| | 0/1500 [00:00<?, ?it/s]
0%| | 0/1500 [00:00<?, ?it/s]
Arviz
comes with the function az.from_numpyro
that seems like it shoud be able to the job for us. But if we call it naively, the dimension names are all messed up!
az.from_numpyro(mcmc)
-
<xarray.Dataset> Size: 184kB Dimensions: (chain: 4, draw: 1000, mu_dim_0: 6, year_dim_0: 2) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * mu_dim_0 (mu_dim_0) int64 48B 0 1 2 3 4 5 * year_dim_0 (year_dim_0) int64 16B 0 1 Data variables: Intercept (chain, draw) float32 16kB 0.7586 2.631 2.299 ... 0.3715 1.361 mu (chain, draw, mu_dim_0) float32 96kB 0.4586 0.5923 ... 4.705 sigma (chain, draw) float32 16kB 1.11 2.791 1.157 ... 1.344 1.259 x (chain, draw) float32 16kB 0.4241 -0.2606 ... -0.1037 0.2269 year (chain, draw, year_dim_0) float32 32kB -3.399 2.84 ... 3.188 Attributes: created_at: 2025-01-08T00:36:28.626591 arviz_version: 0.17.1 inference_library: numpyro inference_library_version: 0.15.1
-
<xarray.Dataset> Size: 104kB Dimensions: (chain: 4, draw: 1000, y_dim_0: 6) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * y_dim_0 (y_dim_0) int64 48B 0 1 2 3 4 5 Data variables: y (chain, draw, y_dim_0) float32 96kB -1.926 -1.081 ... -1.153 -2.492 Attributes: created_at: 2025-01-08T00:36:28.684388 arviz_version: 0.17.1 inference_library: numpyro inference_library_version: 0.15.1
-
<xarray.Dataset> Size: 12kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: diverging (chain, draw) bool 4kB False False False ... False False False Attributes: created_at: 2025-01-08T00:36:28.632759 arviz_version: 0.17.1 inference_library: numpyro inference_library_version: 0.15.1
-
<xarray.Dataset> Size: 72B Dimensions: (y_dim_0: 6) Coordinates: * y_dim_0 (y_dim_0) int64 48B 0 1 2 3 4 5 Data variables: y (y_dim_0) float32 24B 1.951 0.9676 -3.613 -1.39 5.19 2.642 Attributes: created_at: 2025-01-08T00:36:28.685243 arviz_version: 0.17.1 inference_library: numpyro inference_library_version: 0.15.1
In the model, observations are stacked in a plate called "obs". But this InferenceData
object is calling the dimension "mu_dim_0". In the model, the parameters for each year were nicely stacked in a plate called "years". But once again, this InferenceData
object ignores the plate names we worked so hard to build and substitutes its own "year_dim_0".
To fix this, we need a way to extract plate names from a numpyro model. We'll use the following handler:
class extract_dims(handlers.Messenger):
"""
This effect handler tracks the plates associated with each sample site.
The resulting map from sample names to lists of plates is stored
in `self.dims`.
"""
def __init__(self, fn=None):
self.dims = {}
super().__init__(fn)
def process_message(self, msg):
if msg["type"] in ("sample", "deterministic"):
dims = [a.name for a in msg['cond_indep_stack']]
dims.reverse()
self.dims[msg['name']] = dims
When you run a model using this handler, the handler's dims
property will contain a map from each variable name to the stack of plate names on which it lies.
dimwrap = extract_dims()
with dimwrap:
with handlers.seed(rng_seed=1):
model(predictive=False)
dimwrap.dims
{'Intercept': [], 'year': ['years'], 'x': [], 'sigma': [], 'mu': ['obs'], 'y': ['obs']}
We can use these dimensions in the dims
argument of az.from_numpyro
:
def from_numpyro(df, model, mcmc, *args):
dimwrap = extract_dims()
with dimwrap:
with handlers.seed(rng_seed=1):
model(*args, predictive=False)
result = az.from_numpyro(mcmc,
dims=dimwrap.dims)
return result
We also usually want samples from the prior and posterior predictive distributions as well. The az.from_numpyro
function accepts these arguments as well, but leaves us the annoying task of generating the samples. I'll augment the new from_pyro
function to do the grunt work for us. Finally, it's often necessary to plot posterior samples against covariates from the original dataframe. The easiest way to do this is to add the original dataframe to the InferenceData
object.
def from_numpyro(df, model, mcmc, *args):
post_pred = Predictive(model, mcmc.get_samples())(jax.random.PRNGKey(1), *args, predictive=True)
prior = Predictive(model, num_samples=1000)(jax.random.PRNGKey(2), *args, predictive=True)
dimwrap = extract_dims()
with dimwrap:
with handlers.seed(rng_seed=1):
model(*args, predictive=False)
result = az.from_numpyro(mcmc,
prior=prior,
posterior_predictive=post_pred,
dims=dimwrap.dims)
result.constant_data = xr.Dataset.from_dataframe(df)
result._groups.append('constant_data')
return result