Parameters¶
Parameters in Pyro are basically thin wrappers around PyTorch Tensors that carry unique names. As such Parameters are the primary stateful objects in Pyro. Users typically interact with parameters via the Pyro primitive pyro.param. Parameters play a central role in stochastic variational inference, where they are used to represent point estimates for the parameters in parameterized families of models and guides.
ParamStore¶
-
class
ParamStoreDict
[source]¶ Bases:
object
Global store for parameters in Pyro. This is basically a key-value store. The typical user interacts with the ParamStore primarily through the primitive pyro.param.
See Intro Part II for further discussion and SVI Part I for some examples.
Some things to bear in mind when using parameters in Pyro:
- parameters must be assigned unique names
- the init_tensor argument to pyro.param is only used the first time that a given (named) parameter is registered with Pyro.
- for this reason, a user may need to use the clear() method if working in a REPL in order to get the desired behavior. this method can also be invoked with pyro.clear_param_store().
- the internal name of a parameter within a PyTorch nn.Module that has been registered with Pyro is prepended with the Pyro name of the module. so nothing prevents the user from having two different modules each of which contains a parameter named weight. by contrast, a user can only have one top-level parameter named weight (outside of any module).
- parameters can be saved and loaded from disk using save and load.
-
setdefault
(name, init_constrained_value, constraint=Real())[source]¶ Retrieve a constrained parameter value from the if it exists, otherwise set the initial value. Note that this is a little fancier than
dict.setdefault()
.If the parameter already exists,
init_constrained_tensor
will be ignored. To avoid expensive creation ofinit_constrained_tensor
you can wrap it in alambda
that will only be evaluated if the parameter does not already exist:param_store.get("foo", lambda: (0.001 * torch.randn(1000, 1000)).exp(), constraint=constraints.positive)
Parameters: - name (str) – parameter name
- init_constrained_value (torch.Tensor or callable returning a torch.Tensor) – initial constrained value
- constraint (torch.distributions.constraints.Constraint) – torch constraint object
Returns: constrained parameter value
Return type:
-
named_parameters
()[source]¶ Returns an iterator over
(name, unconstrained_value)
tuples for each parameter in the ParamStore.
-
get_param
(name, init_tensor=None, constraint=Real(), event_dim=None)[source]¶ Get parameter from its name. If it does not yet exist in the ParamStore, it will be created and stored. The Pyro primitive pyro.param dispatches to this method.
Parameters: - name (str) – parameter name
- init_tensor (torch.Tensor) – initial tensor
- constraint (torch.distributions.constraints.Constraint) – torch constraint
- event_dim (int) – (ignored)
Returns: parameter
Return type:
-
match
(name)[source]¶ Get all parameters that match regex. The parameter must exist.
Parameters: name (str) – regular expression Returns: dict with key param name and value torch Tensor
-
param_name
(p)[source]¶ Get parameter name from parameter
Parameters: p – parameter Returns: parameter name
-
load
(filename, map_location=None)[source]¶ Loads parameters from disk
Note
If using
pyro.module()
on parameters loaded from disk, be sure to set theupdate_module_params
flag:pyro.get_param_store().load('saved_params.save') pyro.module('module', nn, update_module_params=True)
Parameters: - filename (str) – file name to load from
- map_location (function, torch.device, string or a dict) – specifies how to remap storage locations