[SYCL] Add sampling initialization (#10216)
--------- Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
committed by
GitHub
parent
59d7b8dc72
commit
58513dc288
58
plugin/sycl/tree/hist_updater.cc
Normal file
58
plugin/sycl/tree/hist_updater.cc
Normal file
@@ -0,0 +1,58 @@
|
||||
/*!
|
||||
* Copyright 2017-2024 by Contributors
|
||||
* \file hist_updater.cc
|
||||
*/
|
||||
|
||||
#include "hist_updater.h"
|
||||
|
||||
#include <oneapi/dpl/random>
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace tree {
|
||||
|
||||
template<typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::InitSampling(
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
USMVector<size_t, MemoryType::on_device>* row_indices) {
|
||||
const size_t num_rows = row_indices->Size();
|
||||
auto* row_idx = row_indices->Data();
|
||||
const auto* gpair_ptr = gpair.DataConst();
|
||||
uint64_t num_samples = 0;
|
||||
const auto subsample = param_.subsample;
|
||||
::sycl::event event;
|
||||
|
||||
{
|
||||
::sycl::buffer<uint64_t, 1> flag_buf(&num_samples, 1);
|
||||
uint64_t seed = seed_;
|
||||
seed_ += num_rows;
|
||||
event = qu_.submit([&](::sycl::handler& cgh) {
|
||||
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
|
||||
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
|
||||
[=](::sycl::item<1> pid) {
|
||||
uint64_t i = pid.get_id(0);
|
||||
|
||||
// Create minstd_rand engine
|
||||
oneapi::dpl::minstd_rand engine(seed, i);
|
||||
oneapi::dpl::bernoulli_distribution coin_flip(subsample);
|
||||
|
||||
auto rnd = coin_flip(engine);
|
||||
if (gpair_ptr[i].GetHess() >= 0.0f && rnd) {
|
||||
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
|
||||
row_idx[num_samples_ref++] = i;
|
||||
}
|
||||
});
|
||||
});
|
||||
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
|
||||
}
|
||||
|
||||
row_indices->Resize(&qu_, num_samples, 0, &event);
|
||||
qu_.wait();
|
||||
}
|
||||
|
||||
template class HistUpdater<float>;
|
||||
template class HistUpdater<double>;
|
||||
|
||||
} // namespace tree
|
||||
} // namespace sycl
|
||||
} // namespace xgboost
|
||||
72
plugin/sycl/tree/hist_updater.h
Normal file
72
plugin/sycl/tree/hist_updater.h
Normal file
@@ -0,0 +1,72 @@
|
||||
/*!
|
||||
* Copyright 2017-2024 by Contributors
|
||||
* \file hist_updater.h
|
||||
*/
|
||||
#ifndef PLUGIN_SYCL_TREE_HIST_UPDATER_H_
|
||||
#define PLUGIN_SYCL_TREE_HIST_UPDATER_H_
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include <xgboost/tree_updater.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
|
||||
#include "../common/partition_builder.h"
|
||||
#include "split_evaluator.h"
|
||||
|
||||
#include "../data.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace tree {
|
||||
|
||||
template<typename GradientSumT>
|
||||
class HistUpdater {
|
||||
public:
|
||||
explicit HistUpdater(::sycl::queue qu,
|
||||
const xgboost::tree::TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
FeatureInteractionConstraintHost int_constraints_,
|
||||
DMatrix const* fmat)
|
||||
: qu_(qu), param_(param),
|
||||
tree_evaluator_(qu, param, fmat->Info().num_col_),
|
||||
pruner_(std::move(pruner)),
|
||||
interaction_constraints_{std::move(int_constraints_)},
|
||||
p_last_tree_(nullptr), p_last_fmat_(fmat) {
|
||||
builder_monitor_.Init("SYCL::Quantile::HistUpdater");
|
||||
kernel_monitor_.Init("SYCL::Quantile::HistUpdater");
|
||||
const auto sub_group_sizes =
|
||||
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
|
||||
sub_group_size_ = sub_group_sizes.back();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
USMVector<size_t, MemoryType::on_device>* row_indices);
|
||||
|
||||
size_t sub_group_size_;
|
||||
const xgboost::tree::TrainParam& param_;
|
||||
TreeEvaluator<GradientSumT> tree_evaluator_;
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
FeatureInteractionConstraintHost interaction_constraints_;
|
||||
|
||||
// back pointers to tree and data matrix
|
||||
const RegTree* p_last_tree_;
|
||||
DMatrix const* const p_last_fmat_;
|
||||
|
||||
xgboost::common::Monitor builder_monitor_;
|
||||
xgboost::common::Monitor kernel_monitor_;
|
||||
|
||||
uint64_t seed_ = 0;
|
||||
|
||||
::sycl::queue qu_;
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace sycl
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // PLUGIN_SYCL_TREE_HIST_UPDATER_H_
|
||||
Reference in New Issue
Block a user