Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python][sklearn] add __sklearn_is_fitted__() method to be better compatible with scikit-learn API #4636

Merged
merged 2 commits into from
Oct 5, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ def _more_tags(self):
}
}

def __sklearn_is_fitted__(self) -> bool:
return getattr(self, "fitted_", False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We set fitted_ attribute at the end of the fit() method

self.fitted_ = True

Copy link
Collaborator

@jameslamb jameslamb Oct 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed in the dmlc/xgboost#7230, they also replaced some other internal checks of estimator attributes with this method call.

I think that's a good idea, so only this method needs to know about the fitted_ property. What do you think?

That would mean changing the following to if not self.__sklearn_is_fitted__():

if not getattr(self, "fitted_", False):
raise LGBMNotFittedError('Cannot access property client_ before calling fit().')

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I remember you were asking why do we use if self._n_features is None everywhere for checking that estimator is fitted when we have self.fitted_.

if self._n_features is None:

#3883 (comment)
My answer was that it's because self.fitted_ was introduced much later. Now we can replace everything with self.__sklearn_is_fitted__(). And I think it will be clearer.

But I'd prefer to do it in a follow-up PR, if you do not mind.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok yep, I'm fine with it being a followup


def get_params(self, deep=True):
"""Get parameters for this estimator.

Expand Down