This package extends the functionality of univariate distributions in torch.distributions
by implementing several new methods:
sf
: survival function (complementary CDF)logsf
: logarithm of the survival function (negative cumulative hazard function)logcdf
: logarithm of the CDFlog_hazard
: logarithm of the hazard function (logarithm of the failure rate)isf
: inverse of the survival functionsample_cond
: instead of sampling from the full support of the distribution, generate samples betweenlower_bound
andupper_bound
This is especially useful when working with temporal point processes or survival analysis.
Naive implementation based on existing PyTorch functionality (e.g.,
torch.log(1.0 - dist.cdf(x))
for logsf
) will often not be as accurate and numerically
stable as the implementation provided by survival_distributions
.
Hopefully, these methods will be implemented in PyTorch sometime in the future,
but this package provides an alternative for the time being.
See DISTRIBUTIONS.md
for more details about the implemented functions and supported distributions.
- Install the latest version of PyTorch.
- Install
survival_distributions
pip install survival_distributions
For these distributions we provide a numerically stable implementation of logsf
.
Exponential
Logistic
LogLogistic
MixtureSameFamily
TransformedDistribution
Uniform
Weibull
For these distributions we implement logsf(x)
as log(1.0 - dist.cdf(x))
, which is less
numerically stable.
LogNormal
Normal
The package provides a drop-in replacement for torch.distributions
, so you can just modify your code as follows.
Old code
import torch
dist = torch.distributions.Exponential(rate=torch.tensor(2.0))
x = torch.tensor(1.5)
log_survival_proba = torch.log(1.0 - dist.cdf(x))
New code
import torch
import survival_distributions as sd
dist = sd.Exponential(rate=torch.tensor(2.0))
x = torch.tensor(1.5)
log_survival_proba = dist.logsf(x)