Probabilistic programming languages like Stan and Pymc3 have special DSLs for generalized linear models with data driven priors (bambi
and stan_glm
respectively). In this post, we'll throw together something comparable for Pyro.
import os
os.environ["JAX_PLATFORMS"] = "cpu"
import numpy as np
import jax.numpy as jnp
import jax
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from numpyro import handlers
from typing import Optional
We'll use the patsy
library for an R-like GLM formula syntax.
from patsy import dmatrices, Term
import pandas as pd
Let's start with a fake dataset to test our tool on.
year = [0, 0, 1, 1, 2, 2]
x = np.random.randn(6)
bases = 2 * np.random.randn(3)
df = pd.DataFrame({'year': pd.CategoricalIndex(year), 'x': x, 'y': x + bases[year] + np.random.randn(6)})
We know the data generation process obeys the following formula:
formula = 'y ~ year + x'
can turn this formula and dataframe into a design matrix.
y, design = dmatrices(formula, df)
DesignMatrix with shape (6, 4) Intercept year[T.1] year[T.2] x 1 0 0 0.82525 1 0 0 -1.33658 1 1 0 -2.10089 1 1 0 0.12643 1 0 1 0.34730 1 0 1 0.55807 Terms: 'Intercept' (column 0) 'year' (columns 1:3) 'x' (column 3)
This design-matrix object knows how to relate each term in the formula to a slice of columns indices.
OrderedDict([(Term([]), slice(0, 1, None)), (Term([EvalFactor('year')]), slice(1, 3, None)), (Term([EvalFactor('x')]), slice(3, 4, None))])
We'll give all non-intercept parameters for covariate $x$ a Normal prior centered on zero with a weakly informative standard deviation $2.5 * \sigma_y / \sigma_x$. The intercept will have a Normal prior centered at the maximum likelihood estimate with weakly informative standard deviation $2.5 * \sigma_y$. This is roughly the same approach taken by stan_glm
Putting these pieces together gives the following general purpose GLM function.
def glm(formula: str, df: pd.DataFrame, family: dist.Distribution = dist.Normal,
groups: Optional[str] = None, weights: Optional[str] = None, predictive: bool = False):
Create a generalized linear model following `formula`.
We assume that the associated link function has already been applied to the RHS,
so e.g. count observations with Poisson families are on a log scale.
If the `groups` argument is provided, separate variance parameters are found for each group.
y, design = dmatrices(formula, df)
y = jnp.array(y[:,0])
X = jnp.array(design)
mle_params = jnp.linalg.solve(X.T @ X, X.T @ y)
stdy = y.std()
stds = 2.5 * stdy / X.std(axis=0)
mu = 0.0 # Observation mean for each unit
for (k,v) in design.design_info.term_slices.items():
subX = X[:, v]
loc = mle_params[v] if k == Term([]) else jnp.zeros(1)
K = subX.shape[1]
if K == 1:
beta = numpyro.sample(, dist.Normal(loc[0], 2.5 * stdy if k == Term([]) else stds[v][0]))
mu = mu + jnp.array(subX[:, 0]) * beta
else: # Stack categorical factors into their own plate
with numpyro.plate( + "s", K):
beta = numpyro.sample(, dist.Normal(loc, 2.5 * stdy if k == Term([]) else stds[v]))
mu = mu + subX @ beta
if family is not dist.Poisson:
if groups is not None:
with numpyro.plate("groups", df[groups].nunique()):
sigmas = numpyro.sample("sigma", dist.Exponential(1 / stdy))
sigma = sigmas[df[groups]]
sigma = numpyro.sample("sigma", dist.Exponential(1 / stdy))
if family is dist.Poisson:
data_dist = dist.Poisson(jnp.exp(mu))
obs = jnp.exp(y) # Assuming data is log counts
elif family is dist.NegativeBinomial2:
data_dist = dist.NegativeBinomial2(jnp.exp(mu), 1 / sigma)
obs = jnp.exp(y) # Assuming data is log counts
elif family in [dist.StudentT, dist.Normal, dist.Cauchy, dist.Laplace]:
data_dist = family(mu, sigma)
obs = y
raise Exception("Unknown family")
with handlers.scale(scale=jnp.array(df[weights]) if weights is not None else 1.0):
with numpyro.plate("obs", X.shape[0]):
numpyro.deterministic("mu", mu)
return numpyro.sample('y', data_dist, obs=None if predictive else obs)
Let's try it out!
def model():
glm(formula, df)
def fit_nuts(model, *args, num_samples=1000, **kwargs):
"Run four chains with 500 warmup samples using the NUTS kernel."
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=num_samples, num_chains=4), *args, **kwargs)
return mcmc
mcmc = fit_nuts(model)
0%| | 0/1500 [00:00<?, ?it/s]
0%| | 0/1500 [00:00<?, ?it/s]
0%| | 0/1500 [00:00<?, ?it/s]
0%| | 0/1500 [00:00<?, ?it/s]
mean std median 5.0% 95.0% n_eff r_hat Intercept -2.80 0.83 -2.83 -4.11 -1.65 1370.63 1.00 sigma 0.97 0.70 0.76 0.23 1.83 532.04 1.02 x 1.24 0.50 1.24 0.43 1.95 1963.75 1.00 year[0] 3.72 1.23 3.76 1.94 5.69 1542.31 1.00 year[1] 4.03 1.23 4.05 2.41 6.09 1414.52 1.00 Number of divergences: 33