Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor LGBM_DatasetGetFeatureNames #3022

Merged
merged 3 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,23 @@ LGBM_SE LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
R_API_BEGIN();
int len = 0;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
const size_t reserved_string_size = 256;
std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len);
for (int i = 0; i < len; ++i) {
names[i].resize(256);
names[i].resize(reserved_string_size);
ptr_names[i] = names[i].data();
}
int out_len;
CHECK_CALL(LGBM_DatasetGetFeatureNames(R_GET_PTR(handle),
ptr_names.data(), &out_len));
size_t required_string_size;
CHECK_CALL(
LGBM_DatasetGetFeatureNames(
R_GET_PTR(handle),
len, &out_len,
reserved_string_size, &required_string_size,
ptr_names.data()));
CHECK_EQ(len, out_len);
CHECK_GE(reserved_string_size, required_string_size);
auto merge_str = Join<char*>(ptr_names, "\t");
EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
R_API_END();
Expand Down
14 changes: 11 additions & 3 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,21 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(DatasetHandle handle,
/*!
* \brief Get feature names of dataset.
* \param handle Handle of dataset
* \param[out] feature_names Feature names, should pre-allocate memory
* \param len Number of ``char*`` pointers stored at ``out_strs``.
* If smaller than the max size, only this many strings are copied
* \param[out] num_feature_names Number of feature names
* \param buffer_len Size of pre-allocated strings.
* Content is copied up to ``buffer_len - 1`` and null-terminated
* \param[out] out_buffer_len String sizes required to do the full string copies
* \param[out] feature_names Feature names, should pre-allocate memory
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(DatasetHandle handle,
char** feature_names,
int* num_feature_names);
const int len,
int* num_feature_names,
const size_t buffer_len,
size_t* out_buffer_len,
char** feature_names);

/*!
* \brief Free space for dataset.
Expand Down
32 changes: 32 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,38 @@ def set_group(self, group):
self.set_field('group', group)
return self

def get_feature_name(self):
"""Get the names of columns (features) in the Dataset.

Returns
-------
feature_names : list
The names of columns (features) in the Dataset.
"""
if self.handle is None:
raise LightGBMError("Cannot get feature_name before construct dataset")
num_feature = self.num_feature()
tmp_out_len = ctypes.c_int(0)
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for i in range_(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_DatasetGetFeatureNames(
self.handle,
num_feature,
ctypes.byref(tmp_out_len),
reserved_string_buffer_size,
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
if num_feature != tmp_out_len.value:
raise ValueError("Length of feature names doesn't equal with num_feature")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
"Allocated feature name buffer size ({}) was inferior to the needed size ({})."
.format(reserved_string_buffer_size, required_string_buffer_size.value)
)
return [string_buffers[i].value.decode('utf-8') for i in range_(num_feature)]

def get_label(self):
"""Get the label of the Dataset.

Expand Down
18 changes: 13 additions & 5 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1110,15 +1110,23 @@ int LGBM_DatasetSetFeatureNames(
}

int LGBM_DatasetGetFeatureNames(
DatasetHandle handle,
char** feature_names,
int* num_feature_names) {
API_BEGIN();
DatasetHandle handle,
const int len,
int* num_feature_names,
const size_t buffer_len,
size_t* out_buffer_len,
char** feature_names) {
API_BEGIN();
*out_buffer_len = 0;
auto dataset = reinterpret_cast<Dataset*>(handle);
auto inside_feature_name = dataset->feature_names();
*num_feature_names = static_cast<int>(inside_feature_name.size());
for (int i = 0; i < *num_feature_names; ++i) {
std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
if (i < len) {
std::memcpy(feature_names[i], inside_feature_name[i].c_str(), std::min(inside_feature_name[i].size() + 1, buffer_len));
feature_names[i][buffer_len - 1] = '\0';
}
*out_buffer_len = std::max(inside_feature_name[i].size() + 1, *out_buffer_len);
}
API_END();
}
Expand Down
7 changes: 6 additions & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,20 @@ def check_asserts(data):
self.assertTrue(np.all(np.isclose([data.label[0], data.weight[0], data.init_score[0]],
data.label[0])))
self.assertAlmostEqual(data.label[1], data.weight[1])
self.assertListEqual(data.feature_name, data.get_feature_name())

X, y = load_breast_cancer(True)
sequence = np.ones(y.shape[0])
sequence[0] = np.nan
sequence[1] = np.inf
lgb_data = lgb.Dataset(X, sequence, weight=sequence, init_score=sequence).construct()
feature_names = ['f{0}'.format(i) for i in range(X.shape[1])]
lgb_data = lgb.Dataset(X, sequence,
weight=sequence, init_score=sequence,
feature_name=feature_names).construct()
check_asserts(lgb_data)
lgb_data = lgb.Dataset(X, y).construct()
lgb_data.set_label(sequence)
lgb_data.set_weight(sequence)
lgb_data.set_init_score(sequence)
lgb_data.set_feature_name(feature_names)
check_asserts(lgb_data)