# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional, Tuple, Union
import torch
from torch.distributions import biject_to, constraints
import pyro.distributions as dist
from pyro.distributions.distribution import Distribution
from pyro.nn.module import PyroModule, PyroParam, pyro_method
from pyro.ops.tensor_utils import periodic_repeat
from pyro.poutine.guide import GuideMessenger
from pyro.poutine.runtime import get_plates
from .initialization import init_to_feasible, init_to_mean
from .utils import deep_getattr, deep_setattr, helpful_support_errors
class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)):
pass
[docs]class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerMeta):
"""
Base class for :class:`~pyro.poutine.guide.GuideMessenger` autoguides.
:param callable model: A Pyro model.
:param tuple amortized_plates: A tuple of names of plates over which guide
parameters should be shared. This is useful for subsampling, where a
guide parameter can be shared across all plates.
"""
def __init__(self, model: Callable, *, amortized_plates: Tuple[str, ...] = ()):
self.amortized_plates = amortized_plates
super().__init__(model)
@pyro_method
def __call__(self, *args, **kwargs):
# Since this guide creates parameters lazily, we need to avoid batching
# those parameters by a particle plate, in case the first time this
# guide is called is inside a particle plate. We assume all plates
# outside the model are particle plates.
self._outer_plates = tuple(f.name for f in get_plates())
try:
return super().__call__(*args, **kwargs)
finally:
del self._outer_plates
[docs] def call(self, *args, **kwargs):
"""
Method that calls :meth:`forward` and returns parameter values of the
guide as a `tuple` instead of a `dict`, which is a requirement for
JIT tracing. Unlike :meth:`forward`, this method can be traced by
:func:`torch.jit.trace_module`.
.. warning::
This method may be removed once PyTorch JIT tracer starts accepting
`dict` as valid return types. See
`issue <https://github.com/pytorch/pytorch/issues/27743>_`.
"""
result = self(*args, **kwargs)
return tuple(v for _, v in sorted(result.items()))
@torch.no_grad()
def _adjust_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor:
"""
Adjusts plates for generating initial values of parameters.
"""
for f in get_plates():
full_size = f.full_size or f.size
dim = f.dim - event_dim
if f in self._outer_plates or f.name in self.amortized_plates:
if -value.dim() <= dim:
value = value.mean(dim, keepdim=True)
elif f.size != full_size:
value = periodic_repeat(value, full_size, dim).contiguous()
for dim in range(value.dim() - event_dim):
value = value.squeeze(0)
return value
[docs]class AutoNormalMessenger(AutoMessenger):
"""
:class:`AutoMessenger` with mean-field normal posterior.
The mean-field posterior at any site is a transformed normal distribution.
This posterior is equivalent to :class:`~pyro.infer.autoguide.AutoNormal`
or :class:`~pyro.infer.autoguide.AutoDiagonalNormal`, but allows
customization via subclassing.
Derived classes may override the :meth:`get_posterior` behavior at
particular sites and use the mean-field normal behavior simply as a
default, e.g.::
def model(data):
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(0, 1))
c = pyro.sample("c", dist.Normal(a + b, 1))
pyro.sample("obs", dist.Normal(c, 1), obs=data)
class MyGuideMessenger(AutoNormalMessenger):
def get_posterior(self, name, prior):
if name == "c":
# Use a custom distribution at site c.
bias = pyro.param("c_bias", lambda: torch.zeros(()))
weight = pyro.param("c_weight", lambda: torch.ones(()),
constraint=constraints.positive)
scale = pyro.param("c_scale", lambda: torch.ones(()),
constraint=constraints.positive)
a = self.upstream_value("a")
b = self.upstream_value("b")
loc = bias + weight * (a + b)
return dist.Normal(loc, scale)
# Fall back to mean field.
return super().get_posterior(name, prior)
Note that above we manually computed ``loc = bias + weight * (a + b)``.
Alternatively we could reuse the model-side computation by setting ``loc =
bias + weight * prior.loc``::
class MyGuideMessenger_v2(AutoNormalMessenger):
def get_posterior(self, name, prior):
if name == "c":
# Use a custom distribution at site c.
bias = pyro.param("c_bias", lambda: torch.zeros(()))
scale = pyro.param("c_scale", lambda: torch.ones(()),
constraint=constraints.positive)
weight = pyro.param("c_weight", lambda: torch.ones(()),
constraint=constraints.positive)
loc = bias + weight * prior.loc
return dist.Normal(loc, scale)
# Fall back to mean field.
return super().get_posterior(name, prior)
:param callable model: A Pyro model.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:param float init_scale: Initial scale for the standard deviation of each
(unconstrained transformed) latent variable.
:param tuple amortized_plates: A tuple of names of plates over which guide
parameters should be shared. This is useful for subsampling, where a
guide parameter can be shared across all plates.
"""
def __init__(
self,
model: Callable,
*,
init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible),
init_scale: float = 0.1,
amortized_plates: Tuple[str, ...] = (),
):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
super().__init__(model, amortized_plates=amortized_plates)
self.init_loc_fn = init_loc_fn
self._init_scale = init_scale
self._computing_median = False
[docs] def get_posterior(
self, name: str, prior: Distribution
) -> Union[Distribution, torch.Tensor]:
if self._computing_median:
return self._get_posterior_median(name, prior)
with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
loc, scale = self._get_params(name, prior)
posterior = dist.TransformedDistribution(
dist.Normal(loc, scale).to_event(transform.domain.event_dim),
transform.with_cache(),
)
return posterior
def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
return loc, scale
except AttributeError:
pass
# Initialize.
with torch.no_grad():
transform = biject_to(prior.support)
event_dim = transform.domain.event_dim
constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
unconstrained = transform.inv(constrained)
init_loc = self._adjust_plates(unconstrained, event_dim)
init_scale = torch.full_like(init_loc, self._init_scale)
deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim))
deep_setattr(
self,
"scales." + name,
PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim),
)
return self._get_params(name, prior)
def _get_posterior_median(self, name, prior):
transform = biject_to(prior.support)
loc, scale = self._get_params(name, prior)
return transform(loc)
[docs]class AutoHierarchicalNormalMessenger(AutoNormalMessenger):
"""
:class:`AutoMessenger` with mean-field normal posterior conditional on all dependencies.
The mean-field posterior at any site is a transformed normal distribution,
the mean of which depends on the value of that site given its dependencies in the model::
loc_total = loc + transform.inv(prior.mean) * weight
Where the value of ``prior.mean`` is conditional on upstream sites in the model,
``loc`` is independent component of the mean in the untransformed space,
``weight`` is element-wise factor that scales the prior mean.
This approach doesn't work for distributions that don't have the mean.
Derived classes may override particular sites and use this simply as a
default, see :class:`AutoNormalMessenger` documentation for example.
:param callable model: A Pyro model.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:param float init_scale: Initial scale for the standard deviation of each
(unconstrained transformed) latent variable.
:param float init_weight: Initial value for the weight of the contribution
of hierarchical sites to posterior mean for each latent variable.
:param list hierarchical_sites: List of latent variables (model sites)
that have hierarchical dependencies.
If None, all sites are assumed to have hierarchical dependencies. If None, for the sites
that don't have upstream sites, the loc and weight of the guide
are representing/learning deviation from the prior.
"""
# 'element-wise' or 'scalar'
weight_type = "element-wise"
def __init__(
self,
model: Callable,
*,
init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible),
init_scale: float = 0.1,
amortized_plates: Tuple[str, ...] = (),
init_weight: float = 1.0,
hierarchical_sites: Optional[list] = None,
):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
super().__init__(model, amortized_plates=amortized_plates)
self.init_loc_fn = init_loc_fn
self._init_scale = init_scale
self._init_weight = init_weight
self._hierarchical_sites = hierarchical_sites
self._computing_median = False
[docs] def get_posterior(
self, name: str, prior: Distribution
) -> Union[Distribution, torch.Tensor]:
if self._computing_median:
return self._get_posterior_median(name, prior)
with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
# If hierarchical_sites not specified all sites are assumed to be hierarchical
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight
posterior = dist.TransformedDistribution(
dist.Normal(loc, scale).to_event(transform.domain.event_dim),
transform.with_cache(),
)
return posterior
else:
# Fall back to mean field when hierarchical_sites list is not empty and site not in the list.
return super().get_posterior(name, prior)
def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
weight = deep_getattr(self.weights, name)
return loc, scale, weight
else:
return loc, scale
except AttributeError:
pass
# Initialize.
with torch.no_grad():
transform = biject_to(prior.support)
event_dim = transform.domain.event_dim
constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
unconstrained = transform.inv(constrained)
init_loc = self._adjust_plates(unconstrained, event_dim)
init_scale = torch.full_like(init_loc, self._init_scale)
if self.weight_type == "scalar":
# weight is a single value parameter
init_weight = torch.full((), self._init_weight)
if self.weight_type == "element-wise":
# weight is element-wise
init_weight = torch.full_like(init_loc, self._init_weight)
# if site is hierarchical substract contribution of dependencies from init_loc
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
init_prior_mean = transform.inv(prior.mean)
init_prior_mean = self._adjust_plates(init_prior_mean, event_dim)
init_loc = init_loc - init_weight * init_prior_mean
deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim))
deep_setattr(
self,
"scales." + name,
PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim),
)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
if self.weight_type == "scalar":
# weight is a single value parameter
deep_setattr(
self,
"weights." + name,
PyroParam(init_weight, constraint=constraints.positive),
)
if self.weight_type == "element-wise":
# weight is element-wise
deep_setattr(
self,
"weights." + name,
PyroParam(
init_weight,
constraint=constraints.positive,
event_dim=event_dim,
),
)
return self._get_params(name, prior)
def _get_posterior_median(self, name, prior):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight
else:
loc, scale = self._get_params(name, prior)
return transform(loc)
[docs]class AutoRegressiveMessenger(AutoMessenger):
"""
:class:`AutoMessenger` with recursively affine-transformed priors using
prior dependency structure.
The posterior at any site is a learned affine transform of the prior,
conditioned on upstream posterior samples. The affine transform operates in
unconstrained space. This supports only continuous latent variables.
Derived classes may override the :meth:`get_posterior` behavior at
particular sites and use the regressive behavior simply as a default,
e.g.::
class MyGuideMessenger(AutoRegressiveMessenger):
def get_posterior(self, name, prior):
if name == "x":
# Use a custom distribution at site x.
loc = pyro.param("x_loc", lambda: torch.zeros(prior.shape()))
scale = pyro.param("x_scale", lambda: torch.ones(prior.shape())),
constraint=constraints.positive
return dist.Normal(loc, scale).to_event(prior.event_dim())
# Fall back to autoregressive.
return super().get_posterior(name, prior)
.. warning:: This guide currently does not support jit-based elbos.
:param callable model: A Pyro model.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:param float init_scale: Initial scale for the standard deviation of each
(unconstrained transformed) latent variable.
:param tuple amortized_plates: A tuple of names of plates over which guide
parameters should be shared. This is useful for subsampling, where a
guide parameter can be shared across all plates.
"""
def __init__(
self,
model: Callable,
*,
init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible),
init_scale: float = 0.1,
amortized_plates: Tuple[str, ...] = (),
):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
super().__init__(model, amortized_plates=amortized_plates)
self.init_loc_fn = init_loc_fn
self._init_scale = init_scale
[docs] def get_posterior(
self, name: str, prior: Distribution
) -> Union[Distribution, torch.Tensor]:
with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
loc, scale = self._get_params(name, prior)
affine = dist.transforms.AffineTransform(
loc, scale, event_dim=transform.domain.event_dim, cache_size=1
)
posterior = dist.TransformedDistribution(
prior, [transform.inv.with_cache(), affine, transform.with_cache()]
)
return posterior
def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
return loc, scale
except AttributeError:
pass
# Initialize.
with torch.no_grad():
transform = biject_to(prior.support)
event_dim = transform.domain.event_dim
constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
unconstrained = transform.inv(constrained)
# Initialize the distribution to be an affine combination:
# init_scale * prior + (1 - init_scale) * init_loc
init_loc = self._adjust_plates(unconstrained, event_dim)
init_loc = init_loc * (1 - self._init_scale)
init_scale = torch.full_like(init_loc, self._init_scale)
deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim))
deep_setattr(
self,
"scales." + name,
PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim),
)
return self._get_params(name, prior)