From 4d44aa08a9ae714accc7b52e94ae2d372546c5ee Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 4 Jun 2021 10:23:21 +0200 Subject: [PATCH] Fixed DocString example. --- pyro/distributions/sine_skewed.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 1d103d4da8..eb049009d6 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -41,10 +41,10 @@ def model(obs): with pyro.plate('obs_plate'): sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc, - phi_concentration=1000 * phi_conc, - psi_concentration=1000 * psi_conc, - weighted_correlation=corr_scale) - return pyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs) + phi_concentration=1000 * phi_conc, + psi_concentration=1000 * psi_conc, + weighted_correlation=corr_scale) + return pyro.sample(' phi_psi', SineSkewed(sine, skewness), obs=obs) To ensure the skewing does not alter the normalization constant of the (Sine Bivaraite von Mises) base distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of @@ -109,7 +109,7 @@ def __repr__(self): def sample(self, sample_shape=torch.Size()): bd = self.base_dist ys = bd.sample(sample_shape) - u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape) + u = Uniform(0., self.skewness.new_ones(())).sample(sample_shape + self.batch_shape) # Section 2.3 step 3 in [1] mask = u <= .5 + .5 * (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1)