Add Model and Configurable interface. (#4945)

* Apply Configurable to objective functions.
* Apply Model to Learner and Regtree, gbm.
* Add Load/SaveConfig to objs.
* Refactor obj tests to use smart pointer.
* Dummy methods for Save/Load Model.
This commit is contained in:
Jiaming Yuan
2019-10-18 01:56:02 -04:00
committed by GitHub
parent 9fc681001a
commit ae536756ae
31 changed files with 521 additions and 187 deletions

View File

@@ -7,7 +7,7 @@
#include <dmlc/parameter.h>
#include <xgboost/logging.h>
#include <xgboost/enum_class_param.h>
#include <xgboost/parameter.h>
#include <string>

View File

@@ -322,6 +322,9 @@ class Json {
static void Dump(Json json, std::ostream* stream,
bool pretty = ConsoleLogger::ShouldLog(
ConsoleLogger::LogVerbosity::kDebug));
static void Dump(Json json, std::string* out,
bool pretty = ConsoleLogger::ShouldLog(
ConsoleLogger::LogVerbosity::kDebug));
Json() : ptr_{new JsonNull} {}
@@ -400,6 +403,13 @@ class Json {
return *ptr_ == *(rhs.ptr_);
}
friend std::ostream& operator<<(std::ostream& os, Json const& j) {
std::string str;
Json::Dump(j, &str);
os << str;
return os;
}
private:
std::shared_ptr<Value> ptr_;
};

View File

@@ -16,6 +16,7 @@
#include <xgboost/objective.h>
#include <xgboost/feature_map.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/model.h>
#include <utility>
#include <map>
@@ -41,7 +42,7 @@ namespace xgboost {
*
* \endcode
*/
class Learner : public rabit::Serializable {
class Learner : public Model, public rabit::Serializable {
public:
/*! \brief virtual destructor */
~Learner() override = default;

View File

@@ -1,9 +1,9 @@
/*!
* Copyright (c) 2015-2019 by Contributors
* \file logging.h
* \brief defines console logging options for xgboost.
* Use to enforce unified print behavior.
* For debug loggers, use LOG(INFO) and LOG(ERROR).
*
* \brief defines console logging options for xgboost. Use to enforce unified print
* behavior.
*/
#ifndef XGBOOST_LOGGING_H_
#define XGBOOST_LOGGING_H_

44
include/xgboost/model.h Normal file
View File

@@ -0,0 +1,44 @@
/*!
* Copyright (c) 2019 by Contributors
* \file model.h
* \brief Defines the abstract interface for different components in XGBoost.
*/
#ifndef XGBOOST_MODEL_H_
#define XGBOOST_MODEL_H_
namespace dmlc {
class Stream;
} // namespace dmlc
namespace xgboost {
class Json;
struct Model {
/*!
* \brief Save the model to stream.
* \param fo output write stream
*/
virtual void SaveModel(dmlc::Stream* fo) const = 0;
/*!
* \brief Load the model from stream.
* \param fi input read stream
*/
virtual void LoadModel(dmlc::Stream* fi) = 0;
};
struct Configurable {
/*!
* \brief Load configuration from JSON object
* \param in JSON object containing the configuration
*/
virtual void LoadConfig(Json const& in) = 0;
/*!
* \brief Save configuration to JSON object
* \param out pointer to output JSON object
*/
virtual void SaveConfig(Json* out) const = 0;
};
} // namespace xgboost
#endif // XGBOOST_MODEL_H_

View File

@@ -10,6 +10,7 @@
#include <dmlc/registry.h>
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <xgboost/model.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h>
@@ -21,7 +22,7 @@
namespace xgboost {
/*! \brief interface of objective function */
class ObjFunction {
class ObjFunction : public Configurable {
protected:
GenericParameter const* tparam_;

View File

@@ -5,10 +5,11 @@
* \author Hyunsu Philip Cho
*/
#ifndef XGBOOST_ENUM_CLASS_PARAM_H_
#define XGBOOST_ENUM_CLASS_PARAM_H_
#ifndef XGBOOST_PARAMETER_H_
#define XGBOOST_PARAMETER_H_
#include <dmlc/parameter.h>
#include <xgboost/base.h>
#include <string>
#include <type_traits>
@@ -78,4 +79,27 @@ class FieldEntry<EnumClass> : public FieldEntry<int> { \
} /* namespace parameter */ \
} /* namespace dmlc */
#endif // XGBOOST_ENUM_CLASS_PARAM_H_
namespace xgboost {
template <typename Type>
struct XGBoostParameter : public dmlc::Parameter<Type> {
protected:
bool initialised_ {false};
public:
template <typename Container>
Args UpdateAllowUnknown(Container const& kwargs, bool* out_changed = nullptr) {
if (initialised_) {
return dmlc::Parameter<Type>::UpdateAllowUnknown(kwargs, out_changed);
} else {
auto unknown = dmlc::Parameter<Type>::InitAllowUnknown(kwargs);
if (out_changed) {
*out_changed = true;
}
initialised_ = true;
return unknown;
}
}
};
} // namespace xgboost
#endif // XGBOOST_PARAMETER_H_

View File

@@ -14,6 +14,7 @@
#include <xgboost/data.h>
#include <xgboost/logging.h>
#include <xgboost/feature_map.h>
#include <xgboost/model.h>
#include <limits>
#include <vector>
@@ -93,7 +94,7 @@ struct RTreeNodeStat {
* \brief define regression tree to be the most common tree model.
* This is the data structure used in xgboost's major tree models.
*/
class RegTree {
class RegTree : public Model {
public:
/*! \brief auxiliary statistics of node to help tree building */
using SplitCondT = bst_float;
@@ -289,38 +290,17 @@ class RegTree {
const RTreeNodeStat& Stat(int nid) const {
return stats_[nid];
}
/*!
* \brief load model from stream
* \param fi input stream
*/
void Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam));
nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes);
CHECK_NE(param.num_nodes, 0);
CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()),
sizeof(Node) * nodes_.size());
CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * stats_.size()),
sizeof(RTreeNodeStat) * stats_.size());
// chg deleted nodes
deleted_nodes_.resize(0);
for (int i = param.num_roots; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i);
}
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
}
void LoadModel(dmlc::Stream* fi) override;
/*!
* \brief save model to stream
* \param fo output stream
*/
void Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
fo->Write(&param, sizeof(TreeParam));
CHECK_NE(param.num_nodes, 0);
fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size());
fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());
}
void SaveModel(dmlc::Stream* fo) const override;
bool operator==(const RegTree& b) const {
return nodes_ == b.nodes_ && stats_ == b.stats_ &&