Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Better column dtype logging when column has "bad dtype" #5065

Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ def list_to_1d_numpy(data, dtype=np.float32, name='list'):
elif is_1d_list(data):
return np.array(data, dtype=dtype, copy=False)
elif isinstance(data, pd_Series):
if _get_bad_pandas_dtypes([data.dtypes]):
raise ValueError('Series.dtypes must be int, float or bool')
_check_for_bad_pandas_dtypes(data.to_frame().dtypes)
return np.array(data, dtype=dtype, copy=False) # SparseArray should be supported as well
else:
raise TypeError(f"Wrong type({type(data).__name__}) for {name}.\n"
Expand Down Expand Up @@ -217,8 +216,7 @@ def _data_to_2d_numpy(data: Any, dtype: type = np.float32, name: str = 'list') -
if _is_2d_list(data):
return np.array(data, dtype=dtype)
if isinstance(data, pd_DataFrame):
if _get_bad_pandas_dtypes(data.dtypes):
raise ValueError('DataFrame.dtypes must be int, float or bool')
_check_for_bad_pandas_dtypes(data.dtypes)
return cast_numpy_array_to_dtype(data.values, dtype)
raise TypeError(f"Wrong type({type(data).__name__}) for {name}.\n"
"It should be list of lists, numpy 2-D array or pandas DataFrame")
Expand Down Expand Up @@ -500,7 +498,7 @@ def c_int_array(data):
return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed


def _get_bad_pandas_dtypes(dtypes):
def _check_for_bad_pandas_dtypes(pandas_dtypes_series):
float128 = getattr(np, 'float128', type(None))

def is_allowed_numpy_dtype(dtype):
Expand All @@ -509,7 +507,14 @@ def is_allowed_numpy_dtype(dtype):
and not issubclass(dtype, (np.timedelta64, float128))
)

return [i for i, dtype in enumerate(dtypes) if not is_allowed_numpy_dtype(dtype.type)]
bad_pandas_dtypes = [
f'{column_name}: {pandas_dtype}'
for column_name, pandas_dtype in pandas_dtypes_series.iteritems()
if not is_allowed_numpy_dtype(pandas_dtype)
hsorsky marked this conversation as resolved.
Show resolved Hide resolved
]
if bad_pandas_dtypes:
raise ValueError('DataFrame.dtypes must be int, float or bool.\n'
hsorsky marked this conversation as resolved.
Show resolved Hide resolved
f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}')


def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical):
Expand Down Expand Up @@ -540,12 +545,7 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
categorical_feature = list(categorical_feature)
if feature_name == 'auto':
feature_name = list(data.columns)
bad_indices = _get_bad_pandas_dtypes(data.dtypes)
if bad_indices:
bad_index_cols_str = ', '.join(data.columns[bad_indices])
raise ValueError("DataFrame.dtypes for data must be int, float or bool.\n"
"Did not expect the data types in the following fields: "
f"{bad_index_cols_str}")
_check_for_bad_pandas_dtypes(data.dtypes)
df_dtypes = [dtype.type for dtype in data.dtypes]
df_dtypes.append(np.float32) # so that the target dtype considers floats
target_dtype = np.find_common_type(df_dtypes, [])
Expand All @@ -562,8 +562,7 @@ def _label_from_pandas(label):
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
if _get_bad_pandas_dtypes(label.dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
_check_for_bad_pandas_dtypes(label.dtypes)
label = np.ravel(label.values.astype(np.float32, copy=False))
return label

Expand Down