Implement column sampler in CUDA. (#9785)

- CUDA implementation.
- Extract the broadcasting logic, we will need the context parameter after revamping the collective implementation.
- Some changes to the event loop for fixing a deadlock in CI.
- Move argsort into algorithms.cuh, add support for cuda stream.
This commit is contained in:
Jiaming Yuan
2023-11-17 04:29:08 +08:00
committed by GitHub
parent 178cfe70a8
commit fedd9674c8
20 changed files with 447 additions and 232 deletions

View File

@@ -72,7 +72,7 @@ common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
dh::Iota(sorted_idx);
dh::Iota(sorted_idx, dh::DefaultStream());
auto data = this->SortInput(d_inputs.size(), shared_inputs.feature_values.size());
auto it = thrust::make_counting_iterator(0u);
auto d_feature_idx = dh::ToSpan(feature_idx_);

View File

@@ -248,8 +248,7 @@ class GlobalApproxUpdater : public TreeUpdater {
std::unique_ptr<GloablApproxBuilder> pimpl_;
// pointer to the last DMatrix, used for update prediction cache.
DMatrix *cached_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
std::shared_ptr<common::ColumnSampler> column_sampler_;
ObjInfo const *task_;
HistMakerTrainParam hist_param_;
@@ -284,6 +283,9 @@ class GlobalApproxUpdater : public TreeUpdater {
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override {
CHECK(hist_param_.GetInitialised());
if (!column_sampler_) {
column_sampler_ = common::MakeColumnSampler(ctx_);
}
pimpl_ = std::make_unique<GloablApproxBuilder>(param, &hist_param_, m->Info(), ctx_,
column_sampler_, task_, &monitor_);

View File

@@ -225,9 +225,12 @@ class ColMaker: public TreeUpdater {
}
}
{
column_sampler_.Init(ctx_, fmat.Info().num_col_,
fmat.Info().feature_weights.ConstHostVector(), param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
if (!column_sampler_) {
column_sampler_ = common::MakeColumnSampler(ctx_);
}
column_sampler_->Init(
ctx_, fmat.Info().num_col_, fmat.Info().feature_weights.ConstHostVector(),
param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bytree);
}
{
// setup temp space for each thread
@@ -467,7 +470,7 @@ class ColMaker: public TreeUpdater {
RegTree *p_tree) {
auto evaluator = tree_evaluator_.GetEvaluator();
auto feat_set = column_sampler_.GetFeatureSet(depth);
auto feat_set = column_sampler_->GetFeatureSet(depth);
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>(ctx_)) {
this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat);
}
@@ -586,7 +589,7 @@ class ColMaker: public TreeUpdater {
const ColMakerTrainParam& colmaker_train_param_;
// number of omp thread used during training
Context const* ctx_;
common::ColumnSampler column_sampler_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
// Instance Data: current node position in the tree of each instance
std::vector<int> position_;
// PerThread x PerTreeNode: statistics for per thread construction

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 by XGBoost Contributors
* Copyright 2017-2023, XGBoost Contributors
* \file updater_quantile_hist.cc
* \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Checn, Egor Smirnov
@@ -470,8 +470,7 @@ class HistUpdater {
class QuantileHistMaker : public TreeUpdater {
std::unique_ptr<HistUpdater> p_impl_{nullptr};
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
std::shared_ptr<common::ColumnSampler> column_sampler_;
common::Monitor monitor_;
ObjInfo const *task_{nullptr};
HistMakerTrainParam hist_param_;
@@ -495,6 +494,10 @@ class QuantileHistMaker : public TreeUpdater {
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override {
if (!column_sampler_) {
column_sampler_ = common::MakeColumnSampler(ctx_);
}
if (trees.front()->IsMultiTarget()) {
CHECK(hist_param_.GetInitialised());
CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented();