Optimizations of pre-processing for 'hist' tree method (#4310)
* oprimizations for pre-processing * code cleaning * code cleaning * code cleaning after review * Apply suggestions from code review Co-Authored-By: SmirnovEgorRu <egor.smirnov@intel.com>
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
207f058711
commit
711397d645
@@ -420,25 +420,74 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
|
||||
CHECK_EQ(info.root_index_.size(), 0U);
|
||||
std::vector<size_t>& row_indices = row_set_collection_.row_indices_;
|
||||
row_indices.resize(info.num_row_);
|
||||
auto* p_row_indices = row_indices.data();
|
||||
// mark subsample and build list of member rows
|
||||
|
||||
if (param_.subsample < 1.0f) {
|
||||
std::bernoulli_distribution coin_flip(param_.subsample);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < info.num_row_; ++i) {
|
||||
if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) {
|
||||
row_indices.push_back(i);
|
||||
p_row_indices[j++] = i;
|
||||
}
|
||||
}
|
||||
row_indices.resize(j);
|
||||
} else {
|
||||
for (size_t i = 0; i < info.num_row_; ++i) {
|
||||
if (gpair[i].GetHess() >= 0.0f) {
|
||||
row_indices.push_back(i);
|
||||
MemStackAllocator<bool, 128> buff(this->nthread_);
|
||||
bool* p_buff = buff.Get();
|
||||
std::fill(p_buff, p_buff + this->nthread_, false);
|
||||
|
||||
const size_t block_size = info.num_row_ / this->nthread_ + !!(info.num_row_ % this->nthread_);
|
||||
|
||||
#pragma omp parallel num_threads(this->nthread_)
|
||||
{
|
||||
const size_t tid = omp_get_thread_num();
|
||||
const size_t ibegin = tid * block_size;
|
||||
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
|
||||
static_cast<size_t>(info.num_row_));
|
||||
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
if (gpair[i].GetHess() < 0.0f) {
|
||||
p_buff[tid] = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool has_neg_hess = false;
|
||||
for (size_t tid = 0; tid < this->nthread_; ++tid) {
|
||||
if (p_buff[tid]) {
|
||||
has_neg_hess = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_neg_hess) {
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < info.num_row_; ++i) {
|
||||
if (gpair[i].GetHess() >= 0.0f) {
|
||||
p_row_indices[j++] = i;
|
||||
}
|
||||
}
|
||||
row_indices.resize(j);
|
||||
} else {
|
||||
#pragma omp parallel num_threads(this->nthread_)
|
||||
{
|
||||
const size_t tid = omp_get_thread_num();
|
||||
const size_t ibegin = tid * block_size;
|
||||
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
|
||||
static_cast<size_t>(info.num_row_));
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
p_row_indices[i] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
row_set_collection_.Init();
|
||||
}
|
||||
|
||||
row_set_collection_.Init();
|
||||
|
||||
{
|
||||
/* determine layout of data */
|
||||
const size_t nrow = info.num_row_;
|
||||
|
||||
@@ -28,6 +28,41 @@
|
||||
#include "../common/column_matrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
/*!
|
||||
* \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated.
|
||||
*/
|
||||
template<typename T, size_t MaxStackSize>
|
||||
class MemStackAllocator {
|
||||
public:
|
||||
explicit MemStackAllocator(size_t required_size): required_size_(required_size) {
|
||||
}
|
||||
|
||||
T* Get() {
|
||||
if (!ptr_) {
|
||||
if (MaxStackSize >= required_size_) {
|
||||
ptr_ = stack_mem_;
|
||||
} else {
|
||||
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
|
||||
do_free_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
~MemStackAllocator() {
|
||||
if (do_free_) free(ptr_);
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
T* ptr_ = nullptr;
|
||||
bool do_free_ = false;
|
||||
size_t required_size_;
|
||||
T stack_mem_[MaxStackSize];
|
||||
};
|
||||
|
||||
namespace tree {
|
||||
|
||||
using xgboost::common::HistCutMatrix;
|
||||
|
||||
Reference in New Issue
Block a user