Causal Effect VAE¶
This module implements the Causal Effect Variational Autoencoder [1], which demonstrates a number of innovations including:
- a generative model for causal effect inference with hidden confounders;
- a model and guide with twin neural nets to allow imbalanced treatment; and
- a custom training loss that includes both ELBO terms and extra terms needed to train the guide to be able to answer counterfactual queries.
The main interface is the CEVAE
class, but users may customize by
using components Model
, Guide
,
TraceCausalEffect_ELBO
and utilities.
References
- [1] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, M. Welling (2017).
- Causal Effect Inference with Deep Latent-Variable Models.
CEVAE Class¶
-
class
CEVAE
(feature_dim, outcome_dist='bernoulli', latent_dim=20, hidden_dim=200, num_layers=3, num_samples=100)[source]¶ Bases:
torch.nn.modules.module.Module
Main class implementing a Causal Effect VAE [1]. This assumes a graphical model
where t is a binary treatment variable, y is an outcome, Z is an unobserved confounder, and X is a noisy function of the hidden confounder Z.
Example:
cevae = CEVAE(feature_dim=5) cevae.fit(x_train, t_train, y_train) ite = cevae.ite(x_test) # individual treatment effect ate = ite.mean() # average treatment effect
Variables: Parameters: - feature_dim (int) – Dimension of the feature space x.
- outcome_dist (str) – One of: “bernoulli” (default), “exponential”, “laplace”, “normal”, “studentt”.
- latent_dim (int) – Dimension of the latent variable z. Defaults to 20.
- hidden_dim (int) – Dimension of hidden layers of fully connected networks. Defaults to 200.
- num_layers (int) – Number of hidden layers in fully connected networks.
- num_samples (int) – Default number of samples for the
ite()
method. Defaults to 100.
-
fit
(x, t, y, num_epochs=100, batch_size=100, learning_rate=0.001, learning_rate_decay=0.1, weight_decay=0.0001)[source]¶ Train using
SVI
with theTraceCausalEffect_ELBO
loss.Parameters: - x (Tensor) –
- t (Tensor) –
- y (Tensor) –
- num_epochs (int) – Number of training epochs. Defaults to 100.
- batch_size (int) – Batch size. Defaults to 100.
- learning_rate (float) – Learning rate. Defaults to 1e-3.
- learning_rate_decay (float) – Learning rate decay over all epochs;
the per-step decay rate will depend on batch size and number of epochs
such that the initial learning rate will be
learning_rate
and the final learning rate will belearning_rate * learning_rate_decay
. Defaults to 0.1. - weight_decay (float) – Weight decay. Defaults to 1e-4.
Returns: list of epoch losses
-
ite
(x, num_samples=None, batch_size=None)[source]¶ Computes Individual Treatment Effect for a batch of data
x
.\[ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr] - \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr]\]This has complexity
O(len(x) * num_samples ** 2)
.Parameters: Returns: A
len(x)
-sized tensor of estimated effects.Return type:
-
to_script_module
()[source]¶ Compile this module using
torch.jit.trace_module()
, assuming self has already been fit to data.Returns: A traced version of self with an ite()
method.Return type: torch.jit.ScriptModule
CEVAE Components¶
-
class
Model
(config)[source]¶ Bases:
pyro.nn.module.PyroModule
Generative model for a causal model with latent confounder
z
and binary treatmentt
:z ~ p(z) # latent confounder x ~ p(x|z) # partial noisy observation of z t ~ p(t|z) # treatment, whose application is biased by z y ~ p(y|t,z) # outcome
Each of these distributions is defined by a neural network. The
y
distribution is defined by a disjoint pair of neural networks definingp(y|t=0,z)
andp(y|t=1,z)
; this allows highly imbalanced treatment.Parameters: config (dict) – A dict specifying feature_dim
,latent_dim
,hidden_dim
,num_layers
, andoutcome_dist
.
-
class
Guide
(config)[source]¶ Bases:
pyro.nn.module.PyroModule
Inference model for causal effect estimation with latent confounder
z
and binary treatmentt
:t ~ p(t|x) # treatment y ~ p(y|t,x) # outcome z ~ p(t|y,t,x) # latent confounder, an embedding
Each of these distributions is defined by a neural network. The
y
andz
distributions are defined by disjoint pairs of neural networks definingp(-|t=0,...)
andp(-|t=1,...)
; this allows highly imbalanced treatment.Parameters: config (dict) – A dict specifying feature_dim
,latent_dim
,hidden_dim
,num_layers
, andoutcome_dist
.
-
class
TraceCausalEffect_ELBO
(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]¶ Bases:
pyro.infer.trace_elbo.Trace_ELBO
Loss function for training a
CEVAE
. From [1], the CEVAE objective (to maximize) is:-loss = ELBO + log q(t|x) + log q(y|t,x)
Utilities¶
-
class
FullyConnected
(sizes, final_activation=None)[source]¶ Bases:
torch.nn.modules.container.Sequential
Fully connected multi-layer network with ELU activations.
-
class
DistributionNet
[source]¶ Bases:
torch.nn.modules.module.Module
Base class for distribution nets.
-
class
BernoulliNet
(sizes)[source]¶ Bases:
pyro.contrib.cevae.DistributionNet
FullyConnected
network outputting a singlelogits
value.This is used to represent a conditional probability distribution of a single Bernoulli random variable conditioned on a
sizes[0]
-sized real value, for example:net = BernoulliNet([3, 4]) z = torch.randn(3) logits, = net(z) t = net.make_dist(logits).sample()
-
class
ExponentialNet
(sizes)[source]¶ Bases:
pyro.contrib.cevae.DistributionNet
FullyConnected
network outputting a constrainedrate
.This is used to represent a conditional probability distribution of a single Normal random variable conditioned on a
sizes[0]
-size real value, for example:net = ExponentialNet([3, 4]) x = torch.randn(3) rate, = net(x) y = net.make_dist(rate).sample()
-
class
LaplaceNet
(sizes)[source]¶ Bases:
pyro.contrib.cevae.DistributionNet
FullyConnected
network outputting a constrainedloc,scale
pair.This is used to represent a conditional probability distribution of a single Laplace random variable conditioned on a
sizes[0]
-size real value, for example:net = LaplaceNet([3, 4]) x = torch.randn(3) loc, scale = net(x) y = net.make_dist(loc, scale).sample()
-
class
NormalNet
(sizes)[source]¶ Bases:
pyro.contrib.cevae.DistributionNet
FullyConnected
network outputting a constrainedloc,scale
pair.This is used to represent a conditional probability distribution of a single Normal random variable conditioned on a
sizes[0]
-size real value, for example:net = NormalNet([3, 4]) x = torch.randn(3) loc, scale = net(x) y = net.make_dist(loc, scale).sample()
-
class
StudentTNet
(sizes)[source]¶ Bases:
pyro.contrib.cevae.DistributionNet
FullyConnected
network outputting a constraineddf,loc,scale
triple, with shareddf > 1
.This is used to represent a conditional probability distribution of a single Student’s t random variable conditioned on a
sizes[0]
-size real value, for example:net = StudentTNet([3, 4]) x = torch.randn(3) df, loc, scale = net(x) y = net.make_dist(df, loc, scale).sample()
-
class
DiagNormalNet
(sizes)[source]¶ Bases:
torch.nn.modules.module.Module
FullyConnected
network outputting a constrainedloc,scale
pair.This is used to represent a conditional probability distribution of a
sizes[-1]
-sized diagonal Normal random variable conditioned on asizes[0]
-size real value, for example:net = DiagNormalNet([3, 4, 5]) z = torch.randn(3) loc, scale = net(z) x = dist.Normal(loc, scale).sample()
This is intended for the latent
z
distribution and the prewhitenedx
features, and conservatively clipsloc
andscale
values.