[SYCL] Add sampling initialization (#10216)

---------

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin
2024-04-24 22:35:52 +02:00
committed by GitHub
parent 59d7b8dc72
commit 58513dc288
4 changed files with 235 additions and 0 deletions

View File

@@ -0,0 +1,58 @@
/*!
* Copyright 2017-2024 by Contributors
* \file hist_updater.cc
*/
#include "hist_updater.h"
#include <oneapi/dpl/random>
namespace xgboost {
namespace sycl {
namespace tree {
template<typename GradientSumT>
void HistUpdater<GradientSumT>::InitSampling(
const USMVector<GradientPair, MemoryType::on_device> &gpair,
USMVector<size_t, MemoryType::on_device>* row_indices) {
const size_t num_rows = row_indices->Size();
auto* row_idx = row_indices->Data();
const auto* gpair_ptr = gpair.DataConst();
uint64_t num_samples = 0;
const auto subsample = param_.subsample;
::sycl::event event;
{
::sycl::buffer<uint64_t, 1> flag_buf(&num_samples, 1);
uint64_t seed = seed_;
seed_ += num_rows;
event = qu_.submit([&](::sycl::handler& cgh) {
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
[=](::sycl::item<1> pid) {
uint64_t i = pid.get_id(0);
// Create minstd_rand engine
oneapi::dpl::minstd_rand engine(seed, i);
oneapi::dpl::bernoulli_distribution coin_flip(subsample);
auto rnd = coin_flip(engine);
if (gpair_ptr[i].GetHess() >= 0.0f && rnd) {
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
row_idx[num_samples_ref++] = i;
}
});
});
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
}
row_indices->Resize(&qu_, num_samples, 0, &event);
qu_.wait();
}
template class HistUpdater<float>;
template class HistUpdater<double>;
} // namespace tree
} // namespace sycl
} // namespace xgboost

View File

@@ -0,0 +1,72 @@
/*!
* Copyright 2017-2024 by Contributors
* \file hist_updater.h
*/
#ifndef PLUGIN_SYCL_TREE_HIST_UPDATER_H_
#define PLUGIN_SYCL_TREE_HIST_UPDATER_H_
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <xgboost/tree_updater.h>
#pragma GCC diagnostic pop
#include <utility>
#include <memory>
#include "../common/partition_builder.h"
#include "split_evaluator.h"
#include "../data.h"
namespace xgboost {
namespace sycl {
namespace tree {
template<typename GradientSumT>
class HistUpdater {
public:
explicit HistUpdater(::sycl::queue qu,
const xgboost::tree::TrainParam& param,
std::unique_ptr<TreeUpdater> pruner,
FeatureInteractionConstraintHost int_constraints_,
DMatrix const* fmat)
: qu_(qu), param_(param),
tree_evaluator_(qu, param, fmat->Info().num_col_),
pruner_(std::move(pruner)),
interaction_constraints_{std::move(int_constraints_)},
p_last_tree_(nullptr), p_last_fmat_(fmat) {
builder_monitor_.Init("SYCL::Quantile::HistUpdater");
kernel_monitor_.Init("SYCL::Quantile::HistUpdater");
const auto sub_group_sizes =
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
sub_group_size_ = sub_group_sizes.back();
}
protected:
void InitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
USMVector<size_t, MemoryType::on_device>* row_indices);
size_t sub_group_size_;
const xgboost::tree::TrainParam& param_;
TreeEvaluator<GradientSumT> tree_evaluator_;
std::unique_ptr<TreeUpdater> pruner_;
FeatureInteractionConstraintHost interaction_constraints_;
// back pointers to tree and data matrix
const RegTree* p_last_tree_;
DMatrix const* const p_last_fmat_;
xgboost::common::Monitor builder_monitor_;
xgboost::common::Monitor kernel_monitor_;
uint64_t seed_ = 0;
::sycl::queue qu_;
};
} // namespace tree
} // namespace sycl
} // namespace xgboost
#endif // PLUGIN_SYCL_TREE_HIST_UPDATER_H_