Pass obj info by reference instead of by value. (#8889)
- Pass obj info into tree updater as const pointer. This way we don't have to initialize the learner model param before configuring gbm, hence breaking up the dependency of configurations.
This commit is contained in:
@@ -15,12 +15,12 @@
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "constraints.cuh"
|
||||
#include "driver.h"
|
||||
#include "gpu_hist/evaluate_splits.cuh"
|
||||
@@ -39,11 +39,10 @@
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/task.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace xgboost::tree {
|
||||
#if !defined(GTEST_TEST)
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||
#endif // !defined(GTEST_TEST)
|
||||
@@ -106,12 +105,12 @@ class DeviceHistogramStorage {
|
||||
nidx_map_.clear();
|
||||
overflow_nidx_map_.clear();
|
||||
}
|
||||
bool HistogramExists(int nidx) const {
|
||||
[[nodiscard]] bool HistogramExists(int nidx) const {
|
||||
return nidx_map_.find(nidx) != nidx_map_.cend() ||
|
||||
overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend();
|
||||
}
|
||||
int Bins() const { return n_bins_; }
|
||||
size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
|
||||
[[nodiscard]] int Bins() const { return n_bins_; }
|
||||
[[nodiscard]] size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
|
||||
dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; }
|
||||
|
||||
void AllocateHistograms(const std::vector<int>& new_nidxs) {
|
||||
@@ -690,8 +689,9 @@ struct GPUHistMakerDevice {
|
||||
return root_entry;
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
|
||||
RegTree* p_tree, collective::DeviceCommunicator* communicator,
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
ObjInfo const* task, RegTree* p_tree,
|
||||
collective::DeviceCommunicator* communicator,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
auto& tree = *p_tree;
|
||||
// Process maximum 32 nodes at a time
|
||||
@@ -741,7 +741,7 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
monitor.Start("FinalisePosition");
|
||||
this->FinalisePosition(p_tree, p_fmat, task, p_out_position);
|
||||
this->FinalisePosition(p_tree, p_fmat, *task, p_out_position);
|
||||
monitor.Stop("FinalisePosition");
|
||||
}
|
||||
};
|
||||
@@ -750,7 +750,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
using GradientSumT = GradientPairPrecise;
|
||||
|
||||
public:
|
||||
explicit GPUHistMaker(Context const* ctx, ObjInfo task)
|
||||
explicit GPUHistMaker(Context const* ctx, ObjInfo const* task)
|
||||
: TreeUpdater(ctx), task_{task} {};
|
||||
void Configure(const Args& args) override {
|
||||
// Used in test to count how many configurations are performed
|
||||
@@ -872,8 +872,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||
|
||||
char const* Name() const override { return "grow_gpu_hist"; }
|
||||
bool HasNodePosition() const override { return true; }
|
||||
[[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
|
||||
[[nodiscard]] bool HasNodePosition() const override { return true; }
|
||||
|
||||
private:
|
||||
bool initialised_{false};
|
||||
@@ -882,7 +882,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
DMatrix* p_last_fmat_{nullptr};
|
||||
RegTree const* p_last_tree_{nullptr};
|
||||
ObjInfo task_;
|
||||
ObjInfo const* task_{nullptr};
|
||||
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
@@ -890,8 +890,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
#if !defined(GTEST_TEST)
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||
.describe("Grow tree with GPU.")
|
||||
.set_body([](Context const* ctx, ObjInfo task) { return new GPUHistMaker(ctx, task); });
|
||||
.set_body([](Context const* ctx, ObjInfo const* task) {
|
||||
return new GPUHistMaker(ctx, task);
|
||||
});
|
||||
#endif // !defined(GTEST_TEST)
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
Reference in New Issue
Block a user