The arviz library for plotting the results of MCMC runs relies on special InferenceData objects. Unfortunately, it's not immediately obvious how to translate the results of MCMC runs from Numpyro into these special objects. In this post, we'll build such a translation function.

In [1]:
import os
os.environ["JAX_PLATFORMS"] = "cpu"
import numpy as np
import jax.numpy as jnp
import jax
import numpyro
numpyro.set_host_device_count(4)
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from numpyro import handlers
from typing import Optional
import pandas as pd
import pyro_util
import arviz as az

I'll motivate our discussion with some synthetic data. The glm and fit_nuts functions were defined in a previous post.

In [2]:
year = [0, 0, 1, 1, 2, 2]
x = np.random.randn(6)
bases = 2 * np.random.randn(3)
df = pd.DataFrame({'year': pd.CategoricalIndex(year), 'x': x, 'y': x + bases[year] + np.random.randn(6)})
In [3]:
def model(**kwargs):
    pyro_util.glm( 'y ~ year + x', df, **kwargs)
In [4]:
mcmc = pyro_util.fit_nuts(model)
  0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s]

Arviz comes with the function az.from_numpyro that seems like it shoud be able to the job for us. But if we call it naively, the dimension names are all messed up!

In [5]:
az.from_numpyro(mcmc)
Out[5]:
arviz.InferenceData
    • <xarray.Dataset> Size: 184kB
      Dimensions:     (chain: 4, draw: 1000, mu_dim_0: 6, year_dim_0: 2)
      Coordinates:
        * chain       (chain) int64 32B 0 1 2 3
        * draw        (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * mu_dim_0    (mu_dim_0) int64 48B 0 1 2 3 4 5
        * year_dim_0  (year_dim_0) int64 16B 0 1
      Data variables:
          Intercept   (chain, draw) float32 16kB 0.7586 2.631 2.299 ... 0.3715 1.361
          mu          (chain, draw, mu_dim_0) float32 96kB 0.4586 0.5923 ... 4.705
          sigma       (chain, draw) float32 16kB 1.11 2.791 1.157 ... 1.344 1.259
          x           (chain, draw) float32 16kB 0.4241 -0.2606 ... -0.1037 0.2269
          year        (chain, draw, year_dim_0) float32 32kB -3.399 2.84 ... 3.188
      Attributes:
          created_at:                 2025-01-08T00:36:28.626591
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.15.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • mu_dim_0: 6
        • year_dim_0: 2
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999])
        • mu_dim_0
          (mu_dim_0)
          int64
          0 1 2 3 4 5
          array([0, 1, 2, 3, 4, 5])
        • year_dim_0
          (year_dim_0)
          int64
          0 1
          array([0, 1])
        • Intercept
          (chain, draw)
          float32
          0.7586 2.631 2.299 ... 0.3715 1.361
          array([[0.75857586, 2.63063   , 2.298524  , ..., 3.1893117 , 3.930863  ,
                  3.8521183 ],
                 [1.9285047 , 2.3176327 , 2.7241917 , ..., 1.4848075 , 2.2841883 ,
                  0.445622  ],
                 [1.6812876 , 1.2485336 , 3.1172335 , ..., 2.9870203 , 2.2174242 ,
                  2.4899802 ],
                 [0.71638304, 4.1856446 , 1.1233672 , ..., 0.37151736, 0.37151736,
                  1.360693  ]], dtype=float32)
        • mu
          (chain, draw, mu_dim_0)
          float32
          0.4586 0.5923 ... 5.089 4.705
          array([[[ 0.45856354,  0.5922939 , -2.8679502 , -2.8879251 ,
                    4.6079144 ,  3.8916883 ],
                  [ 2.8149552 ,  2.7327924 , -0.9055467 , -0.8932743 ,
                    2.4330864 ,  2.87313   ],
                  [ 1.6159793 ,  1.9202234 , -4.572874  , -4.618318  ,
                    5.7492843 ,  4.11983   ],
                  ...,
                  [ 1.5307113 ,  2.270032  , -2.035584  , -2.146014  ,
                    4.7721424 ,  0.81252915],
                  [ 3.215802  ,  3.5345402 , -2.1341548 , -2.1817636 ,
                    3.4110792 ,  1.7039983 ],
                  [ 2.772389  ,  3.2536778 , -2.1255581 , -2.1974466 ,
                    5.390095  ,  2.8124337 ]],
          
                 [[ 0.69686633,  1.2458688 , -2.3627958 , -2.4447985 ,
                    5.8813186 ,  2.9410012 ],
                  [ 1.4136187 ,  1.8165826 , -1.7210183 , -1.7812077 ,
                    4.7406635 ,  2.5824907 ],
                  [ 1.9081947 ,  2.2719252 , -3.1131928 , -3.167522  ,
                    4.145621  ,  2.1975734 ],
          ...
                  [ 3.0342183 ,  3.0131798 , -2.8795478 , -2.8764055 ,
                    4.5105057 ,  4.623183  ],
                  [ 1.9625264 ,  2.076147  , -3.5765648 , -3.5935361 ,
                    3.2034104 ,  2.5948875 ],
                  [ 1.3551953 ,  1.8610253 , -2.1033273 , -2.1788814 ,
                    3.9322953 ,  1.2231984 ]],
          
                 [[-1.151205  , -0.31872818, -4.4340568 , -4.558401  ,
                    5.3482323 ,  0.88969857],
                  [ 2.9539907 ,  3.503     , -1.0877671 , -1.169771  ,
                    6.680791  ,  3.7404366 ],
                  [-0.39626086,  0.28111288, -1.8339602 , -1.9351373 ,
                    5.0932326 ,  1.4653913 ],
                  ...,
                  [ 0.44488862,  0.4121834 , -2.4801948 , -2.4753096 ,
                    3.3946764 ,  3.569837  ],
                  [ 0.44488862,  0.4121834 , -2.4801948 , -2.4753096 ,
                    3.3946764 ,  3.569837  ],
                  [ 1.2001832 ,  1.2717304 , -2.7713392 , -2.782026  ,
                    5.0885663 ,  4.7053776 ]]], dtype=float32)
        • sigma
          (chain, draw)
          float32
          1.11 2.791 1.157 ... 1.344 1.259
          array([[1.1102884, 2.7913005, 1.1571234, ..., 1.2137805, 2.4473596,
                  2.3534856],
                 [1.1564898, 1.1613271, 1.3617151, ..., 1.5085388, 2.4395025,
                  1.8731871],
                 [1.5024555, 1.733058 , 3.1805377, ..., 1.0380949, 1.0931878,
                  1.9955833],
                 [1.8776833, 1.9939799, 1.4786174, ..., 1.3440568, 1.3440568,
                  1.2591903]], dtype=float32)
        • x
          (chain, draw)
          float32
          0.4241 -0.2606 ... -0.1037 0.2269
          array([[ 0.42414588, -0.2605919 ,  0.9649554 , ...,  2.3448656 ,
                   1.0109258 ,  1.5264798 ],
                 [ 1.741243  ,  1.2780603 ,  1.1536251 , ...,  1.6887498 ,
                  -1.4377515 ,  0.1342758 ],
                 [ 2.5390055 ,  3.2143846 , -0.13694805, ..., -0.06672669,
                   0.36036456,  1.6043153 ],
                 [ 2.640324  ,  1.7412649 ,  2.1483917 , ..., -0.10372945,
                  -0.10372945,  0.2269226 ]], dtype=float32)
        • year
          (chain, draw, year_dim_0)
          float32
          -3.399 2.84 -3.676 ... -4.01 3.188
          array([[[-3.3992376 ,  2.8397505 ],
                  [-3.675821  ,  0.42273918],
                  [-6.354304  ,  1.1538917 ],
                  ...,
                  [-3.9683442 , -3.9986172 ],
                  [-5.523289  , -2.926075  ],
                  [-5.159676  , -2.0954797 ]],
          
                 [[-3.3582144 , -0.19184045],
                  [-3.3537722 , -0.6191169 ],
                  [-5.2191873 , -1.3245273 ],
                  ...,
                  [-3.6180646 ,  0.09287859],
                  [-5.2057524 ,  4.7405696 ],
                  [-3.2393632 ,  3.7050037 ]],
          
                 [[-2.0643342 , -3.6393173 ],
                  [-0.9636353 , -0.5539977 ],
                  [-9.473647  , -1.3633595 ],
                  ...,
                  [-5.902325  ,  1.6823144 ],
                  [-5.600879  ,  0.12821612],
                  [-3.7335973 , -2.3764122 ]],
          
                 [[-3.7355597 , -1.652874  ],
                  [-4.340314  , -1.6495602 ],
                  [-1.806061  , -1.1439186 ],
                  ...,
                  [-2.907298  ,  3.2700646 ],
                  [-2.907298  ,  3.2700646 ],
                  [-4.0104303 ,  3.187733  ]]], dtype=float32)
        • chain
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain'))
        • draw
          PandasIndex
          PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
                 ...
                 990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
                dtype='int64', name='draw', length=1000))
        • mu_dim_0
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3, 4, 5], dtype='int64', name='mu_dim_0'))
        • year_dim_0
          PandasIndex
          PandasIndex(Index([0, 1], dtype='int64', name='year_dim_0'))
      • created_at :
        2025-01-08T00:36:28.626591
        arviz_version :
        0.17.1
        inference_library :
        numpyro
        inference_library_version :
        0.15.1

    • <xarray.Dataset> Size: 104kB
      Dimensions:  (chain: 4, draw: 1000, y_dim_0: 6)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * y_dim_0  (y_dim_0) int64 48B 0 1 2 3 4 5
      Data variables:
          y        (chain, draw, y_dim_0) float32 96kB -1.926 -1.081 ... -1.153 -2.492
      Attributes:
          created_at:                 2025-01-08T00:36:28.684388
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.15.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • y_dim_0: 6
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999])
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5
          array([0, 1, 2, 3, 4, 5])
        • y
          (chain, draw, y_dim_0)
          float32
          -1.926 -1.081 ... -1.153 -2.492
          array([[[-1.9264581, -1.0806873, -1.2489105, -1.9340634, -1.1610949,
                   -1.6567236],
                  [-1.9933943, -2.1454065, -2.4159756, -1.9612577, -2.4332848,
                   -1.9488665],
                  [-1.1066815, -1.4037653, -1.408698 , -4.957637 , -1.1815877,
                   -1.8801494],
                  ...,
                  [-1.1725051, -1.6883883, -1.9575052, -1.3068353, -1.1720021,
                   -2.24891  ],
                  [-1.9475816, -2.3640056, -1.9965975, -1.8663266, -2.0781896,
                   -1.8874382],
                  [-1.8358036, -2.2466068, -1.974649 , -1.833741 , -1.7784418,
                   -1.77745  ]],
          
                 [[-1.6519201, -1.093277 , -1.6489596, -1.4805399, -1.2428745,
                   -1.0976906],
                  [-1.1753902, -1.3357191, -2.3960476, -1.1253421, -1.1434313,
                   -1.0698266],
                  [-1.2281678, -1.6864305, -1.2951345, -2.0799966, -1.5219269,
                   -1.2810068],
          ...
                  [-1.5011706, -2.897798 , -1.2061522, -1.9819205, -1.1706957,
                   -2.776984 ],
                  [-1.0080965, -1.5221908, -1.0086024, -3.040203 , -2.6596107,
                   -1.0089757],
                  [-1.65438  , -1.7100945, -1.896154 , -1.6880809, -1.8085519,
                   -1.8627093]],
          
                 [[-2.9133902, -1.7836297, -1.6445018, -2.972951 , -1.5525175,
                   -1.9845641],
                  [-1.7356889, -2.4174657, -2.4112058, -1.6151509, -1.8884714,
                   -1.7607301],
                  [-2.5696144, -1.4178205, -2.034139 , -1.3780963, -1.3121978,
                   -1.6267976],
                  ...,
                  [-1.8421123, -1.3000125, -1.5700197, -1.5408605, -2.106975 ,
                   -1.4527694],
                  [-1.8421123, -1.3000125, -1.5700197, -1.5408605, -2.106975 ,
                   -1.4527694],
                  [-1.3269724, -1.1785766, -1.3729752, -1.7607732, -1.1526668,
                   -2.4916553]]], dtype=float32)
        • chain
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain'))
        • draw
          PandasIndex
          PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
                 ...
                 990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
                dtype='int64', name='draw', length=1000))
        • y_dim_0
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3, 4, 5], dtype='int64', name='y_dim_0'))
      • created_at :
        2025-01-08T00:36:28.684388
        arviz_version :
        0.17.1
        inference_library :
        numpyro
        inference_library_version :
        0.15.1

    • <xarray.Dataset> Size: 12kB
      Dimensions:    (chain: 4, draw: 1000)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
      Data variables:
          diverging  (chain, draw) bool 4kB False False False ... False False False
      Attributes:
          created_at:                 2025-01-08T00:36:28.632759
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.15.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999])
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]])
        • chain
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain'))
        • draw
          PandasIndex
          PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
                 ...
                 990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
                dtype='int64', name='draw', length=1000))
      • created_at :
        2025-01-08T00:36:28.632759
        arviz_version :
        0.17.1
        inference_library :
        numpyro
        inference_library_version :
        0.15.1

    • <xarray.Dataset> Size: 72B
      Dimensions:  (y_dim_0: 6)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 48B 0 1 2 3 4 5
      Data variables:
          y        (y_dim_0) float32 24B 1.951 0.9676 -3.613 -1.39 5.19 2.642
      Attributes:
          created_at:                 2025-01-08T00:36:28.685243
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.15.1
      xarray.Dataset
        • y_dim_0: 6
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5
          array([0, 1, 2, 3, 4, 5])
        • y
          (y_dim_0)
          float32
          1.951 0.9676 -3.613 ... 5.19 2.642
          array([ 1.9505696 ,  0.96759415, -3.613337  , -1.3896487 ,  5.190232  ,
                  2.6422658 ], dtype=float32)
        • y_dim_0
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3, 4, 5], dtype='int64', name='y_dim_0'))
      • created_at :
        2025-01-08T00:36:28.685243
        arviz_version :
        0.17.1
        inference_library :
        numpyro
        inference_library_version :
        0.15.1

