diff --git a/plugin/sycl/tree/hist_updater.cc b/plugin/sycl/tree/hist_updater.cc new file mode 100644 index 000000000..7ac5924f4 --- /dev/null +++ b/plugin/sycl/tree/hist_updater.cc @@ -0,0 +1,58 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file hist_updater.cc + */ + +#include "hist_updater.h" + +#include + +namespace xgboost { +namespace sycl { +namespace tree { + +template +void HistUpdater::InitSampling( + const USMVector &gpair, + USMVector* row_indices) { + const size_t num_rows = row_indices->Size(); + auto* row_idx = row_indices->Data(); + const auto* gpair_ptr = gpair.DataConst(); + uint64_t num_samples = 0; + const auto subsample = param_.subsample; + ::sycl::event event; + + { + ::sycl::buffer flag_buf(&num_samples, 1); + uint64_t seed = seed_; + seed_ += num_rows; + event = qu_.submit([&](::sycl::handler& cgh) { + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)), + [=](::sycl::item<1> pid) { + uint64_t i = pid.get_id(0); + + // Create minstd_rand engine + oneapi::dpl::minstd_rand engine(seed, i); + oneapi::dpl::bernoulli_distribution coin_flip(subsample); + + auto rnd = coin_flip(engine); + if (gpair_ptr[i].GetHess() >= 0.0f && rnd) { + AtomicRef num_samples_ref(flag_buf_acc[0]); + row_idx[num_samples_ref++] = i; + } + }); + }); + /* After calling a destructor for flag_buf, content will be copyed to num_samples */ + } + + row_indices->Resize(&qu_, num_samples, 0, &event); + qu_.wait(); +} + +template class HistUpdater; +template class HistUpdater; + +} // namespace tree +} // namespace sycl +} // namespace xgboost diff --git a/plugin/sycl/tree/hist_updater.h b/plugin/sycl/tree/hist_updater.h new file mode 100644 index 000000000..9efc402c0 --- /dev/null +++ b/plugin/sycl/tree/hist_updater.h @@ -0,0 +1,72 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file hist_updater.h + */ +#ifndef PLUGIN_SYCL_TREE_HIST_UPDATER_H_ +#define PLUGIN_SYCL_TREE_HIST_UPDATER_H_ + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include +#pragma GCC diagnostic pop + +#include +#include + +#include "../common/partition_builder.h" +#include "split_evaluator.h" + +#include "../data.h" + +namespace xgboost { +namespace sycl { +namespace tree { + +template +class HistUpdater { + public: + explicit HistUpdater(::sycl::queue qu, + const xgboost::tree::TrainParam& param, + std::unique_ptr pruner, + FeatureInteractionConstraintHost int_constraints_, + DMatrix const* fmat) + : qu_(qu), param_(param), + tree_evaluator_(qu, param, fmat->Info().num_col_), + pruner_(std::move(pruner)), + interaction_constraints_{std::move(int_constraints_)}, + p_last_tree_(nullptr), p_last_fmat_(fmat) { + builder_monitor_.Init("SYCL::Quantile::HistUpdater"); + kernel_monitor_.Init("SYCL::Quantile::HistUpdater"); + const auto sub_group_sizes = + qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>(); + sub_group_size_ = sub_group_sizes.back(); + } + + protected: + void InitSampling(const USMVector &gpair, + USMVector* row_indices); + + size_t sub_group_size_; + const xgboost::tree::TrainParam& param_; + TreeEvaluator tree_evaluator_; + std::unique_ptr pruner_; + FeatureInteractionConstraintHost interaction_constraints_; + + // back pointers to tree and data matrix + const RegTree* p_last_tree_; + DMatrix const* const p_last_fmat_; + + xgboost::common::Monitor builder_monitor_; + xgboost::common::Monitor kernel_monitor_; + + uint64_t seed_ = 0; + + ::sycl::queue qu_; +}; + +} // namespace tree +} // namespace sycl +} // namespace xgboost + +#endif // PLUGIN_SYCL_TREE_HIST_UPDATER_H_ diff --git a/tests/ci_build/conda_env/linux_sycl_test.yml b/tests/ci_build/conda_env/linux_sycl_test.yml index bb14c1e77..7335b7f20 100644 --- a/tests/ci_build/conda_env/linux_sycl_test.yml +++ b/tests/ci_build/conda_env/linux_sycl_test.yml @@ -18,3 +18,4 @@ dependencies: - pytest-timeout - pytest-cov - dpcpp_linux-64 +- onedpl-devel diff --git a/tests/cpp/plugin/test_sycl_hist_updater.cc b/tests/cpp/plugin/test_sycl_hist_updater.cc new file mode 100644 index 000000000..81bb9cb7f --- /dev/null +++ b/tests/cpp/plugin/test_sycl_hist_updater.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2020-2024 by XGBoost contributors + */ +#include + +#include + +#include "../../../plugin/sycl/tree/hist_updater.h" +#include "../../../plugin/sycl/device_manager.h" + +#include "../helpers.h" + +namespace xgboost::sycl::tree { + +template +class TestHistUpdater : public HistUpdater { + public: + TestHistUpdater(::sycl::queue qu, + const xgboost::tree::TrainParam& param, + std::unique_ptr pruner, + FeatureInteractionConstraintHost int_constraints_, + DMatrix const* fmat) : HistUpdater(qu, param, std::move(pruner), + int_constraints_, fmat) {} + + void TestInitSampling(const USMVector &gpair, + USMVector* row_indices) { + HistUpdater::InitSampling(gpair, row_indices); + } +}; + +template +void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) { + const size_t num_rows = 1u << 12; + const size_t num_columns = 1; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + ObjInfo task{ObjInfo::kRegression}; + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix(); + + FeatureInteractionConstraintHost int_constraints; + std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; + + TestHistUpdater updater(qu, param, std::move(pruner), int_constraints, p_fmat.get()); + + USMVector row_indices_0(&qu, num_rows); + USMVector row_indices_1(&qu, num_rows); + USMVector gpair(&qu, num_rows); + auto* gpair_ptr = gpair.Data(); + qu.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)), + [=](::sycl::item<1> pid) { + uint64_t i = pid.get_linear_id(); + + constexpr uint32_t seed = 777; + oneapi::dpl::minstd_rand engine(seed, i); + oneapi::dpl::uniform_real_distribution distr(-1., 1.); + gpair_ptr[i] = {distr(engine), distr(engine)}; + }); + }).wait(); + + updater.TestInitSampling(gpair, &row_indices_0); + + size_t n_samples = row_indices_0.Size(); + // Half of gpairs have neg hess + ASSERT_LT(n_samples, num_rows * 0.5 * param.subsample * 1.2); + ASSERT_GT(n_samples, num_rows * 0.5 * param.subsample / 1.2); + + // Check if two lanunches generate different realisations: + updater.TestInitSampling(gpair, &row_indices_1); + if (row_indices_1.Size() == n_samples) { + std::vector row_indices_0_host(n_samples); + std::vector row_indices_1_host(n_samples); + qu.memcpy(row_indices_0_host.data(), row_indices_0.Data(), n_samples * sizeof(size_t)).wait(); + qu.memcpy(row_indices_1_host.data(), row_indices_1.Data(), n_samples * sizeof(size_t)).wait(); + + // The order in row_indices_0 and row_indices_1 can be different + std::set rows; + for (auto row : row_indices_0_host) { + rows.insert(row); + } + + size_t num_diffs = 0; + for (auto row : row_indices_1_host) { + if (rows.count(row) == 0) num_diffs++; + } + + ASSERT_NE(num_diffs, 0); + } + +} + +TEST(SyclHistUpdater, Sampling) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"subsample", "0.7"}}); + + TestHistUpdaterSampling(param); + TestHistUpdaterSampling(param); +} +} // namespace xgboost::sycl::tree