Probabilistic programming languages like Stan and Pymc3 have special DSLs for generalized linear models with data driven priors (bambi
and stan_glm
respectively). In this post, we'll throw together something comparable for Pyro.
import os
os.environ["JAX_PLATFORMS"] = "cpu"
import numpy as np
import jax.numpy as jnp
import jax
import numpyro
numpyro.set_host_device_count(4)
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from numpyro import handlers
from typing import Optional
We'll use the patsy
library for an R-like GLM formula syntax.
from patsy import dmatrices, Term
import pandas as pd
Let's start with a fake dataset to test our tool on.
year = [0, 0, 1, 1, 2, 2]
x = np.random.randn(6)
bases = 2 * np.random.randn(3)
df = pd.DataFrame({'year': pd.CategoricalIndex(year), 'x': x, 'y': x + bases[year] + np.random.randn(6)})
We know the data generation process obeys the following formula:
formula = 'y ~ year + x'
Patsy
can turn this formula and dataframe into a design matrix.
y, design = dmatrices(formula, df)
design
DesignMatrix with shape (6, 4) Intercept year[T.1] year[T.2] x 1 0 0 0.82525 1 0 0 -1.33658 1 1 0 -2.10089 1 1 0 0.12643 1 0 1 0.34730 1 0 1 0.55807 Terms: 'Intercept' (column 0) 'year' (columns 1:3) 'x' (column 3)
This design-matrix object knows how to relate each term in the formula to a slice of columns indices.
design.design_info.term_slices
OrderedDict([(Term([]), slice(0, 1, None)), (Term([EvalFactor('year')]), slice(1, 3, None)), (Term([EvalFactor('x')]), slice(3, 4, None))])
We'll give all non-intercept parameters for covariate $x$ a Normal prior centered on zero with a weakly informative standard deviation $2.5 * \sigma_y / \sigma_x$. The intercept will have a Normal prior centered at the maximum likelihood estimate with weakly informative standard deviation $2.5 * \sigma_y$. This is roughly the same approach taken by stan_glm
.
Putting these pieces together gives the following general purpose GLM function.
def glm(formula: str, df: pd.DataFrame, family: dist.Distribution = dist.Normal,
groups: Optional[str] = None, weights: Optional[str] = None, predictive: bool = False):
"""
Create a generalized linear model following `formula`.
We assume that the associated link function has already been applied to the RHS,
so e.g. count observations with Poisson families are on a log scale.
If the `groups` argument is provided, separate variance parameters are found for each group.
"""
y, design = dmatrices(formula, df)
y = jnp.array(y[:,0])
X = jnp.array(design)
mle_params = jnp.linalg.solve(X.T @ X, X.T @ y)
stdy = y.std()
stds = 2.5 * stdy / X.std(axis=0)
mu = 0.0 # Observation mean for each unit
for (k,v) in design.design_info.term_slices.items():
subX = X[:, v]
loc = mle_params[v] if k == Term([]) else jnp.zeros(1)
K = subX.shape[1]
if K == 1:
beta = numpyro.sample(k.name(), dist.Normal(loc[0], 2.5 * stdy if k == Term([]) else stds[v][0]))
mu = mu + jnp.array(subX[:, 0]) * beta
else: # Stack categorical factors into their own plate
with numpyro.plate(k.name() + "s", K):
beta = numpyro.sample(k.name(), dist.Normal(loc, 2.5 * stdy if k == Term([]) else stds[v]))
mu = mu + subX @ beta
if family is not dist.Poisson:
if groups is not None:
with numpyro.plate("groups", df[groups].nunique()):
sigmas = numpyro.sample("sigma", dist.Exponential(1 / stdy))
sigma = sigmas[df[groups].cat.codes.to_numpy()]
else:
sigma = numpyro.sample("sigma", dist.Exponential(1 / stdy))
if family is dist.Poisson:
data_dist = dist.Poisson(jnp.exp(mu))
obs = jnp.exp(y) # Assuming data is log counts
elif family is dist.NegativeBinomial2:
data_dist = dist.NegativeBinomial2(jnp.exp(mu), 1 / sigma)
obs = jnp.exp(y) # Assuming data is log counts
elif family in [dist.StudentT, dist.Normal, dist.Cauchy, dist.Laplace]:
data_dist = family(mu, sigma)
obs = y
else:
raise Exception("Unknown family")
with handlers.scale(scale=jnp.array(df[weights]) if weights is not None else 1.0):
with numpyro.plate("obs", X.shape[0]):
numpyro.deterministic("mu", mu)
return numpyro.sample('y', data_dist, obs=None if predictive else obs)
Let's try it out!
def model():
glm(formula, df)
def fit_nuts(model, *args, num_samples=1000, **kwargs):
"Run four chains with 500 warmup samples using the NUTS kernel."
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=num_samples, num_chains=4)
mcmc.run(jax.random.PRNGKey(0), *args, **kwargs)
return mcmc
mcmc = fit_nuts(model)
0%| | 0/1500 [00:00<?, ?it/s]
0%| | 0/1500 [00:00<?, ?it/s]
0%| | 0/1500 [00:00<?, ?it/s]
0%| | 0/1500 [00:00<?, ?it/s]
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat Intercept -2.80 0.83 -2.83 -4.11 -1.65 1370.63 1.00 sigma 0.97 0.70 0.76 0.23 1.83 532.04 1.02 x 1.24 0.50 1.24 0.43 1.95 1963.75 1.00 year[0] 3.72 1.23 3.76 1.94 5.69 1542.31 1.00 year[1] 4.03 1.23 4.05 2.41 6.09 1414.52 1.00 Number of divergences: 33