Skip to content

Commit

Permalink
Add params_to_tune for linear models (#1204)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman committed Apr 6, 2023
1 parent e919876 commit 2ae199b
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Remove version python-3.7 from `pyproject.toml`, update lock ([#1183](https://github.com/tinkoff-ai/etna/pull/1183))
- Add default `params_to_tune` for catboost models ([#1185](https://github.com/tinkoff-ai/etna/pull/1185))
- Add default `params_to_tune` for `ProphetModel` ([#1203](https://github.com/tinkoff-ai/etna/pull/1203))
- Add default `params_to_tune` for linear models ([#1204](https://github.com/tinkoff-ai/etna/pull/1204))
### Fixed
- Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
- `ProphetModel` fails with additional seasonality set ([#1157](https://github.com/tinkoff-ai/etna/pull/1157))
Expand Down
60 changes: 60 additions & 0 deletions etna/models/linear.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
from typing import Dict

import pandas as pd
from sklearn.linear_model import ElasticNet
from sklearn.linear_model import LinearRegression

from etna import SETTINGS
from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel
from etna.models.mixins import MultiSegmentModelMixin
from etna.models.mixins import NonPredictionIntervalContextIgnorantModelMixin
from etna.models.mixins import PerSegmentModelMixin
from etna.models.sklearn import _SklearnAdapter

if SETTINGS.auto_required:
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalDistribution
from optuna.distributions import LogUniformDistribution
from optuna.distributions import UniformDistribution


LINEAR_GRID: Dict[str, "BaseDistribution"] = {
"fit_intercept": CategoricalDistribution([False, True]),
}

ELASTIC_GRID: Dict[str, "BaseDistribution"] = {
"fit_intercept": CategoricalDistribution([False, True]),
"l1_ratio": UniformDistribution(0, 1),
"alpha": LogUniformDistribution(low=1e-5, high=1e3),
}


class _LinearAdapter(_SklearnAdapter):
def predict_components(self, df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -64,6 +84,16 @@ def __init__(self, fit_intercept: bool = True, **kwargs):
base_model=_LinearAdapter(regressor=LinearRegression(fit_intercept=self.fit_intercept, **self.kwargs))
)

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.
Returns
-------
:
Grid to tune.
"""
return LINEAR_GRID


class ElasticPerSegmentModel(
PerSegmentModelMixin,
Expand Down Expand Up @@ -117,6 +147,16 @@ def __init__(self, alpha: float = 1.0, l1_ratio: float = 0.5, fit_intercept: boo
)
)

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.
Returns
-------
:
Grid to tune.
"""
return ELASTIC_GRID


class LinearMultiSegmentModel(
MultiSegmentModelMixin,
Expand Down Expand Up @@ -147,6 +187,16 @@ def __init__(self, fit_intercept: bool = True, **kwargs):
base_model=_LinearAdapter(regressor=LinearRegression(fit_intercept=self.fit_intercept, **self.kwargs))
)

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.
Returns
-------
:
Grid to tune.
"""
return LINEAR_GRID


class ElasticMultiSegmentModel(
MultiSegmentModelMixin,
Expand Down Expand Up @@ -199,3 +249,13 @@ def __init__(self, alpha: float = 1.0, l1_ratio: float = 0.5, fit_intercept: boo
)
)
)

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.
Returns
-------
:
Grid to tune.
"""
return ELASTIC_GRID
15 changes: 15 additions & 0 deletions tests/test_models/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import pytest
from optuna.samplers import RandomSampler
from sklearn.linear_model import ElasticNet
from sklearn.linear_model import LinearRegression

Expand Down Expand Up @@ -325,3 +326,17 @@ def test_linear_adapter_predict_components_sum_up_to_target(df_with_regressors,
target = adapter.predict(df)
target_components = adapter.predict_components(df)
np.testing.assert_array_almost_equal(target, target_components.sum(axis=1), decimal=10)


@pytest.mark.parametrize(
"model", [LinearPerSegmentModel(), LinearMultiSegmentModel(), ElasticPerSegmentModel(), ElasticMultiSegmentModel()]
)
def test_params_to_tune(model):
grid = model.params_to_tune()
# we need sampler to get a value from distribution
sampler = RandomSampler()

assert len(grid) > 0
for name, distribution in grid.items():
value = sampler.sample_independent(study=None, trial=None, param_name=name, param_distribution=distribution)
_ = model.set_params(**{name: value})

0 comments on commit 2ae199b

Please sign in to comment.