Implement column sampler in CUDA. (#9785)
- CUDA implementation. - Extract the broadcasting logic, we will need the context parameter after revamping the collective implementation. - Some changes to the event loop for fixing a deadlock in CI. - Move argsort into algorithms.cuh, add support for cuda stream.
This commit is contained in:
@@ -72,7 +72,7 @@ common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
|
||||
dh::Iota(sorted_idx);
|
||||
dh::Iota(sorted_idx, dh::DefaultStream());
|
||||
auto data = this->SortInput(d_inputs.size(), shared_inputs.feature_values.size());
|
||||
auto it = thrust::make_counting_iterator(0u);
|
||||
auto d_feature_idx = dh::ToSpan(feature_idx_);
|
||||
|
||||
@@ -248,8 +248,7 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
std::unique_ptr<GloablApproxBuilder> pimpl_;
|
||||
// pointer to the last DMatrix, used for update prediction cache.
|
||||
DMatrix *cached_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_ =
|
||||
std::make_shared<common::ColumnSampler>();
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||
ObjInfo const *task_;
|
||||
HistMakerTrainParam hist_param_;
|
||||
|
||||
@@ -284,6 +283,9 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
CHECK(hist_param_.GetInitialised());
|
||||
if (!column_sampler_) {
|
||||
column_sampler_ = common::MakeColumnSampler(ctx_);
|
||||
}
|
||||
pimpl_ = std::make_unique<GloablApproxBuilder>(param, &hist_param_, m->Info(), ctx_,
|
||||
column_sampler_, task_, &monitor_);
|
||||
|
||||
|
||||
@@ -225,9 +225,12 @@ class ColMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
{
|
||||
column_sampler_.Init(ctx_, fmat.Info().num_col_,
|
||||
fmat.Info().feature_weights.ConstHostVector(), param_.colsample_bynode,
|
||||
param_.colsample_bylevel, param_.colsample_bytree);
|
||||
if (!column_sampler_) {
|
||||
column_sampler_ = common::MakeColumnSampler(ctx_);
|
||||
}
|
||||
column_sampler_->Init(
|
||||
ctx_, fmat.Info().num_col_, fmat.Info().feature_weights.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bytree);
|
||||
}
|
||||
{
|
||||
// setup temp space for each thread
|
||||
@@ -467,7 +470,7 @@ class ColMaker: public TreeUpdater {
|
||||
RegTree *p_tree) {
|
||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||
|
||||
auto feat_set = column_sampler_.GetFeatureSet(depth);
|
||||
auto feat_set = column_sampler_->GetFeatureSet(depth);
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>(ctx_)) {
|
||||
this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat);
|
||||
}
|
||||
@@ -586,7 +589,7 @@ class ColMaker: public TreeUpdater {
|
||||
const ColMakerTrainParam& colmaker_train_param_;
|
||||
// number of omp thread used during training
|
||||
Context const* ctx_;
|
||||
common::ColumnSampler column_sampler_;
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||
// Instance Data: current node position in the tree of each instance
|
||||
std::vector<int> position_;
|
||||
// PerThread x PerTreeNode: statistics for per thread construction
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
* Copyright 2017-2023, XGBoost Contributors
|
||||
* \file updater_quantile_hist.cc
|
||||
* \brief use quantized feature values to construct a tree
|
||||
* \author Philip Cho, Tianqi Checn, Egor Smirnov
|
||||
@@ -470,8 +470,7 @@ class HistUpdater {
|
||||
class QuantileHistMaker : public TreeUpdater {
|
||||
std::unique_ptr<HistUpdater> p_impl_{nullptr};
|
||||
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_ =
|
||||
std::make_shared<common::ColumnSampler>();
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||
common::Monitor monitor_;
|
||||
ObjInfo const *task_{nullptr};
|
||||
HistMakerTrainParam hist_param_;
|
||||
@@ -495,6 +494,10 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *p_fmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
if (!column_sampler_) {
|
||||
column_sampler_ = common::MakeColumnSampler(ctx_);
|
||||
}
|
||||
|
||||
if (trees.front()->IsMultiTarget()) {
|
||||
CHECK(hist_param_.GetInitialised());
|
||||
CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented();
|
||||
|
||||
Reference in New Issue
Block a user