-
Notifications
You must be signed in to change notification settings - Fork 881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/add feature values to explainability result #1546
Feature/add feature values to explainability result #1546
Conversation
Codecov ReportPatch coverage:
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## master #1546 +/- ##
==========================================
- Coverage 94.12% 94.05% -0.08%
==========================================
Files 125 125
Lines 11308 11326 +18
==========================================
+ Hits 10644 10653 +9
- Misses 664 673 +9
... and 9 files with indirect coverage changes Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rijkvandermeulen thanks again for another great PR :)
Looks good in general! I added a suggestion that we should try to keep ExplainabilityResult as a base class, and opt for a new ShapExplainabilityResult subclass that carries the added logic specific to Shap Explainers.
Apart from that, only minor suggestions, and some fixes in the unit tests.
if len(feature_values_list) == 1: | ||
feature_values_list = feature_values_list[0] | ||
if len(shap_explanation_object_list) == 1: | ||
shap_explanation_object_list = shap_explanation_object_list[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if len(feature_values_list) == 1: | |
feature_values_list = feature_values_list[0] | |
if len(shap_explanation_object_list) == 1: | |
shap_explanation_object_list = shap_explanation_object_list[0] | |
feature_values_list = feature_values_list[0] | |
shap_explanation_object_list = shap_explanation_object_list[0] |
The component for which to return the `shap.Explanations` object(s). Does not | ||
need to be specified for univariate series. | ||
""" | ||
if not self.shap_explanation_object: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have the ShapExplainabilityResult subclass, it is guaranteed that it will have a shap_explanation_object
@@ -580,6 +601,7 @@ def shap_explanations( | |||
:, :, self.target_dim * (h - 1) + t_idx | |||
] | |||
) | |||
tmp_t.data = shap_explanation_tmp.data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I believe it is. Reason being that in this particular if condition the shap.Explanation object is 'manually' constructed so the tmp_t dict doesn't contain the feature values (i.e., data) yet.
@@ -16,6 +18,45 @@ | |||
logger = get_logger(__name__) | |||
|
|||
|
|||
def check_input_validity_for_querying_explainability_result(func: Callable) -> Callable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we simply make this a normal function and call it at the begging in each method?
Wrappers can swallow up the actual method/function arguments in the documentation (from IDE for example) without some additional code
…ssert + other small stuff
… ShapExplainabilityResult
…ty_result # Conflicts: # darts/tests/explainability/test_shap_explainer.py
@dennisbader thanks for the review! I believe I've addressed all your comments and suggestions. Let me know in case you've any additional remarks :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @rijkvandermeulen, thanks for the updates!
I only had one suggestion for refactoring :) After that we can merge this one! 🚀
Dict[integer, Dict[str, TimeSeries]], | ||
Sequence[Dict[integer, Dict[str, TimeSeries]]], | ||
], | ||
shap_explanation_object: Optional[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the Optional can now be removed
).data | ||
|
||
assert_array_equal(feature_values.values(), comparison) | ||
self.assertTrue( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be assertEqual?
… code duplication
…ttps://github.com/rijkvandermeulen/darts into feature/add_feature_values_to_explainability_result
@dennisbader implemented your suggestions! :) BTW, not very important and also not something for this PR but I was thinking about whether it would make sense to at some point refactor the attributes in ExplainabilityResult (which are now nested dictionaries) to dataclasses? Might potentially make it even more robust and easier to maintain. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this @rijkvandermeulen, looks great :) 🚀
And I like your idea about refactoring the ExplainabilityResult attributes with dataclasses 👍
That could make our lives much easier also for extending the Explainability module.
* store feature values and shap.Explanation object in ExplainabilityResult * accounted for is_multioutputregressor * unit8co#1545 added entry to CHANGELOG.md * unit8co#1545 update docstrings for correctness API reference docs * unit8co#1580 create ShapExplainabilityResult subclass and remove decorator * unit8co#1580 adjust unit tests to have dedicated with statement per assert + other small stuff * unit8co#1580 change asserts in unit test from ExplainabilityResult to ShapExplainabilityResult * unit8co#1580 test get_feature_values() against raw output shap * unit8co#1580 adjust docstring * unit8co#1580 fixing small stuff * unit8co#1580 added one assert to unit test * unit8co#1545 implement _query_explainability_result() helper to avoid code duplication --------- Co-authored-by: Rijk van der Meulen <rijk.vandermeulen@eyeon.nl> Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com> Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
#1545
Summary
Currently the ExplainabilityResult object only contains the explained forecast (i.e., the shapley values for each prediction). I think it would be useful to also store the corresponding feature values and the SHAP explanation object itself. This will give the end user more options to work with the explainability results (e.g., creating plots, doing analysis etc). In any case, having this functionality added would help how we use darts a great deal so from a personal pov I'd love to see this added :)
Other Information
Since I added two methods that very much resemble 'get_explanation' I've written a decorator to take care of the validation steps. Main goal here was to avoid duplicating the validation checks across multiple methods. If you think this makes the code necessarily complex and/or less readable let me know - then we can change it back.
@hrzn @dumjax
WDYT? Do you see the added value of this? Any thoughts on how to make it better?