diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 09c16eff6..3f842f422 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -11,15 +11,16 @@ #include #include #include -#include #include #include #include +#include +#include -#include #include #include #include +#include #include namespace xgboost { @@ -307,11 +308,13 @@ struct LearnerModelParam { uint32_t num_feature { 0 }; /* \brief number of classes, if it is multi-class classification */ uint32_t num_output_group { 0 }; + /* \brief Current task, determined by objective. */ + ObjInfo task{ObjInfo::kRegression}; LearnerModelParam() = default; // As the old `LearnerModelParamLegacy` is still used by binary IO, we keep // this one as an immutable copy. - LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin); + LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, ObjInfo t); /* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */ bool Initialized() const { return num_feature != 0; } }; diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 3e722a18f..3cf85c41d 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -72,6 +73,11 @@ class ObjFunction : public Configurable { virtual bst_float ProbToMargin(bst_float base_score) const { return base_score; } + /*! + * \brief Return task of this objective. + */ + virtual struct ObjInfo Task() const = 0; + /*! * \brief Create an objective function according to name. * \param tparam Generic parameters. diff --git a/include/xgboost/task.h b/include/xgboost/task.h new file mode 100644 index 000000000..6430794c3 --- /dev/null +++ b/include/xgboost/task.h @@ -0,0 +1,39 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_TASK_H_ +#define XGBOOST_TASK_H_ + +#include + +namespace xgboost { +/*! + * \brief A struct returned by objective, which determines task at hand. The struct is + * not used by any algorithm yet, only for future development like categorical + * split. + * + * The task field is useful for tree split finding, also for some metrics like auc. + * Lastly, knowing whether hessian is constant can allow some optimizations like skipping + * the quantile sketching. + * + * This struct should not be serialized since it can be recovered from objective function, + * hence it doesn't need to be stable. + */ +struct ObjInfo { + // What kind of problem are we trying to solve + enum Task : uint8_t { + kRegression = 0, + kBinary = 1, + kClassification = 2, + kSurvival = 3, + kRanking = 4, + kOther = 5, + } task; + // Does the objective have constant hessian value? + bool const_hess{false}; + + explicit ObjInfo(Task t) : task{t} {} + ObjInfo(Task t, bool khess) : const_hess{khess} {} +}; +} // namespace xgboost +#endif // XGBOOST_TASK_H_ diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index f36005a9a..8e1a5bdb3 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -11,16 +11,17 @@ #include #include #include -#include #include #include -#include #include +#include +#include +#include #include -#include -#include #include +#include +#include namespace xgboost { @@ -83,7 +84,7 @@ class TreeUpdater : public Configurable { * \param name Name of the tree updater. * \param tparam A global runtime parameter */ - static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam); + static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam, ObjInfo task); }; /*! @@ -91,8 +92,7 @@ class TreeUpdater : public Configurable { */ struct TreeUpdaterReg : public dmlc::FunctionRegEntryBase > { -}; + std::function > {}; /*! * \brief Macro to register tree updater. diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index a18d8aecc..c38ad4fbd 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -34,6 +34,11 @@ class MyLogistic : public ObjFunction { void Configure(const std::vector >& args) override { param_.UpdateAllowUnknown(args); } + + struct ObjInfo Task() const override { + return {ObjInfo::kRegression, false}; + } + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 859e5ba9d..500140ce6 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -306,7 +306,8 @@ void GBTree::InitUpdater(Args const& cfg) { // create new updaters for (const std::string& pstr : ups) { - std::unique_ptr up(TreeUpdater::Create(pstr.c_str(), generic_param_)); + std::unique_ptr up( + TreeUpdater::Create(pstr.c_str(), generic_param_, model_.learner_model_param->task)); up->Configure(cfg); updaters_.push_back(std::move(up)); } @@ -391,7 +392,8 @@ void GBTree::LoadConfig(Json const& in) { auto const& j_updaters = get(in["updater"]); updaters_.clear(); for (auto const& kv : j_updaters) { - std::unique_ptr up(TreeUpdater::Create(kv.first, generic_param_)); + std::unique_ptr up( + TreeUpdater::Create(kv.first, generic_param_, model_.learner_model_param->task)); up->LoadConfig(kv.second); updaters_.push_back(std::move(up)); } diff --git a/src/learner.cc b/src/learner.cc index 399d299f5..bd5a845c6 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -159,13 +159,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter } }; -LearnerModelParam::LearnerModelParam( - LearnerModelParamLegacy const &user_param, float base_margin) - : base_score{base_margin}, num_feature{user_param.num_feature}, - num_output_group{user_param.num_class == 0 - ? 1 - : static_cast(user_param.num_class)} -{} +LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, + ObjInfo t) + : base_score{base_margin}, + num_feature{user_param.num_feature}, + num_output_group{user_param.num_class == 0 ? 1 : static_cast(user_param.num_class)}, + task{t} {} struct LearnerTrainParam : public XGBoostParameter { // data split mode, can be row, col, or none. @@ -339,8 +338,8 @@ class LearnerConfiguration : public Learner { // - model is created from scratch. // - model is configured second time due to change of parameter if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) { - learner_model_param_ = LearnerModelParam(mparam_, - obj_->ProbToMargin(mparam_.base_score)); + learner_model_param_ = + LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task()); } this->ConfigureGBM(old_tparam, args); @@ -832,7 +831,7 @@ class LearnerIO : public LearnerConfiguration { } learner_model_param_ = - LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score)); + LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task()); if (attributes_.find("objective") != attributes_.cend()) { auto obj_str = attributes_.at("objective"); auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()}); diff --git a/src/objective/aft_obj.cu b/src/objective/aft_obj.cu index 95e4f4c55..882402a0c 100644 --- a/src/objective/aft_obj.cu +++ b/src/objective/aft_obj.cu @@ -38,6 +38,8 @@ class AFTObj : public ObjFunction { param_.UpdateAllowUnknown(args); } + ObjInfo Task() const override { return {ObjInfo::kSurvival, false}; } + template void GetGradientImpl(const HostDeviceVector &preds, const MetaInfo &info, diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index 0c8c2f317..068a4eea6 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -27,6 +27,8 @@ class HingeObj : public ObjFunction { void Configure( const std::vector > &args) override {} + ObjInfo Task() const override { return {ObjInfo::kRegression, false}; } + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 6ffa6eac2..710428b00 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -45,6 +45,9 @@ class SoftmaxMultiClassObj : public ObjFunction { void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + + ObjInfo Task() const override { return {ObjInfo::kClassification, false}; } + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int iter, diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index 164b60611..228c54642 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -754,6 +754,8 @@ class LambdaRankObj : public ObjFunction { param_.UpdateAllowUnknown(args); } + ObjInfo Task() const override { return {ObjInfo::kRanking, false}; } + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int iter, diff --git a/src/objective/regression_loss.h b/src/objective/regression_loss.h index 54b95cfe1..30605b348 100644 --- a/src/objective/regression_loss.h +++ b/src/objective/regression_loss.h @@ -7,6 +7,8 @@ #include #include #include + +#include "xgboost/task.h" #include "../common/math.h" namespace xgboost { @@ -36,6 +38,7 @@ struct LinearSquareLoss { static const char* DefaultEvalMetric() { return "rmse"; } static const char* Name() { return "reg:squarederror"; } + static ObjInfo Info() { return {ObjInfo::kRegression, true}; } }; struct SquaredLogError { @@ -61,6 +64,8 @@ struct SquaredLogError { static const char* DefaultEvalMetric() { return "rmsle"; } static const char* Name() { return "reg:squaredlogerror"; } + + static ObjInfo Info() { return {ObjInfo::kRegression, false}; } }; // logistic loss for probability regression task @@ -96,6 +101,8 @@ struct LogisticRegression { static const char* DefaultEvalMetric() { return "rmse"; } static const char* Name() { return "reg:logistic"; } + + static ObjInfo Info() { return {ObjInfo::kRegression, false}; } }; struct PseudoHuberError { @@ -127,12 +134,14 @@ struct PseudoHuberError { static const char* Name() { return "reg:pseudohubererror"; } + static ObjInfo Info() { return {ObjInfo::kRegression, false}; } }; // logistic loss for binary classification task struct LogisticClassification : public LogisticRegression { static const char* DefaultEvalMetric() { return "logloss"; } static const char* Name() { return "binary:logistic"; } + static ObjInfo Info() { return {ObjInfo::kBinary, false}; } }; // logistic loss, but predict un-transformed margin @@ -168,6 +177,8 @@ struct LogisticRaw : public LogisticRegression { static const char* DefaultEvalMetric() { return "logloss"; } static const char* Name() { return "binary:logitraw"; } + + static ObjInfo Info() { return {ObjInfo::kRegression, false}; } }; } // namespace obj diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index ccb3a723d..ca9ec2c70 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -52,6 +52,10 @@ class RegLossObj : public ObjFunction { param_.UpdateAllowUnknown(args); } + struct ObjInfo Task() const override { + return Loss::Info(); + } + void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector* out_gpair) override { @@ -207,6 +211,10 @@ class PoissonRegression : public ObjFunction { param_.UpdateAllowUnknown(args); } + struct ObjInfo Task() const override { + return {ObjInfo::kRegression, false}; + } + void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { @@ -298,6 +306,10 @@ class CoxRegression : public ObjFunction { void Configure( const std::vector >&) override {} + struct ObjInfo Task() const override { + return {ObjInfo::kRegression, false}; + } + void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { @@ -395,6 +407,10 @@ class GammaRegression : public ObjFunction { void Configure( const std::vector >&) override {} + struct ObjInfo Task() const override { + return {ObjInfo::kRegression, false}; + } + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { @@ -491,6 +507,10 @@ class TweedieRegression : public ObjFunction { metric_ = os.str(); } + struct ObjInfo Task() const override { + return {ObjInfo::kRegression, false}; + } + void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index a619713e0..293dfb53a 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -14,12 +14,13 @@ DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg); namespace xgboost { -TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam) { - auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name); +TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam, + ObjInfo task) { + auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name); if (e == nullptr) { LOG(FATAL) << "Unknown tree updater " << name; } - auto p_updater = (e->body)(); + auto p_updater = (e->body)(task); p_updater->tparam_ = tparam; return p_updater; } diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 952a60f0f..3b0a74f36 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -628,7 +628,7 @@ class ColMaker: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker") .describe("Grow tree with parallelization over columns.") -.set_body([]() { +.set_body([](ObjInfo) { return new ColMaker(); }); } // namespace tree diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 561a364b5..708e4e8e5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -698,7 +698,7 @@ struct GPUHistMakerDevice { int right_child_nidx = tree[candidate.nid].RightChild(); // Only create child entries if needed if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), - num_leaves)) { + num_leaves)) { monitor.Start("UpdatePosition"); this->UpdatePosition(candidate.nid, p_tree); monitor.Stop("UpdatePosition"); @@ -732,7 +732,7 @@ struct GPUHistMakerDevice { template class GPUHistMakerSpecialised { public: - GPUHistMakerSpecialised() = default; + explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {}; void Configure(const Args& args, GenericParameter const* generic_param) { param_.UpdateAllowUnknown(args); generic_param_ = generic_param; @@ -859,12 +859,14 @@ class GPUHistMakerSpecialised { DMatrix* p_last_fmat_ { nullptr }; int device_{-1}; + ObjInfo task_; common::Monitor monitor_; }; class GPUHistMaker : public TreeUpdater { public: + explicit GPUHistMaker(ObjInfo task) : task_{task} {} void Configure(const Args& args) override { // Used in test to count how many configurations are performed LOG(DEBUG) << "[GPU Hist]: Configure"; @@ -878,11 +880,11 @@ class GPUHistMaker : public TreeUpdater { param = double_maker_->param_; } if (hist_maker_param_.single_precision_histogram) { - float_maker_.reset(new GPUHistMakerSpecialised()); + float_maker_.reset(new GPUHistMakerSpecialised(task_)); float_maker_->param_ = param; float_maker_->Configure(args, tparam_); } else { - double_maker_.reset(new GPUHistMakerSpecialised()); + double_maker_.reset(new GPUHistMakerSpecialised(task_)); double_maker_->param_ = param; double_maker_->Configure(args, tparam_); } @@ -892,10 +894,10 @@ class GPUHistMaker : public TreeUpdater { auto const& config = get(in); FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_); if (hist_maker_param_.single_precision_histogram) { - float_maker_.reset(new GPUHistMakerSpecialised()); + float_maker_.reset(new GPUHistMakerSpecialised(task_)); FromJson(config.at("train_param"), &float_maker_->param_); } else { - double_maker_.reset(new GPUHistMakerSpecialised()); + double_maker_.reset(new GPUHistMakerSpecialised(task_)); FromJson(config.at("train_param"), &double_maker_->param_); } } @@ -933,6 +935,7 @@ class GPUHistMaker : public TreeUpdater { private: GPUHistMakerTrainParam hist_maker_param_; + ObjInfo task_; std::unique_ptr> float_maker_; std::unique_ptr> double_maker_; }; @@ -940,7 +943,7 @@ class GPUHistMaker : public TreeUpdater { #if !defined(GTEST_TEST) XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") .describe("Grow tree with GPU.") - .set_body([]() { return new GPUHistMaker(); }); + .set_body([](ObjInfo task) { return new GPUHistMaker(task); }); #endif // !defined(GTEST_TEST) } // namespace tree diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 1c086b69a..ac040f14e 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -750,14 +750,14 @@ class GlobalProposalHistMaker: public CQHistMaker { XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") .describe("Tree constructor that uses approximate histogram construction.") -.set_body([]() { +.set_body([](ObjInfo) { return new CQHistMaker(); }); // The updater for approx tree method. XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker") .describe("Tree constructor that uses approximate global of histogram construction.") -.set_body([]() { +.set_body([](ObjInfo) { return new GlobalProposalHistMaker(); }); } // namespace tree diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index 76a8916a0..293f302cf 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -23,8 +23,8 @@ DMLC_REGISTRY_FILE_TAG(updater_prune); /*! \brief pruner that prunes a tree after growing finishes */ class TreePruner: public TreeUpdater { public: - TreePruner() { - syncher_.reset(TreeUpdater::Create("sync", tparam_)); + explicit TreePruner(ObjInfo task) { + syncher_.reset(TreeUpdater::Create("sync", tparam_, task)); pruner_monitor_.Init("TreePruner"); } char const* Name() const override { @@ -113,8 +113,8 @@ class TreePruner: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune") .describe("Pruner that prune the tree according to statistics.") -.set_body([]() { - return new TreePruner(); +.set_body([](ObjInfo task) { + return new TreePruner(task); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 19c300b30..ab0dd7082 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -40,7 +40,7 @@ DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam); void QuantileHistMaker::Configure(const Args& args) { // initialize pruner if (!pruner_) { - pruner_.reset(TreeUpdater::Create("prune", tparam_)); + pruner_.reset(TreeUpdater::Create("prune", tparam_, task_)); } pruner_->Configure(args); param_.UpdateAllowUnknown(args); @@ -52,7 +52,7 @@ void QuantileHistMaker::SetBuilder(const size_t n_trees, std::unique_ptr>* builder, DMatrix *dmat) { builder->reset( - new Builder(n_trees, param_, std::move(pruner_), dmat)); + new Builder(n_trees, param_, std::move(pruner_), dmat, task_)); } template @@ -529,11 +529,11 @@ void QuantileHistMaker::Builder::InitData( // store a pointer to the tree p_last_tree_ = &tree; if (data_layout_ == DataLayout::kDenseDataOneBased) { - evaluator_.reset(new HistEvaluator{ - param_, info, this->nthread_, column_sampler_, true}); + evaluator_.reset(new HistEvaluator{param_, info, this->nthread_, + column_sampler_, true}); } else { - evaluator_.reset(new HistEvaluator{ - param_, info, this->nthread_, column_sampler_, false}); + evaluator_.reset(new HistEvaluator{param_, info, this->nthread_, + column_sampler_, false}); } if (data_layout_ == DataLayout::kDenseDataZeroBased @@ -677,17 +677,17 @@ XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker") .describe("(Deprecated, use grow_quantile_histmaker instead.)" " Grow tree using quantized histogram.") .set_body( - []() { + [](ObjInfo task) { LOG(WARNING) << "grow_fast_histmaker is deprecated, " << "use grow_quantile_histmaker instead."; - return new QuantileHistMaker(); + return new QuantileHistMaker(task); }); XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") .set_body( - []() { - return new QuantileHistMaker(); + [](ObjInfo task) { + return new QuantileHistMaker(task); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 9654ab00a..a324d8a0c 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -95,7 +95,7 @@ using xgboost::common::Column; /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: - QuantileHistMaker() { + explicit QuantileHistMaker(ObjInfo task) : task_{task} { updater_monitor_.Init("QuantileHistMaker"); } void Configure(const Args& args) override; @@ -154,12 +154,15 @@ class QuantileHistMaker: public TreeUpdater { using GHistRowT = GHistRow; using GradientPairT = xgboost::detail::GradientPairInternal; // constructor - explicit Builder(const size_t n_trees, const TrainParam ¶m, - std::unique_ptr pruner, DMatrix const *fmat) - : n_trees_(n_trees), param_(param), pruner_(std::move(pruner)), - p_last_tree_(nullptr), p_last_fmat_(fmat), - histogram_builder_{ - new HistogramBuilder} { + explicit Builder(const size_t n_trees, const TrainParam& param, + std::unique_ptr pruner, DMatrix const* fmat, ObjInfo task) + : n_trees_(n_trees), + param_(param), + pruner_(std::move(pruner)), + p_last_tree_(nullptr), + p_last_fmat_(fmat), + histogram_builder_{new HistogramBuilder}, + task_{task} { builder_monitor_.Init("Quantile::Builder"); } ~Builder(); @@ -261,6 +264,7 @@ class QuantileHistMaker: public TreeUpdater { DataLayout data_layout_; std::unique_ptr> histogram_builder_; + ObjInfo task_; common::Monitor builder_monitor_; }; @@ -281,6 +285,7 @@ class QuantileHistMaker: public TreeUpdater { std::unique_ptr> double_builder_; std::unique_ptr pruner_; + ObjInfo task_; }; } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 1d54ad9e3..993899c7b 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -161,7 +161,7 @@ class TreeRefresher: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") .describe("Refresher that refreshes the weight and statistics according to data.") -.set_body([]() { +.set_body([](ObjInfo) { return new TreeRefresher(); }); } // namespace tree diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index 7979d10c2..4f7c7a1a8 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -53,7 +53,7 @@ class TreeSyncher: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync") .describe("Syncher that synchronize the tree in all distributed nodes.") -.set_body([]() { +.set_body([](ObjInfo) { return new TreeSyncher(); }); } // namespace tree diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 72c225396..faa51eac2 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -275,7 +275,8 @@ void TestHistogramIndexImpl() { int constexpr kNRows = 1000, kNCols = 10; // Build 2 matrices and build a histogram maker with that - tree::GPUHistMakerSpecialised hist_maker, hist_maker_ext; + tree::GPUHistMakerSpecialised hist_maker{ObjInfo{ObjInfo::kRegression}}, + hist_maker_ext{ObjInfo{ObjInfo::kRegression}}; std::unique_ptr hist_maker_dmat( CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true)); @@ -333,7 +334,7 @@ int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector hist_maker; + tree::GPUHistMakerSpecialised hist_maker{ObjInfo{ObjInfo::kRegression}}; GenericParameter generic_param(CreateEmptyGenericParam(0)); hist_maker.Configure(args, &generic_param); @@ -394,7 +395,7 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, {"sampling_method", sampling_method}, }; - tree::GPUHistMakerSpecialised hist_maker; + tree::GPUHistMakerSpecialised hist_maker{ObjInfo{ObjInfo::kRegression}}; GenericParameter generic_param(CreateEmptyGenericParam(0)); hist_maker.Configure(args, &generic_param); @@ -539,7 +540,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) { TEST(GpuHist, ConfigIO) { GenericParameter generic_param(CreateEmptyGenericParam(0)); - std::unique_ptr updater {TreeUpdater::Create("grow_gpu_hist", &generic_param) }; + std::unique_ptr updater{ + TreeUpdater::Create("grow_gpu_hist", &generic_param, ObjInfo{ObjInfo::kRegression})}; updater->Configure(Args{}); Json j_updater { Object() }; diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index e1cb3568d..56878b159 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -34,7 +34,8 @@ TEST(GrowHistMaker, InteractionConstraint) { RegTree tree; tree.param.num_feature = kCols; - std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m) }; + std::unique_ptr updater{ + TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})}; updater->Configure(Args{ {"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); @@ -51,7 +52,8 @@ TEST(GrowHistMaker, InteractionConstraint) { RegTree tree; tree.param.num_feature = kCols; - std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m) }; + std::unique_ptr updater{ + TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})}; updater->Configure(Args{{"num_feature", std::to_string(kCols)}}); updater->Update(&gradients, p_dmat.get(), {&tree}); diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index dbe910a8f..dc6a8da21 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -38,7 +38,8 @@ TEST(Updater, Prune) { tree.param.UpdateAllowUnknown(cfg); std::vector trees {&tree}; // prepare pruner - std::unique_ptr pruner(TreeUpdater::Create("prune", &lparam)); + std::unique_ptr pruner( + TreeUpdater::Create("prune", &lparam, ObjInfo{ObjInfo::kRegression})); pruner->Configure(cfg); // loss_chg < min_split_loss; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 938205aae..534dd2a9e 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -28,7 +28,7 @@ class QuantileHistMock : public QuantileHistMaker { BuilderMock(const TrainParam ¶m, std::unique_ptr pruner, DMatrix const *fmat) - : RealImpl(1, param, std::move(pruner), fmat) {} + : RealImpl(1, param, std::move(pruner), fmat, ObjInfo{ObjInfo::kRegression}) {} public: void TestInitData(const GHistIndexMatrix& gmat, @@ -230,7 +230,7 @@ class QuantileHistMock : public QuantileHistMaker { explicit QuantileHistMock( const std::vector >& args, const bool single_precision_histogram = false, bool batch = true) : - cfg_{args} { + QuantileHistMaker{ObjInfo{ObjInfo::kRegression}}, cfg_{args} { QuantileHistMaker::Configure(args); dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); if (single_precision_histogram) { diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index 3689940fd..5b71f0841 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -32,7 +32,8 @@ TEST(Updater, Refresh) { auto lparam = CreateEmptyGenericParam(GPUIDX); tree.param.UpdateAllowUnknown(cfg); std::vector trees {&tree}; - std::unique_ptr refresher(TreeUpdater::Create("refresh", &lparam)); + std::unique_ptr refresher( + TreeUpdater::Create("refresh", &lparam, ObjInfo{ObjInfo::kRegression})); tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f, /*left_sum=*/0.0f, /*right_sum=*/0.0f); diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index eb8a7c5d9..de9c53f35 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -23,7 +23,7 @@ class UpdaterTreeStatTest : public ::testing::Test { void RunTest(std::string updater) { auto tparam = CreateEmptyGenericParam(0); auto up = std::unique_ptr{ - TreeUpdater::Create(updater, &tparam)}; + TreeUpdater::Create(updater, &tparam, ObjInfo{ObjInfo::kRegression})}; up->Configure(Args{}); RegTree tree; tree.param.num_feature = kCols;