Probabilistic programming languages like Stan and Pymc3 have special DSLs for generalized linear models with data driven priors (bambi and stan_glm respectively). In this post, we'll throw together something comparable for Pyro.

In [1]:
import os
os.environ["JAX_PLATFORMS"] = "cpu"
import numpy as np
import jax.numpy as jnp
import jax
import numpyro
numpyro.set_host_device_count(4)
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from numpyro import handlers
from typing import Optional

We'll use the patsy library for an R-like GLM formula syntax.

In [2]:
from patsy import dmatrices, Term
import pandas as pd

Let's start with a fake dataset to test our tool on.

In [3]:
year = [0, 0, 1, 1, 2, 2]
x = np.random.randn(6)
bases = 2 * np.random.randn(3)
df = pd.DataFrame({'year': pd.CategoricalIndex(year), 'x': x, 'y': x + bases[year] + np.random.randn(6)})

We know the data generation process obeys the following formula:

In [4]:
formula = 'y ~ year + x'

Patsy can turn this formula and dataframe into a design matrix.

In [5]:
y, design = dmatrices(formula, df)
In [6]:
design
Out[6]:
DesignMatrix with shape (6, 4)
  Intercept  year[T.1]  year[T.2]         x
          1          0          0   0.82525
          1          0          0  -1.33658
          1          1          0  -2.10089
          1          1          0   0.12643
          1          0          1   0.34730
          1          0          1   0.55807
  Terms:
    'Intercept' (column 0)
    'year' (columns 1:3)
    'x' (column 3)

This design-matrix object knows how to relate each term in the formula to a slice of columns indices.

In [7]:
design.design_info.term_slices
Out[7]:
OrderedDict([(Term([]), slice(0, 1, None)),
             (Term([EvalFactor('year')]), slice(1, 3, None)),
             (Term([EvalFactor('x')]), slice(3, 4, None))])

We'll give all non-intercept parameters for covariate $x$ a Normal prior centered on zero with a weakly informative standard deviation $2.5 * \sigma_y / \sigma_x$. The intercept will have a Normal prior centered at the maximum likelihood estimate with weakly informative standard deviation $2.5 * \sigma_y$. This is roughly the same approach taken by stan_glm.

Putting these pieces together gives the following general purpose GLM function.

In [8]:
def glm(formula: str, df: pd.DataFrame, family: dist.Distribution = dist.Normal,
    groups: Optional[str] = None, weights: Optional[str] = None, predictive: bool = False):
    """
    Create a generalized linear model following `formula`.
    We assume that the associated link function has already been applied to the RHS,
    so e.g. count observations with Poisson families are on a log scale.
    If the `groups` argument is provided, separate variance parameters are found for each group.
    """
    y, design = dmatrices(formula, df)
    y = jnp.array(y[:,0])
    X = jnp.array(design)
    mle_params = jnp.linalg.solve(X.T @ X, X.T @ y)
    stdy = y.std()
    stds = 2.5 * stdy / X.std(axis=0)
    mu = 0.0 # Observation mean for each unit
    for (k,v) in design.design_info.term_slices.items():
        subX = X[:, v]
        loc = mle_params[v] if k == Term([]) else jnp.zeros(1)
        K = subX.shape[1]
        if K == 1:
            beta = numpyro.sample(k.name(), dist.Normal(loc[0], 2.5 * stdy if  k == Term([]) else stds[v][0]))
            mu = mu + jnp.array(subX[:, 0]) * beta
        else: # Stack categorical factors into their own plate
            with numpyro.plate(k.name() + "s", K):
                beta = numpyro.sample(k.name(), dist.Normal(loc, 2.5 * stdy if  k == Term([]) else stds[v]))
                mu = mu + subX @ beta

    if family is not dist.Poisson:
        if groups is not None:
            with numpyro.plate("groups", df[groups].nunique()):
               sigmas = numpyro.sample("sigma", dist.Exponential(1 / stdy))
               sigma = sigmas[df[groups].cat.codes.to_numpy()]
        else:
            sigma = numpyro.sample("sigma", dist.Exponential(1 / stdy))
    if family is dist.Poisson:
        data_dist = dist.Poisson(jnp.exp(mu))
        obs = jnp.exp(y) # Assuming data is log counts
    elif family is dist.NegativeBinomial2:
        data_dist = dist.NegativeBinomial2(jnp.exp(mu), 1 / sigma)
        obs = jnp.exp(y) # Assuming data is log counts
    elif family in [dist.StudentT, dist.Normal, dist.Cauchy, dist.Laplace]:
        data_dist = family(mu, sigma)
        obs = y
    else:
        raise Exception("Unknown family")
    with handlers.scale(scale=jnp.array(df[weights]) if weights is not None else 1.0):
        with numpyro.plate("obs", X.shape[0]):
            numpyro.deterministic("mu", mu)
            return numpyro.sample('y', data_dist, obs=None if predictive else obs)

Let's try it out!

In [9]:
def model():
    glm(formula, df)
In [10]:
def fit_nuts(model, *args, num_samples=1000, **kwargs):
    "Run four chains with 500 warmup samples using the NUTS kernel."
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=num_samples, num_chains=4)
    mcmc.run(jax.random.PRNGKey(0), *args, **kwargs)
    return mcmc
In [11]:
mcmc = fit_nuts(model)
  0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s]
In [12]:
mcmc.print_summary()
                 mean       std    median      5.0%     95.0%     n_eff     r_hat
  Intercept     -2.80      0.83     -2.83     -4.11     -1.65   1370.63      1.00
      sigma      0.97      0.70      0.76      0.23      1.83    532.04      1.02
          x      1.24      0.50      1.24      0.43      1.95   1963.75      1.00
    year[0]      3.72      1.23      3.76      1.94      5.69   1542.31      1.00
    year[1]      4.03      1.23      4.05      2.41      6.09   1414.52      1.00

Number of divergences: 33