Pass obj info by reference instead of by value. (#8889)

- Pass obj info into tree updater as const pointer.

This way we don't have to initialize the learner model param before configuring gbm, hence
breaking up the dependency of configurations.
This commit is contained in:
Jiaming Yuan
2023-03-11 01:38:28 +08:00
committed by GitHub
parent 54e001bbf4
commit 6deaec8027
18 changed files with 125 additions and 112 deletions

View File

@@ -15,12 +15,12 @@
#include "../collective/device_communicator.cuh"
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/cuda_context.cuh" // CUDAContext
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "../common/io.h"
#include "../common/timer.h"
#include "../data/ellpack_page.cuh"
#include "../common/cuda_context.cuh" // CUDAContext
#include "constraints.cuh"
#include "driver.h"
#include "gpu_hist/evaluate_splits.cuh"
@@ -39,11 +39,10 @@
#include "xgboost/json.h"
#include "xgboost/parameter.h"
#include "xgboost/span.h"
#include "xgboost/task.h"
#include "xgboost/task.h" // for ObjInfo
#include "xgboost/tree_model.h"
namespace xgboost {
namespace tree {
namespace xgboost::tree {
#if !defined(GTEST_TEST)
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
#endif // !defined(GTEST_TEST)
@@ -106,12 +105,12 @@ class DeviceHistogramStorage {
nidx_map_.clear();
overflow_nidx_map_.clear();
}
bool HistogramExists(int nidx) const {
[[nodiscard]] bool HistogramExists(int nidx) const {
return nidx_map_.find(nidx) != nidx_map_.cend() ||
overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend();
}
int Bins() const { return n_bins_; }
size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
[[nodiscard]] int Bins() const { return n_bins_; }
[[nodiscard]] size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; }
void AllocateHistograms(const std::vector<int>& new_nidxs) {
@@ -690,8 +689,9 @@ struct GPUHistMakerDevice {
return root_entry;
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
RegTree* p_tree, collective::DeviceCommunicator* communicator,
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
ObjInfo const* task, RegTree* p_tree,
collective::DeviceCommunicator* communicator,
HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree;
// Process maximum 32 nodes at a time
@@ -741,7 +741,7 @@ struct GPUHistMakerDevice {
}
monitor.Start("FinalisePosition");
this->FinalisePosition(p_tree, p_fmat, task, p_out_position);
this->FinalisePosition(p_tree, p_fmat, *task, p_out_position);
monitor.Stop("FinalisePosition");
}
};
@@ -750,7 +750,7 @@ class GPUHistMaker : public TreeUpdater {
using GradientSumT = GradientPairPrecise;
public:
explicit GPUHistMaker(Context const* ctx, ObjInfo task)
explicit GPUHistMaker(Context const* ctx, ObjInfo const* task)
: TreeUpdater(ctx), task_{task} {};
void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
@@ -872,8 +872,8 @@ class GPUHistMaker : public TreeUpdater {
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
char const* Name() const override { return "grow_gpu_hist"; }
bool HasNodePosition() const override { return true; }
[[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
[[nodiscard]] bool HasNodePosition() const override { return true; }
private:
bool initialised_{false};
@@ -882,7 +882,7 @@ class GPUHistMaker : public TreeUpdater {
DMatrix* p_last_fmat_{nullptr};
RegTree const* p_last_tree_{nullptr};
ObjInfo task_;
ObjInfo const* task_{nullptr};
common::Monitor monitor_;
};
@@ -890,8 +890,8 @@ class GPUHistMaker : public TreeUpdater {
#if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([](Context const* ctx, ObjInfo task) { return new GPUHistMaker(ctx, task); });
.set_body([](Context const* ctx, ObjInfo const* task) {
return new GPUHistMaker(ctx, task);
});
#endif // !defined(GTEST_TEST)
} // namespace tree
} // namespace xgboost
} // namespace xgboost::tree