Skip to content

Commit

Permalink
Save Scikit-Learn attributes into learner attributes. (#5245)
Browse files Browse the repository at this point in the history
* Remove the recommendation for pickle.

* Save skl attributes in booster.attr

* Test loading scikit-learn model with native booster.
  • Loading branch information
trivialfis authored Jan 30, 2020
1 parent c671632 commit 472ded5
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 57 deletions.
8 changes: 8 additions & 0 deletions doc/tutorials/saving_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ again after the model is loaded. If the customized function is useful, please co
making a PR for implementing it inside XGBoost, this way we can have your functions
working with different language bindings.

******************************************************
Loading pickled file from different version of XGBoost
******************************************************

As noted, pickled model is neither portable nor stable, but in some cases the pickled
models are valuable. One way to restore it in the future is to load it back with that
specific version of Python and XGBoost, export the model by calling `save_model`.

********************************************************
Saving and Loading the internal parameters configuration
********************************************************
Expand Down
29 changes: 26 additions & 3 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import abc
import os
import sys

from pathlib import PurePath

import numpy as np

assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'

# pylint: disable=invalid-name, redefined-builtin
Expand Down Expand Up @@ -148,7 +149,29 @@ class DataTable(object):

XGBKFold = KFold
XGBStratifiedKFold = StratifiedKFold
XGBLabelEncoder = LabelEncoder

class XGBoostLabelEncoder(LabelEncoder):
'''Label encoder with JSON serialization methods.'''
def to_json(self):
'''Returns a JSON compatible dictionary'''
meta = dict()
for k, v in self.__dict__.items():
if isinstance(v, np.ndarray):
meta[k] = v.tolist()
else:
meta[k] = v
return meta

def from_json(self, doc):
# pylint: disable=attribute-defined-outside-init
'''Load the encoder back from a JSON compatible dict.'''
meta = dict()
for k, v in doc.items():
if k == 'classes_':
self.classes_ = np.array(v)
continue
meta[k] = v
self.__dict__.update(meta)
except ImportError:
SKLEARN_INSTALLED = False

Expand All @@ -159,7 +182,7 @@ class DataTable(object):

XGBKFold = None
XGBStratifiedKFold = None
XGBLabelEncoder = None
XGBoostLabelEncoder = None


# dask
Expand Down
118 changes: 83 additions & 35 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Do not use class names on scikit-learn directly. Re-define the classes on
# .compat to guarantee the behavior without scikit-learn
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
XGBClassifierBase, XGBRegressorBase, XGBoostLabelEncoder)


def _objective_decorator(func):
Expand Down Expand Up @@ -330,54 +330,96 @@ def get_num_boosting_rounds(self):
"""Gets the number of xgboost boosting rounds."""
return self.n_estimators

def save_model(self, fname):
"""
Save the model to a file.
def save_model(self, fname: str):
"""Save the model to a file.
The model is saved in an XGBoost internal format which is universal
among the various XGBoost interfaces. Auxiliary attributes of the
Python Booster object (such as feature names) will not be saved.
.. note::
See:
The model is saved in an XGBoost internal binary format which is
universal among the various XGBoost interfaces. Auxiliary attributes of
the Python Booster object (such as feature names) will not be loaded.
Label encodings (text labels to numeric labels) will be also lost.
**If you are using only the Python interface, we recommend pickling the
model object for best results.**
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
Parameters
----------
fname : string
Output file name
"""
warnings.warn("save_model: Useful attributes in the Python " +
"object {} will be lost. ".format(type(self).__name__) +
"If you did not mean to export the model to " +
"a non-Python binding of XGBoost, consider " +
"using `pickle` or `joblib` to save your model.",
Warning)
meta = dict()
for k, v in self.__dict__.items():
if k == '_le':
meta['_le'] = self._le.to_json()
continue
if k == '_Booster':
continue
if k == 'classes_':
# numpy array is not JSON serializable
meta['classes_'] = self.classes_.tolist()
continue
try:
json.dumps({k: v})
meta[k] = v
except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
meta['type'] = type(self).__name__
meta = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta)
self.get_booster().save_model(fname)
# Delete the attribute after save
self.get_booster().set_attr(scikit_learn=None)

def load_model(self, fname):
"""
Load the model from a file.
# pylint: disable=attribute-defined-outside-init
"""Load the model from a file.
The model is loaded from an XGBoost internal binary format which is
universal among the various XGBoost interfaces. Auxiliary attributes of
the Python Booster object (such as feature names) will not be loaded.
Label encodings (text labels to numeric labels) will be also lost.
**If you are using only the Python interface, we recommend pickling the
model object for best results.**
The model is loaded from an XGBoost internal format which is universal
among the various XGBoost interfaces. Auxiliary attributes of the
Python Booster object (such as feature names) will not be loaded.
Parameters
----------
fname : string or a memory buffer
Input file name or memory buffer(see also save_raw)
fname : string
Input file name.
"""
if self._Booster is None:
self._Booster = Booster({'n_jobs': self.n_jobs})
self._Booster.load_model(fname)
meta = self._Booster.attr('scikit_learn')
if meta is None:
warnings.warn(
'Loading a native XGBoost model with Scikit-Learn interface.')
return
meta = json.loads(meta)
states = dict()
for k, v in meta.items():
if k == '_le':
self._le = XGBoostLabelEncoder()
self._le.from_json(v)
continue
if k == 'classes_':
self.classes_ = np.array(v)
continue
if k == 'type' and type(self).__name__ != v:
msg = f'Current model type: {type(self).__name__}, ' + \
f'type of model in file: {v}'
raise TypeError(msg)
if k == 'type':
continue
states[k] = v
self.__dict__.update(states)
# Delete the attribute after load
self.get_booster().set_attr(scikit_learn=None)

