Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFTExplainer #1392

Merged
merged 50 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
ce98db1
#675 add first draft for tft_explainer
Cattes Nov 2, 2022
5678f7f
#675 add first working version of TFTExplainer class with tests
Cattes Nov 22, 2022
c22e96d
#675 allow passing of arguments to the explain method of the TFTExpla…
Cattes Nov 27, 2022
598b134
#675 add test for multiple_covariates input to test_tft_explainer.py
Cattes Nov 27, 2022
f7387a4
#675 add correct feature names to vsv
Cattes Nov 27, 2022
68e384d
#675 add TFTExplainer to 13-TFT-examples.ipynb
Cattes Nov 27, 2022
bfffe87
Merge branch 'master' into feature/675_tft_explainer
Cattes Nov 27, 2022
160d196
#675 add CHANGELOG.md entry for the TFTExplainer class
Cattes Nov 27, 2022
60eae66
Merge branch 'unit8co:master' into feature/675_tft_explainer
Cattes Nov 28, 2022
42cfb92
#675 use @MagMueller's plot method for the variable importance plot
Cattes Nov 28, 2022
57e64fb
Merge branch 'master' into feature/675_tft_explainer
Cattes Dec 19, 2022
e59a7ba
#675 allow absolute tolerance of 1% in feature importance test
Cattes Dec 19, 2022
49587c9
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 18, 2023
706e190
Update CHANGELOG.md
hrzn Jan 18, 2023
eb4bcfc
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 20, 2023
96a1c2b
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 23, 2023
4556a44
Merge branch 'unit8co:master' into feature/675_tft_explainer
Cattes Jan 23, 2023
79b7755
#675 work in PR feedback
Cattes Jan 23, 2023
18f9f65
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 24, 2023
32ec858
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 31, 2023
6b813f2
Update darts/explainability/tft_explainer.py
Cattes Feb 2, 2023
6a8b0fb
Update darts/timeseries.py
Cattes Feb 2, 2023
b9f0a9e
Merge branch 'master' into feature/675_tft_explainer
hrzn Feb 10, 2023
e5242fb
Merge branch 'master' into feature/675_tft_explainer
madtoinou Feb 14, 2023
ce6a5a3
Merge branch 'master' into feature/675_tft_explainer
hrzn Feb 23, 2023
ca10d0d
#675 Add docstrings to tft_explainer.py
Cattes Feb 23, 2023
4f1409d
#675 Allow Dict[str, TimeSeries] as ExplainabilityResult input
Cattes Feb 23, 2023
ccff53b
#675 remove horizon=0 from the 13-TFT-examples.ipynb notebook
Cattes Feb 23, 2023
53bf28d
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 11, 2023
f0e551c
fix failing tests p1
dennisbader Jul 11, 2023
42c9c96
refactor ForecastingModelExplainer.__init__
dennisbader Jul 12, 2023
9464c8f
further explainability refactoring for input processing
dennisbader Jul 12, 2023
71e1e81
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 19, 2023
7ef185a
refactor ForecastingModelExplainer p3
dennisbader Jul 19, 2023
ce4e233
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 21, 2023
8979bd8
full refactor of ForecastingModelExplainer
dennisbader Jul 21, 2023
40c5aef
update component naming
dennisbader Jul 21, 2023
10fc42b
add static covariates importance
dennisbader Jul 24, 2023
052680a
improved attention head plots
dennisbader Jul 24, 2023
29e5aa3
multiple time series support
dennisbader Jul 24, 2023
090419b
update explainability documnetation
dennisbader Jul 25, 2023
72f8ded
update TFTModel full attention
dennisbader Jul 25, 2023
8a3af53
remove optional horizon from HorizonBasedExplainabilityResult
dennisbader Jul 25, 2023
9ffcd5f
update TFTModel example notebook
dennisbader Jul 26, 2023
04daed1
fix covariates issue when supplying covariates at predict time
dennisbader Jul 26, 2023
cd79722
update unit tests
dennisbader Jul 27, 2023
a66c86e
update changelog
dennisbader Jul 27, 2023
7e80088
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 31, 2023
d0c472c
applied suggestions from PR review
dennisbader Jul 31, 2023
92391a8
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- `RegressionEnsembleModel` and `NaiveEnsembleModel` can generate probabilistic forecasts, probabilistics `forecasting_models` can be sampled to train the `regression_model`, updated the documentation (stacking technique). [#1692](https://github.com/unit8co/darts/pull/#1692) by [Antoine Madrona](https://github.com/madtoinou).
- Improvements to `ShapExplainer`:
- Added static covariates support to `ShapeExplainer`. [#1803](https://github.com/unit8co/darts/pull/#1803) by [Anne de Vries](https://github.com/anne-devries) and [Dennis Bader](https://github.com/dennisbader).
- Improved static covariates column naming when applying a `sklearn.preprocessing.OneHotEncoder` with `StaticCovariatesTransformer` [#1863](https://github.com/unit8co/darts/pull/1863) by [Anne de Vries](https://github.com/anne-devries)
- Added `MSTL` (Season-Trend decomposition using LOESS for multiple seasonalities) as a `method` option for `extract_trend_and_seasonality()`. [#1879](https://github.com/unit8co/darts/pull/1879) by [Alex Colpitts](https://github.com/alexcolpitts96)
- Added `RINorm` (Reversible Instance Norm) as a new layer normalization option. [#1121](https://github.com/unit8co/darts/issues/1121) by [Alex Colpitts](https://github.com/alexcolpitts96)
- Improved static covariates column naming when applying a `sklearn.preprocessing.OneHotEncoder` with `StaticCovariatesTransformer` [#1863](https://github.com/unit8co/darts/pull/1863) by [Anne de Vries](https://github.com/anne-devries).
- Added `MSTL` (Season-Trend decomposition using LOESS for multiple seasonalities) as a `method` option for `extract_trend_and_seasonality()`. [#1879](https://github.com/unit8co/darts/pull/1879) by [Alex Colpitts](https://github.com/alexcolpitts96).
- Added `RINorm` (Reversible Instance Norm) as a new layer normalization option. [#1121](https://github.com/unit8co/darts/issues/1121) by [Alex Colpitts](https://github.com/alexcolpitts96).
- New forecasting model: `TiDEModel` as proposed in [this paper](https://arxiv.org/abs/2304.08424). An MLP based encoder-decoder model that outperforms many Transformer-based architectures. [#1727](https://github.com/unit8co/darts/pull/1727) by [Alex Colpitts](https://github.com/alexcolpitts96).
- New forecasting model explainer: `TFTExplainer` for `TFTModel`. You can now access and visualize the trained model's feature importances and self attention. [#1392](https://github.com/unit8co/darts/issues/1392) by [Sebastian Cattes](https://github.com/Cattes) and [Dennis Bader](https://github.com/dennisbader).

**Fixed**
- Fixed an issue not considering original component names for `TimeSeries.plot()` when providing a label prefix. [#1783](https://github.com/unit8co/darts/pull/1783) by [Simon Sudrich](https://github.com/sudrich).
Expand All @@ -36,10 +37,16 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Fixed `TimeSeries.__getitem__()` for series with a RangeIndex with start != 0 and freq != 1. [#1868](https://github.com/unit8co/darts/pull/#1868) by [Dennis Bader](https://github.com/dennisbader).
- Fixed an issue where `DTWAlignment.plot_alignment()` was not plotting the alignment plot of series with a RangeIndex correctly. [#1880](https://github.com/unit8co/darts/pull/1880) by [Ahmet Zamanis](https://github.com/AhmetZamanis) and [Dennis Bader](https://github.com/dennisbader).
- Fixed an issue when calling `ARIMA.predict()` and `num_samples > 1` (probabilistic forecasting), where the start point of the simulation was not anchored to the end of the target series. [#1893](https://github.com/unit8co/darts/pull/1893) by [Dennis Bader](https://github.com/dennisbader).
- Fixed an issue when using `TFTModel.predict()` with `full_attention=True` where the attention mask was not applied properly. [#1392](https://github.com/unit8co/darts/issues/1392) by [Dennis Bader](https://github.com/dennisbader).

**Removed**
- Removed support for Python 3.7 [#1864](https://github.com/unit8co/darts/pull/#1864) by [Dennis Bader](https://github.com/dennisbader).

### For developers of the library:

**Improvements**
- Refactored the `ForecastingModelExplainer` and `ExplainabilityResult` to simplify implementation of new explainers. [#1392](https://github.com/unit8co/darts/issues/1392) by [Dennis Bader](https://github.com/dennisbader).

## [0.24.0](https://github.com/unit8co/darts/tree/0.24.0) (2023-04-12)
### For users of the library:

Expand Down Expand Up @@ -95,6 +102,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Major refactor of data transformers which simplifies implementation of new transformers. [#1409](https://github.com/unit8co/darts/pull/1409) by [Matt Bilton](https://github.com/mabilton).


- Added new `TFTExplainer` class to implement the Explainable AI part described in [the paper](https://arxiv.org/abs/1912.09363) of the `TFT` model. [#1392](https://github.com/unit8co/darts/pull/1392) by [Sebastian Cattes](https://github.com/cattes).

## [0.23.1](https://github.com/unit8co/darts/tree/0.23.1) (2023-01-12)
Patch release

Expand Down Expand Up @@ -166,7 +175,6 @@ Patch release
by [Antoine Madrona](https://github.com/madtoinou).



**Fixed**
- Fixed edge case in ShapExplainer for regression models where covariates series > target series
[#1310](https://https://github.com/unit8co/darts/pull/1310) by [Rijk van der Meulen](https://github.com/rijkvandermeulen)
Expand Down
7 changes: 6 additions & 1 deletion darts/explainability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,10 @@
--------------
"""

from darts.explainability.explainability_result import ExplainabilityResult
from darts.explainability.explainability_result import (
ShapExplainabilityResult,
TFTExplainabilityResult,
_ExplainabilityResult,
)
from darts.explainability.shap_explainer import ShapExplainer
from darts.explainability.tft_explainer import TFTExplainer
Loading
Loading