Add Model and Configurable interface. (#4945)
* Apply Configurable to objective functions. * Apply Model to Learner and Regtree, gbm. * Add Load/SaveConfig to objs. * Refactor obj tests to use smart pointer. * Dummy methods for Save/Load Model.
This commit is contained in:
@@ -718,5 +718,13 @@ void Json::Dump(Json json, std::ostream *stream, bool pretty) {
|
||||
writer.Save(json);
|
||||
}
|
||||
|
||||
void Json::Dump(Json json, std::string* str, bool pretty) {
|
||||
GlobalCLocale guard;
|
||||
std::stringstream ss;
|
||||
JsonWriter writer(&ss, pretty);
|
||||
writer.Save(json);
|
||||
*str = ss.str();
|
||||
}
|
||||
|
||||
Json& Json::operator=(Json const &other) = default;
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
@@ -34,7 +35,7 @@ struct GBLinearModelParam : public dmlc::Parameter<GBLinearModelParam> {
|
||||
};
|
||||
|
||||
// model for linear booster
|
||||
class GBLinearModel {
|
||||
class GBLinearModel : public Model {
|
||||
public:
|
||||
// parameter
|
||||
GBLinearModelParam param;
|
||||
@@ -57,6 +58,17 @@ class GBLinearModel {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param));
|
||||
fi->Read(&weight);
|
||||
}
|
||||
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Load(fi);
|
||||
}
|
||||
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Save(fo);
|
||||
}
|
||||
|
||||
// model bias
|
||||
inline bst_float* bias() {
|
||||
return &weight[param.num_feature * param.num_output_group];
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include <xgboost/gbm.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <xgboost/enum_class_param.h>
|
||||
#include <xgboost/parameter.h>
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
#include <dmlc/parameter.h>
|
||||
#include <dmlc/io.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
|
||||
#include <memory>
|
||||
@@ -61,7 +62,7 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
||||
}
|
||||
};
|
||||
|
||||
struct GBTreeModel {
|
||||
struct GBTreeModel : public Model {
|
||||
explicit GBTreeModel(bst_float base_margin) : base_margin(base_margin) {}
|
||||
void Configure(const Args& cfg) {
|
||||
// initialize model parameters if not yet been initialized.
|
||||
@@ -81,6 +82,15 @@ struct GBTreeModel {
|
||||
}
|
||||
}
|
||||
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Load(fi);
|
||||
}
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Save(fo);
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param))
|
||||
<< "GBTree: invalid model file";
|
||||
@@ -88,7 +98,7 @@ struct GBTreeModel {
|
||||
trees_to_update.clear();
|
||||
for (int i = 0; i < param.num_trees; ++i) {
|
||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||
ptr->Load(fi);
|
||||
ptr->LoadModel(fi);
|
||||
trees.push_back(std::move(ptr));
|
||||
}
|
||||
tree_info.resize(param.num_trees);
|
||||
@@ -103,7 +113,7 @@ struct GBTreeModel {
|
||||
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
||||
fo->Write(¶m, sizeof(param));
|
||||
for (const auto & tree : trees) {
|
||||
tree->Save(fo);
|
||||
tree->SaveModel(fo);
|
||||
}
|
||||
if (tree_info.size() != 0) {
|
||||
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <dmlc/any.h>
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/learner.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <algorithm>
|
||||
@@ -196,6 +197,16 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
}
|
||||
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Load(fi);
|
||||
}
|
||||
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Save(fo);
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
generic_param_.InitAllowUnknown(Args{});
|
||||
tparam_.Init(std::vector<std::pair<std::string, std::string>>{});
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
/*!
|
||||
* Copyright 2018 by Contributors
|
||||
* Copyright 2018-2019 by Contributors
|
||||
* \file hinge.cc
|
||||
* \brief Provides an implementation of the hinge loss function
|
||||
* \author Henry Gouk
|
||||
*/
|
||||
#include "xgboost/objective.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
@@ -76,6 +77,12 @@ class HingeObj : public ObjFunction {
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "error";
|
||||
}
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("binary:hinge");
|
||||
}
|
||||
void LoadConfig(Json const& in) override {}
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/transform.h"
|
||||
@@ -25,7 +26,7 @@ namespace obj {
|
||||
DMLC_REGISTRY_FILE_TAG(multiclass_obj_gpu);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
|
||||
struct SoftmaxMultiClassParam : public XGBoostParameter<SoftmaxMultiClassParam> {
|
||||
int num_class;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) {
|
||||
@@ -37,10 +38,10 @@ struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
|
||||
class SoftmaxMultiClassObj : public ObjFunction {
|
||||
public:
|
||||
explicit SoftmaxMultiClassObj(bool output_prob)
|
||||
: output_prob_(output_prob) {
|
||||
}
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
: output_prob_(output_prob) {}
|
||||
|
||||
void Configure(Args const& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
@@ -155,6 +156,20 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
}
|
||||
}
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
if (this->output_prob_) {
|
||||
out["name"] = String("multi:softprob");
|
||||
} else {
|
||||
out["name"] = String("multi:softmax");
|
||||
}
|
||||
out["softmax_multiclass_param"] = toJson(param_);
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
fromJson(in["softmax_multiclass_param"], ¶m_);
|
||||
}
|
||||
|
||||
private:
|
||||
// output probability
|
||||
bool output_prob_;
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include <xgboost/objective.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
namespace dmlc {
|
||||
@@ -17,10 +19,12 @@ namespace xgboost {
|
||||
ObjFunction* ObjFunction::Create(const std::string& name, GenericParameter const* tparam) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
std::stringstream ss;
|
||||
for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) {
|
||||
LOG(INFO) << "Objective candidate: " << entry->name;
|
||||
ss << "Objective candidate: " << entry->name << "\n";
|
||||
}
|
||||
LOG(FATAL) << "Unknown objective function " << name;
|
||||
LOG(FATAL) << "Unknown objective function: `" << name << "`\n"
|
||||
<< ss.str();
|
||||
}
|
||||
auto pobj = (e->body)();
|
||||
pobj->tparam_ = tparam;
|
||||
|
||||
@@ -10,6 +10,10 @@
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/parameter.h"
|
||||
|
||||
#include "../common/math.h"
|
||||
#include "../common/random.h"
|
||||
|
||||
@@ -18,7 +22,7 @@ namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(rank_obj);
|
||||
|
||||
struct LambdaRankParam : public dmlc::Parameter<LambdaRankParam> {
|
||||
struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
|
||||
int num_pairsample;
|
||||
float fix_list_weight;
|
||||
// declare parameters
|
||||
@@ -35,7 +39,7 @@ struct LambdaRankParam : public dmlc::Parameter<LambdaRankParam> {
|
||||
class LambdaRankObj : public ObjFunction {
|
||||
public:
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
@@ -170,7 +174,16 @@ class LambdaRankObj : public ObjFunction {
|
||||
virtual void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
|
||||
std::vector<LambdaPair> *io_pairs) = 0;
|
||||
|
||||
private:
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("LambdaRankObj");
|
||||
out["lambda_rank_param"] = Object();
|
||||
for (auto const& kv : param_.__DICT__()) {
|
||||
out["lambda_rank_param"][kv.first] = kv.second;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
LambdaRankParam param_;
|
||||
};
|
||||
|
||||
@@ -178,6 +191,15 @@ class PairwiseRankObj: public LambdaRankObj{
|
||||
protected:
|
||||
void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
|
||||
std::vector<LambdaPair> *io_pairs) override {}
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("rank:pairwise");
|
||||
out["lambda_rank_param"] = toJson(LambdaRankObj::param_);
|
||||
}
|
||||
void LoadConfig(Json const& in) override {
|
||||
fromJson(in["lambda_rank_param"], &(LambdaRankObj::param_));
|
||||
}
|
||||
};
|
||||
|
||||
// beta version: NDCG lambda rank
|
||||
@@ -228,6 +250,14 @@ class LambdaRankObjNDCG : public LambdaRankObj {
|
||||
}
|
||||
return static_cast<bst_float>(sumdcg);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("rank:ndcg");
|
||||
out["lambda_rank_param"] = toJson(LambdaRankObj::param_);
|
||||
}
|
||||
void LoadConfig(Json const& in) override {
|
||||
fromJson(in["lambda_rank_param"], &(LambdaRankObj::param_));
|
||||
}
|
||||
};
|
||||
|
||||
class LambdaRankObjMAP : public LambdaRankObj {
|
||||
@@ -315,6 +345,15 @@ class LambdaRankObjMAP : public LambdaRankObj {
|
||||
pair.neg_index, &map_stats);
|
||||
}
|
||||
}
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("rank:map");
|
||||
out["lambda_rank_param"] = toJson(LambdaRankObj::param_);
|
||||
}
|
||||
void LoadConfig(Json const& in) override {
|
||||
fromJson(in["lambda_rank_param"], &(LambdaRankObj::param_));
|
||||
}
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
|
||||
@@ -34,6 +34,8 @@ struct LinearSquareLoss {
|
||||
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
|
||||
static const char* LabelErrorMsg() { return ""; }
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
|
||||
static const char* Name() { return "reg:squarederror"; }
|
||||
};
|
||||
|
||||
struct SquaredLogError {
|
||||
@@ -57,6 +59,8 @@ struct SquaredLogError {
|
||||
return "label must be greater than -1 for rmsle so that log(label + 1) can be valid.";
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "rmsle"; }
|
||||
|
||||
static const char* Name() { return "reg:squaredlogerror"; }
|
||||
};
|
||||
|
||||
// logistic loss for probability regression task
|
||||
@@ -83,18 +87,21 @@ struct LogisticRegression {
|
||||
}
|
||||
static bst_float ProbToMargin(bst_float base_score) {
|
||||
CHECK(base_score > 0.0f && base_score < 1.0f)
|
||||
<< "base_score must be in (0,1) for logistic loss";
|
||||
<< "base_score must be in (0,1) for logistic loss, got: " << base_score;
|
||||
return -logf(1.0f / base_score - 1.0f);
|
||||
}
|
||||
static const char* LabelErrorMsg() {
|
||||
return "label must be in [0,1] for logistic regression";
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
|
||||
static const char* Name() { return "reg:logistic"; }
|
||||
};
|
||||
|
||||
// logistic loss for binary classification task
|
||||
struct LogisticClassification : public LogisticRegression {
|
||||
static const char* DefaultEvalMetric() { return "error"; }
|
||||
static const char* Name() { return "binary:logistic"; }
|
||||
};
|
||||
|
||||
// logistic loss, but predict un-transformed margin
|
||||
@@ -125,6 +132,8 @@ struct LogisticRaw : public LogisticRegression {
|
||||
return std::max(predt * (T(1.0f) - predt), eps);
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "auc"; }
|
||||
|
||||
static const char* Name() { return "binary:logitraw"; }
|
||||
};
|
||||
|
||||
} // namespace obj
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/span.h"
|
||||
|
||||
#include "../common/transform.h"
|
||||
#include "../common/common.h"
|
||||
@@ -27,7 +29,7 @@ namespace obj {
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
struct RegLossParam : public dmlc::Parameter<RegLossParam> {
|
||||
struct RegLossParam : public XGBoostParameter<RegLossParam> {
|
||||
float scale_pos_weight;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(RegLossParam) {
|
||||
@@ -45,7 +47,7 @@ class RegLossObj : public ObjFunction {
|
||||
RegLossObj() = default;
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
@@ -114,6 +116,16 @@ class RegLossObj : public ObjFunction {
|
||||
return Loss::ProbToMargin(base_score);
|
||||
}
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String(Loss::Name());
|
||||
out["reg_loss_param"] = toJson(param_);
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
fromJson(in["reg_loss_param"], ¶m_);
|
||||
}
|
||||
|
||||
protected:
|
||||
RegLossParam param_;
|
||||
};
|
||||
@@ -121,23 +133,23 @@ class RegLossObj : public ObjFunction {
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(RegLossParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression, "reg:squarederror")
|
||||
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression, LinearSquareLoss::Name())
|
||||
.describe("Regression with squared error.")
|
||||
.set_body([]() { return new RegLossObj<LinearSquareLoss>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(SquareLogError, "reg:squaredlogerror")
|
||||
XGBOOST_REGISTER_OBJECTIVE(SquareLogError, SquaredLogError::Name())
|
||||
.describe("Regression with root mean squared logarithmic error.")
|
||||
.set_body([]() { return new RegLossObj<SquaredLogError>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, "reg:logistic")
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, LogisticRegression::Name())
|
||||
.describe("Logistic regression for probability regression task.")
|
||||
.set_body([]() { return new RegLossObj<LogisticRegression>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, "binary:logistic")
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, LogisticClassification::Name())
|
||||
.describe("Logistic regression for binary classification task.")
|
||||
.set_body([]() { return new RegLossObj<LogisticClassification>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw")
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, LogisticRaw::Name())
|
||||
.describe("Logistic regression for classification, output score "
|
||||
"before logistic transformation.")
|
||||
.set_body([]() { return new RegLossObj<LogisticRaw>(); });
|
||||
@@ -151,7 +163,7 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
|
||||
// End deprecated
|
||||
|
||||
// declare parameter
|
||||
struct PoissonRegressionParam : public dmlc::Parameter<PoissonRegressionParam> {
|
||||
struct PoissonRegressionParam : public XGBoostParameter<PoissonRegressionParam> {
|
||||
float max_delta_step;
|
||||
DMLC_DECLARE_PARAMETER(PoissonRegressionParam) {
|
||||
DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f)
|
||||
@@ -165,7 +177,7 @@ class PoissonRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
@@ -227,6 +239,16 @@ class PoissonRegression : public ObjFunction {
|
||||
return "poisson-nloglik";
|
||||
}
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("count:poisson");
|
||||
out["poisson_regression_param"] = toJson(param_);
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
fromJson(in["poisson_regression_param"], ¶m_);
|
||||
}
|
||||
|
||||
private:
|
||||
PoissonRegressionParam param_;
|
||||
HostDeviceVector<int> label_correct_;
|
||||
@@ -321,6 +343,12 @@ class CoxRegression : public ObjFunction {
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "cox-nloglik";
|
||||
}
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("survival:cox");
|
||||
}
|
||||
void LoadConfig(Json const&) override {}
|
||||
};
|
||||
|
||||
// register the objective function
|
||||
@@ -391,6 +419,11 @@ class GammaRegression : public ObjFunction {
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "gamma-nloglik";
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("reg:gamma");
|
||||
}
|
||||
void LoadConfig(Json const&) override {}
|
||||
|
||||
private:
|
||||
HostDeviceVector<int> label_correct_;
|
||||
@@ -403,7 +436,7 @@ XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
|
||||
|
||||
|
||||
// declare parameter
|
||||
struct TweedieRegressionParam : public dmlc::Parameter<TweedieRegressionParam> {
|
||||
struct TweedieRegressionParam : public XGBoostParameter<TweedieRegressionParam> {
|
||||
float tweedie_variance_power;
|
||||
DMLC_DECLARE_PARAMETER(TweedieRegressionParam) {
|
||||
DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f)
|
||||
@@ -416,7 +449,7 @@ class TweedieRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
param_.UpdateAllowUnknown(args);
|
||||
std::ostringstream os;
|
||||
os << "tweedie-nloglik@" << param_.tweedie_variance_power;
|
||||
metric_ = os.str();
|
||||
@@ -485,6 +518,15 @@ class TweedieRegression : public ObjFunction {
|
||||
return metric_.c_str();
|
||||
}
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("reg:tweedie");
|
||||
out["tweedie_regression_param"] = toJson(param_);
|
||||
}
|
||||
void LoadConfig(Json const& in) override {
|
||||
fromJson(in["tweedie_regression_param"], ¶m_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string metric_;
|
||||
TweedieRegressionParam param_;
|
||||
|
||||
@@ -617,6 +617,35 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||
return result;
|
||||
}
|
||||
|
||||
void RegTree::LoadModel(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
CHECK_NE(param.num_nodes, 0);
|
||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()),
|
||||
sizeof(Node) * nodes_.size());
|
||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * stats_.size()),
|
||||
sizeof(RTreeNodeStat) * stats_.size());
|
||||
// chg deleted nodes
|
||||
deleted_nodes_.resize(0);
|
||||
for (int i = param.num_roots; i < param.num_nodes; ++i) {
|
||||
if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i);
|
||||
}
|
||||
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
|
||||
}
|
||||
/*!
|
||||
* \brief save model to stream
|
||||
* \param fo output stream
|
||||
*/
|
||||
void RegTree::SaveModel(dmlc::Stream* fo) const {
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
||||
fo->Write(¶m, sizeof(TreeParam));
|
||||
CHECK_NE(param.num_nodes, 0);
|
||||
fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size());
|
||||
fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());
|
||||
}
|
||||
|
||||
void RegTree::FillNodeMeanValues() {
|
||||
size_t num_nodes = this->param.num_nodes;
|
||||
if (this->node_mean_values_.size() == num_nodes) {
|
||||
|
||||
@@ -1053,12 +1053,12 @@ class GPUHistMakerSpecialised {
|
||||
common::MemoryBufferStream fs(&s_model);
|
||||
int rank = rabit::GetRank();
|
||||
if (rank == 0) {
|
||||
local_trees.front().Save(&fs);
|
||||
local_trees.front().SaveModel(&fs);
|
||||
}
|
||||
fs.Seek(0);
|
||||
rabit::Broadcast(&s_model, 0);
|
||||
RegTree reference_tree{};
|
||||
reference_tree.Load(&fs);
|
||||
reference_tree.LoadModel(&fs);
|
||||
for (const auto& tree : local_trees) {
|
||||
CHECK(tree == reference_tree);
|
||||
}
|
||||
|
||||
@@ -35,13 +35,13 @@ class TreeSyncher: public TreeUpdater {
|
||||
int rank = rabit::GetRank();
|
||||
if (rank == 0) {
|
||||
for (auto tree : trees) {
|
||||
tree->Save(&fs);
|
||||
tree->SaveModel(&fs);
|
||||
}
|
||||
}
|
||||
fs.Seek(0);
|
||||
rabit::Broadcast(&s_model, 0);
|
||||
for (auto tree : trees) {
|
||||
tree->Load(&fs);
|
||||
tree->LoadModel(&fs);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user