Primitives¶
-
clear_param_store
()[source]¶ Clears the ParamStore. This is especially useful if you’re working in a REPL.
-
param
(name, *args, **kwargs)[source]¶ Saves the variable as a parameter in the param store. To interact with the param store or write to disk, see Parameters.
Parameters: - name (str) – name of parameter
- init_tensor (torch.Tensor or callable) – initial tensor or lazy callable that returns a tensor.
For large tensors, it may be cheaper to write e.g.
lambda: torch.randn(100000)
, which will only be evaluated on the initial statement. - constraint (torch.distributions.constraints.Constraint) – torch constraint, defaults to
constraints.real
. - event_dim (int) – (optional) number of rightmost dimensions unrelated to baching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed.
Returns: parameter
Return type:
-
sample
(name, fn, *args, **kwargs)[source]¶ Calls the stochastic function fn with additional side-effects depending on name and the enclosing context (e.g. an inference algorithm). See Intro I and Intro II for a discussion.
Parameters: - name – name of sample
- fn – distribution class or function
- obs – observed datum (optional; should only be used in context of inference) optionally specified in kwargs
- infer (dict) – Optional dictionary of inference parameters specified in kwargs. See inference documentation for details.
Returns: sample
-
factor
(name, log_factor)[source]¶ Factor statement to add arbitrary log probability factor to a probabilisitic model.
Parameters: - name (str) – Name of the trivial sample
- log_factor (torch.Tensor) – A possibly batched log probability factor.
-
deterministic
(name, value, event_dim=None)[source]¶ EXPERIMENTAL Deterministic statement to add a
Delta
site with name name and value value to the trace. This is useful when we want to record values which are completely determined by their parents. For example:x = sample("x", dist.Normal(0, 1)) x2 = deterministic("x2", x ** 2)
Note
The site does not affect the model density. This currently converts to a
sample()
statement, but may change in the future.Parameters: - name (str) – Name of the site.
- value (torch.Tensor) – Value of the site.
- event_dim (int) – Optional event dimension, defaults to value.ndim.
-
subsample
(data, event_dim)[source]¶ EXPERIMENTAL Subsampling statement to subsample data based on enclosing
plate
s.This is typically called on arguments to
model()
when subsampling is performed automatically byplate
s by passing either thesubsample
orsubsample_size
kwarg. For example the following are equivalent:# Version 1. using pyro.subsample() def model(data): with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind: data = data[ind] # ... # Version 2. using indexing def model(data): with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()): data = pyro.subsample(data, event_dim=0) # ...
Parameters: Returns: A subsampled version of
data
Return type:
-
class
plate
(name, size=None, subsample_size=None, subsample=None, dim=None, use_cuda=None, device=None)[source]¶ Bases:
pyro.poutine.plate_messenger.PlateMessenger
Construct for conditionally independent sequences of variables.
plate
can be used either sequentially as a generator or in parallel as a context manager (formerlyirange
andiarange
, respectively).Sequential
plate
is similar torange()
in that it generates a sequence of values.Vectorized
plate
is similar totorch.arange()
in that it yields an array of indices by which other tensors can be indexed.plate
differs fromtorch.arange()
in that it also informs inference algorithms that the variables being indexed are conditionally independent. To do this,plate
is a provided as context manager rather than a function, and users must guarantee that all computation within anplate
context is conditionally independent:with plate("name", size) as ind: # ...do conditionally independent stuff with ind...
Additionally,
plate
can take advantage of the conditional independence assumptions by subsampling the indices and informing inference algorithms to scale various computed values. This is typically used to subsample minibatches of data:with plate("data", len(data), subsample_size=100) as ind: batch = data[ind] assert len(batch) == 100
By default
subsample_size=False
and this simply yields atorch.arange(0, size)
. If0 < subsample_size <= size
this yields a single random batch of indices of sizesubsample_size
and scales all log likelihood terms bysize/batch_size
, within this context.Warning
This is only correct if all computation is conditionally independent within the context.
Parameters: - name (str) – A unique name to help inference algorithms match
plate
sites between models and guides. - size (int) – Optional size of the collection being subsampled (like stop in builtin range).
- subsample_size (int) – Size of minibatches used in subsampling. Defaults to size.
- subsample (Anything supporting len().) – Optional custom subsample for user-defined subsampling schemes. If specified, then subsample_size will be set to len(subsample).
- dim (int) – An optional dimension to use for this independence index.
If specified,
dim
should be negative, i.e. should index from the right. If not specified,dim
is set to the rightmost dim that is left of all enclosingplate
contexts. - use_cuda (bool) – DEPRECATED, use the device arg instead.
Optional bool specifying whether to use cuda tensors for subsample
and log_prob. Defaults to
torch.Tensor.is_cuda
. - device (str) – Optional keyword specifying which device to place the results of subsample and log_prob on. By default, results are placed on the same device as the default tensor.
Returns: A reusabe context manager yielding a single 1-dimensional
torch.Tensor
of indices.Examples:
>>> # This version declares sequential independence and subsamples data: >>> for i in plate('data', 100, subsample_size=10): ... if z[i]: # Control flow in this example prevents vectorization. ... obs = sample('obs_{}'.format(i), dist.Normal(loc, scale), obs=data[i])
>>> # This version declares vectorized independence: >>> with plate('data'): ... obs = sample('obs', dist.Normal(loc, scale), obs=data)
>>> # This version subsamples data in vectorized way: >>> with plate('data', 100, subsample_size=10) as ind: ... obs = sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This wraps a user-defined subsampling method for use in pyro: >>> ind = torch.randint(0, 100, (10,)).long() # custom subsample >>> with plate('data', 100, subsample=ind): ... obs = sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This reuses two different independence contexts. >>> x_axis = plate('outer', 320, dim=-1) >>> y_axis = plate('inner', 200, dim=-2) >>> with x_axis: ... x_noise = sample("x_noise", dist.Normal(loc, scale)) ... assert x_noise.shape == (320,) >>> with y_axis: ... y_noise = sample("y_noise", dist.Normal(loc, scale)) ... assert y_noise.shape == (200, 1) >>> with x_axis, y_axis: ... xy_noise = sample("xy_noise", dist.Normal(loc, scale)) ... assert xy_noise.shape == (200, 320)
See SVI Part II for an extended discussion.
- name (str) – A unique name to help inference algorithms match
-
class
iarange
(*args, **kwargs)[source]¶ Bases:
pyro.primitives.plate
-
plate_stack
(prefix, sizes, rightmost_dim=-1)[source]¶ Create a contiguous stack of
plate
s with dimensions:rightmost_dim - len(sizes), ..., rightmost_dim
Parameters:
-
module
(name, nn_module, update_module_params=False)[source]¶ Takes a torch.nn.Module and registers its parameters with the ParamStore. In conjunction with the ParamStore save() and load() functionality, this allows the user to save and load modules.
Parameters: - name (str) – name of module
- nn_module (torch.nn.Module) – the module to be registered with Pyro
- update_module_params – determines whether Parameters in the PyTorch module get overridden with the values found in the ParamStore (if any). Defaults to False
Returns: torch.nn.Module
-
random_module
(name, nn_module, prior, *args, **kwargs)[source]¶ Warning
The random_module primitive is deprecated, and will be removed in a future release. Use
PyroModule
instead to to create Bayesian modules fromtorch.nn.Module
instances. See the Bayesian Regression tutorial for an example.Places a prior over the parameters of the module nn_module. Returns a distribution (callable) over nn.Modules, which upon calling returns a sampled nn.Module.
Parameters: - name (str) – name of pyro module
- nn_module (torch.nn.Module) – the module to be registered with pyro
- prior – pyro distribution, stochastic function, or python dict with parameter names as keys and respective distributions/stochastic functions as values.
Returns: a callable which returns a sampled module
-
enable_validation
(is_validate=True)[source]¶ Enable or disable validation checks in Pyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is useful for debugging. Since some of these checks may be expensive, we recommend turning this off for mature models.
Parameters: is_validate (bool) – (optional; defaults to True) whether to enable validation checks.
-
validation_enabled
(is_validate=True)[source]¶ Context manager that is useful when temporarily enabling/disabling validation checks.
Parameters: is_validate (bool) – (optional; defaults to True) temporary validation check override.
-
trace
(fn=None, ignore_warnings=False, jit_options=None)[source]¶ Lazy replacement for
torch.jit.trace()
that works with Pyro functions that callpyro.param()
.The actual compilation artifact is stored in the
compiled
attribute of the output. Call diagnostic methods on this attribute.Example:
def model(x): scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive) return pyro.sample("y", dist.Normal(x, scale)) @pyro.ops.jit.trace def model_log_prob_fn(x, y): cond_model = pyro.condition(model, data={"y": y}) tr = pyro.poutine.trace(cond_model).get_trace(x) return tr.log_prob_sum()
Parameters: - fn (callable) – The function to be traced.
- ignore_warnins (bool) – Whether to ignore jit warnings.
- jit_options (dict) – Optional dict of options to pass to
torch.jit.trace()
, e.g.{"optimize": False}
.