Pass infomation about objective to tree methods. (#7385)

* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
Jiaming Yuan
2021-11-04 01:52:44 +08:00
committed by GitHub
parent ccdabe4512
commit 4100827971
28 changed files with 178 additions and 69 deletions

View File

@@ -14,12 +14,13 @@ DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
namespace xgboost {
TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam) {
auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam,
ObjInfo task) {
auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown tree updater " << name;
}
auto p_updater = (e->body)();
auto p_updater = (e->body)(task);
p_updater->tparam_ = tparam;
return p_updater;
}

View File

@@ -628,7 +628,7 @@ class ColMaker: public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
.describe("Grow tree with parallelization over columns.")
.set_body([]() {
.set_body([](ObjInfo) {
return new ColMaker();
});
} // namespace tree

View File

@@ -698,7 +698,7 @@ struct GPUHistMakerDevice {
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
num_leaves)) {
monitor.Start("UpdatePosition");
this->UpdatePosition(candidate.nid, p_tree);
monitor.Stop("UpdatePosition");
@@ -732,7 +732,7 @@ struct GPUHistMakerDevice {
template <typename GradientSumT>
class GPUHistMakerSpecialised {
public:
GPUHistMakerSpecialised() = default;
explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {};
void Configure(const Args& args, GenericParameter const* generic_param) {
param_.UpdateAllowUnknown(args);
generic_param_ = generic_param;
@@ -859,12 +859,14 @@ class GPUHistMakerSpecialised {
DMatrix* p_last_fmat_ { nullptr };
int device_{-1};
ObjInfo task_;
common::Monitor monitor_;
};
class GPUHistMaker : public TreeUpdater {
public:
explicit GPUHistMaker(ObjInfo task) : task_{task} {}
void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Hist]: Configure";
@@ -878,11 +880,11 @@ class GPUHistMaker : public TreeUpdater {
param = double_maker_->param_;
}
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
float_maker_->param_ = param;
float_maker_->Configure(args, tparam_);
} else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
double_maker_->param_ = param;
double_maker_->Configure(args, tparam_);
}
@@ -892,10 +894,10 @@ class GPUHistMaker : public TreeUpdater {
auto const& config = get<Object const>(in);
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
FromJson(config.at("train_param"), &float_maker_->param_);
} else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
FromJson(config.at("train_param"), &double_maker_->param_);
}
}
@@ -933,6 +935,7 @@ class GPUHistMaker : public TreeUpdater {
private:
GPUHistMakerTrainParam hist_maker_param_;
ObjInfo task_;
std::unique_ptr<GPUHistMakerSpecialised<GradientPair>> float_maker_;
std::unique_ptr<GPUHistMakerSpecialised<GradientPairPrecise>> double_maker_;
};
@@ -940,7 +943,7 @@ class GPUHistMaker : public TreeUpdater {
#if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUHistMaker(); });
.set_body([](ObjInfo task) { return new GPUHistMaker(task); });
#endif // !defined(GTEST_TEST)
} // namespace tree

View File

