Skip to content

Commit

Permalink
Fixed DocString example.
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning committed Jun 4, 2021
1 parent 0b5af73 commit 4d44aa0
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions pyro/distributions/sine_skewed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def model(obs):
with pyro.plate('obs_plate'):
sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc,
phi_concentration=1000 * phi_conc,
psi_concentration=1000 * psi_conc,
weighted_correlation=corr_scale)
return pyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)
phi_concentration=1000 * phi_conc,
psi_concentration=1000 * psi_conc,
weighted_correlation=corr_scale)
return pyro.sample(' phi_psi', SineSkewed(sine, skewness), obs=obs)
To ensure the skewing does not alter the normalization constant of the (Sine Bivaraite von Mises) base
distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of
Expand Down Expand Up @@ -109,7 +109,7 @@ def __repr__(self):
def sample(self, sample_shape=torch.Size()):
bd = self.base_dist
ys = bd.sample(sample_shape)
u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape)
u = Uniform(0., self.skewness.new_ones(())).sample(sample_shape + self.batch_shape)

# Section 2.3 step 3 in [1]
mask = u <= .5 + .5 * (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1)
Expand Down

0 comments on commit 4d44aa0

Please sign in to comment.