From 5c8ccf4455e3eef3f863573958853bd6240800bd Mon Sep 17 00:00:00 2001 From: Igor Rukhovich Date: Wed, 16 Dec 2020 07:59:24 +0300 Subject: [PATCH] Improved InitSampling function speed by 2.12 times (#6410) * Improved InitSampling function speed by 2.12 times * Added explicit conversion --- src/tree/updater_quantile_hist.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 2a180fc82..b0fa98c85 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -695,17 +695,18 @@ void QuantileHistMaker::Builder::InitSampling(const std::vector(std::numeric_limits::max()); + uint32_t coin_flip_border = static_cast(upper_border * param_.subsample); #pragma omp parallel num_threads(nthread) { const size_t tid = omp_get_thread_num(); const size_t ibegin = tid * discard_size; const size_t iend = (tid == (nthread - 1)) ? info.num_row_ : ibegin + discard_size; - std::bernoulli_distribution coin_flip(param_.subsample); - rnds[tid].discard(2*discard_size * tid); + rnds[tid].discard(discard_size * tid); for (size_t i = ibegin; i < iend; ++i) { - if (gpair[i].GetHess() >= 0.0f && coin_flip(rnds[tid])) { + if (gpair[i].GetHess() >= 0.0f && rnds[tid]() < coin_flip_border) { p_row_indices[ibegin + row_offsets[tid]++] = i; } }