Skip to content

Commit

Permalink
module
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Oct 22, 2023
1 parent 27969dc commit 559ffce
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 28 deletions.
5 changes: 1 addition & 4 deletions pyro/contrib/autoname/autoname.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,11 @@ def _pyro_genname(msg):
msg["stop"] = True


@_make_handler(AutonameMessenger)
@_make_handler(AutonameMessenger, __name__)
def autoname(fn=None, name=None):
...


autoname.__module__ = __name__


@singledispatch
def sample(*args):
raise NotImplementedError
Expand Down
42 changes: 19 additions & 23 deletions pyro/contrib/funsor/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,50 @@

from .enum_messenger import EnumMessenger, queue # noqa: F401
from .named_messenger import MarkovMessenger, NamedMessenger
from .plate_messenger import VectorizedMarkovMessenger
from .plate_messenger import PlateMessenger, VectorizedMarkovMessenger
from .replay_messenger import ReplayMessenger
from .trace_messenger import TraceMessenger


@_make_handler(EnumMessenger)
@_make_handler(EnumMessenger, __name__)
def enum(fn=None, first_available_dim=None):
...


enum.__module__ = __name__


@_make_handler(MarkovMessenger)
@_make_handler(MarkovMessenger, __name__)
def markov(fn=None, history=1, keep=False):
...


markov.__module__ = __name__


@_make_handler(NamedMessenger)
@_make_handler(NamedMessenger, __name__)
def named(fn=None, first_available_dim=None):
...


named.__module__ = __name__
@_make_handler(PlateMessenger, __name__)
def plate(
fn=None,
name=None,
size=None,
subsample_size=None,
subsample=None,
dim=None,
use_cuda=None,
device=None,
):
...


@_make_handler(ReplayMessenger)
@_make_handler(ReplayMessenger, __name__)
def replay(fn=None, trace=None, params=None):
...


replay.__module__ = __name__


@_make_handler(TraceMessenger)
@_make_handler(TraceMessenger, __name__)
def trace(fn=None, graph_type=None, param_only=None, pack_online=True):
...


trace.__module__ = __name__


@_make_handler(VectorizedMarkovMessenger)
@_make_handler(VectorizedMarkovMessenger, __name__)
def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1):
...


vectorized_markov.__module__ = __name__
4 changes: 3 additions & 1 deletion pyro/poutine/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
############################################


def _make_handler(msngr_cls):
def _make_handler(msngr_cls, module=None):
def handler_decorator(func):
def handler(fn=None, *args, **kwargs):
if fn is not None and not (
Expand All @@ -102,6 +102,8 @@ def handler(fn=None, *args, **kwargs):
+ (msngr_cls.__doc__ if msngr_cls.__doc__ else "")
)
handler.__name__ = func.__name__
if module is not None:
handler.__module__ = module
return handler

return handler_decorator
Expand Down

0 comments on commit 559ffce

Please sign in to comment.