Automatic Guide Generation¶
AutoGuide¶
-
class
AutoGuide
(model, *, create_plates=None)[source]¶ Bases:
pyro.nn.module.PyroModule
Base class for automatic guides.
Derived classes must implement the
forward()
method, with the same*args, **kwargs
as the basemodel
.Auto guides can be used individually or combined in an
AutoGuideList
object.Parameters: - model (callable) – A pyro model.
- create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
-
call
(*args, **kwargs)[source]¶ Method that calls
forward()
and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlikeforward()
, this method can be traced bytorch.jit.trace_module()
.Warning
This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue <https://github.com/pytorch/pytorch/issues/27743>_.
-
median
(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
model
¶
AutoGuideList¶
-
class
AutoGuideList
(model, *, create_plates=None)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
,torch.nn.modules.container.ModuleList
Container class to combine multiple automatic guides.
Example usage:
guide = AutoGuideList(my_model) guide.add(AutoDiagonalNormal(poutine.block(model, hide=["assignment"]))) guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) svi = SVI(model, guide, optim, Trace_ELBO())
Parameters: model (callable) – a Pyro model -
append
(part)[source]¶ Add an automatic guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.
Parameters: part (AutoGuide or callable) – a partial guide to add
-
AutoCallable¶
-
class
AutoCallable
(model, guide, median=<function AutoCallable.<lambda>>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
AutoGuide
wrapper for simple callable guides.This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:
def my_local_guide(*args, **kwargs): ... guide = AutoGuideList(model) guide.add(AutoDelta(poutine.block(model, expose=['my_global_param'])) guide.add(my_local_guide) # automatically wrapped in an AutoCallable
To specify a median callable, you can instead:
def my_local_median(*args, **kwargs) ... guide.add(AutoCallable(model, my_local_guide, my_local_median))
For more complex guides that need e.g. access to plates, users should instead subclass
AutoGuide
.Parameters: - model (callable) – a Pyro model
- guide (callable) – a Pyro guide (typically over only part of the model)
- median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
AutoNormal¶
-
class
AutoNormal
(model, *, init_loc_fn=<function init_to_feasible>, init_scale=0.1, create_plates=None)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
This implementation of
AutoGuide
uses Normal(0, 1) distributions to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.It should be equivalent to :class: AutoDiagonalNormal , but with more convenient site names and with better support for
TraceMeanField_ELBO
.In
AutoDiagonalNormal
, if your model has N named parameters with dimensions k_i and sum k_i = D, you get a single vector of length D for your mean, and a single vector of length D for sigmas. This guide gives you N distinct normals that you can call by name.Usage:
guide = AutoNormal(model) svi = SVI(model, guide, ...)
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
- create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
-
forward
(*args, **kwargs)[source]¶ An automatic guide with the same
*args, **kwargs
as the basemodel
.Returns: A dict mapping sample site name to sampled value. Return type: dict
-
median
(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
quantiles
(quantiles, *args, **kwargs)[source]¶ Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters: quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1. Returns: A dict mapping sample site name to a list of quantile values. Return type: dict
AutoDelta¶
-
class
AutoDelta
(model, init_loc_fn=<function init_to_median>, *, create_plates=None)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
This implementation of
AutoGuide
uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Note
This class does MAP inference in constrained space.
Usage:
guide = AutoDelta(model) svi = SVI(model, guide, ...)
Latent variables are initialized using
init_loc_fn()
. To change the default behavior, create a custominit_loc_fn()
as described in Initialization , for example:def my_init_fn(site): if site["name"] == "level": return torch.tensor([-1., 0., 1.]) if site["name"] == "concentration": return torch.ones(k) return init_to_sample(site)
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
AutoContinuous¶
-
class
AutoContinuous
(model, init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].
This uses
torch.distributions.transforms
to transform each constrained latent variable to an unconstrained space, then concatenate all variables into a single unconstrained latent variable. Each derived class implements aget_posterior()
method returning a distribution over this single unconstrained latent variable.Assumes model structure and latent dimension are fixed, and all latent variables are continuous.
Parameters: model (callable) – a Pyro model Reference:
- [1] Automatic Differentiation Variational Inference,
- Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
-
forward
(*args, **kwargs)[source]¶ An automatic guide with the same
*args, **kwargs
as the basemodel
.Returns: A dict mapping sample site name to sampled value. Return type: dict
-
get_base_dist
()[source]¶ Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution
. This should not depend on the model’s *args, **kwargs.posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns: TorchDistribution
instance representing the base distribution.
-
get_transform
(*args, **kwargs)[source]¶ Returns the transform applied to the base distribution when the posterior is reparameterized as a
TransformedDistribution
. This may depend on the model’s *args, **kwargs.posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns: a Transform
instance.
-
median
(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
quantiles
(quantiles, *args, **kwargs)[source]¶ Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters: quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1. Returns: A dict mapping sample site name to a list of quantile values. Return type: dict
AutoMultivariateNormal¶
-
class
AutoMultivariateNormal
(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoMultivariateNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized by
init_loc_fn()
and the Cholesky factor is initialized to the identity times a small factor.Parameters: - model (callable) – A generative model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
AutoDiagonalNormal¶
-
class
AutoDiagonalNormal
(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoDiagonalNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized to zero and the scale is initialized to the identity times a small factor.
Parameters: - model (callable) – A generative model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
AutoLowRankMultivariateNormal¶
-
class
AutoLowRankMultivariateNormal
(model, init_loc_fn=<function init_to_median>, init_scale=0.1, rank=None)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoLowRankMultivariateNormal(model, rank=10) svi = SVI(model, guide, ...)
By default the
cov_diag
is initialized to a small constant and thecov_factor
is initialized randomly such that on averagecov_factor.matmul(cov_factor.t())
has the same scale ascov_diag
.Parameters: - model (callable) – A generative model.
- rank (int or None) – The rank of the low-rank part of the covariance matrix.
Defaults to approximately
sqrt(latent dim)
. - init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Approximate initial scale for the standard deviation of each (unconstrained transformed) latent variable.
AutoNormalizingFlow¶
-
class
AutoNormalizingFlow
(model, init_transform_fn)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via a sequence of bijective transforms (e.g. variousTransformModule
subclasses) to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
transform_init = partial(iterated, block_autoregressive, repeats=2) guide = AutoNormalizingFlow(model, transform_init) svi = SVI(model, guide, ...)
Parameters: - model (callable) – a generative model
- init_transform_fn – a callable which when provided with the latent
dimension returns an instance of
Transform
, orTransformModule
if the transform has trainable params.
AutoIAFNormal¶
-
class
AutoIAFNormal
(model, hidden_dim=None, init_loc_fn=None, num_transforms=1, **init_transform_kwargs)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoNormalizingFlow
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via aAffineAutoregressive
to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoIAFNormal(model, hidden_dim=latent_dim) svi = SVI(model, guide, ...)
Parameters: - model (callable) – a generative model
- hidden_dim (int) – number of hidden dimensions in the IAF
- init_loc_fn (callable) –
A per-site initialization function. See Initialization section for available functions.
Warning
This argument is only to preserve backwards compatibility and has no effect in practice.
- num_transforms (int) – number of
AffineAutoregressive
transforms to use in sequence. - init_transform_kwargs – other keyword arguments taken by
affine_autoregressive()
.
AutoLaplaceApproximation¶
-
class
AutoLaplaceApproximation
(model, init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.
Usage:
delta_guide = AutoLaplaceApproximation(model) svi = SVI(model, delta_guide, ...) # ...then train the delta_guide... guide = delta_guide.laplace_approximation()
By default the mean vector is initialized to an empirical prior median.
Parameters: - model (callable) – a generative model
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
-
laplace_approximation
(*args, **kwargs)[source]¶ Returns a
AutoMultivariateNormal
instance whose posterior’s loc and scale_tril are given by Laplace approximation.
AutoDiscreteParallel¶
Initialization¶
The pyro.infer.autoguide.initialization module contains initialization functions for automatic guides.
The standard interface for initialization is a function that inputs a Pyro
trace site
dict and returns an appropriately sized value
to serve
as an initial constrained value for a guide estimate.
-
init_to_feasible
(site)[source]¶ Initialize to an arbitrary feasible point, ignoring distribution parameters.
-
init_to_median
(site, num_samples=15)[source]¶ Initialize to the prior median; fallback to a feasible point if median is undefined.
-
class
InitMessenger
(init_fn)[source]¶ Bases:
pyro.poutine.messenger.Messenger
Initializes a site by replacing
.sample()
calls with values drawn from an initialization strategy. This is mainly for internal use by autoguide classes.Parameters: init_fn (callable) – An initialization function.