def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, xgb_model=None, sample_weight_eval_set=None, callbacks=None):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
verbose=True, xgb_model=None, sample_weight_eval_set=None,
callbacks=None):
# pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model
Parameters
Expand Down Expand Up @@ -678,7 +720,7 @@ def intercept_(self):
"Implementation of the scikit-learn API for XGBoost classification.",
['model', 'objective'])
class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
def __init__(self, objective="binary:logistic", **kwargs):
super().__init__(objective=objective, **kwargs)

Expand Down Expand Up @@ -714,7 +756,7 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
else:
xgb_options.update({"eval_metric": eval_metric})

self._le = XGBLabelEncoder().fit(y)
self._le = XGBoostLabelEncoder().fit(y)
training_labels = self._le.transform(y)

if eval_set is not None:
Expand Down Expand Up @@ -809,10 +851,11 @@ def predict(self, data, output_margin=False, ntree_limit=None,
missing=self.missing, nthread=self.n_jobs)
if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0)
class_probs = self.get_booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit,
validate_features=validate_features)
class_probs = self.get_booster().predict(
test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit,
validate_features=validate_features)
if output_margin:
# If output_margin is active, simply return the scores
return class_probs
Expand All @@ -822,7 +865,12 @@ def predict(self, data, output_margin=False, ntree_limit=None,
else:
column_indexes = np.repeat(0, class_probs.shape[0])
column_indexes[class_probs > 0.5] = 1
return self._le.inverse_transform(column_indexes)

if hasattr(self, '_le'):
return self._le.inverse_transform(column_indexes)
warnings.warn(
'Label encoder is not defined. Returning class probability.')
return class_probs

def predict_proba(self, data, ntree_limit=None, validate_features=True,
base_margin=None):
Expand Down
2 changes: 1 addition & 1 deletion src/common/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class PeekableInStream : public dmlc::Stream {
class FixedSizeStream : public PeekableInStream {
public:
explicit FixedSizeStream(PeekableInStream* stream);
~FixedSizeStream() = default;
~FixedSizeStream() override = default;

size_t Read(void* dptr, size_t size) override;
size_t PeekRead(void* dptr, size_t size) override;
Expand Down
12 changes: 11 additions & 1 deletion src/common/json.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*!
* Copyright (c) by Contributors 2019
*/
#include <cctype>
#include <sstream>
#include <limits>
#include <cmath>
Expand Down Expand Up @@ -351,7 +352,9 @@ Json JsonReader::Parse() {
return ParseObject();
} else if ( c == '[' ) {
return ParseArray();
} else if ( c == '-' || std::isdigit(c) ) {
} else if ( c == '-' || std::isdigit(c) ||
c == 'N' ) {
// For now we only accept `NaN`, not `nan` as the later violiates LR(1) with `null`.
return ParseNumber();
} else if ( c == '\"' ) {
return ParseString();
Expand Down Expand Up @@ -547,6 +550,13 @@ Json JsonReader::ParseNumber() {

// TODO(trivialfis): Add back all the checks for number
bool negative = false;
if (XGBOOST_EXPECT(*p == 'N', false)) {
GetChar('N');
GetChar('a');
GetChar('N');
return Json(static_cast<Number::Float>(std::numeric_limits<float>::quiet_NaN()));
}

if ('-' == *p) {
++p;
negative = true;
Expand Down
10 changes: 5 additions & 5 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -661,13 +661,13 @@ class LearnerImpl : public Learner {
CHECK(header == serialisation_header_) // NOLINT
<< R"doc(
If you are loading a serialized model (like pickle in Python) generated by older XGBoost,
please export the model by calling `Booster.save_model` from that version first, then load
it back in current version. See:
If you are loading a serialized model (like pickle in Python) generated by older
XGBoost, please export the model by calling `Booster.save_model` from that version
first, then load it back in current version. See:
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
for more details about differences between saving model and serializing.
for more details about differences between saving model and serializing.
)doc";
int64_t json_offset {-1};
Expand Down
3 changes: 2 additions & 1 deletion tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def json_model(model_path, parameters):
class TestModels(unittest.TestCase):
def test_glm(self):
param = {'verbosity': 0, 'objective': 'binary:logistic',
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1, 'nthread': 1}
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1,
'nthread': 1}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4
bst = xgb.train(param, dtrain, num_round, watchlist)
Expand Down
Loading

0 comments on commit 472ded5

Please sign in to comment.