Add Model and Configurable interface. (#4945)
* Apply Configurable to objective functions. * Apply Model to Learner and Regtree, gbm. * Add Load/SaveConfig to objs. * Refactor obj tests to use smart pointer. * Dummy methods for Save/Load Model.
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/enum_class_param.h>
|
||||
#include <xgboost/parameter.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
|
||||
@@ -322,6 +322,9 @@ class Json {
|
||||
static void Dump(Json json, std::ostream* stream,
|
||||
bool pretty = ConsoleLogger::ShouldLog(
|
||||
ConsoleLogger::LogVerbosity::kDebug));
|
||||
static void Dump(Json json, std::string* out,
|
||||
bool pretty = ConsoleLogger::ShouldLog(
|
||||
ConsoleLogger::LogVerbosity::kDebug));
|
||||
|
||||
Json() : ptr_{new JsonNull} {}
|
||||
|
||||
@@ -400,6 +403,13 @@ class Json {
|
||||
return *ptr_ == *(rhs.ptr_);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, Json const& j) {
|
||||
std::string str;
|
||||
Json::Dump(j, &str);
|
||||
os << str;
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Value> ptr_;
|
||||
};
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include <xgboost/objective.h>
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/model.h>
|
||||
|
||||
#include <utility>
|
||||
#include <map>
|
||||
@@ -41,7 +42,7 @@ namespace xgboost {
|
||||
*
|
||||
* \endcode
|
||||
*/
|
||||
class Learner : public rabit::Serializable {
|
||||
class Learner : public Model, public rabit::Serializable {
|
||||
public:
|
||||
/*! \brief virtual destructor */
|
||||
~Learner() override = default;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
/*!
|
||||
* Copyright (c) 2015-2019 by Contributors
|
||||
* \file logging.h
|
||||
* \brief defines console logging options for xgboost.
|
||||
* Use to enforce unified print behavior.
|
||||
* For debug loggers, use LOG(INFO) and LOG(ERROR).
|
||||
*
|
||||
* \brief defines console logging options for xgboost. Use to enforce unified print
|
||||
* behavior.
|
||||
*/
|
||||
#ifndef XGBOOST_LOGGING_H_
|
||||
#define XGBOOST_LOGGING_H_
|
||||
|
||||
44
include/xgboost/model.h
Normal file
44
include/xgboost/model.h
Normal file
@@ -0,0 +1,44 @@
|
||||
/*!
|
||||
* Copyright (c) 2019 by Contributors
|
||||
* \file model.h
|
||||
* \brief Defines the abstract interface for different components in XGBoost.
|
||||
*/
|
||||
#ifndef XGBOOST_MODEL_H_
|
||||
#define XGBOOST_MODEL_H_
|
||||
|
||||
namespace dmlc {
|
||||
class Stream;
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class Json;
|
||||
|
||||
struct Model {
|
||||
/*!
|
||||
* \brief Save the model to stream.
|
||||
* \param fo output write stream
|
||||
*/
|
||||
virtual void SaveModel(dmlc::Stream* fo) const = 0;
|
||||
/*!
|
||||
* \brief Load the model from stream.
|
||||
* \param fi input read stream
|
||||
*/
|
||||
virtual void LoadModel(dmlc::Stream* fi) = 0;
|
||||
};
|
||||
|
||||
struct Configurable {
|
||||
/*!
|
||||
* \brief Load configuration from JSON object
|
||||
* \param in JSON object containing the configuration
|
||||
*/
|
||||
virtual void LoadConfig(Json const& in) = 0;
|
||||
/*!
|
||||
* \brief Save configuration to JSON object
|
||||
* \param out pointer to output JSON object
|
||||
*/
|
||||
virtual void SaveConfig(Json* out) const = 0;
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_MODEL_H_
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
|
||||
@@ -21,7 +22,7 @@
|
||||
namespace xgboost {
|
||||
|
||||
/*! \brief interface of objective function */
|
||||
class ObjFunction {
|
||||
class ObjFunction : public Configurable {
|
||||
protected:
|
||||
GenericParameter const* tparam_;
|
||||
|
||||
|
||||
@@ -5,10 +5,11 @@
|
||||
* \author Hyunsu Philip Cho
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_ENUM_CLASS_PARAM_H_
|
||||
#define XGBOOST_ENUM_CLASS_PARAM_H_
|
||||
#ifndef XGBOOST_PARAMETER_H_
|
||||
#define XGBOOST_PARAMETER_H_
|
||||
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -78,4 +79,27 @@ class FieldEntry<EnumClass> : public FieldEntry<int> { \
|
||||
} /* namespace parameter */ \
|
||||
} /* namespace dmlc */
|
||||
|
||||
#endif // XGBOOST_ENUM_CLASS_PARAM_H_
|
||||
namespace xgboost {
|
||||
template <typename Type>
|
||||
struct XGBoostParameter : public dmlc::Parameter<Type> {
|
||||
protected:
|
||||
bool initialised_ {false};
|
||||
|
||||
public:
|
||||
template <typename Container>
|
||||
Args UpdateAllowUnknown(Container const& kwargs, bool* out_changed = nullptr) {
|
||||
if (initialised_) {
|
||||
return dmlc::Parameter<Type>::UpdateAllowUnknown(kwargs, out_changed);
|
||||
} else {
|
||||
auto unknown = dmlc::Parameter<Type>::InitAllowUnknown(kwargs);
|
||||
if (out_changed) {
|
||||
*out_changed = true;
|
||||
}
|
||||
initialised_ = true;
|
||||
return unknown;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_PARAMETER_H_
|
||||
@@ -14,6 +14,7 @@
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/model.h>
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
@@ -93,7 +94,7 @@ struct RTreeNodeStat {
|
||||
* \brief define regression tree to be the most common tree model.
|
||||
* This is the data structure used in xgboost's major tree models.
|
||||
*/
|
||||
class RegTree {
|
||||
class RegTree : public Model {
|
||||
public:
|
||||
/*! \brief auxiliary statistics of node to help tree building */
|
||||
using SplitCondT = bst_float;
|
||||
@@ -289,38 +290,17 @@ class RegTree {
|
||||
const RTreeNodeStat& Stat(int nid) const {
|
||||
return stats_[nid];
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream
|
||||
*/
|
||||
void Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
CHECK_NE(param.num_nodes, 0);
|
||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()),
|
||||
sizeof(Node) * nodes_.size());
|
||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * stats_.size()),
|
||||
sizeof(RTreeNodeStat) * stats_.size());
|
||||
// chg deleted nodes
|
||||
deleted_nodes_.resize(0);
|
||||
for (int i = param.num_roots; i < param.num_nodes; ++i) {
|
||||
if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i);
|
||||
}
|
||||
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
|
||||
}
|
||||
void LoadModel(dmlc::Stream* fi) override;
|
||||
/*!
|
||||
* \brief save model to stream
|
||||
* \param fo output stream
|
||||
*/
|
||||
void Save(dmlc::Stream* fo) const {
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
||||
fo->Write(¶m, sizeof(TreeParam));
|
||||
CHECK_NE(param.num_nodes, 0);
|
||||
fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size());
|
||||
fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());
|
||||
}
|
||||
void SaveModel(dmlc::Stream* fo) const override;
|
||||
|
||||
bool operator==(const RegTree& b) const {
|
||||
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
|
||||
|
||||
Reference in New Issue
Block a user