Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] convert datatable to numpy directly #1970

Merged
merged 5 commits into from
Feb 4, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/Python-Intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ The LightGBM Python module can load data from:

- libsvm/tsv/csv/txt format file

- NumPy 2D array(s), pandas DataFrame, SciPy sparse matrix
- NumPy 2D array(s), pandas DataFrame, H2O DataTable, SciPy sparse matrix

- LightGBM binary file

Expand Down
20 changes: 13 additions & 7 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import scipy.sparse

from .compat import (DataFrame, Series,
from .compat import (DataFrame, Series, DataTable,
decode_string, string_type,
integer_types, numeric_types,
json, json_default_with_numpy,
Expand Down Expand Up @@ -402,7 +402,7 @@ def predict(self, data, num_iteration=-1,

Parameters
----------
data : string, numpy array, pandas DataFrame or scipy.sparse
data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for prediction.
When data type is string, it represents the path of txt file.
num_iteration : int, optional (default=-1)
Expand Down Expand Up @@ -464,6 +464,8 @@ def predict(self, data, num_iteration=-1,
except BaseException:
raise ValueError('Cannot convert data list to numpy array.')
preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type)
elif isinstance(data, DataTable):
preds, nrow = self.__pred_for_np2d(data.to_numpy(), num_iteration, predict_type)
else:
try:
warnings.warn('Converting data to scipy sparse matrix.')
Expand Down Expand Up @@ -643,7 +645,7 @@ def __init__(self, data, label=None, reference=None,

Parameters
----------
data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays
data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse or list of numpy arrays
Data source of Dataset.
If string, it represents the path to txt file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Expand Down Expand Up @@ -782,6 +784,8 @@ def _lazy_init(self, data, label=None, reference=None,
self.__init_from_np2d(data, params_str, ref_dataset)
elif isinstance(data, list) and len(data) > 0 and all(isinstance(x, np.ndarray) for x in data):
self.__init_from_list_np2d(data, params_str, ref_dataset)
elif isinstance(data, DataTable):
self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset)
else:
try:
csr = scipy.sparse.csr_matrix(data)
Expand Down Expand Up @@ -998,7 +1002,7 @@ def create_valid(self, data, label=None, weight=None, group=None,

Parameters
----------
data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays
data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse or list of numpy arrays
Data source of Dataset.
If string, it represents the path to txt file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Expand Down Expand Up @@ -1388,7 +1392,7 @@ def get_data(self):

Returns
-------
data : string, numpy array, pandas DataFrame, scipy.sparse, list of numpy arrays or None
data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse, list of numpy arrays or None
Raw data used in the Dataset construction.
"""
if self.handle is None:
Expand All @@ -1398,6 +1402,8 @@ def get_data(self):
self.data = self.data[self.used_indices, :]
elif isinstance(self.data, DataFrame):
self.data = self.data.iloc[self.used_indices].copy()
elif isinstance(self.data, DataTable):
self.data = self.data[self.used_indices, :]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not very sure about here, should we need a copy or not?

Copy link
Collaborator

@StrikerRUS StrikerRUS Jan 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me too... To be honest, I see DataTable for the first time 😄

Speaking about pandas, I used .copy() with the aim to prevent SettingWithCopyWarning in users' preprocess functions.

Maybe address this question to someone from h2o?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see. So theoretically, we don't need a copy here.
@pseudotensor any idea about this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @pseudotensor
Please take a look at this function for getting raw data.

else:
warnings.warn("Cannot subset {} type of raw data.\n"
"Returning original raw data".format(type(self.data).__name__))
Expand Down Expand Up @@ -2149,7 +2155,7 @@ def predict(self, data, num_iteration=None,

Parameters
----------
data : string, numpy array, pandas DataFrame or scipy.sparse
data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for prediction.
If string, it represents the path to txt file.
num_iteration : int or None, optional (default=None)
Expand Down Expand Up @@ -2194,7 +2200,7 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):

Parameters
----------
data : string, numpy array, pandas DataFrame or scipy.sparse
data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for refit.
If string, it represents the path to txt file.
label : list, numpy 1-D array or pandas Series / one-column DataFrame
Expand Down
13 changes: 13 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ class DataFrame(object):
except ImportError:
GRAPHVIZ_INSTALLED = False

"""datatable"""
try:
from datatable import DataTable
DT_INSTALLED = True
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
DT_INSTALLED = False

class DataTable(object):
"""Dummy class for DataTable."""

pass


"""sklearn"""
try:
from sklearn.base import BaseEstimator
Expand Down
6 changes: 3 additions & 3 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
_LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
argc_, range_, string_type, DataFrame)
argc_, range_, string_type, DataFrame, DataTable)
from .engine import train


Expand Down Expand Up @@ -479,7 +479,7 @@ def fit(self, X, y,
eval_metric = [eval_metric] if isinstance(eval_metric, (string_type, type(None))) else eval_metric
params['metric'] = set(original_metric + eval_metric)

if not isinstance(X, DataFrame):
if not isinstance(X, (DataFrame, DataTable)):
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(_X, _y, sample_weight)
else:
Expand Down Expand Up @@ -595,7 +595,7 @@ def predict(self, X, raw_score=False, num_iteration=None,
"""
if self._n_features is None:
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
if not isinstance(X, DataFrame):
if not isinstance(X, (DataFrame, DataTable)):
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
n_features = X.shape[1]
if self._n_features != n_features:
Expand Down