Implement JSON IO for updaters (#5094)

* Implement JSON IO for updaters.

* Remove parameters in split evaluator.
This commit is contained in:
Jiaming Yuan 2019-12-07 00:24:00 +08:00 committed by GitHub
parent 2dcb62ddfb
commit 7ef5b78003
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 145 additions and 92 deletions

View File

@ -14,6 +14,7 @@
#include <xgboost/tree_model.h> #include <xgboost/tree_model.h>
#include <xgboost/generic_parameters.h> #include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/model.h>
#include <functional> #include <functional>
#include <vector> #include <vector>
@ -24,7 +25,7 @@ namespace xgboost {
/*! /*!
* \brief interface of tree update module, that performs update of a tree. * \brief interface of tree update module, that performs update of a tree.
*/ */
class TreeUpdater { class TreeUpdater : public Configurable {
protected: protected:
GenericParameter const* tparam_; GenericParameter const* tparam_;

View File

@ -69,10 +69,8 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
// whether refresh updater needs to update the leaf values // whether refresh updater needs to update the leaf values
bool refresh_leaf; bool refresh_leaf;
// FIXME(trivialfis): Following constraints are used by gpu
// algorithm, duplicated with those defined split evaluator due to
// their different code paths.
std::vector<int> monotone_constraints; std::vector<int> monotone_constraints;
// Stored as a JSON string.
std::string interaction_constraints; std::string interaction_constraints;
// the criteria to use for ranking splits // the criteria to use for ranking splits

View File

@ -46,7 +46,7 @@ SplitEvaluator* SplitEvaluator::Create(const std::string& name) {
} }
// Default implementations of some virtual methods that aren't always needed // Default implementations of some virtual methods that aren't always needed
void SplitEvaluator::Init(const Args& args) {} void SplitEvaluator::Init(const TrainParam* param) {}
void SplitEvaluator::Reset() {} void SplitEvaluator::Reset() {}
void SplitEvaluator::AddSplit(bst_uint nodeid, void SplitEvaluator::AddSplit(bst_uint nodeid,
bst_uint leftid, bst_uint leftid,
@ -64,36 +64,6 @@ bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid,
return ComputeSplitScore(nodeid, featureid, left_stats, right_stats, left_weight, right_weight); return ComputeSplitScore(nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
} }
//! \brief Encapsulates the parameters for ElasticNet
struct ElasticNetParams : public XGBoostParameter<ElasticNetParams> {
bst_float reg_lambda;
bst_float reg_alpha;
// maximum delta update we can add in weight estimation
// this parameter can be used to stabilize update
// default=0 means no constraint on weight delta
float max_delta_step;
DMLC_DECLARE_PARAMETER(ElasticNetParams) {
DMLC_DECLARE_FIELD(reg_lambda)
.set_lower_bound(0.0)
.set_default(1.0)
.describe("L2 regularization on leaf weight");
DMLC_DECLARE_FIELD(reg_alpha)
.set_lower_bound(0.0)
.set_default(0.0)
.describe("L1 regularization on leaf weight");
DMLC_DECLARE_FIELD(max_delta_step)
.set_lower_bound(0.0f)
.set_default(0.0f)
.describe("Maximum delta step we allow each tree's weight estimate to be. "\
"If the value is set to 0, it means there is no constraint");
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
}
};
DMLC_REGISTER_PARAMETER(ElasticNetParams);
/*! \brief Applies an elastic net penalty and per-leaf penalty. */ /*! \brief Applies an elastic net penalty and per-leaf penalty. */
class ElasticNet final : public SplitEvaluator { class ElasticNet final : public SplitEvaluator {
public: public:
@ -102,13 +72,14 @@ class ElasticNet final : public SplitEvaluator {
LOG(FATAL) << "ElasticNet does not accept an inner SplitEvaluator"; LOG(FATAL) << "ElasticNet does not accept an inner SplitEvaluator";
} }
} }
void Init(const Args& args) override { void Init(const TrainParam* param) override {
params_.UpdateAllowUnknown(args); params_ = param;
} }
SplitEvaluator* GetHostClone() const override { SplitEvaluator* GetHostClone() const override {
auto r = new ElasticNet(nullptr); auto r = new ElasticNet(nullptr);
r->params_ = this->params_; r->params_ = this->params_;
CHECK(r->params_);
return r; return r;
} }
@ -133,14 +104,14 @@ class ElasticNet final : public SplitEvaluator {
bst_float ComputeScore(bst_uint parentID, const GradStats &stats, bst_float weight) bst_float ComputeScore(bst_uint parentID, const GradStats &stats, bst_float weight)
const override { const override {
auto loss = weight * (2.0 * stats.sum_grad + stats.sum_hess * weight auto loss = weight * (2.0 * stats.sum_grad + stats.sum_hess * weight
+ params_.reg_lambda * weight) + params_->reg_lambda * weight)
+ 2.0 * params_.reg_alpha * std::abs(weight); + 2.0 * params_->reg_alpha * std::abs(weight);
return -loss; return -loss;
} }
bst_float ComputeScore(bst_uint parentID, const GradStats &stats) const { bst_float ComputeScore(bst_uint parentID, const GradStats &stats) const {
if (params_.max_delta_step == 0.0f) { if (params_->max_delta_step == 0.0f) {
return Sqr(ThresholdL1(stats.sum_grad)) / (stats.sum_hess + params_.reg_lambda); return Sqr(ThresholdL1(stats.sum_grad)) / (stats.sum_hess + params_->reg_lambda);
} else { } else {
return ComputeScore(parentID, stats, ComputeWeight(parentID, stats)); return ComputeScore(parentID, stats, ComputeWeight(parentID, stats));
} }
@ -148,21 +119,21 @@ class ElasticNet final : public SplitEvaluator {
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats) bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
const override { const override {
bst_float w = -ThresholdL1(stats.sum_grad) / (stats.sum_hess + params_.reg_lambda); bst_float w = -ThresholdL1(stats.sum_grad) / (stats.sum_hess + params_->reg_lambda);
if (params_.max_delta_step != 0.0f && std::abs(w) > params_.max_delta_step) { if (params_->max_delta_step != 0.0f && std::abs(w) > params_->max_delta_step) {
w = std::copysign(params_.max_delta_step, w); w = std::copysign(params_->max_delta_step, w);
} }
return w; return w;
} }
private: private:
ElasticNetParams params_; TrainParam const* params_;
inline double ThresholdL1(double g) const { inline double ThresholdL1(double g) const {
if (g > params_.reg_alpha) { if (g > params_->reg_alpha) {
return g - params_.reg_alpha; return g - params_->reg_alpha;
} else if (g < -params_.reg_alpha) { } else if (g < -params_->reg_alpha) {
return g + params_.reg_alpha; return g + params_->reg_alpha;
} else { } else {
return 0.0; return 0.0;
} }
@ -175,22 +146,6 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(ElasticNet, "elastic_net")
return new ElasticNet(std::move(inner)); return new ElasticNet(std::move(inner));
}); });
/*! \brief Encapsulates the parameters required by the MonotonicConstraint
split evaluator
*/
struct MonotonicConstraintParams
: public XGBoostParameter<MonotonicConstraintParams> {
std::vector<bst_int> monotone_constraints;
DMLC_DECLARE_PARAMETER(MonotonicConstraintParams) {
DMLC_DECLARE_FIELD(monotone_constraints)
.set_default(std::vector<bst_int>())
.describe("Constraint of variable monotonicity");
}
};
DMLC_REGISTER_PARAMETER(MonotonicConstraintParams);
/*! \brief Enforces that the tree is monotonically increasing/decreasing with respect to a user specified set of /*! \brief Enforces that the tree is monotonically increasing/decreasing with respect to a user specified set of
features. features.
*/ */
@ -203,10 +158,9 @@ class MonotonicConstraint final : public SplitEvaluator {
inner_ = std::move(inner); inner_ = std::move(inner);
} }
void Init(const Args& args) void Init(const TrainParam* param) override {
override { inner_->Init(param);
inner_->Init(args); params_ = param;
params_.UpdateAllowUnknown(args);
Reset(); Reset();
} }
@ -216,13 +170,14 @@ class MonotonicConstraint final : public SplitEvaluator {
} }
SplitEvaluator* GetHostClone() const override { SplitEvaluator* GetHostClone() const override {
if (params_.monotone_constraints.size() == 0) { if (params_->monotone_constraints.size() == 0) {
// No monotone constraints specified, just return a clone of inner to speed things up // No monotone constraints specified, just return a clone of inner to speed things up
return inner_->GetHostClone(); return inner_->GetHostClone();
} else { } else {
auto c = new MonotonicConstraint( auto c = new MonotonicConstraint(
std::unique_ptr<SplitEvaluator>(inner_->GetHostClone())); std::unique_ptr<SplitEvaluator>(inner_->GetHostClone()));
c->params_ = this->params_; c->params_ = this->params_;
CHECK(c->params_);
c->Reset(); c->Reset();
return c; return c;
} }
@ -300,14 +255,14 @@ class MonotonicConstraint final : public SplitEvaluator {
} }
private: private:
MonotonicConstraintParams params_; TrainParam const* params_;
std::unique_ptr<SplitEvaluator> inner_; std::unique_ptr<SplitEvaluator> inner_;
std::vector<bst_float> lower_; std::vector<bst_float> lower_;
std::vector<bst_float> upper_; std::vector<bst_float> upper_;
inline bst_int GetConstraint(bst_uint featureid) const { inline bst_int GetConstraint(bst_uint featureid) const {
if (featureid < params_.monotone_constraints.size()) { if (featureid < params_->monotone_constraints.size()) {
return params_.monotone_constraints[featureid]; return params_->monotone_constraints[featureid];
} else { } else {
return 0; return 0;
} }

View File

@ -16,6 +16,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "param.h"
#define ROOT_PARENT_ID (-1 & ((1U << 31) - 1)) #define ROOT_PARENT_ID (-1 & ((1U << 31) - 1))
namespace xgboost { namespace xgboost {
@ -32,7 +34,7 @@ class SplitEvaluator {
virtual ~SplitEvaluator() = default; virtual ~SplitEvaluator() = default;
// Used to initialise any regularisation hyperparameters provided by the user // Used to initialise any regularisation hyperparameters provided by the user
virtual void Init(const Args& args); virtual void Init(const TrainParam* param);
// Resets the SplitEvaluator to the state it was in after the Init was called // Resets the SplitEvaluator to the state it was in after the Init was called
virtual void Reset(); virtual void Reset();

View File

@ -17,6 +17,7 @@
#include <utility> #include <utility>
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/json.h"
#include "xgboost/tree_updater.h" #include "xgboost/tree_updater.h"
#include "param.h" #include "param.h"
#include "constraints.h" #include "constraints.h"
@ -37,6 +38,15 @@ class BaseMaker: public TreeUpdater {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
} }
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
fromJson(config.at("train_param"), &this->param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["train_param"] = toJson(param_);
}
protected: protected:
// helper to collect and query feature meta information // helper to collect and query feature meta information
struct FMetaHelper { struct FMetaHelper {

View File

@ -12,6 +12,7 @@
#include <cmath> #include <cmath>
#include <algorithm> #include <algorithm>
#include "xgboost/json.h"
#include "param.h" #include "param.h"
#include "constraints.h" #include "constraints.h"
#include "../common/random.h" #include "../common/random.h"
@ -28,8 +29,19 @@ class ColMaker: public TreeUpdater {
public: public:
void Configure(const Args& args) override { void Configure(const Args& args) override {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator)); if (!spliteval_) {
spliteval_->Init(args); spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
}
spliteval_->Init(&param_);
}
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
fromJson(config.at("train_param"), &this->param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["train_param"] = toJson(param_);
} }
char const* Name() const override { char const* Name() const override {
@ -705,7 +717,7 @@ class DistColMaker : public ColMaker {
pruner_.reset(TreeUpdater::Create("prune", tparam_)); pruner_.reset(TreeUpdater::Create("prune", tparam_));
pruner_->Configure(args); pruner_->Configure(args);
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator)); spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
spliteval_->Init(args); spliteval_->Init(&param_);
} }
char const* Name() const override { char const* Name() const override {

View File

@ -18,6 +18,7 @@
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/parameter.h" #include "xgboost/parameter.h"
#include "xgboost/span.h" #include "xgboost/span.h"
#include "xgboost/json.h"
#include "../common/common.h" #include "../common/common.h"
#include "../common/compressed_iterator.h" #include "../common/compressed_iterator.h"
@ -1028,7 +1029,6 @@ class GPUHistMakerSpecialised {
hist_maker_param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args);
device_ = generic_param_->gpu_id; device_ = generic_param_->gpu_id;
CHECK_GE(device_, 0) << "Must have at least one device"; CHECK_GE(device_, 0) << "Must have at least one device";
dh::CheckComputeCapability(); dh::CheckComputeCapability();
monitor_.Init("updater_gpu_hist"); monitor_.Init("updater_gpu_hist");
@ -1129,8 +1129,7 @@ class GPUHistMakerSpecialised {
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_); maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
} }
bool UpdatePredictionCache( bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false; return false;
} }
@ -1141,8 +1140,8 @@ class GPUHistMakerSpecialised {
return true; return true;
} }
TrainParam param_; // NOLINT TrainParam param_; // NOLINT
MetaInfo* info_{}; // NOLINT MetaInfo* info_{}; // NOLINT
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
@ -1175,6 +1174,27 @@ class GPUHistMaker : public TreeUpdater {
} }
} }
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
fromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
fromJson(config.at("train_param"), &float_maker_->param_);
} else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
fromJson(config.at("train_param"), &double_maker_->param_);
}
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["gpu_hist_train_param"] = toJson(hist_maker_param_);
if (hist_maker_param_.single_precision_histogram) {
out["train_param"] = toJson(float_maker_->param_);
} else {
out["train_param"] = toJson(double_maker_->param_);
}
}
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override { const std::vector<RegTree*>& trees) override {
if (hist_maker_param_.single_precision_histogram) { if (hist_maker_param_.single_precision_histogram) {

View File

@ -10,6 +10,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "xgboost/json.h"
#include "./param.h" #include "./param.h"
#include "../common/io.h" #include "../common/io.h"
@ -33,6 +34,16 @@ class TreePruner: public TreeUpdater {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
syncher_->Configure(args); syncher_->Configure(args);
} }
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
fromJson(config.at("train_param"), &this->param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["train_param"] = toJson(param_);
}
// update the tree, do pruning // update the tree, do pruning
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *p_fmat, DMatrix *p_fmat,

View File

@ -6,8 +6,6 @@
*/ */
#include <dmlc/timer.h> #include <dmlc/timer.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <xgboost/logging.h>
#include <xgboost/tree_updater.h>
#include <cmath> #include <cmath>
#include <memory> #include <memory>
@ -19,10 +17,13 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "./param.h" #include "xgboost/logging.h"
#include "xgboost/tree_updater.h"
#include "constraints.h"
#include "param.h"
#include "./updater_quantile_hist.h" #include "./updater_quantile_hist.h"
#include "./split_evaluator.h" #include "./split_evaluator.h"
#include "constraints.h"
#include "../common/random.h" #include "../common/random.h"
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "../common/row_set.h" #include "../common/row_set.h"
@ -47,21 +48,19 @@ void QuantileHistMaker::Configure(const Args& args) {
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator)); spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
} }
spliteval_->Init(args); spliteval_->Init(&param_);
} }
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair, void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat, DMatrix *dmat,
const std::vector<RegTree *> &trees) { const std::vector<RegTree *> &trees) {
if (is_gmat_initialized_ == false) { if (is_gmat_initialized_ == false) {
double tstart = dmlc::GetTime();
gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin)); gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
column_matrix_.Init(gmat_, param_.sparse_threshold); column_matrix_.Init(gmat_, param_.sparse_threshold);
if (param_.enable_feature_grouping > 0) { if (param_.enable_feature_grouping > 0) {
gmatb_.Init(gmat_, column_matrix_, param_); gmatb_.Init(gmat_, column_matrix_, param_);
} }
is_gmat_initialized_ = true; is_gmat_initialized_ = true;
LOG(INFO) << "Generating gmat: " << dmlc::GetTime() - tstart << " sec";
} }
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
@ -386,7 +385,6 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(
} }
} }
} }
return true; return true;
} }

