diff --git a/python/mxnet/gluon/__init__.py b/python/mxnet/gluon/__init__.py index 514087049edb..f43da1dae738 100644 --- a/python/mxnet/gluon/__init__.py +++ b/python/mxnet/gluon/__init__.py @@ -40,3 +40,5 @@ from . import model_zoo from . import contrib + +from . import probability diff --git a/python/mxnet/gluon/probability/__init__.py b/python/mxnet/gluon/probability/__init__.py new file mode 100644 index 000000000000..bcc32fb2833c --- /dev/null +++ b/python/mxnet/gluon/probability/__init__.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Probability module""" + +from .block import * + +from .distributions import * + +from .transformation import * diff --git a/python/mxnet/gluon/probability/block/__init__.py b/python/mxnet/gluon/probability/block/__init__.py new file mode 100644 index 000000000000..f6c817533c6b --- /dev/null +++ b/python/mxnet/gluon/probability/block/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Stochastic block.""" + +from .stochastic_block import * diff --git a/python/mxnet/gluon/probability/block/stochastic_block.py b/python/mxnet/gluon/probability/block/stochastic_block.py new file mode 100644 index 000000000000..64602145e613 --- /dev/null +++ b/python/mxnet/gluon/probability/block/stochastic_block.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=abstract-method +"""Stochastic block class.""" +__all__ = ['StochasticBlock', 'StochasticSequential'] + +from functools import wraps +from ...block import HybridBlock +from ...utils import _indent + + +class StochasticBlock(HybridBlock): + """`StochasticBlock` extends `HybridBlock` to support accumulating loss + in the forward phase, which is extremely useful in building Bayesian Neural Network, + where the loss function is composed of a classification loss and a KL loss. + + """ + + def __init__(self, **kwargs): + super(StochasticBlock, self).__init__(**kwargs) + self._losses = [] + self._losscache = [] + # Recording whether collectLoss is invoked. + self._flag = False + + def add_loss(self, loss): + self._losscache.append(loss) + + @staticmethod + def collectLoss(func): + """To accumulate loss during the forward phase, one could first decorate + hybrid_forward with `StochasticBlock.collectLoss, + and then collect the loss tensor `x` by calling self.add_loss(x). + For example, in the following forward function, + we generate samples from a Gaussian parameterized by `loc` and `scale` and + accumulate the KL-divergence between it and its prior into the block's loss storage.: + @StochasticBlock.collectLoss + def forward(self, loc, scale): + qz = mgp.Normal(loc, scale) + # prior + pz = mgp.Normal(np.zeros_like(loc), np.ones_like(scale)) + self.add_loss(mgp.kl_divergence(qz, pz)) + return qz.sample() + """ + @wraps(func) + def inner(self, *args, **kwargs): + # Loss from hybrid_forward + func_out = func(self, *args, **kwargs) + collected_loss = self._losscache + self._losscache = [] + self._flag = True + return (func_out, collected_loss) + + return inner + + def __call__(self, *args, **kwargs): + # pylint: disable=arguments-differ + self._flag = False + out = super().__call__(*args, **kwargs) + if not self._flag: + raise ValueError("The forward function should be decorated by " + + "StochasticBlock.collectLoss") + self._losses = out[1] + return out[0] + + @property + def losses(self): + return self._losses + + +class StochasticSequential(StochasticBlock): + """Stack StochasticBlock sequentially. + """ + + def __init__(self, **kwargs): + super(StochasticSequential, self).__init__(**kwargs) + self._layers = [] + + def add(self, *blocks): + """Adds block on top of the stack.""" + for block in blocks: + self._layers.append(block) + self.register_child(block) + + @StochasticBlock.collectLoss + def forward(self, x, *args): + # pylint: disable=arguments-differ + for block in self._children.values(): + x = block()(x, *args) + args = [] + if isinstance(x, (tuple, list)): + args = x[1:] + x = x[0] + if args: + x = tuple([x] + list(args)) + for block in self._layers: + if hasattr(block, '_losses'): + self.add_loss(block._losses) + return x + + def __repr__(self): + s = '{name}(\n{modstr}\n)' + modstr = '\n'.join([' ({key}): {block}'.format(key=key, + block=_indent(block().__repr__(), 2)) + for key, block in self._children.items()]) + return s.format(name=self.__class__.__name__, modstr=modstr) + + def __getitem__(self, key): + layers = list(self._children.values())[key] + if isinstance(layers, list): + net = type(self)() + net.add(*(l() for l in layers)) + return net + else: + return layers() + + def __len__(self): + return len(self._children) diff --git a/python/mxnet/gluon/probability/distributions/__init__.py b/python/mxnet/gluon/probability/distributions/__init__.py new file mode 100644 index 000000000000..3fd16273dc48 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/__init__.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Distribution classes.""" + +from .distribution import * + +from .exp_family import * + +from .exponential import * + +from .weibull import * + +from .pareto import * + +from .uniform import * + +from .normal import * + +from .laplace import * + +from .cauchy import * + +from .half_cauchy import * + +from .poisson import * + +from .geometric import * + +from .negative_binomial import * + +from .gamma import * + +from .dirichlet import * + +from .beta import * + +from .chi2 import * + +from .fishersnedecor import * + +from .studentT import * + +from .half_normal import * + +from .independent import * + +from .bernoulli import * + +from .binomial import * + +from .relaxed_bernoulli import * + +from .gumbel import * + +from .categorical import * + +from .one_hot_categorical import * + +from .relaxed_one_hot_categorical import * + +from .multinomial import * + +from .multivariate_normal import * + +from .transformed_distribution import * + +from .divergence import * + +from .utils import * diff --git a/python/mxnet/gluon/probability/distributions/bernoulli.py b/python/mxnet/gluon/probability/distributions/bernoulli.py new file mode 100644 index 000000000000..f61189c13bc6 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/bernoulli.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Bernoulli class.""" +__all__ = ['Bernoulli'] + +from .exp_family import ExponentialFamily +from .utils import prob2logit, logit2prob, getF, cached_property, sample_n_shape_converter +from .constraint import Boolean, Interval, Real + + +class Bernoulli(ExponentialFamily): + r"""Create a bernoulli distribution object. + + Parameters + ---------- + prob : Tensor or scalar, default None + Probability of sampling `1`. + logit : Tensor or scalar, default None + The log-odds of sampling `1`. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + support = Boolean() + arg_constraints = {'prob': Interval(0, 1), + 'logit': Real()} + + def __init__(self, prob=None, logit=None, F=None, validate_args=None): + _F = F if F is not None else getF(prob, logit) + + if (prob is None) == (logit is None): + raise ValueError( + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) + + if prob is not None: + self.prob = prob + else: + self.logit = logit + + super(Bernoulli, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + @cached_property + def prob(self): + """Get the probability of sampling `1`. + + Returns + ------- + Tensor + Parameter tensor. + """ + # pylint: disable=method-hidden + return logit2prob(self.logit, True, self.F) + + @cached_property + def logit(self): + """Get the log-odds of sampling `1`. + + Returns + ------- + Tensor + Parameter tensor. + """ + # pylint: disable=method-hidden + return prob2logit(self.prob, True, self.F) + + @property + def mean(self): + return self.prob + + @property + def variance(self): + return self.prob * (1 - self.prob) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + if 'prob' in self.__dict__: + new_instance.prob = F.np.broadcast_to(self.prob, batch_shape) + else: + new_instance.logit = F.np.broadcast_to(self.logit, batch_shape) + super(Bernoulli, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + if self.prob is None: + logit = self.logit + return logit * (value - 1) - F.np.log(F.np.exp(-logit) + 1) + else: + # Parameterized by probability + eps = 1e-12 + return (self.F.np.log(self.prob + eps) * value + + self.F.np.log1p(-self.prob + eps) * (1 - value)) + + def sample(self, size=None): + return self.F.npx.random.bernoulli(self.prob, self.logit, size) + + def sample_n(self, size=None): + return self.F.npx.random.bernoulli(self.prob, self.logit, sample_n_shape_converter(size)) + + @property + def _natural_params(self): + return (self.logit,) + + def _log_normalizer(self, x): + # pylint: disable=arguments-differ + return self.F.np.log(1 + self.F.np.exp(x)) + + def entropy(self): + F = self.F + logit = self.logit + prob = self.prob + return -(logit * (prob - 1) - F.np.log(F.np.exp(-logit) + 1)) diff --git a/python/mxnet/gluon/probability/distributions/beta.py b/python/mxnet/gluon/probability/distributions/beta.py new file mode 100644 index 000000000000..dea7dc728a8d --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/beta.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Beta Distribution.""" +__all__ = ['Beta'] + +from .exp_family import ExponentialFamily +from .constraint import UnitInterval, Positive +from .utils import getF, sample_n_shape_converter, gammaln, digamma, _clip_prob + + +class Beta(ExponentialFamily): + r"""Create a Beta distribution object. + + Parameters + ---------- + alpha : Tensor or scalar + The first shape parameter + beta : Tensor or scalar + The second shape parameter + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + """ + # pylint: disable=abstract-method + + has_grad = False + support = UnitInterval() + arg_constraints = {'alpha': Positive(), + 'beta': Positive()} + + def __init__(self, alpha, beta, F=None, validate_args=None): + _F = F if F is not None else getF(alpha, beta) + self.alpha = alpha + self.beta = beta + super(Beta, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + def sample(self, size=None): + F = self.F + X = F.np.random.gamma(self.alpha, 1, size=size) + Y = F.np.random.gamma(self.beta, 1, size=size) + out = X / (X + Y) + return _clip_prob(out, F) + + def sample_n(self, size=None): + return self.sample(sample_n_shape_converter(size)) + + @property + def mean(self): + a = self.alpha + b = self.beta + return a / (a + b) + + @property + def variance(self): + a = self.alpha + b = self.beta + return (a * b / + ((a + b) ** 2 * (a + b + 1))) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + lgamma = gammaln(F) + log = F.np.log + log1p = F.np.log1p + a = self.alpha + b = self.beta + lgamma_term = lgamma(a + b) - lgamma(a) - lgamma(b) + return (a - 1) * log(value) + (b - 1) * log1p(-value) + lgamma_term + + def entropy(self): + F = self.F + lgamma = gammaln(F) + dgamma = digamma(F) + a = self.alpha + b = self.beta + lgamma_term = lgamma(a + b) - lgamma(a) - lgamma(b) + return (-lgamma_term - (a - 1) * dgamma(a) - (b - 1) * dgamma(b) + + (a + b - 2) * dgamma(a + b)) diff --git a/python/mxnet/gluon/probability/distributions/binomial.py b/python/mxnet/gluon/probability/distributions/binomial.py new file mode 100644 index 000000000000..e99acb5d0bba --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/binomial.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Binomial distribution class.""" +__all__ = ['Binomial'] + +from .distribution import Distribution +from .utils import prob2logit, logit2prob, getF, cached_property, sample_n_shape_converter +from .utils import gammaln +from .constraint import Interval, Real, NonNegativeInteger + + +class Binomial(Distribution): + r"""Create a binomial distribution object. + + Parameters + ---------- + n : scalar + Non-negative interger of Bernoulli trials to stop. + prob : Tensor or scalar, default None + Probability of sampling `1`. + logit : Tensor or scalar, default None + The log-odds of sampling `1`. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + support = NonNegativeInteger() + arg_constraints = {'prob': Interval(0, 1), + 'logit': Real()} + + def __init__(self, n=1, prob=None, logit=None, F=None, validate_args=None): + if (n < 0) or (n % 1 != 0): + raise ValueError( + "Expect `n` to be non-negative integer, received n={}".format(n)) + _F = F if F is not None else getF(n, prob, logit) + if (prob is None) == (logit is None): + raise ValueError( + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) + + if prob is not None: + self.prob = prob + else: + self.logit = logit + self.n = n + super(Binomial, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + @cached_property + def prob(self): + """Get the probability of sampling `1`. + + Returns + ------- + Tensor + Parameter tensor. + """ + # pylint: disable=method-hidden + return logit2prob(self.logit, True, self.F) + + @cached_property + def logit(self): + """Get the log-odds of sampling `1`. + + Returns + ------- + Tensor + Parameter tensor. + """ + # pylint: disable=method-hidden + return prob2logit(self.prob, True, self.F) + + @property + def mean(self): + return self.n * self.prob + + @property + def variance(self): + p = self.prob + return self.n * p * (1 - p) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + if 'prob' in self.__dict__: + new_instance.prob = F.np.broadcast_to(self.prob, batch_shape) + else: + new_instance.logit = F.np.broadcast_to(self.logit, batch_shape) + new_instance.n = self.n + super(Binomial, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + lgamma = gammaln(F) + binomal_coef = lgamma(self.n + 1) - lgamma(1 + + value) - lgamma(self.n - value + 1) + # log(prob) may have numerical issue. + unnormalized_log_prob = (value * F.np.log(self.prob) + + (self.n - value) * F.np.log1p(-self.prob)) + return binomal_coef + unnormalized_log_prob + + def sample(self, size=None): + F = self.F + if size is not None: + logit = F.np.broadcast_to(self.logit, size) + else: + logit = self.logit + expanded_logit = F.np.repeat( + F.np.expand_dims(logit, -1), int(self.n), -1) + return F.npx.random.bernoulli(logit=expanded_logit).sum(-1) + + def sample_n(self, size=None): + F = self.F + logit = self.logit + expanded_logit = F.np.repeat( + F.np.expand_dims(logit, -1), int(self.n), -1) + return F.npx.random.bernoulli( + logit=expanded_logit, + size=sample_n_shape_converter(size) + ).sum(-1) diff --git a/python/mxnet/gluon/probability/distributions/categorical.py b/python/mxnet/gluon/probability/distributions/categorical.py new file mode 100644 index 000000000000..8633ba979b32 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/categorical.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Categorical class.""" +__all__ = ['Categorical'] + +from .distribution import Distribution +from .utils import prob2logit, logit2prob, getF, cached_property, sample_n_shape_converter +from .constraint import Simplex, Real, IntegerInterval + + +class Categorical(Distribution): + """Create a categorical distribution object. + + Parameters + ---------- + num_events : Int + Number of events. + prob : Tensor + Probabilities of each event. + logit : Tensor + The log-odds of each event + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_enumerate_support = True + arg_constraints = {'prob': Simplex(), + 'logit': Real()} + + def __init__(self, num_events, prob=None, logit=None, F=None, validate_args=None): + _F = F if F is not None else getF(prob, logit) + if (num_events > 0): + num_events = int(num_events) + self.num_events = num_events + else: + raise ValueError("`num_events` should be greater than zero. " + + "Received num_events={}".format(num_events)) + if (prob is None) == (logit is None): + raise ValueError( + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) + + if prob is not None: + self.prob = prob + else: + self.logit = logit + + super(Categorical, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + @cached_property + def prob(self): + # pylint: disable=method-hidden + """Get the probability of sampling each class. + + Returns + ------- + Tensor + Parameter tensor. + """ + return logit2prob(self.logit, False, self.F) + + @cached_property + def logit(self): + # pylint: disable=method-hidden + """Get the log probability of sampling each class. + + Returns + ------- + Tensor + Parameter tensor. + """ + return prob2logit(self.prob, False, self.F) + + @property + def support(self): + return IntegerInterval(0, self.num_events) + + def log_prob(self, value): + """Compute the log-likelihood of `value` + + Parameters + ---------- + value : Tensor + samples from Categorical distribution + + Returns + ------- + Tensor + log-likelihood of `value` + """ + if self._validate_args: + self._validate_samples(value) + F = self.F + logit = self.logit + indices = F.np.expand_dims(value, -1).astype('int') + expanded_logit = logit * F.np.ones_like(logit + indices) + return F.npx.pick(expanded_logit, indices).squeeze() + + def sample(self, size=None): + """Sample from categorical distribution. + Given logit/prob of size `(batch_size, num_events)`, + `batch_size` samples will be drawn. + If `size` is given, `np.broadcast(size, batch_size)` samples will be drawn. + + Parameters + ---------- + size : int or tuple of ints + + Returns + ------- + out : Tensor + Samples from the categorical distribution. + """ + F = self.F + if size is None: + size = () + logit = self.logit + else: + if isinstance(size, int): + logit = F.np.broadcast_to(self.logit, (size,) + (-2,)) + else: + logit = F.np.broadcast_to(self.logit, size + (-2,)) + gumbel_samples = F.np.random.gumbel(logit) + return F.np.argmax(gumbel_samples, axis=-1) + + def sample_n(self, size=None): + F = self.F + size = sample_n_shape_converter(size) + gumbel_samples = F.np.random.gumbel(self.logit, size=size) + return F.np.argmax(gumbel_samples, axis=-1) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance.prob = F.np.broadcast_to(self.prob, batch_shape + (-2,)) + new_instance.logit = F.np.broadcast_to(self.logit, batch_shape + (-2,)) + new_instance.num_events = self.num_events + super(Categorical, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def enumerate_support(self): + num_events = self.num_events + F = self.F + value = F.npx.arange_like(self.logit) % num_events + return F.np.moveaxis(value, -1, 0) diff --git a/python/mxnet/gluon/probability/distributions/cauchy.py b/python/mxnet/gluon/probability/distributions/cauchy.py new file mode 100644 index 000000000000..90e16b93a8de --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/cauchy.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Cauchy distribution""" + +__all__ = ['Cauchy'] + +from numbers import Number +from numpy import nan, pi +from .constraint import Real +from .distribution import Distribution +from .utils import getF, sample_n_shape_converter + + +class Cauchy(Distribution): + r"""Create a relaxed Cauchy distribution object. + + Parameters + ---------- + loc : Tensor or scalar, default 0 + mode or median of the distribution + scale : Tensor or scalar, default 1 + half width at half maximum + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + support = Real() + arg_constraints = {'loc': Real(), 'scale': Real()} + + def __init__(self, loc=0.0, scale=1.0, F=None, validate_args=None): + _F = F if F is not None else getF(loc, scale) + self.loc = loc + self.scale = scale + super(Cauchy, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + @property + def mean(self): + return nan + + @property + def variance(self): + return nan + + def sample(self, size=None): + # TODO: Implement sampling op in the backend. + F = self.F + # `np.zeros_like` does not support scalar at this moment. + if (isinstance(self.loc, Number), isinstance(self.scale, Number)) == (True, True): + u = F.np.random.uniform(size=size) + else: + u = F.np.random.uniform(F.np.zeros_like( + self.loc + self.scale), size=size) + return self.icdf(u) + + def sample_n(self, size=None): + return self.sample(sample_n_shape_converter(size)) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + log = self.F.np.log + return (-log(pi) - log(self.scale) - + log(1 + ((value - self.loc) / self.scale) ** 2)) + + def cdf(self, value): + if self._validate_args: + self._validate_samples(value) + return self.F.np.arctan((value - self.loc) / self.scale) / pi + 0.5 + + def icdf(self, value): + return self.F.np.tan(pi * (value - 0.5)) * self.scale + self.loc + + def entropy(self): + log = self.F.np.log + return log(4 * pi) + log(self.scale) diff --git a/python/mxnet/gluon/probability/distributions/chi2.py b/python/mxnet/gluon/probability/distributions/chi2.py new file mode 100644 index 000000000000..7b74683cb09c --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/chi2.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Chi-sqaure distribution""" +__all__ = ['Chi2'] + +from .gamma import Gamma +from .constraint import Positive + + +class Chi2(Gamma): + r"""Create a Chi2 distribution object. + Chi2(df) is equivalent to Gamma(shape=df / 2, scale=2) + + Parameters + ---------- + df : Tensor or scalar, default 0 + Shape parameter of the distribution. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + arg_constraints = {'df': Positive()} + + def __init__(self, df, F=None, validate_args=None): + super(Chi2, self).__init__(df / 2, 2, F, validate_args) + + @property + def df(self): + return self.shape * 2 diff --git a/python/mxnet/gluon/probability/distributions/constraint.py b/python/mxnet/gluon/probability/distributions/constraint.py new file mode 100644 index 000000000000..a27850f08e51 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/constraint.py @@ -0,0 +1,548 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Base class and implementations of constraint""" +__all__ = ["Constraint", "Real", "Boolean", + "Interval", "OpenInterval", "HalfOpenInterval", "UnitInterval", + "IntegerInterval", "IntegerOpenInterval", "IntegerHalfOpenInterval", + "GreaterThan", "GreaterThanEq", "IntegerGreaterThan", "IntegerGreaterThanEq", + "LessThan", "LessThanEq", "IntegerLessThan", "IntegerLessThanEq", + "Positive", "NonNegative", "PositiveInteger", "NonNegativeInteger", + "Simplex", "LowerTriangular", "LowerCholesky", "PositiveDefinite", + "Cat", "Stack"] + +from .utils import getF, constraint_check +from .... import ndarray as nd + + +class Constraint(object): + """Base class for constraints. + + A constraint object represents a region over which a variable + is valid. + """ + + def check(self, value): + """Check if `value` satisfies the constraint, + return the origin value if valid, + raise `ValueError` with given message otherwise. + + Parameters + ---------- + value : Tensor + Input tensor to be checked. + """ + raise NotImplementedError + + +class _Dependent(Constraint): + """ + Placeholder for variables whose support depends on other variables. + """ + + def check(self, value): + raise ValueError('Cannot validate dependent constraint') + + +def is_dependent(constraint): + return isinstance(constraint, _Dependent) + + +class _DependentProperty(property, _Dependent): + """ + Decorator that extends @property to act like a `_Dependent` constraint when + called on a class and act like a property when called on an object. + Example:: + class Uniform(Distribution): + def __init__(self, low, high): + self.low = low + self.high = high + @constraint.dependent_property + def support(self): + return constraint.Interval(self.low, self.high) + """ + pass # pylint: disable=unnecessary-pass + + +class Real(Constraint): + """ + Constrain to be a real number. (exclude `np.nan`) + """ + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be a real tensor".format( + value) + # False when value has NANs + condition = (value == value) # pylint: disable=comparison-with-itself + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class Boolean(Constraint): + """ + Constrain to `{0, 1}`. + """ + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be either 0 or 1.".format( + value) + condition = (value == 0) | (value == 1) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class Interval(Constraint): + """ + Constrain to a real interval `[lower_bound, upper_bound]` + """ + + def __init__(self, lower_bound, upper_bound): + super(Interval, self).__init__() + self._lower_bound = lower_bound + self._upper_bound = upper_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be >= {} and <= {}.".format( + value, self._lower_bound, self._upper_bound) + condition = (value >= self._lower_bound) & (value <= self._upper_bound) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class OpenInterval(Constraint): + """ + Constrain to a real interval `(lower_bound, upper_bound)` + """ + + def __init__(self, lower_bound, upper_bound): + super(OpenInterval, self).__init__() + self._lower_bound = lower_bound + self._upper_bound = upper_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be > {} and < {}.".format( + value, self._lower_bound, self._upper_bound) + condition = (value > self._lower_bound) & (value < self._upper_bound) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class HalfOpenInterval(Constraint): + """ + Constrain to a real interval `[lower_bound, upper_bound)` + """ + + def __init__(self, lower_bound, upper_bound): + super(HalfOpenInterval, self).__init__() + self._lower_bound = lower_bound + self._upper_bound = upper_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be >= {} and < {}.".format( + value, self._lower_bound, self._upper_bound) + condition = (value >= self._lower_bound) & (value < self._upper_bound) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class IntegerInterval(Constraint): + """ + Constrain to an integer interval `[lower_bound, upper_bound]` + """ + + def __init__(self, lower_bound, upper_bound): + super(IntegerInterval, self).__init__() + self._lower_bound = lower_bound + self._upper_bound = upper_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be integer and be >= {} and <= {}.".format( + value, self._lower_bound, self._upper_bound) + condition = value % 1 == 0 + condition = condition & (value >= self._lower_bound) & ( + value <= self._upper_bound) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class IntegerOpenInterval(Constraint): + """ + Constrain to an integer interval `(lower_bound, upper_bound)` + """ + + def __init__(self, lower_bound, upper_bound): + super(IntegerOpenInterval, self).__init__() + self._lower_bound = lower_bound + self._upper_bound = upper_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be integer and be > {} and < {}.".format( + value, self._lower_bound, self._upper_bound) + condition = value % 1 == 0 + condition = condition & (value > self._lower_bound) & ( + value < self._upper_bound) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class IntegerHalfOpenInterval(Constraint): + """ + Constrain to an integer interval `[lower_bound, upper_bound)` + """ + + def __init__(self, lower_bound, upper_bound): + super(IntegerHalfOpenInterval, self).__init__() + self._lower_bound = lower_bound + self._upper_bound = upper_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be integer and be >= {} and < {}.".format( + value, self._lower_bound, self._upper_bound) + condition = value % 1 == 0 + condition = condition & (value >= self._lower_bound) & ( + value < self._upper_bound) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class GreaterThan(Constraint): + """ + Constrain to be greater than `lower_bound`. + """ + + def __init__(self, lower_bound): + super(GreaterThan, self).__init__() + self._lower_bound = lower_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be greater than {}".format( + value, self._lower_bound) + condition = value > self._lower_bound + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class UnitInterval(Interval): + """ + Constrain to an unit interval `[0, 1]` + """ + + def __init__(self): + super(UnitInterval, self).__init__(0, 1) + + +class GreaterThanEq(Constraint): + """ + Constrain to be greater than or equal to `lower_bound`. + """ + + def __init__(self, lower_bound): + super(GreaterThanEq, self).__init__() + self._lower_bound = lower_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be greater than or equal to {}".format( + value, self._lower_bound) + condition = value >= self._lower_bound + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class LessThan(Constraint): + """ + Constrain to be less than `upper_bound`. + """ + + def __init__(self, upper_bound): + super(LessThan, self).__init__() + self._upper_bound = upper_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be less than {}".format( + value, self._upper_bound) + condition = value < self._upper_bound + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class LessThanEq(Constraint): + """ + Constrain to be less than `upper_bound`. + """ + + def __init__(self, upper_bound): + super(LessThanEq, self).__init__() + self._upper_bound = upper_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be less than or equal to {}".format( + value, self._upper_bound) + condition = value <= self._upper_bound + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class IntegerGreaterThan(Constraint): + """ + Constrain to be integer and be greater than `lower_bound`. + """ + + def __init__(self, lower_bound): + super(IntegerGreaterThan, self).__init__() + self._lower_bound = lower_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be integer and be greater than {}".format( + value, self._lower_bound) + condition = value % 1 == 0 + condition = F.np.bitwise_and(condition, value > self._lower_bound) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class IntegerGreaterThanEq(Constraint): + """ + Constrain to be integer and be greater than or equal to `lower_bound`. + """ + + def __init__(self, lower_bound): + super(IntegerGreaterThanEq, self).__init__() + self._lower_bound = lower_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be integer and" \ + " be greater than or equal to {}".format( + value, self._lower_bound) + condition = value % 1 == 0 + condition = F.np.bitwise_and(condition, value >= self._lower_bound) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class IntegerLessThan(Constraint): + """ + Constrain to be integer and be less than `upper_bound`. + """ + + def __init__(self, upper_bound): + super(IntegerLessThan, self).__init__() + self._upper_bound = upper_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be integer and be less than {}".format( + value, self._upper_bound) + condition = value % 1 == 0 + condition = condition & (value < self._upper_bound) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class IntegerLessThanEq(Constraint): + """ + Constrain to be integer and be less than or equal to `upper_bound`. + """ + + def __init__(self, upper_bound): + super(IntegerLessThanEq, self).__init__() + self._upper_bound = upper_bound + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be integer and" \ + " be less than or equal to {}".format( + value, self._upper_bound) + condition = value % 1 == 0 + condition = condition & (value <= self._upper_bound) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class Positive(GreaterThan): + """ + Constrain to be greater than zero. + """ + + def __init__(self): + super(Positive, self).__init__(0) + + +class NonNegative(GreaterThanEq): + """ + Constrain to be greater than or equal to zero. + """ + + def __init__(self): + super(NonNegative, self).__init__(0) + + +class PositiveInteger(IntegerGreaterThan): + """ + Constrain to be positive integer. + """ + + def __init__(self): + super(PositiveInteger, self).__init__(0) + + +class NonNegativeInteger(IntegerGreaterThanEq): + """ + Constrain to be non-negative integer. + """ + + def __init__(self): + super(NonNegativeInteger, self).__init__(0) + + +class Simplex(Constraint): + """ + Constraint to the simplex that rightmost dimension lies on a simplex. + `x >= 0` and `x.sum(-1) == 1`. + """ + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be >= 0 and" \ + " its rightmost dimension should sum up to 1".format(value) + condition = F.np.all(value >= 0, axis=-1) + condition = condition & (F.np.abs(value.sum(-1) - 1) < 1e-6) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class LowerTriangular(Constraint): + """ + Constraint to square lower triangular matrices. + """ + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be" \ + " square lower triangular matrices".format(value) + condition = F.np.tril(value) == value + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class LowerCholesky(Constraint): + """ + Constraint to square lower triangular matrices with real and positive diagonal entries. + """ + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be" \ + " square lower triangular matrices" \ + " with real and positive diagonal entries".format(value) + condition = F.np.all(F.np.tril(value) == value, axis=-1) + condition = condition & (F.np.diagonal(value, axis1=-2, axis2=-1) > 0) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class PositiveDefinite(Constraint): + """ + Constraint to positive-definite matrices. + """ + + def check(self, value): + F = getF(value) + err_msg = "Constraint violated: {} should be" \ + " positive definite matrices".format(value) + eps = 1e-5 + condition = F.np.all( + F.np.abs(value - F.np.swapaxes(value, -1, -2)) < eps, axis=-1) + condition = condition & (F.np.linalg.eigvals(value) > 0) + _value = constraint_check(F)(condition, err_msg) * value + return _value + + +class Cat(Constraint): + """ + Constraint functor that applies a sequence of constraints + `constraint_seq` at the submatrices at `axis`, each of size `lengths[axis]`, + in compatible with :func:`np.concatenate`. + """ + + def __init__(self, constraint_seq, axis=0, lengths=None): + assert all(isinstance(c, Constraint) for c in constraint_seq) + self._constraint_seq = list(constraint_seq) + if lengths is None: + lengths = [1] * len(self._constraint_seq) + self._lengths = list(lengths) + assert len(self._lengths) == len(self._constraint_seq),\ + "The number of lengths {} should be equal to number" \ + " of constraints {}".format( + len(self._lengths), len(self._constraint_seq)) + self._axis = axis + + def check(self, value): + F = getF(value) + _values = [] + start = 0 + for length in self._lengths: + v = F.np.take(value, indices=F.np.arange( + start, start + length), axis=self._axis) + _values.append(v) + start = start + length + _value = F.np.concatenate(_values, self._axis) + return _value + + +class Stack(Constraint): + """ + Constraint functor that applies a sequence of constraints + `constraint_seq` at the submatrices at `axis`, + in compatible with :func:`np.stack`. + + Stack is currently only supported in imperative mode. + """ + + def __init__(self, constraint_seq, axis=0): + assert all(isinstance(c, Constraint) for c in constraint_seq) + self._constraint_seq = list(constraint_seq) + self._axis = axis + + def check(self, value): + F = getF(value) + assert F is nd, "mxnet.probability.distributions.constraint.Stack" \ + " is only supported when hybridization is turned off" + size = value.shape[self._axis] + value_array = F.np.split(value, size, axis=self._axis) + value_array = [constraint.check(F.np.squeeze(v)) for v, constraint + in zip(value_array, self._constraint_seq)] + _value = F.np.stack(value_array, self._axis) + return _value + + +dependent_property = _DependentProperty diff --git a/python/mxnet/gluon/probability/distributions/dirichlet.py b/python/mxnet/gluon/probability/distributions/dirichlet.py new file mode 100644 index 000000000000..205b5bb9e9e5 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/dirichlet.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Dirichlet Distribution.""" +__all__ = ['Dirichlet'] + +from .exp_family import ExponentialFamily +from .constraint import Positive, Simplex +from .utils import getF, gammaln, digamma, sample_n_shape_converter, _clip_float_eps + + +class Dirichlet(ExponentialFamily): + r"""Create a Dirichlet distribution object. + + Parameters + ---------- + alpha : Tensor or scalar + Shape parameter of the distribution + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + """ + # pylint: disable=abstract-method + + has_grad = False + support = Simplex() + arg_constraints = {'alpha': Positive()} + + def __init__(self, alpha, F=None, validate_args=None): + _F = F if F is not None else getF(alpha) + self.alpha = alpha + super(Dirichlet, self).__init__( + F=_F, event_dim=1, validate_args=validate_args) + + def sample(self, size=None): + F = self.F + if size is None: + size = () + alpha = self.alpha + else: + if isinstance(size, int): + alpha = F.np.broadcast_to(self.alpha, (size,) + (-2,)) + else: + alpha = F.np.broadcast_to(self.alpha, size + (-2,)) + gamma_samples = F.np.random.gamma(alpha, 1) + s = gamma_samples.sum(-1, keepdims=True) + return _clip_float_eps(gamma_samples / s, F) + + def sample_n(self, size=None): + F = self.F + alpha = self.alpha + if size is None: + return self.sample() + gamma_samples = F.np.random.gamma( + alpha, 1, sample_n_shape_converter(size)) + s = gamma_samples.sum(-1, keepdims=True) + return _clip_float_eps(gamma_samples / s, F) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + lgamma = gammaln(F) + log = F.np.log + alpha = self.alpha + return (log(value) * (alpha - 1.0)).sum(-1) +\ + lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + + @property + def mean(self): + alpha = self.alpha + return alpha / alpha.sum(-1, keepdims=True) + + @property + def variance(self): + a = self.alpha + s = a.sum(-1, keepdims=True) + return a * (s - a) / ((s + 1) * s ** 2) + + def entropy(self): + F = self.F + lgamma = gammaln(F) + dgamma = digamma(F) + a0 = self.alpha.sum(-1) + log_B_alpha = lgamma(self.alpha).sum(-1) - lgamma(a0) + return (log_B_alpha + (self.alpha - 1).sum(-1) * dgamma(a0) - + ((self.alpha - 1) * dgamma(self.alpha)).sum(-1)) diff --git a/python/mxnet/gluon/probability/distributions/distribution.py b/python/mxnet/gluon/probability/distributions/distribution.py new file mode 100644 index 000000000000..36002485ff1d --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/distribution.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Base distribution class.""" +__all__ = ['Distribution'] + +from numbers import Number +from .utils import cached_property + + +class Distribution(object): + r"""Base class for distribution. + + Parameters + ---------- + F : mx.ndarray or mx.symbol.numpy._Symbol + Variable that stores the running mode. + event_dim : int, default None + Variable indicating the dimension of the distribution's support. + validate_args : bool, default None + Whether to validate the distribution parameters + """ + + # Variable indicating whether the sampling method has + # pathwise gradient. + has_grad = False + support = None + has_enumerate_support = False + arg_constraints = {} + _validate_args = False + + @staticmethod + def set_default_validate_args(value): + if value not in [True, False]: + raise ValueError + Distribution._validate_args = value + + def __init__(self, F=None, event_dim=None, validate_args=None): + self.F = F + self.event_dim = event_dim + if validate_args is not None: + self._validate_args = validate_args + if self._validate_args: + for param, constraint in self.arg_constraints.items(): + if param not in self.__dict__ and isinstance(getattr(type(self), param), + cached_property): + # skip param that is decorated by cached_property + continue + setattr(self, param, constraint.check(getattr(self, param))) + super(Distribution, self).__init__() + + def log_prob(self, value): + r""" + Returns the log of the probability density/mass function evaluated at `value`. + """ + raise NotImplementedError() + + def pdf(self, value): + r""" + Returns the probability density/mass function evaluated at `value`. + """ + return self.F.np.exp(self.log_prob(value)) + + def cdf(self, value): + r""" + Returns the cumulative density/mass function evaluated at `value`. + """ + raise NotImplementedError + + def icdf(self, value): + r""" + Returns the inverse cumulative density/mass function evaluated at `value`. + """ + raise NotImplementedError + + def sample(self, size=None): + r""" + Generates a `shape` shaped sample. + """ + raise NotImplementedError + + def sample_n(self, size): + r""" + Generate samples of (n + parameter_shape) from the distribution. + """ + raise NotImplementedError + + def broadcast_to(self, batch_shape): + r""" + Returns a new distribution instance with parameters expanded + to `batch_shape`. This method calls `numpy.broadcast_to` on + the parameters. + + Parameters + ---------- + batch_shape : Tuple + The batch shape of the desired distribution. + + """ + raise NotImplementedError + + def enumerate_support(self): + r""" + Returns a tensor that contains all values supported + by a discrete distribution. + """ + raise NotImplementedError + + @property + def arg_constraints(self): + """ + Returns a dictionary from parameter names to + :class:`~mxnet.gluon.probability.distributions.constraint.Constraint` objects that + should be satisfied by each parameter of this distribution. Args that + are not ndarray/symbol need not appear in this dict. + """ + # pylint: disable=function-redefined + raise NotImplementedError + + @property + def mean(self): + r""" + Returns the mean of the distribution. + """ + raise NotImplementedError + + @property + def variance(self): + r""" + Returns the variance of the distribution. + """ + raise NotImplementedError + + @property + def stddev(self): + """ + Returns the standard deviation of the distribution. + """ + return self.variance.sqrt() + + @property + def support(self): + r""" + Returns a function representing the distribution's support. + """ + # pylint: disable=function-redefined + raise NotImplementedError + + def entropy(self): + r""" + Returns entropy of distribution. + """ + raise NotImplementedError + + def perplexity(self): + r""" + Returns perplexity of distribution. + """ + F = self.F + return F.np.exp(self.entropy()) + + def __repr__(self): + mode = self.F + args_string = '' + if 'symbol' not in mode.__name__: + for k, _ in self.arg_constraints.items(): + v = self.__dict__[k] + if isinstance(v, Number): + shape_v = () + else: + shape_v = v.shape + args_string += '{}: size {}'.format(k, shape_v) + ', ' + args_string += ', '.join(['F: {}'.format(mode.__name__), + 'event_dim: {}'.format(self.event_dim)]) + return self.__class__.__name__ + '(' + args_string + ')' + + def _validate_samples(self, value): + """ + Validate samples for methods like `log_prob`, `cdf`. + Check if `value` lies in `self.support` + """ + return self.support.check(value) diff --git a/python/mxnet/gluon/probability/distributions/divergence.py b/python/mxnet/gluon/probability/distributions/divergence.py new file mode 100644 index 000000000000..f58c578edd2f --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/divergence.py @@ -0,0 +1,382 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""KL divergence functions.""" +__all__ = ['register_kl', 'kl_divergence', 'empirical_kl'] + +import math +import numpy as _np + +from .utils import gammaln, digamma +from .exponential import Exponential +from .pareto import Pareto +from .uniform import Uniform +from .normal import Normal +from .laplace import Laplace +from .cauchy import Cauchy +from .poisson import Poisson +from .geometric import Geometric +from .gamma import Gamma +from .dirichlet import Dirichlet +from .beta import Beta +from .half_normal import HalfNormal +from .bernoulli import Bernoulli +from .binomial import Binomial +from .gumbel import Gumbel +from .categorical import Categorical +from .one_hot_categorical import OneHotCategorical +from .multivariate_normal import MultivariateNormal + + +def empirical_kl(p, q, n_samples=1): + r"""Estimate KL(p||q) through monte-carlo estimation, i.e. approximate + KL(p||q) with: + + 1/M * \Sum_{i=1}^{M} log(p(x_i) / q(x_i)), x_i ~ p(x) + + Parameters + ---------- + p : Distribution + q : Distribution + n_samples : int, optional + Number of monte-carlo samples, by default 1 + """ + samples = p.sample_n(n_samples) + return (p.log_prob(samples) - q.log_prob(samples)).mean(0) + + +def register_kl(typeP, typeQ): + """Decorator for registering custom implementation of kl divergence between + distribution `typeP` and `typeQ` + + Returns + ------- function + """ + func_name = "_kl_" + str(typeP.__name__) \ + + "_" + str(typeQ.__name__) + + def decorator(func): + func_arg_num = func.__code__.co_argcount + if (func_arg_num != 2): + raise TypeError('Expect kl_divergence implementation ' + + 'to have exactly two arguments, but got {}'.format(func_arg_num)) + if not hasattr(_KL_storage, func_name): + setattr(_KL_storage, func_name, func) + else: + # Behavior TBD. + print("Error: Duplicate definition") + return func + return decorator + + +def kl_divergence(p, q): + r""" + Return the kl divergence between p and q, + this method will automatically dispatch + to the corresponding function based on q's type. + + Parameters + ---------- + p : Distribution + lhs distribution. + q : Distribution + rhs distribution. + + Returns + ------- + Tensor + KL(p||q) + """ + func = _dispatch_kl(p.__class__.__name__, q.__class__.__name__) + return func(p, q) # pylint: disable=not-callable + + +def _dispatch_kl(type_p, type_q): + r"""KL divergence methods should be registered + with distribution name, + i.e. the implementation of KL(P(\theta)||Q(\theta)) + should be named after _kl_{P}_{Q} + + Parameters + ---------- + type_q : Typename of a distribution + type_q : Typename of a distribution + + + Returns + ------- + Get a class method with function name. + """ + func_name = "_kl_" + str(type_p) + "_" + str(type_q) + func_impl = getattr(_KL_storage, func_name, None) + if (not callable(func_impl)): + raise NotImplementedError( + "KL divergence between {} and {} is not implemented.".format(type_p, type_q)) + return func_impl + + +class _KL_storage(): + r"""Class for storing the definition of kl divergence + between distributions. + All the class methods should be static + """ + + @staticmethod + def _kl_Normal_Normal(p, q): + F = p.F + var_ratio = (p.scale / q.scale) ** 2 + t1 = ((p.loc - q.loc) / q.scale) ** 2 + return 0.5 * (var_ratio + t1 - 1 - F.np.log(var_ratio)) + + +@register_kl(Bernoulli, Bernoulli) +def _kl_bernoulli_bernoulli(p, q): + F = p.F + log_fn = F.np.log + prob_p = p.prob + prob_q = q.prob + t1 = prob_p * log_fn(prob_p / prob_q) + t2 = (1 - prob_p) * log_fn((1 - prob_p) / (1 - prob_q)) + return t1 + t2 + + +@register_kl(Categorical, Categorical) +def _kl_categorical_categorical(p, q): + return (p.prob * (p.logit - q.logit)).sum(-1) + + +@register_kl(OneHotCategorical, OneHotCategorical) +def _kl_onehotcategorical_onehotcategorical(p, q): + return _kl_categorical_categorical(p._categorical, q._categorical) + + +@register_kl(Uniform, Uniform) +def _kl_uniform_uniform(p, q): + F = p.F + result = F.np.log((q.high - q.low) / (p.high - p.low)) + result = F.np.where((q.low > p.low) | (q.high < p.high), _np.inf, result) + return result + + +@register_kl(Cauchy, Cauchy) +def _kl_cauchy_cauchy(p, q): + F = p.F + t1 = F.np.log((p.scale + q.scale) ** 2 + (p.loc - q.loc) ** 2) + t2 = F.np.log(4 * p.scale * q.scale) + return t1 - t2 + + +@register_kl(Laplace, Laplace) +def _kl_laplace_laplace(p, q): + F = p.F + scale_ratio = p.scale / q.scale + loc_abs_diff = F.np.abs(p.loc - q.loc) + t1 = -F.np.log(scale_ratio) + t2 = loc_abs_diff / q.scale + t3 = scale_ratio * F.np.exp(-loc_abs_diff / p.scale) + return t1 + t2 + t3 - 1 + + +@register_kl(Poisson, Poisson) +def _kl_poisson_poisson(p, q): + F = p.F + t1 = p.rate * (F.np.log(p.rate) - F.np.log(q.rate)) + t2 = (p.rate - q.rate) + return t1 - t2 + + +@register_kl(Geometric, Geometric) +def _kl_geometric_geometric(p, q): + F = p.F + return (-p.entropy() - F.np.log1p(-q.prob) / p.prob - q.logit) + + +@register_kl(Exponential, Exponential) +def _kl_exponential_exponential(p, q): + F = p.F + scale_ratio = p.scale / q.scale + t1 = -F.np.log(scale_ratio) + return t1 + scale_ratio - 1 + + +@register_kl(Pareto, Pareto) +def _kl_pareto_pareto(p, q): + F = p.F + scale_ratio = p.scale / q.scale + alpha_ratio = q.alpha / p.alpha + t1 = q.alpha * F.np.log(scale_ratio) + t2 = -F.np.log(alpha_ratio) + result = t1 + t2 + alpha_ratio - 1 + result = F.np.where(p.support._lower_bound < + q.support._lower_bound, _np.nan, result) + return result + + +@register_kl(Gumbel, Gumbel) +def _kl_gumbel_gumbel(p, q): + F = p.F + lgamma = gammaln(F) + _euler_gamma = _np.euler_gamma + ct1 = p.scale / q.scale + ct2 = q.loc / q.scale + ct3 = p.loc / q.scale + t1 = -F.np.log(ct1) - ct2 + ct3 + t2 = ct1 * _euler_gamma + t3 = F.np.exp(ct2 + lgamma(1 + ct1) - ct3) + return t1 + t2 + t3 - (1 + _euler_gamma) + + +@register_kl(Gamma, Gamma) +def _kl_gamma_gamma(p, q): + F = p.F + lgamma = gammaln(F) + dgamma = digamma(F) + return ( + q.shape * F.np.log(q.scale / p.scale) + + lgamma(q.shape) - lgamma(p.shape) + + (p.shape - q.shape) * dgamma(p.shape) + + (p.shape * p.scale) * (1 / q.scale - 1 / p.scale) + ) + + +@register_kl(Beta, Beta) +def _kl_beta_beta(p, q): + F = p.F + lgamma = gammaln(F) + dgamma = digamma(F) + sum_params_p = p.beta + p.alpha + sum_params_q = q.beta + q.alpha + t1 = lgamma(q.alpha) + lgamma(q.beta) + lgamma(sum_params_p) + t2 = lgamma(p.alpha) + lgamma(p.beta) + lgamma(sum_params_q) + t3 = (p.beta - q.beta) * dgamma(p.beta) + t4 = (p.alpha - q.alpha) * dgamma(p.alpha) + t5 = (sum_params_q - sum_params_p) * dgamma(sum_params_p) + return t1 - t2 + t3 + t4 + t5 + +# http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ + + +@register_kl(Dirichlet, Dirichlet) +def _kl_dirichlet_dirichlet(p, q): + F = p.F + lgamma = gammaln(F) + dgamma = digamma(F) + sum_p_concentration = p.alpha.sum(-1) + sum_q_concentration = q.alpha.sum(-1) + t1 = lgamma(sum_p_concentration) - lgamma(sum_q_concentration) + t2 = (lgamma(p.alpha) - lgamma(q.alpha)).sum(-1) + t3 = p.alpha - q.alpha + t4 = dgamma(p.alpha) - F.np.expand_dims(dgamma(sum_p_concentration), -1) + return t1 - t2 + (t3 * t4).sum(-1) + + +@register_kl(HalfNormal, HalfNormal) +def _kl_halfNormal_halfNormal(p, q): + F = p.F + var_ratio = (p.scale / q.scale) ** 2 + t1 = ((p.loc - q.loc) / q.scale) ** 2 + return 0.5 * (var_ratio + t1 - 1 - F.np.log(var_ratio)) + + +@register_kl(Binomial, Binomial) +def _kl_binomial_binomial(p, q): + F = p.F + kl = p.n * (p.prob * (p.logit - q.logit) + + F.np.log1p(-p.prob) - F.np.log1p(-q.prob)) + kl = F.np.where(p.n > q.n, _np.inf, kl) + return kl + + +@register_kl(MultivariateNormal, MultivariateNormal) +def _kl_mvn_mvn(p, q): + F = p.F + log_det = (lambda mvn: + F.np.log( + F.np.diagonal(mvn.scale_tril, axis1=-2, axis2=-1) + ).sum(-1) + ) + # log(det(\Sigma_1) / det(\Sigma_2)) + term1 = log_det(q) - log_det(p) + + # tr(inv(\Sigma_2) * \Sigma_1) + term2 = F.np.trace(F.np.matmul(q.precision, p.cov), axis1=-2, axis2=-1) + + # (\mu_2 - \mu_1).T * inv(\Sigma_2) * (\mu_2 - \mu_1) + diff = q.loc - p.loc + term3 = F.np.einsum( + '...i,...i->...', + diff, + # Batch matrix vector multiply + F.np.einsum('...jk,...j->...k', q.precision, diff) + ) * -0.5 + n = F.np.ones_like(diff).sum(-1) + return 0.5 * (term1 + term2 + term3 - n) + + +@register_kl(Uniform, Normal) +def _kl_uniform_normal(p, q): + F = p.F + common_term = p.high - p.low + t1 = F.np.log(math.sqrt(math.pi * 2) * q.scale / common_term) + t2 = (common_term) ** 2 / 12 + t3 = ((p.high + p.low - 2 * q.loc) / 2) ** 2 + return t1 + 0.5 * (t2 + t3) / (q.scale ** 2) + + +@register_kl(Uniform, Gumbel) +def _kl_uniform_gumbel(p, q): + F = p.F + common_term = q.scale / (p.high - p.low) + high_loc_diff = (p.high - q.loc) / q.scale + low_loc_diff = (p.low - q.loc) / q.scale + t1 = F.np.log(common_term) + 0.5 * (high_loc_diff + low_loc_diff) + t2 = common_term * (F.np.exp(-high_loc_diff) - F.np.exp(-low_loc_diff)) + return t1 - t2 + + +@register_kl(Exponential, Gumbel) +def _kl_exponential_gumbel(p, q): + F = p.F + scale_rate_prod = q.scale / p.scale + loc_scale_ratio = q.loc / q.scale + t1 = F.np.log(scale_rate_prod) - 1 + t2 = F.np.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1) + t3 = scale_rate_prod ** -1 + return t1 - loc_scale_ratio + t2 + t3 + + +@register_kl(Exponential, Normal) +def _kl_exponential_normal(p, q): + F = p.F + var_normal = q.variance + rate_sqr = p.scale ** (-2) + t1 = 0.5 * F.np.log(rate_sqr * var_normal * 2 * _np.pi) + t2 = rate_sqr ** -1 + t3 = q.loc * p.scale + t4 = (q.loc ** 2) * 0.5 + return t1 - 1 + (t2 - t3 + t4) / var_normal + + +@register_kl(Exponential, Gamma) +def _kl_exponential_gamma(p, q): + F = p.F + lgamma = gammaln(F) + ratio = p.scale / q.scale + t1 = -q.shape * F.np.log(ratio) + return t1 + ratio + lgamma(q.shape) + q.shape * _np.euler_gamma - (1 + _np.euler_gamma) diff --git a/python/mxnet/gluon/probability/distributions/exp_family.py b/python/mxnet/gluon/probability/distributions/exp_family.py new file mode 100644 index 000000000000..8bacc64983e7 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/exp_family.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Exponential family class""" +__all__ = ['ExponentialFamily'] + +from .distribution import Distribution + + +class ExponentialFamily(Distribution): + r""" + ExponentialFamily inherits from Distribution. ExponentialFamily is a base + class for distributions whose density function has the form: + p_F(x;\theta) = exp( + - + F(\theta) + + k(x) + ) where + t(x): sufficient statistics + \theta: natural parameters + F(\theta): log_normalizer + k(x): carrier measure + """ + + @property + def _natural_params(self): + r""" + Return a tuple that stores natural parameters of the distribution. + """ + raise NotImplementedError + + def _log_normalizer(self, *natural_params): + r""" + Return the log_normalizer F(\theta) based the natural parameters. + """ + raise NotImplementedError + + def _mean_carrier_measure(self, x): + r""" + Return the mean of carrier measure k(x) based on input x, + this method is required for calculating the entropy. + """ + raise NotImplementedError + + def entropy(self): + r""" + Return the entropy of a distribution. + The entropy of distributions in exponential families + could be computed by: + H(P) = F(\theta) - <\theta, F(\theta)'> - E_p[k(x)] + """ + raise NotImplementedError diff --git a/python/mxnet/gluon/probability/distributions/exponential.py b/python/mxnet/gluon/probability/distributions/exponential.py new file mode 100644 index 000000000000..19ddd58ed74b --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/exponential.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Exponential Distribution.""" +__all__ = ['Exponential'] + +from .exp_family import ExponentialFamily +from .constraint import Positive +from .utils import getF, sample_n_shape_converter, cached_property + + +class Exponential(ExponentialFamily): + r"""Create a Exponential distribution object parameterized by `scale`. + + Parameters + ---------- + scale : Tensor or scalar + Scale of the distribution. (scale = 1 /rate) + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + """ + # pylint: disable=abstract-method + + has_grad = True + support = Positive() + arg_constraints = {'scale': Positive()} + + def __init__(self, scale=1.0, F=None, validate_args=None): + _F = F if F is not None else getF(scale) + self.scale = scale + super(Exponential, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + @cached_property + def rate(self): + return 1 / self.scale + + @property + def mean(self): + return self.scale + + @property + def variance(self): + return self.scale ** 2 + + @property + def stddev(self): + return self.scale + + def sample(self, size=None): + return self.F.np.random.exponential(self.scale, size=size) + + def sample_n(self, size=None): + return self.F.np.random.exponential(self.scale, + size=sample_n_shape_converter(size)) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance.scale = F.np.broadcast_to(self.scale, batch_shape) + super(Exponential, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + return F.np.log(self.rate) - self.rate * value + + def cdf(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + return 1 - F.np.exp(-self.rate * value) + + def icdf(self, value): + F = self.F + return - self.scale * F.np.log(1 - value) + + def entropy(self): + F = self.F + return 1.0 + F.np.log(self.scale) + + @property + def _natural_params(self): + return (-self.rate,) + + def _log_normalizer(self, x): + # pylint: disable=arguments-differ + F = self.F + return -F.np.log(-x) diff --git a/python/mxnet/gluon/probability/distributions/fishersnedecor.py b/python/mxnet/gluon/probability/distributions/fishersnedecor.py new file mode 100644 index 000000000000..f4d06a2f3e8f --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/fishersnedecor.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Snedecor's F Distribution.""" +__all__ = ['FisherSnedecor'] + +from numpy import nan +from .distribution import Distribution +from .gamma import Gamma +from .constraint import Positive +from .utils import getF, gammaln + + +class FisherSnedecor(Distribution): + r"""Create a FisherSnedecor distribution object, often known as F distribution. + + Parameters + ---------- + df1 : Tensor or scalar + degree of freedom parameter 1 + scale : Tensor or scalar + degree of freedom parameter 2 + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + support = Positive() + arg_constraints = {'df1': Positive(), 'df2': Positive()} + + def __init__(self, df1, df2, F=None, validate_args=None): + _F = F if F is not None else getF(df1, df2) + self.df1 = df1 + self.df2 = df2 + self._gamma1 = Gamma(0.5 * self.df1, 1 / self.df1) + self._gamma2 = Gamma(0.5 * self.df2, 1 / self.df2) + super(FisherSnedecor, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance.df1 = F.np.broadcast_to(self.df1, batch_shape) + new_instance.df2 = F.np.broadcast_to(self.df2, batch_shape) + new_instance._gamma1 = self._gamma1.broadcast_to(batch_shape) + new_instance._gamma2 = self._gamma2.broadcast_to(batch_shape) + super(FisherSnedecor, new_instance).__init__(F=F, + event_dim=0, validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + @property + def mean(self): + # mean is only defined for df2 > 2 + df2 = self.F.np.where(self.df2 <= 2, nan, self.df2) + return df2 / (df2 - 2) + + @property + def variance(self): + # variance is only define for df2 > 4 + df2 = self.F.np.where(self.df2 <= 4, nan, self.df2) + df1 = self.df1 + numerator = 2 * df2 ** 2 * (df1 + df2 - 2) + denominator = df1 * (df2 - 2) ** 2 * (df2 - 4) + return numerator / denominator + + def sample(self, size=None): + X1 = self._gamma1.sample(size) + X2 = self._gamma2.sample(size) + return X1 / X2 + + def sample_n(self, size=None): + X1 = self._gamma1.sample_n(size) + X2 = self._gamma2.sample_n(size) + return X1 / X2 + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + lgamma = gammaln(F) + log = F.np.log + ct1 = self.df1 / 2 + ct2 = self.df2 / 2 + ct3 = self.df1 / self.df2 + t1 = lgamma(ct1 + ct2) - lgamma(ct1) - \ + lgamma(ct2) # Beta(df1/2, df2/2) + t2 = log(ct3) * ct1 + (ct1 - 1) * log(value) + t3 = (ct1 + ct2) * log(ct3 * value + 1) + return t1 + t2 - t3 diff --git a/python/mxnet/gluon/probability/distributions/gamma.py b/python/mxnet/gluon/probability/distributions/gamma.py new file mode 100644 index 000000000000..348a0a51f0a4 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/gamma.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Gamma Distribution.""" +__all__ = ['Gamma'] + +from .exp_family import ExponentialFamily +from .constraint import Real, Positive +from .utils import getF, sample_n_shape_converter, gammaln, digamma + + +class Gamma(ExponentialFamily): + r"""Create a Gamma distribution object. + + Parameters + ---------- + shape : Tensor or scalar + shape parameter of the distribution, often represented by `k` or `\alpha` + scale : Tensor or scalar, default 1 + scale parameter of the distribution, often represented by `\theta`, + `\theta` = 1 / `\beta`, where `\beta` stands for the rate parameter. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + # TODO: Implement implicit reparameterization gradient for Gamma. + has_grad = False + support = Real() + arg_constraints = {'shape': Positive(), 'scale': Positive()} + + def __init__(self, shape, scale=1.0, F=None, validate_args=None): + _F = F if F is not None else getF(shape, scale) + self.shape = shape + self.scale = scale + super(Gamma, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + log_fn = F.np.log + lgamma = gammaln(F) + # alpha (concentration) + a = self.shape + # beta (rate) + b = 1 / self.scale + return a * log_fn(b) + (a - 1) * log_fn(value) - b * value - lgamma(a) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance.shape = F.np.broadcast_to(self.shape, batch_shape) + new_instance.scale = F.np.broadcast_to(self.scale, batch_shape) + super(Gamma, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def sample(self, size=None): + return self.F.np.random.gamma(self.shape, self.scale, size) + + def sample_n(self, size=None): + return self.F.np.random.gamma(self.shape, self.scale, sample_n_shape_converter(size)) + + @property + def mean(self): + return self.shape * self.scale + + @property + def variance(self): + return self.shape * (self.scale ** 2) + + def entropy(self): + F = self.F + lgamma = gammaln(F) + dgamma = digamma(F) + return (self.shape + F.np.log(self.scale) + lgamma(self.shape) + + (1 - self.shape) * dgamma(self.shape)) + + @property + def _natural_params(self): + return (self.shape - 1, -1 / self.scale) diff --git a/python/mxnet/gluon/probability/distributions/geometric.py b/python/mxnet/gluon/probability/distributions/geometric.py new file mode 100644 index 000000000000..170edfec9912 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/geometric.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Geometric distribution class.""" +__all__ = ['Geometric'] + +from numbers import Number +from .distribution import Distribution +from .utils import prob2logit, logit2prob, getF, cached_property, sample_n_shape_converter +from .constraint import NonNegativeInteger, Interval, Real + + +class Geometric(Distribution): + r"""Create a geometric distribution object. + + Parameters + ---------- + prob : Tensor or scalar, default None + Probability of sampling `1`. + logit : Tensor or scalar, default None + The log-odds of sampling `1`. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + support = NonNegativeInteger() + arg_constraints = {'prob': Interval(0, 1), + 'logit': Real()} + + def __init__(self, prob=None, logit=None, F=None, validate_args=None): + _F = F if F is not None else getF(prob, logit) + if (prob is None) == (logit is None): + raise ValueError( + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) + + if prob is not None: + self.prob = prob + else: + self.logit = logit + super(Geometric, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + @cached_property + def prob(self): + """Get the probability of sampling `1`. + + Returns + ------- + Tensor + Parameter tensor. + """ + # pylint: disable=method-hidden + return logit2prob(self.logit, True, self.F) + + @cached_property + def logit(self): + """Get the log-odds of sampling `1`. + + Returns + ------- + Tensor + Parameter tensor. + """ + # pylint: disable=method-hidden + return prob2logit(self.prob, True, self.F) + + @property + def mean(self): + return 1 / self.prob - 1 + + @property + def variance(self): + return (1 / self.prob - 1) / self.prob + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + if 'prob' in self.__dict__: + new_instance.prob = F.np.broadcast_to(self.prob, batch_shape) + else: + new_instance.logit = F.np.broadcast_to(self.logit, batch_shape) + super(Geometric, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + prob = self.prob + return value * F.np.log1p(-prob) + F.np.log(prob) + + def sample(self, size=None): + F = self.F + if isinstance(self.prob, Number): + shape_tensor = F.np.zeros(()) + else: + shape_tensor = F.np.zeros_like(self.prob) + u = F.np.random.uniform(shape_tensor, size=size) + samples = F.np.floor( + F.np.log(u) / F.np.log1p(-self.prob) + ) + return samples + + def sample_n(self, size=None): + return self.sample(sample_n_shape_converter(size)) + + def entropy(self): + F = self.F + logit = self.logit + prob = self.prob + return -(logit * (prob - 1) - F.np.log1p(F.np.exp(-logit))) / prob diff --git a/python/mxnet/gluon/probability/distributions/gumbel.py b/python/mxnet/gluon/probability/distributions/gumbel.py new file mode 100644 index 000000000000..7094a5a0d90a --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/gumbel.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Gumbel Distribution.""" +__all__ = ['Gumbel'] + +import math +from numpy import euler_gamma # Euler-Mascheroni constant +from .distribution import Distribution +from .constraint import Real, Positive +from .utils import getF, sample_n_shape_converter + + +class Gumbel(Distribution): + r"""Create a Gumble distribution object + + Parameters + ---------- + loc : Tensor or scalar, default 0 + Location parameter of the distribution. + scale : Tensor or scalar, default 1 + Scale parameter of the distribution + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + support = Real() + arg_constraints = {'loc': Real(), + 'scale': Positive()} + + def __init__(self, loc, scale=1, F=None, validate_args=None): + _F = F if F is not None else getF(loc, scale) + self.loc = loc + self.scale = scale + super(Gumbel, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + # Standardized sample + y = (self.loc - value) / self.scale + return (y - F.np.exp(y)) - F.np.log(self.scale) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance.loc = F.np.broadcast_to(self.loc, batch_shape) + new_instance.scale = F.np.broadcast_to(self.scale, batch_shape) + super(Gumbel, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def cdf(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + y = (value - self.loc) / self.scale + exp_fn = F.np.exp + return exp_fn(-exp_fn(-y)) + + def icdf(self, value): + F = self.F + log_fn = F.np.log + return self.loc + self.scale * (-log_fn(-log_fn(value))) + + def sample(self, size=None): + return self.F.np.random.gumbel(self.loc, self.scale, size) + + def sample_n(self, size=None): + return self.F.np.random.gumbel(self.loc, self.scale, sample_n_shape_converter(size)) + + @property + def mean(self): + return self.loc + self.scale * euler_gamma + + @property + def stddev(self): + return (math.pi / math.sqrt(6)) * self.scale + + @property + def variance(self): + return self.stddev ** 2 + + def entropy(self): + F = self.F + return F.np.log(self.scale) + (1 + euler_gamma) diff --git a/python/mxnet/gluon/probability/distributions/half_cauchy.py b/python/mxnet/gluon/probability/distributions/half_cauchy.py new file mode 100644 index 000000000000..a39236b81784 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/half_cauchy.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Half-cauchy Distribution""" +__all__ = ["HalfCauchy"] + +import math +from numpy import inf +from .transformed_distribution import TransformedDistribution +from ..transformation import AbsTransform +from .cauchy import Cauchy +from .constraint import Positive + + +class HalfCauchy(TransformedDistribution): + r"""Create a half cauchy object, where + X ~ Cauchy(0, scale) + Y = |X| ~ HalfCauchy(scale) + + Parameters + ---------- + scale : Tensor or scalar, default 1 + Scale of the full Cauchy distribution. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + support = Positive() + arg_constraints = {'scale': Positive()} + + def __init__(self, scale=1.0, F=None, validate_args=None): + base_dist = Cauchy(0, scale, F) + self.scale = scale + super(HalfCauchy, self).__init__( + base_dist, AbsTransform(), validate_args=validate_args) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + log_prob = self._base_dist.log_prob(value) + math.log(2) + log_prob = self.F.np.where(value < 0, -inf, log_prob) + return log_prob + + def cdf(self, value): + if self._validate_args: + self._validate_samples(value) + return 2 * self._base_dist.cdf(value) - 1 + + def icdf(self, value): + return self._base_dist.icdf((value + 1) / 2) + + def entropy(self): + return self._base_dist.entropy() - math.log(2) + + @property + def mean(self): + return self.scale * math.sqrt(2 / math.pi) + + @property + def variance(self): + pow_fn = self.F.np.power + return pow_fn(self.scale, 2) * (1 - 2 / math.pi) diff --git a/python/mxnet/gluon/probability/distributions/half_normal.py b/python/mxnet/gluon/probability/distributions/half_normal.py new file mode 100644 index 000000000000..7e93b7b5837d --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/half_normal.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Half-normal Distribution""" +__all__ = ["HalfNormal"] + +import math +from numpy import inf +from .transformed_distribution import TransformedDistribution +from ..transformation import AbsTransform +from .normal import Normal +from .constraint import Positive + + +class HalfNormal(TransformedDistribution): + r"""Create a half normal object, where + X ~ Normal(0, scale) + Y = |X| ~ HalfNormal(scale) + + Parameters + ---------- + scale : Tensor or scalar, default 1 + Scale of the full Normal distribution. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + support = Positive() + arg_constraints = {'scale': Positive()} + + def __init__(self, scale=1.0, F=None, validate_args=None): + base_dist = Normal(0, scale, F) + self.scale = scale + super(HalfNormal, self).__init__( + base_dist, AbsTransform(), validate_args=validate_args) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + log_prob = self._base_dist.log_prob(value) + math.log(2) + log_prob = self.F.np.where(value < 0, -inf, log_prob) + return log_prob + + def cdf(self, value): + if self._validate_args: + self._validate_samples(value) + return 2 * self._base_dist.cdf(value) - 1 + + def icdf(self, value): + return self._base_dist.icdf((value + 1) / 2) + + @property + def loc(self): + return self._base_dist.loc + + @property + def mean(self): + return self.scale * math.sqrt(2 / math.pi) + + @property + def variance(self): + pow_fn = self.F.np.power + return pow_fn(self.scale, 2) * (1 - 2 / math.pi) diff --git a/python/mxnet/gluon/probability/distributions/independent.py b/python/mxnet/gluon/probability/distributions/independent.py new file mode 100644 index 000000000000..25c846d656cc --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/independent.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Independent class.""" +__all__ = ['Independent'] + +from .distribution import Distribution +from .constraint import dependent_property +from .utils import sum_right_most + + +class Independent(Distribution): + r""" + Reinterprets some collection of independent, non-identical distributions as + a single multivariate random variable (convert some `batch_dim` to `event_dim`). + """ + # pylint: disable=abstract-method + + arg_constraints = {} + + def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None): + event_dim = reinterpreted_batch_ndims + base_distribution.event_dim + self.base_dist = base_distribution + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + super(Independent, self).__init__(F=base_distribution.F, + event_dim=event_dim, + validate_args=validate_args) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + # we use -2 to copy the sizes of reinterpreted batch dimensions + reinterpreted_axes = (-2,) * self.reinterpreted_batch_ndims + new_instance.base_dist = self.base_dist.broadcast_to( + batch_shape + reinterpreted_axes) + new_instance.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims + super(Independent, new_instance).__init__(F=F, event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + @property + def has_enumerate_support(self): + if self.reinterpreted_batch_ndims > 0: + return False + return self.base_dist.has_enumerate_support + + @dependent_property + def support(self): + return self.base_dist.support + + @property + def mean(self): + return self.base_dist.mean + + @property + def variance(self): + return self.base_dist.variance + + def sample(self, size=None): + return self.base_dist.sample(size) + + def sample_n(self, size): + return self.base_dist.sample_n(size) + + def log_prob(self, value): + log_prob = self.base_dist.log_prob(value) + return sum_right_most(log_prob, self.reinterpreted_batch_ndims) + + def entropy(self): + entropy = self.base_dist.entropy() + return sum_right_most(entropy, self.reinterpreted_batch_ndims) + + def enumerate_support(self): + if self.reinterpreted_batch_ndims > 0: + raise NotImplementedError( + "Enumeration over cartesian product is not implemented") + return self.base_dist.enumerate_support() diff --git a/python/mxnet/gluon/probability/distributions/laplace.py b/python/mxnet/gluon/probability/distributions/laplace.py new file mode 100644 index 000000000000..1bc88e94017c --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/laplace.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Laplace distribution""" +__all__ = ['Laplace'] + +from .constraint import Real, Positive +from .distribution import Distribution +from .utils import getF, sample_n_shape_converter + + +class Laplace(Distribution): + r"""Create a laplace distribution object. + + Parameters + ---------- + loc : Tensor or scalar, default 0 + mean of the distribution. + scale : Tensor or scalar, default 1 + scale of the distribution + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + + """ + # pylint: disable=abstract-method + + has_grad = False + support = Real() + arg_constraints = {'loc': Real(), 'scale': Positive()} + + def __init__(self, loc=0.0, scale=1.0, F=None, validate_args=None): + _F = F if F is not None else getF(loc, scale) + self.loc = loc + self.scale = scale + super(Laplace, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + def log_prob(self, value): + """Compute the log likelihood of `value`. + + Parameters + ---------- + value : Tensor + Input data. + + Returns + ------- + Tensor + Log likelihood of the input. + """ + if self._validate_args: + self._validate_samples(value) + F = self.F + return -F.np.log(2 * self.scale) - F.np.abs(value - self.loc) / self.scale + + def sample(self, size=None): + r"""Generate samples of `size` from the normal distribution + parameterized by `self._loc` and `self._scale` + + Parameters + ---------- + size : Tuple, Scalar, or None + Size of samples to be generated. If size=None, the output shape + will be `broadcast(loc, scale).shape` + + Returns + ------- + Tensor + Samples from Normal distribution. + """ + return self.F.np.random.laplace(self.loc, self.scale, size) + + def sample_n(self, size=None): + r"""Generate samples of (batch_size + broadcast(loc, scale).shape) + from the normal distribution parameterized by `self._loc` and `self._scale` + + Parameters + ---------- + size : Tuple, Scalar, or None + Size of independent batch to be generated from the distribution. + + Returns + ------- + Tensor + Samples from Normal distribution. + """ + return self.F.np.random.laplace(self.loc, self.scale, sample_n_shape_converter(size)) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance.loc = F.np.broadcast_to(self.loc, batch_shape) + new_instance.scale = F.np.broadcast_to(self.scale, batch_shape) + super(Laplace, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def cdf(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + value = value - self.loc + return 0.5 - 0.5 * F.np.sign(value) * F.np.expm1(-F.np.abs(value) / self.scale) + + def icdf(self, value): + F = self.F + value = value - 0.5 + return self.loc - self.scale * F.np.sign(value) * F.np.log1p(-2 * F.np.abs(value)) + + @property + def mean(self): + return self.loc + + @property + def stddev(self): + return (2 ** 0.5) * self.scale + + @property + def variance(self): + return 2 * (self.scale ** 2) + + def entropy(self): + F = self.F + return 1 + F.np.log(2 * self.scale) diff --git a/python/mxnet/gluon/probability/distributions/multinomial.py b/python/mxnet/gluon/probability/distributions/multinomial.py new file mode 100644 index 000000000000..875125eb3e02 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/multinomial.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Multinomial Distribution""" +__all__ = ['Multinomial'] + +from numbers import Number +from .distribution import Distribution +from .one_hot_categorical import OneHotCategorical +from .utils import getF, cached_property, logit2prob, prob2logit, gammaln +from .constraint import Simplex, Real, IntegerInterval + + +class Multinomial(Distribution): + r"""Create a multinomial distribution object. + + Parameters + ---------- + num_events : int + number of events. + prob : Tensor + probability of each event. + logit : Tensor + unnormalized probability of each event. + total_count : int + number of trials. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + arg_constraints = {'prob': Simplex(), 'logit': Real()} + + def __init__(self, num_events, + prob=None, logit=None, total_count=1, F=None, validate_args=None): + _F = F if F is not None else getF(prob, logit) + if not isinstance(total_count, Number): + raise ValueError("Expect `total_conut` to be scalar value") + self.total_count = total_count + if (prob is None) == (logit is None): + raise ValueError( + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) + if prob is not None: + self.prob = prob + else: + self.logit = logit + self._categorical = OneHotCategorical( + num_events, prob, logit, F, validate_args) + super(Multinomial, self).__init__( + F=_F, event_dim=1, validate_args=validate_args) + + @property + def mean(self): + return self.prob * self.total_count + + @property + def variance(self): + return self.total_count * self.prob * (1 - self.prob) + + @cached_property + def prob(self): + # pylint: disable=method-hidden + return logit2prob(self.logit, False, self.F) + + @cached_property + def logit(self): + # pylint: disable=method-hidden + return prob2logit(self.prob, False, self.F) + + @property + def support(self): + return IntegerInterval(0, self.total_count) + + def sample(self, size=None): + if size is not None: + categorical = self._categorical.broadcast_to(size) + else: + categorical = self._categorical + return categorical.sample_n(self.total_count).sum(0) + + def sample_n(self, size=None): + if isinstance(size, Number): + size = (size,) + size = () if size is None else size + return self._categorical.sample_n((self.total_count,) + size).sum(0) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + lgamma = gammaln(F) + log_factorial_n = lgamma(value.sum(-1) + 1) + log_factorial_x = lgamma(value + 1).sum(-1) + log_power = (self.logit * value).sum(-1) + return log_factorial_n - log_factorial_x + log_power + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance._categorical = self._categorical.broadcast_to(batch_shape) + new_instance.num_events = self.num_events + new_instance.total_conut = self.total_count + super(Multinomial, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance diff --git a/python/mxnet/gluon/probability/distributions/multivariate_normal.py b/python/mxnet/gluon/probability/distributions/multivariate_normal.py new file mode 100644 index 000000000000..1eaa41449261 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/multivariate_normal.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Multivariate Normal Distribution""" +__all__ = ['MultivariateNormal'] + +import math +from .distribution import Distribution +from .constraint import Real, PositiveDefinite, LowerCholesky +from .utils import getF, cached_property + + +class MultivariateNormal(Distribution): + r"""Create a multivaraite Normal distribution object. + + Parameters + ---------- + loc : Tensor + mean of the distribution. + cov : Tensor + covariance matrix of the distribution + precision : Tensor + precision matrix of the distribution + scale_tril : Tensor + lower-triangular factor of the covariance + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + support = Real() + arg_constraints = {'loc': Real(), + 'cov': PositiveDefinite(), + 'precision': PositiveDefinite(), + 'scale_tril': LowerCholesky()} + + def __init__(self, loc, cov=None, precision=None, scale_tril=None, F=None, validate_args=None): + if (cov is not None) + (precision is not None) + (scale_tril is not None) != 1: + raise ValueError("Exactly one onf `cov` or `precision` or " + + "`scale_tril` may be specified") + _F = F if F is not None else getF(cov, precision, scale_tril) + self.loc = loc + if cov is not None: + self.cov = cov + elif precision is not None: + self.precision = precision + else: + self.scale_tril = scale_tril + super(MultivariateNormal, self).__init__( + F=_F, event_dim=1, validate_args=validate_args) + + def _precision_to_scale_tril(self, P): + """ + P = inv(L * L.T) = inv(L.T) * inv(L) + flip(P) = flip(inv(L.T)) * flip(inv(L)) + flip(inv(L.T)) = Cholesky(flip(P)) + L = flip(Cholesky(flip(P))).T + """ + F = self.F + L_flip_inv_T = F.np.linalg.cholesky(F.np.flip(P, (-1, -2))) + L = F.np.linalg.inv(F.np.swapaxes( + F.np.flip(L_flip_inv_T, (-1, -2)), -1, -2)) + return L + + @cached_property + def scale_tril(self): + # pylint: disable=method-hidden + F = self.F + if 'cov' in self.__dict__: + return F.np.linalg.cholesky(self.cov) + return self._precision_to_scale_tril(self.precision) + + @cached_property + def cov(self): + # pylint: disable=method-hidden + F = self.F + if 'scale_tril' in self.__dict__: + scale_triu = F.np.swapaxes(self.scale_tril, -1, -2) + return F.np.matmul(self.scale_tril, scale_triu) + return F.np.linalg.inv(self.precision) + + @cached_property + def precision(self): + # pylint: disable=method-hidden + F = self.F + if 'cov' in self.__dict__: + return F.np.linalg.inv(self.cov) + scale_tril_inv = F.np.linalg.inv(self.scale_tril) + scale_triu_inv = F.np.swapaxes(scale_tril_inv, -1, -2) + return F.np.matmul(scale_triu_inv, scale_tril_inv) + + @property + def mean(self): + return self.loc + + @property + def variance(self): + return (self.scale_tril ** 2).sum(-1) + + def sample(self, size=None): + F = self.F + # symbol does not support `np.broadcast` + shape_tensor = self.loc + self.scale_tril.sum(-1) + if size is not None: + if isinstance(size, int): + size = (size,) + shape_tensor = F.np.broadcast_to(shape_tensor, size + (-2,)) + noise = F.np.random.normal(F.np.zeros_like( + shape_tensor), F.np.ones_like(shape_tensor)) + samples = self.loc + \ + F.np.einsum('...jk,...j->...k', self.scale_tril, noise) + return samples + + def sample_n(self, size=None): + if size is None: + return self.sample() + F = self.F + # symbol does not support `np.broadcast` + shape_tensor = self.loc + self.scale_tril[..., 0] + if isinstance(size, int): + size = (size,) + noise = F.np.random.normal(F.np.zeros_like(shape_tensor), F.np.ones_like(shape_tensor), + (-2,) + size) + samples = self.loc + \ + F.np.einsum('...jk,...j->...k', self.scale_tril, noise) + return samples + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + diff = value - self.loc + # diff.T * inv(\Sigma) * diff + M = F.np.einsum( + '...i,...i->...', + diff, + F.np.einsum('...jk,...j->...k', self.precision, + diff) # Batch matrix vector multiply + ) * -0.5 + # (2 * \pi)^{-k/2} * det(\Sigma)^{-1/2} + # = det(2 * \pi * L * L.T)^{-1/2} + # = det(\sqrt(2 * \pi) * L)^{-1} + half_log_det = F.np.log( + F.np.diagonal(F.np.sqrt(2 * math.pi) * + self.scale_tril, axis1=-2, axis2=-1) + ).sum(-1) + return M - half_log_det + + def entropy(self): + F = self.F + # det(2 * \pi * e * \Sigma) + # = det(\sqrt(2 * \pi * e) * L)^2 + return F.np.log(F.np.diagonal( + F.np.sqrt(2 * math.pi * math.e) * self.scale_tril, + axis1=-2, axis2=-1 + )).sum(-1) diff --git a/python/mxnet/gluon/probability/distributions/negative_binomial.py b/python/mxnet/gluon/probability/distributions/negative_binomial.py new file mode 100644 index 000000000000..d360d48f4d61 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/negative_binomial.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Negative binomial distribution class.""" +__all__ = ['NegativeBinomial'] + +from .distribution import Distribution +from .poisson import Poisson +from .gamma import Gamma +from .utils import prob2logit, logit2prob, getF, cached_property +from .utils import gammaln +from .constraint import GreaterThanEq, Interval, Real, NonNegativeInteger + + +class NegativeBinomial(Distribution): + r"""Create a negative binomial distribution object. + + Parameters + ---------- + n : Tensor or scalar + Non-negative number of negative Bernoulli trials to stop. + prob : Tensor or scalar, default None + Probability of sampling `1`. + logit : Tensor or scalar, default None + The log-odds of sampling `1`. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + support = NonNegativeInteger() + arg_constraints = {'n': GreaterThanEq(0), + 'prob': Interval(0, 1), + 'logit': Real()} + + def __init__(self, n, prob=None, logit=None, F=None, validate_args=None): + _F = F if F is not None else getF(n, prob, logit) + if (prob is None) == (logit is None): + raise ValueError( + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) + + if prob is not None: + self.prob = prob + else: + self.logit = logit + self.n = n + super(NegativeBinomial, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + @cached_property + def prob(self): + """Get the probability of sampling `1`. + + Returns + ------- + Tensor + Parameter tensor. + """ + # pylint: disable=method-hidden + return logit2prob(self.logit, True, self.F) + + @cached_property + def logit(self): + """Get the log-odds of sampling `1`. + + Returns + ------- + Tensor + Parameter tensor. + """ + # pylint: disable=method-hidden + return prob2logit(self.prob, True, self.F) + + @property + def mean(self): + F = self.F + return self.n * F.np.exp(self.logit) + + @property + def variance(self): + prob = self.prob + return self.n * prob / (1 - prob) ** 2 + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + if 'prob' in self.__dict__: + new_instance.prob = F.np.broadcast_to(self.prob, batch_shape) + else: + new_instance.logit = F.np.broadcast_to(self.logit, batch_shape) + new_instance.n = F.np.broadcast_to(self.n, batch_shape) + super(NegativeBinomial, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + lgamma = gammaln(F) + binomal_coef = lgamma(value + self.n) - \ + lgamma(1 + value) - lgamma(self.n) + # log(prob) may have numerical issue. + unnormalized_log_prob = self.n * \ + F.np.log(self.prob) + value * F.np.log1p(-self.prob) + return binomal_coef + unnormalized_log_prob + + def sample(self, size=None): + F = self.F + # Sample via Poisson-Gamma mixture + rate = Gamma(shape=self.n, scale=F.np.exp( + self.logit), F=F).sample(size) + return Poisson(rate, F=F).sample() + + def sample_n(self, size=None): + F = self.F + # Sample via Poisson-Gamma mixture + rate = Gamma(shape=self.n, scale=F.np.exp( + self.logit), F=F).sample_n(size) + return Poisson(rate, F=F).sample() diff --git a/python/mxnet/gluon/probability/distributions/normal.py b/python/mxnet/gluon/probability/distributions/normal.py new file mode 100644 index 000000000000..d0f1b1fbb8b0 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/normal.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Normal distribution""" +__all__ = ['Normal'] + +import math +from .constraint import Real, Positive +from .exp_family import ExponentialFamily +from .utils import getF, erf, erfinv + + +class Normal(ExponentialFamily): + r"""Create a Normal distribution object. + + Parameters + ---------- + loc : Tensor or scalar, default 0 + mean of the distribution. + scale : Tensor or scalar, default 1 + standard deviation of the distribution + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + support = Real() + arg_constraints = {'loc': Real(), 'scale': Positive()} + + def __init__(self, loc=0.0, scale=1.0, F=None, validate_args=None): + _F = F if F is not None else getF(loc, scale) + self.loc = loc + self.scale = scale + super(Normal, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + def log_prob(self, value): + """Compute the log likelihood of `value`. + + Parameters + ---------- + value : Tensor + Input data. + + Returns + ------- + Tensor + Log likelihood of the input. + """ + if self._validate_args: + self._validate_samples(value) + F = self.F + log_scale = F.np.log(self.scale) + log_prob = -((value - self.loc) ** 2) / (2 * self.variance) + log_prob = log_prob - log_scale + log_prob = log_prob - F.np.log(F.np.sqrt(2 * math.pi)) + return log_prob + + def sample(self, size=None): + r"""Generate samples of `size` from the normal distribution + parameterized by `self._loc` and `self._scale` + + Parameters + ---------- + size : Tuple, Scalar, or None + Size of samples to be generated. If size=None, the output shape + will be `broadcast(loc, scale).shape` + + Returns + ------- + Tensor + Samples from Normal distribution. + """ + return self.F.np.random.normal(self.loc, self.scale, size) + + def sample_n(self, size=None): + r"""Generate samples of (batch_size + broadcast(loc, scale).shape) + from the normal distribution parameterized by `self._loc` and `self._scale` + + Parameters + ---------- + size : Tuple, Scalar, or None + Size of independent batch to be generated from the distribution. + + Returns + ------- + Tensor + Samples from Normal distribution. + """ + return self.F.npx.random.normal_n(self.loc, self.scale, size) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance.loc = F.np.broadcast_to(self.loc, batch_shape) + new_instance.scale = F.np.broadcast_to(self.scale, batch_shape) + super(Normal, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def cdf(self, value): + if self._validate_args: + self._validate_samples(value) + erf_func = erf(self.F) + standarized_samples = ((value - self.loc) / + (math.sqrt(2) * self.scale)) + erf_term = erf_func(standarized_samples) + return 0.5 * (1 + erf_term) + + def icdf(self, value): + erfinv_func = erfinv(self.F) + return self.loc + self.scale * erfinv_func(2 * value - 1) * math.sqrt(2) + + @property + def mean(self): + return self.loc + + @property + def stddev(self): + return self.scale + + @property + def variance(self): + return self.scale ** 2 + + def entropy(self): + F = self.F + return 0.5 + 0.5 * math.log(2 * math.pi) + F.np.log(self.scale) + + @property + def _natural_params(self): + r"""Return the natural parameters of normal distribution, + which are (\frac{\mu}{\sigma^2}, -0.5 / (\sigma^2)) + + Returns + ------- + Tuple + Natural parameters of normal distribution. + """ + return (self.loc / (self.scale ** 2), + -0.5 * self.F.np.reciprocal(self.scale ** 2)) + + def _log_normalizer(self, x, y): + # pylint: disable=arguments-differ + F = self.F + return -0.25 * F.np.pow(x, 2) / y + 0.5 * F.np.log(-math.pi / y) diff --git a/python/mxnet/gluon/probability/distributions/one_hot_categorical.py b/python/mxnet/gluon/probability/distributions/one_hot_categorical.py new file mode 100644 index 000000000000..8729cd81b3a1 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/one_hot_categorical.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""One-hot Categorical Distribution""" +__all__ = ['OneHotCategorical'] + +from .distribution import Distribution +from .categorical import Categorical +from .utils import getF, cached_property +from .constraint import Simplex, Real + + +class OneHotCategorical(Distribution): + """Create a one-hot categorical distribution object. + + Parameters + ---------- + num_events : Int + Number of events. + prob : Tensor + Probabilities of each event. + logit : Tensor + The log-odds of each event + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + arg_constraints = {'prob': Simplex(), 'logit': Real()} + + def __init__(self, num_events, prob=None, logit=None, F=None, validate_args=None): + _F = F if F is not None else getF(prob, logit) + if (num_events > 0): + num_events = int(num_events) + self.num_events = num_events + else: + raise ValueError("`num_events` should be greater than zero. " + + "Received num_events={}".format(num_events)) + self._categorical = Categorical( + num_events, prob, logit, _F, validate_args) + super(OneHotCategorical, self).__init__( + _F, event_dim=1, validate_args=validate_args) + + @cached_property + def prob(self): + return self._categorical.prob + + @cached_property + def logit(self): + return self._categorical.logit + + @property + def mean(self): + return self._categorical.prob + + @property + def variance(self): + prob = self.prob + return prob * (1 - prob) + + def sample(self, size=None): + indices = self._categorical.sample(size) + return self.F.npx.one_hot(indices, self.num_events) + + def sample_n(self, size=None): + indices = self._categorical.sample_n(size) + return self.F.npx.one_hot(indices, self.num_events) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + logit = self.logit + return (value * logit).sum(-1) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance._categorical = self._categorical.broadcast_to(batch_shape) + new_instance.num_events = self.num_events + super(OneHotCategorical, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def enumerate_support(self): + value = self._categorical.enumerate_support() + return self.F.npx.one_hot(value, self.num_events) diff --git a/python/mxnet/gluon/probability/distributions/pareto.py b/python/mxnet/gluon/probability/distributions/pareto.py new file mode 100644 index 000000000000..309d49dce2ed --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/pareto.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Pareto Distribution.""" +__all__ = ['Pareto'] + +from .transformed_distribution import TransformedDistribution +from .exponential import Exponential +from .constraint import Positive, dependent_property, GreaterThan +from ..transformation import ExpTransform, AffineTransform +from .utils import getF, sample_n_shape_converter + + +class Pareto(TransformedDistribution): + r"""Create a Pareto Type I distribution object. + + Parameters + ---------- + alpha : Tensor or scalar + shape parameter of the distribution. + scale : Tensor or scalar, default 1 + scale parameter of the distribution. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + arg_constraints = {'scale': Positive(), + 'alpha': Positive()} + + def __init__(self, alpha, scale=1.0, F=None, validate_args=None): + _F = F if F is not None else getF(alpha, scale) + self.alpha = alpha + self.scale = scale + base_dist = Exponential(1 / self.alpha) + super(Pareto, self).__init__(base_dist, [ + ExpTransform(), AffineTransform(0, self.scale)]) + + def sample(self, size=None): + F = self.F + return self.scale * (F.np.random.pareto(self.alpha, size) + 1) + + def sample_n(self, size=None): + F = self.F + return self.scale * (F.np.random.pareto(self.alpha, sample_n_shape_converter(size)) + 1) + + @dependent_property + def support(self): + return GreaterThan(self.scale) + + @property + def mean(self): + F = self.F + a = F.np.clip(self.alpha, min=1) + return a * self.scale / (a - 1) + + @property + def variance(self): + F = self.F + a = F.np.clip(self.alpha, min=2) + return (self.scale ** 2) * a / ((a - 1) ** 2 * (a - 2)) + + def entropy(self): + F = self.F + return F.np.log(self.scale / self.alpha) + 1 / self.alpha + 1 diff --git a/python/mxnet/gluon/probability/distributions/poisson.py b/python/mxnet/gluon/probability/distributions/poisson.py new file mode 100644 index 000000000000..ff32379424eb --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/poisson.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Poisson distribution.""" +__all__ = ['Poisson'] + +from numbers import Number +from .exp_family import ExponentialFamily +from .constraint import Positive, NonNegativeInteger +from .utils import getF, gammaln + + +class Poisson(ExponentialFamily): + r"""Create a Poisson distribution object. + + Parameters + ---------- + rate : Tensor or scalar, default 1 + rate parameter of the distribution. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + arg_constraints = {'rate': Positive()} + support = NonNegativeInteger() + + def __init__(self, rate=1.0, F=None, validate_args=None): + _F = F if F is not None else getF(rate) + self.rate = rate + super(Poisson, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + @property + def mean(self): + return self.rate + + @property + def variance(self): + return self.rate + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance.rate = F.np.broadcast_to(self.rate, batch_shape) + super(Poisson, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def sample(self, size=None): + F = self.F + lam = self.rate + if size is None: + size = () + if isinstance(lam, Number): + # Scalar case + return F.npx.scalar_poisson(lam, size) + else: + # Tensor case + shape_tensor = F.np.ones(size) + # shape = () currently not supported + return F.npx.tensor_poisson(lam * shape_tensor) + + def sample_n(self, size=None): + F = self.F + lam = self.rate + if isinstance(lam, Number): + # Scalar case + if size is None: + size = () + return F.npx.scalar_poisson(lam, size) + else: + return F.np.moveaxis(F.npx.tensor_poisson(lam, size), -1, 0) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + lgamma = gammaln(F) + rate = self.rate + return value * F.np.log(rate) - rate - lgamma(value + 1) + + @property + def _natural_params(self): + F = self.F + return (F.np.log(self.rate),) + + def _log_normalizer(self, x): + # pylint: disable=arguments-differ + F = self.F + return F.np.exp(x) diff --git a/python/mxnet/gluon/probability/distributions/relaxed_bernoulli.py b/python/mxnet/gluon/probability/distributions/relaxed_bernoulli.py new file mode 100644 index 000000000000..faae9aed0cd4 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/relaxed_bernoulli.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Relaxed Bernoulli class.""" +__all__ = ['RelaxedBernoulli'] + +from .distribution import Distribution +from .transformed_distribution import TransformedDistribution +from ..transformation import SigmoidTransform +from .utils import prob2logit, logit2prob, getF, cached_property +from .constraint import OpenInterval, Real, Interval + + +class _LogitRelaxedBernoulli(Distribution): + r"""Helper class for creating an unnormalized relaxed Bernoulli object. + + Parameters + ---------- + T : scalar, default None + Relaxation temperature + prob : Tensor or scalar, default None + Probability of sampling `1`. + logit : Tensor or scalar, default None + The log-odds of sampling `1`. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + support = Real() + arg_constraints = {'prob': Interval(0, 1), + 'logit': Real()} + + def __init__(self, T, prob=None, logit=None, F=None, validate_args=None): + _F = F if F is not None else getF(prob, logit) + self.T = T + if (prob is None) == (logit is None): + raise ValueError( + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) + if prob is not None: + self.prob = prob + else: + self.logit = logit + super(_LogitRelaxedBernoulli, self).__init__( + F=_F, event_dim=0, validate_args=validate_args + ) + + @cached_property + def prob(self): + # pylint: disable=method-hidden + return logit2prob(self.logit, True, self.F) + + @cached_property + def logit(self): + # pylint: disable=method-hidden + return prob2logit(self.prob, True, self.F) + + def sample(self, size=None): + F = self.F + logit = self.logit + return F.np.random.logistic(loc=logit, scale=1, size=size) / self.T + + def log_prob(self, value): + F = self.F + # log-likelihood of `value` from (Logistic(logit, 1) / T) + diff = self.logit - self.T * value + return F.np.log(self.T) + diff - 2 * F.np.log1p(F.np.exp(diff)) + + +class RelaxedBernoulli(TransformedDistribution): + r"""Create a relaxed Bernoulli distribution object. + + Parameters + ---------- + T : scalar, default None + Relaxation temperature + prob : Tensor or scalar, default None + Probability of sampling `1`. + logit : Tensor or scalar, default None + The log-odds of sampling `1`. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + support = OpenInterval(0, 1) + arg_constraints = {'prob': Interval(0, 1), + 'logit': Real()} + + def __init__(self, T, prob=None, logit=None, F=None, validate_args=None): + base_dist = _LogitRelaxedBernoulli(T, prob, logit, F, validate_args) + super(RelaxedBernoulli, self).__init__(base_dist, SigmoidTransform()) + + @property + def T(self): + return self._base_dist.T + + @property + def prob(self): + return self._base_dist.prob + + @property + def logit(self): + return self._base_dist.logit + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + if 'prob' in self.__dict__: + new_instance.prob = F.np.broadcast_to(self.prob, batch_shape) + else: + new_instance.logit = F.np.broadcast_to(self.logit, batch_shape) + super(RelaxedBernoulli, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance diff --git a/python/mxnet/gluon/probability/distributions/relaxed_one_hot_categorical.py b/python/mxnet/gluon/probability/distributions/relaxed_one_hot_categorical.py new file mode 100644 index 000000000000..9d5f172cc865 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/relaxed_one_hot_categorical.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Relaxed Bernoulli class.""" +__all__ = ['RelaxedOneHotCategorical'] + +from math import lgamma +from .distribution import Distribution +from .transformed_distribution import TransformedDistribution +from ..transformation import ExpTransform +from .utils import prob2logit, logit2prob, getF, cached_property +from .constraint import Real, Simplex + + +class _LogRelaxedOneHotCategorical(Distribution): + """Helper class for creating the log of a + categorical distribution object. + + Parameters + ---------- + T : scalar, default None + Relaxation temperature + num_events : Int + Number of events. + prob : Tensor + Probabilities of each event. + logit : Tensor + The log-odds of each event + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + arg_constraints = {'prob': Simplex(), + 'logit': Real()} + + def __init__(self, T, num_events, prob=None, logit=None, F=None, validate_args=None): + self.T = T + _F = F if F is not None else getF(prob, logit) + if (num_events > 0): + num_events = int(num_events) + self.num_events = num_events + else: + raise ValueError("`num_events` should be greater than zero. " + + "Received num_events={}".format(num_events)) + if (prob is None) == (logit is None): + raise ValueError( + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) + + if prob is not None: + self.prob = prob + else: + self.logit = logit + + super(_LogRelaxedOneHotCategorical, self).__init__( + _F, event_dim=1, validate_args=validate_args) + + @cached_property + def prob(self): + """Get the probability of sampling each class. + + Returns + ------- + Tensor + Parameter tensor. + """ + # pylint: disable=method-hidden + return logit2prob(self.logit, False, self.F) + + @cached_property + def logit(self): + """Get the log probability of sampling each class. + + Returns + ------- + Tensor + Parameter tensor. + """ + # pylint: disable=method-hidden + return prob2logit(self.prob, False, self.F) + + def log_prob(self, value): + """Compute the log-likelihood of `value` + + Parameters + ---------- + value : Tensor + samples from Relaxed Categorical distribution + + Returns + ------- + Tensor + log-likelihood of `value` + """ + F = self.F + K = self.num_events # Python scalar + log = F.np.log + exp = F.np.exp + logit = self.logit + y = logit - value * self.T + log_sum_exp = log(exp(y).sum(-1, keepdims=True) + 1e-20) + log_scale = lgamma(K) - log(self.T) * (-(K - 1)) + return (y - log_sum_exp).sum(-1) + log_scale + + def sample(self, size=None): + F = self.F + if size is None: + size = () + logit = self.logit + else: + if isinstance(size, int): + logit = F.np.broadcast_to(self.logit, (size) + (-2,)) + else: + logit = F.np.broadcast_to(self.logit, size + (-2,)) + scores = F.np.random.gumbel(logit) / self.T + return F.np.log(F.npx.softmax(scores, axis=-1) + 1e-20) + + +class RelaxedOneHotCategorical(TransformedDistribution): + """Create a relaxed one hot categorical distribution object. + + Parameters + ---------- + T : scalar, default None + Relaxation temperature + num_events : Int + Number of events. + prob : Tensor + Probabilities of each event. + logit : Tensor + The log-odds of each event + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + has_grad = True + arg_constraints = {'prob': Simplex(), + 'logit': Real()} + + def __init__(self, T, num_events, prob=None, logit=None, F=None, validate_args=None): + base_dist = _LogRelaxedOneHotCategorical( + T, num_events, prob, logit, F, validate_args) + super(RelaxedOneHotCategorical, self).__init__( + base_dist, ExpTransform()) + + @property + def T(self): + return self._base_dist.T + + @property + def prob(self): + return self._base_dist.prob + + @property + def logit(self): + return self._base_dist.logit diff --git a/python/mxnet/gluon/probability/distributions/studentT.py b/python/mxnet/gluon/probability/distributions/studentT.py new file mode 100644 index 000000000000..45a4e1c4d385 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/studentT.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Student T distribution""" +__all__ = ['StudentT'] + +from numpy import nan, inf, pi +from .distribution import Distribution +from .constraint import Real, Positive +from .chi2 import Chi2 +from .utils import getF, gammaln, digamma, sample_n_shape_converter + + +class StudentT(Distribution): + r"""Create a studentT distribution object, often known as t distribution. + + Parameters + ---------- + df : Tensor or scalar + degree of freedom. + loc : Tensor or scalar, default 0 + mean of the distribution. + scale : Tensor or scalar, default 1 + scale of the distribution + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + support = Real() + arg_constraints = {'df': Positive(), 'loc': Real(), 'scale': Real()} + + def __init__(self, df, loc=0.0, scale=1.0, F=None, validate_args=None): + _F = F if F is not None else getF(df, loc, scale) + self.df = df + self.loc = loc + self.scale = scale + self._chi2 = Chi2(self.df) + super(StudentT, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance.loc = F.np.broadcast_to(self.loc, batch_shape) + new_instance.scale = F.np.broadcast_to(self.scale, batch_shape) + new_instance.df = F.np.broadcast_to(self.df, batch_shape) + new_instance._chi2 = self._chi2.broadcast_to(batch_shape) + super(StudentT, new_instance).__init__( + F=F, event_dim=0, validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + @property + def mean(self): + # mean is only defined for df > 1 + m = self.F.np.where(self.df <= 1, nan, self.loc) + return m + + @property + def variance(self): + F = self.F + df = self.df + v = self.scale ** 2 * self.df / (self.df - 2) + v = F.np.where(df <= 2, inf, v) + v = F.np.where(df <= 1, nan, v) + return v + + def sample(self, size=None): + F = self.F + X = F.np.random.normal(size=size) + Z = self._chi2.sample(size) + Y = X * F.np.sqrt(self.df / Z) + return self.loc + Y * self.scale + + def sample_n(self, size=None): + return self.sample(sample_n_shape_converter(size)) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + lgamma = gammaln(F) + df = self.df + value = (value - self.loc) / self.scale + return ( + lgamma((df + 1) / 2) - lgamma(df / 2) - + F.np.log(self.scale) - 0.5 * F.np.log(df * pi) + - 0.5 * (df + 1) * F.np.log1p(value ** 2 / df) + ) + + def entropy(self): + F = self.F + lgamma = gammaln(F) + dgamma = digamma(F) + log_fn = F.np.log + lbeta = lgamma(0.5 * self.df) + lgamma(0.5) - \ + lgamma(0.5 * (self.df + 1)) + return (log_fn(self.scale) + + 0.5 * (self.df + 1) * + (dgamma(0.5 * (self.df + 1)) - dgamma(0.5 * self.df)) + + 0.5 * log_fn(self.df) + lbeta) diff --git a/python/mxnet/gluon/probability/distributions/transformed_distribution.py b/python/mxnet/gluon/probability/distributions/transformed_distribution.py new file mode 100644 index 000000000000..c5cf3625e348 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/transformed_distribution.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""Transformed distribution""" +__all__ = ['TransformedDistribution'] + +from ..transformation import Transformation +from .distribution import Distribution +from .utils import sum_right_most + + +class TransformedDistribution(Distribution): + """A distribution generated by applying a sequence of transformations to + a base distribution/ + + Parameters + ---------- + base_dist : Distribution + Base distribution + transforms : Transformation or List + Transformation to be applied + """ + # pylint: disable=abstract-method + + def __init__(self, base_dist, transforms, validate_args=None): + self._base_dist = base_dist + if isinstance(transforms, Transformation): + transforms = [transforms, ] + self._transforms = transforms + _F = base_dist.F + # Overwrite the F in transform + for t in self._transforms: + t.F = _F + event_dim = max([self._base_dist.event_dim] + + [t.event_dim for t in self._transforms]) + super(TransformedDistribution, self).__init__( + _F, event_dim=event_dim, validate_args=validate_args) + + def sample(self, size=None): + x = self._base_dist.sample(size) + for t in self._transforms: + x = t(x) + return x + + def sample_n(self, size=None): + x = self._base_dist.sample_n(size) + for t in self._transforms: + x = t(x) + return x + + def log_prob(self, value): + """ + Compute log-likelihood of `value` with `log_det_jacobian` and + log-likelihood of the base distribution according to the following conclusion: + + Given that Y = T(X), + log(p(y)) = log(p(x)) - log(|dy/dx|) + """ + log_prob = 0.0 + y = value # T_n(T_{n-1}(...T_1(x))) + # Reverse `_transforms` to transform to the base distribution. + for t in reversed(self._transforms): + x = t.inv(y) + log_prob = log_prob - sum_right_most(t.log_det_jacobian(x, y), + self.event_dim - t.event_dim) + y = x + log_prob = log_prob + sum_right_most(self._base_dist.log_prob(y), + self.event_dim - self._base_dist.event_dim) + return log_prob + + def cdf(self, value): + """ + Compute the cumulative distribution function(CDF) p(Y < `value`) + """ + sign = self.F.np.ones_like(value) + for t in reversed(self._transforms): + value = t.inv(value) + sign = sign * t.sign + value = self._base_dist.cdf(value) + return sign * (value - 0.5) + 0.5 + + def icdf(self, value): + sign = self.F.np.ones_like(value) + for t in self._transforms: + sign = sign * t.sign + value = sign * (value - 0.5) + 0.5 # value or (1 - value) + samples_base = self._base_dist.icdf(value) + for t in self._transforms: + samples_base = t(samples_base) + return samples_base diff --git a/python/mxnet/gluon/probability/distributions/uniform.py b/python/mxnet/gluon/probability/distributions/uniform.py new file mode 100644 index 000000000000..e2d237418c18 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/uniform.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Uniform distribution""" +__all__ = ['Uniform'] + +from .distribution import Distribution +from .constraint import Real, Interval +from .utils import getF, sample_n_shape_converter + + +class Uniform(Distribution): + r"""Create a uniform distribution object. + + Parameters + ---------- + low : Tensor or scalar, default 0 + lower range of the distribution. + high : Tensor or scalar, default 1 + upper range of the distribution. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + + # Reparameterization gradient for Uniform is currently not implemented + # in the backend at this moment. + has_grad = False + arg_constraints = {'low': Real(), 'high': Real()} + + def __init__(self, low=0.0, high=1.0, F=None, validate_args=None): + _F = F if F is not None else getF(low, high) + self.low = low + self.high = high + super(Uniform, self).__init__( + F=_F, event_dim=0, validate_args=validate_args) + + def log_prob(self, value): + if self._validate_args: + self._validate_samples(value) + F = self.F + def type_converter(x): + return float(x) if isinstance(x, bool) else x.astype('float') + lower_bound = type_converter(self.low < value) + upper_bound = type_converter(self.high > value) + # 0 if value \in [low, high], -inf otherwise. + out_of_support_value = F.np.log(lower_bound * upper_bound) + return out_of_support_value - F.np.log(self.high - self.low) + + def sample(self, size=None): + F = self.F + return F.np.random.uniform(self.low, self.high, size=size) + + def sample_n(self, size=None): + F = self.F + return F.np.random.uniform(self.low, self.high, + size=sample_n_shape_converter(size)) + + @property + def support(self): + return Interval(self.low, self.high) + + def broadcast_to(self, batch_shape): + new_instance = self.__new__(type(self)) + F = self.F + new_instance.low = F.np.broadcast_to(self.low, batch_shape) + new_instance.high = F.np.broadcast_to(self.high, batch_shape) + super(Uniform, new_instance).__init__(F=F, + event_dim=self.event_dim, + validate_args=False) + new_instance._validate_args = self._validate_args + return new_instance + + def cdf(self, value): + if self._validate_args: + self._validate_samples(value) + x = (value - self.low) / (self.high - self.low) + return x.clip(0, 1) + + def icdf(self, value): + return value * (self.high - self.low) + self.low + + def entropy(self): + return self.F.np.log(self.high - self.low) diff --git a/python/mxnet/gluon/probability/distributions/utils.py b/python/mxnet/gluon/probability/distributions/utils.py new file mode 100644 index 000000000000..f8a03c49e33b --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/utils.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Distribution utilities""" +__all__ = ['getF', 'prob2logit', 'logit2prob', 'cached_property', 'sample_n_shape_converter', + 'constraint_check', 'digamma', 'gammaln', 'erfinv', 'erf'] + +from functools import update_wrapper +from numbers import Number +import numpy as onp +import scipy.special as sc +from .... import symbol as sym +from .... import ndarray as nd + + +def constraint_check(F): + """Unified check_constraint interface for both scalar and tensor + """ + def _check(condition, err_msg): + if isinstance(condition, bool): + if not condition: + raise ValueError(err_msg) + return 1.0 + return F.npx.constraint_check(condition, err_msg) + return _check + + +def digamma(F): + """Unified digamma interface for both scalar and tensor + """ + def compute(value): + """Return digamma(value) + """ + if isinstance(value, Number): + return sc.digamma(value, dtype='float32') + return F.npx.digamma(value) + return compute + + +def gammaln(F): + """Unified gammaln interface for both scalar and tensor + """ + def compute(value): + """Return log(gamma(value)) + """ + if isinstance(value, Number): + return sc.gammaln(value, dtype='float32') + return F.npx.gammaln(value) + return compute + + +def erf(F): + """Unified erf interface for both scalar and tensor + """ + def compute(value): + if isinstance(value, Number): + return sc.erf(value) + return F.npx.erf(value) + return compute + + +def erfinv(F): + """Unified erfinv interface for both scalar and tensor + """ + def compute(value): + if isinstance(value, Number): + return sc.erfinv(value) + return F.npx.erfinv(value) + return compute + + +def sample_n_shape_converter(size): + """Convert `size` to the proper format for performing sample_n. + """ + if size is None: + return size + if size == (): + size = None + else: + if isinstance(size, int): + size = (size,) + size = (-2,) + size + return size + + +def getF(*params): + """Get running mode from parameters, + return mx.ndarray if inputs are python scalar. + + Returns + ------- + ndarray or _Symbol + the running mode inferred from `*params` + """ + mode_flag = 0 + for param in params: + if isinstance(param, nd.NDArray): + if mode_flag < 0: + raise TypeError("Expect parameters to have consistent running mode," + + " got {}".format([type(p) for p in params])) + mode_flag = 1 + elif isinstance(param, sym.Symbol): + if mode_flag > 0: + raise TypeError("Expect parameters to have consistent running mode," + + " got {}".format([type(p) for p in params])) + mode_flag = -1 + # In case of scalar params, we choose to use the imperative mode. + if mode_flag < 0: + return sym + return nd + + +def sum_right_most(x, ndim): + """Sum along the right most `ndim` dimensions of `x`, + + Parameters + ---------- + x : Tensor + Input tensor. + ndim : Int + Number of dimensions to be summed. + + Returns + ------- + Tensor + """ + if ndim == 0: + return x + axes = list(range(-ndim, 0)) + return x.sum(axes) + + +def _clip_prob(prob, F): + eps = onp.finfo('float32').eps + return F.np.clip(prob, eps, 1 - eps) + + +def _clip_float_eps(value, F): + eps = onp.finfo('float32').eps + return F.np.maximum(value, eps) + + +def prob2logit(prob, binary=True, F=None): + r"""Convert probability to logit form. + For the binary case, the logit stands for log(p / (1 - p)). + Whereas for the multinomial case, the logit denotes log(p). + """ + if F is None: + F = getF(prob) + _clipped_prob = _clip_prob(prob, F) + if binary: + return F.np.log(_clipped_prob) - F.np.log1p(-_clipped_prob) + # The clipped prob would cause numerical error in the categorical case, + # no idea about the reason behind. + return F.np.log(_clipped_prob) + + +def logit2prob(logit, binary=True, F=None): + r"""Convert logit into probability form. + For the binary case, `sigmoid()` is applied on the logit tensor. + Whereas for the multinomial case, `softmax` is applied along the last + dimension of the logit tensor. + """ + if F is None: + F = getF(logit) + if binary: + return F.npx.sigmoid(logit) + return F.npx.softmax(logit) + + +class _CachedProperty(object): + r"""Use as a decorator for loading class attribute, but caches the value.""" + + def __init__(self, func): + self._func = func + update_wrapper(self, self._func) + + def __get__(self, instance, cls=None): + if instance is None: + return self + value = self._func(instance) + setattr(instance, self._func.__name__, value) + return value + + +cached_property = _CachedProperty diff --git a/python/mxnet/gluon/probability/distributions/weibull.py b/python/mxnet/gluon/probability/distributions/weibull.py new file mode 100644 index 000000000000..358765b815e0 --- /dev/null +++ b/python/mxnet/gluon/probability/distributions/weibull.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Weibull Distribution.""" +__all__ = ['Weibull'] + +# Euler-Mascheroni constant +from numpy import euler_gamma +from .transformed_distribution import TransformedDistribution +from .exponential import Exponential +from .constraint import Positive +from ..transformation import PowerTransform, AffineTransform +from .utils import getF, sample_n_shape_converter, gammaln + + +class Weibull(TransformedDistribution): + r"""Create a two parameter Weibull distribution object. + + Parameters + ---------- + concentration : Tensor or scalar + Concentration/shape parameter of the distribution. + scale : Tensor or scalar, default 1 + scale parameter of the distribution. + F : mx.ndarray or mx.symbol.numpy._Symbol or None + Variable recording running mode, will be automatically + inferred from parameters if declared None. + """ + # pylint: disable=abstract-method + has_grad = True + support = Positive() + arg_constraints = {'scale': Positive(), + 'concentration': Positive()} + + def __init__(self, concentration, scale=1.0, F=None, validate_args=None): + _F = F if F is not None else getF(scale, concentration) + self.concentration = concentration + self.scale = scale + base_dist = Exponential(F=_F) + super(Weibull, self).__init__(base_dist, [PowerTransform(1 / self.concentration), + AffineTransform(0, self.scale)]) + + def sample(self, size=None): + F = self.F + return self.scale * F.np.random.weibull(self.concentration, size) + + def sample_n(self, size=None): + F = self.F + return self.scale * F.np.random.weibull(self.concentration, + sample_n_shape_converter(size)) + + @property + def mean(self): + F = self.F + return self.scale * F.np.exp(F.npx.gammaln(1 + 1 / self.concentration)) + + @property + def variance(self): + F = self.F + exp = F.np.exp + lgamma = gammaln(F) + term1 = exp(lgamma(1 + 2 / self.concentration)) + term2 = exp(2 * lgamma(1 + 1 / self.concentration)) + return (self.scale ** 2) * (term1 - term2) + + def entropy(self): + F = self.F + return (euler_gamma * (1 - 1 / self.concentration) + + F.np.log(self.scale / self.concentration) + 1) diff --git a/python/mxnet/gluon/probability/transformation/__init__.py b/python/mxnet/gluon/probability/transformation/__init__.py new file mode 100644 index 000000000000..58f381840d2a --- /dev/null +++ b/python/mxnet/gluon/probability/transformation/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Transformation classes.""" + +from .transformation import * +from .domain_map import * diff --git a/python/mxnet/gluon/probability/transformation/domain_map.py b/python/mxnet/gluon/probability/transformation/domain_map.py new file mode 100644 index 000000000000..3c79d2d7268c --- /dev/null +++ b/python/mxnet/gluon/probability/transformation/domain_map.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""Classes for registering and storing bijection/transformations from +unconstrained space to a given domain. +""" + +from numbers import Number +from .transformation import ( + ExpTransform, AffineTransform, SigmoidTransform, ComposeTransform) +from ..distributions.constraint import (Constraint, Positive, GreaterThan, GreaterThanEq, + LessThan, Interval, HalfOpenInterval) + + +__all__ = ['domain_map', 'biject_to', 'transform_to'] + + +class domain_map(): + """ + Abstract Class for registering and storing mappings from domain + to bijections/transformations + """ + def __init__(self): + # constraint -> constraint -> transformation + self._storage = {} + super(domain_map, self).__init__() + + def register(self, constraint, factory=None): + """Register a bijection/transformation from unconstrained space to the domain + specified by `constraint`. + + Parameters + ---------- + constraint : Type or Object + A class of constraint or an object of constraint + factory : callable + A function that outputs a `transformation` given a `constraint`, + by default None. + """ + # Decorator mode + if factory is None: + return lambda factory: self.register(constraint, factory) + + if isinstance(constraint, Constraint): + constraint = type(constraint) + + if not isinstance(constraint, type) or not issubclass(constraint, Constraint): + raise TypeError('Expected constraint to be either a Constraint subclass or instance, ' + 'but got {}'.format(constraint)) + + self._storage[constraint] = factory + return factory + + def __call__(self, constraint): + try: + factory = self._storage[type(constraint)] + except KeyError: + raise NotImplementedError( + 'Cannot transform {} constraints'.format(type(constraint).__name__)) + return factory(constraint) + + +biject_to = domain_map() +transform_to = domain_map() + + +@biject_to.register(Positive) +@transform_to.register(Positive) +def _transform_to_positive(constraint): + # Although `constraint` is not used in this factory function, + # we decide to keep it for the purpose of consistency. + # pylint: disable=unused-argument + return ExpTransform() + + +@biject_to.register(GreaterThan) +@biject_to.register(GreaterThanEq) +@transform_to.register(GreaterThan) +@transform_to.register(GreaterThanEq) +def _transform_to_greater_than(constraint): + return ComposeTransform([ExpTransform(), + AffineTransform(constraint._lower_bound, 1)]) + + +@biject_to.register(LessThan) +@transform_to.register(LessThan) +def _transform_to_less_than(constraint): + return ComposeTransform([ExpTransform(), + AffineTransform(constraint._upper_bound, -1)]) + + +@biject_to.register(Interval) +@biject_to.register(HalfOpenInterval) +@transform_to.register(Interval) +@transform_to.register(HalfOpenInterval) +def _transform_to_interval(constraint): + # Handle the special case of the unit interval. + lower_is_0 = isinstance(constraint._lower_bound, + Number) and constraint._lower_bound == 0 + upper_is_1 = isinstance(constraint._upper_bound, + Number) and constraint._upper_bound == 1 + if lower_is_0 and upper_is_1: + return SigmoidTransform() + + loc = constraint._lower_bound + scale = constraint._upper_bound - constraint._lower_bound + return ComposeTransform([SigmoidTransform(), + AffineTransform(loc, scale)]) diff --git a/python/mxnet/gluon/probability/transformation/transformation.py b/python/mxnet/gluon/probability/transformation/transformation.py new file mode 100644 index 000000000000..4599a483d5dc --- /dev/null +++ b/python/mxnet/gluon/probability/transformation/transformation.py @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=abstract-method +# pylint: disable=arguments-differ +"""Transformation Classes""" +__all__ = ["Transformation", "TransformBlock", "ComposeTransform", "ExpTransform", + "AffineTransform", "PowerTransform", "AbsTransform", 'SigmoidTransform', + 'SoftmaxTransform'] + +import weakref +from ..distributions.utils import _clip_prob, cached_property, sum_right_most +from ...block import HybridBlock +from .... import ndarray as nd + + +class Transformation(object): + r"""Abstract class for implementing invertible transformation + with computable log det jacobians + + Attributes + ---------- + bijective : bool + + """ + bijective = False + event_dim = 0 + + def __init__(self, F=nd): + self._inv = None + self._F = F + super(Transformation, self).__init__() + + @property + def F(self): + return self._F + + @F.setter + def F(self, value): + self._F = value + + @property + def sign(self): + """ + Returns the sign of the determinant of the Jacobian. + """ + raise NotImplementedError + + @property + def inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _InverseTransformation(self) + self._inv = weakref.ref(inv) + return inv + + def __call__(self, x): + return self._forward_compute(x) + + def _inv_call(self, y): + return self._inverse_compute(y) + + def _forward_compute(self, x): + raise NotImplementedError + + def _inverse_compute(self, x): + raise NotImplementedError + + def log_det_jacobian(self, x, y): + """ + Compute the value of log(|dy/dx|) + """ + raise NotImplementedError + + +class _InverseTransformation(Transformation): + """ + A private class representing the invert of `Transformation`, + which should be accessed through `Transformation.inv` property. + """ + + def __init__(self, forward_transformation): + super(_InverseTransformation, self).__init__() + self._inv = forward_transformation + + @property + def inv(self): + return self._inv + + @property + def sign(self): + return self._inv.sign + + @property + def event_dim(self): + return self._inv.event_dim + + def __call__(self, x): + return self._inv._inverse_compute(x) + + def log_det_jacobian(self, x, y): + return -self._inv.log_det_jacobian(y, x) + + +class TransformBlock(Transformation, HybridBlock): + """Transform with learnable parameters should inherit from this class + rather than `Transformation`. + For example: normalization flow. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class ComposeTransform(Transformation): + r""" + Composes multiple transforms in a chain. + """ + def __init__(self, parts): + super(ComposeTransform, self).__init__() + self._parts = parts + + def _forward_compute(self, x): + for t in self._parts: + x = t(x) + return x + + @property + def F(self): + return self._parts[0].F + + @F.setter + def F(self, value): + for t in self._parts: + t.F = value + + # @cached_property is, in essence, @property with lazy evaluation. + # pylint: disable=invalid-overridden-method + @cached_property + def sign(self): + sign = 1 + for p in self._parts: + sign = sign * p.sign + return sign + + @cached_property + def event_dim(self): + return max(p.event_dim for p in self._parts) if self._parts else 0 + + @property + def inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = ComposeTransform([t.inv for t in reversed(self._parts)]) + self._inv = weakref.ref(inv) + inv._inv = weakref.ref(self) + return inv + + def log_det_jacobian(self, x, y): + if not self._parts: + return self.F.np.zeros_like(x) + result = 0 + x_prime = None + for t in self._parts[:-1]: + x_prime = t(x) + result = result + sum_right_most(t.log_det_jacobian(x, x_prime), + self.event_dim - t.event_dim) + x = x_prime + t_last = self._parts[-1] + result = result + sum_right_most(t_last.log_det_jacobian(x, y), + self.event_dim - t_last.event_dim) + + return result + + +class ExpTransform(Transformation): + r""" + Perform the exponential transform: y = exp{x}. + """ + bijective = True + sign = 1 + + def _forward_compute(self, x): + return self.F.np.exp(x) + + def _inverse_compute(self, y): + return self.F.np.log(y) + + def log_det_jacobian(self, x, y): + return x + + +class AffineTransform(Transformation): + r""" + Perform *pointwise* affine transform: y = loc + scale * x. + """ + bijective = True + + def __init__(self, loc, scale, event_dim=0): + super(AffineTransform, self).__init__() + self._loc = loc + self._scale = scale + self.event_dim = event_dim + + def _forward_compute(self, x): + return self._loc + self._scale * x + + def _inverse_compute(self, y): + return (y - self._loc) / self._scale + + def log_det_jacobian(self, x, y): + abs_fn = self.F.np.abs + log_fn = self.F.np.log + ones_fn = self.F.np.ones_like + # element-wise abs(log(dy/dx)) + value = ones_fn(x) * log_fn(abs_fn(self._scale)) + return sum_right_most(value, self.event_dim) + + @property + def sign(self): + return self.F.np.sign(self._scale) + + +class PowerTransform(Transformation): + r""" + Perform *pointwise* power transform: y = pow(x, exponent). + """ + bijective = True + sign = 1 + + def __init__(self, exponent): + super(PowerTransform, self).__init__() + self._exponent = exponent + + def _forward_compute(self, x): + return self.F.np.power(x, self._exponent) + + def _inverse_compute(self, y): + return self.F.np.power(y, 1 / self._exponent) + + def log_det_jacobian(self, x, y): + log_fn = self.F.np.log + abs_fn = self.F.np.abs + return log_fn(abs_fn(self._exponent * y / x)) + + +class SigmoidTransform(Transformation): + r""" + Perform *pointwise* sigmoid transform: y = 1 / (1 + exp(-x)). + """ + bijective = True + sign = 1 + + def _forward_compute(self, x): + F = self.F + return _clip_prob(F.npx.sigmoid(x), F) + + def _inverse_compute(self, y): + F = self.F + clipped_prob = _clip_prob(y, F) + return F.np.log(clipped_prob) - F.np.log1p(-clipped_prob) + + def log_det_jacobian(self, x, y): + F = self.F + log = F.np.log + exp = F.np.exp + softplus_fn = lambda x: log(1 + exp(x)) + return -softplus_fn(-x) - softplus_fn(x) + + +class SoftmaxTransform(Transformation): + event_dim = 1 + + def _forward_compute(self, x): + return self.F.npx.softmax(x, -1) + + def _inverse_compute(self, y): + return self.F.log(y) + + +class AbsTransform(Transformation): + def _forward_compute(self, x): + return self.F.np.abs(x) + + def _inverse_compute(self, y): + return y diff --git a/python/mxnet/ndarray/numpy_extension/random.py b/python/mxnet/ndarray/numpy_extension/random.py index 1ddd28f9e013..fcc65b3084ee 100644 --- a/python/mxnet/ndarray/numpy_extension/random.py +++ b/python/mxnet/ndarray/numpy_extension/random.py @@ -23,7 +23,7 @@ __all__ = ['bernoulli', 'normal_n', 'uniform_n'] -def bernoulli(prob, logit, size, dtype, ctx, out): +def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): """Creates a Bernoulli distribution parameterized by :attr:`prob` or :attr:`logit` (but not both). diff --git a/src/operator/random/multisample_op.cc b/src/operator/random/multisample_op.cc index 240126b17b79..b4c5be160101 100644 --- a/src/operator/random/multisample_op.cc +++ b/src/operator/random/multisample_op.cc @@ -58,13 +58,13 @@ DMLC_REGISTER_PARAMETER(MultiSampleParam); #define MXNET_OPERATOR_REGISTER_SAMPLING1(distr, sampler, input_name, input_desc, \ description) \ MXNET_OPERATOR_REGISTER_SAMPLING(distr, sampler, 1, input_name, input_name, \ - input_desc, input_desc, description); + input_desc, input_desc, description) #define MXNET_OPERATOR_REGISTER_SAMPLING2(distr, sampler, input_name_1, input_name_2, \ input_desc_1, input_desc_2, description) \ MXNET_OPERATOR_REGISTER_SAMPLING(distr, sampler, 2, input_name_1, input_name_2, \ input_desc_1, input_desc_2, description) \ - .add_argument(input_name_2, "NDArray-or-Symbol", input_desc_2); + .add_argument(input_name_2, "NDArray-or-Symbol", input_desc_2) inline std::string uniform_desc() { return std::string(R"code(Concurrent sampling from multiple @@ -274,23 +274,24 @@ Examples:: } MXNET_OPERATOR_REGISTER_SAMPLING2(uniform, UniformSampler, "low", "high", - "Lower bounds of the distributions.", "Upper bounds of the distributions.", uniform_desc) + "Lower bounds of the distributions.", "Upper bounds of the distributions.", uniform_desc); MXNET_OPERATOR_REGISTER_SAMPLING2(normal, NormalSampler, "mu", "sigma", - "Means of the distributions.", "Standard deviations of the distributions.", normal_desc) + "Means of the distributions.", "Standard deviations of the distributions.", normal_desc); MXNET_OPERATOR_REGISTER_SAMPLING2(gamma, GammaSampler, "alpha", "beta", "Alpha (shape) parameters of the distributions.", "Beta (scale) parameters of the distributions.", - gamma_desc) + gamma_desc); MXNET_OPERATOR_REGISTER_SAMPLING1(exponential, ExponentialSampler, "lam", - "Lambda (rate) parameters of the distributions.", exponential_desc) + "Lambda (rate) parameters of the distributions.", exponential_desc); MXNET_OPERATOR_REGISTER_SAMPLING1(poisson, PoissonSampler, "lam", "Lambda (rate) parameters of the distributions.", poisson_desc) +.add_alias("_npx_tensor_poisson"); MXNET_OPERATOR_REGISTER_SAMPLING2(negative_binomial, NegativeBinomialSampler, "k", "p", "Limits of unsuccessful experiments.", "Failure probabilities in each experiment.", - negative_binomial_desc) + negative_binomial_desc); MXNET_OPERATOR_REGISTER_SAMPLING2(generalized_negative_binomial, GeneralizedNegativeBinomialSampler, "mu", "alpha", "Means of the distributions.", "Alpha (dispersion) parameters of the distributions.", - generalized_negative_binomial_desc) + generalized_negative_binomial_desc); } // namespace op } // namespace mxnet diff --git a/src/operator/random/multisample_op.h b/src/operator/random/multisample_op.h index d1d5b3607f0d..ff413bd6f286 100644 --- a/src/operator/random/multisample_op.h +++ b/src/operator/random/multisample_op.h @@ -67,7 +67,7 @@ inline bool MultiSampleOpShape(const nnvm::NodeAttrs& attrs, const MultiSampleParam& param = nnvm::get(attrs.parsed); mxnet::TShape sshape = param.shape; for (int i = 0; i < sshape.ndim(); ++i) { - CHECK_GT(sshape[i], 0) << "shape parameter must be non-zero within each dimension"; + CHECK_GE(sshape[i], 0) << "shape parameter must be non-negative within each dimension"; } // Examine output shape whether it is already defined. mxnet::TShape tshape((*out_attrs)[0]); @@ -177,7 +177,9 @@ void MultiSampleOpForward(const nnvm::NodeAttrs& attrs, using namespace mxnet_op; CHECK_EQ(inputs.size(), inum); CHECK_EQ(outputs.size(), 1); - CHECK_GT(inputs[0].Size(), 0); + if (inputs[0].Size() == 0) { + return; + } mshadow::Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index 5fe97c8bfc9c..01c21e54072e 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -138,6 +138,7 @@ Example:: MXNET_OPERATOR_REGISTER_SAMPLE(_random_poisson, SamplePoissonParam) .add_alias("random_poisson") +.add_alias("_npx_scalar_poisson") .describe(R"code(Draw random samples from a Poisson distribution. Samples are distributed according to a Poisson distribution parametrized by *lambda* (rate). diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 7bd6dfe98ef5..9b969607666d 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -38,6 +38,8 @@ from test_numpy_ndarray import * from test_numpy_op import * from test_numpy_interoperability import * +from test_gluon_probability_v1 import * +from test_gluon_probability_v2 import * from test_optimizer import * from test_random import * from test_exc_handling import * diff --git a/tests/python/unittest/test_gluon_probability_v1.py b/tests/python/unittest/test_gluon_probability_v1.py new file mode 100644 index 000000000000..92721f610495 --- /dev/null +++ b/tests/python/unittest/test_gluon_probability_v1.py @@ -0,0 +1,2435 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test gluon.probability with HybridBlock.hybrid_forward api +""" +import mxnet as mx +import numpy as _np +from mxnet import np, npx, autograd +from mxnet import gluon +import mxnet.gluon.probability as mgp +from mxnet.gluon.probability import StochasticBlock, StochasticSequential +from mxnet.gluon import HybridBlock +from mxnet.test_utils import use_np, assert_almost_equal + +from common import with_seed +from numpy.testing import assert_array_equal +import pytest +import scipy.stats as ss +import scipy.special as scipy_special +import itertools +from numbers import Number + + +def prob_to_logit(prob): + return np.log(prob) - np.log1p(-prob) + + +def _distribution_method_invoker(dist, func, *args): + """Wrapper for invoking different types of class methods with one unified + interface. + + Parameters + ---------- + dist : Distribution + func : method + """ + if (len(args) == 0): + out = getattr(dist, func) + if callable(out): + return out() + else: + return out + return getattr(dist, func)(*args) + + +def test_mgp_getF_v1(): + # Test getF + getF = mgp.utils.getF + nd = mx.nd + sym = mx.sym + assert getF(nd.ones((2, 2)), nd.ones((2, 2))) == nd + assert getF(sym.ones((2, 2)), sym.ones((2, 2))) == sym + assert getF(1.0, 2.0) == nd + + # Test exception + with pytest.raises(TypeError): + getF(nd.ones((2, 2)), sym.ones((2, 2))) + getF(sym.ones((2, 2)), nd.ones((2, 2))) + + +@with_seed() +@use_np +def test_gluon_uniform_v1(): + class TestUniform(HybridBlock): + def __init__(self, func): + super(TestUniform, self).__init__() + self._func = func + + def hybrid_forward(self, F, low, high, *args): + uniform = mgp.Uniform(low, high, validate_args=True) + return _distribution_method_invoker(uniform, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + low = np.random.uniform(-1, 1, shape) + high = low + np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(low, high) + net = TestUniform("log_prob") + if hybridize: + net.hybridize() + for i in range(2): + mx_out = net(low, high, samples).asnumpy() + np_out = ss.uniform(low.asnumpy(), + (high - low).asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + low = np.random.uniform(-1, 1, shape) + high = low + np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(low, high) + net = TestUniform("cdf") + if hybridize: + net.hybridize() + mx_out = net(low, high, samples).asnumpy() + np_out = ss.uniform(low.asnumpy(), + (high - low).asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + low = np.random.uniform(-1, 1, shape) + high = low + np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestUniform("icdf") + if hybridize: + net.hybridize() + mx_out = net(low, high, samples).asnumpy() + np_out = ss.uniform(low.asnumpy(), + (high - low).asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + low = np.random.uniform(-1, 1, shape) + high = low + np.random.uniform(0.5, 1.5, shape) + net = TestUniform("entropy") + if hybridize: + net.hybridize() + mx_out = net(low, high).asnumpy() + np_out = ss.uniform(low.asnumpy(), + (high - low).asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_normal_v1(): + class TestNormal(HybridBlock): + def __init__(self, func): + super(TestNormal, self).__init__() + self._func = func + + def hybrid_forward(self, F, loc, scale, *args): + normal = mgp.Normal(loc, scale, validate_args=True) + return _distribution_method_invoker(normal, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestNormal("log_prob") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.norm(loc.asnumpy(), + scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestNormal("cdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.norm(loc.asnumpy(), + scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestNormal("icdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.norm(loc.asnumpy(), + scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + net = TestNormal("entropy") + if hybridize: + net.hybridize() + mx_out = net(loc, scale).asnumpy() + np_out = ss.norm(loc.asnumpy(), + scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_laplace_v1(): + class TestLaplace(HybridBlock): + def __init__(self, func): + super(TestLaplace, self).__init__() + self._func = func + + def hybrid_forward(self, F, loc, scale, *args): + laplace = mgp.Laplace(loc, scale, validate_args=True) + return _distribution_method_invoker(laplace, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.laplace(size=shape) + net = TestLaplace("log_prob") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.laplace(loc.asnumpy(), + scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.laplace(size=shape) + net = TestLaplace("cdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.laplace(loc.asnumpy(), + scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestLaplace("icdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.laplace(loc.asnumpy(), + scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + net = TestLaplace("entropy") + if hybridize: + net.hybridize() + mx_out = net(loc, scale).asnumpy() + np_out = ss.laplace(loc.asnumpy(), + scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_cauchy_v1(): + class TestCauchy(HybridBlock): + def __init__(self, func): + self._func = func + super(TestCauchy, self).__init__() + + def hybrid_forward(self, F, loc, scale, *args): + cauchy = mgp.Cauchy(loc, scale, F, validate_args=True) + return _distribution_method_invoker(cauchy, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestCauchy("sample") + if hybridize: + net.hybridize() + mx_out = net(loc, scale) + desired_shape = (shape,) if isinstance(shape, Number) else shape + assert mx_out.shape == desired_shape + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestCauchy("log_prob") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.cauchy(loc.asnumpy(), + scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestCauchy("cdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.cauchy(loc.asnumpy(), + scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestCauchy("icdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.cauchy(loc.asnumpy(), + scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + net = TestCauchy("entropy") + if hybridize: + net.hybridize() + mx_out = net(loc, scale).asnumpy() + np_out = ss.cauchy(loc.asnumpy(), + scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_half_cauchy_v1(): + class TestHalfCauchy(HybridBlock): + def __init__(self, func): + super(TestHalfCauchy, self).__init__() + self._func = func + + def hybrid_forward(self, F, scale, *args): + half_normal = mgp.HalfCauchy(scale, F, validate_args=True) + return getattr(half_normal, self._func)(*args) + + shapes = [(), (1,), (2, 3), 6] + + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + net = TestHalfCauchy("sample") + if hybridize: + net.hybridize() + mx_out = net(scale).asnumpy() + if isinstance(shape, Number): + shape = (shape,) + assert mx_out.shape == shape + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.abs(np.random.normal(size=shape)) + net = TestHalfCauchy("log_prob") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfcauchy(0, scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.abs(np.random.normal(size=shape)) + net = TestHalfCauchy("cdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfcauchy(0, scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestHalfCauchy("icdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfcauchy(0, scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_poisson_v1(): + class TestPoisson(HybridBlock): + def __init__(self, func): + self._func = func + super(TestPoisson, self).__init__() + + def hybrid_forward(self, F, rate, *args): + poisson = mgp.Poisson(rate, F, validate_args=True) + return _distribution_method_invoker(poisson, self._func, *args) + + shapes = [(1,), (2, 3), 6] + # Test sampling + for shape, hybridize in itertools.product(shapes, [False]): + rate = np.random.uniform(0.5, 1.5, shape) + net = TestPoisson("sample") + if hybridize: + net.hybridize() + mx_out = net(rate).asnumpy() + assert mx_out.shape == rate.shape + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + rate = np.random.uniform(0.5, 1.5, shape) + samples = np.random.randint(0, 5, shape).astype('float') + net = TestPoisson("log_prob") + if hybridize: + net.hybridize() + mx_out = net(rate, samples).asnumpy() + np_out = ss.poisson(mu=rate.asnumpy()).logpmf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_geometric_v1(): + class TestGeometric(HybridBlock): + def __init__(self, func, is_logit=False): + super(TestGeometric, self).__init__() + self._is_logit = is_logit + self._func = func + + def hybrid_forward(self, F, params, *args): + dist = mgp.Geometric(logit=params, validate_args=True) if self._is_logit else \ + mgp.Geometric(prob=params, validate_args=True) + return _distribution_method_invoker(dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + # Test log_prob + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + sample = np.random.randint(0, 10, size=shape).astype('float32') + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestGeometric("log_prob", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param, sample).asnumpy() + np_out = ss.geom.logpmf(sample.asnumpy() + 1, prob.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test variance + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestGeometric("variance", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + np_out = ss.geom(prob.asnumpy()).var() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + # Add lower bound constraint, otherwise scipy would raise warning. + prob = np.random.uniform(low=0.1, size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestGeometric("entropy", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + np_out = ss.geom(prob.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_negative_binomial_v1(): + class TestNegativeBinomial(HybridBlock): + def __init__(self, func, is_logit=False): + super(TestNegativeBinomial, self).__init__() + self._is_logit = is_logit + self._func = func + + def hybrid_forward(self, F, n, params, *args): + dist = mgp.NegativeBinomial(n=n, logit=params, validate_args=True) if self._is_logit else \ + mgp.NegativeBinomial(n=n, prob=params, validate_args=True) + return _distribution_method_invoker(dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + # Test log_prob + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + n = np.random.randint(1, 10, size=shape).astype('float32') + prob = np.random.uniform(low=0.1, size=shape) + sample = np.random.randint(0, 10, size=shape).astype('float32') + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestNegativeBinomial("log_prob", use_logit) + if hybridize: + net.hybridize() + mx_out = net(n, param, sample).asnumpy() + np_out = ss.nbinom(n=n.asnumpy(), p=prob.asnumpy() + ).logpmf(sample.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test mean and variance + for shape, hybridize in itertools.product(shapes, [True, False]): + for func in ['mean', 'variance']: + for use_logit in [True, False]: + n = np.random.randint(1, 10, size=shape).astype('float32') + prob = np.random.uniform(low=0.1, size=shape) + net = TestNegativeBinomial(func, use_logit) + param = prob + if use_logit: + param = prob_to_logit(param) + if hybridize: + net.hybridize() + mx_out = net(n, param).asnumpy() + ss_nbinom = ss.nbinom(n=n.asnumpy(), p=1 - prob.asnumpy()) + if func == 'mean': + np_out = ss_nbinom.mean() + else: + np_out = ss_nbinom.var() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_exponential_v1(): + class TestExponential(HybridBlock): + def __init__(self, func): + self._func = func + super(TestExponential, self).__init__() + + def hybrid_forward(self, F, scale, *args): + exponential = mgp.Exponential(scale, F, validate_args=True) + return _distribution_method_invoker(exponential, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(0.2, 1.2, size=shape) + net = TestExponential("log_prob") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.expon(scale=scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(0.2, 1.2, size=shape) + net = TestExponential("cdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.expon(scale=scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(0.0, 1.0, size=shape) + net = TestExponential("icdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.expon(scale=scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + net = TestExponential("entropy") + if hybridize: + net.hybridize() + mx_out = net(scale).asnumpy() + np_out = ss.expon(scale=scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_weibull_v1(): + class TestWeibull(HybridBlock): + def __init__(self, func): + super(TestWeibull, self).__init__() + self._func = func + + def hybrid_forward(self, F, concentration, scale, *args): + weibull = mgp.Weibull(concentration, scale, F, validate_args=True) + return _distribution_method_invoker(weibull, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + concentration = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(size=shape) + net = TestWeibull("log_prob") + if hybridize: + net.hybridize() + mx_out = net(concentration, scale, samples).asnumpy() + np_out = ss.weibull_min(c=concentration.asnumpy( + ), scale=scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + concentration = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(size=shape) + net = TestWeibull("cdf") + if hybridize: + net.hybridize() + mx_out = net(concentration, scale, samples).asnumpy() + np_out = ss.weibull_min(c=concentration.asnumpy( + ), scale=scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + concentration = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(size=shape) + net = TestWeibull("icdf") + if hybridize: + net.hybridize() + mx_out = net(concentration, scale, samples).asnumpy() + np_out = ss.weibull_min(c=concentration.asnumpy( + ), scale=scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + concentration = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + net = TestWeibull("entropy") + if hybridize: + net.hybridize() + mx_out = net(concentration, scale).asnumpy() + np_out = ss.weibull_min(c=concentration.asnumpy(), + scale=scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_pareto_v1(): + class TestPareto(HybridBlock): + def __init__(self, func): + super(TestPareto, self).__init__() + self._func = func + + def hybrid_forward(self, F, alpha, scale, *args): + pareto = mgp.Pareto(alpha, scale, F, validate_args=True) + return _distribution_method_invoker(pareto, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(1, 2, size=shape) + net = TestPareto("log_prob") + if hybridize: + net.hybridize() + mx_out = net(alpha, scale, samples).asnumpy() + np_out = ss.pareto(b=alpha.asnumpy(), scale=scale.asnumpy()).logpdf( + samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(1.0, 2.0, size=shape) + net = TestPareto("cdf") + if hybridize: + net.hybridize() + mx_out = net(alpha, scale, samples).asnumpy() + np_out = ss.pareto(b=alpha.asnumpy(), scale=scale.asnumpy()).cdf( + samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(size=shape) + net = TestPareto("icdf") + if hybridize: + net.hybridize() + mx_out = net(alpha, scale, samples).asnumpy() + np_out = ss.pareto(b=alpha.asnumpy(), scale=scale.asnumpy()).ppf( + samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + net = TestPareto("entropy") + if hybridize: + net.hybridize() + mx_out = net(alpha, scale).asnumpy() + np_out = ss.pareto(b=alpha.asnumpy(), scale=scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_gamma_v1(): + class TestGamma(HybridBlock): + def __init__(self, func): + super(TestGamma, self).__init__() + self._func = func + + def hybrid_forward(self, F, shape, scale, *args): + gamma = mgp.Gamma(shape, scale, F, validate_args=True) + return _distribution_method_invoker(gamma, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(0.5, 1.5, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestGamma("log_prob") + if hybridize: + net.hybridize() + mx_out = net(alpha, scale, samples).asnumpy() + np_out = ss.gamma(a=alpha.asnumpy(), loc=0, + scale=scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test `mean`, `var` and `entropy` + for shape, hybridize in itertools.product(shapes, [True, False]): + for func in ['mean', 'variance', 'entropy']: + alpha = np.random.uniform(0.5, 1.5, shape) + scale = np.random.uniform(0.5, 1.5, shape) + net = TestGamma(func) + if hybridize: + net.hybridize() + mx_out = net(alpha, scale).asnumpy() + ss_gamma = ss.gamma(a=alpha.asnumpy(), loc=0, + scale=scale.asnumpy()) + if func == 'mean': + np_out = ss_gamma.mean() + elif func == 'variance': + np_out = ss_gamma.var() + else: + np_out = ss_gamma.entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_dirichlet_v1(): + class TestDirichlet(HybridBlock): + def __init__(self, func): + super(TestDirichlet, self).__init__() + self._func = func + + def hybrid_forward(self, F, alpha, *args): + dirichlet = mgp.Dirichlet(alpha, F, validate_args=True) + return _distribution_method_invoker(dirichlet, self._func, *args) + + event_shapes = [2, 5, 10] + batch_shapes = [None, (2, 3)] + + # Test sampling + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for hybridize in [True, False]: + desired_shape = ( + batch_shape if batch_shape is not None else ()) + (event_shape,) + alpha = np.random.uniform(size=desired_shape) + net = TestDirichlet("sample") + if hybridize: + net.hybridize() + mx_out = net(alpha).asnumpy() + # Check shape + assert mx_out.shape == desired_shape + # Check simplex + assert_almost_equal(mx_out.sum(-1), _np.ones_like(mx_out.sum(-1)), atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test log_prob + # Scipy does not support batch `alpha`, thus we skip multi-dimensional batch_shape case. + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes[:1]): + for hybridize in [True, False]: + desired_shape = ( + batch_shape if batch_shape is not None else ()) + (event_shape,) + alpha = np.random.uniform(size=desired_shape) + np_samples = _np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape) + net = TestDirichlet("log_prob") + if hybridize: + net.hybridize() + mx_out = net(alpha, np.array(np_samples)).asnumpy() + np_out = ss.dirichlet(alpha=alpha.asnumpy()).logpdf(np_samples) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test `mean`, `var` and `entropy` + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes[:1]): + for hybridize in [False]: + for func in ['mean', 'variance', 'entropy']: + desired_shape = ( + batch_shape if batch_shape is not None else ()) + (event_shape,) + alpha = np.random.uniform(size=desired_shape) + net = TestDirichlet(func) + if hybridize: + net.hybridize() + mx_out = net(alpha).asnumpy() + ss_dir = ss.dirichlet(alpha=alpha.asnumpy()) + if func == 'mean': + np_out = ss_dir.mean() + elif func == 'variance': + np_out = ss_dir.var() + else: + np_out = ss_dir.entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_beta_v1(): + class TestBeta(HybridBlock): + def __init__(self, func): + super(TestBeta, self).__init__() + self._func = func + + def hybrid_forward(self, F, alpha, beta, *args): + beta_dist = mgp.Beta(alpha, beta, F, validate_args=True) + return _distribution_method_invoker(beta_dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(0.5, 1.5, shape) + beta = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestBeta("log_prob") + if hybridize: + net.hybridize() + mx_out = net(alpha, beta, samples).asnumpy() + np_out = ss.beta(alpha.asnumpy(), beta.asnumpy() + ).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test `mean`, `var` and `entropy` + for shape, hybridize in itertools.product(shapes, [True, False]): + for func in ['mean', 'variance', 'entropy']: + alpha = np.random.uniform(0.5, 1.5, shape) + beta = np.random.uniform(0.5, 1.5, shape) + net = TestBeta(func) + if hybridize: + net.hybridize() + mx_out = net(alpha, beta).asnumpy() + ss_beta = ss.beta(alpha.asnumpy(), beta.asnumpy()) + if func == 'mean': + np_out = ss_beta.mean() + elif func == 'variance': + np_out = ss_beta.var() + else: + np_out = ss_beta.entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_fisher_snedecor_v1(): + class TestFisherSnedecor(HybridBlock): + def __init__(self, func): + super(TestFisherSnedecor, self).__init__() + self._func = func + + def hybrid_forward(self, F, df1, df2, *args): + beta_dist = mgp.FisherSnedecor(df1, df2, F, validate_args=True) + return _distribution_method_invoker(beta_dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + df1 = np.random.uniform(0.5, 1.5, shape) + df2 = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestFisherSnedecor("log_prob") + if hybridize: + net.hybridize() + mx_out = net(df1, df2, samples).asnumpy() + np_out = ss.f(dfn=df1.asnumpy(), dfd=df2.asnumpy() + ).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test `mean` and `var` + for shape, hybridize in itertools.product(shapes, [True, False]): + for func in ['mean', 'variance']: + df1 = np.random.uniform(0.5, 1.5, shape) + df2 = np.random.uniform(4.0, 6.0, shape) + net = TestFisherSnedecor(func) + if hybridize: + net.hybridize() + mx_out = net(df1, df2).asnumpy() + ss_f = ss.f(dfn=df1.asnumpy(), dfd=df2.asnumpy()) + if func == 'mean': + np_out = ss_f.mean() + else: + np_out = ss_f.var() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_student_t_v1(): + class TestT(HybridBlock): + def __init__(self, func): + super(TestT, self).__init__() + self._func = func + + def hybrid_forward(self, F, df, loc, scale, *args): + t_dist = mgp.StudentT(df, loc, scale, F, validate_args=True) + return _distribution_method_invoker(t_dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.zeros(shape) + scale = np.random.uniform(0.5, 1.5, shape) + df = np.random.uniform(2, 4, shape) + samples = np.random.uniform(0, 4, size=shape) + net = TestT("log_prob") + if hybridize: + net.hybridize() + mx_out = net(df, loc, scale, samples).asnumpy() + np_out = ss.t(loc=0, scale=scale.asnumpy(), + df=df.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test `mean`, `var` and `entropy` + for shape, hybridize in itertools.product(shapes, [False, True]): + for func in ['mean', 'variance', 'entropy']: + loc = np.zeros(shape) + scale = np.random.uniform(0.5, 1.5, shape) + df = np.random.uniform(3, 4, shape) + net = TestT(func) + if hybridize: + net.hybridize() + mx_out = net(df, loc, scale).asnumpy() + ss_f = ss.t(loc=0, scale=scale.asnumpy(), df=df.asnumpy()) + if func == 'mean': + np_out = ss_f.mean() + elif func == 'variance': + np_out = ss_f.var() + else: + np_out = ss_f.entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_gumbel_v1(): + class TestGumbel(HybridBlock): + def __init__(self, func): + super(TestGumbel, self).__init__() + self._func = func + + def hybrid_forward(self, F, loc, scale, *args): + normal = mgp.Gumbel(loc, scale, F, validate_args=True) + return getattr(normal, self._func)(*args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestGumbel("log_prob") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.gumbel_r(loc=loc.asnumpy(), + scale=scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestGumbel("cdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.gumbel_r(loc.asnumpy(), + scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestGumbel("icdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.gumbel_r(loc.asnumpy(), + scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + net = TestGumbel("entropy") + if hybridize: + net.hybridize() + mx_out = net(loc, scale).asnumpy() + np_out = ss.gumbel_r(loc.asnumpy(), + scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_multinomial_v1(): + class TestMultinomial(HybridBlock): + def __init__(self, func, num_events, total_count, is_logit, batch_shape=None, sample_shape=None): + super(TestMultinomial, self).__init__() + self._num_events = num_events + self._total_count = total_count + self._is_logit = is_logit + self._func = func + self._batch_shape = batch_shape + self._sample_shape = sample_shape + + def hybrid_forward(self, F, params, *args): + multinomial = ( + mgp.Multinomial(self._num_events, logit=params, total_count=self._total_count, + validate_args=True) + if self._is_logit else + mgp.Multinomial(self._num_events, prob=params, total_count=self._total_count, + validate_args=True) + ) + if self._func == 'sample': + return multinomial.sample(self._batch_shape) + if self._func == 'sample_n': + return multinomial.sample_n(self._sample_shape) + return _distribution_method_invoker(multinomial, self._func, *args) + + def one_hot(a, num_classes): + return np.identity(num_classes)[a] + + event_shapes = [2, 5, 10] + batch_shapes = [None, (2, 3)] # , (4, 0, 5)] + sample_shapes = [None, (2,), (3, 4)] + + # Test sampling + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob + if use_logit: + param = np.log(param) + net = TestMultinomial("sample", event_shape, _np.random.randint(1, 5), + use_logit, batch_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = batch_shape if batch_shape is not None else () + assert mx_out.shape == desired_shape + (event_shape,) + + # Test sample_n + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob + if use_logit: + param = np.log(param) + net = TestMultinomial("sample_n", event_shape, _np.random.randint(1, 5), + use_logit, batch_shape, sample_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + sample_shape = () if sample_shape is None else sample_shape + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + assert mx_out.shape == desired_shape + (event_shape,) + + # Test log_prob + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + eps = _np.finfo('float32').eps + prob = np.clip(prob, eps, 1 - eps) + param = prob + sample_shape = () if sample_shape is None else sample_shape + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + samples = np.random.choice(event_shape, size=desired_shape) + samples = one_hot(samples, event_shape) + if use_logit: + param = np.log(param) + net = TestMultinomial("log_prob", event_shape, + _np.random.randint(1, 5), use_logit) + if hybridize: + net.hybridize() + mx_out = net(param, samples).asnumpy() + # Check shape + assert mx_out.shape == desired_shape + + +@with_seed() +@use_np +def test_gluon_binomial_v1(): + class TestBinomial(HybridBlock): + def __init__(self, func, is_logit=False, n=1): + super(TestBinomial, self).__init__() + self._is_logit = is_logit + self._func = func + self._n = n + + def hybrid_forward(self, F, params, *args): + dist = mgp.Binomial(n=self._n, logit=params, validate_args=True) \ + if self._is_logit else \ + mgp.Binomial(n=self._n, prob=params, validate_args=True) + return _distribution_method_invoker(dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + for use_logit in [True, False]: + n = _np.random.randint(5, 10) + prob = np.random.uniform(low=0.1, size=shape) + net = TestBinomial('sample', use_logit, n=float(n)) + param = prob + if use_logit: + param = prob_to_logit(param) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = (shape,) if isinstance(shape, int) else shape + assert mx_out.shape == desired_shape + + # Test sample_n + prefix_shape = (2, 3) + for shape in shapes: + n = _np.random.randint(5, 10) + prob = np.random.uniform(low=0.1, size=shape) + dist = mgp.Binomial(n=n, prob=prob) + samples = dist.sample_n(prefix_shape) + assert samples.shape == (prefix_shape + prob.shape) + + # Test log_prob + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + n = _np.random.randint(5, 10) + prob = np.random.uniform(low=0.1, size=shape) + sample = np.random.randint(0, n, size=shape).astype('float32') + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestBinomial("log_prob", use_logit, n=float(n)) + if hybridize: + net.hybridize() + mx_out = net(param, sample).asnumpy() + np_out = ss.binom(n=n, p=prob.asnumpy()).logpmf(sample.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test mean and variance + for shape, hybridize in itertools.product(shapes, [True, False]): + for func in ['mean', 'variance']: + for use_logit in [True, False]: + n = _np.random.randint(5, 10) + prob = np.random.uniform(low=0.1, size=shape) + net = TestBinomial(func, use_logit, n=float(n)) + param = prob + if use_logit: + param = prob_to_logit(param) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + ss_binom = ss.binom(n=n, p=prob.asnumpy()) + if func == 'mean': + np_out = ss_binom.mean() + else: + np_out = ss_binom.var() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_bernoulli_v1(): + class TestBernoulli(HybridBlock): + def __init__(self, func, is_logit=False): + super(TestBernoulli, self).__init__() + self._is_logit = is_logit + self._func = func + + def hybrid_forward(self, F, params, *args): + bernoulli = mgp.Bernoulli(logit=params, validate_args=True) if self._is_logit else \ + mgp.Bernoulli(prob=params, validate_args=True) + return _distribution_method_invoker(bernoulli, self._func, *args) + + # Test log_prob + shapes = [(), (1,), (2, 3), 6] + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + sample = npx.random.bernoulli(prob=0.5, size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestBernoulli("log_prob", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param, sample).asnumpy() + np_out = _np.log(ss.bernoulli.pmf(sample.asnumpy(), prob.asnumpy())) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test variance + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + sample = npx.random.bernoulli(prob=0.5, size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestBernoulli("variance", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + np_out = ss.bernoulli(prob.asnumpy()).var() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + sample = npx.random.bernoulli(prob=0.5, size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestBernoulli("entropy", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + np_out = ss.bernoulli(prob.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_relaxed_bernoulli_v1(): + class TestRelaxedBernoulli(HybridBlock): + def __init__(self, func, is_logit=False): + super(TestRelaxedBernoulli, self).__init__() + self._is_logit = is_logit + self._func = func + + def hybrid_forward(self, F, params, *args): + relaxed_bernoulli = mgp.RelaxedBernoulli(T=1.0, logit=params, validate_args=True)\ + if self._is_logit else \ + mgp.RelaxedBernoulli(T=1.0, prob=params, validate_args=True) + if self._func == "sample": + return relaxed_bernoulli.sample() + return _distribution_method_invoker(relaxed_bernoulli, self._func, *args) + + def prob_to_logit(prob): + return np.log(prob) - np.log1p(-prob) + + shapes = [(), (1,), (2, 3), 6] + # Test sampling + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + param.attach_grad() + net = TestRelaxedBernoulli("sample", use_logit) + if hybridize: + net.hybridize() + with autograd.record(): + mx_out = net(param) + mx_out.backward() + desired_shape = (shape,) if isinstance(shape, int) else shape + assert param.grad.shape == desired_shape + + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + sample = np.random.uniform(0.1, 0.9, size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestRelaxedBernoulli("log_prob", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param, sample).asnumpy() + desired_shape = (shape,) if isinstance(shape, int) else shape + assert mx_out.shape == desired_shape + + +@with_seed() +@use_np +def test_gluon_categorical_v1(): + class TestCategorical(HybridBlock): + def __init__(self, func, is_logit=False, batch_shape=None, num_events=None, sample_shape=None): + super(TestCategorical, self).__init__() + self._is_logit = is_logit + self._func = func + self._batch_shape = batch_shape + self._num_events = num_events + self._sample_shape = sample_shape + + def hybrid_forward(self, F, params, *args): + categorical = mgp.Categorical(self._num_events, logit=params, validate_args=True)\ + if self._is_logit else \ + mgp.Categorical(self._num_events, prob=params, + validate_args=True) + if self._func == "sample": + return categorical.sample(self._batch_shape) + if self._func == "sample_n": + return categorical.sample_n(self._sample_shape) + return _distribution_method_invoker(categorical, self._func, *args) + + event_shapes = [2, 5, 10] + batch_shapes = [None, (2, 3)] # , (4, 0, 5)] + sample_shapes = [(), (2,), (3, 4)] + + # Test sampling + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob.astype('float32') + if use_logit: + param = np.log(param) + net = TestCategorical("sample", use_logit, + batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = batch_shape if batch_shape is not None else () + assert mx_out.shape == desired_shape + + # Test sample_n + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob.astype('float32') + if use_logit: + param = np.log(param) + net = TestCategorical("sample_n", + is_logit=use_logit, batch_shape=batch_shape, + num_events=event_shape, sample_shape=sample_shape + ) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + assert mx_out.shape == desired_shape + + # Test log_prob + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + eps = _np.finfo('float32').eps + prob = np.clip(prob, eps, 1 - eps) + param = prob.astype('float32') + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + samples = np.random.choice(event_shape, size=desired_shape) + if use_logit: + param = np.log(param) + net = TestCategorical("log_prob", use_logit, + batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param, samples) + # Check shape + assert mx_out.shape == desired_shape + # Check value + log_pmf, indices = np.broadcast_arrays( + np.log(prob), np.expand_dims(samples, -1)) + if indices.ndim >= 1: + indices = indices[..., :1] + expect_log_prob = _np.take_along_axis( + log_pmf, indices.astype('int'), axis=-1).asnumpy() + assert_almost_equal(mx_out.asnumpy(), expect_log_prob.squeeze(), atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test enumerate_support + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob.astype('float32') + if use_logit: + param = np.log(param) + net = TestCategorical("enumerate_support", + use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = (event_shape,) + \ + (batch_shape if batch_shape is not None else ()) + assert mx_out.shape == desired_shape + + +@with_seed() +@use_np +def test_gluon_one_hot_categorical_v1(): + def one_hot(a, num_classes): + return np.identity(num_classes)[a] + + class TestOneHotCategorical(HybridBlock): + def __init__(self, func, is_logit=False, batch_shape=None, num_events=None): + super(TestOneHotCategorical, self).__init__() + self._is_logit = is_logit + self._func = func + self._batch_shape = batch_shape + self._num_events = num_events + + def hybrid_forward(self, F, params, *args): + categorical = mgp.OneHotCategorical(num_events=self._num_events, logit=params) \ + if self._is_logit else \ + mgp.OneHotCategorical(num_events=self._num_events, prob=params) + if self._func == "sample": + return categorical.sample(self._batch_shape) + return _distribution_method_invoker(categorical, self._func, *args) + + event_shapes = [2, 5, 10] + batch_shapes = [None, (2, 3)] # , (4, 0, 5)] + sample_shapes = [(), (2,), (3, 4)] + + # Test sampling + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob + if use_logit: + param = np.log(param) + net = TestOneHotCategorical( + "sample", use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = batch_shape if batch_shape is not None else () + assert mx_out.shape == desired_shape + (event_shape,) + + # Test log_prob + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + eps = _np.finfo('float32').eps + prob = np.clip(prob, eps, 1 - eps) + param = prob + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + samples = np.random.choice(event_shape, size=desired_shape) + samples = one_hot(samples, event_shape) + if use_logit: + param = np.log(param) + net = TestOneHotCategorical( + "log_prob", use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param, samples) + # Check shape + assert mx_out.shape == desired_shape + + # Test enumerate support + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob + if use_logit: + param = np.log(param) + net = TestOneHotCategorical( + "enumerate_support", use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = batch_shape if batch_shape is not None else () + assert mx_out.shape == (event_shape,) + \ + desired_shape + (event_shape,) + + +@with_seed() +@use_np +def test_relaxed_one_hot_categorical_v1(): + class TestRelaxedOneHotCategorical(HybridBlock): + def __init__(self, func, is_logit=False, batch_shape=None, num_events=None): + super(TestRelaxedOneHotCategorical, self).__init__() + self._is_logit = is_logit + self._func = func + self._batch_shape = batch_shape + self._num_events = num_events + + def hybrid_forward(self, F, params, *args): + categorical = mgp.RelaxedOneHotCategorical(T=1.0, num_events=self._num_events, logit=params) \ + if self._is_logit else \ + mgp.RelaxedOneHotCategorical( + T=1.0, num_events=self._num_events, prob=params) + if self._func == "sample": + return categorical.sample(self._batch_shape) + return _distribution_method_invoker(categorical, self._func, *args) + + event_shapes = [2, 5, 10] + batch_shapes = [None, (2, 3)] # , (4, 0, 5)] + sample_shapes = [(), (2,), (3, 4)] + + # Test sampling + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + prob = prob.astype('float32') + param = prob + if use_logit: + param = np.log(param) + param.attach_grad() + net = TestRelaxedOneHotCategorical( + "sample", use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + with autograd.record(): + mx_out = net(param) + mx_out.backward() + desired_shape = batch_shape if batch_shape is not None else () + assert mx_out.shape == desired_shape + (event_shape,) + assert param.grad.shape == param.shape + + # Test log_prob + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + eps = _np.finfo('float32').eps + prob = np.clip(prob, eps, 1 - eps) + param = prob + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + # Samples from a Relaxed One-hot Categorical lie on a simplex. + samples = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=desired_shape)) + if use_logit: + param = np.log(param) + net = TestRelaxedOneHotCategorical( + "log_prob", use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param, samples) + # Check shape + assert mx_out.shape == desired_shape + + +@with_seed() +@use_np +def test_gluon_mvn_v1(): + class TestMVN(HybridBlock): + def __init__(self, func, param_type): + super(TestMVN, self).__init__() + self._func = func + # cov, precision or scale_tril + self._param_type = param_type + + def hybrid_forward(self, F, loc, cov, *args): + mvn = mgp.MultivariateNormal(loc=loc, **{self._param_type: cov}, + validate_args=True) + return _distribution_method_invoker(mvn, self._func, *args) + + def _stable_inv(cov): + """ + Force the precision matrix to be symmetric. + """ + precision = np.linalg.inv(cov) + precision_t = np.swapaxes(precision, -1, -2) + return (precision + precision_t) / 2 + + event_shapes = [3, 5] + loc_shapes = [(), (2,), (4, 2)] + cov_shapes = [(), (2,), (4, 2)] + cov_func = { + 'cov': lambda s: s, + 'precision': lambda s: _stable_inv(s), + 'scale_tril': lambda s: np.linalg.cholesky(s) + } + + # Test sampling + for loc_shape, cov_shape, event_shape in itertools.product(loc_shapes, cov_shapes, event_shapes): + for cov_type in cov_func.keys(): + for hybridize in [False]: + loc = np.random.randn(*(loc_shape + (event_shape,))) + _s = np.random.randn(*(cov_shape + (event_shape, event_shape))) + loc.attach_grad() + _s.attach_grad() + # Full covariance matrix + sigma = np.matmul(_s, np.swapaxes( + _s, -1, -2)) + np.eye(event_shape) + cov_param = cov_func[cov_type](sigma) + net = TestMVN('sample', cov_type) + if hybridize: + net.hybridize() + with autograd.record(): + mx_out = net(loc, cov_param) + desired_shape = (loc + sigma[..., 0]).shape + assert mx_out.shape == desired_shape + mx_out.backward() + assert loc.grad.shape == loc.shape + assert _s.grad.shape == _s.shape + + # Test log_prob + for loc_shape, cov_shape, event_shape in itertools.product(loc_shapes, cov_shapes, event_shapes): + for cov_type in cov_func.keys(): + for hybridize in [True, False]: + loc = np.random.randn(*(loc_shape + (event_shape,))) + _s = np.random.randn(*(cov_shape + (event_shape, event_shape))) + samples = np.random.normal( + np.zeros_like(loc), np.ones_like(_s[..., 0])) + loc.attach_grad() + _s.attach_grad() + # Full covariance matrix + sigma = np.matmul(_s, np.swapaxes( + _s, -1, -2)) + np.eye(event_shape) + cov_param = cov_func[cov_type](sigma) + net = TestMVN('log_prob', cov_type) + if hybridize: + net.hybridize() + mx_out = net(loc, cov_param, samples) + assert mx_out.shape == samples.shape[:-1] + # Select the first element in the batch, because scipy does not support batching. + loc_t = loc.reshape(-1, event_shape)[0].asnumpy() + sigma_t = sigma.reshape(-1, event_shape, + event_shape)[0].asnumpy() + if mx_out.shape == (): + mx_out_t = mx_out.asnumpy() + else: + mx_out_t = mx_out.flatten()[0].asnumpy() + samples_t = samples.reshape(-1, event_shape).asnumpy()[0] + scipy_mvn = ss.multivariate_normal(loc_t, sigma_t) + ss_out = scipy_mvn.logpdf(samples_t) + assert_almost_equal(mx_out_t, ss_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for loc_shape, cov_shape, event_shape in itertools.product(loc_shapes, cov_shapes, event_shapes): + for cov_type in cov_func.keys(): + for hybridize in [True, False]: + loc = np.random.randn(*(loc_shape + (event_shape,))) + _s = np.random.randn(*(cov_shape + (event_shape, event_shape))) + loc.attach_grad() + _s.attach_grad() + # Full covariance matrix + sigma = np.matmul(_s, np.swapaxes( + _s, -1, -2)) + np.eye(event_shape) + cov_param = cov_func[cov_type](sigma) + net = TestMVN('entropy', cov_type) + if hybridize: + net.hybridize() + mx_out = net(loc, cov_param) + assert mx_out.shape == sigma.shape[:-2] + # Select the first element in the batch, because scipy does not support batching. + loc_t = loc.reshape(-1, event_shape)[0].asnumpy() + sigma_t = sigma.reshape(-1, event_shape, + event_shape)[0].asnumpy() + if mx_out.shape == (): + mx_out_t = mx_out.asnumpy() + else: + mx_out_t = mx_out.flatten()[0].asnumpy() + scipy_mvn = ss.multivariate_normal(loc_t, sigma_t) + ss_out = scipy_mvn.entropy() + assert_almost_equal(mx_out_t, ss_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_half_normal_v1(): + class TestHalfNormal(HybridBlock): + def __init__(self, func): + super(TestHalfNormal, self).__init__() + self._func = func + + def hybrid_forward(self, F, scale, *args): + half_normal = mgp.HalfNormal(scale, F, validate_args=True) + return getattr(half_normal, self._func)(*args) + + shapes = [(), (1,), (2, 3), 6] + + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + net = TestHalfNormal("sample") + if hybridize: + net.hybridize() + mx_out = net(scale).asnumpy() + if isinstance(shape, Number): + shape = (shape,) + assert mx_out.shape == shape + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.abs(np.random.normal(size=shape)) + net = TestHalfNormal("log_prob") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfnorm(0, scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.abs(np.random.normal(size=shape)) + net = TestHalfNormal("cdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfnorm(0, scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestHalfNormal("icdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfnorm(0, scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_affine_transform_v1(): + r""" + Test the correctness of affine transformation by performing it + on a standard normal, since N(\mu, \sigma^2) = \mu + \sigma * N(0, 1) + """ + class TestAffineTransform(HybridBlock): + def __init__(self, func): + super(TestAffineTransform, self).__init__() + self._func = func + + def hybrid_forward(self, F, loc, scale, *args): + std_normal = mgp.Normal(F.np.zeros_like(loc), + F.np.ones_like(scale), F) + transforms = [mgp.AffineTransform(loc=0, scale=scale), + mgp.AffineTransform(loc=loc, scale=1)] + transformed_normal = mgp.TransformedDistribution( + std_normal, transforms) + if (len(args) == 0): + return getattr(transformed_normal, self._func) + return getattr(transformed_normal, self._func)(*args) + + shapes = [(1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + loc.attach_grad() + scale = np.random.uniform(0.5, 1.5, shape) + scale.attach_grad() + samples = np.random.normal(size=shape) + net = TestAffineTransform('log_prob') + if hybridize: + net.hybridize() + with autograd.record(): + mx_out = net(loc, scale, samples) + np_out = _np.log(ss.norm(loc.asnumpy(), + scale.asnumpy()).pdf(samples.asnumpy())) + assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + mx_out.backward() + loc_expected_grad = ((samples - loc) / scale ** 2).asnumpy() + scale_expected_grad = (samples - loc) ** 2 * \ + np.power(scale, -3) - (1 / scale) + assert_almost_equal(loc.grad.asnumpy(), loc_expected_grad, atol=1e-4, + rtol=1e-3, use_broadcast=False) + assert_almost_equal(scale.grad.asnumpy(), scale_expected_grad, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + loc.attach_grad() + scale = np.random.uniform(0.5, 1.5, shape) + scale.attach_grad() + if not isinstance(shape, tuple): + shape = (shape,) + expected_shape = (4, 5) + shape + net = TestAffineTransform('sample') + mx_out = net(loc, scale, expected_shape).asnumpy() + assert mx_out.shape == expected_shape + + +@with_seed() +@use_np +def test_compose_transform_v1(): + class TestComposeTransform(HybridBlock): + def __init__(self, func): + super(TestComposeTransform, self).__init__() + self._func = func + + def hybrid_forward(self, F, loc, scale, *args): + # Generate a log_normal distribution. + std_normal = mgp.Normal(F.np.zeros_like(loc), + F.np.ones_like(scale), F) + transforms = mgp.ComposeTransform([ + mgp.AffineTransform(loc=0, scale=scale), + mgp.AffineTransform(loc=loc, scale=1), + mgp.ExpTransform() + ]) + transformed_normal = mgp.TransformedDistribution( + std_normal, transforms) + if (len(args) == 0): + return getattr(transformed_normal, self._func) + return getattr(transformed_normal, self._func)(*args) + + shapes = [(1,), (2, 3), 6] + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + loc.attach_grad() + scale = np.random.uniform(0.5, 1.5, shape) + scale.attach_grad() + samples = np.random.uniform(1, 2, size=shape) + net = TestComposeTransform('log_prob') + if hybridize: + net.hybridize() + with autograd.record(): + mx_out = net(loc, scale, samples) + np_out = ss.lognorm(s=scale.asnumpy(), scale=np.exp( + loc).asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@use_np +def test_cached_property_v1(): + x = np.random.normal() + x.attach_grad() + scale = 0.1 + + class Dummy(object): + def __init__(self, x): + super(Dummy, self).__init__() + self.x = x + + @mgp.cached_property + def y(self): + return scale * self.x + 1 + + with autograd.record(): + obj = Dummy(x) + obj.y.backward() + assert_almost_equal(x.grad.asnumpy(), scale * np.ones((1,))) + + class DummyBlock(HybridBlock): + def hybrid_forward(self, F, x): + obj = Dummy(x) + return obj.y + + x = np.random.normal() + x.attach_grad() + net = DummyBlock() + with autograd.record(): + y = net(x) + y.backward() + assert_almost_equal(x.grad.asnumpy(), scale * np.ones((1,))) + + x = np.random.normal() + x.attach_grad() + net.hybridize() + with autograd.record(): + y = net(x) + y.backward() + assert_almost_equal(x.grad.asnumpy(), scale * np.ones((1,))) + + +@use_np +def test_independent_v1(): + class TestIndependent(HybridBlock): + def __init__(self, event_dim, func): + super(TestIndependent, self).__init__() + self._event_dim = event_dim + self._func = func + + def hybrid_forward(self, F, logit, *args): + base_dist = mgp.Bernoulli(logit=logit) + reshaped_dist = mgp.Independent(base_dist, self._event_dim) + return getattr(reshaped_dist, self._func)(*args) + + event_shapes = [(1,), (4,), (2, 2)] + batch_shapes = [(2, 3), (2,)] + for (batch_shape, event_shape) in itertools.product(batch_shapes, event_shapes): + for hybridize in [False, True]: + for func in ['log_prob']: + full_shape = batch_shape + event_shape + logit = np.random.normal(0, 2, size=full_shape) + samples = np.round(np.random.uniform(size=full_shape)) + net = TestIndependent(len(event_shape), func) + if hybridize: + net.hybridize() + mx_out = net(logit, samples) + assert mx_out.shape == batch_shape + + +@with_seed() +@use_np +def test_gluon_kl_v1(): + def _test_zero_kl(p, shape): + """Check if KL(p || p) = 0 + + Parameters + ---------- + p : Distribution + """ + mx_out = mgp.kl_divergence(p, p).asnumpy() + np_out = _np.zeros(shape) + assert_almost_equal(mx_out, np_out, atol=1e-3, + rtol=1e-2, use_broadcast=False) + + def _test_monte_carlo(p, q, M=50000): + r"""Check if KL(p || q) is approximately equal to + 1/M * \Sum_{i=1}^{M} log(p(x_i) / q(x_i)), x_i ~ p(x) + """ + kl = mgp.kl_divergence(p, q) + mc_approx = mgp.empirical_kl(p, q, M) + assert_almost_equal(mc_approx.asnumpy(), kl.asnumpy(), atol=1e-1, + rtol=1e-1, use_broadcast=False) + + def _dist_factory(dist, *param_funcs): + """Generate a distribution object with parameters of random value. + + Parameters + ---------- + dist : Type + A type of distribution. + param_funcs : List + A list of functions that generate valid parameters for `dist` + """ + params = [f() if callable(f) else f for f in param_funcs] + return dist(*params) + + # could cause longer runtime and potential flaky tests + monte_carlo_test = False + repeated_times = 50000 + shapes = [(), (1,), (2, 3), 6] + + # Test kl between same distributions + # uniform + for shape in shapes: + dist = mgp.Uniform + def low(): return np.random.uniform(0, 1, shape) + def high(): return np.random.uniform(1, 2, shape) + _test_zero_kl(_dist_factory(dist, low, high), shape) + + # normal, laplace, cauchy, gumbel + for dist in [mgp.Normal, mgp.Laplace, mgp.Cauchy, mgp.Gumbel]: + for shape in shapes: + def loc(): return np.random.uniform(-1, 1, shape) + def scale(): return np.random.uniform(0.5, 1.5, shape) + _test_zero_kl(_dist_factory(dist, loc, scale), shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, loc, scale), + _dist_factory(dist, loc, scale), + repeated_times) + + # poisson + for shape in shapes[1:]: + dist = mgp.Poisson + def rate(): return np.random.uniform(0.5, 1.5, shape) + _test_zero_kl(_dist_factory(dist, rate), shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, rate), + _dist_factory(dist, rate), + repeated_times) + + # exponential, geometric + for dist in [mgp.Exponential, mgp.Geometric]: + for shape in shapes: + def s(): return np.random.uniform(size=shape) + _test_zero_kl(_dist_factory(dist, s), shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, s), + _dist_factory(dist, s), + repeated_times) + + # pareto + for shape in shapes: + dist = mgp.Pareto + def alpha(): return np.random.uniform(size=shape) + def scale(): return np.random.uniform(size=shape) + _test_zero_kl(_dist_factory(dist, alpha, scale), shape) + + for shape in shapes: + dist = mgp.HalfNormal + def scale(): return np.random.uniform(0.5, 1.5, shape) + _test_zero_kl(_dist_factory(dist, scale), shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, scale), + _dist_factory(dist, scale), + repeated_times) + + # gamma, beta + for dist in [mgp.Gamma, mgp.Beta]: + for shape in shapes: + def param1(): return np.random.uniform(0.5, 1.5, shape) + def param2(): return np.random.uniform(0.5, 1.5, shape) + _test_zero_kl(_dist_factory(dist, param1, param2), shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, param1, param2), + _dist_factory(dist, param1, param2), + 50000) + + # binomial + for shape in shapes: + n = _np.random.randint(5, 10) + prob = np.random.uniform(low=0.1, size=shape) + dist = mgp.Binomial(n=n, prob=prob) + _test_zero_kl(dist, shape) + + # bernoulli + for shape in shapes: + prob = np.random.uniform(size=shape) + dist = mgp.Bernoulli(prob=prob) + _test_zero_kl(dist, shape) + + event_shapes = [3, 5, 10] + loc_shapes = [(), (2,), (4, 2)] + cov_shapes = [(), (2,), (4, 2)] + for loc_shape, cov_shape, event_shape in itertools.product(loc_shapes, cov_shapes, event_shapes): + loc = np.random.randn(*(loc_shape + (event_shape,))) + _s = np.random.randn(*(cov_shape + (event_shape, event_shape))) + sigma = np.matmul(_s, np.swapaxes(_s, -1, -2)) + np.eye(event_shape) + dist = mgp.MultivariateNormal(loc, cov=sigma) + desired_shape = (loc + sigma[..., 0]).shape[:-1] + _test_zero_kl(dist, desired_shape) + + batch_shapes = loc_shapes + # dirichlet + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + desired_shape = (batch_shape if batch_shape is not None else ()) + dist = mgp.Dirichlet + def alpha(): return np.random.uniform( + 0.5, 1.5, size=(desired_shape + (event_shape,))) + _test_zero_kl(_dist_factory(dist, alpha), desired_shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, alpha), + _dist_factory(dist, alpha), + 50000) + + # categorical, One-hot categorical + for dist in [mgp.Categorical, mgp.OneHotCategorical]: + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + prob = (lambda: + np.array(_np.random.dirichlet([1 / event_shape] * event_shape, size=batch_shape))) + _test_zero_kl(_dist_factory(dist, event_shape, prob), batch_shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, event_shape, prob), + _dist_factory(dist, event_shape, prob), + repeated_times) + + # Test kl between different distributions + # KL(Uniform || ...) + for shape in shapes: + rhs_dists = [ + mgp.Normal(np.random.uniform(-1, 1, shape), + np.random.uniform(0.5, 1.5, shape)), + mgp.Gumbel(np.random.uniform(-1, 1, shape), + np.random.uniform(0.5, 1.5, shape)), + ] + for rhs_dist in rhs_dists: + low = np.random.uniform(-1, 1, shape) + high = low + np.random.uniform(0.5, 1.5, shape) + lhs_dist = mgp.Uniform(low, high) + kl = mgp.kl_divergence(lhs_dist, rhs_dist) + assert kl.shape == low.shape + if monte_carlo_test: + _test_monte_carlo(lhs_dist, rhs_dist, repeated_times) + + # KL(Exponential || ...) + for shape in shapes: + rhs_dists = [ + mgp.Normal(np.random.uniform(-1, 1, shape), + np.random.uniform(0.5, 1.5, shape)), + mgp.Gumbel(np.random.uniform(-1, 1, shape), + np.random.uniform(0.5, 1.5, shape)), + mgp.Gamma(np.random.uniform(0.5, 1.5, shape), + np.random.uniform(0.5, 1.5, shape)) + ] + for rhs_dist in rhs_dists: + s = np.random.uniform(size=shape) + lhs_dist = mgp.Exponential(s) + kl = mgp.kl_divergence(lhs_dist, rhs_dist) + assert kl.shape == s.shape + if monte_carlo_test: + _test_monte_carlo(lhs_dist, rhs_dist, repeated_times) + + +@pytest.mark.garbage_expected +@with_seed() +@use_np +def test_gluon_stochastic_block_v1(): + class dummyBlock(StochasticBlock): + """In this test case, we generate samples from a Gaussian parameterized + by `loc` and `scale` and accumulate the KL-divergence between it and + its prior and the l2 norm of `loc` into the block's loss storage.""" + @StochasticBlock.collectLoss + def hybrid_forward(self, F, loc, scale): + qz = mgp.Normal(loc, scale) + # prior + pz = mgp.Normal(F.np.zeros_like(loc), F.np.ones_like(scale)) + self.add_loss(mgp.kl_divergence(qz, pz)) + self.add_loss((loc ** 2).sum(1)) + return qz.sample() + + shape = (4, 4) + for hybridize in [True, False]: + net = dummyBlock() + if hybridize: + net.hybridize() + loc = np.random.randn(*shape) + scale = np.random.rand(*shape) + mx_out = net(loc, scale).asnumpy() + kl = net.losses[0].asnumpy() + l2_norm = net.losses[1].asnumpy() + assert mx_out.shape == loc.shape + assert kl.shape == loc.shape + assert l2_norm.shape == shape[:-1] + + +@with_seed() +@use_np +def test_gluon_stochastic_block_exception_v1(): + class problemBlock(StochasticBlock): + def hybrid_forward(self, F, loc, scale): + qz = mgp.Normal(loc, scale) + # prior + pz = mgp.Normal(F.np.zeros_like(loc), F.np.ones_like(scale)) + self.add_loss(mgp.kl_divergence(qz, pz)) + self.add_loss((loc ** 2).sum(1)) + return qz.sample() + + shape = (4, 4) + for hybridize in [True, False]: + net = problemBlock() + if hybridize: + net.hybridize() + loc = np.random.randn(*shape) + scale = np.random.rand(*shape) + with pytest.raises(ValueError): + mx_out = net(loc, scale).asnumpy() + + +@pytest.mark.garbage_expected +@with_seed() +@use_np +def test_gluon_stochastic_sequential_v1(): + class normalBlock(HybridBlock): + def hybrid_forward(self, F, x): + return (x + 1) + + class stochasticBlock(StochasticBlock): + @StochasticBlock.collectLoss + def hybrid_forward(self, F, x): + self.add_loss(x ** 2) + self.add_loss(x - 1) + return (x + 1) + + class problemBlock(StochasticBlock): + def hybrid_forward(self, F, x): + self.add_loss(x ** 2) + self.add_loss(x - 1) + return (x + 1) + + shape = (4, 4) + for hybridize in [True, False]: + initial_value = np.ones(shape) + net = StochasticSequential() + net.add(stochasticBlock()) + net.add(normalBlock()) + net.add(stochasticBlock()) + net.add(normalBlock()) + if hybridize: + net.hybridize() + mx_out = net(initial_value).asnumpy() + assert_almost_equal(mx_out, _np.ones(shape) * 5) + accumulated_loss = net.losses + assert len(accumulated_loss) == 2 + assert_almost_equal(accumulated_loss[0][0].asnumpy(), _np.ones(shape)) + assert_almost_equal( + accumulated_loss[0][1].asnumpy(), _np.ones(shape) - 1) + assert_almost_equal( + accumulated_loss[1][0].asnumpy(), _np.ones(shape) * 9) + assert_almost_equal( + accumulated_loss[1][1].asnumpy(), _np.ones(shape) + 1) + + for hybridize in [True, False]: + initial_value = np.ones(shape) + net = StochasticSequential() + net.add(stochasticBlock()) + net.add(normalBlock()) + net.add(problemBlock()) + net.add(normalBlock()) + if hybridize: + net.hybridize() + with pytest.raises(ValueError): + mx_out = net(initial_value).asnumpy() + + +@with_seed() +@use_np +def test_gluon_constraint_v1(): + class TestConstraint(HybridBlock): + def __init__(self, constraint_type): + super(TestConstraint, self).__init__() + self._constraint_type = getattr(mgp.constraint, constraint_type) + + def hybrid_forward(self, F, *params): + value = params[0] + constraint_param = params[1:] + if len(constraint_param) == 0: + constraint = self._constraint_type() + else: + constraint = self._constraint_type(*constraint_param) + return constraint.check(value) + + _s = np.random.randn(5, 10, 10) + psd_matrix = np.matmul(_s, np.swapaxes(_s, -1, -2)) + np.eye(_s.shape[-1]) + + constraints_zoo = [ + # (constraint_type, constraint_param, test_samples) + ('Real', (), [np.random.randn(2, 2)]), + ('Boolean', (), [np.random.randint(0, 20, size=(2, 2)) % 2 == 0]), + ('Interval', [np.zeros((2, 2)), np.ones( + (2, 2))], [np.random.rand(2, 2)]), + ('OpenInterval', [np.zeros((2, 2)), np.ones( + (2, 2))], [np.random.rand(2, 2)]), + ('HalfOpenInterval', [np.zeros((2, 2)), + np.ones((2, 2))], [np.random.rand(2, 2)]), + ('IntegerInterval', [np.zeros((2, 2)), np.ones((2, 2)) * 10], + [np.random.randint(0, 10, size=(2, 2)).astype('float32')]), + ('IntegerOpenInterval', [np.zeros((2, 2)), np.ones((2, 2)) * 10], + [np.random.randint(1, 9, size=(2, 2)).astype('float32')]), + ('IntegerHalfOpenInterval', [np.zeros((2, 2)), np.ones((2, 2)) * 10], + [np.random.randint(1, 9, size=(2, 2)).astype('float32')]), + ('GreaterThan', [np.zeros((2, 2))], [np.random.rand(2, 2)]), + ('GreaterThanEq', [np.zeros((2, 2))], [np.random.rand(2, 2)]), + ('LessThan', [np.ones((2, 2))], [np.random.rand(2, 2)]), + ('LessThanEq', [np.ones((2, 2))], [np.random.rand(2, 2)]), + ('IntegerGreaterThan', [np.zeros((2, 2))], + [np.random.randint(1, 10, size=(2, 2)).astype('float32')]), + ('IntegerGreaterThanEq', [np.zeros((2, 2))], + [np.random.randint(0, 10, size=(2, 2)).astype('float32')]), + ('IntegerLessThan', [np.ones((2, 2)) * 10], + [np.random.randint(0, 9, size=(2, 2)).astype('float32')]), + ('IntegerLessThanEq', [np.ones((2, 2)) * 10], + [np.random.randint(0, 10, size=(2, 2)).astype('float32')]), + ('Positive', (), [np.random.rand(2, 2)]), + ('NonNegative', (), [np.random.rand(2, 2)]), + ('PositiveInteger', (), [np.random.randint( + 1, 5, size=(2, 2)).astype('float32')]), + ('NonNegativeInteger', (), [np.random.randint( + 0, 5, size=(2, 2)).astype('float32')]), + ('Simplex', (), [npx.softmax(np.random.randn(4, 4), axis=-1)]), + ('LowerTriangular', (), [np.tril(np.random.randn(5, 3, 3))]), + ('LowerCholesky', (), [np.linalg.cholesky(psd_matrix)]), + ('PositiveDefinite', (), [psd_matrix]), + ] + + for (constraint_type, constraint_arg, test_samples) in constraints_zoo: + for hybridize in [True, False]: + net = TestConstraint(constraint_type) + if hybridize: + net.hybridize() + for test_sample in test_samples: + mx_out = net(test_sample, *constraint_arg).asnumpy() + assert_almost_equal(mx_out, test_sample.asnumpy()) + + +@with_seed() +@use_np +def test_gluon_domain_map_v1(): + class TestDomainMap(HybridBlock): + def __init__(self, constraint_type, bijective): + super(TestDomainMap, self).__init__() + self._constraint_type = getattr(mgp.constraint, constraint_type) + + def hybrid_forward(self, F, *params): + value = params[0] + constraint_param = params[1:] + if len(constraint_param) == 0: + constraint = self._constraint_type() + else: + constraint = self._constraint_type(*constraint_param) + if bijective: + bijector = mgp.biject_to(constraint) + bijector.F = F + value = bijector(value) + else: + transformation = mgp.transform_to(constraint) + transformation.F = F + value = transformation(value) + return (value, constraint.check(value)) + + constraints_zoo = [ + # (constraint_type, constraint_param) + ('Positive', ()), + ('GreaterThan', [np.random.randn(2, 2)]), + ('GreaterThanEq', [np.random.randn(2, 2)]), + ('LessThan', [np.random.randn(2, 2)]), + ('Interval', [np.random.uniform(0, 1, (2, 2)), + np.random.uniform(2, 3, (2, 2))]), + ('HalfOpenInterval', [np.random.uniform( + 0, 1, (2, 2)), np.random.uniform(2, 3, (2, 2))]) + ] + + test_sample = np.random.randn(2, 2) + + for (constraint_type, constraint_arg) in constraints_zoo: + for bijective in [True, False]: + for hybridize in [True, False]: + net = TestDomainMap(constraint_type, bijective) + if hybridize: + net.hybridize() + constrained_out, constraint_status = net( + test_sample, *constraint_arg) + assert_almost_equal(constrained_out.asnumpy(), + constraint_status.asnumpy()) diff --git a/tests/python/unittest/test_gluon_probability_v2.py b/tests/python/unittest/test_gluon_probability_v2.py new file mode 100644 index 000000000000..9a36b4fc7056 --- /dev/null +++ b/tests/python/unittest/test_gluon_probability_v2.py @@ -0,0 +1,2365 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test gluon.probability with HybridBlock.forward api +""" +import mxnet as mx +import numpy as _np +from mxnet import np, npx, autograd +from mxnet import gluon +import mxnet.gluon.probability as mgp +from mxnet.gluon.probability import StochasticBlock, StochasticSequential +from mxnet.gluon import HybridBlock +from mxnet.test_utils import use_np, assert_almost_equal + +from common import with_seed +from numpy.testing import assert_array_equal +import pytest +import scipy.stats as ss +import scipy.special as scipy_special +import itertools +from numbers import Number + + +def prob_to_logit(prob): + return np.log(prob) - np.log1p(-prob) + + +def _distribution_method_invoker(dist, func, *args): + """Wrapper for invoking different types of class methods with one unified + interface. + + Parameters + ---------- + dist : Distribution + func : method + """ + if (len(args) == 0): + out = getattr(dist, func) + if callable(out): + return out() + else: + return out + return getattr(dist, func)(*args) + + +def test_mgp_getF(): + # Test getF + getF = mgp.utils.getF + nd = mx.nd + sym = mx.sym + assert getF(nd.ones((2, 2)), nd.ones((2, 2))) == nd + assert getF(sym.ones((2, 2)), sym.ones((2, 2))) == sym + assert getF(1.0, 2.0) == nd + + # Test exception + with pytest.raises(TypeError): + getF(nd.ones((2, 2)), sym.ones((2, 2))) + getF(sym.ones((2, 2)), nd.ones((2, 2))) + + +@with_seed() +@use_np +def test_gluon_uniform(): + class TestUniform(HybridBlock): + def __init__(self, func): + super(TestUniform, self).__init__() + self._func = func + + def forward(self, low, high, *args): + uniform = mgp.Uniform(low, high, validate_args=True) + return _distribution_method_invoker(uniform, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + low = np.random.uniform(-1, 1, shape) + high = low + np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(low, high) + net = TestUniform("log_prob") + if hybridize: + net.hybridize() + for i in range(2): + mx_out = net(low, high, samples).asnumpy() + np_out = ss.uniform(low.asnumpy(), + (high - low).asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + low = np.random.uniform(-1, 1, shape) + high = low + np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(low, high) + net = TestUniform("cdf") + if hybridize: + net.hybridize() + mx_out = net(low, high, samples).asnumpy() + np_out = ss.uniform(low.asnumpy(), + (high - low).asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + low = np.random.uniform(-1, 1, shape) + high = low + np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestUniform("icdf") + if hybridize: + net.hybridize() + mx_out = net(low, high, samples).asnumpy() + np_out = ss.uniform(low.asnumpy(), + (high - low).asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + low = np.random.uniform(-1, 1, shape) + high = low + np.random.uniform(0.5, 1.5, shape) + net = TestUniform("entropy") + if hybridize: + net.hybridize() + mx_out = net(low, high).asnumpy() + np_out = ss.uniform(low.asnumpy(), + (high - low).asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_normal(): + class TestNormal(HybridBlock): + def __init__(self, func): + super(TestNormal, self).__init__() + self._func = func + + def forward(self, loc, scale, *args): + normal = mgp.Normal(loc, scale, validate_args=True) + return _distribution_method_invoker(normal, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestNormal("log_prob") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.norm(loc.asnumpy(), + scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestNormal("cdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.norm(loc.asnumpy(), + scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestNormal("icdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.norm(loc.asnumpy(), + scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + net = TestNormal("entropy") + if hybridize: + net.hybridize() + mx_out = net(loc, scale).asnumpy() + np_out = ss.norm(loc.asnumpy(), + scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_laplace(): + class TestLaplace(HybridBlock): + def __init__(self, func): + super(TestLaplace, self).__init__() + self._func = func + + def forward(self, loc, scale, *args): + laplace = mgp.Laplace(loc, scale, validate_args=True) + return _distribution_method_invoker(laplace, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.laplace(size=shape) + net = TestLaplace("log_prob") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.laplace(loc.asnumpy(), + scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.laplace(size=shape) + net = TestLaplace("cdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.laplace(loc.asnumpy(), + scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestLaplace("icdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.laplace(loc.asnumpy(), + scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + net = TestLaplace("entropy") + if hybridize: + net.hybridize() + mx_out = net(loc, scale).asnumpy() + np_out = ss.laplace(loc.asnumpy(), + scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_cauchy(): + class TestCauchy(HybridBlock): + def __init__(self, func): + self._func = func + super(TestCauchy, self).__init__() + + def forward(self, loc, scale, *args): + cauchy = mgp.Cauchy(loc, scale, validate_args=True) + return _distribution_method_invoker(cauchy, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestCauchy("sample") + if hybridize: + net.hybridize() + mx_out = net(loc, scale) + desired_shape = (shape,) if isinstance(shape, Number) else shape + assert mx_out.shape == desired_shape + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestCauchy("log_prob") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.cauchy(loc.asnumpy(), + scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestCauchy("cdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.cauchy(loc.asnumpy(), + scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestCauchy("icdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.cauchy(loc.asnumpy(), + scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + net = TestCauchy("entropy") + if hybridize: + net.hybridize() + mx_out = net(loc, scale).asnumpy() + np_out = ss.cauchy(loc.asnumpy(), + scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_half_cauchy(): + class TestHalfCauchy(HybridBlock): + def __init__(self, func): + super(TestHalfCauchy, self).__init__() + self._func = func + + def forward(self, scale, *args): + half_normal = mgp.HalfCauchy(scale, validate_args=True) + return getattr(half_normal, self._func)(*args) + + shapes = [(), (1,), (2, 3), 6] + + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + net = TestHalfCauchy("sample") + if hybridize: + net.hybridize() + mx_out = net(scale).asnumpy() + if isinstance(shape, Number): + shape = (shape,) + assert mx_out.shape == shape + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.abs(np.random.normal(size=shape)) + net = TestHalfCauchy("log_prob") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfcauchy(0, scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.abs(np.random.normal(size=shape)) + net = TestHalfCauchy("cdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfcauchy(0, scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestHalfCauchy("icdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfcauchy(0, scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_poisson(): + class TestPoisson(HybridBlock): + def __init__(self, func): + self._func = func + super(TestPoisson, self).__init__() + + def forward(self, rate, *args): + poisson = mgp.Poisson(rate, validate_args=True) + return _distribution_method_invoker(poisson, self._func, *args) + + shapes = [(1,), (2, 3), 6] + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + rate = np.random.uniform(0.5, 1.5, shape) + net = TestPoisson("sample") + if hybridize: + net.hybridize() + mx_out = net(rate).asnumpy() + assert mx_out.shape == rate.shape + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + rate = np.random.uniform(0.5, 1.5, shape) + samples = np.random.randint(0, 5, shape).astype('float') + net = TestPoisson("log_prob") + if hybridize: + net.hybridize() + mx_out = net(rate, samples).asnumpy() + np_out = ss.poisson(mu=rate.asnumpy()).logpmf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_geometric(): + class TestGeometric(HybridBlock): + def __init__(self, func, is_logit=False): + super(TestGeometric, self).__init__() + self._is_logit = is_logit + self._func = func + + def forward(self, params, *args): + dist = mgp.Geometric(logit=params, validate_args=True) if self._is_logit else \ + mgp.Geometric(prob=params, validate_args=True) + return _distribution_method_invoker(dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + # Test log_prob + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + sample = np.random.randint(0, 10, size=shape).astype('float32') + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestGeometric("log_prob", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param, sample).asnumpy() + np_out = ss.geom.logpmf(sample.asnumpy() + 1, prob.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test variance + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestGeometric("variance", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + np_out = ss.geom(prob.asnumpy()).var() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + # Add lower bound constraint, otherwise scipy would raise warning. + prob = np.random.uniform(low=0.1, size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestGeometric("entropy", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + np_out = ss.geom(prob.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_negative_binomial(): + class TestNegativeBinomial(HybridBlock): + def __init__(self, func, is_logit=False): + super(TestNegativeBinomial, self).__init__() + self._is_logit = is_logit + self._func = func + + def forward(self, n, params, *args): + dist = mgp.NegativeBinomial(n=n, logit=params, validate_args=True) if self._is_logit else \ + mgp.NegativeBinomial(n=n, prob=params, validate_args=True) + return _distribution_method_invoker(dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + # Test log_prob + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + n = np.random.randint(1, 10, size=shape).astype('float32') + prob = np.random.uniform(low=0.1, size=shape) + sample = np.random.randint(0, 10, size=shape).astype('float32') + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestNegativeBinomial("log_prob", use_logit) + if hybridize: + net.hybridize() + mx_out = net(n, param, sample).asnumpy() + np_out = ss.nbinom(n=n.asnumpy(), p=prob.asnumpy() + ).logpmf(sample.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test mean and variance + for shape, hybridize in itertools.product(shapes, [True, False]): + for func in ['mean', 'variance']: + for use_logit in [True, False]: + n = np.random.randint(1, 10, size=shape).astype('float32') + prob = np.random.uniform(low=0.1, size=shape) + net = TestNegativeBinomial(func, use_logit) + param = prob + if use_logit: + param = prob_to_logit(param) + if hybridize: + net.hybridize() + mx_out = net(n, param).asnumpy() + ss_nbinom = ss.nbinom(n=n.asnumpy(), p=1 - prob.asnumpy()) + if func == 'mean': + np_out = ss_nbinom.mean() + else: + np_out = ss_nbinom.var() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_exponential(): + class TestExponential(HybridBlock): + def __init__(self, func): + self._func = func + super(TestExponential, self).__init__() + + def forward(self, scale, *args): + exponential = mgp.Exponential(scale, validate_args=True) + return _distribution_method_invoker(exponential, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(0.2, 1.2, size=shape) + net = TestExponential("log_prob") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.expon(scale=scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(0.2, 1.2, size=shape) + net = TestExponential("cdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.expon(scale=scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(0.0, 1.0, size=shape) + net = TestExponential("icdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.expon(scale=scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + net = TestExponential("entropy") + if hybridize: + net.hybridize() + mx_out = net(scale).asnumpy() + np_out = ss.expon(scale=scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_weibull(): + class TestWeibull(HybridBlock): + def __init__(self, func): + super(TestWeibull, self).__init__() + self._func = func + + def forward(self, concentration, scale, *args): + weibull = mgp.Weibull(concentration, scale, validate_args=True) + return _distribution_method_invoker(weibull, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + concentration = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(size=shape) + net = TestWeibull("log_prob") + if hybridize: + net.hybridize() + mx_out = net(concentration, scale, samples).asnumpy() + np_out = ss.weibull_min(c=concentration.asnumpy( + ), scale=scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + concentration = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(size=shape) + net = TestWeibull("cdf") + if hybridize: + net.hybridize() + mx_out = net(concentration, scale, samples).asnumpy() + np_out = ss.weibull_min(c=concentration.asnumpy( + ), scale=scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + concentration = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(size=shape) + net = TestWeibull("icdf") + if hybridize: + net.hybridize() + mx_out = net(concentration, scale, samples).asnumpy() + np_out = ss.weibull_min(c=concentration.asnumpy( + ), scale=scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + concentration = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + net = TestWeibull("entropy") + if hybridize: + net.hybridize() + mx_out = net(concentration, scale).asnumpy() + np_out = ss.weibull_min(c=concentration.asnumpy(), + scale=scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_pareto(): + class TestPareto(HybridBlock): + def __init__(self, func): + super(TestPareto, self).__init__() + self._func = func + + def forward(self, alpha, scale, *args): + pareto = mgp.Pareto(alpha, scale, validate_args=True) + return _distribution_method_invoker(pareto, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(1, 2, size=shape) + net = TestPareto("log_prob") + if hybridize: + net.hybridize() + mx_out = net(alpha, scale, samples).asnumpy() + np_out = ss.pareto(b=alpha.asnumpy(), scale=scale.asnumpy()).logpdf( + samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(1.0, 2.0, size=shape) + net = TestPareto("cdf") + if hybridize: + net.hybridize() + mx_out = net(alpha, scale, samples).asnumpy() + np_out = ss.pareto(b=alpha.asnumpy(), scale=scale.asnumpy()).cdf( + samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + samples = np.random.uniform(size=shape) + net = TestPareto("icdf") + if hybridize: + net.hybridize() + mx_out = net(alpha, scale, samples).asnumpy() + np_out = ss.pareto(b=alpha.asnumpy(), scale=scale.asnumpy()).ppf( + samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(size=shape) + scale = np.random.uniform(size=shape) + net = TestPareto("entropy") + if hybridize: + net.hybridize() + mx_out = net(alpha, scale).asnumpy() + np_out = ss.pareto(b=alpha.asnumpy(), scale=scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_gamma(): + class TestGamma(HybridBlock): + def __init__(self, func): + super(TestGamma, self).__init__() + self._func = func + + def forward(self, shape, scale, *args): + gamma = mgp.Gamma(shape, scale, validate_args=True) + return _distribution_method_invoker(gamma, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(0.5, 1.5, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestGamma("log_prob") + if hybridize: + net.hybridize() + mx_out = net(alpha, scale, samples).asnumpy() + np_out = ss.gamma(a=alpha.asnumpy(), loc=0, + scale=scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test `mean`, `var` and `entropy` + for shape, hybridize in itertools.product(shapes, [True, False]): + for func in ['mean', 'variance', 'entropy']: + alpha = np.random.uniform(0.5, 1.5, shape) + scale = np.random.uniform(0.5, 1.5, shape) + net = TestGamma(func) + if hybridize: + net.hybridize() + mx_out = net(alpha, scale).asnumpy() + ss_gamma = ss.gamma(a=alpha.asnumpy(), loc=0, + scale=scale.asnumpy()) + if func == 'mean': + np_out = ss_gamma.mean() + elif func == 'variance': + np_out = ss_gamma.var() + else: + np_out = ss_gamma.entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_dirichlet(): + class TestDirichlet(HybridBlock): + def __init__(self, func): + super(TestDirichlet, self).__init__() + self._func = func + + def forward(self, alpha, *args): + dirichlet = mgp.Dirichlet(alpha, validate_args=True) + return _distribution_method_invoker(dirichlet, self._func, *args) + + event_shapes = [2, 5, 10] + batch_shapes = [None, (2, 3)] + + # Test sampling + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for hybridize in [True, False]: + desired_shape = ( + batch_shape if batch_shape is not None else ()) + (event_shape,) + alpha = np.random.uniform(size=desired_shape) + net = TestDirichlet("sample") + if hybridize: + net.hybridize() + mx_out = net(alpha).asnumpy() + # Check shape + assert mx_out.shape == desired_shape + # Check simplex + assert_almost_equal(mx_out.sum(-1), _np.ones_like(mx_out.sum(-1)), atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test log_prob + # Scipy does not support batch `alpha`, thus we skip multi-dimensional batch_shape case. + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes[:1]): + for hybridize in [True, False]: + desired_shape = ( + batch_shape if batch_shape is not None else ()) + (event_shape,) + alpha = np.random.uniform(size=desired_shape) + np_samples = _np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape) + net = TestDirichlet("log_prob") + if hybridize: + net.hybridize() + mx_out = net(alpha, np.array(np_samples)).asnumpy() + np_out = ss.dirichlet(alpha=alpha.asnumpy()).logpdf(np_samples) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test `mean`, `var` and `entropy` + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes[:1]): + for hybridize in [True, False]: + for func in ['mean', 'variance', 'entropy']: + desired_shape = ( + batch_shape if batch_shape is not None else ()) + (event_shape,) + alpha = np.random.uniform(size=desired_shape) + net = TestDirichlet(func) + if hybridize: + net.hybridize() + mx_out = net(alpha).asnumpy() + ss_dir = ss.dirichlet(alpha=alpha.asnumpy()) + if func == 'mean': + np_out = ss_dir.mean() + elif func == 'variance': + np_out = ss_dir.var() + else: + np_out = ss_dir.entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_beta(): + class TestBeta(HybridBlock): + def __init__(self, func): + super(TestBeta, self).__init__() + self._func = func + + def forward(self, alpha, beta, *args): + beta_dist = mgp.Beta(alpha, beta, validate_args=True) + return _distribution_method_invoker(beta_dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + alpha = np.random.uniform(0.5, 1.5, shape) + beta = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestBeta("log_prob") + if hybridize: + net.hybridize() + mx_out = net(alpha, beta, samples).asnumpy() + np_out = ss.beta(alpha.asnumpy(), beta.asnumpy() + ).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test `mean`, `var` and `entropy` + for shape, hybridize in itertools.product(shapes, [True, False]): + for func in ['mean', 'variance', 'entropy']: + alpha = np.random.uniform(0.5, 1.5, shape) + beta = np.random.uniform(0.5, 1.5, shape) + net = TestBeta(func) + if hybridize: + net.hybridize() + mx_out = net(alpha, beta).asnumpy() + ss_beta = ss.beta(alpha.asnumpy(), beta.asnumpy()) + if func == 'mean': + np_out = ss_beta.mean() + elif func == 'variance': + np_out = ss_beta.var() + else: + np_out = ss_beta.entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_fisher_snedecor(): + class TestFisherSnedecor(HybridBlock): + def __init__(self, func): + super(TestFisherSnedecor, self).__init__() + self._func = func + + def forward(self, df1, df2, *args): + beta_dist = mgp.FisherSnedecor(df1, df2, validate_args=True) + return _distribution_method_invoker(beta_dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + df1 = np.random.uniform(0.5, 1.5, shape) + df2 = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestFisherSnedecor("log_prob") + if hybridize: + net.hybridize() + mx_out = net(df1, df2, samples).asnumpy() + np_out = ss.f(dfn=df1.asnumpy(), dfd=df2.asnumpy() + ).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test `mean` and `var` + for shape, hybridize in itertools.product(shapes, [True, False]): + for func in ['mean', 'variance']: + df1 = np.random.uniform(0.5, 1.5, shape) + df2 = np.random.uniform(4.0, 6.0, shape) + net = TestFisherSnedecor(func) + if hybridize: + net.hybridize() + mx_out = net(df1, df2).asnumpy() + ss_f = ss.f(dfn=df1.asnumpy(), dfd=df2.asnumpy()) + if func == 'mean': + np_out = ss_f.mean() + else: + np_out = ss_f.var() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_student_t(): + class TestT(HybridBlock): + def __init__(self, func): + super(TestT, self).__init__() + self._func = func + + def forward(self, df, loc, scale, *args): + t_dist = mgp.StudentT(df, loc, scale, validate_args=True) + return _distribution_method_invoker(t_dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.zeros(shape) + scale = np.random.uniform(0.5, 1.5, shape) + df = np.random.uniform(2, 4, shape) + samples = np.random.uniform(0, 4, size=shape) + net = TestT("log_prob") + if hybridize: + net.hybridize() + mx_out = net(df, loc, scale, samples).asnumpy() + np_out = ss.t(loc=0, scale=scale.asnumpy(), + df=df.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test `mean`, `var` and `entropy` + for shape, hybridize in itertools.product(shapes, [False, True]): + for func in ['mean', 'variance', 'entropy']: + loc = np.zeros(shape) + scale = np.random.uniform(0.5, 1.5, shape) + df = np.random.uniform(3, 4, shape) + net = TestT(func) + if hybridize: + net.hybridize() + mx_out = net(df, loc, scale).asnumpy() + ss_f = ss.t(loc=0, scale=scale.asnumpy(), df=df.asnumpy()) + if func == 'mean': + np_out = ss_f.mean() + elif func == 'variance': + np_out = ss_f.var() + else: + np_out = ss_f.entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_gumbel(): + class TestGumbel(HybridBlock): + def __init__(self, func): + super(TestGumbel, self).__init__() + self._func = func + + def forward(self, loc, scale, *args): + normal = mgp.Gumbel(loc, scale, validate_args=True) + return getattr(normal, self._func)(*args) + + shapes = [(), (1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestGumbel("log_prob") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.gumbel_r(loc=loc.asnumpy(), + scale=scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.normal(size=shape) + net = TestGumbel("cdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.gumbel_r(loc.asnumpy(), + scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestGumbel("icdf") + if hybridize: + net.hybridize() + mx_out = net(loc, scale, samples).asnumpy() + np_out = ss.gumbel_r(loc.asnumpy(), + scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + scale = np.random.uniform(0.5, 1.5, shape) + net = TestGumbel("entropy") + if hybridize: + net.hybridize() + mx_out = net(loc, scale).asnumpy() + np_out = ss.gumbel_r(loc.asnumpy(), + scale.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_multinomial(): + class TestMultinomial(HybridBlock): + def __init__(self, func, num_events, total_count, is_logit, batch_shape=None, sample_shape=None): + super(TestMultinomial, self).__init__() + self._num_events = num_events + self._total_count = total_count + self._is_logit = is_logit + self._func = func + self._batch_shape = batch_shape + self._sample_shape = sample_shape + + def forward(self, params, *args): + multinomial = ( + mgp.Multinomial(self._num_events, logit=params, total_count=self._total_count, + validate_args=True) + if self._is_logit else + mgp.Multinomial(self._num_events, prob=params, total_count=self._total_count, + validate_args=True) + ) + if self._func == 'sample': + return multinomial.sample(self._batch_shape) + if self._func == 'sample_n': + return multinomial.sample_n(self._sample_shape) + return _distribution_method_invoker(multinomial, self._func, *args) + + def one_hot(a, num_classes): + return np.identity(num_classes)[a] + + event_shapes = [2, 5, 10] + batch_shapes = [None, (2, 3)] # , (4, 0, 5)] + sample_shapes = [None, (2,), (3, 4)] + + # Test sampling + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob + if use_logit: + param = np.log(param) + net = TestMultinomial("sample", event_shape, _np.random.randint(1, 5), + use_logit, batch_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = batch_shape if batch_shape is not None else () + assert mx_out.shape == desired_shape + (event_shape,) + + # Test sample_n + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob + if use_logit: + param = np.log(param) + net = TestMultinomial("sample_n", event_shape, _np.random.randint(1, 5), + use_logit, batch_shape, sample_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + sample_shape = () if sample_shape is None else sample_shape + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + assert mx_out.shape == desired_shape + (event_shape,) + + # Test log_prob + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + eps = _np.finfo('float32').eps + prob = np.clip(prob, eps, 1 - eps) + param = prob + sample_shape = () if sample_shape is None else sample_shape + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + samples = np.random.choice(event_shape, size=desired_shape) + samples = one_hot(samples, event_shape) + if use_logit: + param = np.log(param) + net = TestMultinomial("log_prob", event_shape, + _np.random.randint(1, 5), use_logit) + if hybridize: + net.hybridize() + mx_out = net(param, samples).asnumpy() + # Check shape + assert mx_out.shape == desired_shape + + +@with_seed() +@use_np +def test_gluon_binomial(): + class TestBinomial(HybridBlock): + def __init__(self, func, is_logit=False, n=1): + super(TestBinomial, self).__init__() + self._is_logit = is_logit + self._func = func + self._n = n + + def forward(self, params, *args): + dist = mgp.Binomial(n=self._n, logit=params, validate_args=True) \ + if self._is_logit else \ + mgp.Binomial(n=self._n, prob=params, validate_args=True) + return _distribution_method_invoker(dist, self._func, *args) + + shapes = [(), (1,), (2, 3), 6] + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + for use_logit in [True, False]: + n = _np.random.randint(5, 10) + prob = np.random.uniform(low=0.1, size=shape) + net = TestBinomial('sample', use_logit, n=float(n)) + param = prob + if use_logit: + param = prob_to_logit(param) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = (shape,) if isinstance(shape, int) else shape + assert mx_out.shape == desired_shape + + # Test sample_n + prefix_shape = (2, 3) + for shape in shapes: + n = _np.random.randint(5, 10) + prob = np.random.uniform(low=0.1, size=shape) + dist = mgp.Binomial(n=n, prob=prob) + samples = dist.sample_n(prefix_shape) + assert samples.shape == (prefix_shape + prob.shape) + + # Test log_prob + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + n = _np.random.randint(5, 10) + prob = np.random.uniform(low=0.1, size=shape) + sample = np.random.randint(0, n, size=shape).astype('float32') + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestBinomial("log_prob", use_logit, n=float(n)) + if hybridize: + net.hybridize() + mx_out = net(param, sample).asnumpy() + np_out = ss.binom(n=n, p=prob.asnumpy()).logpmf(sample.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test mean and variance + for shape, hybridize in itertools.product(shapes, [True, False]): + for func in ['mean', 'variance']: + for use_logit in [True, False]: + n = _np.random.randint(5, 10) + prob = np.random.uniform(low=0.1, size=shape) + net = TestBinomial(func, use_logit, n=float(n)) + param = prob + if use_logit: + param = prob_to_logit(param) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + ss_binom = ss.binom(n=n, p=prob.asnumpy()) + if func == 'mean': + np_out = ss_binom.mean() + else: + np_out = ss_binom.var() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_bernoulli(): + class TestBernoulli(HybridBlock): + def __init__(self, func, is_logit=False): + super(TestBernoulli, self).__init__() + self._is_logit = is_logit + self._func = func + + def forward(self, params, *args): + bernoulli = mgp.Bernoulli(logit=params, validate_args=True) if self._is_logit else \ + mgp.Bernoulli(prob=params, validate_args=True) + return _distribution_method_invoker(bernoulli, self._func, *args) + + # Test log_prob + shapes = [(), (1,), (2, 3), 6] + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + sample = npx.random.bernoulli(prob=0.5, size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestBernoulli("log_prob", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param, sample).asnumpy() + np_out = _np.log(ss.bernoulli.pmf(sample.asnumpy(), prob.asnumpy())) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test variance + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + sample = npx.random.bernoulli(prob=0.5, size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestBernoulli("variance", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + np_out = ss.bernoulli(prob.asnumpy()).var() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + sample = npx.random.bernoulli(prob=0.5, size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestBernoulli("entropy", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + np_out = ss.bernoulli(prob.asnumpy()).entropy() + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_relaxed_bernoulli(): + class TestRelaxedBernoulli(HybridBlock): + def __init__(self, func, is_logit=False): + super(TestRelaxedBernoulli, self).__init__() + self._is_logit = is_logit + self._func = func + + def forward(self, params, *args): + relaxed_bernoulli = mgp.RelaxedBernoulli(T=1.0, logit=params, validate_args=True)\ + if self._is_logit else \ + mgp.RelaxedBernoulli(T=1.0, prob=params, validate_args=True) + if self._func == "sample": + return relaxed_bernoulli.sample() + return _distribution_method_invoker(relaxed_bernoulli, self._func, *args) + + def prob_to_logit(prob): + return np.log(prob) - np.log1p(-prob) + + shapes = [(), (1,), (2, 3), 6] + # Test sampling + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + param.attach_grad() + net = TestRelaxedBernoulli("sample", use_logit) + if hybridize: + net.hybridize() + with autograd.record(): + mx_out = net(param) + mx_out.backward() + desired_shape = (shape,) if isinstance(shape, int) else shape + assert param.grad.shape == desired_shape + + for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): + prob = np.random.uniform(size=shape) + sample = np.random.uniform(0.1, 0.9, size=shape) + param = prob + if use_logit: + param = prob_to_logit(param) + net = TestRelaxedBernoulli("log_prob", use_logit) + if hybridize: + net.hybridize() + mx_out = net(param, sample).asnumpy() + desired_shape = (shape,) if isinstance(shape, int) else shape + assert mx_out.shape == desired_shape + + +@with_seed() +@use_np +def test_gluon_categorical(): + class TestCategorical(HybridBlock): + def __init__(self, func, is_logit=False, batch_shape=None, num_events=None, sample_shape=None): + super(TestCategorical, self).__init__() + self._is_logit = is_logit + self._func = func + self._batch_shape = batch_shape + self._num_events = num_events + self._sample_shape = sample_shape + + def forward(self, params, *args): + categorical = mgp.Categorical(self._num_events, logit=params, validate_args=True)\ + if self._is_logit else \ + mgp.Categorical(self._num_events, prob=params, + validate_args=True) + if self._func == "sample": + return categorical.sample(self._batch_shape) + if self._func == "sample_n": + return categorical.sample_n(self._sample_shape) + return _distribution_method_invoker(categorical, self._func, *args) + + event_shapes = [2, 5, 10] + batch_shapes = [None, (2, 3)] # , (4, 0, 5)] + sample_shapes = [(), (2,), (3, 4)] + + # Test sampling + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob.astype('float32') + if use_logit: + param = np.log(param) + net = TestCategorical("sample", use_logit, + batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = batch_shape if batch_shape is not None else () + assert mx_out.shape == desired_shape + + # Test sample_n + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob.astype('float32') + if use_logit: + param = np.log(param) + net = TestCategorical("sample_n", + is_logit=use_logit, batch_shape=batch_shape, + num_events=event_shape, sample_shape=sample_shape + ) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + assert mx_out.shape == desired_shape + + # Test log_prob + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + eps = _np.finfo('float32').eps + prob = np.clip(prob, eps, 1 - eps) + param = prob.astype('float32') + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + samples = np.random.choice(event_shape, size=desired_shape) + if use_logit: + param = np.log(param) + net = TestCategorical("log_prob", use_logit, + batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param, samples) + # Check shape + assert mx_out.shape == desired_shape + # Check value + log_pmf, indices = np.broadcast_arrays( + np.log(prob), np.expand_dims(samples, -1)) + if indices.ndim >= 1: + indices = indices[..., :1] + expect_log_prob = _np.take_along_axis( + log_pmf, indices.astype('int'), axis=-1).asnumpy() + assert_almost_equal(mx_out.asnumpy(), expect_log_prob.squeeze(), atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test enumerate_support + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob.astype('float32') + if use_logit: + param = np.log(param) + net = TestCategorical("enumerate_support", + use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = (event_shape,) + \ + (batch_shape if batch_shape is not None else ()) + assert mx_out.shape == desired_shape + + +@with_seed() +@use_np +def test_gluon_one_hot_categorical(): + def one_hot(a, num_classes): + return np.identity(num_classes)[a] + + class TestOneHotCategorical(HybridBlock): + def __init__(self, func, is_logit=False, batch_shape=None, num_events=None): + super(TestOneHotCategorical, self).__init__() + self._is_logit = is_logit + self._func = func + self._batch_shape = batch_shape + self._num_events = num_events + + def forward(self, params, *args): + categorical = mgp.OneHotCategorical(num_events=self._num_events, logit=params) \ + if self._is_logit else \ + mgp.OneHotCategorical(num_events=self._num_events, prob=params) + if self._func == "sample": + return categorical.sample(self._batch_shape) + return _distribution_method_invoker(categorical, self._func, *args) + + event_shapes = [2, 5, 10] + batch_shapes = [None, (2, 3)] # , (4, 0, 5)] + sample_shapes = [(), (2,), (3, 4)] + + # Test sampling + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob + if use_logit: + param = np.log(param) + net = TestOneHotCategorical( + "sample", use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = batch_shape if batch_shape is not None else () + assert mx_out.shape == desired_shape + (event_shape,) + + # Test log_prob + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + eps = _np.finfo('float32').eps + prob = np.clip(prob, eps, 1 - eps) + param = prob + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + samples = np.random.choice(event_shape, size=desired_shape) + samples = one_hot(samples, event_shape) + if use_logit: + param = np.log(param) + net = TestOneHotCategorical( + "log_prob", use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param, samples) + # Check shape + assert mx_out.shape == desired_shape + + # Test enumerate support + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + param = prob + if use_logit: + param = np.log(param) + net = TestOneHotCategorical( + "enumerate_support", use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param).asnumpy() + desired_shape = batch_shape if batch_shape is not None else () + assert mx_out.shape == (event_shape,) + \ + desired_shape + (event_shape,) + + +@with_seed() +@use_np +def test_relaxed_one_hot_categorical(): + class TestRelaxedOneHotCategorical(HybridBlock): + def __init__(self, func, is_logit=False, batch_shape=None, num_events=None): + super(TestRelaxedOneHotCategorical, self).__init__() + self._is_logit = is_logit + self._func = func + self._batch_shape = batch_shape + self._num_events = num_events + + def forward(self, params, *args): + categorical = mgp.RelaxedOneHotCategorical(T=1.0, num_events=self._num_events, logit=params) \ + if self._is_logit else \ + mgp.RelaxedOneHotCategorical( + T=1.0, num_events=self._num_events, prob=params) + if self._func == "sample": + return categorical.sample(self._batch_shape) + return _distribution_method_invoker(categorical, self._func, *args) + + event_shapes = [2, 5, 10] + batch_shapes = [None, (2, 3)] # , (4, 0, 5)] + sample_shapes = [(), (2,), (3, 4)] + + # Test sampling + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + prob = prob.astype('float32') + param = prob + if use_logit: + param = np.log(param) + param.attach_grad() + net = TestRelaxedOneHotCategorical( + "sample", use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + with autograd.record(): + mx_out = net(param) + mx_out.backward() + desired_shape = batch_shape if batch_shape is not None else () + assert mx_out.shape == desired_shape + (event_shape,) + assert param.grad.shape == param.shape + + # Test log_prob + for event_shape, batch_shape, sample_shape in itertools.product(event_shapes, batch_shapes, sample_shapes): + for use_logit, hybridize in itertools.product([True, False], [True, False]): + prob = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=batch_shape)) + eps = _np.finfo('float32').eps + prob = np.clip(prob, eps, 1 - eps) + param = prob + desired_shape = sample_shape + \ + (batch_shape if batch_shape is not None else ()) + # Samples from a Relaxed One-hot Categorical lie on a simplex. + samples = np.array(_np.random.dirichlet( + [1 / event_shape] * event_shape, size=desired_shape)) + if use_logit: + param = np.log(param) + net = TestRelaxedOneHotCategorical( + "log_prob", use_logit, batch_shape, event_shape) + if hybridize: + net.hybridize() + mx_out = net(param, samples) + # Check shape + assert mx_out.shape == desired_shape + + +@with_seed() +@use_np +def test_gluon_mvn(): + class TestMVN(HybridBlock): + def __init__(self, func, param_type): + super(TestMVN, self).__init__() + self._func = func + # cov, precision or scale_tril + self._param_type = param_type + + def forward(self, loc, cov, *args): + mvn = mgp.MultivariateNormal(loc=loc, **{self._param_type: cov}, + validate_args=True) + return _distribution_method_invoker(mvn, self._func, *args) + + def _stable_inv(cov): + """ + Force the precision matrix to be symmetric. + """ + precision = np.linalg.inv(cov) + precision_t = np.swapaxes(precision, -1, -2) + return (precision + precision_t) / 2 + + event_shapes = [3, 5] + loc_shapes = [(), (2,), (4, 2)] + cov_shapes = [(), (2,), (4, 2)] + cov_func = { + 'cov': lambda s: s, + 'precision': lambda s: _stable_inv(s), + 'scale_tril': lambda s: np.linalg.cholesky(s) + } + + # Test sampling + for loc_shape, cov_shape, event_shape in itertools.product(loc_shapes, cov_shapes, event_shapes): + for cov_type in cov_func.keys(): + for hybridize in [True, False]: + loc = np.random.randn(*(loc_shape + (event_shape,))) + _s = np.random.randn(*(cov_shape + (event_shape, event_shape))) + loc.attach_grad() + _s.attach_grad() + # Full covariance matrix + sigma = np.matmul(_s, np.swapaxes( + _s, -1, -2)) + np.eye(event_shape) + cov_param = cov_func[cov_type](sigma) + net = TestMVN('sample', cov_type) + if hybridize: + net.hybridize() + with autograd.record(): + mx_out = net(loc, cov_param) + desired_shape = (loc + sigma[..., 0]).shape + assert mx_out.shape == desired_shape + mx_out.backward() + assert loc.grad.shape == loc.shape + assert _s.grad.shape == _s.shape + + # Test log_prob + for loc_shape, cov_shape, event_shape in itertools.product(loc_shapes, cov_shapes, event_shapes): + for cov_type in cov_func.keys(): + for hybridize in [True, False]: + loc = np.random.randn(*(loc_shape + (event_shape,))) + _s = np.random.randn(*(cov_shape + (event_shape, event_shape))) + samples = np.random.normal( + np.zeros_like(loc), np.ones_like(_s[..., 0])) + loc.attach_grad() + _s.attach_grad() + # Full covariance matrix + sigma = np.matmul(_s, np.swapaxes( + _s, -1, -2)) + np.eye(event_shape) + cov_param = cov_func[cov_type](sigma) + net = TestMVN('log_prob', cov_type) + if hybridize: + net.hybridize() + mx_out = net(loc, cov_param, samples) + assert mx_out.shape == samples.shape[:-1] + # Select the first element in the batch, because scipy does not support batching. + loc_t = loc.reshape(-1, event_shape)[0].asnumpy() + sigma_t = sigma.reshape(-1, event_shape, + event_shape)[0].asnumpy() + if mx_out.shape == (): + mx_out_t = mx_out.asnumpy() + else: + mx_out_t = mx_out.asnumpy().flatten()[0] + samples_t = samples.reshape(-1, event_shape).asnumpy()[0] + scipy_mvn = ss.multivariate_normal(loc_t, sigma_t) + ss_out = scipy_mvn.logpdf(samples_t) + assert_almost_equal(mx_out_t, ss_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test entropy + for loc_shape, cov_shape, event_shape in itertools.product(loc_shapes, cov_shapes, event_shapes): + for cov_type in cov_func.keys(): + for hybridize in [True, False]: + loc = np.random.randn(*(loc_shape + (event_shape,))) + _s = np.random.randn(*(cov_shape + (event_shape, event_shape))) + loc.attach_grad() + _s.attach_grad() + # Full covariance matrix + sigma = np.matmul(_s, np.swapaxes( + _s, -1, -2)) + np.eye(event_shape) + cov_param = cov_func[cov_type](sigma) + net = TestMVN('entropy', cov_type) + if hybridize: + net.hybridize() + mx_out = net(loc, cov_param) + assert mx_out.shape == sigma.shape[:-2] + # Select the first element in the batch, because scipy does not support batching. + loc_t = loc.reshape(-1, event_shape)[0].asnumpy() + sigma_t = sigma.reshape(-1, event_shape, + event_shape)[0].asnumpy() + if mx_out.shape == (): + mx_out_t = mx_out.asnumpy() + else: + mx_out_t = mx_out.asnumpy().flatten()[0] + scipy_mvn = ss.multivariate_normal(loc_t, sigma_t) + ss_out = scipy_mvn.entropy() + assert_almost_equal(mx_out_t, ss_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_gluon_half_normal(): + class TestHalfNormal(HybridBlock): + def __init__(self, func): + super(TestHalfNormal, self).__init__() + self._func = func + + def forward(self, scale, *args): + half_normal = mgp.HalfNormal(scale, validate_args=True) + return getattr(half_normal, self._func)(*args) + + shapes = [(), (1,), (2, 3), 6] + + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + net = TestHalfNormal("sample") + if hybridize: + net.hybridize() + mx_out = net(scale).asnumpy() + if isinstance(shape, Number): + shape = (shape,) + assert mx_out.shape == shape + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.abs(np.random.normal(size=shape)) + net = TestHalfNormal("log_prob") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfnorm(0, scale.asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test cdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.abs(np.random.normal(size=shape)) + net = TestHalfNormal("cdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfnorm(0, scale.asnumpy()).cdf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test icdf + for shape, hybridize in itertools.product(shapes, [True, False]): + scale = np.random.uniform(0.5, 1.5, shape) + samples = np.random.uniform(size=shape) + net = TestHalfNormal("icdf") + if hybridize: + net.hybridize() + mx_out = net(scale, samples).asnumpy() + np_out = ss.halfnorm(0, scale.asnumpy()).ppf(samples.asnumpy()) + assert_almost_equal(mx_out, np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@with_seed() +@use_np +def test_affine_transform(): + r""" + Test the correctness of affine transformation by performing it + on a standard normal, since N(\mu, \sigma^2) = \mu + \sigma * N(0, 1) + """ + class TestAffineTransform(HybridBlock): + def __init__(self, func): + super(TestAffineTransform, self).__init__() + self._func = func + + def forward(self, loc, scale, *args): + std_normal = mgp.Normal(np.zeros_like(loc), + np.ones_like(scale)) + transforms = [mgp.AffineTransform(loc=0, scale=scale), + mgp.AffineTransform(loc=loc, scale=1)] + transformed_normal = mgp.TransformedDistribution( + std_normal, transforms) + if (len(args) == 0): + return getattr(transformed_normal, self._func) + return getattr(transformed_normal, self._func)(*args) + + shapes = [(1,), (2, 3), 6] + + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + loc.attach_grad() + scale = np.random.uniform(0.5, 1.5, shape) + scale.attach_grad() + samples = np.random.normal(size=shape) + net = TestAffineTransform('log_prob') + if hybridize: + net.hybridize() + with autograd.record(): + mx_out = net(loc, scale, samples) + np_out = _np.log(ss.norm(loc.asnumpy(), + scale.asnumpy()).pdf(samples.asnumpy())) + assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + mx_out.backward() + loc_expected_grad = ((samples - loc) / scale ** 2).asnumpy() + scale_expected_grad = (samples - loc) ** 2 * \ + np.power(scale, -3) - (1 / scale) + assert_almost_equal(loc.grad.asnumpy(), loc_expected_grad, atol=1e-4, + rtol=1e-3, use_broadcast=False) + assert_almost_equal(scale.grad.asnumpy(), scale_expected_grad, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + # Test sampling + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + loc.attach_grad() + scale = np.random.uniform(0.5, 1.5, shape) + scale.attach_grad() + if not isinstance(shape, tuple): + shape = (shape,) + expected_shape = (4, 5) + shape + net = TestAffineTransform('sample') + mx_out = net(loc, scale, expected_shape).asnumpy() + assert mx_out.shape == expected_shape + + +@with_seed() +@use_np +def test_compose_transform(): + class TestComposeTransform(HybridBlock): + def __init__(self, func): + super(TestComposeTransform, self).__init__() + self._func = func + + def forward(self, loc, scale, *args): + # Generate a log_normal distribution. + std_normal = mgp.Normal(np.zeros_like(loc), + np.ones_like(scale)) + transforms = mgp.ComposeTransform([ + mgp.AffineTransform(loc=0, scale=scale), + mgp.AffineTransform(loc=loc, scale=1), + mgp.ExpTransform() + ]) + transformed_normal = mgp.TransformedDistribution( + std_normal, transforms) + if (len(args) == 0): + return getattr(transformed_normal, self._func) + return getattr(transformed_normal, self._func)(*args) + + shapes = [(1,), (2, 3), 6] + # Test log_prob + for shape, hybridize in itertools.product(shapes, [True, False]): + loc = np.random.uniform(-1, 1, shape) + loc.attach_grad() + scale = np.random.uniform(0.5, 1.5, shape) + scale.attach_grad() + samples = np.random.uniform(1, 2, size=shape) + net = TestComposeTransform('log_prob') + if hybridize: + net.hybridize() + with autograd.record(): + mx_out = net(loc, scale, samples) + np_out = ss.lognorm(s=scale.asnumpy(), scale=np.exp( + loc).asnumpy()).logpdf(samples.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-4, + rtol=1e-3, use_broadcast=False) + + +@use_np +def test_cached_property(): + x = np.random.normal() + x.attach_grad() + scale = 0.1 + + class Dummy(object): + def __init__(self, x): + super(Dummy, self).__init__() + self.x = x + + @mgp.cached_property + def y(self): + return scale * self.x + 1 + + with autograd.record(): + obj = Dummy(x) + obj.y.backward() + assert_almost_equal(x.grad.asnumpy(), scale * np.ones((1,))) + + class DummyBlock(HybridBlock): + def forward(self, x): + obj = Dummy(x) + return obj.y + + x = np.random.normal() + x.attach_grad() + net = DummyBlock() + with autograd.record(): + y = net(x) + y.backward() + assert_almost_equal(x.grad.asnumpy(), scale * np.ones((1,))) + + x = np.random.normal() + x.attach_grad() + net.hybridize() + with autograd.record(): + y = net(x) + y.backward() + assert_almost_equal(x.grad.asnumpy(), scale * np.ones((1,))) + + +@use_np +def test_independent(): + class TestIndependent(HybridBlock): + def __init__(self, event_dim, func): + super(TestIndependent, self).__init__() + self._event_dim = event_dim + self._func = func + + def forward(self, logit, *args): + base_dist = mgp.Bernoulli(logit=logit) + reshaped_dist = mgp.Independent(base_dist, self._event_dim) + return getattr(reshaped_dist, self._func)(*args) + + event_shapes = [(1,), (4,), (2, 2)] + batch_shapes = [(2, 3), (2,)] + for (batch_shape, event_shape) in itertools.product(batch_shapes, event_shapes): + for hybridize in [False, True]: + for func in ['log_prob']: + full_shape = batch_shape + event_shape + logit = np.random.normal(0, 2, size=full_shape) + samples = np.round(np.random.uniform(size=full_shape)) + net = TestIndependent(len(event_shape), func) + if hybridize: + net.hybridize() + mx_out = net(logit, samples) + assert mx_out.shape == batch_shape + + +@with_seed() +@use_np +def test_gluon_kl(): + def _test_zero_kl(p, shape): + """Check if KL(p || p) = 0 + + Parameters + ---------- + p : Distribution + """ + mx_out = mgp.kl_divergence(p, p).asnumpy() + np_out = _np.zeros(shape) + assert_almost_equal(mx_out, np_out, atol=1e-3, + rtol=1e-2, use_broadcast=False) + + def _test_monte_carlo(p, q, M=50000): + r"""Check if KL(p || q) is approximately equal to + 1/M * \Sum_{i=1}^{M} log(p(x_i) / q(x_i)), x_i ~ p(x) + """ + kl = mgp.kl_divergence(p, q) + mc_approx = mgp.empirical_kl(p, q, M) + assert_almost_equal(mc_approx.asnumpy(), kl.asnumpy(), atol=1e-1, + rtol=1e-1, use_broadcast=False) + + def _dist_factory(dist, *param_funcs): + """Generate a distribution object with parameters of random value. + + Parameters + ---------- + dist : Type + A type of distribution. + param_funcs : List + A list of functions that generate valid parameters for `dist` + """ + params = [f() if callable(f) else f for f in param_funcs] + return dist(*params) + + # could cause longer runtime and potential flaky tests + monte_carlo_test = False + repeated_times = 50000 + shapes = [(), (1,), (2, 3), 6] + + # Test kl between same distributions + # uniform + for shape in shapes: + dist = mgp.Uniform + def low(): return np.random.uniform(0, 1, shape) + def high(): return np.random.uniform(1, 2, shape) + _test_zero_kl(_dist_factory(dist, low, high), shape) + + # normal, laplace, cauchy, gumbel + for dist in [mgp.Normal, mgp.Laplace, mgp.Cauchy, mgp.Gumbel]: + for shape in shapes: + def loc(): return np.random.uniform(-1, 1, shape) + def scale(): return np.random.uniform(0.5, 1.5, shape) + _test_zero_kl(_dist_factory(dist, loc, scale), shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, loc, scale), + _dist_factory(dist, loc, scale), + repeated_times) + + # poisson + for shape in shapes[1:]: + dist = mgp.Poisson + def rate(): return np.random.uniform(0.5, 1.5, shape) + _test_zero_kl(_dist_factory(dist, rate), shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, rate), + _dist_factory(dist, rate), + repeated_times) + + # exponential, geometric + for dist in [mgp.Exponential, mgp.Geometric]: + for shape in shapes: + def s(): return np.random.uniform(size=shape) + _test_zero_kl(_dist_factory(dist, s), shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, s), + _dist_factory(dist, s), + repeated_times) + + # pareto + for shape in shapes: + dist = mgp.Pareto + def alpha(): return np.random.uniform(size=shape) + def scale(): return np.random.uniform(size=shape) + _test_zero_kl(_dist_factory(dist, alpha, scale), shape) + + for shape in shapes: + dist = mgp.HalfNormal + def scale(): return np.random.uniform(0.5, 1.5, shape) + _test_zero_kl(_dist_factory(dist, scale), shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, scale), + _dist_factory(dist, scale), + repeated_times) + + # gamma, beta + for dist in [mgp.Gamma, mgp.Beta]: + for shape in shapes: + def param1(): return np.random.uniform(0.5, 1.5, shape) + def param2(): return np.random.uniform(0.5, 1.5, shape) + _test_zero_kl(_dist_factory(dist, param1, param2), shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, param1, param2), + _dist_factory(dist, param1, param2), + 50000) + + # binomial + for shape in shapes: + n = _np.random.randint(5, 10) + prob = np.random.uniform(low=0.1, size=shape) + dist = mgp.Binomial(n=n, prob=prob) + _test_zero_kl(dist, shape) + + # bernoulli + for shape in shapes: + prob = np.random.uniform(size=shape) + dist = mgp.Bernoulli(prob=prob) + _test_zero_kl(dist, shape) + + event_shapes = [3, 5, 10] + loc_shapes = [(), (2,), (4, 2)] + cov_shapes = [(), (2,), (4, 2)] + for loc_shape, cov_shape, event_shape in itertools.product(loc_shapes, cov_shapes, event_shapes): + loc = np.random.randn(*(loc_shape + (event_shape,))) + _s = np.random.randn(*(cov_shape + (event_shape, event_shape))) + sigma = np.matmul(_s, np.swapaxes(_s, -1, -2)) + np.eye(event_shape) + dist = mgp.MultivariateNormal(loc, cov=sigma) + desired_shape = (loc + sigma[..., 0]).shape[:-1] + _test_zero_kl(dist, desired_shape) + + batch_shapes = loc_shapes + # dirichlet + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + desired_shape = (batch_shape if batch_shape is not None else ()) + dist = mgp.Dirichlet + def alpha(): return np.random.uniform( + 0.5, 1.5, size=(desired_shape + (event_shape,))) + _test_zero_kl(_dist_factory(dist, alpha), desired_shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, alpha), + _dist_factory(dist, alpha), + 50000) + + # categorical, One-hot categorical + for dist in [mgp.Categorical, mgp.OneHotCategorical]: + for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes): + prob = (lambda: + np.array(_np.random.dirichlet([1 / event_shape] * event_shape, size=batch_shape))) + _test_zero_kl(_dist_factory(dist, event_shape, prob), batch_shape) + if monte_carlo_test: + _test_monte_carlo(_dist_factory(dist, event_shape, prob), + _dist_factory(dist, event_shape, prob), + repeated_times) + + # Test kl between different distributions + # KL(Uniform || ...) + for shape in shapes: + rhs_dists = [ + mgp.Normal(np.random.uniform(-1, 1, shape), + np.random.uniform(0.5, 1.5, shape)), + mgp.Gumbel(np.random.uniform(-1, 1, shape), + np.random.uniform(0.5, 1.5, shape)), + ] + for rhs_dist in rhs_dists: + low = np.random.uniform(-1, 1, shape) + high = low + np.random.uniform(0.5, 1.5, shape) + lhs_dist = mgp.Uniform(low, high) + kl = mgp.kl_divergence(lhs_dist, rhs_dist) + assert kl.shape == low.shape + if monte_carlo_test: + _test_monte_carlo(lhs_dist, rhs_dist, repeated_times) + + # KL(Exponential || ...) + for shape in shapes: + rhs_dists = [ + mgp.Normal(np.random.uniform(-1, 1, shape), + np.random.uniform(0.5, 1.5, shape)), + mgp.Gumbel(np.random.uniform(-1, 1, shape), + np.random.uniform(0.5, 1.5, shape)), + mgp.Gamma(np.random.uniform(0.5, 1.5, shape), + np.random.uniform(0.5, 1.5, shape)) + ] + for rhs_dist in rhs_dists: + s = np.random.uniform(size=shape) + lhs_dist = mgp.Exponential(s) + kl = mgp.kl_divergence(lhs_dist, rhs_dist) + assert kl.shape == s.shape + if monte_carlo_test: + _test_monte_carlo(lhs_dist, rhs_dist, repeated_times) + + +@pytest.mark.garbage_expected +@with_seed() +@use_np +def test_gluon_stochastic_block(): + class dummyBlock(StochasticBlock): + """In this test case, we generate samples from a Gaussian parameterized + by `loc` and `scale` and accumulate the KL-divergence between it and + its prior and the l2 norm of `loc` into the block's loss storage.""" + @StochasticBlock.collectLoss + def forward(self, loc, scale): + qz = mgp.Normal(loc, scale) + # prior + pz = mgp.Normal(np.zeros_like(loc), np.ones_like(scale)) + self.add_loss(mgp.kl_divergence(qz, pz)) + self.add_loss((loc ** 2).sum(1)) + return qz.sample() + + shape = (4, 4) + for hybridize in [True, False]: + net = dummyBlock() + if hybridize: + net.hybridize() + loc = np.random.randn(*shape) + scale = np.random.rand(*shape) + mx_out = net(loc, scale).asnumpy() + kl = net.losses[0].asnumpy() + l2_norm = net.losses[1].asnumpy() + assert mx_out.shape == loc.shape + assert kl.shape == loc.shape + assert l2_norm.shape == shape[:-1] + if hybridize: + net.export('dummyBlock', epoch=0) + + +@with_seed() +@use_np +def test_gluon_stochastic_block_exception(): + class problemBlock(StochasticBlock): + def forward(self, loc, scale): + qz = mgp.Normal(loc, scale) + # prior + pz = mgp.Normal(np.zeros_like(loc), np.ones_like(scale)) + self.add_loss(mgp.kl_divergence(qz, pz)) + self.add_loss((loc ** 2).sum(1)) + return qz.sample() + + shape = (4, 4) + for hybridize in [True, False]: + net = problemBlock() + if hybridize: + net.hybridize() + loc = np.random.randn(*shape) + scale = np.random.rand(*shape) + with pytest.raises(ValueError): + mx_out = net(loc, scale).asnumpy() + + +@pytest.mark.garbage_expected +@with_seed() +@use_np +def test_gluon_stochastic_sequential(): + class normalBlock(HybridBlock): + def forward(self, x): + return (x + 1) + + class stochasticBlock(StochasticBlock): + @StochasticBlock.collectLoss + def forward(self, x): + self.add_loss(x ** 2) + self.add_loss(x - 1) + return (x + 1) + + class problemBlock(StochasticBlock): + def forward(self, x): + self.add_loss(x ** 2) + self.add_loss(x - 1) + return (x + 1) + + shape = (4, 4) + for hybridize in [True, False]: + initial_value = np.ones(shape) + net = StochasticSequential() + net.add(stochasticBlock()) + net.add(normalBlock()) + net.add(stochasticBlock()) + net.add(normalBlock()) + if hybridize: + net.hybridize() + mx_out = net(initial_value).asnumpy() + assert_almost_equal(mx_out, _np.ones(shape) * 5) + accumulated_loss = net.losses + assert len(accumulated_loss) == 2 + assert_almost_equal(accumulated_loss[0][0].asnumpy(), _np.ones(shape)) + assert_almost_equal( + accumulated_loss[0][1].asnumpy(), _np.ones(shape) - 1) + assert_almost_equal( + accumulated_loss[1][0].asnumpy(), _np.ones(shape) * 9) + assert_almost_equal( + accumulated_loss[1][1].asnumpy(), _np.ones(shape) + 1) + + for hybridize in [True, False]: + initial_value = np.ones(shape) + net = StochasticSequential() + net.add(stochasticBlock()) + net.add(normalBlock()) + net.add(problemBlock()) + net.add(normalBlock()) + if hybridize: + net.hybridize() + with pytest.raises(ValueError): + mx_out = net(initial_value).asnumpy() + + +@with_seed() +@use_np +def test_gluon_domain_map(): + class TestDomainMap(HybridBlock): + def __init__(self, constraint_type, bijective): + super(TestDomainMap, self).__init__() + self._constraint_type = getattr(mgp.constraint, constraint_type) + + def forward(self, *params): + value = params[0] + constraint_param = params[1:] + if len(constraint_param) == 0: + constraint = self._constraint_type() + else: + constraint = self._constraint_type(*constraint_param) + if bijective: + bijector = mgp.biject_to(constraint) + value = bijector(value) + else: + transformation = mgp.transform_to(constraint) + value = transformation(value) + return (value, constraint.check(value)) + + constraints_zoo = [ + # (constraint_type, constraint_param) + ('Positive', ()), + ('GreaterThan', [np.random.randn(2, 2)]), + ('GreaterThanEq', [np.random.randn(2, 2)]), + ('LessThan', [np.random.randn(2, 2)]), + ('Interval', [np.random.uniform(0, 1, (2, 2)), + np.random.uniform(2, 3, (2, 2))]), + ('HalfOpenInterval', [np.random.uniform( + 0, 1, (2, 2)), np.random.uniform(2, 3, (2, 2))]) + ] + + test_sample = np.random.randn(2, 2) + + for (constraint_type, constraint_arg) in constraints_zoo: + for bijective in [True, False]: + for hybridize in [True, False]: + net = TestDomainMap(constraint_type, bijective) + if hybridize: + net.hybridize() + constrained_out, constraint_status = net( + test_sample, *constraint_arg) + assert_almost_equal(constrained_out.asnumpy(), + constraint_status.asnumpy())