Skip to content

Commit

Permalink
[python] convert datatable to numpy directly (#1970)
Browse files Browse the repository at this point in the history
* convert datatable to numpy directly

* fix according to comments

* updated more docstrings

* simplified isinstance check

* Update compat.py
  • Loading branch information
guolinke authored Feb 4, 2019
1 parent 107b50b commit 2c9d332
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/Python-Intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ The LightGBM Python module can load data from:

- libsvm/tsv/csv/txt format file

- NumPy 2D array(s), pandas DataFrame, SciPy sparse matrix
- NumPy 2D array(s), pandas DataFrame, H2O DataTable, SciPy sparse matrix

- LightGBM binary file

Expand Down
20 changes: 13 additions & 7 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import scipy.sparse

from .compat import (DataFrame, Series,
from .compat import (DataFrame, Series, DataTable,
decode_string, string_type,
integer_types, numeric_types,
json, json_default_with_numpy,
Expand Down Expand Up @@ -409,7 +409,7 @@ def predict(self, data, num_iteration=-1,
Parameters
----------
data : string, numpy array, pandas DataFrame or scipy.sparse
data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for prediction.
When data type is string, it represents the path of txt file.
num_iteration : int, optional (default=-1)
Expand Down Expand Up @@ -471,6 +471,8 @@ def predict(self, data, num_iteration=-1,
except BaseException:
raise ValueError('Cannot convert data list to numpy array.')
preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type)
elif isinstance(data, DataTable):
preds, nrow = self.__pred_for_np2d(data.to_numpy(), num_iteration, predict_type)
else:
try:
warnings.warn('Converting data to scipy sparse matrix.')
Expand Down Expand Up @@ -650,7 +652,7 @@ def __init__(self, data, label=None, reference=None,
Parameters
----------
data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays
data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse or list of numpy arrays
Data source of Dataset.
If string, it represents the path to txt file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Expand Down Expand Up @@ -789,6 +791,8 @@ def _lazy_init(self, data, label=None, reference=None,
self.__init_from_np2d(data, params_str, ref_dataset)
elif isinstance(data, list) and len(data) > 0 and all(isinstance(x, np.ndarray) for x in data):
self.__init_from_list_np2d(data, params_str, ref_dataset)
elif isinstance(data, DataTable):
self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset)
else:
try:
csr = scipy.sparse.csr_matrix(data)
Expand Down Expand Up @@ -1005,7 +1009,7 @@ def create_valid(self, data, label=None, weight=None, group=None,
Parameters
----------
data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays
data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse or list of numpy arrays
Data source of Dataset.
If string, it represents the path to txt file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Expand Down Expand Up @@ -1395,7 +1399,7 @@ def get_data(self):
Returns
-------
data : string, numpy array, pandas DataFrame, scipy.sparse, list of numpy arrays or None
data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse, list of numpy arrays or None
Raw data used in the Dataset construction.
"""
if self.handle is None:
Expand All @@ -1405,6 +1409,8 @@ def get_data(self):
self.data = self.data[self.used_indices, :]
elif isinstance(self.data, DataFrame):
self.data = self.data.iloc[self.used_indices].copy()
elif isinstance(self.data, DataTable):
self.data = self.data[self.used_indices, :]
else:
warnings.warn("Cannot subset {} type of raw data.\n"
"Returning original raw data".format(type(self.data).__name__))
Expand Down Expand Up @@ -2156,7 +2162,7 @@ def predict(self, data, num_iteration=None,
Parameters
----------
data : string, numpy array, pandas DataFrame or scipy.sparse
data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for prediction.
If string, it represents the path to txt file.
num_iteration : int or None, optional (default=None)
Expand Down Expand Up @@ -2201,7 +2207,7 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
Parameters
----------
data : string, numpy array, pandas DataFrame or scipy.sparse
data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for refit.
If string, it represents the path to txt file.
label : list, numpy 1-D array or pandas Series / one-column DataFrame
Expand Down
13 changes: 13 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ class DataFrame(object):
except ImportError:
GRAPHVIZ_INSTALLED = False

"""datatable"""
try:
from datatable import DataTable
DATATABLE_INSTALLED = True
except ImportError:
DATATABLE_INSTALLED = False

class DataTable(object):
"""Dummy class for DataTable."""

pass


"""sklearn"""
try:
from sklearn.base import BaseEstimator
Expand Down
6 changes: 3 additions & 3 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
_LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
argc_, range_, string_type, DataFrame)
argc_, range_, string_type, DataFrame, DataTable)
from .engine import train


Expand Down Expand Up @@ -479,7 +479,7 @@ def fit(self, X, y,
eval_metric = [eval_metric] if isinstance(eval_metric, (string_type, type(None))) else eval_metric
params['metric'] = set(original_metric + eval_metric)

if not isinstance(X, DataFrame):
if not isinstance(X, (DataFrame, DataTable)):
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(_X, _y, sample_weight)
else:
Expand Down Expand Up @@ -595,7 +595,7 @@ def predict(self, X, raw_score=False, num_iteration=None,
"""
if self._n_features is None:
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
if not isinstance(X, DataFrame):
if not isinstance(X, (DataFrame, DataTable)):
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
n_features = X.shape[1]
if self._n_features != n_features:
Expand Down

0 comments on commit 2c9d332

Please sign in to comment.