diff --git a/docs/Python-Intro.rst b/docs/Python-Intro.rst index d5e89d153888..77147c0857d6 100644 --- a/docs/Python-Intro.rst +++ b/docs/Python-Intro.rst @@ -36,7 +36,7 @@ The LightGBM Python module can load data from: - libsvm/tsv/csv/txt format file -- NumPy 2D array(s), pandas DataFrame, SciPy sparse matrix +- NumPy 2D array(s), pandas DataFrame, H2O DataTable, SciPy sparse matrix - LightGBM binary file diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 9d09cc392e02..7db64e7125cf 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -13,7 +13,7 @@ import numpy as np import scipy.sparse -from .compat import (DataFrame, Series, +from .compat import (DataFrame, Series, DataTable, decode_string, string_type, integer_types, numeric_types, json, json_default_with_numpy, @@ -409,7 +409,7 @@ def predict(self, data, num_iteration=-1, Parameters ---------- - data : string, numpy array, pandas DataFrame or scipy.sparse + data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse Data source for prediction. When data type is string, it represents the path of txt file. num_iteration : int, optional (default=-1) @@ -471,6 +471,8 @@ def predict(self, data, num_iteration=-1, except BaseException: raise ValueError('Cannot convert data list to numpy array.') preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type) + elif isinstance(data, DataTable): + preds, nrow = self.__pred_for_np2d(data.to_numpy(), num_iteration, predict_type) else: try: warnings.warn('Converting data to scipy sparse matrix.') @@ -650,7 +652,7 @@ def __init__(self, data, label=None, reference=None, Parameters ---------- - data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays + data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse or list of numpy arrays Data source of Dataset. If string, it represents the path to txt file. label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) @@ -789,6 +791,8 @@ def _lazy_init(self, data, label=None, reference=None, self.__init_from_np2d(data, params_str, ref_dataset) elif isinstance(data, list) and len(data) > 0 and all(isinstance(x, np.ndarray) for x in data): self.__init_from_list_np2d(data, params_str, ref_dataset) + elif isinstance(data, DataTable): + self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset) else: try: csr = scipy.sparse.csr_matrix(data) @@ -1005,7 +1009,7 @@ def create_valid(self, data, label=None, weight=None, group=None, Parameters ---------- - data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays + data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse or list of numpy arrays Data source of Dataset. If string, it represents the path to txt file. label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) @@ -1395,7 +1399,7 @@ def get_data(self): Returns ------- - data : string, numpy array, pandas DataFrame, scipy.sparse, list of numpy arrays or None + data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse, list of numpy arrays or None Raw data used in the Dataset construction. """ if self.handle is None: @@ -1405,6 +1409,8 @@ def get_data(self): self.data = self.data[self.used_indices, :] elif isinstance(self.data, DataFrame): self.data = self.data.iloc[self.used_indices].copy() + elif isinstance(self.data, DataTable): + self.data = self.data[self.used_indices, :] else: warnings.warn("Cannot subset {} type of raw data.\n" "Returning original raw data".format(type(self.data).__name__)) @@ -2156,7 +2162,7 @@ def predict(self, data, num_iteration=None, Parameters ---------- - data : string, numpy array, pandas DataFrame or scipy.sparse + data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse Data source for prediction. If string, it represents the path to txt file. num_iteration : int or None, optional (default=None) @@ -2201,7 +2207,7 @@ def refit(self, data, label, decay_rate=0.9, **kwargs): Parameters ---------- - data : string, numpy array, pandas DataFrame or scipy.sparse + data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse Data source for refit. If string, it represents the path to txt file. label : list, numpy 1-D array or pandas Series / one-column DataFrame diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 74d79fe83912..798aea125748 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -90,6 +90,19 @@ class DataFrame(object): except ImportError: GRAPHVIZ_INSTALLED = False +"""datatable""" +try: + from datatable import DataTable + DATATABLE_INSTALLED = True +except ImportError: + DATATABLE_INSTALLED = False + + class DataTable(object): + """Dummy class for DataTable.""" + + pass + + """sklearn""" try: from sklearn.base import BaseEstimator diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index ed023e8ed5ca..f9b1697e256d 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -11,7 +11,7 @@ LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength, _LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight, - argc_, range_, string_type, DataFrame) + argc_, range_, string_type, DataFrame, DataTable) from .engine import train @@ -479,7 +479,7 @@ def fit(self, X, y, eval_metric = [eval_metric] if isinstance(eval_metric, (string_type, type(None))) else eval_metric params['metric'] = set(original_metric + eval_metric) - if not isinstance(X, DataFrame): + if not isinstance(X, (DataFrame, DataTable)): _X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2) _LGBMCheckConsistentLength(_X, _y, sample_weight) else: @@ -595,7 +595,7 @@ def predict(self, X, raw_score=False, num_iteration=None, """ if self._n_features is None: raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") - if not isinstance(X, DataFrame): + if not isinstance(X, (DataFrame, DataTable)): X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) n_features = X.shape[1] if self._n_features != n_features: