From 0def8e0bae41e0eeeb4ee078e6e699ad6bef1986 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Wed, 14 Aug 2024 21:50:17 +0200 Subject: [PATCH] [sycl] fix fitting for fp32 devices (#10702) Co-authored-by: Dmitry Razdoburdin <> --- plugin/sycl/tree/hist_updater.cc | 54 +++++++++++++++++++++++--------- plugin/sycl/tree/hist_updater.h | 2 ++ 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/plugin/sycl/tree/hist_updater.cc b/plugin/sycl/tree/hist_updater.cc index 18c1f02a3..efaddafdb 100644 --- a/plugin/sycl/tree/hist_updater.cc +++ b/plugin/sycl/tree/hist_updater.cc @@ -322,23 +322,49 @@ void HistUpdater::InitSampling( ::sycl::buffer flag_buf(&num_samples, 1); uint64_t seed = seed_; seed_ += num_rows; - event = qu_.submit([&](::sycl::handler& cgh) { - auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); - cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)), - [=](::sycl::item<1> pid) { - uint64_t i = pid.get_id(0); - // Create minstd_rand engine - oneapi::dpl::minstd_rand engine(seed, i); - oneapi::dpl::bernoulli_distribution coin_flip(subsample); + /* + * oneDLP bernoulli_distribution implicitly uses double. + * In this case the device doesn't have fp64 support, + * we generate bernoulli distributed random values from uniform distribution + */ + if (has_fp64_support_) { + // Use oneDPL bernoulli_distribution for better perf + event = qu_.submit([&](::sycl::handler& cgh) { + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)), + [=](::sycl::item<1> pid) { + uint64_t i = pid.get_id(0); + // Create minstd_rand engine + oneapi::dpl::minstd_rand engine(seed, i); + oneapi::dpl::bernoulli_distribution coin_flip(subsample); + auto bernoulli_rnd = coin_flip(engine); - auto rnd = coin_flip(engine); - if (gpair_ptr[i].GetHess() >= 0.0f && rnd) { - AtomicRef num_samples_ref(flag_buf_acc[0]); - row_idx[num_samples_ref++] = i; - } + if (gpair_ptr[i].GetHess() >= 0.0f && bernoulli_rnd) { + AtomicRef num_samples_ref(flag_buf_acc[0]); + row_idx[num_samples_ref++] = i; + } + }); }); - }); + } else { + // Use oneDPL uniform, as far as bernoulli_distribution uses fp64 + event = qu_.submit([&](::sycl::handler& cgh) { + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)), + [=](::sycl::item<1> pid) { + uint64_t i = pid.get_id(0); + oneapi::dpl::minstd_rand engine(seed, i); + oneapi::dpl::uniform_real_distribution distr; + const float rnd = distr(engine); + const bool bernoulli_rnd = rnd < subsample ? 1 : 0; + + if (gpair_ptr[i].GetHess() >= 0.0f && bernoulli_rnd) { + AtomicRef num_samples_ref(flag_buf_acc[0]); + row_idx[num_samples_ref++] = i; + } + }); + }); + } /* After calling a destructor for flag_buf, content will be copyed to num_samples */ } diff --git a/plugin/sycl/tree/hist_updater.h b/plugin/sycl/tree/hist_updater.h index 6d7d84070..5e0ca6645 100644 --- a/plugin/sycl/tree/hist_updater.h +++ b/plugin/sycl/tree/hist_updater.h @@ -67,6 +67,7 @@ class HistUpdater { if (param.max_depth > 0) { snode_device_.Resize(&qu, 1u << (param.max_depth + 1)); } + has_fp64_support_ = qu_.get_device().has(::sycl::aspect::fp64); const auto sub_group_sizes = qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>(); sub_group_size_ = sub_group_sizes.back(); @@ -183,6 +184,7 @@ class HistUpdater { // --data fields-- const Context* ctx_; + bool has_fp64_support_; size_t sub_group_size_; // the internal row sets