In the model, observations are stacked in a plate called "obs". But this InferenceData object is calling the dimension "mu_dim_0". In the model, the parameters for each year were nicely stacked in a plate called "years". But once again, this InferenceData object ignores the plate names we worked so hard to build and substitutes its own "year_dim_0".

To fix this, we need a way to extract plate names from a numpyro model. We'll use the following handler:

In [6]:
class extract_dims(handlers.Messenger):
    """
    This effect handler tracks the plates associated with each sample site.
    The resulting map from sample names to lists of plates is stored
    in `self.dims`.
    """
    def __init__(self, fn=None):
        self.dims = {}
        super().__init__(fn)

    def process_message(self, msg):
        if msg["type"] in ("sample", "deterministic"):
            dims = [a.name for a in msg['cond_indep_stack']]
            dims.reverse()
            self.dims[msg['name']] = dims

When you run a model using this handler, the handler's dims property will contain a map from each variable name to the stack of plate names on which it lies.

In [7]:
dimwrap = extract_dims()
with dimwrap:
    with handlers.seed(rng_seed=1):
        model(predictive=False)
In [8]:
dimwrap.dims
Out[8]:
{'Intercept': [],
 'year': ['years'],
 'x': [],
 'sigma': [],
 'mu': ['obs'],
 'y': ['obs']}

