Skip to content

Commit

Permalink
Clean up handling of global settings (#3152)
Browse files Browse the repository at this point in the history
* Add a global settings module

* Fancy docs stuff

* Add cholesky_relative_jitter setting

* lint

* Add explicit stage mark
  • Loading branch information
fritzo authored Nov 1, 2022
1 parent ed54fe8 commit 891880f
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Pyro Documentation
optimization
poutine
ops
settings
testing

.. toctree::
Expand Down
6 changes: 6 additions & 0 deletions docs/source/settings.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Settings
--------

.. automodule:: pyro.settings
:members:
:member-order: bysource
3 changes: 3 additions & 0 deletions pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
)
from pyro.util import set_rng_seed

from . import settings

# After changing this, run scripts/update_version.py
version_prefix = "1.8.2"

Expand Down Expand Up @@ -58,6 +60,7 @@
"render_model",
"sample",
"set_rng_seed",
"settings",
"subsample",
"validation_enabled",
]
17 changes: 17 additions & 0 deletions pyro/distributions/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyro.distributions.util import broadcast_shape, sum_rightmost
from pyro.ops.special import log_binomial

from .. import settings
from . import constraints


Expand Down Expand Up @@ -98,6 +99,22 @@ def log_prob(self, value):
)


@settings.register(
"binomial_approx_sample_thresh", __name__, "Binomial.approx_sample_thresh"
)
def _validate_thresh(thresh):
assert isinstance(thresh, float)
assert 0 < thresh


@settings.register(
"binomial_approx_log_prob_tol", __name__, "Binomial.approx_log_prob_tol"
)
def _validate_tol(tol):
assert isinstance(tol, float)
assert 0 <= tol


# This overloads .log_prob() and .enumerate_support() to speed up evaluating
# log_prob on the support of this variable: we can completely avoid tensor ops
# and merely reshape the self.logits tensor. This is especially important for
Expand Down
9 changes: 9 additions & 0 deletions pyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@

from pyro.util import ignore_jit_warnings

from .. import settings

_VALIDATION_ENABLED = __debug__
torch_dist.Distribution.set_default_validate_args(__debug__)

settings.register("validate_distributions_pyro", __name__, "_VALIDATION_ENABLED")
settings.register(
"validate_distributions_torch",
"torch.distributions.distribution",
"Distribution._validate_args",
)

log_sum_exp = logsumexp # DEPRECATED


Expand Down
4 changes: 4 additions & 0 deletions pyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from pyro.ops.rings import MarginalRing
from pyro.poutine.util import site_is_subsample

from .. import settings

_VALIDATION_ENABLED = __debug__
settings.register("validate_infer", __name__, "_VALIDATION_ENABLED")

LAST_CACHE_SIZE = [Counter()] # for profiling


Expand Down
8 changes: 8 additions & 0 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
import torch
from torch.fft import irfft, rfft

from .. import settings

_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0)
CHOLESKY_RELATIVE_JITTER = 4.0 # in units of finfo.eps


@settings.register("cholesky_relative_jitter", __name__, "CHOLESKY_RELATIVE_JITTER")
def _validate_jitter(value):
assert isinstance(value, (float, int))
assert 0 <= value


def as_complex(x):
"""
Similar to :func:`torch.view_as_complex` but copies data in case strides
Expand Down
3 changes: 3 additions & 0 deletions pyro/poutine/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from .. import settings

_VALIDATION_ENABLED = __debug__
settings.register("validate_poutine", __name__, "_VALIDATION_ENABLED")


def enable_validation(is_validate):
Expand Down
163 changes: 163 additions & 0 deletions pyro/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
Example usage::
# Simple getting and setting.
print(pyro.settings.get()) # print all settings
print(pyro.settings.get("cholesky_relative_jitter")) # print one
pyro.settings.set(cholesky_relative_jitter=0.5) # set one
pyro.settings.set(**my_settings) # set many
# Use as a contextmanager.
with pyro.settings.context(cholesky_relative_jitter=0.5):
my_function()
# Use as a decorator.
fn = pyro.settings.context(cholesky_relative_jitter=0.5)(my_function)
fn()
# Register a new setting.
pyro.settings.register(
"binomial_approx_sample_thresh", # alias
"pyro.distributions.torch", # module
"Binomial.approx_sample_thresh", # deep name
)
# Register a new setting on a user-provided validator.
@pyro.settings.register(
"binomial_approx_sample_thresh", # alias
"pyro.distributions.torch", # module
"Binomial.approx_sample_thresh", # deep name
)
def validate_thresh(thresh): # called each time setting is set
assert isinstance(thresh, float)
assert thresh > 0
Default Settings
----------------
{defaults}
Settings Interface
------------------
"""

