Skip to content

Commit

Permalink
add constraint kwarg to fake param statements
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 committed May 7, 2024
1 parent 3492301 commit f8cfb35
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,12 @@ def __getattr__(self, name: str) -> Any:
constrained_value.unconstrained = weakref.ref(unconstrained_value)
return pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: constrained_value
)(fullname, event_dim=event_dim, name=fullname)
)(
fullname,
constraint=constraint,
event_dim=event_dim,
name=fullname,
)
else: # Cannot determine supermodule and hence cannot compute fullname.
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
Expand Down Expand Up @@ -621,7 +626,7 @@ def __getattr__(self, name: str) -> Any:
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
pyro.poutine.runtime.effectful(type="param")(lambda *_, **__: result)(
fullname, result, name=fullname
fullname, result, constraint=constraints.real, name=fullname
)

if isinstance(result, torch.nn.Module):
Expand All @@ -645,7 +650,12 @@ def __getattr__(self, name: str) -> Any:
)
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: param_value
)(fullname_param, param_value, name=fullname_param)
)(
fullname_param,
param_value,
constraint=constraints.real,
name=fullname_param,
)

return result

Expand Down

0 comments on commit f8cfb35

Please sign in to comment.