View File

@ -19,6 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "xgboost/json.h"
#include "constraints.h" #include "constraints.h"
#include "./param.h" #include "./param.h"
#include "./split_evaluator.h" #include "./split_evaluator.h"
@ -79,6 +80,7 @@ using xgboost::common::Column;
/*! \brief construct a tree using quantized feature values */ /*! \brief construct a tree using quantized feature values */
class QuantileHistMaker: public TreeUpdater { class QuantileHistMaker: public TreeUpdater {
public: public:
QuantileHistMaker() : is_gmat_initialized_{ false } {}
void Configure(const Args& args) override; void Configure(const Args& args) override;
void Update(HostDeviceVector<GradientPair>* gpair, void Update(HostDeviceVector<GradientPair>* gpair,
@ -88,6 +90,15 @@ class QuantileHistMaker: public TreeUpdater {
bool UpdatePredictionCache(const DMatrix* data, bool UpdatePredictionCache(const DMatrix* data,
HostDeviceVector<bst_float>* out_preds) override; HostDeviceVector<bst_float>* out_preds) override;
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
fromJson(config.at("train_param"), &this->param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["train_param"] = toJson(param_);
}
char const* Name() const override { char const* Name() const override {
return "grow_quantile_histmaker"; return "grow_quantile_histmaker";
} }

View File

@ -10,6 +10,7 @@
#include <vector> #include <vector>
#include <limits> #include <limits>
#include "xgboost/json.h"
#include "./param.h" #include "./param.h"
#include "../common/io.h" #include "../common/io.h"
@ -24,6 +25,14 @@ class TreeRefresher: public TreeUpdater {
void Configure(const Args& args) override { void Configure(const Args& args) override {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
} }
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
fromJson(config.at("train_param"), &this->param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["train_param"] = toJson(param_);
}
char const* Name() const override { char const* Name() const override {
return "refresh"; return "refresh";
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014-2019 by Contributors
* \file updater_sync.cc * \file updater_sync.cc
* \brief synchronize the tree in all distributed nodes * \brief synchronize the tree in all distributed nodes
*/ */
@ -7,6 +7,8 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <limits> #include <limits>
#include "xgboost/json.h"
#include "../common/io.h" #include "../common/io.h"
namespace xgboost { namespace xgboost {
@ -22,6 +24,9 @@ class TreeSyncher: public TreeUpdater {
public: public:
void Configure(const Args& args) override {} void Configure(const Args& args) override {}
void LoadConfig(Json const& in) override {}
void SaveConfig(Json* p_out) const override {}
char const* Name() const override { char const* Name() const override {
return "prune"; return "prune";
} }

View File

@ -11,6 +11,7 @@
#include "../helpers.h" #include "../helpers.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "xgboost/json.h"
#include "../../../src/data/sparse_page_source.h" #include "../../../src/data/sparse_page_source.h"
#include "../../../src/gbm/gbtree_model.h" #include "../../../src/gbm/gbtree_model.h"
#include "../../../src/tree/updater_gpu_hist.cu" #include "../../../src/tree/updater_gpu_hist.cu"
@ -424,5 +425,24 @@ TEST(GpuHist, ExternalMemory) {
} }
} }
TEST(GpuHist, Config_IO) {
GenericParameter generic_param(CreateEmptyGenericParam(0));
std::unique_ptr<TreeUpdater> updater {TreeUpdater::Create("grow_gpu_hist", &generic_param) };
updater->Configure(Args{});
Json j_updater { Object() };
updater->SaveConfig(&j_updater);
ASSERT_TRUE(IsA<Object>(j_updater["gpu_hist_train_param"]));
ASSERT_TRUE(IsA<Object>(j_updater["train_param"]));
updater->LoadConfig(j_updater);
Json j_updater_roundtrip { Object() };
updater->SaveConfig(&j_updater_roundtrip);
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["gpu_hist_train_param"]));
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["train_param"]));
ASSERT_EQ(j_updater, j_updater_roundtrip);
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -162,7 +162,7 @@ class QuantileHistMock : public QuantileHistMaker {
} }
// Initialize split evaluator // Initialize split evaluator
std::unique_ptr<SplitEvaluator> evaluator(SplitEvaluator::Create("elastic_net")); std::unique_ptr<SplitEvaluator> evaluator(SplitEvaluator::Create("elastic_net"));
evaluator->Init({}); evaluator->Init(&param_);
// Now enumerate all feature*threshold combination to get best split // Now enumerate all feature*threshold combination to get best split
// To simplify logic, we make some assumptions: // To simplify logic, we make some assumptions:
@ -235,6 +235,7 @@ class QuantileHistMock : public QuantileHistMaker {
const std::vector<std::pair<std::string, std::string> >& args) : const std::vector<std::pair<std::string, std::string> >& args) :
cfg_{args} { cfg_{args} {
QuantileHistMaker::Configure(args); QuantileHistMaker::Configure(args);
spliteval_->Init(&param_);
builder_.reset( builder_.reset(
new BuilderMock( new BuilderMock(
param_, param_,