Skip to content

Commit

Permalink
Vectorize metric computation (#1347)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman committed Aug 8, 2023
1 parent d23ec9d commit 26be1b2
Show file tree
Hide file tree
Showing 10 changed files with 331 additions and 100 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Remove upper limitation on version of numba ([#1321](https://github.com/tinkoff-ai/etna/pull/1321))
- Optimize `TSDataset.describe` and `TSDataset.info` by vectorization ([#1344](https://github.com/tinkoff-ai/etna/pull/1344))
- Add documentation warning about using dill during loading ([#1346](https://github.com/tinkoff-ai/etna/pull/1346))
- Vectorize metric computation ([#1347](https://github.com/tinkoff-ai/etna/pull/1347))

### Fixed
- Pipeline ensembles fail in `etna forecast` CLI ([#1331](https://github.com/tinkoff-ai/etna/pull/1331))
Expand Down
11 changes: 5 additions & 6 deletions etna/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from sklearn.metrics import mean_absolute_error as mae
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import mean_squared_log_error as msle
from sklearn.metrics import median_absolute_error as medae
from sklearn.metrics import r2_score

from etna.metrics.base import Metric
from etna.metrics.base import MetricAggregationMode
from etna.metrics.functional_metrics import mae
from etna.metrics.functional_metrics import mape
from etna.metrics.functional_metrics import max_deviation
from etna.metrics.functional_metrics import medae
from etna.metrics.functional_metrics import mse
from etna.metrics.functional_metrics import msle
from etna.metrics.functional_metrics import r2_score
from etna.metrics.functional_metrics import rmse
from etna.metrics.functional_metrics import sign
from etna.metrics.functional_metrics import smape
Expand Down
73 changes: 61 additions & 12 deletions etna/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from abc import ABC
from abc import abstractmethod
from enum import Enum
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Union

import numpy as np
import pandas as pd
from typing_extensions import Protocol
from typing_extensions import assert_never

from etna.core import BaseMixin
from etna.datasets.tsdataset import TSDataset
from etna.loggers import tslogger
from etna.metrics.functional_metrics import ArrayLike


class MetricAggregationMode(str, Enum):
Expand All @@ -27,6 +29,31 @@ def _missing_(cls, value):
)


class MetricFunctionSignature(str, Enum):
"""Enum for different metric function signatures."""

#: function should expect arrays of y_pred and y_true with length ``n_timestamps`` and return scalar
array_to_scalar = "array_to_scalar"

#: function should expect matrices of y_pred and y_true with shape ``(n_timestamps, n_segments)``
#: and return vector of length ``n_segments``
matrix_to_array = "matrix_to_array"

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} signatures allowed"
)


class MetricFunction(Protocol):
"""Protocol for ``metric_fn`` parameter."""

@abstractmethod
def __call__(self, y_true: ArrayLike, y_pred: ArrayLike) -> ArrayLike:
pass


class AbstractMetric(ABC):
"""Abstract class for metric."""

Expand Down Expand Up @@ -74,7 +101,13 @@ class Metric(AbstractMetric, BaseMixin):
dataset and aggregates it according to mode.
"""

def __init__(self, metric_fn: Callable[..., float], mode: str = MetricAggregationMode.per_segment, **kwargs):
def __init__(
self,
metric_fn: MetricFunction,
mode: str = MetricAggregationMode.per_segment,
metric_fn_signature: str = "array_to_scalar",
**kwargs,
):
"""
Init Metric.
Expand All @@ -89,21 +122,29 @@ def __init__(self, metric_fn: Callable[..., float], mode: str = MetricAggregatio
* if "per-segment" -- does not aggregate metrics
metric_fn_signature:
type of signature of ``metric_fn`` (see :py:class:`~etna.metrics.base.MetricFunctionSignature`)
kwargs:
functional metric's params
Raises
------
NotImplementedError:
it non existent mode is used
If non-existent ``mode`` is used.
NotImplementedError:
If non-existent ``metric_fn_signature`` is used.
"""
self.metric_fn = metric_fn
self.kwargs = kwargs
if MetricAggregationMode(mode) == MetricAggregationMode.macro:
if MetricAggregationMode(mode) is MetricAggregationMode.macro:
self._aggregate_metrics = self._macro_average
elif MetricAggregationMode(mode) == MetricAggregationMode.per_segment:
elif MetricAggregationMode(mode) is MetricAggregationMode.per_segment:
self._aggregate_metrics = self._per_segment_average

self._metric_fn_signature = MetricFunctionSignature(metric_fn_signature)

self.metric_fn = metric_fn
self.kwargs = kwargs
self.mode = mode
self.metric_fn_signature = metric_fn_signature

@property
def name(self) -> str:
Expand Down Expand Up @@ -276,13 +317,21 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
df_true = y_true[:, :, "target"].sort_index(axis=1)
df_pred = y_pred[:, :, "target"].sort_index(axis=1)

metrics_per_segment = {}
segments = df_true.columns.get_level_values("segment").unique()

for i, segment in enumerate(segments):
cur_y_true = df_true.iloc[:, i]
cur_y_pred = df_pred.iloc[:, i]
metrics_per_segment[segment] = self.metric_fn(y_true=cur_y_true, y_pred=cur_y_pred, **self.kwargs)
metrics_per_segment: Dict[str, float]
if self._metric_fn_signature is MetricFunctionSignature.array_to_scalar:
metrics_per_segment = {}
for i, segment in enumerate(segments):
cur_y_true = df_true.iloc[:, i].values
cur_y_pred = df_pred.iloc[:, i].values
metrics_per_segment[segment] = self.metric_fn(y_true=cur_y_true, y_pred=cur_y_pred, **self.kwargs) # type: ignore
elif self._metric_fn_signature is MetricFunctionSignature.matrix_to_array:
values = self.metric_fn(y_true=df_true.values, y_pred=df_pred.values, **self.kwargs)
metrics_per_segment = dict(zip(segments, values)) # type: ignore
else:
assert_never(self._metric_fn_signature)

metrics = self._aggregate_metrics(metrics_per_segment)
return metrics

Expand Down
118 changes: 94 additions & 24 deletions etna/metrics/functional_metrics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,47 @@
from enum import Enum
from functools import partial
from typing import List
from typing import Optional
from typing import Sequence
from typing import Union

import numpy as np
from sklearn.metrics import mean_absolute_error as mae
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import mean_squared_log_error as msle
from sklearn.metrics import median_absolute_error as medae
from sklearn.metrics import r2_score
from typing_extensions import assert_never

ArrayLike = List[Union[float, List[float]]]
ArrayLike = Union[float, Sequence[float], Sequence[Sequence[float]]]


def mape(y_true: ArrayLike, y_pred: ArrayLike, eps: float = 1e-15) -> float:
class FunctionalMetricMultioutput(str, Enum):
"""Enum for different functional metric multioutput modes."""

#: Compute one scalar value taking into account all outputs.
joint = "joint"

#: Compute one value per each output.
raw_values = "raw_values"

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} options allowed"
)


def _get_axis_by_multioutput(multioutput: str) -> Optional[int]:
multioutput_enum = FunctionalMetricMultioutput(multioutput)
if multioutput_enum is FunctionalMetricMultioutput.joint:
return None
elif multioutput_enum is FunctionalMetricMultioutput.raw_values:
return 0
else:
assert_never(multioutput_enum)


def mape(y_true: ArrayLike, y_pred: ArrayLike, eps: float = 1e-15, multioutput: str = "joint") -> ArrayLike:
"""Mean absolute percentage error.
`Wikipedia entry on the Mean absolute percentage error
Expand All @@ -26,14 +59,19 @@ def mape(y_true: ArrayLike, y_pred: ArrayLike, eps: float = 1e-15) -> float:
Estimated target values.
eps: float=1e-15
eps:
MAPE is undefined for ``y_true[i]==0`` for any ``i``, so all zeros ``y_true[i]`` are
clipped to ``max(eps, abs(y_true))``.
multioutput:
Defines aggregating of multiple output values
(see :py:class:`~etna.metrics.functional_metrics.FunctionalMetricMultioutput`).
Returns
-------
float
A non-negative floating point value (the best value is 0.0).
:
A non-negative floating point value (the best value is 0.0), or an array of floating point values,
one for each individual target.
"""
y_true_array, y_pred_array = np.asarray(y_true), np.asarray(y_pred)

Expand All @@ -42,10 +80,12 @@ def mape(y_true: ArrayLike, y_pred: ArrayLike, eps: float = 1e-15) -> float:

y_true_array = y_true_array.clip(eps)

return np.mean(np.abs((y_true_array - y_pred_array) / y_true_array)) * 100
axis = _get_axis_by_multioutput(multioutput)

return np.mean(np.abs((y_true_array - y_pred_array) / y_true_array), axis=axis) * 100


def smape(y_true: ArrayLike, y_pred: ArrayLike, eps: float = 1e-15) -> float:
def smape(y_true: ArrayLike, y_pred: ArrayLike, eps: float = 1e-15, multioutput: str = "joint") -> ArrayLike:
"""Symmetric mean absolute percentage error.
`Wikipedia entry on the Symmetric mean absolute percentage error
Expand All @@ -70,22 +110,29 @@ def smape(y_true: ArrayLike, y_pred: ArrayLike, eps: float = 1e-15) -> float:
SMAPE is undefined for ``y_true[i] + y_pred[i] == 0`` for any ``i``, so all zeros ``y_true[i] + y_pred[i]`` are
clipped to ``max(eps, abs(y_true) + abs(y_pred))``.
multioutput:
Defines aggregating of multiple output values
(see :py:class:`~etna.metrics.functional_metrics.FunctionalMetricMultioutput`).
Returns
-------
float
A non-negative floating point value (the best value is 0.0).
:
A non-negative floating point value (the best value is 0.0), or an array of floating point values,
one for each individual target.
"""
y_true_array, y_pred_array = np.asarray(y_true), np.asarray(y_pred)

if len(y_true_array.shape) != len(y_pred_array.shape):
raise ValueError("Shapes of the labels must be the same")

axis = _get_axis_by_multioutput(multioutput)

return 100 * np.mean(
2 * np.abs(y_pred_array - y_true_array) / (np.abs(y_true_array) + np.abs(y_pred_array)).clip(eps)
2 * np.abs(y_pred_array - y_true_array) / (np.abs(y_true_array) + np.abs(y_pred_array)).clip(eps), axis=axis
)


def sign(y_true: ArrayLike, y_pred: ArrayLike) -> float:
def sign(y_true: ArrayLike, y_pred: ArrayLike, multioutput: str = "joint") -> ArrayLike:
"""Sign error metric.
.. math::
Expand All @@ -103,20 +150,27 @@ def sign(y_true: ArrayLike, y_pred: ArrayLike) -> float:
Estimated target values.
multioutput:
Defines aggregating of multiple output values
(see :py:class:`~etna.metrics.functional_metrics.FunctionalMetricMultioutput`).
Returns
-------
float
A floating point value (the best value is 0.0).
:
A floating point value, or an array of floating point values,
one for each individual target.
"""
y_true_array, y_pred_array = np.asarray(y_true), np.asarray(y_pred)

if len(y_true_array.shape) != len(y_pred_array.shape):
raise ValueError("Shapes of the labels must be the same")

return np.mean(np.sign(y_true_array - y_pred_array))
axis = _get_axis_by_multioutput(multioutput)

return np.mean(np.sign(y_true_array - y_pred_array), axis=axis)

def max_deviation(y_true: ArrayLike, y_pred: ArrayLike) -> float:

def max_deviation(y_true: ArrayLike, y_pred: ArrayLike, multioutput: str = "joint") -> ArrayLike:
"""Max Deviation metric.
Parameters
Expand All @@ -131,25 +185,31 @@ def max_deviation(y_true: ArrayLike, y_pred: ArrayLike) -> float:
Estimated target values.
multioutput:
Defines aggregating of multiple output values
(see :py:class:`~etna.metrics.functional_metrics.FunctionalMetricMultioutput`).
Returns
-------
float
A floating point value (the best value is 0.0).
:
A non-negative floating point value (the best value is 0.0), or an array of floating point values,
one for each individual target.
"""
y_true_array, y_pred_array = np.asarray(y_true), np.asarray(y_pred)

if len(y_true_array.shape) != len(y_pred_array.shape):
raise ValueError("Shapes of the labels must be the same")

prefix_error_sum = np.cumsum(y_pred_array - y_true_array)
axis = _get_axis_by_multioutput(multioutput)

return max(np.abs(prefix_error_sum))
prefix_error_sum = np.cumsum(y_pred_array - y_true_array, axis=axis)
return np.max(np.abs(prefix_error_sum), axis=axis)


rmse = partial(mse, squared=False)


def wape(y_true: ArrayLike, y_pred: ArrayLike) -> float:
def wape(y_true: ArrayLike, y_pred: ArrayLike, multioutput: str = "joint") -> ArrayLike:
"""Weighted average percentage Error metric.
.. math::
Expand All @@ -167,14 +227,24 @@ def wape(y_true: ArrayLike, y_pred: ArrayLike) -> float:
Estimated target values.
multioutput:
Defines aggregating of multiple output values
(see :py:class:`~etna.metrics.functional_metrics.FunctionalMetricMultioutput`).
Returns
-------
float
A floating point value (the best value is 0.0).
:
A non-negative floating point value (the best value is 0.0), or an array of floating point values,
one for each individual target.
"""
y_true_array, y_pred_array = np.asarray(y_true), np.asarray(y_pred)

if len(y_true_array.shape) != len(y_pred_array.shape):
raise ValueError("Shapes of the labels must be the same")

return np.sum(np.abs(y_true_array - y_pred_array)) / np.sum(np.abs(y_true_array))
axis = _get_axis_by_multioutput(multioutput)

return np.sum(np.abs(y_true_array - y_pred_array), axis=axis) / np.sum(np.abs(y_true_array), axis=axis) # type: ignore


__all__ = ["mae", "mse", "msle", "medae", "r2_score", "mape", "smape", "sign", "max_deviation", "rmse", "wape"]
Loading

1 comment on commit 26be1b2

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.