Skip to content

Commit

Permalink
use __save_model_to_string for feature importance
Browse files Browse the repository at this point in the history
  • Loading branch information
wxchan committed Jun 13, 2017
1 parent 760e9eb commit dca6a85
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
37 changes: 23 additions & 14 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,20 +1661,29 @@ def feature_importance(self, importance_type='split'):
"""
if importance_type not in ["split", "gain"]:
raise KeyError("importance_type must be split or gain")
dump_model = self.dump_model()
ret = [0] * (dump_model["max_feature_idx"] + 1)

def dfs(root):
if "split_feature" in root:
if root['split_gain'] > 0:
if importance_type == 'split':
ret[root["split_feature"]] += 1
elif importance_type == 'gain':
ret[root["split_feature"]] += root["split_gain"]
dfs(root["left_child"])
dfs(root["right_child"])
for tree in dump_model["tree_info"]:
dfs(tree["tree_structure"])
dump_model = self.__save_model_to_string()
type_is_gain = importance_type == 'gain'
split_feature = []
split_gain = []
max_feature_idx = None
for line in dump_model.split('\n'):
if line.startswith('split_feature='):
tokens = [int(token) for token in line[len('split_feature='):].split()]
split_feature.append(tokens)
elif type_is_gain and line.startswith('split_gain='):
tokens = [float(token) for token in line[len('split_gain='):].split()]
split_gain.append(tokens)
elif max_feature_idx is None and line.startswith('max_feature_idx='):
max_feature_idx = int(line[len('max_feature_idx='):])
ret = [0] * (max_feature_idx + 1)
if type_is_gain:
for features, gains in zip(split_feature, split_gain):
for feature, gain in zip(features, gains):
ret[feature] += gain
else:
for features in split_feature:
for feature in features:
ret[feature] += 1
return np.array(ret)

def __inner_eval(self, data_name, data_idx, feval=None):
Expand Down
2 changes: 1 addition & 1 deletion src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ int Tree::Split(int leaf, int feature, BinType bin_type, uint32_t threshold_bin,
}

threshold_in_bin_[new_node_idx] = threshold_bin;
threshold_[new_node_idx] = threshold_double;
threshold_[new_node_idx] = Common::AvoidInf(threshold_double);
split_gain_[new_node_idx] = Common::AvoidInf(gain);
// add two new leaves
left_child_[new_node_idx] = ~leaf;
Expand Down

0 comments on commit dca6a85

Please sign in to comment.