Pass infomation about objective to tree methods. (#7385)
* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
parent
ccdabe4512
commit
4100827971
@ -11,15 +11,16 @@
|
|||||||
#include <dmlc/any.h>
|
#include <dmlc/any.h>
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
#include <xgboost/feature_map.h>
|
#include <xgboost/feature_map.h>
|
||||||
#include <xgboost/predictor.h>
|
|
||||||
#include <xgboost/generic_parameters.h>
|
#include <xgboost/generic_parameters.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
#include <xgboost/model.h>
|
#include <xgboost/model.h>
|
||||||
|
#include <xgboost/predictor.h>
|
||||||
|
#include <xgboost/task.h>
|
||||||
|
|
||||||
#include <utility>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -307,11 +308,13 @@ struct LearnerModelParam {
|
|||||||
uint32_t num_feature { 0 };
|
uint32_t num_feature { 0 };
|
||||||
/* \brief number of classes, if it is multi-class classification */
|
/* \brief number of classes, if it is multi-class classification */
|
||||||
uint32_t num_output_group { 0 };
|
uint32_t num_output_group { 0 };
|
||||||
|
/* \brief Current task, determined by objective. */
|
||||||
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
|
|
||||||
LearnerModelParam() = default;
|
LearnerModelParam() = default;
|
||||||
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
|
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
|
||||||
// this one as an immutable copy.
|
// this one as an immutable copy.
|
||||||
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin);
|
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, ObjInfo t);
|
||||||
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
|
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
|
||||||
bool Initialized() const { return num_feature != 0; }
|
bool Initialized() const { return num_feature != 0; }
|
||||||
};
|
};
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
#include <xgboost/model.h>
|
#include <xgboost/model.h>
|
||||||
#include <xgboost/generic_parameters.h>
|
#include <xgboost/generic_parameters.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
|
#include <xgboost/task.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -72,6 +73,11 @@ class ObjFunction : public Configurable {
|
|||||||
virtual bst_float ProbToMargin(bst_float base_score) const {
|
virtual bst_float ProbToMargin(bst_float base_score) const {
|
||||||
return base_score;
|
return base_score;
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief Return task of this objective.
|
||||||
|
*/
|
||||||
|
virtual struct ObjInfo Task() const = 0;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Create an objective function according to name.
|
* \brief Create an objective function according to name.
|
||||||
* \param tparam Generic parameters.
|
* \param tparam Generic parameters.
|
||||||
|
|||||||
39
include/xgboost/task.h
Normal file
39
include/xgboost/task.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_TASK_H_
|
||||||
|
#define XGBOOST_TASK_H_
|
||||||
|
|
||||||
|
#include <cinttypes>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
/*!
|
||||||
|
* \brief A struct returned by objective, which determines task at hand. The struct is
|
||||||
|
* not used by any algorithm yet, only for future development like categorical
|
||||||
|
* split.
|
||||||
|
*
|
||||||
|
* The task field is useful for tree split finding, also for some metrics like auc.
|
||||||
|
* Lastly, knowing whether hessian is constant can allow some optimizations like skipping
|
||||||
|
* the quantile sketching.
|
||||||
|
*
|
||||||
|
* This struct should not be serialized since it can be recovered from objective function,
|
||||||
|
* hence it doesn't need to be stable.
|
||||||
|
*/
|
||||||
|
struct ObjInfo {
|
||||||
|
// What kind of problem are we trying to solve
|
||||||
|
enum Task : uint8_t {
|
||||||
|
kRegression = 0,
|
||||||
|
kBinary = 1,
|
||||||
|
kClassification = 2,
|
||||||
|
kSurvival = 3,
|
||||||
|
kRanking = 4,
|
||||||
|
kOther = 5,
|
||||||
|
} task;
|
||||||
|
// Does the objective have constant hessian value?
|
||||||
|
bool const_hess{false};
|
||||||
|
|
||||||
|
explicit ObjInfo(Task t) : task{t} {}
|
||||||
|
ObjInfo(Task t, bool khess) : const_hess{khess} {}
|
||||||
|
};
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_TASK_H_
|
||||||
@ -11,16 +11,17 @@
|
|||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
#include <xgboost/tree_model.h>
|
|
||||||
#include <xgboost/generic_parameters.h>
|
#include <xgboost/generic_parameters.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
#include <xgboost/model.h>
|
|
||||||
#include <xgboost/linalg.h>
|
#include <xgboost/linalg.h>
|
||||||
|
#include <xgboost/model.h>
|
||||||
|
#include <xgboost/task.h>
|
||||||
|
#include <xgboost/tree_model.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
|
||||||
#include <utility>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
@ -83,7 +84,7 @@ class TreeUpdater : public Configurable {
|
|||||||
* \param name Name of the tree updater.
|
* \param name Name of the tree updater.
|
||||||
* \param tparam A global runtime parameter
|
* \param tparam A global runtime parameter
|
||||||
*/
|
*/
|
||||||
static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam);
|
static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam, ObjInfo task);
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
@ -91,8 +92,7 @@ class TreeUpdater : public Configurable {
|
|||||||
*/
|
*/
|
||||||
struct TreeUpdaterReg
|
struct TreeUpdaterReg
|
||||||
: public dmlc::FunctionRegEntryBase<TreeUpdaterReg,
|
: public dmlc::FunctionRegEntryBase<TreeUpdaterReg,
|
||||||
std::function<TreeUpdater* ()> > {
|
std::function<TreeUpdater*(ObjInfo task)> > {};
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Macro to register tree updater.
|
* \brief Macro to register tree updater.
|
||||||
|
|||||||
@ -34,6 +34,11 @@ class MyLogistic : public ObjFunction {
|
|||||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ObjInfo Task() const override {
|
||||||
|
return {ObjInfo::kRegression, false};
|
||||||
|
}
|
||||||
|
|
||||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
int iter,
|
int iter,
|
||||||
|
|||||||
@ -306,7 +306,8 @@ void GBTree::InitUpdater(Args const& cfg) {
|
|||||||
|
|
||||||
// create new updaters
|
// create new updaters
|
||||||
for (const std::string& pstr : ups) {
|
for (const std::string& pstr : ups) {
|
||||||
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(pstr.c_str(), generic_param_));
|
std::unique_ptr<TreeUpdater> up(
|
||||||
|
TreeUpdater::Create(pstr.c_str(), generic_param_, model_.learner_model_param->task));
|
||||||
up->Configure(cfg);
|
up->Configure(cfg);
|
||||||
updaters_.push_back(std::move(up));
|
updaters_.push_back(std::move(up));
|
||||||
}
|
}
|
||||||
@ -391,7 +392,8 @@ void GBTree::LoadConfig(Json const& in) {
|
|||||||
auto const& j_updaters = get<Object const>(in["updater"]);
|
auto const& j_updaters = get<Object const>(in["updater"]);
|
||||||
updaters_.clear();
|
updaters_.clear();
|
||||||
for (auto const& kv : j_updaters) {
|
for (auto const& kv : j_updaters) {
|
||||||
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(kv.first, generic_param_));
|
std::unique_ptr<TreeUpdater> up(
|
||||||
|
TreeUpdater::Create(kv.first, generic_param_, model_.learner_model_param->task));
|
||||||
up->LoadConfig(kv.second);
|
up->LoadConfig(kv.second);
|
||||||
updaters_.push_back(std::move(up));
|
updaters_.push_back(std::move(up));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -159,13 +159,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
LearnerModelParam::LearnerModelParam(
|
LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin,
|
||||||
LearnerModelParamLegacy const &user_param, float base_margin)
|
ObjInfo t)
|
||||||
: base_score{base_margin}, num_feature{user_param.num_feature},
|
: base_score{base_margin},
|
||||||
num_output_group{user_param.num_class == 0
|
num_feature{user_param.num_feature},
|
||||||
? 1
|
num_output_group{user_param.num_class == 0 ? 1 : static_cast<uint32_t>(user_param.num_class)},
|
||||||
: static_cast<uint32_t>(user_param.num_class)}
|
task{t} {}
|
||||||
{}
|
|
||||||
|
|
||||||
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||||
// data split mode, can be row, col, or none.
|
// data split mode, can be row, col, or none.
|
||||||
@ -339,8 +338,8 @@ class LearnerConfiguration : public Learner {
|
|||||||
// - model is created from scratch.
|
// - model is created from scratch.
|
||||||
// - model is configured second time due to change of parameter
|
// - model is configured second time due to change of parameter
|
||||||
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
|
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
|
||||||
learner_model_param_ = LearnerModelParam(mparam_,
|
learner_model_param_ =
|
||||||
obj_->ProbToMargin(mparam_.base_score));
|
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task());
|
||||||
}
|
}
|
||||||
|
|
||||||
this->ConfigureGBM(old_tparam, args);
|
this->ConfigureGBM(old_tparam, args);
|
||||||
@ -832,7 +831,7 @@ class LearnerIO : public LearnerConfiguration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
learner_model_param_ =
|
learner_model_param_ =
|
||||||
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score));
|
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task());
|
||||||
if (attributes_.find("objective") != attributes_.cend()) {
|
if (attributes_.find("objective") != attributes_.cend()) {
|
||||||
auto obj_str = attributes_.at("objective");
|
auto obj_str = attributes_.at("objective");
|
||||||
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
|
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
|
||||||
|
|||||||
@ -38,6 +38,8 @@ class AFTObj : public ObjFunction {
|
|||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ObjInfo Task() const override { return {ObjInfo::kSurvival, false}; }
|
||||||
|
|
||||||
template <typename Distribution>
|
template <typename Distribution>
|
||||||
void GetGradientImpl(const HostDeviceVector<bst_float> &preds,
|
void GetGradientImpl(const HostDeviceVector<bst_float> &preds,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
|
|||||||
@ -27,6 +27,8 @@ class HingeObj : public ObjFunction {
|
|||||||
void Configure(
|
void Configure(
|
||||||
const std::vector<std::pair<std::string, std::string> > &args) override {}
|
const std::vector<std::pair<std::string, std::string> > &args) override {}
|
||||||
|
|
||||||
|
ObjInfo Task() const override { return {ObjInfo::kRegression, false}; }
|
||||||
|
|
||||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
int iter,
|
int iter,
|
||||||
|
|||||||
@ -45,6 +45,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
void Configure(Args const& args) override {
|
void Configure(Args const& args) override {
|
||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ObjInfo Task() const override { return {ObjInfo::kClassification, false}; }
|
||||||
|
|
||||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo& info,
|
const MetaInfo& info,
|
||||||
int iter,
|
int iter,
|
||||||
|
|||||||
@ -754,6 +754,8 @@ class LambdaRankObj : public ObjFunction {
|
|||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ObjInfo Task() const override { return {ObjInfo::kRanking, false}; }
|
||||||
|
|
||||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo& info,
|
const MetaInfo& info,
|
||||||
int iter,
|
int iter,
|
||||||
|
|||||||
@ -7,6 +7,8 @@
|
|||||||
#include <dmlc/omp.h>
|
#include <dmlc/omp.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "xgboost/task.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -36,6 +38,7 @@ struct LinearSquareLoss {
|
|||||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||||
|
|
||||||
static const char* Name() { return "reg:squarederror"; }
|
static const char* Name() { return "reg:squarederror"; }
|
||||||
|
static ObjInfo Info() { return {ObjInfo::kRegression, true}; }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SquaredLogError {
|
struct SquaredLogError {
|
||||||
@ -61,6 +64,8 @@ struct SquaredLogError {
|
|||||||
static const char* DefaultEvalMetric() { return "rmsle"; }
|
static const char* DefaultEvalMetric() { return "rmsle"; }
|
||||||
|
|
||||||
static const char* Name() { return "reg:squaredlogerror"; }
|
static const char* Name() { return "reg:squaredlogerror"; }
|
||||||
|
|
||||||
|
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// logistic loss for probability regression task
|
// logistic loss for probability regression task
|
||||||
@ -96,6 +101,8 @@ struct LogisticRegression {
|
|||||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||||
|
|
||||||
static const char* Name() { return "reg:logistic"; }
|
static const char* Name() { return "reg:logistic"; }
|
||||||
|
|
||||||
|
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PseudoHuberError {
|
struct PseudoHuberError {
|
||||||
@ -127,12 +134,14 @@ struct PseudoHuberError {
|
|||||||
static const char* Name() {
|
static const char* Name() {
|
||||||
return "reg:pseudohubererror";
|
return "reg:pseudohubererror";
|
||||||
}
|
}
|
||||||
|
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// logistic loss for binary classification task
|
// logistic loss for binary classification task
|
||||||
struct LogisticClassification : public LogisticRegression {
|
struct LogisticClassification : public LogisticRegression {
|
||||||
static const char* DefaultEvalMetric() { return "logloss"; }
|
static const char* DefaultEvalMetric() { return "logloss"; }
|
||||||
static const char* Name() { return "binary:logistic"; }
|
static const char* Name() { return "binary:logistic"; }
|
||||||
|
static ObjInfo Info() { return {ObjInfo::kBinary, false}; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// logistic loss, but predict un-transformed margin
|
// logistic loss, but predict un-transformed margin
|
||||||
@ -168,6 +177,8 @@ struct LogisticRaw : public LogisticRegression {
|
|||||||
static const char* DefaultEvalMetric() { return "logloss"; }
|
static const char* DefaultEvalMetric() { return "logloss"; }
|
||||||
|
|
||||||
static const char* Name() { return "binary:logitraw"; }
|
static const char* Name() { return "binary:logitraw"; }
|
||||||
|
|
||||||
|
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace obj
|
} // namespace obj
|
||||||
|
|||||||
@ -52,6 +52,10 @@ class RegLossObj : public ObjFunction {
|
|||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ObjInfo Task() const override {
|
||||||
|
return Loss::Info();
|
||||||
|
}
|
||||||
|
|
||||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo &info, int,
|
const MetaInfo &info, int,
|
||||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||||
@ -207,6 +211,10 @@ class PoissonRegression : public ObjFunction {
|
|||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ObjInfo Task() const override {
|
||||||
|
return {ObjInfo::kRegression, false};
|
||||||
|
}
|
||||||
|
|
||||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo &info, int,
|
const MetaInfo &info, int,
|
||||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||||
@ -298,6 +306,10 @@ class CoxRegression : public ObjFunction {
|
|||||||
void Configure(
|
void Configure(
|
||||||
const std::vector<std::pair<std::string, std::string> >&) override {}
|
const std::vector<std::pair<std::string, std::string> >&) override {}
|
||||||
|
|
||||||
|
struct ObjInfo Task() const override {
|
||||||
|
return {ObjInfo::kRegression, false};
|
||||||
|
}
|
||||||
|
|
||||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo &info, int,
|
const MetaInfo &info, int,
|
||||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||||
@ -395,6 +407,10 @@ class GammaRegression : public ObjFunction {
|
|||||||
void Configure(
|
void Configure(
|
||||||
const std::vector<std::pair<std::string, std::string> >&) override {}
|
const std::vector<std::pair<std::string, std::string> >&) override {}
|
||||||
|
|
||||||
|
struct ObjInfo Task() const override {
|
||||||
|
return {ObjInfo::kRegression, false};
|
||||||
|
}
|
||||||
|
|
||||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||||
const MetaInfo &info, int,
|
const MetaInfo &info, int,
|
||||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||||
@ -491,6 +507,10 @@ class TweedieRegression : public ObjFunction {
|
|||||||
metric_ = os.str();
|
metric_ = os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ObjInfo Task() const override {
|
||||||
|
return {ObjInfo::kRegression, false};
|
||||||
|
}
|
||||||
|
|
||||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo &info, int,
|
const MetaInfo &info, int,
|
||||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||||
|
|||||||
@ -14,12 +14,13 @@ DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam) {
|
TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam,
|
||||||
auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
|
ObjInfo task) {
|
||||||
|
auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
|
||||||
if (e == nullptr) {
|
if (e == nullptr) {
|
||||||
LOG(FATAL) << "Unknown tree updater " << name;
|
LOG(FATAL) << "Unknown tree updater " << name;
|
||||||
}
|
}
|
||||||
auto p_updater = (e->body)();
|
auto p_updater = (e->body)(task);
|
||||||
p_updater->tparam_ = tparam;
|
p_updater->tparam_ = tparam;
|
||||||
return p_updater;
|
return p_updater;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -628,7 +628,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
||||||
.describe("Grow tree with parallelization over columns.")
|
.describe("Grow tree with parallelization over columns.")
|
||||||
.set_body([]() {
|
.set_body([](ObjInfo) {
|
||||||
return new ColMaker();
|
return new ColMaker();
|
||||||
});
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
@ -732,7 +732,7 @@ struct GPUHistMakerDevice {
|
|||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
class GPUHistMakerSpecialised {
|
class GPUHistMakerSpecialised {
|
||||||
public:
|
public:
|
||||||
GPUHistMakerSpecialised() = default;
|
explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {};
|
||||||
void Configure(const Args& args, GenericParameter const* generic_param) {
|
void Configure(const Args& args, GenericParameter const* generic_param) {
|
||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
generic_param_ = generic_param;
|
generic_param_ = generic_param;
|
||||||
@ -859,12 +859,14 @@ class GPUHistMakerSpecialised {
|
|||||||
|
|
||||||
DMatrix* p_last_fmat_ { nullptr };
|
DMatrix* p_last_fmat_ { nullptr };
|
||||||
int device_{-1};
|
int device_{-1};
|
||||||
|
ObjInfo task_;
|
||||||
|
|
||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GPUHistMaker : public TreeUpdater {
|
class GPUHistMaker : public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
|
explicit GPUHistMaker(ObjInfo task) : task_{task} {}
|
||||||
void Configure(const Args& args) override {
|
void Configure(const Args& args) override {
|
||||||
// Used in test to count how many configurations are performed
|
// Used in test to count how many configurations are performed
|
||||||
LOG(DEBUG) << "[GPU Hist]: Configure";
|
LOG(DEBUG) << "[GPU Hist]: Configure";
|
||||||
@ -878,11 +880,11 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
param = double_maker_->param_;
|
param = double_maker_->param_;
|
||||||
}
|
}
|
||||||
if (hist_maker_param_.single_precision_histogram) {
|
if (hist_maker_param_.single_precision_histogram) {
|
||||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
|
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
|
||||||
float_maker_->param_ = param;
|
float_maker_->param_ = param;
|
||||||
float_maker_->Configure(args, tparam_);
|
float_maker_->Configure(args, tparam_);
|
||||||
} else {
|
} else {
|
||||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
|
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
|
||||||
double_maker_->param_ = param;
|
double_maker_->param_ = param;
|
||||||
double_maker_->Configure(args, tparam_);
|
double_maker_->Configure(args, tparam_);
|
||||||
}
|
}
|
||||||
@ -892,10 +894,10 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
auto const& config = get<Object const>(in);
|
auto const& config = get<Object const>(in);
|
||||||
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
||||||
if (hist_maker_param_.single_precision_histogram) {
|
if (hist_maker_param_.single_precision_histogram) {
|
||||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
|
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
|
||||||
FromJson(config.at("train_param"), &float_maker_->param_);
|
FromJson(config.at("train_param"), &float_maker_->param_);
|
||||||
} else {
|
} else {
|
||||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
|
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
|
||||||
FromJson(config.at("train_param"), &double_maker_->param_);
|
FromJson(config.at("train_param"), &double_maker_->param_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -933,6 +935,7 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
GPUHistMakerTrainParam hist_maker_param_;
|
GPUHistMakerTrainParam hist_maker_param_;
|
||||||
|
ObjInfo task_;
|
||||||
std::unique_ptr<GPUHistMakerSpecialised<GradientPair>> float_maker_;
|
std::unique_ptr<GPUHistMakerSpecialised<GradientPair>> float_maker_;
|
||||||
std::unique_ptr<GPUHistMakerSpecialised<GradientPairPrecise>> double_maker_;
|
std::unique_ptr<GPUHistMakerSpecialised<GradientPairPrecise>> double_maker_;
|
||||||
};
|
};
|
||||||
@ -940,7 +943,7 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
#if !defined(GTEST_TEST)
|
#if !defined(GTEST_TEST)
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||||
.describe("Grow tree with GPU.")
|
.describe("Grow tree with GPU.")
|
||||||
.set_body([]() { return new GPUHistMaker(); });
|
.set_body([](ObjInfo task) { return new GPUHistMaker(task); });
|
||||||
#endif // !defined(GTEST_TEST)
|
#endif // !defined(GTEST_TEST)
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
@ -750,14 +750,14 @@ class GlobalProposalHistMaker: public CQHistMaker {
|
|||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
||||||
.describe("Tree constructor that uses approximate histogram construction.")
|
.describe("Tree constructor that uses approximate histogram construction.")
|
||||||
.set_body([]() {
|
.set_body([](ObjInfo) {
|
||||||
return new CQHistMaker();
|
return new CQHistMaker();
|
||||||
});
|
});
|
||||||
|
|
||||||
// The updater for approx tree method.
|
// The updater for approx tree method.
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
|
||||||
.describe("Tree constructor that uses approximate global of histogram construction.")
|
.describe("Tree constructor that uses approximate global of histogram construction.")
|
||||||
.set_body([]() {
|
.set_body([](ObjInfo) {
|
||||||
return new GlobalProposalHistMaker();
|
return new GlobalProposalHistMaker();
|
||||||
});
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
@ -23,8 +23,8 @@ DMLC_REGISTRY_FILE_TAG(updater_prune);
|
|||||||
/*! \brief pruner that prunes a tree after growing finishes */
|
/*! \brief pruner that prunes a tree after growing finishes */
|
||||||
class TreePruner: public TreeUpdater {
|
class TreePruner: public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
TreePruner() {
|
explicit TreePruner(ObjInfo task) {
|
||||||
syncher_.reset(TreeUpdater::Create("sync", tparam_));
|
syncher_.reset(TreeUpdater::Create("sync", tparam_, task));
|
||||||
pruner_monitor_.Init("TreePruner");
|
pruner_monitor_.Init("TreePruner");
|
||||||
}
|
}
|
||||||
char const* Name() const override {
|
char const* Name() const override {
|
||||||
@ -113,8 +113,8 @@ class TreePruner: public TreeUpdater {
|
|||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
||||||
.describe("Pruner that prune the tree according to statistics.")
|
.describe("Pruner that prune the tree according to statistics.")
|
||||||
.set_body([]() {
|
.set_body([](ObjInfo task) {
|
||||||
return new TreePruner();
|
return new TreePruner(task);
|
||||||
});
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -40,7 +40,7 @@ DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam);
|
|||||||
void QuantileHistMaker::Configure(const Args& args) {
|
void QuantileHistMaker::Configure(const Args& args) {
|
||||||
// initialize pruner
|
// initialize pruner
|
||||||
if (!pruner_) {
|
if (!pruner_) {
|
||||||
pruner_.reset(TreeUpdater::Create("prune", tparam_));
|
pruner_.reset(TreeUpdater::Create("prune", tparam_, task_));
|
||||||
}
|
}
|
||||||
pruner_->Configure(args);
|
pruner_->Configure(args);
|
||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
@ -52,7 +52,7 @@ void QuantileHistMaker::SetBuilder(const size_t n_trees,
|
|||||||
std::unique_ptr<Builder<GradientSumT>>* builder,
|
std::unique_ptr<Builder<GradientSumT>>* builder,
|
||||||
DMatrix *dmat) {
|
DMatrix *dmat) {
|
||||||
builder->reset(
|
builder->reset(
|
||||||
new Builder<GradientSumT>(n_trees, param_, std::move(pruner_), dmat));
|
new Builder<GradientSumT>(n_trees, param_, std::move(pruner_), dmat, task_));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename GradientSumT>
|
template<typename GradientSumT>
|
||||||
@ -529,11 +529,11 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
|
|||||||
// store a pointer to the tree
|
// store a pointer to the tree
|
||||||
p_last_tree_ = &tree;
|
p_last_tree_ = &tree;
|
||||||
if (data_layout_ == DataLayout::kDenseDataOneBased) {
|
if (data_layout_ == DataLayout::kDenseDataOneBased) {
|
||||||
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
|
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{param_, info, this->nthread_,
|
||||||
param_, info, this->nthread_, column_sampler_, true});
|
column_sampler_, true});
|
||||||
} else {
|
} else {
|
||||||
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
|
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{param_, info, this->nthread_,
|
||||||
param_, info, this->nthread_, column_sampler_, false});
|
column_sampler_, false});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data_layout_ == DataLayout::kDenseDataZeroBased
|
if (data_layout_ == DataLayout::kDenseDataZeroBased
|
||||||
@ -677,17 +677,17 @@ XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
|
|||||||
.describe("(Deprecated, use grow_quantile_histmaker instead.)"
|
.describe("(Deprecated, use grow_quantile_histmaker instead.)"
|
||||||
" Grow tree using quantized histogram.")
|
" Grow tree using quantized histogram.")
|
||||||
.set_body(
|
.set_body(
|
||||||
[]() {
|
[](ObjInfo task) {
|
||||||
LOG(WARNING) << "grow_fast_histmaker is deprecated, "
|
LOG(WARNING) << "grow_fast_histmaker is deprecated, "
|
||||||
<< "use grow_quantile_histmaker instead.";
|
<< "use grow_quantile_histmaker instead.";
|
||||||
return new QuantileHistMaker();
|
return new QuantileHistMaker(task);
|
||||||
});
|
});
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
||||||
.describe("Grow tree using quantized histogram.")
|
.describe("Grow tree using quantized histogram.")
|
||||||
.set_body(
|
.set_body(
|
||||||
[]() {
|
[](ObjInfo task) {
|
||||||
return new QuantileHistMaker();
|
return new QuantileHistMaker(task);
|
||||||
});
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -95,7 +95,7 @@ using xgboost::common::Column;
|
|||||||
/*! \brief construct a tree using quantized feature values */
|
/*! \brief construct a tree using quantized feature values */
|
||||||
class QuantileHistMaker: public TreeUpdater {
|
class QuantileHistMaker: public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
QuantileHistMaker() {
|
explicit QuantileHistMaker(ObjInfo task) : task_{task} {
|
||||||
updater_monitor_.Init("QuantileHistMaker");
|
updater_monitor_.Init("QuantileHistMaker");
|
||||||
}
|
}
|
||||||
void Configure(const Args& args) override;
|
void Configure(const Args& args) override;
|
||||||
@ -154,12 +154,15 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
using GHistRowT = GHistRow<GradientSumT>;
|
using GHistRowT = GHistRow<GradientSumT>;
|
||||||
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
|
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
|
||||||
// constructor
|
// constructor
|
||||||
explicit Builder(const size_t n_trees, const TrainParam ¶m,
|
explicit Builder(const size_t n_trees, const TrainParam& param,
|
||||||
std::unique_ptr<TreeUpdater> pruner, DMatrix const *fmat)
|
std::unique_ptr<TreeUpdater> pruner, DMatrix const* fmat, ObjInfo task)
|
||||||
: n_trees_(n_trees), param_(param), pruner_(std::move(pruner)),
|
: n_trees_(n_trees),
|
||||||
p_last_tree_(nullptr), p_last_fmat_(fmat),
|
param_(param),
|
||||||
histogram_builder_{
|
pruner_(std::move(pruner)),
|
||||||
new HistogramBuilder<GradientSumT, CPUExpandEntry>} {
|
p_last_tree_(nullptr),
|
||||||
|
p_last_fmat_(fmat),
|
||||||
|
histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>},
|
||||||
|
task_{task} {
|
||||||
builder_monitor_.Init("Quantile::Builder");
|
builder_monitor_.Init("Quantile::Builder");
|
||||||
}
|
}
|
||||||
~Builder();
|
~Builder();
|
||||||
@ -261,6 +264,7 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
DataLayout data_layout_;
|
DataLayout data_layout_;
|
||||||
std::unique_ptr<HistogramBuilder<GradientSumT, CPUExpandEntry>>
|
std::unique_ptr<HistogramBuilder<GradientSumT, CPUExpandEntry>>
|
||||||
histogram_builder_;
|
histogram_builder_;
|
||||||
|
ObjInfo task_;
|
||||||
|
|
||||||
common::Monitor builder_monitor_;
|
common::Monitor builder_monitor_;
|
||||||
};
|
};
|
||||||
@ -281,6 +285,7 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
std::unique_ptr<Builder<double>> double_builder_;
|
std::unique_ptr<Builder<double>> double_builder_;
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> pruner_;
|
std::unique_ptr<TreeUpdater> pruner_;
|
||||||
|
ObjInfo task_;
|
||||||
};
|
};
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -161,7 +161,7 @@ class TreeRefresher: public TreeUpdater {
|
|||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
||||||
.describe("Refresher that refreshes the weight and statistics according to data.")
|
.describe("Refresher that refreshes the weight and statistics according to data.")
|
||||||
.set_body([]() {
|
.set_body([](ObjInfo) {
|
||||||
return new TreeRefresher();
|
return new TreeRefresher();
|
||||||
});
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class TreeSyncher: public TreeUpdater {
|
|||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
|
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
|
||||||
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
||||||
.set_body([]() {
|
.set_body([](ObjInfo) {
|
||||||
return new TreeSyncher();
|
return new TreeSyncher();
|
||||||
});
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
@ -275,7 +275,8 @@ void TestHistogramIndexImpl() {
|
|||||||
int constexpr kNRows = 1000, kNCols = 10;
|
int constexpr kNRows = 1000, kNCols = 10;
|
||||||
|
|
||||||
// Build 2 matrices and build a histogram maker with that
|
// Build 2 matrices and build a histogram maker with that
|
||||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker, hist_maker_ext;
|
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}},
|
||||||
|
hist_maker_ext{ObjInfo{ObjInfo::kRegression}};
|
||||||
std::unique_ptr<DMatrix> hist_maker_dmat(
|
std::unique_ptr<DMatrix> hist_maker_dmat(
|
||||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
||||||
|
|
||||||
@ -333,7 +334,7 @@ int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPa
|
|||||||
{"gamma", std::to_string(gamma)}
|
{"gamma", std::to_string(gamma)}
|
||||||
};
|
};
|
||||||
|
|
||||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
|
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
|
||||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||||
hist_maker.Configure(args, &generic_param);
|
hist_maker.Configure(args, &generic_param);
|
||||||
|
|
||||||
@ -394,7 +395,7 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
|||||||
{"sampling_method", sampling_method},
|
{"sampling_method", sampling_method},
|
||||||
};
|
};
|
||||||
|
|
||||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
|
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
|
||||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||||
hist_maker.Configure(args, &generic_param);
|
hist_maker.Configure(args, &generic_param);
|
||||||
|
|
||||||
@ -539,7 +540,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
|||||||
|
|
||||||
TEST(GpuHist, ConfigIO) {
|
TEST(GpuHist, ConfigIO) {
|
||||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||||
std::unique_ptr<TreeUpdater> updater {TreeUpdater::Create("grow_gpu_hist", &generic_param) };
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
|
TreeUpdater::Create("grow_gpu_hist", &generic_param, ObjInfo{ObjInfo::kRegression})};
|
||||||
updater->Configure(Args{});
|
updater->Configure(Args{});
|
||||||
|
|
||||||
Json j_updater { Object() };
|
Json j_updater { Object() };
|
||||||
|
|||||||
@ -34,7 +34,8 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
|||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.param.num_feature = kCols;
|
tree.param.num_feature = kCols;
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> updater { TreeUpdater::Create("grow_histmaker", ¶m) };
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
|
TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})};
|
||||||
updater->Configure(Args{
|
updater->Configure(Args{
|
||||||
{"interaction_constraints", "[[0, 1]]"},
|
{"interaction_constraints", "[[0, 1]]"},
|
||||||
{"num_feature", std::to_string(kCols)}});
|
{"num_feature", std::to_string(kCols)}});
|
||||||
@ -51,7 +52,8 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
|||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.param.num_feature = kCols;
|
tree.param.num_feature = kCols;
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> updater { TreeUpdater::Create("grow_histmaker", ¶m) };
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
|
TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})};
|
||||||
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
||||||
updater->Update(&gradients, p_dmat.get(), {&tree});
|
updater->Update(&gradients, p_dmat.get(), {&tree});
|
||||||
|
|
||||||
|
|||||||
@ -38,7 +38,8 @@ TEST(Updater, Prune) {
|
|||||||
tree.param.UpdateAllowUnknown(cfg);
|
tree.param.UpdateAllowUnknown(cfg);
|
||||||
std::vector<RegTree*> trees {&tree};
|
std::vector<RegTree*> trees {&tree};
|
||||||
// prepare pruner
|
// prepare pruner
|
||||||
std::unique_ptr<TreeUpdater> pruner(TreeUpdater::Create("prune", &lparam));
|
std::unique_ptr<TreeUpdater> pruner(
|
||||||
|
TreeUpdater::Create("prune", &lparam, ObjInfo{ObjInfo::kRegression}));
|
||||||
pruner->Configure(cfg);
|
pruner->Configure(cfg);
|
||||||
|
|
||||||
// loss_chg < min_split_loss;
|
// loss_chg < min_split_loss;
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
|
|
||||||
BuilderMock(const TrainParam ¶m, std::unique_ptr<TreeUpdater> pruner,
|
BuilderMock(const TrainParam ¶m, std::unique_ptr<TreeUpdater> pruner,
|
||||||
DMatrix const *fmat)
|
DMatrix const *fmat)
|
||||||
: RealImpl(1, param, std::move(pruner), fmat) {}
|
: RealImpl(1, param, std::move(pruner), fmat, ObjInfo{ObjInfo::kRegression}) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void TestInitData(const GHistIndexMatrix& gmat,
|
void TestInitData(const GHistIndexMatrix& gmat,
|
||||||
@ -230,7 +230,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
explicit QuantileHistMock(
|
explicit QuantileHistMock(
|
||||||
const std::vector<std::pair<std::string, std::string> >& args,
|
const std::vector<std::pair<std::string, std::string> >& args,
|
||||||
const bool single_precision_histogram = false, bool batch = true) :
|
const bool single_precision_histogram = false, bool batch = true) :
|
||||||
cfg_{args} {
|
QuantileHistMaker{ObjInfo{ObjInfo::kRegression}}, cfg_{args} {
|
||||||
QuantileHistMaker::Configure(args);
|
QuantileHistMaker::Configure(args);
|
||||||
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||||
if (single_precision_histogram) {
|
if (single_precision_histogram) {
|
||||||
|
|||||||
@ -32,7 +32,8 @@ TEST(Updater, Refresh) {
|
|||||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||||
tree.param.UpdateAllowUnknown(cfg);
|
tree.param.UpdateAllowUnknown(cfg);
|
||||||
std::vector<RegTree*> trees {&tree};
|
std::vector<RegTree*> trees {&tree};
|
||||||
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam));
|
std::unique_ptr<TreeUpdater> refresher(
|
||||||
|
TreeUpdater::Create("refresh", &lparam, ObjInfo{ObjInfo::kRegression}));
|
||||||
|
|
||||||
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f,
|
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f,
|
||||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||||
|
|||||||
@ -23,7 +23,7 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
|||||||
void RunTest(std::string updater) {
|
void RunTest(std::string updater) {
|
||||||
auto tparam = CreateEmptyGenericParam(0);
|
auto tparam = CreateEmptyGenericParam(0);
|
||||||
auto up = std::unique_ptr<TreeUpdater>{
|
auto up = std::unique_ptr<TreeUpdater>{
|
||||||
TreeUpdater::Create(updater, &tparam)};
|
TreeUpdater::Create(updater, &tparam, ObjInfo{ObjInfo::kRegression})};
|
||||||
up->Configure(Args{});
|
up->Configure(Args{});
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.param.num_feature = kCols;
|
tree.param.num_feature = kCols;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user