-
Notifications
You must be signed in to change notification settings - Fork 726
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enabled feature_importances_ for our ForestDML and ForestDRLearner es…
…timators (#306) This required changing the subsampled honest forest code a bit so that it does not alter the arrays of the tree structures of sklearn but rather stores two additional arrays required for prediction. This does add around 1.5 times the original running time, so makes it slightly slower due to the extra memory allocation. However this enables correct feature_importance calculation and also in the future correct SHAP calculation (fixes #297), as now the tree entries are consistent with a tree in a randomforestregressor and so shap logic can be applied if we recast the subsampled honest forest as a randomforestregressor (additivity of shap will still be violated since the prediction of the subsample honest forest is not just the aggregation of the predictions across the trees but more complex weighted average). But we can still call shap and still get meaningful shap numbers. One discrepancy is that shap is explaining a different value that what effect returns, since it explains the value that corresponds to the average of the predictions of each honest tree regressor. however, the prediction of an honest forest is not the average of the tree predictions. For a full solution to this small discrepancy, one would need a full re-working of Shap's tree explainer and the tree explainer algorithm to account for such alternative aggregations of tree predictors. * changed subsampledhonest forest to not alter the entries of each tree but rather create auxiliary numpy arrays that store the numerator and denominator of every node. This enables consistent feature_importance calculation and also potentially more accurate shap_values calcualtion. * added feature improtances in dr learner example notebook * added feature_importances_ to DML example notebook * enabled feature_importances_ for forestDML and forestDRLearner as an attribute * fixed doctest in subsample honest forest which was producing old feature_importances_. Added tests that the feature_importances_ API is working in test_drlearner and test_dml. * Transformed sparse matrices to dense matrices after dot product in parallel_add_trees_ of ensemble.py. This leads to 6 fold speed-up as we were doing many slicing operations to sparse matrices before, which are very slow!
- Loading branch information
1 parent
0a66aa9
commit 61cd136
Showing
9 changed files
with
399 additions
and
268 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.