Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding RandomVariable container class around Distribution #2448

Merged
merged 9 commits into from
May 5, 2020
11 changes: 11 additions & 0 deletions docs/source/contrib.randomvariable.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Random Variables
================

.. automodule:: pyro.contrib.randomvariable

Random Variable
---------------
.. autoclass:: pyro.contrib.randomvariable.random_variable.RandomVariable
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Pyro Documentation
contrib.gp
contrib.minipyro
contrib.oed
contrib.randomvariable
contrib.timeseries
contrib.tracking

Expand Down
5 changes: 5 additions & 0 deletions pyro/contrib/randomvariable/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pyro.contrib.randomvariable.random_variable import RandomVariable

__all__ = [
"RandomVariable",
]
155 changes: 155 additions & 0 deletions pyro/contrib/randomvariable/random_variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Union

from torch import Tensor
from pyro.distributions import TransformedDistribution
from pyro.distributions.transforms import (
Transform,
AffineTransform,
AbsTransform,
PowerTransform,
ExpTransform,
TanhTransform,
SoftmaxTransform,
SigmoidTransform
)


class RVMagicOps:
"""Mixin class for overloading __magic__ operations on random variables.
"""

def __add__(self, x: Union[float, Tensor]):
return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(x, 1)))

def __radd__(self, x: Union[float, Tensor]):
return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(x, 1)))

def __sub__(self, x: Union[float, Tensor]):
return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(-x, 1)))

def __rsub__(self, x: Union[float, Tensor]):
return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(x, -1)))

def __mul__(self, x: Union[float, Tensor]):
return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, x)))

def __rmul__(self, x: Union[float, Tensor]):
return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, x)))

def __truediv__(self, x: Union[float, Tensor]):
return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, 1/x)))

def __neg__(self):
return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, -1)))

def __abs__(self):
return RandomVariable(TransformedDistribution(self.distribution, AbsTransform()))

def __pow__(self, x):
return RandomVariable(TransformedDistribution(self.distribution, PowerTransform(x)))


class RVChainOps:
"""Mixin class for performing common unary/binary operations on/between
random variables/constant tensors using method chaining syntax.
"""

def add(self, x):
return self + x

def sub(self, x):
return self - x

def mul(self, x):
return self * x

def div(self, x):
return self / x

def abs(self):
return abs(self)

def pow(self, x):
return self ** x

def neg(self):
return -self

def exp(self):
return self.transform(ExpTransform())

def log(self):
return self.transform(ExpTransform().inv)

def sigmoid(self):
return self.transform(SigmoidTransform())

def tanh(self):
return self.transform(TanhTransform())

def softmax(self):
return self.transform(SoftmaxTransform())


class RandomVariable(RVMagicOps, RVChainOps):
"""EXPERIMENTAL random variable container class around a distribution

Representation of a distribution interpreted as a random variable. Rather
than directly manipulating a probability density by applying pointwise
transformations to it, this allows for simple arithmetic transformations of
the random variable the distribution represents. For more flexibility,
consider using the `transform` method. Note that if you perform a
non-invertible transform (like `abs(X)` or `X**2`), certain things might
not work properly.

Can switch between `RandomVariable` and `Distribution` objects with the
convenient `Distribution.rv` and `RandomVariable.dist` properties.

Supports either chaining operations or arithmetic operator overloading.

Example usage::

# This should be equivalent to an Exponential distribution.
RandomVariable(Uniform(0, 1)).log().neg().dist

# These two distributions Y1, Y2 should be the same
X = Uniform(0, 1).rv
Y1 = X.mul(4).pow(0.5).sub(1).abs().neg().dist
Y2 = (-abs((4*X)**(0.5) - 1)).dist
"""

def __init__(self, distribution):
"""Wraps a distribution as a RandomVariable

:param distribution: The `Distribution` object to wrap
:type distribution: ~pyro.distributions.distribution.Distribution
"""
self.distribution = distribution

