Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Blocking] Fix #3840: Clean up logic for parsing tree_method parameter #3849

Merged
merged 11 commits into from
Nov 2, 2018
82 changes: 82 additions & 0 deletions src/common/enum_class_param.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*!
* Copyright 2018 by Contributors
* \file enum_class_param.h
* \brief macro for using C++11 enum class as DMLC parameter
* \author Hyunsu Philip Cho
*/

#ifndef XGBOOST_COMMON_ENUM_CLASS_PARAM_H_
#define XGBOOST_COMMON_ENUM_CLASS_PARAM_H_

#include <dmlc/parameter.h>
#include <string>
#include <type_traits>

/*!
* \brief Specialization of FieldEntry for enum class (backed by int)
*
* Use this macro to use C++11 enum class as DMLC parameters
*
* Usage:
*
* \code{.cpp}
*
* // enum class must inherit from int type
* enum class Foo : int {
* kBar = 0, kFrog = 1, kCat = 2, kDog = 3
* };
*
* // This line is needed to prevent compilation error
* DECLARE_FIELD_ENUM_CLASS(Foo);
*
* // Now define DMLC parameter as usual;
* // enum classes can now be members.
* struct MyParam : dmlc::Parameter<MyParam> {
* Foo foo;
* DMLC_DECLARE_PARAMETER(MyParam) {
* DMLC_DECLARE_FIELD(foo)
* .set_default(Foo::kBar)
* .add_enum("bar", Foo::kBar)
* .add_enum("frog", Foo::kFrog)
* .add_enum("cat", Foo::kCat)
* .add_enum("dog", Foo::kDog);
* }
* };
*
* DMLC_REGISTER_PARAMETER(MyParam);
* \endcode
*/
#define DECLARE_FIELD_ENUM_CLASS(EnumClass) \
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
namespace dmlc { \
namespace parameter { \
template <> \
class FieldEntry<EnumClass> : public FieldEntry<int> { \
public: \
FieldEntry<EnumClass>() { \
static_assert( \
std::is_same<int, typename std::underlying_type<EnumClass>::type>::value, \
"enum class must be backed by int"); \
is_enum_ = true; \
} \
using Super = FieldEntry<int>; \
void Set(void *head, const std::string &value) const override { \
Super::Set(head, value); \
} \
inline FieldEntry<EnumClass>& add_enum(const std::string &key, EnumClass value) { \
Super::add_enum(key, static_cast<int>(value)); \
return *this; \
} \
inline FieldEntry<EnumClass>& set_default(const EnumClass& default_value) { \
default_value_ = static_cast<int>(default_value); \
has_default_ = true; \
return *this; \
} \
/* NOLINTNEXTLINE */ \
inline void Init(const std::string &key, void *head, EnumClass& ref) { \
Super::Init(key, head, *reinterpret_cast<int*>(&ref)); \
} \
}; \
} /* namespace parameter */ \
} /* namespace dmlc */

#endif // XGBOOST_COMMON_ENUM_CLASS_PARAM_H_
Loading