-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[python-package] Support 2d collections as input for init_score
in multiclass classification task
#4150
[python-package] Support 2d collections as input for init_score
in multiclass classification task
#4150
Changes from 7 commits
d387d6c
75bb7ef
cd9a41c
5644295
0edd5e7
e11a44c
a6b1744
ad44959
6222521
2c7ef3c
d3b763f
1676e55
16f0b9d
9bb4454
a3e890d
94aa6a9
9fb7f2b
91e7429
7822810
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -134,8 +134,8 @@ def is_numpy_column_array(data): | |||||||||
return len(shape) == 2 and shape[1] == 1 | ||||||||||
|
||||||||||
|
||||||||||
def cast_numpy_1d_array_to_dtype(array, dtype): | ||||||||||
"""Cast numpy 1d array to given dtype.""" | ||||||||||
def cast_numpy_array_to_dtype(array, dtype): | ||||||||||
"""Cast numpy array to given dtype.""" | ||||||||||
if array.dtype == dtype: | ||||||||||
return array | ||||||||||
return array.astype(dtype=dtype, copy=False) | ||||||||||
|
@@ -146,14 +146,24 @@ def is_1d_list(data): | |||||||||
return isinstance(data, list) and (not data or is_numeric(data[0])) | ||||||||||
|
||||||||||
|
||||||||||
def is_1d_collection(data): | ||||||||||
"""Check whether data is a 1-D collection.""" | ||||||||||
return ( | ||||||||||
is_numpy_1d_array(data) | ||||||||||
or is_numpy_column_array(data) | ||||||||||
or is_1d_list(data) | ||||||||||
or isinstance(data, pd_Series) | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
def list_to_1d_numpy(data, dtype=np.float32, name='list'): | ||||||||||
"""Convert data to numpy 1-D array.""" | ||||||||||
if is_numpy_1d_array(data): | ||||||||||
return cast_numpy_1d_array_to_dtype(data, dtype) | ||||||||||
return cast_numpy_array_to_dtype(data, dtype) | ||||||||||
elif is_numpy_column_array(data): | ||||||||||
_log_warning('Converting column-vector to 1d array') | ||||||||||
array = data.ravel() | ||||||||||
return cast_numpy_1d_array_to_dtype(array, dtype) | ||||||||||
return cast_numpy_array_to_dtype(array, dtype) | ||||||||||
elif is_1d_list(data): | ||||||||||
return np.array(data, dtype=dtype, copy=False) | ||||||||||
elif isinstance(data, pd_Series): | ||||||||||
|
@@ -165,6 +175,39 @@ def list_to_1d_numpy(data, dtype=np.float32, name='list'): | |||||||||
"It should be list, numpy 1-D array or pandas Series") | ||||||||||
|
||||||||||
|
||||||||||
def is_numpy_2d_array(data): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
"""Check whether data is a numpy 2-D array.""" | ||||||||||
return isinstance(data, np.ndarray) and len(data.shape) == 2 and data.shape[1] > 1 | ||||||||||
|
||||||||||
|
||||||||||
def is_2d_list(data): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
"""Check whether data is a 2-D list.""" | ||||||||||
return isinstance(data, list) and len(data) > 0 and is_1d_list(data[0]) | ||||||||||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
|
||||||||||
def is_2d_collection(data): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
"""Check whether data is a 2-D collection.""" | ||||||||||
return ( | ||||||||||
is_numpy_2d_array(data) | ||||||||||
or is_2d_list(data) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
or isinstance(data, pd_DataFrame) | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
def data_to_2d_numpy(data, dtype=np.float32, name='list'): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Could you also add type hints here? |
||||||||||
"""Convert data to numpy 2-D array.""" | ||||||||||
if is_numpy_2d_array(data): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
return cast_numpy_array_to_dtype(data, dtype) | ||||||||||
if is_2d_list(data): | ||||||||||
return np.array(data, dtype=dtype) | ||||||||||
if isinstance(data, pd_DataFrame): | ||||||||||
if _get_bad_pandas_dtypes(data.dtypes): | ||||||||||
raise ValueError('DataFrame.dtypes must be int, float or bool') | ||||||||||
return cast_numpy_array_to_dtype(data.values, dtype) | ||||||||||
raise TypeError(f"Wrong type({type(data).__name__}) for {name}.\n" | ||||||||||
"It should be list of lists, numpy 2-D array or pandas DataFrame") | ||||||||||
|
||||||||||
|
||||||||||
def cfloat32_array_to_numpy(cptr, length): | ||||||||||
"""Convert a ctypes float pointer array to a numpy array.""" | ||||||||||
if isinstance(cptr, ctypes.POINTER(ctypes.c_float)): | ||||||||||
|
@@ -1070,7 +1113,7 @@ def __init__(self, data, label=None, reference=None, | |||||||||
sum(group) = n_samples. | ||||||||||
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, | ||||||||||
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc. | ||||||||||
init_score : list, numpy 1-D array, pandas Series or None, optional (default=None) | ||||||||||
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task) or None, optional (default=None) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @StrikerRUS do you think this should have a comma before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like how you did it. I believe a comma before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a personally prefer |
||||||||||
Init score for Dataset. | ||||||||||
silent : bool, optional (default=False) | ||||||||||
Whether to print messages during construction. | ||||||||||
|
@@ -1487,7 +1530,7 @@ def create_valid(self, data, label=None, weight=None, group=None, | |||||||||
sum(group) = n_samples. | ||||||||||
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, | ||||||||||
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc. | ||||||||||
init_score : list, numpy 1-D array, pandas Series or None, optional (default=None) | ||||||||||
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task) or None, optional (default=None) | ||||||||||
Init score for Dataset. | ||||||||||
silent : bool, optional (default=False) | ||||||||||
Whether to print messages during construction. | ||||||||||
|
@@ -1823,7 +1866,7 @@ def set_init_score(self, init_score): | |||||||||
|
||||||||||
Parameters | ||||||||||
---------- | ||||||||||
init_score : list, numpy 1-D array, pandas Series or None | ||||||||||
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task) or None, optional (default=None) | ||||||||||
jmoralez marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
Init score for Booster. | ||||||||||
|
||||||||||
Returns | ||||||||||
|
@@ -1833,7 +1876,16 @@ def set_init_score(self, init_score): | |||||||||
""" | ||||||||||
self.init_score = init_score | ||||||||||
if self.handle is not None and init_score is not None: | ||||||||||
init_score = list_to_1d_numpy(init_score, np.float64, name='init_score') | ||||||||||
if is_1d_collection(init_score): | ||||||||||
init_score = list_to_1d_numpy(init_score, np.float64, name='init_score') | ||||||||||
elif is_2d_collection(init_score): | ||||||||||
jmoralez marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
init_score = data_to_2d_numpy(init_score, np.float64, name='init_score') | ||||||||||
init_score = init_score.ravel(order='F') | ||||||||||
else: | ||||||||||
raise TypeError( | ||||||||||
'init_score must be list, numpy 1-D array or pandas Series.\n' | ||||||||||
'In multiclass classification init_score can also be a list of lists, numpy 2-D array or pandas DataFrame.' | ||||||||||
) | ||||||||||
self.set_field('init_score', init_score) | ||||||||||
StrikerRUS marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
self.init_score = self.get_field('init_score') # original values can be modified at cpp side | ||||||||||
return self | ||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1277,17 +1277,14 @@ def test_init_score(task, output, cluster): | |||||||||||||||||||||||||||||||||||||||||
'time_out': 5 | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
init_score = random.random() | ||||||||||||||||||||||||||||||||||||||||||
# init_scores must be a 1D array, even for multiclass classification | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we need updates in type hints and docstrings for Dask module. LightGBM/python-package/lightgbm/dask.py Line 395 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 401 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Lines 423 to 424 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Lines 442 to 443 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1024 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1030 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1162 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1167 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1195 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1198 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1341 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1345 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1372 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1375 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1502 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1507 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1539 in cfe8eb1
LightGBM/python-package/lightgbm/dask.py Line 1542 in cfe8eb1
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in 2c7ef3c. I think the docstrings maybe ended up a bit too verbose, let me know what you think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||
# where you need to provide 1 score per class for each row in X | ||||||||||||||||||||||||||||||||||||||||||
# https://github.com/microsoft/LightGBM/issues/4046 | ||||||||||||||||||||||||||||||||||||||||||
size_factor = 1 | ||||||||||||||||||||||||||||||||||||||||||
if task == 'multiclass-classification': | ||||||||||||||||||||||||||||||||||||||||||
size_factor = 3 # number of classes | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
if output.startswith('dataframe'): | ||||||||||||||||||||||||||||||||||||||||||
init_scores = dy.map_partitions(lambda x: pd.Series([init_score] * x.size * size_factor)) | ||||||||||||||||||||||||||||||||||||||||||
init_scores = dy.map_partitions(lambda x: pd.DataFrame([[init_score] * size_factor] * x.size)) | ||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||
init_scores = dy.map_blocks(lambda x: np.repeat(init_score, x.size * size_factor)) | ||||||||||||||||||||||||||||||||||||||||||
init_scores = dy.map_blocks(lambda x: np.full((x.size, size_factor), init_score)) | ||||||||||||||||||||||||||||||||||||||||||
model = model_factory(client=client, **params) | ||||||||||||||||||||||||||||||||||||||||||
model.fit(dX, dy, sample_weight=dw, init_score=init_scores, group=dg) | ||||||||||||||||||||||||||||||||||||||||||
# value of the root node is 0 when init_score is set | ||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be adding type hints for new code when possible, to increase the chance of catching bugs with
mypy
and reduce the amount of effort needed for #3756.I'd also like to recommend prefixing objects that we don't want to encourage people to import with
_
, to make it clearer that they're intended to be internal