Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFTExplainer #1392

Merged
merged 50 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
ce98db1
#675 add first draft for tft_explainer
Cattes Nov 2, 2022
5678f7f
#675 add first working version of TFTExplainer class with tests
Cattes Nov 22, 2022
c22e96d
#675 allow passing of arguments to the explain method of the TFTExpla…
Cattes Nov 27, 2022
598b134
#675 add test for multiple_covariates input to test_tft_explainer.py
Cattes Nov 27, 2022
f7387a4
#675 add correct feature names to vsv
Cattes Nov 27, 2022
68e384d
#675 add TFTExplainer to 13-TFT-examples.ipynb
Cattes Nov 27, 2022
bfffe87
Merge branch 'master' into feature/675_tft_explainer
Cattes Nov 27, 2022
160d196
#675 add CHANGELOG.md entry for the TFTExplainer class
Cattes Nov 27, 2022
60eae66
Merge branch 'unit8co:master' into feature/675_tft_explainer
Cattes Nov 28, 2022
42cfb92
#675 use @MagMueller's plot method for the variable importance plot
Cattes Nov 28, 2022
57e64fb
Merge branch 'master' into feature/675_tft_explainer
Cattes Dec 19, 2022
e59a7ba
#675 allow absolute tolerance of 1% in feature importance test
Cattes Dec 19, 2022
49587c9
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 18, 2023
706e190
Update CHANGELOG.md
hrzn Jan 18, 2023
eb4bcfc
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 20, 2023
96a1c2b
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 23, 2023
4556a44
Merge branch 'unit8co:master' into feature/675_tft_explainer
Cattes Jan 23, 2023
79b7755
#675 work in PR feedback
Cattes Jan 23, 2023
18f9f65
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 24, 2023
32ec858
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 31, 2023
6b813f2
Update darts/explainability/tft_explainer.py
Cattes Feb 2, 2023
6a8b0fb
Update darts/timeseries.py
Cattes Feb 2, 2023
b9f0a9e
Merge branch 'master' into feature/675_tft_explainer
hrzn Feb 10, 2023
e5242fb
Merge branch 'master' into feature/675_tft_explainer
madtoinou Feb 14, 2023
ce6a5a3
Merge branch 'master' into feature/675_tft_explainer
hrzn Feb 23, 2023
ca10d0d
#675 Add docstrings to tft_explainer.py
Cattes Feb 23, 2023
4f1409d
#675 Allow Dict[str, TimeSeries] as ExplainabilityResult input
Cattes Feb 23, 2023
ccff53b
#675 remove horizon=0 from the 13-TFT-examples.ipynb notebook
Cattes Feb 23, 2023
53bf28d
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 11, 2023
f0e551c
fix failing tests p1
dennisbader Jul 11, 2023
42c9c96
refactor ForecastingModelExplainer.__init__
dennisbader Jul 12, 2023
9464c8f
further explainability refactoring for input processing
dennisbader Jul 12, 2023
71e1e81
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 19, 2023
7ef185a
refactor ForecastingModelExplainer p3
dennisbader Jul 19, 2023
ce4e233
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 21, 2023
8979bd8
full refactor of ForecastingModelExplainer
dennisbader Jul 21, 2023
40c5aef
update component naming
dennisbader Jul 21, 2023
10fc42b
add static covariates importance
dennisbader Jul 24, 2023
052680a
improved attention head plots
dennisbader Jul 24, 2023
29e5aa3
multiple time series support
dennisbader Jul 24, 2023
090419b
update explainability documnetation
dennisbader Jul 25, 2023
72f8ded
update TFTModel full attention
dennisbader Jul 25, 2023
8a3af53
remove optional horizon from HorizonBasedExplainabilityResult
dennisbader Jul 25, 2023
9ffcd5f
update TFTModel example notebook
dennisbader Jul 26, 2023
04daed1
fix covariates issue when supplying covariates at predict time
dennisbader Jul 26, 2023
cd79722
update unit tests
dennisbader Jul 27, 2023
a66c86e
update changelog
dennisbader Jul 27, 2023
7e80088
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 31, 2023
d0c472c
applied suggestions from PR review
dennisbader Jul 31, 2023
92391a8
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
## [Unreleased](https://github.com/unit8co/darts/tree/master)
[Full Changelog](https://github.com/unit8co/darts/compare/0.23.1...master)

- Added new `TFTExplainer` class to implement the Explainable AI part described in [the paper](https://arxiv.org/abs/1912.09363) of the `TFT` model. [#1392](https://github.com/unit8co/darts/pull/1392) by [Sebastian Cattes](https://github.com/cattes).

## [0.23.1](https://github.com/unit8co/darts/tree/0.23.1) (2023-01-12)
Patch release

Expand Down Expand Up @@ -78,7 +80,6 @@ Patch release
by [Antoine Madrona](https://github.com/madtoinou).



**Fixed**
- Fixed edge case in ShapExplainer for regression models where covariates series > target series
[#1310](https://https://github.com/unit8co/darts/pull/1310) by [Rijk van der Meulen](https://github.com/rijkvandermeulen)
Expand Down
1 change: 1 addition & 0 deletions darts/explainability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from darts.explainability.explainability_result import ExplainabilityResult
from darts.explainability.shap_explainer import ShapExplainer
from darts.explainability.tft_explainer import TFTExplainer
76 changes: 60 additions & 16 deletions darts/explainability/explainability_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numpy import integer

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_if_not
from darts.logging import get_logger, raise_if, raise_if_not, raise_log

logger = get_logger(__name__)

Expand All @@ -27,21 +27,50 @@ def __init__(
explained_forecasts: Union[
Dict[integer, Dict[str, TimeSeries]],
Sequence[Dict[integer, Dict[str, TimeSeries]]],
Dict[str, TimeSeries],
],
):

self.explained_forecasts = explained_forecasts
if isinstance(self.explained_forecasts, list):
raise_if_not(
isinstance(self.explained_forecasts[0], dict),
"The explained_forecasts Sequence must consist of dicts.",
logger,
)
raise_if_not(
all(isinstance(key, int) for key in self.explained_forecasts[0].keys()),
"The explained_forecasts dict Sequence must have all integer keys.",
logger,
)
self.available_horizons = list(self.explained_forecasts[0].keys())
h_0 = self.available_horizons[0]
self.available_components = list(self.explained_forecasts[0][h_0].keys())
elif isinstance(self.explained_forecasts, dict):
if all(isinstance(key, int) for key in self.explained_forecasts.keys()):
self.available_horizons = list(self.explained_forecasts.keys())
h_0 = self.available_horizons[0]
self.available_components = list(self.explained_forecasts[h_0].keys())
elif all(isinstance(key, str) for key in self.explained_forecasts.keys()):
self.available_horizons = []
self.available_components = list(self.explained_forecasts.keys())
else:
raise_log(
ValueError(
"The explained_forecasts dictionary must have all integer or all string keys."
),
logger,
)
else:
self.available_horizons = list(self.explained_forecasts.keys())
h_0 = self.available_horizons[0]
self.available_components = list(self.explained_forecasts[h_0].keys())
raise_log(
ValueError(
"The explained_forecasts must be a dictionary or a list of dictionaries."
),
logger,
)

def get_explanation(
self, horizon: int, component: Optional[str] = None
self, horizon: Optional[int] = None, component: Optional[str] = None
) -> Union[TimeSeries, Sequence[TimeSeries]]:
"""
Returns one or several `TimeSeries` representing the explanations
Expand All @@ -56,35 +85,50 @@ def get_explanation(
need to be specified for univariate series.
"""

# validate component argument
raise_if(
component is None and len(self.available_components) > 1,
ValueError(
"The component parameter is required when the model has more than one component."
),
"The component parameter is required when the model has more than one component.",
logger,
)

if component is None:
component = self.available_components[0]

raise_if_not(
horizon in self.available_horizons,
"Horizon {} is not available. Available horizons are: {}".format(
horizon, self.available_horizons
),
)

raise_if_not(
component in self.available_components,
"Component {} is not available. Available components are: {}".format(
component, self.available_components
),
)

# validate horizon argument
if horizon is not None:
raise_if(
len(self.available_horizons) == 0,
"The horizon parameter can not be used for a model where all time horizons are saved in the component.",
)

raise_if_not(
horizon in self.available_horizons,
"Horizon {} is not available. Available horizons are: {}".format(
horizon, self.available_horizons
),
)

if isinstance(self.explained_forecasts, list):
return [
self.explained_forecasts[i][horizon][component]
for i in range(len(self.explained_forecasts))
]
else:
elif all(isinstance(key, int) for key in self.explained_forecasts.keys()):
return self.explained_forecasts[horizon][component]
elif all(isinstance(key, str) for key in self.explained_forecasts.keys()):
return self.explained_forecasts[component]
else:
raise_log(
ValueError(
"Something went wrong. ExplainabilityResult got instantiated with an unexpected type."
),
logger,
)
Loading