diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 3aa100f6e304..3b8c73fc9d54 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -5,6 +5,11 @@ import numpy as np import warnings +try: + import pandas as pd + _IS_PANDAS_INSTALLED = True +except ImportError: + _IS_PANDAS_INSTALLED = False from .basic import Dataset, LightGBMError from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase, LGBMDeprecated, @@ -332,7 +337,7 @@ def fit(self, X, y, categorical_feature : list of strings or int, or 'auto', optional (default="auto") Categorical features. If list of int, interpreted as indices. - If list of strings, interpreted as feature names (need to specify feature_name as well). + If list of strings, interpreted as feature names (need to specify ``feature_name`` as well). If 'auto' and data is pandas DataFrame, pandas categorical columns are used. callbacks : list of callback functions or None, optional (default=None) List of callback functions that are applied at each iteration. @@ -407,8 +412,10 @@ def fit(self, X, y, feval = None params['metric'] = eval_metric - X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2) - _LGBMCheckConsistentLength(X, y, sample_weight) + if not _IS_PANDAS_INSTALLED or not isinstance(X, pd.DataFrame): + X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2) + _LGBMCheckConsistentLength(X, y, sample_weight) + self._n_features = X.shape[1] def _construct_dataset(X, y, sample_weight, init_score, group, params): @@ -482,7 +489,8 @@ def predict(self, X, raw_score=False, num_iteration=0): """ if self._n_features is None: raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") - X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) + if not _IS_PANDAS_INSTALLED or not isinstance(X, pd.DataFrame): + X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) n_features = X.shape[1] if self._n_features != n_features: raise ValueError("Number of features of the model must " @@ -508,7 +516,8 @@ def apply(self, X, num_iteration=0): """ if self._n_features is None: raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") - X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) + if not _IS_PANDAS_INSTALLED or not isinstance(X, pd.DataFrame): + X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) n_features = X.shape[1] if self._n_features != n_features: raise ValueError("Number of features of the model must " @@ -686,7 +695,8 @@ def predict_proba(self, X, raw_score=False, num_iteration=0): """ if self._n_features is None: raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") - X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) + if not _IS_PANDAS_INSTALLED or not isinstance(X, pd.DataFrame): + X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) n_features = X.shape[1] if self._n_features != n_features: raise ValueError("Number of features of the model must " diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index e38854089b5c..d9e7cfd3a024 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -78,7 +78,7 @@ def test_rf(self): self.assertLess(ret, 0.25) self.assertAlmostEqual(evals_result['valid_0']['binary_logloss'][-1], ret, places=5) - def test_regreesion(self): + def test_regression(self): X, y = load_boston(True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) params = { @@ -444,7 +444,6 @@ def test_pandas_categorical(self): gbm3 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False, categorical_feature=['A', 'B', 'C', 'D']) pred3 = list(gbm3.predict(X_test)) - lgb_train = lgb.Dataset(X, y) gbm3.save_model('categorical.model') gbm4 = lgb.Booster(model_file='categorical.model') pred4 = list(gbm4.predict(X_test)) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 0ee655d6799a..606395bb5053 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -19,6 +19,11 @@ sklearn_at_least_019 = True except ImportError: sklearn_at_least_019 = False +try: + import pandas as pd + IS_PANDAS_INSTALLED = True +except ImportError: + IS_PANDAS_INSTALLED = False def multi_error(y_true, y_pred): @@ -40,7 +45,7 @@ def test_binary(self): self.assertLess(ret, 0.15) self.assertAlmostEqual(ret, gbm.evals_result_['valid_0']['binary_logloss'][gbm.best_iteration_ - 1], places=5) - def test_regreesion(self): + def test_regression(self): X, y = load_boston(True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) gbm = lgb.LGBMRegressor(n_estimators=50, silent=True) @@ -194,3 +199,34 @@ def test_sklearn_integration(self): check(name, estimator) except SkipTest as message: warnings.warn(message, SkipTestWarning) + + @unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed') + def test_pandas_categorical(self): + X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str + "B": np.random.permutation([1, 2, 3] * 100), # int + "C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float + "D": np.random.permutation([True, False] * 150)}) # bool + y = np.random.permutation([0, 1] * 150) + X_test = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'e'] * 20), + "B": np.random.permutation([1, 3] * 30), + "C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15), + "D": np.random.permutation([True, False] * 30)}) + for col in ["A", "B", "C", "D"]: + X[col] = X[col].astype('category') + X_test[col] = X_test[col].astype('category') + gbm0 = lgb.sklearn.LGBMClassifier().fit(X, y) + pred0 = list(gbm0.predict(X_test)) + gbm1 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=[0]) + pred1 = list(gbm1.predict(X_test)) + gbm2 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A']) + pred2 = list(gbm2.predict(X_test)) + gbm3 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A', 'B', 'C', 'D']) + pred3 = list(gbm3.predict(X_test)) + gbm3.booster_.save_model('categorical.model') + gbm4 = lgb.Booster(model_file='categorical.model') + pred4 = list(gbm4.predict(X_test)) + pred_prob = list(gbm0.predict_proba(X_test)[:, 1]) + np.testing.assert_almost_equal(pred0, pred1) + np.testing.assert_almost_equal(pred0, pred2) + np.testing.assert_almost_equal(pred0, pred3) + np.testing.assert_almost_equal(pred_prob, pred4)