Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add feature values to explainability result #1546

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ We do our best to avoid the introduction of breaking changes,
but cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "🔴".

## [Unreleased](https://github.com/unit8co/darts/tree/master)
- Created `ShapExplainabilityResult` by extending `ExplainabilityResult`. This subclass carries additional information
specific to Shap Explainers (i.e., the corresponding feature values and the underlying `shap.Explanation` object).
[#1545](https://github.com/unit8co/darts/pull/1545) by [Rijk van der Meulen](https://github.com/rijkvandermeulen).

[Full Changelog](https://github.com/unit8co/darts/compare/0.23.1...master)

## [0.23.1](https://github.com/unit8co/darts/tree/0.23.1) (2023-01-12)
Expand Down
123 changes: 114 additions & 9 deletions darts/explainability/explainability_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
"""

from abc import ABC
from typing import Dict, Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union

import shap
from numpy import integer

from darts import TimeSeries
Expand All @@ -29,7 +30,6 @@ def __init__(
Sequence[Dict[integer, Dict[str, TimeSeries]]],
],
):

self.explained_forecasts = explained_forecasts
if isinstance(self.explained_forecasts, list):
self.available_horizons = list(self.explained_forecasts[0].keys())
Expand All @@ -55,7 +55,54 @@ def get_explanation(
The component for which to return the explanation. Does not
need to be specified for univariate series.
"""
return self._query_explainability_result(
self.explained_forecasts, horizon, component
)

def _query_explainability_result(
self,
attr: Union[
Dict[integer, Dict[str, Any]], Sequence[Dict[integer, Dict[str, Any]]]
],
horizon: int,
component: Optional[str] = None,
) -> Any:
"""
Helper that extracts and returns the explainability result attribute for a specified horizon and component from
the input attribute.

Parameters
----------
attr
An explainability result attribute from which to extract the content for a certain horizon and component.
horizon
The horizon for which to return the content of the attribute.
component
The component for which to return the content of the attribute. Does not
need to be specified for univariate series.
"""
self._validate_input_for_querying_explainability_result(horizon, component)
if component is None:
component = self.available_components[0]
if isinstance(attr, list):
return [attr[i][horizon][component] for i in range(len(attr))]
else:
return attr[horizon][component]

def _validate_input_for_querying_explainability_result(
self, horizon: int, component: Optional[str] = None
) -> None:
"""
Helper that validates the input parameters of a method that queries the ExplainabilityResult.

Parameters
----------
horizon
The horizon for which to return the explanation.
component
The component for which to return the explanation. Does not
need to be specified for univariate series.
"""
raise_if(
component is None and len(self.available_components) > 1,
ValueError(
Expand All @@ -81,10 +128,68 @@ def get_explanation(
),
)

if isinstance(self.explained_forecasts, list):
return [
self.explained_forecasts[i][horizon][component]
for i in range(len(self.explained_forecasts))
]
else:
return self.explained_forecasts[horizon][component]

class ShapExplainabilityResult(ExplainabilityResult):
"""
Stores the explainability results of a :class:`ShapExplainer`
with convenient access to the results. It extends the :class:`ExplainabilityResult` and carries additional
information specific to the Shap explainers. In particular, in addition to the `explained_forecasts` (which in
the case of the :class:`ShapExplainer` are the shap values), it also provides access to the corresponding
`feature_values` and the underlying `shap.Explanation` object.
"""

def __init__(
self,
explained_forecasts: Union[
Dict[integer, Dict[str, TimeSeries]],
Sequence[Dict[integer, Dict[str, TimeSeries]]],
],
feature_values: Union[
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
Dict[integer, Dict[str, TimeSeries]],
Sequence[Dict[integer, Dict[str, TimeSeries]]],
],
shap_explanation_object: Union[
Dict[integer, Dict[str, shap.Explanation]],
Sequence[Dict[integer, Dict[str, shap.Explanation]]],
],
):
super().__init__(explained_forecasts)
self.feature_values = feature_values
self.shap_explanation_object = shap_explanation_object

def get_feature_values(
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
self, horizon: int, component: Optional[str] = None
) -> Union[TimeSeries, Sequence[TimeSeries]]:
"""
Returns one or several `TimeSeries` representing the feature values
for a given horizon and component.

Parameters
----------
horizon
The horizon for which to return the feature values.
component
The component for which to return the feature values. Does not
need to be specified for univariate series.
"""
return self._query_explainability_result(
self.feature_values, horizon, component
)

def get_shap_explanation_object(
self, horizon: int, component: Optional[str] = None
) -> Union[shap.Explanation, Sequence[shap.Explanation]]:
"""
Returns the underlying `shap.Explanation` object for a given horizon and component.

Parameters
----------
horizon
The horizon for which to return the `shap.Explanation` object.
component
The component for which to return the `shap.Explanation` object. Does not
need to be specified for univariate series.
"""
return self._query_explainability_result(
self.shap_explanation_object, horizon, component
)
38 changes: 28 additions & 10 deletions darts/explainability/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@
from sklearn.multioutput import MultiOutputRegressor

from darts import TimeSeries
from darts.explainability.explainability import (
ExplainabilityResult,
ForecastingModelExplainer,
)
from darts.explainability.explainability import ForecastingModelExplainer
from darts.explainability.explainability_result import ShapExplainabilityResult
from darts.logging import get_logger, raise_if, raise_log
from darts.models.forecasting.regression_model import RegressionModel
from darts.utils.data.tabularization import create_lagged_prediction_data
Expand Down Expand Up @@ -187,7 +185,7 @@ def explain(
] = None,
horizons: Optional[Sequence[int]] = None,
target_components: Optional[Sequence[str]] = None,
) -> ExplainabilityResult:
) -> ShapExplainabilityResult:
super().explain(
foreground_series, foreground_past_covariates, foreground_future_covariates
)
Expand Down Expand Up @@ -216,7 +214,8 @@ def explain(
)

shap_values_list = []

feature_values_list = []
shap_explanation_object_list = []
for idx, foreground_ts in enumerate(foreground_series):

foreground_past_cov_ts = None
Expand All @@ -240,22 +239,40 @@ def explain(
)

shap_values_dict = {}
feature_values_dict = {}
shap_explanation_object_dict = {}
for h in horizons:
tmp = {}
shap_values_dict_single_h = {}
feature_values_dict_single_h = {}
shap_explanation_object_dict_single_h = {}
for t in target_names:
tmp[t] = TimeSeries.from_times_and_values(
shap_values_dict_single_h[t] = TimeSeries.from_times_and_values(
shap_[h][t].time_index,
shap_[h][t].values,
columns=shap_[h][t].feature_names,
)
shap_values_dict[h] = tmp
feature_values_dict_single_h[t] = TimeSeries.from_times_and_values(
shap_[h][t].time_index,
shap_[h][t].data,
columns=shap_[h][t].feature_names,
)
shap_explanation_object_dict_single_h[t] = shap_[h][t]
shap_values_dict[h] = shap_values_dict_single_h
feature_values_dict[h] = feature_values_dict_single_h
shap_explanation_object_dict[h] = shap_explanation_object_dict_single_h

shap_values_list.append(shap_values_dict)
feature_values_list.append(feature_values_dict)
shap_explanation_object_list.append(shap_explanation_object_dict)

if len(shap_values_list) == 1:
shap_values_list = shap_values_list[0]
feature_values_list = feature_values_list[0]
shap_explanation_object_list = shap_explanation_object_list[0]

return ExplainabilityResult(shap_values_list)
return ShapExplainabilityResult(
shap_values_list, feature_values_list, shap_explanation_object_list
)

def summary_plot(
self,
Expand Down Expand Up @@ -580,6 +597,7 @@ def shap_explanations(
:, :, self.target_dim * (h - 1) + t_idx
]
)
tmp_t.data = shap_explanation_tmp.data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I believe it is. Reason being that in this particular if condition the shap.Explanation object is 'manually' constructed so the tmp_t dict doesn't contain the feature values (i.e., data) yet.

tmp_t.base_values = shap_explanation_tmp.base_values[
:, self.target_dim * (h - 1) + t_idx
].ravel()
Expand Down
Loading