/*! * Copyright 2018 by Contributors * \file parameter.h * \brief macro for using C++11 enum class as DMLC parameter * \author Hyunsu Philip Cho */ #ifndef XGBOOST_PARAMETER_H_ #define XGBOOST_PARAMETER_H_ #include #include #include #include /*! * \brief Specialization of FieldEntry for enum class (backed by int) * * Use this macro to use C++11 enum class as DMLC parameters * * Usage: * * \code{.cpp} * * // enum class must inherit from int type * enum class Foo : int { * kBar = 0, kFrog = 1, kCat = 2, kDog = 3 * }; * * // This line is needed to prevent compilation error * DECLARE_FIELD_ENUM_CLASS(Foo); * * // Now define DMLC parameter as usual; * // enum classes can now be members. * struct MyParam : dmlc::Parameter { * Foo foo; * DMLC_DECLARE_PARAMETER(MyParam) { * DMLC_DECLARE_FIELD(foo) * .set_default(Foo::kBar) * .add_enum("bar", Foo::kBar) * .add_enum("frog", Foo::kFrog) * .add_enum("cat", Foo::kCat) * .add_enum("dog", Foo::kDog); * } * }; * * DMLC_REGISTER_PARAMETER(MyParam); * \endcode */ #define DECLARE_FIELD_ENUM_CLASS(EnumClass) \ namespace dmlc { \ namespace parameter { \ template <> \ class FieldEntry : public FieldEntry { \ public: \ FieldEntry() { \ static_assert( \ std::is_same::type>::value, \ "enum class must be backed by int"); \ is_enum_ = true; \ } \ using Super = FieldEntry; \ void Set(void *head, const std::string &value) const override { \ Super::Set(head, value); \ } \ inline FieldEntry& add_enum(const std::string &key, EnumClass value) { \ Super::add_enum(key, static_cast(value)); \ return *this; \ } \ inline FieldEntry& set_default(const EnumClass& default_value) { \ default_value_ = static_cast(default_value); \ has_default_ = true; \ return *this; \ } \ inline void Init(const std::string &key, void *head, EnumClass& ref) { /* NOLINT */ \ Super::Init(key, head, *reinterpret_cast(&ref)); \ } \ }; \ } /* namespace parameter */ \ } /* namespace dmlc */ namespace xgboost { template struct XGBoostParameter : public dmlc::Parameter { protected: bool initialised_ {false}; public: template Args UpdateAllowUnknown(Container const& kwargs, bool* out_changed = nullptr) { if (initialised_) { return dmlc::Parameter::UpdateAllowUnknown(kwargs, out_changed); } else { auto unknown = dmlc::Parameter::InitAllowUnknown(kwargs); if (out_changed) { *out_changed = true; } initialised_ = true; return unknown; } } bool GetInitialised() const { return static_cast(this->initialised_); } }; } // namespace xgboost #endif // XGBOOST_PARAMETER_H_