From 75d0da41a5975f634c834ac5bfbf964ceefb7445 Mon Sep 17 00:00:00 2001 From: ecotner <2.71828cotner@gmail.com> Date: Mon, 27 Apr 2020 18:13:21 -0400 Subject: [PATCH 1/9] added RandomVariable container class around Distribution --- pyro/distributions/random_variable.py | 102 ++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 pyro/distributions/random_variable.py diff --git a/pyro/distributions/random_variable.py b/pyro/distributions/random_variable.py new file mode 100644 index 0000000000..91f6c73e51 --- /dev/null +++ b/pyro/distributions/random_variable.py @@ -0,0 +1,102 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union, Callable + +from torch import Tensor +from pyro.distributions import TransformedDistribution +from pyro.distributions.transforms import ( + Transform, + AffineTransform, + AbsTransform, + PowerTransform +) +from pyro.distributions import Distribution + + +class RVArithmeticMixin: + """Mixin class for overloading arithmetic operations on random variables + """ + + def __add__(self, x: Union[float, Tensor]): + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(x, 1)) + ) + + def __radd__(self, x: Union[float, Tensor]): + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(x, 1)) + ) + + def __sub__(self, x: Union[float, Tensor]): + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(-x, 1)) + ) + + def __rsub__(self, x: Union[float, Tensor]): + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(x, -1)) + ) + + def __mul__(self, x: Union[float, Tensor]): + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(0, x)) + ) + + def __rmul__(self, x: Union[float, Tensor]): + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(0, x)) + ) + + def __truediv__(self, x: Union[float, Tensor]): + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(0, 1/x)) + ) + + def __neg__(self): + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(0, -1)) + ) + + def __abs__(self): + return RandomVariable( + TransformedDistribution(self.distribution, AbsTransform()) + ) + + def __pow__(self, x): + return RandomVariable( + TransformedDistribution(self.distribution, PowerTransform(x)) + ) + + +class RandomVariable(RVArithmeticMixin): + """Random variable container class around a distribution + + Representation of a distribution interpreted as a random variable. Rather + than directly manipulating a probability density by applying pointwise + transformations to it, this allows for simple arithmetic transformations of + the random variable the distribution represents. For more flexibility, + consider using the `transform` method. Note that if you + perform a non-invertible transform (like `abs(X)` or `X**2`), certain + things might not work properly. + """ + + def __init__(self, distribution: Distribution): + self.distribution = distribution + + def __getattr__(self, name): + return self.distribution.__getattribute__(name) + + def __call__(self, *args, **kwargs): + return self.distribution(*args, **kwargs) + + def tranform(self, t: Transform): + """Performs a transformation on the distribution underlying the RV. + + :param t: The transformation (or sequence of transformations) to be + applied to the distribution. There are many examples to be found in + `torch.distributions.transforms` and `pyro.distributions.transforms`, + or you can subclass directly from `Transform`. + :type t: `Transform` + """ + self.distribution = TransformedDistribution(self.distribution, t) From 3a15bc164f40f87e2ffddc14d2b52fe31bd2f147 Mon Sep 17 00:00:00 2001 From: ecotner <2.71828cotner@gmail.com> Date: Mon, 27 Apr 2020 18:13:40 -0400 Subject: [PATCH 2/9] added tests for RandomVariable arithmetic --- tests/distributions/test_random_variable.py | 70 +++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/distributions/test_random_variable.py diff --git a/tests/distributions/test_random_variable.py b/tests/distributions/test_random_variable.py new file mode 100644 index 0000000000..1d34c8faaf --- /dev/null +++ b/tests/distributions/test_random_variable.py @@ -0,0 +1,70 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch.tensor as tt +import pyro +from pyro.distributions.random_variable import RandomVariable as RV +from pyro.distributions import Uniform + +N_SAMPLES = 100 + + +def test_add(): + X = RV(Uniform(0, 1)) # (0, 1) + X = X + 1 # (1, 2) + X = 1 + X # (2, 3) + X += 1 # (3, 4) + x = X.sample([N_SAMPLES]) + assert ((3 <= x) & (x <= 4)).all().item() + + +def test_subtract(): + X = RV(Uniform(0, 1)) # (0, 1) + X = 1 - X # (0, 1) + X = X - 1 # (-1, 0) + X -= 1 # (-2, -1) + x = X.sample([N_SAMPLES]) + assert ((-2 <= x) & (x <= -1)).all().item() + + +def test_multiply_divide(): + X = RV(Uniform(0, 1)) # (0, 1) + X *= 4 # (0, 4) + X /= 2 # (0, 2) + x = X.sample([N_SAMPLES]) + assert ((0 <= x) & (x <= 2)).all().item() + + +def test_abs(): + X = RV(Uniform(0, 1)) # (0, 1) + X = 2*(X - 0.5) # (-1, 1) + X = abs(X) # (0, 1) + x = X.sample([N_SAMPLES]) + assert ((0 <= x) & (x <= 1)).all().item() + + +def test_neg(): + X = RV(Uniform(0, 1)) # (0, 1) + X = -X # (-1, 0) + x = X.sample([N_SAMPLES]) + assert ((-1 <= x) & (x <= 0)).all().item() + + +def test_pow(): + X = RV(Uniform(-1, 1)) # (-1, 1) + X = X**2 # (0, 1) + x = X.sample([N_SAMPLES]) + assert ((0 <= x) & (x <= 1)).all().item() + + +def test_tensor_ops(): + pi = 3.141592654 + X = RV(Uniform(0, 1).expand([5, 5])) + a = tt([[1, 2, 3, 4, 5]]) + b = a.T + X = abs(pi*(-X + a - 3*b)) + x = X.sample() + assert x.shape == (5, 5) + assert (x >= 0).all().item() From b7eee8746582eb24942aeb8834cddd9050b419f4 Mon Sep 17 00:00:00 2001 From: ecotner <2.71828cotner@gmail.com> Date: Mon, 27 Apr 2020 18:19:12 -0400 Subject: [PATCH 3/9] changed transform method to return new RandomVariable rather than modifying distribution in-place --- pyro/distributions/random_variable.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyro/distributions/random_variable.py b/pyro/distributions/random_variable.py index 91f6c73e51..675e8780f1 100644 --- a/pyro/distributions/random_variable.py +++ b/pyro/distributions/random_variable.py @@ -99,4 +99,5 @@ def tranform(self, t: Transform): or you can subclass directly from `Transform`. :type t: `Transform` """ - self.distribution = TransformedDistribution(self.distribution, t) + dist = TransformedDistribution(self.distribution, t) + return RandomVariable(dist) From 3424cab11ad4f67371ad5b3b508adb2a2ca1e68c Mon Sep 17 00:00:00 2001 From: ecotner <2.71828cotner@gmail.com> Date: Wed, 29 Apr 2020 01:46:45 -0400 Subject: [PATCH 4/9] moved RandomVariable to pyro.contrib; added method chaining --- pyro/contrib/randomvariable/__init__.py | 1 + .../contrib/randomvariable/random_variable.py | 126 ++++++++++++++++++ pyro/distributions/distribution.py | 5 + pyro/distributions/random_variable.py | 103 -------------- .../randomvariable}/test_random_variable.py | 32 +++-- 5 files changed, 156 insertions(+), 111 deletions(-) create mode 100644 pyro/contrib/randomvariable/__init__.py create mode 100644 pyro/contrib/randomvariable/random_variable.py delete mode 100644 pyro/distributions/random_variable.py rename tests/{distributions => contrib/randomvariable}/test_random_variable.py (66%) diff --git a/pyro/contrib/randomvariable/__init__.py b/pyro/contrib/randomvariable/__init__.py new file mode 100644 index 0000000000..83a9448362 --- /dev/null +++ b/pyro/contrib/randomvariable/__init__.py @@ -0,0 +1 @@ +from pyro.contrib.randomvariable.random_variable import RandomVariable diff --git a/pyro/contrib/randomvariable/random_variable.py b/pyro/contrib/randomvariable/random_variable.py new file mode 100644 index 0000000000..ab491b486f --- /dev/null +++ b/pyro/contrib/randomvariable/random_variable.py @@ -0,0 +1,126 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union, Callable + +from torch import Tensor +from pyro.distributions import TransformedDistribution +from pyro.distributions.transforms import ( + Transform, + AffineTransform, + AbsTransform, + PowerTransform, + ExpTransform, + TanhTransform, + SoftmaxTransform, + SigmoidTransform +) + + +class RVMagicOps: + """Mixin class for overloading __magic__ operations on random variables + """ + + def __add__(self, x: Union[float, Tensor]): + return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(x, 1))) + + def __radd__(self, x: Union[float, Tensor]): + return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(x, 1))) + + def __sub__(self, x: Union[float, Tensor]): + return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(-x, 1))) + + def __rsub__(self, x: Union[float, Tensor]): + return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(x, -1))) + + def __mul__(self, x: Union[float, Tensor]): + return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, x))) + + def __rmul__(self, x: Union[float, Tensor]): + return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, x))) + + def __truediv__(self, x: Union[float, Tensor]): + return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, 1/x))) + + def __neg__(self): + return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, -1))) + + def __abs__(self): + return RandomVariable(TransformedDistribution(self.distribution, AbsTransform())) + + def __pow__(self, x): + return RandomVariable(TransformedDistribution(self.distribution, PowerTransform(x))) + + +class RVChainOps: + """Mixin class for performing common unary/binary operations on/between + random variables/constant tensors + """ + + def add(self, x): + return self + x + + def sub(self, x): + return self - x + + def mul(self, x): + return self * x + + def div(self, x): + return self / x + + def abs(self): + return abs(self) + + def pow(self, x): + return self ** x + + def neg(self, x): + return -self + + def exp(self): + return self.transform(ExpTransform()) + + def log(self): + return self.transform(ExpTransform().inv) + + def sigmoid(self): + return self.transform(SigmoidTransform()) + + def tanh(self): + return self.transform(TanhTransform()) + + def softmax(self): + return self.transform(SoftmaxTransform()) + + +class RandomVariable(RVMagicOps, RVChainOps): + """Random variable container class around a distribution + + Representation of a distribution interpreted as a random variable. Rather + than directly manipulating a probability density by applying pointwise + transformations to it, this allows for simple arithmetic transformations of + the random variable the distribution represents. For more flexibility, + consider using the `transform` method. Note that if you perform a + non-invertible transform (like `abs(X)` or `X**2`), certain things might + not work properly. + """ + + def __init__(self, distribution): + self.distribution = distribution + + def transform(self, t: Transform): + """Performs a transformation on the distribution underlying the RV. + + :param t: The transformation (or sequence of transformations) to be + applied to the distribution. There are many examples to be found in + `torch.distributions.transforms` and `pyro.distributions.transforms`, + or you can subclass directly from `Transform`. + :type t: `Transform` + """ + dist = TransformedDistribution(self.distribution, t) + return RandomVariable(dist) + + @property + def dist(self): + return self.distribution diff --git a/pyro/distributions/distribution.py b/pyro/distributions/distribution.py index 40f0a9bf61..ed696801ae 100644 --- a/pyro/distributions/distribution.py +++ b/pyro/distributions/distribution.py @@ -4,6 +4,7 @@ from abc import ABCMeta, abstractmethod from pyro.distributions.score_parts import ScoreParts +# from pyro.contrib.randomvariable import RandomVariable as RV class Distribution(object, metaclass=ABCMeta): @@ -166,3 +167,7 @@ def has_rsample_(self, value): raise ValueError("Expected value in {False,True}, actual {}".format(value)) self.has_rsample = value return self + + # @property + # def rv(self): + # return RV(self) diff --git a/pyro/distributions/random_variable.py b/pyro/distributions/random_variable.py deleted file mode 100644 index 675e8780f1..0000000000 --- a/pyro/distributions/random_variable.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union, Callable - -from torch import Tensor -from pyro.distributions import TransformedDistribution -from pyro.distributions.transforms import ( - Transform, - AffineTransform, - AbsTransform, - PowerTransform -) -from pyro.distributions import Distribution - - -class RVArithmeticMixin: - """Mixin class for overloading arithmetic operations on random variables - """ - - def __add__(self, x: Union[float, Tensor]): - return RandomVariable( - TransformedDistribution(self.distribution, AffineTransform(x, 1)) - ) - - def __radd__(self, x: Union[float, Tensor]): - return RandomVariable( - TransformedDistribution(self.distribution, AffineTransform(x, 1)) - ) - - def __sub__(self, x: Union[float, Tensor]): - return RandomVariable( - TransformedDistribution(self.distribution, AffineTransform(-x, 1)) - ) - - def __rsub__(self, x: Union[float, Tensor]): - return RandomVariable( - TransformedDistribution(self.distribution, AffineTransform(x, -1)) - ) - - def __mul__(self, x: Union[float, Tensor]): - return RandomVariable( - TransformedDistribution(self.distribution, AffineTransform(0, x)) - ) - - def __rmul__(self, x: Union[float, Tensor]): - return RandomVariable( - TransformedDistribution(self.distribution, AffineTransform(0, x)) - ) - - def __truediv__(self, x: Union[float, Tensor]): - return RandomVariable( - TransformedDistribution(self.distribution, AffineTransform(0, 1/x)) - ) - - def __neg__(self): - return RandomVariable( - TransformedDistribution(self.distribution, AffineTransform(0, -1)) - ) - - def __abs__(self): - return RandomVariable( - TransformedDistribution(self.distribution, AbsTransform()) - ) - - def __pow__(self, x): - return RandomVariable( - TransformedDistribution(self.distribution, PowerTransform(x)) - ) - - -class RandomVariable(RVArithmeticMixin): - """Random variable container class around a distribution - - Representation of a distribution interpreted as a random variable. Rather - than directly manipulating a probability density by applying pointwise - transformations to it, this allows for simple arithmetic transformations of - the random variable the distribution represents. For more flexibility, - consider using the `transform` method. Note that if you - perform a non-invertible transform (like `abs(X)` or `X**2`), certain - things might not work properly. - """ - - def __init__(self, distribution: Distribution): - self.distribution = distribution - - def __getattr__(self, name): - return self.distribution.__getattribute__(name) - - def __call__(self, *args, **kwargs): - return self.distribution(*args, **kwargs) - - def tranform(self, t: Transform): - """Performs a transformation on the distribution underlying the RV. - - :param t: The transformation (or sequence of transformations) to be - applied to the distribution. There are many examples to be found in - `torch.distributions.transforms` and `pyro.distributions.transforms`, - or you can subclass directly from `Transform`. - :type t: `Transform` - """ - dist = TransformedDistribution(self.distribution, t) - return RandomVariable(dist) diff --git a/tests/distributions/test_random_variable.py b/tests/contrib/randomvariable/test_random_variable.py similarity index 66% rename from tests/distributions/test_random_variable.py rename to tests/contrib/randomvariable/test_random_variable.py index 1d34c8faaf..75483b6adc 100644 --- a/tests/distributions/test_random_variable.py +++ b/tests/contrib/randomvariable/test_random_variable.py @@ -1,11 +1,13 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import math + import pytest import torch.tensor as tt import pyro -from pyro.distributions.random_variable import RandomVariable as RV +from pyro.contrib.randomvariable import RandomVariable as RV from pyro.distributions import Uniform N_SAMPLES = 100 @@ -16,7 +18,7 @@ def test_add(): X = X + 1 # (1, 2) X = 1 + X # (2, 3) X += 1 # (3, 4) - x = X.sample([N_SAMPLES]) + x = X.dist.sample([N_SAMPLES]) assert ((3 <= x) & (x <= 4)).all().item() @@ -25,7 +27,7 @@ def test_subtract(): X = 1 - X # (0, 1) X = X - 1 # (-1, 0) X -= 1 # (-2, -1) - x = X.sample([N_SAMPLES]) + x = X.dist.sample([N_SAMPLES]) assert ((-2 <= x) & (x <= -1)).all().item() @@ -33,7 +35,7 @@ def test_multiply_divide(): X = RV(Uniform(0, 1)) # (0, 1) X *= 4 # (0, 4) X /= 2 # (0, 2) - x = X.sample([N_SAMPLES]) + x = X.dist.sample([N_SAMPLES]) assert ((0 <= x) & (x <= 2)).all().item() @@ -41,21 +43,21 @@ def test_abs(): X = RV(Uniform(0, 1)) # (0, 1) X = 2*(X - 0.5) # (-1, 1) X = abs(X) # (0, 1) - x = X.sample([N_SAMPLES]) + x = X.dist.sample([N_SAMPLES]) assert ((0 <= x) & (x <= 1)).all().item() def test_neg(): X = RV(Uniform(0, 1)) # (0, 1) X = -X # (-1, 0) - x = X.sample([N_SAMPLES]) + x = X.dist.sample([N_SAMPLES]) assert ((-1 <= x) & (x <= 0)).all().item() def test_pow(): X = RV(Uniform(-1, 1)) # (-1, 1) X = X**2 # (0, 1) - x = X.sample([N_SAMPLES]) + x = X.dist.sample([N_SAMPLES]) assert ((0 <= x) & (x <= 1)).all().item() @@ -65,6 +67,20 @@ def test_tensor_ops(): a = tt([[1, 2, 3, 4, 5]]) b = a.T X = abs(pi*(-X + a - 3*b)) - x = X.sample() + x = X.dist.sample() assert x.shape == (5, 5) assert (x >= 0).all().item() + + +def test_chaining(): + X = ( + RV(Uniform(0, 1)) # (0, 1) + .add(1) # (1, 2) + .pow(2) # (1, 4) + .mul(2) # (2, 8) + .sub(5) # (-3, 3) + .tanh() # (-1, 1); more like (-0.995, +0.995) + .exp() # (1/e, e) + ) + x = X.dist.sample([N_SAMPLES]) + assert ((1/math.e <= x) & (x <= math.e)).all().item() From fd462c682d002f0d8974a2130351e9c64ded58ed Mon Sep 17 00:00:00 2001 From: ecotner <2.71828cotner@gmail.com> Date: Wed, 29 Apr 2020 01:51:37 -0400 Subject: [PATCH 5/9] moved RandomVariable import into .rv property to prevent circular import --- pyro/distributions/distribution.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyro/distributions/distribution.py b/pyro/distributions/distribution.py index ed696801ae..d0f487589b 100644 --- a/pyro/distributions/distribution.py +++ b/pyro/distributions/distribution.py @@ -4,7 +4,6 @@ from abc import ABCMeta, abstractmethod from pyro.distributions.score_parts import ScoreParts -# from pyro.contrib.randomvariable import RandomVariable as RV class Distribution(object, metaclass=ABCMeta): @@ -168,6 +167,7 @@ def has_rsample_(self, value): self.has_rsample = value return self - # @property - # def rv(self): - # return RV(self) + @property + def rv(self): + from pyro.contrib.randomvariable import RandomVariable + return RandomVariable(self) From 239de06538a3282fa9a268b8a6c337c30604de37 Mon Sep 17 00:00:00 2001 From: ecotner <2.71828cotner@gmail.com> Date: Thu, 30 Apr 2020 18:32:25 -0400 Subject: [PATCH 6/9] tried to fix linting errors; updated tests to use .rv notation --- .../contrib/randomvariable/random_variable.py | 4 ++-- .../randomvariable/test_random_variable.py | 19 ++++++++----------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/pyro/contrib/randomvariable/random_variable.py b/pyro/contrib/randomvariable/random_variable.py index ab491b486f..2b147edb1a 100644 --- a/pyro/contrib/randomvariable/random_variable.py +++ b/pyro/contrib/randomvariable/random_variable.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from typing import Union, Callable +from typing import Union from torch import Tensor from pyro.distributions import TransformedDistribution @@ -112,7 +112,7 @@ def __init__(self, distribution): def transform(self, t: Transform): """Performs a transformation on the distribution underlying the RV. - :param t: The transformation (or sequence of transformations) to be + :param t: The transformation (or sequence of transformations) to be applied to the distribution. There are many examples to be found in `torch.distributions.transforms` and `pyro.distributions.transforms`, or you can subclass directly from `Transform`. diff --git a/tests/contrib/randomvariable/test_random_variable.py b/tests/contrib/randomvariable/test_random_variable.py index 75483b6adc..5b233d4e2b 100644 --- a/tests/contrib/randomvariable/test_random_variable.py +++ b/tests/contrib/randomvariable/test_random_variable.py @@ -3,10 +3,7 @@ import math -import pytest - import torch.tensor as tt -import pyro from pyro.contrib.randomvariable import RandomVariable as RV from pyro.distributions import Uniform @@ -14,7 +11,7 @@ def test_add(): - X = RV(Uniform(0, 1)) # (0, 1) + X = Uniform(0, 1).rv # (0, 1) X = X + 1 # (1, 2) X = 1 + X # (2, 3) X += 1 # (3, 4) @@ -23,7 +20,7 @@ def test_add(): def test_subtract(): - X = RV(Uniform(0, 1)) # (0, 1) + X = Uniform(0, 1).rv # (0, 1) X = 1 - X # (0, 1) X = X - 1 # (-1, 0) X -= 1 # (-2, -1) @@ -32,7 +29,7 @@ def test_subtract(): def test_multiply_divide(): - X = RV(Uniform(0, 1)) # (0, 1) + X = Uniform(0, 1).rv # (0, 1) X *= 4 # (0, 4) X /= 2 # (0, 2) x = X.dist.sample([N_SAMPLES]) @@ -40,7 +37,7 @@ def test_multiply_divide(): def test_abs(): - X = RV(Uniform(0, 1)) # (0, 1) + X = Uniform(0, 1).rv # (0, 1) X = 2*(X - 0.5) # (-1, 1) X = abs(X) # (0, 1) x = X.dist.sample([N_SAMPLES]) @@ -48,14 +45,14 @@ def test_abs(): def test_neg(): - X = RV(Uniform(0, 1)) # (0, 1) + X = Uniform(0, 1).rv # (0, 1) X = -X # (-1, 0) x = X.dist.sample([N_SAMPLES]) assert ((-1 <= x) & (x <= 0)).all().item() def test_pow(): - X = RV(Uniform(-1, 1)) # (-1, 1) + X = Uniform(0, 1).rv # (0, 1) X = X**2 # (0, 1) x = X.dist.sample([N_SAMPLES]) assert ((0 <= x) & (x <= 1)).all().item() @@ -63,7 +60,7 @@ def test_pow(): def test_tensor_ops(): pi = 3.141592654 - X = RV(Uniform(0, 1).expand([5, 5])) + X = Uniform(0, 1).expand([5, 5]).rv a = tt([[1, 2, 3, 4, 5]]) b = a.T X = abs(pi*(-X + a - 3*b)) @@ -74,7 +71,7 @@ def test_tensor_ops(): def test_chaining(): X = ( - RV(Uniform(0, 1)) # (0, 1) + Uniform(0, 1).rv # (0, 1) .add(1) # (1, 2) .pow(2) # (1, 4) .mul(2) # (2, 8) From 245ee3c0118435c8a8e7b535e8082cec32ed8b95 Mon Sep 17 00:00:00 2001 From: ecotner <2.71828cotner@gmail.com> Date: Thu, 30 Apr 2020 18:51:03 -0400 Subject: [PATCH 7/9] fix more linting errors --- pyro/contrib/randomvariable/__init__.py | 4 ++++ tests/contrib/randomvariable/test_random_variable.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pyro/contrib/randomvariable/__init__.py b/pyro/contrib/randomvariable/__init__.py index 83a9448362..e9d97ab2e0 100644 --- a/pyro/contrib/randomvariable/__init__.py +++ b/pyro/contrib/randomvariable/__init__.py @@ -1 +1,5 @@ from pyro.contrib.randomvariable.random_variable import RandomVariable + +__all__ = [ + "RandomVariable", +] diff --git a/tests/contrib/randomvariable/test_random_variable.py b/tests/contrib/randomvariable/test_random_variable.py index 5b233d4e2b..5a1a43c194 100644 --- a/tests/contrib/randomvariable/test_random_variable.py +++ b/tests/contrib/randomvariable/test_random_variable.py @@ -4,7 +4,6 @@ import math import torch.tensor as tt -from pyro.contrib.randomvariable import RandomVariable as RV from pyro.distributions import Uniform N_SAMPLES = 100 From c6c733c3db36a1520bf4ff17e7b71992554074c8 Mon Sep 17 00:00:00 2001 From: ecotner <2.71828cotner@gmail.com> Date: Sat, 2 May 2020 11:23:08 -0400 Subject: [PATCH 8/9] fixed negation chain operation requiring too many arguments --- pyro/contrib/randomvariable/random_variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/contrib/randomvariable/random_variable.py b/pyro/contrib/randomvariable/random_variable.py index 2b147edb1a..1aa1355b43 100644 --- a/pyro/contrib/randomvariable/random_variable.py +++ b/pyro/contrib/randomvariable/random_variable.py @@ -75,7 +75,7 @@ def abs(self): def pow(self, x): return self ** x - def neg(self, x): + def neg(self): return -self def exp(self): From 60c2f9c72c3fad7841f79da0fca83212189d098b Mon Sep 17 00:00:00 2001 From: ecotner <2.71828cotner@gmail.com> Date: Mon, 4 May 2020 18:31:17 -0400 Subject: [PATCH 9/9] adding/updating sphinx-style documentation --- docs/source/contrib.randomvariable.rst | 11 ++++++ docs/source/index.rst | 1 + .../contrib/randomvariable/random_variable.py | 37 +++++++++++++++++-- pyro/distributions/distribution.py | 20 ++++++++++ 4 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 docs/source/contrib.randomvariable.rst diff --git a/docs/source/contrib.randomvariable.rst b/docs/source/contrib.randomvariable.rst new file mode 100644 index 0000000000..a6313bc395 --- /dev/null +++ b/docs/source/contrib.randomvariable.rst @@ -0,0 +1,11 @@ +Random Variables +================ + +.. automodule:: pyro.contrib.randomvariable + +Random Variable +--------------- +.. autoclass:: pyro.contrib.randomvariable.random_variable.RandomVariable + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/index.rst b/docs/source/index.rst index 956b71ecc3..a888f602af 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -39,6 +39,7 @@ Pyro Documentation contrib.gp contrib.minipyro contrib.oed + contrib.randomvariable contrib.timeseries contrib.tracking diff --git a/pyro/contrib/randomvariable/random_variable.py b/pyro/contrib/randomvariable/random_variable.py index 1aa1355b43..1bc28f1caa 100644 --- a/pyro/contrib/randomvariable/random_variable.py +++ b/pyro/contrib/randomvariable/random_variable.py @@ -18,7 +18,7 @@ class RVMagicOps: - """Mixin class for overloading __magic__ operations on random variables + """Mixin class for overloading __magic__ operations on random variables. """ def __add__(self, x: Union[float, Tensor]): @@ -54,7 +54,7 @@ def __pow__(self, x): class RVChainOps: """Mixin class for performing common unary/binary operations on/between - random variables/constant tensors + random variables/constant tensors using method chaining syntax. """ def add(self, x): @@ -95,7 +95,7 @@ def softmax(self): class RandomVariable(RVMagicOps, RVChainOps): - """Random variable container class around a distribution + """EXPERIMENTAL random variable container class around a distribution Representation of a distribution interpreted as a random variable. Rather than directly manipulating a probability density by applying pointwise @@ -104,9 +104,29 @@ class RandomVariable(RVMagicOps, RVChainOps): consider using the `transform` method. Note that if you perform a non-invertible transform (like `abs(X)` or `X**2`), certain things might not work properly. + + Can switch between `RandomVariable` and `Distribution` objects with the + convenient `Distribution.rv` and `RandomVariable.dist` properties. + + Supports either chaining operations or arithmetic operator overloading. + + Example usage:: + + # This should be equivalent to an Exponential distribution. + RandomVariable(Uniform(0, 1)).log().neg().dist + + # These two distributions Y1, Y2 should be the same + X = Uniform(0, 1).rv + Y1 = X.mul(4).pow(0.5).sub(1).abs().neg().dist + Y2 = (-abs((4*X)**(0.5) - 1)).dist """ def __init__(self, distribution): + """Wraps a distribution as a RandomVariable + + :param distribution: The `Distribution` object to wrap + :type distribution: ~pyro.distributions.distribution.Distribution + """ self.distribution = distribution def transform(self, t: Transform): @@ -116,11 +136,20 @@ def transform(self, t: Transform): applied to the distribution. There are many examples to be found in `torch.distributions.transforms` and `pyro.distributions.transforms`, or you can subclass directly from `Transform`. - :type t: `Transform` + :type t: ~pyro.distributions.transforms.Transform + + :return: The transformed `RandomVariable` + :rtype: ~pyro.contrib.randomvariable.random_variable.RandomVariable """ dist = TransformedDistribution(self.distribution, t) return RandomVariable(dist) @property def dist(self): + """Convenience property for exposing the distribution underlying the + random variable. + + :return: The `Distribution` object underlying the random variable + :rtype: ~pyro.distributions.distribution.Distribution + """ return self.distribution diff --git a/pyro/distributions/distribution.py b/pyro/distributions/distribution.py index d0f487589b..9cbd4a899a 100644 --- a/pyro/distributions/distribution.py +++ b/pyro/distributions/distribution.py @@ -169,5 +169,25 @@ def has_rsample_(self, value): @property def rv(self): + """ + EXPERIMENTAL Switch to the Random Variable DSL for applying transformations + to random variables. Supports either chaining operations or arithmetic + operator overloading. + + Example usage:: + + # This should be equivalent to an Exponential distribution. + Uniform(0, 1).rv.log().neg().dist + + # These two distributions Y1, Y2 should be the same + X = Uniform(0, 1).rv + Y1 = X.mul(4).pow(0.5).sub(1).abs().neg().dist + Y2 = (-abs((4*X)**(0.5) - 1)).dist + + + :return: A :class: `~pyro.contrib.randomvariable.random_variable.RandomVariable` + object wrapping this distribution. + :rtype: ~pyro.contrib.randomvariable.random_variable.RandomVariable + """ from pyro.contrib.randomvariable import RandomVariable return RandomVariable(self)