# This library must have no dependencies on other pyro modules.
import functools
from contextlib import contextmanager
from importlib import import_module
from typing import Any, Callable, Dict, Iterator, Optional, Tuple

# Docs are updated by register().
_doc_template = __doc__

# Global registry mapping alias:str to (modulename, deepname, validator)
# triples where deepname may have dots to indicate e.g. class variables.
_REGISTRY: Dict[str, Tuple[str, str, Optional[Callable]]] = {}


def get(alias: Optional[str] = None) -> Any:
"""
Gets one or all global settings.
:param str alias: The name of a registered setting.
:returns: The currently set value.
"""
if alias is None:
# Return dict of all settings.
return {alias: get(alias) for alias in sorted(_REGISTRY)}
# Get a single setting.
module, deepname, validator = _REGISTRY[alias]
value = import_module(module)
for name in deepname.split("."):
value = getattr(value, name)
return value


def set(**kwargs) -> None:
r"""
Sets one or more settings.
:param \*\*kwargs: alias=value pairs.
"""
for alias, value in kwargs.items():
module, deepname, validator = _REGISTRY[alias]
if validator is not None:
validator(value)
destin = import_module(module)
names = deepname.split(".")
for name in names[:-1]:
destin = getattr(destin, name)
setattr(destin, names[-1], value)


@contextmanager
def context(**kwargs) -> Iterator[None]:
r"""
Context manager to temporarily override one or more settings. This also
works as a decorator.
:param \*\*kwargs: alias=value pairs.
"""
old = {alias: get(alias) for alias in kwargs}
try:
set(**kwargs)
yield
finally:
set(**old)


def register(
alias: str,
modulename: str,
deepname: str,
validator: Optional[Callable] = None,
) -> Callable:
"""
Register a global settings.
This should be declared in the module where the setting is defined.
This can be used either as a declaration::
settings.register("my_setting", __name__, "MY_SETTING")
or as a decorator on a user-defined validator function::
@settings.register("my_setting", __name__, "MY_SETTING")
def _validate_my_setting(value):
assert isinstance(value, float)
assert 0 < value
:param str alias: A valid python identifier serving as a settings alias.
Lower snake case preferred, e.g. ``my_setting``.
:param str modulename: The module name where the setting is declared,
typically ``__name__``.
:param str deepname: A ``.``-separated string of names. E.g. for a module
constant, use ``MY_CONSTANT``. For a class attributue, use
``MyClass.my_attribute``.
:param callable validator: Optional validator that inputs a value,
possibly raises validation errors, and returns None.
"""
global __doc__
assert isinstance(alias, str)
assert alias.isidentifier()
assert isinstance(modulename, str)
assert isinstance(deepname, str)
_REGISTRY[alias] = modulename, deepname, validator

# Add default value to module docstring.
__doc__ = _doc_template.format(
defaults="\n".join(f"- {a} = {get(a)}" for a in sorted(_REGISTRY))
)

# Support use as a decorator on an optional user-provided validator.
if validator is None:
# Return a decorator, but its fine if user discards this.
return functools.partial(register, alias, modulename, deepname)
else:
# Test current value passes validation.
validator(get(alias))
return validator
50 changes: 50 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest

from pyro import settings

_TEST_SETTING: float = 0.1

pytestmark = pytest.mark.stage("unit")


def test_settings():
v0 = settings.get()
assert isinstance(v0, dict)
assert all(isinstance(alias, str) for alias in v0)
assert settings.get("validate_distributions_pyro") is True
assert settings.get("validate_distributions_torch") is True
assert settings.get("validate_poutine") is True
assert settings.get("validate_infer") is True


def test_register():
with pytest.raises(KeyError):
settings.get("test_setting")

@settings.register("test_setting", "tests.test_settings", "_TEST_SETTING")
def _validate(value):
assert isinstance(value, float)
assert 0 < value

# Test simple get and set.
assert settings.get("test_setting") == 0.1
settings.set(test_setting=0.2)
assert settings.get("test_setting") == 0.2
with pytest.raises(AssertionError):
settings.set(test_setting=-0.1)

# Test context manager.
with settings.context(test_setting=0.3):
assert settings.get("test_setting") == 0.3
assert settings.get("test_setting") == 0.2

# Test decorator.
@settings.context(test_setting=0.4)
def fn():
assert settings.get("test_setting") == 0.4

fn()
assert settings.get("test_setting") == 0.2

0 comments on commit 891880f

Please sign in to comment.