@@ -750,14 +750,14 @@ class GlobalProposalHistMaker: public CQHistMaker {
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
.describe("Tree constructor that uses approximate histogram construction.")
.set_body([]() {
.set_body([](ObjInfo) {
return new CQHistMaker();
});
// The updater for approx tree method.
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
.describe("Tree constructor that uses approximate global of histogram construction.")
.set_body([]() {
.set_body([](ObjInfo) {
return new GlobalProposalHistMaker();
});
} // namespace tree

View File

@@ -23,8 +23,8 @@ DMLC_REGISTRY_FILE_TAG(updater_prune);
/*! \brief pruner that prunes a tree after growing finishes */
class TreePruner: public TreeUpdater {
public:
TreePruner() {
syncher_.reset(TreeUpdater::Create("sync", tparam_));
explicit TreePruner(ObjInfo task) {
syncher_.reset(TreeUpdater::Create("sync", tparam_, task));
pruner_monitor_.Init("TreePruner");
}
char const* Name() const override {
@@ -113,8 +113,8 @@ class TreePruner: public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
.describe("Pruner that prune the tree according to statistics.")
.set_body([]() {
return new TreePruner();
.set_body([](ObjInfo task) {
return new TreePruner(task);
});
} // namespace tree
} // namespace xgboost

View File

@@ -40,7 +40,7 @@ DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam);
void QuantileHistMaker::Configure(const Args& args) {
// initialize pruner
if (!pruner_) {
pruner_.reset(TreeUpdater::Create("prune", tparam_));
pruner_.reset(TreeUpdater::Create("prune", tparam_, task_));
}
pruner_->Configure(args);
param_.UpdateAllowUnknown(args);
@@ -52,7 +52,7 @@ void QuantileHistMaker::SetBuilder(const size_t n_trees,
std::unique_ptr<Builder<GradientSumT>>* builder,
DMatrix *dmat) {
builder->reset(
new Builder<GradientSumT>(n_trees, param_, std::move(pruner_), dmat));
new Builder<GradientSumT>(n_trees, param_, std::move(pruner_), dmat, task_));
}
template<typename GradientSumT>
@@ -529,11 +529,11 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
// store a pointer to the tree
p_last_tree_ = &tree;
if (data_layout_ == DataLayout::kDenseDataOneBased) {
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
param_, info, this->nthread_, column_sampler_, true});
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{param_, info, this->nthread_,
column_sampler_, true});
} else {
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
param_, info, this->nthread_, column_sampler_, false});
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{param_, info, this->nthread_,
column_sampler_, false});
}
if (data_layout_ == DataLayout::kDenseDataZeroBased
@@ -677,17 +677,17 @@ XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
.describe("(Deprecated, use grow_quantile_histmaker instead.)"
" Grow tree using quantized histogram.")
.set_body(
[]() {
[](ObjInfo task) {
LOG(WARNING) << "grow_fast_histmaker is deprecated, "
<< "use grow_quantile_histmaker instead.";
return new QuantileHistMaker();
return new QuantileHistMaker(task);
});
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
.describe("Grow tree using quantized histogram.")
.set_body(
[]() {
return new QuantileHistMaker();
[](ObjInfo task) {
return new QuantileHistMaker(task);
});
} // namespace tree
} // namespace xgboost

View File

@@ -95,7 +95,7 @@ using xgboost::common::Column;
/*! \brief construct a tree using quantized feature values */
class QuantileHistMaker: public TreeUpdater {
public:
QuantileHistMaker() {
explicit QuantileHistMaker(ObjInfo task) : task_{task} {
updater_monitor_.Init("QuantileHistMaker");
}
void Configure(const Args& args) override;
@@ -154,12 +154,15 @@ class QuantileHistMaker: public TreeUpdater {
using GHistRowT = GHistRow<GradientSumT>;
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
// constructor
explicit Builder(const size_t n_trees, const TrainParam &param,
std::unique_ptr<TreeUpdater> pruner, DMatrix const *fmat)
: n_trees_(n_trees), param_(param), pruner_(std::move(pruner)),
p_last_tree_(nullptr), p_last_fmat_(fmat),
histogram_builder_{
new HistogramBuilder<GradientSumT, CPUExpandEntry>} {
explicit Builder(const size_t n_trees, const TrainParam& param,
std::unique_ptr<TreeUpdater> pruner, DMatrix const* fmat, ObjInfo task)
: n_trees_(n_trees),
param_(param),
pruner_(std::move(pruner)),
p_last_tree_(nullptr),
p_last_fmat_(fmat),
histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>},
task_{task} {
builder_monitor_.Init("Quantile::Builder");
}
~Builder();
@@ -261,6 +264,7 @@ class QuantileHistMaker: public TreeUpdater {
DataLayout data_layout_;
std::unique_ptr<HistogramBuilder<GradientSumT, CPUExpandEntry>>
histogram_builder_;
ObjInfo task_;
common::Monitor builder_monitor_;
};
@@ -281,6 +285,7 @@ class QuantileHistMaker: public TreeUpdater {
std::unique_ptr<Builder<double>> double_builder_;
std::unique_ptr<TreeUpdater> pruner_;
ObjInfo task_;
};
} // namespace tree
} // namespace xgboost

View File

@@ -161,7 +161,7 @@ class TreeRefresher: public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
.describe("Refresher that refreshes the weight and statistics according to data.")
.set_body([]() {
.set_body([](ObjInfo) {
return new TreeRefresher();
});
} // namespace tree

View File

@@ -53,7 +53,7 @@ class TreeSyncher: public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
.describe("Syncher that synchronize the tree in all distributed nodes.")
.set_body([]() {
.set_body([](ObjInfo) {
return new TreeSyncher();
});
} // namespace tree