Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TFTExplainer #1392
Add TFTExplainer #1392
Changes from 14 commits
ce98db1
5678f7f
c22e96d
598b134
f7387a4
68e384d
bfffe87
160d196
60eae66
42cfb92
57e64fb
e59a7ba
49587c9
706e190
eb4bcfc
96a1c2b
4556a44
79b7755
18f9f65
32ec858
6b813f2
6a8b0fb
b9f0a9e
e5242fb
ce6a5a3
ca10d0d
4f1409d
ccff53b
53bf28d
f0e551c
42c9c96
9464c8f
71e1e81
7ef185a
ce4e233
8979bd8
40c5aef
10fc42b
052680a
29e5aa3
090419b
72f8ded
8a3af53
9ffcd5f
04daed1
cd79722
a66c86e
7e80088
d0c472c
92391a8
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add docstrings explaining what this and
decoder_importance
are returning? They can be quite useful I think.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added docstrings to the module and properties. I wasn't 100% sure on the details of the model so if you could have a look at it that would be great? If everything is fine you can resolve this conversation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, I was trying to call this function and ran into an error. it says no attribute 'encoder_sparse_weights'. I went to the tft_model.py and uncommented this code chunk :
return self.to_network_output(
prediction=self.transform_output(out, target_scale=x["target_scale"]),
attention=attn_out_weights,
static_variables=static_covariate_var,
encoder_variables=encoder_sparse_weights,
decoder_variables=decoder_sparse_weights,
decoder_lengths=decoder_lengths,
encoder_lengths=encoder_lengths,
)
It now says TFTModule has no attribute called to_network_output.
Can I get some help regarding how to call the explainer and use it in my code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about taking some of the
predict()
parameters explicitly? At leastseries
,past_covariates
,future_covariates
andn
would make sense IMO. It will produce more comprehensible API documentation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if that is relevant here at all.
I do not understand why predict has to be called to get the proper attention heads of the time series. The learned autoregressive connections should depend on how
predict
is called. But ifpredict
is not called at all theattention_heads
saved inself._model.model._attn_out_weights
do not have the right format. I assume they are still in a state of training and thepredict()
call changes that.If that is the case I would rather remove the
**kwargs
completely from theexplain
method here and call predict once withself._model.model.output_chunk_length
to get the correct attention heads.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree with you we need to call
predict()
here. Howeverpredict()
takes a lot more arguments than justn
. It takesseries
(the series to predict), as well as covariates arguments and other arguments: see the API doc.I think we should probably change the signature of
explain()
to something likeThis way in the docstring you can list
series
,past_covariates
andfuture_covariates
, and explain that those are passed down topredict()
. You can also say thatn
will always be set tooutput_chunk_length
(unless I'm wrong I think that's always what's needed), and that**kwargs
can contain extra arguments for the predict method and link to the API documentation ofTFTModel.predict()
.I hope it makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think calling
predict()
is just a technicality to get to the correct attention weights. I don't think the way we callpredict
matters at all for the result, its just important that it was called (for whatever reason).If I understand it correctly the attention weights are learned during training and are not impacted by the data used in the
predict
call.They don't have a similar logic behind them like shapley values but are learned during the training and are a fixed part of the trained model.
Maybe I am wrong, but if I am right I would rather remove all parameter passed to
explain()
and have thepredict()
call happen without the user needing to know about it at all.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this always relating to horizon 0 only? How about the cases where
predict()
is called withn > 1
above?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to set the
0
here to be compatible with theForecastingModelExplainer
base class. To get theattention_heads
thepredict
method of the TFT class has to be called or the attention_heads will not show the correct values. I am not sure why yet. Placing this logic into the explain() method as theExplainabilityResult
felt like a sensible choice.We could deviate from the
ForecastingModelExplainer
class or add a note to the docstring that the0
is irrelevant in this context.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if I follow well, here the explanation is for all forecasted horizons at once, right?
I would then propose the following. We can adapt the class
ExplainabilityResult
in order to make it a little bit more flexible:To accommodate the second case, we could make it possible to build
ExplainabilityResult
with only aDict[str, TimeSeries]
(in addition toDict[integer, Dict[str, TimeSeries]]
), so we avoid specifying the horizon. We can also adaptExplainabilityResult.get_explanation()
to make specifying thehorizon
optional, and not supported if the underlying explanation is not split by horizon.WDYT? I would find this cleaner than "hacking" the class by using a fake horizon 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cattes any thoughts on this? ^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its a good idea to change the class to handle the TFT explainer. I didn't want to do it before discussing it. Having the hack with horizon=0 was just to conform with the given api. It was not an intuitive solution.
I have added
Dict[str, TimeSeries]
to the valid type for class initialization and made thehorizon
optional.I also added a few more validations to deal with the different input types explicitly.