Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFTExplainer #1392

Merged
merged 50 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
ce98db1
#675 add first draft for tft_explainer
Cattes Nov 2, 2022
5678f7f
#675 add first working version of TFTExplainer class with tests
Cattes Nov 22, 2022
c22e96d
#675 allow passing of arguments to the explain method of the TFTExpla…
Cattes Nov 27, 2022
598b134
#675 add test for multiple_covariates input to test_tft_explainer.py
Cattes Nov 27, 2022
f7387a4
#675 add correct feature names to vsv
Cattes Nov 27, 2022
68e384d
#675 add TFTExplainer to 13-TFT-examples.ipynb
Cattes Nov 27, 2022
bfffe87
Merge branch 'master' into feature/675_tft_explainer
Cattes Nov 27, 2022
160d196
#675 add CHANGELOG.md entry for the TFTExplainer class
Cattes Nov 27, 2022
60eae66
Merge branch 'unit8co:master' into feature/675_tft_explainer
Cattes Nov 28, 2022
42cfb92
#675 use @MagMueller's plot method for the variable importance plot
Cattes Nov 28, 2022
57e64fb
Merge branch 'master' into feature/675_tft_explainer
Cattes Dec 19, 2022
e59a7ba
#675 allow absolute tolerance of 1% in feature importance test
Cattes Dec 19, 2022
49587c9
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 18, 2023
706e190
Update CHANGELOG.md
hrzn Jan 18, 2023
eb4bcfc
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 20, 2023
96a1c2b
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 23, 2023
4556a44
Merge branch 'unit8co:master' into feature/675_tft_explainer
Cattes Jan 23, 2023
79b7755
#675 work in PR feedback
Cattes Jan 23, 2023
18f9f65
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 24, 2023
32ec858
Merge branch 'master' into feature/675_tft_explainer
hrzn Jan 31, 2023
6b813f2
Update darts/explainability/tft_explainer.py
Cattes Feb 2, 2023
6a8b0fb
Update darts/timeseries.py
Cattes Feb 2, 2023
b9f0a9e
Merge branch 'master' into feature/675_tft_explainer
hrzn Feb 10, 2023
e5242fb
Merge branch 'master' into feature/675_tft_explainer
madtoinou Feb 14, 2023
ce6a5a3
Merge branch 'master' into feature/675_tft_explainer
hrzn Feb 23, 2023
ca10d0d
#675 Add docstrings to tft_explainer.py
Cattes Feb 23, 2023
4f1409d
#675 Allow Dict[str, TimeSeries] as ExplainabilityResult input
Cattes Feb 23, 2023
ccff53b
#675 remove horizon=0 from the 13-TFT-examples.ipynb notebook
Cattes Feb 23, 2023
53bf28d
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 11, 2023
f0e551c
fix failing tests p1
dennisbader Jul 11, 2023
42c9c96
refactor ForecastingModelExplainer.__init__
dennisbader Jul 12, 2023
9464c8f
further explainability refactoring for input processing
dennisbader Jul 12, 2023
71e1e81
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 19, 2023
7ef185a
refactor ForecastingModelExplainer p3
dennisbader Jul 19, 2023
ce4e233
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 21, 2023
8979bd8
full refactor of ForecastingModelExplainer
dennisbader Jul 21, 2023
40c5aef
update component naming
dennisbader Jul 21, 2023
10fc42b
add static covariates importance
dennisbader Jul 24, 2023
052680a
improved attention head plots
dennisbader Jul 24, 2023
29e5aa3
multiple time series support
dennisbader Jul 24, 2023
090419b
update explainability documnetation
dennisbader Jul 25, 2023
72f8ded
update TFTModel full attention
dennisbader Jul 25, 2023
8a3af53
remove optional horizon from HorizonBasedExplainabilityResult
dennisbader Jul 25, 2023
9ffcd5f
update TFTModel example notebook
dennisbader Jul 26, 2023
04daed1
fix covariates issue when supplying covariates at predict time
dennisbader Jul 26, 2023
cd79722
update unit tests
dennisbader Jul 27, 2023
a66c86e
update changelog
dennisbader Jul 27, 2023
7e80088
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 31, 2023
d0c472c
applied suggestions from PR review
dennisbader Jul 31, 2023
92391a8
Merge branch 'master' into feature/675_tft_explainer
dennisbader Jul 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
## [Unreleased](https://github.com/unit8co/darts/tree/master)
[Full Changelog](https://github.com/unit8co/darts/compare/0.23.1...master)

- Added new `TFTExplainer` class to implement the Explainable AI part described in [the paper](https://arxiv.org/abs/1912.09363) of the `TFT` model. [#1392](https://github.com/unit8co/darts/pull/1392) by [Sebastian Cattes](https://github.com/cattes).

## [0.23.1](https://github.com/unit8co/darts/tree/0.23.1) (2023-01-12)
Patch release

Expand Down Expand Up @@ -78,7 +80,6 @@ Patch release
by [Antoine Madrona](https://github.com/madtoinou).



**Fixed**
- Fixed edge case in ShapExplainer for regression models where covariates series > target series
[#1310](https://https://github.com/unit8co/darts/pull/1310) by [Rijk van der Meulen](https://github.com/rijkvandermeulen)
Expand Down
1 change: 1 addition & 0 deletions darts/explainability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from darts.explainability.explainability_result import ExplainabilityResult
from darts.explainability.shap_explainer import ShapExplainer
from darts.explainability.tft_explainer import TFTExplainer
4 changes: 1 addition & 3 deletions darts/explainability/explainability_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def get_explanation(

raise_if(
component is None and len(self.available_components) > 1,
ValueError(
"The component parameter is required when the model has more than one component."
),
"The component parameter is required when the model has more than one component.",
logger,
)

Expand Down
252 changes: 252 additions & 0 deletions darts/explainability/tft_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
from typing import Dict, List, Literal, Optional
Cattes marked this conversation as resolved.
Show resolved Hide resolved
Cattes marked this conversation as resolved.
Show resolved Hide resolved

import matplotlib.pyplot as plt
import pandas as pd
from torch import Tensor

from darts import TimeSeries
from darts.explainability.explainability import (
ExplainabilityResult,
ForecastingModelExplainer,
)
from darts.models import TFTModel

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal


class TFTExplainer(ForecastingModelExplainer):
def __init__(
self,
model: TFTModel,
):
"""
Explainer class for the TFT model.

Parameters
----------
model
The fitted TFT model to be explained.
"""
super().__init__(model)

if not model._fit_called:
raise ValueError("The model needs to be trained before explaining it.")

self._model = model

@property
def encoder_importance(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add docstrings explaining what this and decoder_importance are returning? They can be quite useful I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added docstrings to the module and properties. I wasn't 100% sure on the details of the model so if you could have a look at it that would be great? If everything is fine you can resolve this conversation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I was trying to call this function and ran into an error. it says no attribute 'encoder_sparse_weights'. I went to the tft_model.py and uncommented this code chunk :
return self.to_network_output(
prediction=self.transform_output(out, target_scale=x["target_scale"]),
attention=attn_out_weights,
static_variables=static_covariate_var,
encoder_variables=encoder_sparse_weights,
decoder_variables=decoder_sparse_weights,
decoder_lengths=decoder_lengths,
encoder_lengths=encoder_lengths,
)

It now says TFTModule has no attribute called to_network_output.

Can I get some help regarding how to call the explainer and use it in my code?

return self._get_importance(
weight=self._model.model._encoder_sparse_weights,
names=self._model.model.encoder_variables,
)

@property
def decoder_importance(self):
return self._get_importance(
weight=self._model.model._decoder_sparse_weights,
names=self._model.model.decoder_variables,
)

def get_variable_selection_weight(self, plot=False) -> Dict[str, pd.DataFrame]:
"""Returns the variable selection weight of the TFT model.

Parameters
----------
plot
Whether to plot the variable selection weight.

Returns
-------
TimeSeries
The variable selection weight.

"""

if plot:
# plot the encoder and decoder weights
self._plot_cov_selection(
self.encoder_importance,
title="Encoder variable importance",
)
self._plot_cov_selection(
self.decoder_importance,
title="Decoder variable importance",
)

return {
"encoder_importance": self.encoder_importance,
"decoder_importance": self.decoder_importance,
}

def explain(self, **kwargs) -> ExplainabilityResult:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about taking some of the predict() parameters explicitly? At least series, past_covariates, future_covariates and n would make sense IMO. It will produce more comprehensible API documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if that is relevant here at all.

I do not understand why predict has to be called to get the proper attention heads of the time series. The learned autoregressive connections should depend on how predict is called. But if predict is not called at all the attention_heads saved in self._model.model._attn_out_weights do not have the right format. I assume they are still in a state of training and the predict() call changes that.

If that is the case I would rather remove the **kwargs completely from the explain method here and call predict once with self._model.model.output_chunk_length to get the correct attention heads.

Copy link
Contributor

@hrzn hrzn Feb 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree with you we need to call predict() here. However predict() takes a lot more arguments than just n. It takes series (the series to predict), as well as covariates arguments and other arguments: see the API doc.
I think we should probably change the signature of explain() to something like

def explain(self, series, past_covariates, future_covariates, **kwargs) -> ExplainabilityResult

This way in the docstring you can list series, past_covariates and future_covariates, and explain that those are passed down to predict(). You can also say that n will always be set to output_chunk_length (unless I'm wrong I think that's always what's needed), and that **kwargs can contain extra arguments for the predict method and link to the API documentation of TFTModel.predict().
I hope it makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think calling predict() is just a technicality to get to the correct attention weights. I don't think the way we call predict matters at all for the result, its just important that it was called (for whatever reason).
If I understand it correctly the attention weights are learned during training and are not impacted by the data used in the predict call.

They don't have a similar logic behind them like shapley values but are learned during the training and are a fixed part of the trained model.

Maybe I am wrong, but if I am right I would rather remove all parameter passed to explain() and have the predict() call happen without the user needing to know about it at all.

"""Returns the explainability result of the TFT model.

The explainability result contains the attention heads of the TFT model.
The attention heads determine the contribution of time-varying inputs.

Parameters
----------
kwargs
Arguments passed to the `predict` method of the TFT model.

Returns
-------
ExplainabilityResult
The explainability result containing the attention heads.

"""
super().explain()
# without the predict call, the weights will still bet set to the last iteration of the forward() method
# of the _TFTModule class
if "n" not in kwargs:
kwargs["n"] = self._model.model.output_chunk_length

_ = self._model.predict(**kwargs)

# get the weights and the attention head from the trained model for the prediction
attention_heads = (
self._model.model._attn_out_weights.squeeze().sum(axis=1).detach()
)

# return the explainer result to be used in other methods
return ExplainabilityResult(
{
0: {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always relating to horizon 0 only? How about the cases where predict() is called with n > 1 above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to set the 0 here to be compatible with the ForecastingModelExplainer base class. To get the attention_heads the predict method of the TFT class has to be called or the attention_heads will not show the correct values. I am not sure why yet. Placing this logic into the explain() method as the ExplainabilityResult felt like a sensible choice.
We could deviate from the ForecastingModelExplainer class or add a note to the docstring that the 0 is irrelevant in this context.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I follow well, here the explanation is for all forecasted horizons at once, right?
I would then propose the following. We can adapt the class ExplainabilityResult in order to make it a little bit more flexible:

  • It could be used with one explanation per horizon (as now), or
  • with one single explanation for all horizons (as required in this case for the TFT).

To accommodate the second case, we could make it possible to build ExplainabilityResult with only a Dict[str, TimeSeries] (in addition to Dict[integer, Dict[str, TimeSeries]]), so we avoid specifying the horizon. We can also adapt ExplainabilityResult.get_explanation() to make specifying the horizon optional, and not supported if the underlying explanation is not split by horizon.

WDYT? I would find this cleaner than "hacking" the class by using a fake horizon 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cattes any thoughts on this? ^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its a good idea to change the class to handle the TFT explainer. I didn't want to do it before discussing it. Having the hack with horizon=0 was just to conform with the given api. It was not an intuitive solution.
I have added Dict[str, TimeSeries] to the valid type for class initialization and made the horizon optional.
I also added a few more validations to deal with the different input types explicitly.

"attention_heads": TimeSeries.from_values(attention_heads.T),
}
},
)

@staticmethod
def plot_attention_heads(
expl_result: ExplainabilityResult,
plot_type: Optional[Literal["all", "time", "heatmap"]] = "time",
):
"""Plots the attention heads of the TFT model."""
attention_heads = expl_result.get_explanation(
component="attention_heads",
horizon=0,
)
if plot_type == "all":
fig = plt.figure()
attention_heads.plot(
label="Attention Head",
max_nr_components=-1,
figure=fig,
)
# move legend to the right side of the figure
plt.legend(bbox_to_anchor=(0.95, 1), loc="upper left")
plt.xlabel("Time steps in the past (# lags)")
plt.ylabel("Attention")
elif plot_type == "time":
fig = plt.figure()
attention_heads.mean(1).plot(label="Mean Attention Head", figure=fig)
plt.xlabel("Time steps in the past (# lags)")
plt.ylabel("Attention")
elif plot_type == "heatmap":
avg_attention = attention_heads.values().transpose()
fig = plt.figure()
plt.imshow(avg_attention, cmap="hot", interpolation="nearest", figure=fig)
plt.xlabel("Time steps in the past (# lags)")
plt.ylabel("Horizon")
else:
raise ValueError("`plot_type` must be either 'all', 'time' or 'heatmap'")

def _get_importance(
self,
weight: Tensor,
names: List[str],
n_decimals=3,
) -> pd.DataFrame:
"""Returns the encoder or decoder variable of the TFT model.

Parameters
----------
weights
The weights of the encoder or decoder of the trained TFT model.
names
The encoder or decoder names saved in the TFT model class.
n_decimals
The number of decimals to round the importance to.

Returns
-------
pd.DataFrame
The importance of the variables.
"""
# transform the encoder/decoder weights to percentages, rounded to n_decimals
weights_percentage = (
weight.mean(axis=1).detach().numpy().mean(axis=0).round(n_decimals) * 100
)

# create a dataframe with the variable names and the weights
name_mapping = self._name_mapping
importance = pd.DataFrame(
weights_percentage,
columns=[name_mapping[name] for name in names],
)

# return the importance sorted descending
return importance.transpose().sort_values(0, ascending=False).transpose()

@property
def _name_mapping(self) -> Dict[str, str]:
"""Returns the feature name mapping of the TFT model.

Returns
-------
Dict[str, str]
The feature name mapping. For example
{
'past_covariate_0': 'heater',
'past_covariate_1': 'year',
'past_covariate_2': 'month',
'future_covariate_0': 'darts_enc_fc_cyc_month_sin',
'future_covariate_1': 'darts_enc_fc_cyc_month_cos',
'target_0': 'ice cream',
}

"""
past_covariates_name_mapping = {
f"past_covariate_{i}": colname
for i, colname in enumerate(self._model.past_covariate_series.components)
}
future_covariates_name_mapping = {
f"future_covariate_{i}": colname
for i, colname in enumerate(self._model.future_covariate_series.components)
}
target_name_mapping = {
f"target_{i}": colname
for i, colname in enumerate(self._model.training_series.components)
}

return {
**past_covariates_name_mapping,
**future_covariates_name_mapping,
**target_name_mapping,
}

@staticmethod
def _plot_cov_selection(
importance: pd.DataFrame, title: str = "Variable importance"
):
"""Plots the variable importance of the TFT model.

Parameters
----------
importance
The encoder / decoder importance.
title
The title of the plot.
dennisbader marked this conversation as resolved.
Show resolved Hide resolved

"""
fig = plt.figure()
plt.bar(importance.columns.tolist(), importance.values[0].tolist(), figure=fig)
plt.title(title)
plt.xlabel("Variable", fontsize=12)
plt.ylabel("Variable importance in %")
plt.show()
7 changes: 7 additions & 0 deletions darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def __init__(

self.output_layer = nn.Linear(self.hidden_size, self.n_targets * self.loss_size)

self._encoder_sparse_weights = None
self._decoder_sparse_weights = None
self._attn_out_weights = None

@property
def reals(self) -> List[str]:
"""
Expand Down Expand Up @@ -633,6 +637,9 @@ def forward(
out = out.view(
batch_size, self.output_chunk_length, self.n_targets, self.loss_size
)
self._encoder_sparse_weights = encoder_sparse_weights
self._decoder_sparse_weights = decoder_sparse_weights
self._attn_out_weights = attn_out_weights

# TODO: (Darts) remember this in case we want to output interpretation
# return self.to_network_output(
Expand Down
Loading