Improved InitSampling function speed by 2.12 times (#6410)
* Improved InitSampling function speed by 2.12 times * Added explicit conversion
This commit is contained in:
parent
3c3f026ec1
commit
5c8ccf4455
@ -695,17 +695,18 @@ void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<Gr
|
||||
r = rnd;
|
||||
}
|
||||
const size_t discard_size = info.num_row_ / nthread;
|
||||
auto upper_border = static_cast<float>(std::numeric_limits<uint32_t>::max());
|
||||
uint32_t coin_flip_border = static_cast<uint32_t>(upper_border * param_.subsample);
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
const size_t tid = omp_get_thread_num();
|
||||
const size_t ibegin = tid * discard_size;
|
||||
const size_t iend = (tid == (nthread - 1)) ?
|
||||
info.num_row_ : ibegin + discard_size;
|
||||
std::bernoulli_distribution coin_flip(param_.subsample);
|
||||
|
||||
rnds[tid].discard(2*discard_size * tid);
|
||||
rnds[tid].discard(discard_size * tid);
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
if (gpair[i].GetHess() >= 0.0f && coin_flip(rnds[tid])) {
|
||||
if (gpair[i].GetHess() >= 0.0f && rnds[tid]() < coin_flip_border) {
|
||||
p_row_indices[ibegin + row_offsets[tid]++] = i;
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user