Implement JSON IO for updaters (#5094)
* Implement JSON IO for updaters. * Remove parameters in split evaluator.
This commit is contained in:
@@ -69,10 +69,8 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
||||
// whether refresh updater needs to update the leaf values
|
||||
bool refresh_leaf;
|
||||
|
||||
// FIXME(trivialfis): Following constraints are used by gpu
|
||||
// algorithm, duplicated with those defined split evaluator due to
|
||||
// their different code paths.
|
||||
std::vector<int> monotone_constraints;
|
||||
// Stored as a JSON string.
|
||||
std::string interaction_constraints;
|
||||
|
||||
// the criteria to use for ranking splits
|
||||
|
||||
@@ -46,7 +46,7 @@ SplitEvaluator* SplitEvaluator::Create(const std::string& name) {
|
||||
}
|
||||
|
||||
// Default implementations of some virtual methods that aren't always needed
|
||||
void SplitEvaluator::Init(const Args& args) {}
|
||||
void SplitEvaluator::Init(const TrainParam* param) {}
|
||||
void SplitEvaluator::Reset() {}
|
||||
void SplitEvaluator::AddSplit(bst_uint nodeid,
|
||||
bst_uint leftid,
|
||||
@@ -64,36 +64,6 @@ bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid,
|
||||
return ComputeSplitScore(nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
|
||||
}
|
||||
|
||||
//! \brief Encapsulates the parameters for ElasticNet
|
||||
struct ElasticNetParams : public XGBoostParameter<ElasticNetParams> {
|
||||
bst_float reg_lambda;
|
||||
bst_float reg_alpha;
|
||||
// maximum delta update we can add in weight estimation
|
||||
// this parameter can be used to stabilize update
|
||||
// default=0 means no constraint on weight delta
|
||||
float max_delta_step;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(ElasticNetParams) {
|
||||
DMLC_DECLARE_FIELD(reg_lambda)
|
||||
.set_lower_bound(0.0)
|
||||
.set_default(1.0)
|
||||
.describe("L2 regularization on leaf weight");
|
||||
DMLC_DECLARE_FIELD(reg_alpha)
|
||||
.set_lower_bound(0.0)
|
||||
.set_default(0.0)
|
||||
.describe("L1 regularization on leaf weight");
|
||||
DMLC_DECLARE_FIELD(max_delta_step)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.0f)
|
||||
.describe("Maximum delta step we allow each tree's weight estimate to be. "\
|
||||
"If the value is set to 0, it means there is no constraint");
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
DMLC_REGISTER_PARAMETER(ElasticNetParams);
|
||||
|
||||
/*! \brief Applies an elastic net penalty and per-leaf penalty. */
|
||||
class ElasticNet final : public SplitEvaluator {
|
||||
public:
|
||||
@@ -102,13 +72,14 @@ class ElasticNet final : public SplitEvaluator {
|
||||
LOG(FATAL) << "ElasticNet does not accept an inner SplitEvaluator";
|
||||
}
|
||||
}
|
||||
void Init(const Args& args) override {
|
||||
params_.UpdateAllowUnknown(args);
|
||||
void Init(const TrainParam* param) override {
|
||||
params_ = param;
|
||||
}
|
||||
|
||||
SplitEvaluator* GetHostClone() const override {
|
||||
auto r = new ElasticNet(nullptr);
|
||||
r->params_ = this->params_;
|
||||
CHECK(r->params_);
|
||||
|
||||
return r;
|
||||
}
|
||||
@@ -133,14 +104,14 @@ class ElasticNet final : public SplitEvaluator {
|
||||
bst_float ComputeScore(bst_uint parentID, const GradStats &stats, bst_float weight)
|
||||
const override {
|
||||
auto loss = weight * (2.0 * stats.sum_grad + stats.sum_hess * weight
|
||||
+ params_.reg_lambda * weight)
|
||||
+ 2.0 * params_.reg_alpha * std::abs(weight);
|
||||
+ params_->reg_lambda * weight)
|
||||
+ 2.0 * params_->reg_alpha * std::abs(weight);
|
||||
return -loss;
|
||||
}
|
||||
|
||||
bst_float ComputeScore(bst_uint parentID, const GradStats &stats) const {
|
||||
if (params_.max_delta_step == 0.0f) {
|
||||
return Sqr(ThresholdL1(stats.sum_grad)) / (stats.sum_hess + params_.reg_lambda);
|
||||
if (params_->max_delta_step == 0.0f) {
|
||||
return Sqr(ThresholdL1(stats.sum_grad)) / (stats.sum_hess + params_->reg_lambda);
|
||||
} else {
|
||||
return ComputeScore(parentID, stats, ComputeWeight(parentID, stats));
|
||||
}
|
||||
@@ -148,21 +119,21 @@ class ElasticNet final : public SplitEvaluator {
|
||||
|
||||
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
|
||||
const override {
|
||||
bst_float w = -ThresholdL1(stats.sum_grad) / (stats.sum_hess + params_.reg_lambda);
|
||||
if (params_.max_delta_step != 0.0f && std::abs(w) > params_.max_delta_step) {
|
||||
w = std::copysign(params_.max_delta_step, w);
|
||||
bst_float w = -ThresholdL1(stats.sum_grad) / (stats.sum_hess + params_->reg_lambda);
|
||||
if (params_->max_delta_step != 0.0f && std::abs(w) > params_->max_delta_step) {
|
||||
w = std::copysign(params_->max_delta_step, w);
|
||||
}
|
||||
return w;
|
||||
}
|
||||
|
||||
private:
|
||||
ElasticNetParams params_;
|
||||
TrainParam const* params_;
|
||||
|
||||
inline double ThresholdL1(double g) const {
|
||||
if (g > params_.reg_alpha) {
|
||||
return g - params_.reg_alpha;
|
||||
} else if (g < -params_.reg_alpha) {
|
||||
return g + params_.reg_alpha;
|
||||
if (g > params_->reg_alpha) {
|
||||
return g - params_->reg_alpha;
|
||||
} else if (g < -params_->reg_alpha) {
|
||||
return g + params_->reg_alpha;
|
||||
} else {
|
||||
return 0.0;
|
||||
}
|
||||
@@ -175,22 +146,6 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(ElasticNet, "elastic_net")
|
||||
return new ElasticNet(std::move(inner));
|
||||
});
|
||||
|
||||
/*! \brief Encapsulates the parameters required by the MonotonicConstraint
|
||||
split evaluator
|
||||
*/
|
||||
struct MonotonicConstraintParams
|
||||
: public XGBoostParameter<MonotonicConstraintParams> {
|
||||
std::vector<bst_int> monotone_constraints;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(MonotonicConstraintParams) {
|
||||
DMLC_DECLARE_FIELD(monotone_constraints)
|
||||
.set_default(std::vector<bst_int>())
|
||||
.describe("Constraint of variable monotonicity");
|
||||
}
|
||||
};
|
||||
|
||||
DMLC_REGISTER_PARAMETER(MonotonicConstraintParams);
|
||||
|
||||
/*! \brief Enforces that the tree is monotonically increasing/decreasing with respect to a user specified set of
|
||||
features.
|
||||
*/
|
||||
@@ -203,10 +158,9 @@ class MonotonicConstraint final : public SplitEvaluator {
|
||||
inner_ = std::move(inner);
|
||||
}
|
||||
|
||||
void Init(const Args& args)
|
||||
override {
|
||||
inner_->Init(args);
|
||||
params_.UpdateAllowUnknown(args);
|
||||
void Init(const TrainParam* param) override {
|
||||
inner_->Init(param);
|
||||
params_ = param;
|
||||
Reset();
|
||||
}
|
||||
|
||||
@@ -216,13 +170,14 @@ class MonotonicConstraint final : public SplitEvaluator {
|
||||
}
|
||||
|
||||
SplitEvaluator* GetHostClone() const override {
|
||||
if (params_.monotone_constraints.size() == 0) {
|
||||
if (params_->monotone_constraints.size() == 0) {
|
||||
// No monotone constraints specified, just return a clone of inner to speed things up
|
||||
return inner_->GetHostClone();
|
||||
} else {
|
||||
auto c = new MonotonicConstraint(
|
||||
std::unique_ptr<SplitEvaluator>(inner_->GetHostClone()));
|
||||
c->params_ = this->params_;
|
||||
CHECK(c->params_);
|
||||
c->Reset();
|
||||
return c;
|
||||
}
|
||||
@@ -300,14 +255,14 @@ class MonotonicConstraint final : public SplitEvaluator {
|
||||
}
|
||||
|
||||
private:
|
||||
MonotonicConstraintParams params_;
|
||||
TrainParam const* params_;
|
||||
std::unique_ptr<SplitEvaluator> inner_;
|
||||
std::vector<bst_float> lower_;
|
||||
std::vector<bst_float> upper_;
|
||||
|
||||
inline bst_int GetConstraint(bst_uint featureid) const {
|
||||
if (featureid < params_.monotone_constraints.size()) {
|
||||
return params_.monotone_constraints[featureid];
|
||||
if (featureid < params_->monotone_constraints.size()) {
|
||||
return params_->monotone_constraints[featureid];
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "param.h"
|
||||
|
||||
#define ROOT_PARENT_ID (-1 & ((1U << 31) - 1))
|
||||
|
||||
namespace xgboost {
|
||||
@@ -32,7 +34,7 @@ class SplitEvaluator {
|
||||
virtual ~SplitEvaluator() = default;
|
||||
|
||||
// Used to initialise any regularisation hyperparameters provided by the user
|
||||
virtual void Init(const Args& args);
|
||||
virtual void Init(const TrainParam* param);
|
||||
|
||||
// Resets the SplitEvaluator to the state it was in after the Init was called
|
||||
virtual void Reset();
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "param.h"
|
||||
#include "constraints.h"
|
||||
@@ -37,6 +38,15 @@ class BaseMaker: public TreeUpdater {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
fromJson(config.at("train_param"), &this->param_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["train_param"] = toJson(param_);
|
||||
}
|
||||
|
||||
protected:
|
||||
// helper to collect and query feature meta information
|
||||
struct FMetaHelper {
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "param.h"
|
||||
#include "constraints.h"
|
||||
#include "../common/random.h"
|
||||
@@ -28,8 +29,19 @@ class ColMaker: public TreeUpdater {
|
||||
public:
|
||||
void Configure(const Args& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
|
||||
spliteval_->Init(args);
|
||||
if (!spliteval_) {
|
||||
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
|
||||
}
|
||||
spliteval_->Init(¶m_);
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
fromJson(config.at("train_param"), &this->param_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["train_param"] = toJson(param_);
|
||||
}
|
||||
|
||||
char const* Name() const override {
|
||||
@@ -705,7 +717,7 @@ class DistColMaker : public ColMaker {
|
||||
pruner_.reset(TreeUpdater::Create("prune", tparam_));
|
||||
pruner_->Configure(args);
|
||||
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
|
||||
spliteval_->Init(args);
|
||||
spliteval_->Init(¶m_);
|
||||
}
|
||||
|
||||
char const* Name() const override {
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/compressed_iterator.h"
|
||||
@@ -1028,7 +1029,6 @@ class GPUHistMakerSpecialised {
|
||||
hist_maker_param_.UpdateAllowUnknown(args);
|
||||
device_ = generic_param_->gpu_id;
|
||||
CHECK_GE(device_, 0) << "Must have at least one device";
|
||||
|
||||
dh::CheckComputeCapability();
|
||||
|
||||
monitor_.Init("updater_gpu_hist");
|
||||
@@ -1129,8 +1129,7 @@ class GPUHistMakerSpecialised {
|
||||
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
@@ -1141,8 +1140,8 @@ class GPUHistMakerSpecialised {
|
||||
return true;
|
||||
}
|
||||
|
||||
TrainParam param_; // NOLINT
|
||||
MetaInfo* info_{}; // NOLINT
|
||||
TrainParam param_; // NOLINT
|
||||
MetaInfo* info_{}; // NOLINT
|
||||
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||
|
||||
@@ -1175,6 +1174,27 @@ class GPUHistMaker : public TreeUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
fromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
|
||||
fromJson(config.at("train_param"), &float_maker_->param_);
|
||||
} else {
|
||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
|
||||
fromJson(config.at("train_param"), &double_maker_->param_);
|
||||
}
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["gpu_hist_train_param"] = toJson(hist_maker_param_);
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
out["train_param"] = toJson(float_maker_->param_);
|
||||
} else {
|
||||
out["train_param"] = toJson(double_maker_->param_);
|
||||
}
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "./param.h"
|
||||
#include "../common/io.h"
|
||||
|
||||
@@ -33,6 +34,16 @@ class TreePruner: public TreeUpdater {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
syncher_->Configure(args);
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
fromJson(config.at("train_param"), &this->param_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["train_param"] = toJson(param_);
|
||||
}
|
||||
|
||||
// update the tree, do pruning
|
||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *p_fmat,
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
*/
|
||||
#include <dmlc/timer.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
@@ -19,10 +17,13 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "./param.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
|
||||
#include "constraints.h"
|
||||
#include "param.h"
|
||||
#include "./updater_quantile_hist.h"
|
||||
#include "./split_evaluator.h"
|
||||
#include "constraints.h"
|
||||
#include "../common/random.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/row_set.h"
|
||||
@@ -47,21 +48,19 @@ void QuantileHistMaker::Configure(const Args& args) {
|
||||
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
|
||||
}
|
||||
|
||||
spliteval_->Init(args);
|
||||
spliteval_->Init(¶m_);
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
if (is_gmat_initialized_ == false) {
|
||||
double tstart = dmlc::GetTime();
|
||||
gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
|
||||
column_matrix_.Init(gmat_, param_.sparse_threshold);
|
||||
if (param_.enable_feature_grouping > 0) {
|
||||
gmatb_.Init(gmat_, column_matrix_, param_);
|
||||
}
|
||||
is_gmat_initialized_ = true;
|
||||
LOG(INFO) << "Generating gmat: " << dmlc::GetTime() - tstart << " sec";
|
||||
}
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
@@ -386,7 +385,6 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "constraints.h"
|
||||
#include "./param.h"
|
||||
#include "./split_evaluator.h"
|
||||
@@ -79,6 +80,7 @@ using xgboost::common::Column;
|
||||
/*! \brief construct a tree using quantized feature values */
|
||||
class QuantileHistMaker: public TreeUpdater {
|
||||
public:
|
||||
QuantileHistMaker() : is_gmat_initialized_{ false } {}
|
||||
void Configure(const Args& args) override;
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair,
|
||||
@@ -88,6 +90,15 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds) override;
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
fromJson(config.at("train_param"), &this->param_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["train_param"] = toJson(param_);
|
||||
}
|
||||
|
||||
char const* Name() const override {
|
||||
return "grow_quantile_histmaker";
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "./param.h"
|
||||
#include "../common/io.h"
|
||||
|
||||
@@ -24,6 +25,14 @@ class TreeRefresher: public TreeUpdater {
|
||||
void Configure(const Args& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
fromJson(config.at("train_param"), &this->param_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["train_param"] = toJson(param_);
|
||||
}
|
||||
char const* Name() const override {
|
||||
return "refresh";
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* Copyright 2014-2019 by Contributors
|
||||
* \file updater_sync.cc
|
||||
* \brief synchronize the tree in all distributed nodes
|
||||
*/
|
||||
@@ -7,6 +7,8 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <limits>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "../common/io.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -22,6 +24,9 @@ class TreeSyncher: public TreeUpdater {
|
||||
public:
|
||||
void Configure(const Args& args) override {}
|
||||
|
||||
void LoadConfig(Json const& in) override {}
|
||||
void SaveConfig(Json* p_out) const override {}
|
||||
|
||||
char const* Name() const override {
|
||||
return "prune";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user