We can use these dimensions in the dims argument of az.from_numpyro:

In [9]:
def from_numpyro(df, model, mcmc, *args):
    dimwrap = extract_dims()
    with dimwrap:
        with handlers.seed(rng_seed=1):
            model(*args, predictive=False)
    result = az.from_numpyro(mcmc,
        dims=dimwrap.dims)
    return result

We also usually want samples from the prior and posterior predictive distributions as well. The az.from_numpyro function accepts these arguments as well, but leaves us the annoying task of generating the samples. I'll augment the new from_pyro function to do the grunt work for us. Finally, it's often necessary to plot posterior samples against covariates from the original dataframe. The easiest way to do this is to add the original dataframe to the InferenceData object.

In [10]:
def from_numpyro(df, model, mcmc, *args):
    post_pred = Predictive(model, mcmc.get_samples())(jax.random.PRNGKey(1), *args, predictive=True)
    prior = Predictive(model, num_samples=1000)(jax.random.PRNGKey(2), *args, predictive=True)
    dimwrap = extract_dims()
    with dimwrap:
        with handlers.seed(rng_seed=1):
            model(*args, predictive=False)
    result = az.from_numpyro(mcmc,
        prior=prior,
        posterior_predictive=post_pred,
        dims=dimwrap.dims)
    result.constant_data = xr.Dataset.from_dataframe(df)
    result._groups.append('constant_data')
    return result