[sycl] fix fitting for fp32 devices (#10702)
Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
773ded684b
commit
0def8e0bae
@ -322,23 +322,49 @@ void HistUpdater<GradientSumT>::InitSampling(
|
|||||||
::sycl::buffer<uint64_t, 1> flag_buf(&num_samples, 1);
|
::sycl::buffer<uint64_t, 1> flag_buf(&num_samples, 1);
|
||||||
uint64_t seed = seed_;
|
uint64_t seed = seed_;
|
||||||
seed_ += num_rows;
|
seed_ += num_rows;
|
||||||
event = qu_.submit([&](::sycl::handler& cgh) {
|
|
||||||
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
|
|
||||||
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
|
|
||||||
[=](::sycl::item<1> pid) {
|
|
||||||
uint64_t i = pid.get_id(0);
|
|
||||||
|
|
||||||
// Create minstd_rand engine
|
/*
|
||||||
oneapi::dpl::minstd_rand engine(seed, i);
|
* oneDLP bernoulli_distribution implicitly uses double.
|
||||||
oneapi::dpl::bernoulli_distribution coin_flip(subsample);
|
* In this case the device doesn't have fp64 support,
|
||||||
|
* we generate bernoulli distributed random values from uniform distribution
|
||||||
|
*/
|
||||||
|
if (has_fp64_support_) {
|
||||||
|
// Use oneDPL bernoulli_distribution for better perf
|
||||||
|
event = qu_.submit([&](::sycl::handler& cgh) {
|
||||||
|
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
|
||||||
|
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
|
||||||
|
[=](::sycl::item<1> pid) {
|
||||||
|
uint64_t i = pid.get_id(0);
|
||||||
|
// Create minstd_rand engine
|
||||||
|
oneapi::dpl::minstd_rand engine(seed, i);
|
||||||
|
oneapi::dpl::bernoulli_distribution coin_flip(subsample);
|
||||||
|
auto bernoulli_rnd = coin_flip(engine);
|
||||||
|
|
||||||
auto rnd = coin_flip(engine);
|
if (gpair_ptr[i].GetHess() >= 0.0f && bernoulli_rnd) {
|
||||||
if (gpair_ptr[i].GetHess() >= 0.0f && rnd) {
|
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
|
||||||
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
|
row_idx[num_samples_ref++] = i;
|
||||||
row_idx[num_samples_ref++] = i;
|
}
|
||||||
}
|
});
|
||||||
});
|
});
|
||||||
});
|
} else {
|
||||||
|
// Use oneDPL uniform, as far as bernoulli_distribution uses fp64
|
||||||
|
event = qu_.submit([&](::sycl::handler& cgh) {
|
||||||
|
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
|
||||||
|
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
|
||||||
|
[=](::sycl::item<1> pid) {
|
||||||
|
uint64_t i = pid.get_id(0);
|
||||||
|
oneapi::dpl::minstd_rand engine(seed, i);
|
||||||
|
oneapi::dpl::uniform_real_distribution<float> distr;
|
||||||
|
const float rnd = distr(engine);
|
||||||
|
const bool bernoulli_rnd = rnd < subsample ? 1 : 0;
|
||||||
|
|
||||||
|
if (gpair_ptr[i].GetHess() >= 0.0f && bernoulli_rnd) {
|
||||||
|
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
|
||||||
|
row_idx[num_samples_ref++] = i;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
|
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -67,6 +67,7 @@ class HistUpdater {
|
|||||||
if (param.max_depth > 0) {
|
if (param.max_depth > 0) {
|
||||||
snode_device_.Resize(&qu, 1u << (param.max_depth + 1));
|
snode_device_.Resize(&qu, 1u << (param.max_depth + 1));
|
||||||
}
|
}
|
||||||
|
has_fp64_support_ = qu_.get_device().has(::sycl::aspect::fp64);
|
||||||
const auto sub_group_sizes =
|
const auto sub_group_sizes =
|
||||||
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
|
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
|
||||||
sub_group_size_ = sub_group_sizes.back();
|
sub_group_size_ = sub_group_sizes.back();
|
||||||
@ -183,6 +184,7 @@ class HistUpdater {
|
|||||||
|
|
||||||
// --data fields--
|
// --data fields--
|
||||||
const Context* ctx_;
|
const Context* ctx_;
|
||||||
|
bool has_fp64_support_;
|
||||||
size_t sub_group_size_;
|
size_t sub_group_size_;
|
||||||
|
|
||||||
// the internal row sets
|
// the internal row sets
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user