Skip to content

Commit

Permalink
add 'auto' value for importance_type param in plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Aug 29, 2021
1 parent ee5636f commit 569fcc8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
12 changes: 9 additions & 3 deletions python-package/lightgbm/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def plot_importance(
title: Optional[str] = 'Feature importance',
xlabel: Optional[str] = 'Feature importance',
ylabel: Optional[str] = 'Features',
importance_type: str = 'split',
importance_type: str = 'auto',
max_num_features: Optional[int] = None,
ignore_zero: bool = True,
figsize: Optional[Tuple[float, float]] = None,
Expand Down Expand Up @@ -65,8 +65,9 @@ def plot_importance(
ylabel : str or None, optional (default="Features")
Y-axis title label.
If None, title is disabled.
importance_type : str, optional (default="split")
importance_type : str, optional (default="auto")
How the importance is calculated.
If "auto", if ``booster`` parameter is LGBMModel, ``booster.importance_type`` attribute is used; "split" otherwise.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
max_num_features : int or None, optional (default=None)
Expand Down Expand Up @@ -96,8 +97,13 @@ def plot_importance(
raise ImportError('You must install matplotlib and restart your session to plot importance.')

if isinstance(booster, LGBMModel):
if importance_type == "auto":
importance_type = booster.importance_type
booster = booster.booster_
elif not isinstance(booster, Booster):
elif isinstance(booster, Booster):
if importance_type == "auto":
importance_type = "split"
else:
raise TypeError('booster must be Booster or LGBMModel.')

importance = booster.feature_importance(importance_type=importance_type)
Expand Down
22 changes: 20 additions & 2 deletions tests/python_package_test/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def test_plot_importance(params, breast_cancer_split, train_data):
for patch in ax1.patches:
assert patch.get_facecolor() == (1., 0, 0, 1.) # red

ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None)
ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'], title=None, xlabel=None, ylabel=None)
assert isinstance(ax2, matplotlib.axes.Axes)
assert ax2.get_title() == ''
assert ax2.get_xlabel() == ''
Expand All @@ -69,6 +68,25 @@ def test_plot_importance(params, breast_cancer_split, train_data):
assert ax2.patches[2].get_facecolor() == (0, .5, 0, 1.) # g
assert ax2.patches[3].get_facecolor() == (0, 0, 1., 1.) # b

gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True, importance_type="gain")
gbm2.fit(X_train, y_train)

def get_bounds_of_first_patch(axes):
return axes.patches[0].get_extents().bounds

first_bar1 = get_bounds_of_first_patch(lgb.plot_importance(gbm1))
first_bar2 = get_bounds_of_first_patch(lgb.plot_importance(gbm1, importance_type="split"))
first_bar3 = get_bounds_of_first_patch(lgb.plot_importance(gbm1, importance_type="gain"))
first_bar4 = get_bounds_of_first_patch(lgb.plot_importance(gbm2))
first_bar5 = get_bounds_of_first_patch(lgb.plot_importance(gbm2, importance_type="split"))
first_bar6 = get_bounds_of_first_patch(lgb.plot_importance(gbm2, importance_type="gain"))

assert first_bar1 == first_bar2
assert first_bar1 == first_bar5
assert first_bar3 == first_bar4
assert first_bar3 == first_bar6
assert first_bar1 != first_bar3


@pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed')
def test_plot_split_value_histogram(params, breast_cancer_split, train_data):
Expand Down

0 comments on commit 569fcc8

Please sign in to comment.