def transform(self, t: Transform):
"""Performs a transformation on the distribution underlying the RV.

:param t: The transformation (or sequence of transformations) to be
applied to the distribution. There are many examples to be found in
`torch.distributions.transforms` and `pyro.distributions.transforms`,
or you can subclass directly from `Transform`.
:type t: ~pyro.distributions.transforms.Transform

:return: The transformed `RandomVariable`
:rtype: ~pyro.contrib.randomvariable.random_variable.RandomVariable
"""
ecotner marked this conversation as resolved.
Show resolved Hide resolved
dist = TransformedDistribution(self.distribution, t)
return RandomVariable(dist)

@property
def dist(self):
"""Convenience property for exposing the distribution underlying the
random variable.

:return: The `Distribution` object underlying the random variable
:rtype: ~pyro.distributions.distribution.Distribution
"""
return self.distribution
25 changes: 25 additions & 0 deletions pyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,28 @@ def has_rsample_(self, value):
raise ValueError("Expected value in {False,True}, actual {}".format(value))
self.has_rsample = value
return self

@property
def rv(self):
ecotner marked this conversation as resolved.
Show resolved Hide resolved
"""
EXPERIMENTAL Switch to the Random Variable DSL for applying transformations
to random variables. Supports either chaining operations or arithmetic
operator overloading.

Example usage::

# This should be equivalent to an Exponential distribution.
Uniform(0, 1).rv.log().neg().dist

# These two distributions Y1, Y2 should be the same
X = Uniform(0, 1).rv
Y1 = X.mul(4).pow(0.5).sub(1).abs().neg().dist
Y2 = (-abs((4*X)**(0.5) - 1)).dist


:return: A :class: `~pyro.contrib.randomvariable.random_variable.RandomVariable`
object wrapping this distribution.
:rtype: ~pyro.contrib.randomvariable.random_variable.RandomVariable
"""
from pyro.contrib.randomvariable import RandomVariable
return RandomVariable(self)
82 changes: 82 additions & 0 deletions tests/contrib/randomvariable/test_random_variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import torch.tensor as tt
from pyro.distributions import Uniform

N_SAMPLES = 100


def test_add():
X = Uniform(0, 1).rv # (0, 1)
X = X + 1 # (1, 2)
X = 1 + X # (2, 3)
X += 1 # (3, 4)
x = X.dist.sample([N_SAMPLES])
assert ((3 <= x) & (x <= 4)).all().item()


def test_subtract():
X = Uniform(0, 1).rv # (0, 1)
X = 1 - X # (0, 1)
X = X - 1 # (-1, 0)
X -= 1 # (-2, -1)
x = X.dist.sample([N_SAMPLES])
assert ((-2 <= x) & (x <= -1)).all().item()


def test_multiply_divide():
X = Uniform(0, 1).rv # (0, 1)
X *= 4 # (0, 4)
X /= 2 # (0, 2)
x = X.dist.sample([N_SAMPLES])
assert ((0 <= x) & (x <= 2)).all().item()


def test_abs():
X = Uniform(0, 1).rv # (0, 1)
X = 2*(X - 0.5) # (-1, 1)
X = abs(X) # (0, 1)
x = X.dist.sample([N_SAMPLES])
assert ((0 <= x) & (x <= 1)).all().item()


def test_neg():
X = Uniform(0, 1).rv # (0, 1)
X = -X # (-1, 0)
x = X.dist.sample([N_SAMPLES])
assert ((-1 <= x) & (x <= 0)).all().item()


def test_pow():
X = Uniform(0, 1).rv # (0, 1)
X = X**2 # (0, 1)
x = X.dist.sample([N_SAMPLES])
assert ((0 <= x) & (x <= 1)).all().item()


def test_tensor_ops():
pi = 3.141592654
X = Uniform(0, 1).expand([5, 5]).rv
a = tt([[1, 2, 3, 4, 5]])
b = a.T
X = abs(pi*(-X + a - 3*b))
x = X.dist.sample()
assert x.shape == (5, 5)
assert (x >= 0).all().item()


def test_chaining():
X = (
Uniform(0, 1).rv # (0, 1)
.add(1) # (1, 2)
.pow(2) # (1, 4)
.mul(2) # (2, 8)
.sub(5) # (-3, 3)
.tanh() # (-1, 1); more like (-0.995, +0.995)
.exp() # (1/e, e)
)
x = X.dist.sample([N_SAMPLES])
assert ((1/math.e <= x) & (x <= math.e)).all().item()