Skip to content

Commit

Permalink
Updated Dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Jan 19, 2024
1 parent fbf0f75 commit 7b80dac
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand Down
86 changes: 85 additions & 1 deletion xgboostlss/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch.nn.functional import softplus, gumbel_softmax, softmax


def nan_to_num(predt: torch.tensor) -> torch.tensor:
"""
Replace nan, inf and -inf with the mean of predt.
Expand Down Expand Up @@ -119,6 +120,71 @@ def softplus_fn_df(predt: torch.tensor) -> torch.tensor:
return predt + torch.tensor(2.0, dtype=predt.dtype)


def softplus_fn_quantile(predt: torch.tensor,
beta: int,
threshold: int) -> torch.tensor:
"""
Softplus function used for Student-T distribution.
Arguments
---------
predt: torch.tensor
Predicted values.
beta: int
Beta parameter for softplus function.
threshold: int
Threshold parameter for softplus function.
Returns
-------
predt: torch.tensor
Predicted values.
"""
predt = softplus(input=nan_to_num(predt), beta=beta, threshold=threshold) + torch.tensor(1e-06, dtype=predt.dtype)

return predt


def squareplus_fn(predt: torch.tensor) -> torch.tensor:
"""
Square-Plus function used to ensure predt is strictly positive.
Arguments
---------
predt: torch.tensor
Predicted values.
Returns
-------
predt: torch.tensor
Predicted values.
"""
b = torch.tensor(4., dtype=predt.dtype)
predt = 0.5 * (predt + torch.sqrt(predt ** 2 + b)) + torch.tensor(1e-06, dtype=predt.dtype)

return predt


def squareplus_fn_df(predt: torch.tensor) -> torch.tensor:
"""
Square-Plus function used to ensure predt is strictly positive.
Arguments
---------
predt: torch.tensor
Predicted values.
Returns
-------
predt: torch.tensor
Predicted values.
"""
b = torch.tensor(4., dtype=predt.dtype)
predt = 0.5 * (predt + torch.sqrt(predt ** 2 + b)) + torch.tensor(1e-06, dtype=predt.dtype)

return predt + torch.tensor(2.0, dtype=predt.dtype)


def sigmoid_fn(predt: torch.tensor) -> torch.tensor:
"""
Function used to ensure predt are scaled to (0,1).
Expand Down Expand Up @@ -158,6 +224,25 @@ def relu_fn(predt: torch.tensor) -> torch.tensor:
return predt


def relu_fn_df(predt: torch.tensor) -> torch.tensor:
"""
Function used to ensure predt are scaled to max(0, predt).
Arguments
---------
predt: torch.tensor
Predicted values.
Returns
-------
predt: torch.tensor
Predicted values.
"""
predt = torch.relu(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype)

return predt + torch.tensor(2.0, dtype=predt.dtype)


def softmax_fn(predt: torch.tensor) -> torch.tensor:
"""
Softmax function used to ensure predt is adding to one.
Expand Down Expand Up @@ -216,5 +301,4 @@ def gumbel_softmax_fn(predt: torch.tensor,
torch.manual_seed(123)
predt = gumbel_softmax(nan_to_num(predt), tau=tau, dim=1) + torch.tensor(0, dtype=predt.dtype)


return predt

0 comments on commit 7b80dac

Please sign in to comment.