Pass infomation about objective to tree methods. (#7385)
* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
@@ -698,7 +698,7 @@ struct GPUHistMakerDevice {
|
||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||
// Only create child entries if needed
|
||||
if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
|
||||
num_leaves)) {
|
||||
num_leaves)) {
|
||||
monitor.Start("UpdatePosition");
|
||||
this->UpdatePosition(candidate.nid, p_tree);
|
||||
monitor.Stop("UpdatePosition");
|
||||
@@ -732,7 +732,7 @@ struct GPUHistMakerDevice {
|
||||
template <typename GradientSumT>
|
||||
class GPUHistMakerSpecialised {
|
||||
public:
|
||||
GPUHistMakerSpecialised() = default;
|
||||
explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {};
|
||||
void Configure(const Args& args, GenericParameter const* generic_param) {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
generic_param_ = generic_param;
|
||||
@@ -859,12 +859,14 @@ class GPUHistMakerSpecialised {
|
||||
|
||||
DMatrix* p_last_fmat_ { nullptr };
|
||||
int device_{-1};
|
||||
ObjInfo task_;
|
||||
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
|
||||
class GPUHistMaker : public TreeUpdater {
|
||||
public:
|
||||
explicit GPUHistMaker(ObjInfo task) : task_{task} {}
|
||||
void Configure(const Args& args) override {
|
||||
// Used in test to count how many configurations are performed
|
||||
LOG(DEBUG) << "[GPU Hist]: Configure";
|
||||
@@ -878,11 +880,11 @@ class GPUHistMaker : public TreeUpdater {
|
||||
param = double_maker_->param_;
|
||||
}
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
|
||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
|
||||
float_maker_->param_ = param;
|
||||
float_maker_->Configure(args, tparam_);
|
||||
} else {
|
||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
|
||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
|
||||
double_maker_->param_ = param;
|
||||
double_maker_->Configure(args, tparam_);
|
||||
}
|
||||
@@ -892,10 +894,10 @@ class GPUHistMaker : public TreeUpdater {
|
||||
auto const& config = get<Object const>(in);
|
||||
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
|
||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
|
||||
FromJson(config.at("train_param"), &float_maker_->param_);
|
||||
} else {
|
||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
|
||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
|
||||
FromJson(config.at("train_param"), &double_maker_->param_);
|
||||
}
|
||||
}
|
||||
@@ -933,6 +935,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
private:
|
||||
GPUHistMakerTrainParam hist_maker_param_;
|
||||
ObjInfo task_;
|
||||
std::unique_ptr<GPUHistMakerSpecialised<GradientPair>> float_maker_;
|
||||
std::unique_ptr<GPUHistMakerSpecialised<GradientPairPrecise>> double_maker_;
|
||||
};
|
||||
@@ -940,7 +943,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
#if !defined(GTEST_TEST)
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||
.describe("Grow tree with GPU.")
|
||||
.set_body([]() { return new GPUHistMaker(); });
|
||||
.set_body([](ObjInfo task) { return new GPUHistMaker(task); });
|
||||
#endif // !defined(GTEST_TEST)
|
||||
|
||||
} // namespace tree
|
||||
|
||||
Reference in New Issue
Block a user