I've been reading through Causal Inference: The Mixtable which describes a variety of quasi-experimental designs that attempt to estimate causal effects from observational studies. In this post, I'll port one of the book's frequentist examples to a Bayesian workflow in numpyro
.
Import Statements¶
To allow numpyro
to access multiple CPUs during sampling, we need to run the following command before importing jax
.
import os
os.environ["JAX_PLATFORMS"] = "cpu"
import numpyro
numpyro.set_host_device_count(4)
The utilities described in previous posts will be gathered in the pyro_util
module.
import pyro_util
import pandas as pd
import numpy as np
import numpyro.distributions as dist
import arviz as az
import xarray as xr
import seaborn as sns
Difference in Differences¶
The difference in difference design is a model of outcomes $y_{t,u}$ where $u$ is a binary variable indexing populations, $t$ is a categorical variable indexing time, and $x$ is variable giving the amount of treatment each unit encountered which should be different for the two populations over some timeframe. We will assume that $E[y_{t,u}] = a + bt + cu + dx$. In words, each population has a different baseline level for $y$ at time $0$ (given by $a$ and $a+c$ respectively). Over time, in the absence of treatment, we would expect both populations to increase by the same amount $b$. But because the treatment occurred, we expect to see an additional increase $d$.
The Mixtape book applies this idea to estimate whether the Roe decision in 1973 resulted in decreased teen gonorrhea rates 15 years later (the implication being that mothers previously forced to bear children they felt ill equipt to raise would set home environments more conducive to risky sexual behavior in their children).
df = pd.read_stata("https://github.com/scunning1975/mixtape/raw/master/abortion%202.dta")
We can do a difference-in-differences study here because five states legalized abortion before the Roe decision, in 1970. These five states will constitute our first population, and the rest of the states will be our second population. In 1985, none of the teenagers in the 15-20 age group would have been born in a post-legalization world, so this sets the baseline for both populations. In 1986, however, one of the five ages in the first population would be been born post-legalization, so we expect to see some of the treatment effect brought by legalization. By 1991, this entire age group in the first population will represent teens born post-legalization.
We should see this effect in the second population as well, but shifted forward four years. As a result, we should expect to see the difference in ghonnaria rates between the populations peak in '88, when 3 year's worth of post-legalization teens are included in one group and but none are included in the other. After this point, fraction of teens born post-legalization will grow in the second group as well, making the treatement effect fall back to baseline levels by '93.
The time period we're investigating only extends to 1993.
df = df[(df.year <= 1993)]
The amount of 'treatment' present (or the fraction of the given demographic group born after legalization in their state) can be computed as follows:
treatment_group = df.repeal == 1.0
df['treated_frac'] = 0.0
age = df.age.astype(int)
df.loc[treatment_group, 'treated_frac'] = np.clip(
df[treatment_group].year - 1970 - age, 0, 5) / 5
df.loc[~treatment_group, 'treated_frac'] = np.clip(
df[~treatment_group].year - 1973 - age, 0, 5) / 5
Heres's what the treatment fractions look like over time for the 15-20 age group.
sns.relplot(df[df.age == 15], x="year", y="treated_frac", kind="line", hue="repeal");
We'll treat year, treatment status and state as categorical variables.
df['year'] = pd.CategoricalIndex(df.year)
df['repeal'] = pd.CategoricalIndex(df.repeal)
df['fip'] = pd.CategoricalIndex(df.fip)
To decrease the amount of heterogeneity our model will need to explain, we're restrict our analysis to black females in the 15-20 year old age group.
bf15_mask = (df.race == 2) & (df.sex == 2) & (df.age == 15)
bf15 = df[bf15_mask]
We'll attempt to control for a few confounding factors like population poverty and alcohol levels.
bf15 = bf15.loc[:,['fip', 'year', 'lnr', 'treated_frac', 'repeal', 'poverty', 'alcohol']]
We'll also assume the data is missing at random so that null values can be ignored.
bf15.dropna(inplace=True)
Numpyro ignores pandas' index information, so for consistency we need to reset the pandas index.
bf15 = bf15.reset_index(drop=True)
bf15.index.name = 'obs'
At first glance, the hypothesized treatment effect looks plausable. Repeal-states seem to have gonorrhea rates that decrease faster than non-repeal states up to '88, after which the two rates settle back to baseline separation levels.
sns.relplot(bf15, x="year", y="lnr", hue="repeal", kind="line", errorbar="sd");
A Linear Model¶
As a first pass, we might try a model in which the observed difference in gonorrhea rates is directly proportional to the fraction of 15-20 year olds born after legalization given by the treated_frac
variable we computed above.
def model0(predictive=False):
pyro_util.glm('lnr ~ year + treated_frac + center(alcohol) + center(poverty) - 1',
bf15, family=dist.NegativeBinomial2, predictive=predictive)
mcmc = pyro_util.fit_nuts(model0, num_samples=1500, predictive=False)
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
results = pyro_util.from_numpyro(bf15, model0, mcmc)
As expected, the posterior treatment effect has almost all its probability mass over negative values.
az.plot_posterior(results, "treated_frac", ref_val=0.0);
To check that this generative model produces data similar to what we observed, we can sample from the posterior predictive distribution. These samples look much like the observed data including the parabola-shaped dip in the rates for repeal-states that we hypothesize is caused by the treatment effect.
pred_plot = xr.Dataset({'repeal': results.constant_data.repeal,
'lnr': np.log(az.extract(results, group='posterior_predictive',
num_samples=1000, var_names='y')),
'year': results.constant_data.year}).to_dataframe()
sns.relplot(pred_plot, x='year', y='lnr', hue='repeal', kind='line', errorbar="sd");
az.plot_ppc(results, var_names='y');
The fraction of posterior predictive samples below each observation is pretty evenly distributed- it seems posterior predictive samples from this model are neither under nor over-dispersed compared to our observations.
az.plot_loo_pit(results, y="y");
The only point that might give us pause is that the predicted national trend for gonorrhea rates in the absence of treatment curves sharply uprwards. Unless we can think of other events in the late '80s that would explain such a national trend, this indicates that our assumption of a linear treatment effect might be overly restrictive.
trend_plot = xr.Dataset({
'trend': np.log(az.extract(results, var_names="year", num_samples=1000)),
'year': 1985 + results.posterior.years}).to_dataframe()
sns.relplot(trend_plot, x='year', y='trend', kind='line', errorbar="sd");
Relaxing Linearity Assumptions¶
Instead of assuming a linear treatment effect, we can look at the posterior rate differences between the treatment and control groups each year. Visually inspecting these differences over time should give a U-shape with its minimum around '88.
def model1(predictive=False):
pyro_util.glm('lnr ~ year + year:repeal + center(alcohol) + center(poverty) - 1',
bf15, family=dist.NegativeBinomial2, predictive=predictive)
mcmc = pyro_util.fit_nuts(model1, num_samples=1500, predictive=False)
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
results = pyro_util.from_numpyro(bf15, model1, mcmc)
Although a parabola shape is visible to an extent, the uncertainty in our posterior is pretty large.
coeffs = az.extract(results, var_names='year:repeal', num_samples=1000,
keep_dataset=True).rename_dims({'year:repeals': 'years'})['year:repeal'].drop_indexes('year:repeals')
coeffs = az.extract(results, var_names='year:repeal', num_samples=1000,
keep_dataset=True).rename_dims({'year:repeals': 'years'})['year:repeal'].drop_indexes('year:repeals')
diff_plot = xr.Dataset({'trend': coeffs, 'year': 1985 + results.posterior.years}).to_dataframe()
sns.relplot(diff_plot, x='year', y='trend', kind='line', errorbar="sd");
Triple Differences¶
What if the states with early repeals shared other factors that lead to their changing gonorrea rates? Controlling for poverty and alcoholism might not be enough. How might we rule out other potential confounders?
Instead of assuming that the trend among repeal-states would have followed the same trend in non-repeal states but for the differences in legalization timing, we can allow the repeal states to follow a different trend than non-repeal states. This repeal-specific trend will affect both the 15-20 year old age group we've been studying as well as the 25-30 age group, for which we also have data. But the repeal-intervention clearly had no effect on the 25-30 age group, as they were already born. We'll assume that the difference in the rates of gonorrea between the two age groups follows the same trend in repeal states as it did in non-repeal states. This gives us a triple difference, as we're dealing with differences in time, in state, and in age group.
First, we must assemble the combined dataset of both 15-20 and 25-30 age groups.
bf25_mask = (df.race == 2) & (df.sex == 2) & (df.age == 25)
df["older"] = True
df.loc[bf15_mask, "older"] = False
df.older = pd.CategoricalIndex(df.older)
bf = df[bf15_mask | bf25_mask]
bf = bf.dropna()
Below, we can see the trends of gonorrhea rates over time for both age groups. The rate differences between the younger and older demographics seem to follow the same trend for the control group as they do for the treatment group, just like we expect. But it seems like the older population's rates closely mirror that of the younger population: legalization doesn't seemt to be doing that much. We'll quantify this observation more formally later.
sns.relplot(bf, x="year", y="lnr", hue="repeal", kind="line", style="older", errorbar="sd");
Below we define a triple difference model with a year:older
term capturing the national trend of differences between the older and younger popuations, a set of year*repeal
terms capturing the separate time trends for the control and treatment populations, and a treated_frac
allowing for a linear treatment effect.
def model2(predictive=False):
pyro_util.glm('lnr ~ treated_frac + year:older + year*repeal', bf,
family=dist.NegativeBinomial2, predictive=predictive)
mcmc = pyro_util.fit_nuts(model2, num_samples=1500, predictive=False)
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
results = pyro_util.from_numpyro(bf, model2, mcmc)
When we consider both the older and younger populations, a negative treatment effect no longer seems as likely. There's enough uncertainty in the posterior that it's unclear what effect, if any, legalization was having. This conclusion agrees with the textbook's frequentist analysis.
az.plot_posterior(results, "treated_frac", ref_val=0.0);