Skip to content

Commit

Permalink
rsample
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif authored Nov 15, 2024
1 parent f0b9887 commit 4dfee43
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/gluonts/torch/distributions/generalized_pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class GeneralizedPareto(Distribution):
"scale": constraints.positive,
"concentration": constraints.real,
}
has_rsample = False
has_rsample = True

def __init__(self, loc, scale, concentration, validate_args=None):
self.loc, self.scale, self.concentration = broadcast_all(
Expand Down Expand Up @@ -80,11 +80,10 @@ def expand(self, batch_shape, _instance=None):
new._validate_args = self._validate_args
return new

def sample(self, sample_shape=torch.Size()):
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
with torch.no_grad():
u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device)
return self.icdf(u)
u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device)
return self.icdf(u)

def log_prob(self, value):
if self._validate_args:
Expand Down

0 comments on commit 4dfee43

Please sign in to comment.