Skip to content

Commit

Permalink
Enabled feature_importances_ for our ForestDML and ForestDRLearner es…
Browse files Browse the repository at this point in the history
…timators (#306)

This required changing the subsampled honest forest code a bit so that it does not alter the arrays of the tree structures of sklearn but rather stores two additional arrays required for prediction. This does add around 1.5 times the original running time, so makes it slightly slower due to the extra memory allocation.

However this enables correct feature_importance calculation and also in the future correct SHAP calculation (fixes #297), as now the tree entries are consistent with a tree in a randomforestregressor and so shap logic can be applied if we recast the subsampled honest forest as a randomforestregressor (additivity of shap will still be violated since the prediction of the subsample honest forest is not just the aggregation of the predictions across the trees but more complex weighted average). But we can still call shap and still get meaningful shap numbers. One discrepancy is that shap is explaining a different value that what effect returns, since it explains the value that corresponds to the average of the predictions of each honest tree regressor. however, the prediction of an honest forest is not the average of the tree predictions. For a full solution to this small discrepancy, one would need a full re-working of Shap's tree explainer and the tree explainer algorithm to account for such alternative aggregations of tree predictors.

* changed subsampledhonest forest to not alter the entries of each tree but rather create auxiliary numpy arrays that store the numerator and denominator of every node. This enables consistent feature_importance calculation and also potentially more accurate shap_values calcualtion.

* added feature improtances in dr learner example notebook

* added feature_importances_ to DML example notebook

* enabled feature_importances_ for forestDML and forestDRLearner as an attribute

* fixed doctest in subsample honest forest which was producing old feature_importances_. Added tests that the feature_importances_ API is working in test_drlearner and test_dml.

* Transformed sparse matrices to dense matrices after dot product in parallel_add_trees_ of ensemble.py. This leads to 6 fold speed-up as we were doing many slicing operations to sparse matrices before, which are very slow!
  • Loading branch information
vsyrgkanis authored Nov 9, 2020
1 parent 0a66aa9 commit 61cd136
Show file tree
Hide file tree
Showing 9 changed files with 399 additions and 268 deletions.
31 changes: 30 additions & 1 deletion econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from .inference import BootstrapInference
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot, Summary
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete, NormalInferenceResults
LinearModelFinalInferenceDiscrete, NormalInferenceResults, GenericSingleTreatmentModelFinalInference,\
GenericModelFinalInferenceDiscrete


class BaseCateEstimator(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -664,6 +665,19 @@ def _get_inference_options(self):
return options


class ForestModelFinalCateEstimatorMixin(BaseCateEstimator):

def _get_inference_options(self):
# add blb to parent's options
options = super()._get_inference_options()
options.update(blb=GenericSingleTreatmentModelFinalInference)
return options

@property
def feature_importances_(self):
return self.model_final.feature_importances_


class LinearModelFinalCateEstimatorDiscreteMixin(BaseCateEstimator):
# TODO Share some logic with non-discrete version
"""
Expand Down Expand Up @@ -863,3 +877,18 @@ def _get_inference_options(self):
options = super()._get_inference_options()
options.update(debiasedlasso=LinearModelFinalInferenceDiscrete)
return options


class ForestModelFinalCateEstimatorDiscreteMixin(BaseCateEstimator):

def _get_inference_options(self):
# add blb to parent's options
options = super()._get_inference_options()
options.update(blb=GenericModelFinalInferenceDiscrete)
return options

def feature_importances_(self, T):
_, T = self._expand_treatments(None, T)
ind = inverse_onehot(T).item() - 1
assert ind >= 0, "No model was fitted for the control"
return self.fitted_models_final[ind].feature_importances_
11 changes: 3 additions & 8 deletions econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
from sklearn.utils import check_random_state
from .cate_estimator import (BaseCateEstimator, LinearCateEstimator,
TreatmentExpansionMixin, StatsModelsCateEstimatorMixin,
LinearModelFinalCateEstimatorMixin, DebiasedLassoCateEstimatorMixin)
LinearModelFinalCateEstimatorMixin, DebiasedLassoCateEstimatorMixin,
ForestModelFinalCateEstimatorMixin)
from .inference import StatsModelsInference, GenericSingleTreatmentModelFinalInference
from ._rlearner import _RLearner
from .sklearn_extensions.model_selection import WeightedStratifiedKFold
Expand Down Expand Up @@ -866,7 +867,7 @@ def __init__(self,
random_state=random_state)


class ForestDMLCateEstimator(NonParamDMLCateEstimator):
class ForestDMLCateEstimator(ForestModelFinalCateEstimatorMixin, NonParamDMLCateEstimator):
""" Instance of NonParamDMLCateEstimator with a
:class:`~econml.sklearn_extensions.ensemble.SubsampledHonestForest`
as a final model, so as to enable non-parametric inference.
Expand Down Expand Up @@ -1058,12 +1059,6 @@ def __init__(self,
categories=categories,
n_splits=n_crossfit_splits, random_state=random_state)

def _get_inference_options(self):
# add statsmodels to parent's options
options = super()._get_inference_options()
options.update(blb=GenericSingleTreatmentModelFinalInference)
return options

def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference=None):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand Down
11 changes: 3 additions & 8 deletions econml/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
from econml.sklearn_extensions.linear_model import WeightedLassoCVWrapper, DebiasedLasso
from econml.sklearn_extensions.ensemble import SubsampledHonestForest
from econml._ortho_learner import _OrthoLearner
from econml.cate_estimator import StatsModelsCateEstimatorDiscreteMixin, DebiasedLassoCateEstimatorDiscreteMixin
from econml.cate_estimator import StatsModelsCateEstimatorDiscreteMixin, DebiasedLassoCateEstimatorDiscreteMixin,\
ForestModelFinalCateEstimatorDiscreteMixin
from econml.inference import GenericModelFinalInferenceDiscrete


Expand Down Expand Up @@ -949,7 +950,7 @@ def fitted_models_final(self):
return super().model_final.models_cate


class ForestDRLearner(DRLearner):
class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner):
""" Instance of DRLearner with a :class:`~econml.sklearn_extensions.ensemble.SubsampledHonestForest`
as a final model, so as to enable non-parametric inference.
Expand Down Expand Up @@ -1144,12 +1145,6 @@ def __init__(self,
categories=categories,
n_splits=n_crossfit_splits, random_state=random_state)

def _get_inference_options(self):
# add statsmodels to parent's options
options = super()._get_inference_options()
options.update(blb=GenericModelFinalInferenceDiscrete)
return options

def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference=None):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand Down
93 changes: 66 additions & 27 deletions econml/sklearn_extensions/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

""" Subsampled honest forest extension to scikit-learn's forest methods.
""" Subsampled honest forest extension to scikit-learn's forest methods. Contains pieces of code from
scikit-learn's random forest implementation.
"""

import numpy as np
Expand Down Expand Up @@ -57,20 +58,42 @@ def _parallel_add_trees(tree, forest, X, y, sample_weight, s_inds, tree_idx, n_t
path_est = tree.decision_path(X_est)
# Calculate the total weight of estimation samples on each tree node:
# \sum_i sample_weight[i] * 1{i \\in node}
weight_est = scipy.sparse.csr_matrix(
sample_weight_est.reshape(1, -1)).dot(path_est)
weight_est = sample_weight_est.reshape(1, -1) @ path_est
# Calculate the total number of estimation samples on each tree node:
# |node| = \sum_{i} 1{i \\in node}
count_est = np.sum(path_est, axis=0)
count_est = path_est.sum(axis=0)
# Calculate the weighted sum of responses on the estimation sample on each node:
# \sum_{i} sample_weight[i] 1{i \\in node} Y_i
value_est = scipy.sparse.csr_matrix(
(sample_weight_est.reshape(-1, 1) * y_est).T).dot(path_est)
num_est = (sample_weight_est.reshape(-1, 1) * y_est).T @ path_est
# Calculate the predicted value on each node based on the estimation sample:
# weighted sum of responses / total weight
value_est = num_est / weight_est

# Calculate the criterion on each node based on the estimation sample and for each output dimension,
# summing the impurity across dimensions.
# First we calculate the difference of observed label y of each node and predicted value for each
# node that the sample falls in: y[i] - value_est[node]
impurity_est = np.zeros((1, path_est.shape[1]))
for i in range(tree.n_outputs_):
diff = path_est.multiply(y_est[:, [i]]) - path_est.multiply(value_est[[i], :])
if tree.criterion == 'mse':
# If criterion is mse then calculate weighted sum of squared differences for each node
impurity_est_i = sample_weight_est.reshape(1, -1) @ diff.power(2)
elif tree.criterion == 'mae':
# If criterion is mae then calculate weighted sum of absolute differences for each node
impurity_est_i = sample_weight_est.reshape(1, -1) @ np.abs(diff)
else:
raise AttributeError("Criterion {} not yet supported by SubsampledHonestForest!".format(tree.criterion))
# Normalize each weighted sum of criterion for each node by the total weight of each node
impurity_est += impurity_est_i / weight_est

# Prune tree to remove leafs that don't satisfy the leaf requirements on the estimation sample
# and for each un-pruned tree set the value and the weight appropriately.
children_left = tree.tree_.children_left
children_right = tree.tree_.children_right
stack = [(0, -1)] # seed is the root node id and its parent depth
numerator = np.empty_like(tree.tree_.value)
denominator = np.empty_like(tree.tree_.weighted_n_node_samples)
while len(stack) > 0:
node_id, parent_id = stack.pop()
# If minimum weight requirement or minimum leaf size requirement is not satisfied on estimation
Expand All @@ -81,17 +104,24 @@ def _parallel_add_trees(tree, forest, X, y, sample_weight, s_inds, tree_idx, n_t
tree.tree_.children_right[parent_id] = -1
else:
for i in range(tree.n_outputs_):
# Set the value of the node to: \sum_{i} sample_weight[i] 1{i \\in node} Y_i / |node|
tree.tree_.value[node_id, i] = value_est[i, node_id] / count_est[0, node_id]
# Set the weight of the node to: \sum_{i} sample_weight[i] 1{i \\in node} / |node|
tree.tree_.weighted_n_node_samples[node_id] = weight_est[0, node_id] / count_est[0, node_id]
# Set the numerator of the node to: \sum_{i} sample_weight[i] 1{i \\in node} Y_i / |node|
numerator[node_id, i] = num_est[i, node_id] / count_est[0, node_id]
# Set the value of the node to:
# \sum_{i} sample_weight[i] 1{i \\in node} Y_i / \sum_{i} sample_weight[i] 1{i \\in node}
tree.tree_.value[node_id, i] = value_est[i, node_id]
# Set the denominator of the node to: \sum_{i} sample_weight[i] 1{i \\in node} / |node|
denominator[node_id] = weight_est[0, node_id] / count_est[0, node_id]
# Set the weight of the node to: \sum_{i} sample_weight[i] 1{i \\in node}
tree.tree_.weighted_n_node_samples[node_id] = weight_est[0, node_id]
# Set the count to the estimation split count
tree.tree_.n_node_samples[node_id] = count_est[0, node_id]
# Set the node impurity to the estimation split impurity
tree.tree_.impurity[node_id] = impurity_est[0, node_id]
if (children_left[node_id] != children_right[node_id]):
stack.append((children_left[node_id], node_id))
stack.append((children_right[node_id], node_id))

return tree
return tree, numerator, denominator


class SubsampledHonestForest(ForestRegressor, RegressorMixin):
Expand Down Expand Up @@ -314,7 +344,7 @@ class SubsampledHonestForest(ForestRegressor, RegressorMixin):
>>> regr.fit(X_train, y_train)
SubsampledHonestForest(n_estimators=1000, random_state=0)
>>> regr.feature_importances_
array([0.40..., 0.35..., 0.11..., 0.11...])
array([0.64..., 0.33..., 0.01..., 0.01...])
>>> regr.predict(np.ones((1, 4)))
array([112.9...])
>>> regr.predict_interval(np.ones((1, 4)), alpha=.05)
Expand Down Expand Up @@ -483,6 +513,8 @@ def fit(self, X, y, sample_weight=None, sample_var=None):
if not self.warm_start or not hasattr(self, "estimators_"):
# Free allocated memory, if any
self.estimators_ = []
self.numerators_ = []
self.denominators_ = []

n_more_estimators = self.n_estimators - len(self.estimators_)

Expand Down Expand Up @@ -523,21 +555,25 @@ def fit(self, X, y, sample_weight=None, sample_var=None):
int(np.ceil(self.subsample_fr_ *
(X.shape[0] // 2))),
replace=False)])
trees = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
**_joblib_parallel_args(prefer='threads'))(
res = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
**_joblib_parallel_args(prefer='threads'))(
delayed(_parallel_add_trees)(
t, self, X, y, sample_weight, s_inds[i], i, len(trees),
verbose=self.verbose)
for i, t in enumerate(trees))
trees, numerators, denominators = zip(*res)
# Collect newly grown trees
self.estimators_.extend(trees)
self.numerators_.extend(numerators)
self.denominators_.extend(denominators)

return self

def _mean_fn(self, X, fn, acc, slice=None):
# Helper class that accumulates an arbitrary function in parallel on the accumulator acc
# and calls the function fn on each tree e and returns the mean output. The function fn
# should take as input a tree e, and return another function g_e, which takes as input X, check_input
# should take as input a tree e and associated numerator n and denominator d structures and
# return another function g_e, which takes as input X, check_input
# If slice is not None, but rather a tuple (start, end), then a subset of the trees from
# index start to index end will be used. The returned result is essentially:
# (mean over e in slice)(g_e(X)).
Expand All @@ -546,18 +582,21 @@ def _mean_fn(self, X, fn, acc, slice=None):
X = self._validate_X_predict(X)

if slice is None:
estimator_slice = self.estimators_
estimator_slice = zip(self.estimators_, self.numerators_, self.denominators_)
n_estimators = len(self.estimators_)
else:
estimator_slice = self.estimators_[slice[0]:slice[1]]
estimator_slice = zip(self.estimators_[slice[0]:slice[1]], self.numerators_[slice[0]:slice[1]],
self.denominators_[slice[0]:slice[1]])
n_estimators = slice[1] - slice[0]

# Assign chunk of trees to jobs
n_jobs, _, _ = _partition_estimators(len(estimator_slice), self.n_jobs)
n_jobs, _, _ = _partition_estimators(n_estimators, self.n_jobs)
lock = threading.Lock()
Parallel(n_jobs=n_jobs, verbose=self.verbose,
**_joblib_parallel_args(require="sharedmem"))(
delayed(_accumulate_prediction)(fn(e), X, [acc], lock)
for e in estimator_slice)
acc /= len(estimator_slice)
delayed(_accumulate_prediction)(fn(e, n, d), X, [acc], lock)
for e, n, d in estimator_slice)
acc /= n_estimators
return acc

def _weight(self, X, slice=None):
Expand All @@ -582,7 +621,7 @@ def _weight(self, X, slice=None):
# Check data
X = self._validate_X_predict(X)
weight_hat = np.zeros((X.shape[0]), dtype=np.float64)
return self._mean_fn(X, lambda e: (lambda x, check_input: e.tree_.weighted_n_node_samples[e.apply(x)]),
return self._mean_fn(X, lambda e, n, d: (lambda x, check_input: d[e.apply(x)]),
weight_hat, slice=slice)

def _predict(self, X, slice=None):
Expand Down Expand Up @@ -610,12 +649,12 @@ def _predict(self, X, slice=None):
# Check data
X = self._validate_X_predict(X)
# avoid storing the output of every estimator by summing them here
if self.n_outputs_ > 1:
y_hat = np.zeros((X.shape[0], self.n_outputs_), dtype=np.float64)
else:
y_hat = np.zeros((X.shape[0]), dtype=np.float64)
y_hat = np.zeros((X.shape[0], self.n_outputs_), dtype=np.float64)
y_hat = self._mean_fn(X, lambda e, n, d: (lambda x, check_input: n[e.apply(x), :, 0]), y_hat, slice=slice)
if self.n_outputs_ == 1:
y_hat = y_hat.flatten()

return self._mean_fn(X, lambda e: e.predict, y_hat, slice=slice)
return y_hat

def _inference(self, X, stderr=False):
"""
Expand Down
4 changes: 4 additions & 0 deletions econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,10 @@ def make_random(is_discrete, d):
eff = est.effect(X, T0=T0, T1=T)
self.assertEqual(shape(eff), effect_shape)

if isinstance(est, ForestDMLCateEstimator):
np.testing.assert_array_equal(est.feature_importances_.shape,
[X.shape[1]])

if inf is not None:
const_marg_eff_int = est.const_marginal_effect_interval(X)
marg_eff_int = est.marginal_effect_interval(T, X)
Expand Down
5 changes: 5 additions & 0 deletions econml/tests/test_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,11 @@ def test_drlearner_with_inference_all_attributes(self):
# test summary function works
est.summary(t)

if isinstance(est, ForestDRLearner):
for t in [1, 2]:
np.testing.assert_array_equal(est.feature_importances_(t).shape,
[X.shape[1]])

@staticmethod
def _check_with_interval(truth, point, lower, upper):
np.testing.assert_allclose(point, truth, rtol=0, atol=.2)
Expand Down
Loading

0 comments on commit 61cd136

Please sign in to comment.