-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Target and Count encodings for categorical features #3234
Conversation
include/LightGBM/ctr_provider.hpp
Outdated
} | ||
|
||
inline double EncodeCatValueForValidation(const int fid, const double feature_value) const { | ||
const auto& ctr_encoding = ctr_encodings_.at(fid)[config_.num_ctr_folds]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avoid to use .at()
, which is slow.
At your earliest convenience, can you please merge in |
@shiyu1994 can you fix the conflict? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shiyu1994 for the R side, I see you have this error in CI:
Can you please update
LightGBM/R-package/src/lightgbm_R.cpp
Line 681 in 82e2ff7
{"LGBM_DatasetCreateFromCSC_R" , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R , 10}, |
{"LGBM_DatasetCreateFromCSC_R" , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R , 12},
Comparison between CTR and old categorical feature approach:
The hyper-parameters of old categorical method are
where For CTR version, we fix the number of folds for CTR calculation to 4. And try both CTR with count and CTR only. We use 5-fold cross validation. For the number of trials allowed to tune hyperparameters, we allow the old categorical method to use 500 trials. And for each of CTR with count and CTR only, we use 200 trials. The tuning for each algorithm and each dataset is repeated for 5 times (with 5 different CV fold partitions). Following table shows the AUC on test sets
|
// transform categorical features to encoded numerical values before the bin construction process | ||
class CategoryEncodingProvider { | ||
public: | ||
class CatConverter { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put CatConverter as separate class in separate file to reduce complexity for CategoryEncodingProvider? Looks like we might need to access properties in CategoryEncodingProvider, then we can pass CategoryEncodingProvider to the CatConverter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For every converter we can have separate unit tests?
|
||
virtual std::string DumpToString() const = 0; | ||
|
||
virtual json11::Json DumpToJSONObject() const = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make sure we have common pattern for serialization could we use json as general pattern? json is a special string then we do not need to define our customize pattern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I believe we should not do serialization & deserialization very often?
return cat_fid_to_convert_fid_.at(cat_fid); | ||
} | ||
|
||
static CatConverter* CreateFromCharPointer(const char* char_pointer, size_t* used_len, double prior_weight) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can have separate file for class initialization logic.
return str_stream.str(); | ||
} | ||
|
||
json11::Json DumpToJSONObject() const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it is another serialization logic for sub/base class, could we have pattern like below:
DumpToJSONObject(json11::Json context)
{
base.DumpToJSONObject(context);
// Dump to json logic for sub class
}
} | ||
}; | ||
|
||
class TargetEncoderLabelMean: public CatConverter { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For these converter, looks like separate class would be clear?
} | ||
|
||
// parameter configuration | ||
Config config_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we should not have config after initialize?
Config config_; | ||
|
||
// size of training data | ||
data_size_t num_data_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for below properties we'd better have only 1 unified pair serialization & deserialization function?
bool accumulated_from_file_; | ||
}; | ||
|
||
class CategoryEncodingParser : public Parser { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use parser as middleware looks like wield, it is more like a step in pipeline processing. Not sure is that possible we refactor the processing steps to pipeline pattern?
/*! | ||
* \brief Constructor for customized parser. The constructor accepts content not path because need to save/load the config along with model string | ||
*/ | ||
explicit Parser(std::string) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to have a constructor of string looks like wield, not sure any purpose for this in base class?
|
||
namespace LightGBM { | ||
|
||
CategoryEncodingProvider::CategoryEncodingProvider(Config* config) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have a unified construction for CategoryEncodingProvider? And others are just transform inputs into this construction method? something like below:
CategoryEncodingProvider(a, b, c, d, e)
CategoryEncodingProvider(a, b, c, d) {
CategoryEncodingProvider((a, b, c, d, default_value)
}
tmp_parser_ = nullptr; | ||
} | ||
|
||
std::string CategoryEncodingProvider::DumpToJSON() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as converter, Looks like unified single pair serialization/deserialization would reduce lots of complexity.
Thanks @tongwu-msft for the valuable suggestions for refactoring the code. We can have follow-up PRs for the refactor after this is merged. |
TBH, I'm not sure this is a good approach of adding new features for the open source project where there are no any deadlines and the work is based on a collaboration of multiple persons. We may end up having a lot of such "valuable suggestions for refactoring the code" that may be never get enough attention in the future. As a result, the whole codebase of LightGBM will become less readable, unoptimized and extremely hard to make contributions for outside contributors. |
@StrikerRUS Thanks for the reminder. I've discussed with @tongwu-msft, and he will do the refactor and then push to this branch. |
I agree with @StrikerRUS 's comments in #3234 (comment). Right now, while attention is focused on this PR, is the best time to address review comments. |
@shiyu1994 should we close this PR? I think this is hard to merge now. |
It's been 6 months since @guolinke asked if we should close this PR with no response, and more than a year since the most recent commit. @shiyu1994 I'm closing this. If you'd like to continue this work, please propose a new PR and we can start a new review cycle. |
This pull request has been automatically locked since there has not been any recent activity since it was closed. |
This pull request is to support converting categorical features into CTR and count values. The CTR values are calculated by dividing the training data into folds, with a cross-validation style.
The performance evaluation has been added in #3234 (comment). Note that this is a basic version without ensemble trick. We will have a separate PR to implement ensemble trick to boost the performance of this categorical encoding method.
Detailed descriptions and guideline for reviews can be found below.
Description
Idea
LightGBM can handle categorical features internally, without any manual encoding into numerical values by users. The current approach that LightGBM handles categorical features can be found here https://lightgbm.readthedocs.io/en/latest/Features.html#optimal-split-for-categorical-features.
However, this approach has two known drawbacks:
cat_l2
(https://lightgbm.readthedocs.io/en/latest/Parameters.html#cat_l2),cat_smooth
(https://lightgbm.readthedocs.io/en/latest/Parameters.html#cat_smooth) andmax_cat_to_onehot
(https://lightgbm.readthedocs.io/en/latest/Parameters.html#max_cat_to_onehot) need to be tuned carefully.Inspired by CatBoost, we implement two new approach for internal encoding of categorical features:
c
of featurej
is encoded aswhere is the value of feature
j
for datai
in the training set, is the indicator function, is the label for datai
,p
is the prior value (which is the mean of labels over the whole dataset, by default), andw
is the weight for the prior value.For better generalization ability, for the target encoding, we use a cross-validation style approach. The training data is first divided into
K
folds. When encoding the category values in the foldk
, we only use the otherK-1
except foldk
to calculate the encoded value. Thus, for each categoryc
, we will haveK
different encoded values for a target encoding approach.New Parameters
We add a new parameter
category_encoders
to specify the encoding methods for categorical features. Users can specify multiple encoding methods at the same time, and separate them by commas. For example, withcategory_encoders=target:0.5,count,target
, we'll use 3 encoding methods at the same time: target encoding with prior valuep=0.5
, count encoding, and target encoding with default prior value. Each encoding method creates a new feature for each categorical feature for training. We also allow araw
encoding method which indicates the current approach of handling categorical feature of LightGBM.Besides, we also add two parameters
prior_weight
andnum_target_encoding_folds
to allow users specify thep
andK
in the target encoding.Implementation
The core of this PR is a new class
CategoryEncodingProvider
, which is defined in https://github.com/shiyu1994/LightGBM/blob/ctr/include/LightGBM/category_encoding_provider.hpp and https://github.com/shiyu1994/LightGBM/blob/ctr/src/io/category_encoding_provider.cpp. These two files amount to 2/3 of the code changes in the PR.CategoryEncodingProvider
works by 2 steps:Since we need to support multiple data input formats,
CategoryEncodingProvider
is used in corresponding functions insrc/c_api.cpp
which is called when processing inputs of these formats. IncludingLGBM_DatasetCreateFromFile
.LGBM_DatasetCreateFromMat
andLGBM_DatasetCreateFromMats
.LGBM_DatasetCreateFromCSR
.LGBM_DatasetCreateFromCSC
.