Neural Networks¶
The module pyro.nn provides implementations of neural network modules that are useful in the context of deep probabilistic programming.
Pyro Modules¶
Pyro includes an experimental class PyroModule
, a
subclass of torch.nn.Module
, whose attributes can be modified by Pyro
effects. To create a poutine-aware attribute, use either the
PyroParam
struct or the PyroSample
struct:
my_module = PyroModule()
my_module.x = PyroParam(torch.tensor(1.), constraint=constraints.positive)
my_module.y = PyroSample(dist.Normal(0, 1))
-
class
PyroParam
¶ Bases:
tuple
Structure to declare a Pyro-managed learnable parameter of a
PyroModule
.-
constraint
¶ Alias for field number 1
-
event_dim
¶ Alias for field number 2
-
init_value
¶ Alias for field number 0
-
-
class
PyroSample
¶ Bases:
tuple
Structure to declare a Pyro-managed random parameter of a
PyroModule
.-
prior
¶ Alias for field number 0
-
-
class
PyroModule
(name='')[source]¶ Bases:
torch.nn.modules.module.Module
EXPERIMENTAL Subclass of
torch.nn.Module
whose attributes can be modified by Pyro effects. Attributes can be set using helpersPyroParam
andPyroSample
, and methods can be decorated bypyro_method()
.Parameters
To create a Pyro-managed parameter attribute, set that attribute using either
torch.nn.Parameter
(for unconstrained parameters) orPyroParam
(for constrained parameters). Reading that attribute will then trigger apyro.param()
statement. For example:# Create Pyro-managed parameter attributes. my_module = PyroModule() my_module.loc = nn.Parameter(torch.tensor(0.)) my_module.scale = PyroParam(torch.tensor(1.), constraint=constraints.positive) # Read the attributes. loc = my_module.loc # Triggers a pyro.param statement. scale = my_module.scale # Triggers another pyro.param statement.
Note that, unlike normal
torch.nn.Module
s,PyroModule
s should not be registered withpyro.module()
statements.PyroModule
s can contain otherPyroModule
s and normaltorch.nn.Module
s. Accessing a normaltorch.nn.Module
attribute of aPyroModule
triggers apyro.module()
statement. If multiplePyroModule
s appear in a single Pyro model or guide, they should be included in a single rootPyroModule
for that model.PyroModule
s synchronize data with the param store at eachsetattr
,getattr
, anddelattr
event, based on the nested name of an attribute:- Setting
mod.x = x_init
tries to readx
from the param store. If a value is found in the param store, that value is copied intomod
andx_init
is ignored; otherwisex_init
is copied into bothmod
and the param store. - Reading
mod.x
tries to readx
from the param store. If a value is found in the param store, that value is copied intomod
; otherwisemod
’s value is copied into the param store. Finallymod
and the param store agree on a single value to return. - Deleting
del mod.x
removes a value from bothmod
and the param store.
Note two
PyroModule
of the same name will both synchronize with the global param store and thus contain the same data. When creating aPyroModule
, then deleting it, then creating another with the same name, the latter will be populated with the former’s data from the param store. To avoid this persistence, eitherpyro.clear_param_store()
or callclear()
before deleting aPyroModule
.PyroModule
s can be saved and loaded either directly usingtorch.save()
/torch.load()
or indirectly using the param store’ssave()
/load()
. Note thattorch.load()
will be overridden by any values in the param store, so it is safest topyro.clear_param_store()
before loading.Samples
To create a Pyro-managed random attribute, set that attribute using the
PyroSample
helper, specifying a prior distribution. Reading that attribute will then trigger apyro.sample()
statement. For example:# Create Pyro-managed random attributes. my_module.x = PyroSample(dist.Normal(0, 1)) my_module.y = PyroSample(lambda self: dist.Normal(self.loc, self.scale)) # Sample the attributes. x = my_module.x # Triggers a pyro.sample statement. y = my_module.y # Triggers one pyro.sample + two pyro.param statements.
Sampling is cached within each invocation of
.__call__()
or method decorated bypyro_method()
. Because sample statements can appear only once in a Pyro trace, you should ensure that traced access to sample attributes is wrapped in a single invocation of.__call__()
or method decorated bypyro_method()
.To make an existing module probabilistic, you can create a subclass and overwrite some parameters with
PyroSample
s:class RandomLinear(nn.Linear, PyroModule): # used as a mixin def __init__(self, in_features, out_features): super().__init__(in_features, out_features) self.weight = PyroSample( lambda self: dist.Normal(0, 1) .expand([self.out_features, self.in_features]) .to_event(2))
Mixin classes
PyroModule
can be used as a mixin class, and supports simple syntax for dynamically creating mixins, for example the following are equivalent:# Version 1. create a named mixin class class PyroLinear(nn.Linear, PyroModule): pass m.linear = PyroLinear(m, n) # Version 2. create a dynamic mixin class m.linear = PyroModule[nn.Linear](m, n)
This notation can be used recursively to create Bayesian modules, e.g.:
model = PyroModule[nn.Sequential]( PyroModule[nn.Linear](28 * 28, 100), PyroModule[nn.Sigmoid](), PyroModule[nn.Linear](100, 100), PyroModule[nn.Sigmoid](), PyroModule[nn.Linear](100, 10), ) assert isinstance(model, nn.Sequential) assert isinstance(model, PyroModule) # Now we can be Bayesian about weights in the first layer. model[0].weight = PyroSample( prior=dist.Normal(0, 1).expand([28 * 28, 100]).to_event(2)) guide = AutoDiagonalNormal(model)
Note that
PyroModule[...]
does not recursively mix inPyroModule
to submodules of the inputModule
; hence we needed to wrap each submodule of thenn.Sequential
above.Parameters: name (str) – Optional name for a root PyroModule. This is ignored in sub-PyroModules of another PyroModule. - Setting
-
pyro_method
(fn)[source]¶ Decorator for top-level methods of a
PyroModule
to enable pyro effects and cachepyro.sample
statements.This should be applied to all public methods that read Pyro-managed attributes, but is not needed for
.forward()
.
-
clear
(mod)[source]¶ Removes data from both a
PyroModule
and the param store.Parameters: mod (PyroModule) – A module to clear.
AutoRegressiveNN¶
-
class
AutoRegressiveNN
(input_dim, hidden_dims, param_dims=[1, 1], permutation=None, skip_connections=False, nonlinearity=ReLU())[source]¶ Bases:
pyro.nn.auto_reg_nn.ConditionalAutoRegressiveNN
An implementation of a MADE-like auto-regressive neural network.
Example usage:
>>> x = torch.randn(100, 10) >>> arn = AutoRegressiveNN(10, [50], param_dims=[1]) >>> p = arn(x) # 1 parameters of size (100, 10) >>> arn = AutoRegressiveNN(10, [50], param_dims=[1, 1]) >>> m, s = arn(x) # 2 parameters of size (100, 10) >>> arn = AutoRegressiveNN(10, [50], param_dims=[1, 5, 3]) >>> a, b, c = arn(x) # 3 parameters of sizes, (100, 1, 10), (100, 5, 10), (100, 3, 10)
Parameters: - input_dim (int) – the dimensionality of the input variable
- hidden_dims (list[int]) – the dimensionality of the hidden units per layer
- param_dims (list[int]) – shape the output into parameters of dimension (p_n, input_dim) for p_n in param_dims when p_n > 1 and dimension (input_dim) when p_n == 1. The default is [1, 1], i.e. output two parameters of dimension (input_dim), which is useful for inverse autoregressive flow.
- permutation (torch.LongTensor) – an optional permutation that is applied to the inputs and controls the order of the autoregressive factorization. in particular for the identity permutation the autoregressive structure is such that the Jacobian is upper triangular. By default this is chosen at random.
- skip_connections (bool) – Whether to add skip connections from the input to the output.
- nonlinearity (torch.nn.module) – The nonlinearity to use in the feedforward network such as torch.nn.ReLU(). Note that no nonlinearity is applied to the final network output, so the output is an unbounded real number.
Reference:
MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509] Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle
ConditionalAutoRegressiveNN¶
-
class
ConditionalAutoRegressiveNN
(input_dim, context_dim, hidden_dims, param_dims=[1, 1], permutation=None, skip_connections=False, nonlinearity=ReLU())[source]¶ Bases:
torch.nn.modules.module.Module
An implementation of a MADE-like auto-regressive neural network that can input an additional context variable. (See Reference [2] Section 3.3 for an explanation of how the conditional MADE architecture works.)
Example usage:
>>> x = torch.randn(100, 10) >>> y = torch.randn(100, 5) >>> arn = ConditionalAutoRegressiveNN(10, 5, [50], param_dims=[1]) >>> p = arn(x, context=y) # 1 parameters of size (100, 10) >>> arn = ConditionalAutoRegressiveNN(10, 5, [50], param_dims=[1, 1]) >>> m, s = arn(x, context=y) # 2 parameters of size (100, 10) >>> arn = ConditionalAutoRegressiveNN(10, 5, [50], param_dims=[1, 5, 3]) >>> a, b, c = arn(x, context=y) # 3 parameters of sizes, (100, 1, 10), (100, 5, 10), (100, 3, 10)
Parameters: - input_dim (int) – the dimensionality of the input variable
- context_dim (int) – the dimensionality of the context variable
- hidden_dims (list[int]) – the dimensionality of the hidden units per layer
- param_dims (list[int]) – shape the output into parameters of dimension (p_n, input_dim) for p_n in param_dims when p_n > 1 and dimension (input_dim) when p_n == 1. The default is [1, 1], i.e. output two parameters of dimension (input_dim), which is useful for inverse autoregressive flow.
- permutation (torch.LongTensor) – an optional permutation that is applied to the inputs and controls the order of the autoregressive factorization. in particular for the identity permutation the autoregressive structure is such that the Jacobian is upper triangular. By default this is chosen at random.
- skip_connections (bool) – Whether to add skip connections from the input to the output.
- nonlinearity (torch.nn.module) – The nonlinearity to use in the feedforward network such as torch.nn.ReLU(). Note that no nonlinearity is applied to the final network output, so the output is an unbounded real number.
Reference:
1. MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509] Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle
2. Inference Networks for Sequential Monte Carlo in Graphical Models [arXiv:1602.06701] Brooks Paige, Frank Wood